Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequential backprop impl sketch #72

Open
vadimkantorov opened this issue Nov 10, 2021 · 1 comment
Open

Sequential backprop impl sketch #72

vadimkantorov opened this issue Nov 10, 2021 · 1 comment

Comments

@vadimkantorov
Copy link

vadimkantorov commented Nov 10, 2021

Should something like below work for wrapping ResNet's last layer (Neck)? (https://gist.github.com/vadimkantorov/67fe785ed0bf31727af29a3584b87be1)

import torch
import torch.nn as nn

class SequentialBackprop(nn.Module):
    def __init__(self, module, batch_size = 1):
        super().__init__()
        self.module = module
        self.batch_size = batch_size

    def forward(self, x):
        y = self.module(x.detach())
        return self.Function.apply(x, y, self.batch_size, self.module)

    class Function(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x, y, batch_size, module):
            ctx.save_for_backward(x)
            ctx.batch_size = batch_size
            ctx.module = module
            return y

        @staticmethod
        def backward(ctx, grad_output):
            (x,) = ctx.saved_tensors
            grads = []
            for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)):
                with torch.enable_grad():
                    x_mini = x_mini.detach().requires_grad_()
                    x_mini.retain_grad()
                    y_mini = ctx.module(x_mini)
                torch.autograd.backward(y_mini, g_mini)
                grads.append(x_mini.grad)
            return torch.cat(grads), None, None, None

if __name__ == '__main__':
    backbone = nn.Linear(3, 6)
    neck = nn.Linear(6, 12)
    head = nn.Linear(12, 1)

    model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head)

    print('before', neck.weight.grad)

    x = torch.rand(512, 3)
    model(x).sum().backward()
    print('after', neck.weight.grad)
@ck6698000
Copy link

Hello vadimkantorov! I'm trying to implement this module recently, wondering whether your SBP code can work or not?
Or there may need more modification? Would be grateful if any help is provided!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants