In [1]:
# Simple forward mode autodiff.

In [2]:
import torch
from torch import nn

In [3]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f888842c250>

In [4]:
def torch_gradient(f, x):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    with torch.set_grad_enabled(True):
        x = x.detach().requires_grad_(True)
        y = f(x).sum()
        df_dx, = torch.autograd.grad([y], [x])
        return df_dx.detach()


def torch_make_mlp(input_size, hidden_sizes, output_size):
    layers = []
    prev_size = input_size
    for hidden_size in hidden_sizes:
        layers.append(nn.Linear(prev_size, hidden_size))
        layers.append(nn.ReLU())
        prev_size = hidden_size
    layers.append(nn.Linear(prev_size, output_size))
    return nn.Sequential(*layers)


def torch_col_zero(A, mask):
    # A: N x R x C
    # mask: N x C
    N, R, C = A.shape
    N, L = mask.shape
    mask = mask.unsqueeze(1).repeat(1, R, 1)
    A[mask] = 0.0


def torch_forward_diff(net, x, dx=None):
    if dx is None:
        N, L = x.shape
        dx = torch.eye(L, device=x.device, dtype=x.dtype)
        dx = dx.repeat(N, 1, 1)
    if isinstance(net, nn.Sequential):
        count = len(net)
        for i, net_i in enumerate(net):
            dx = torch_forward_diff(net_i, x, dx)
            # Don't compute for last.
            if i + 1 < count:
                x = net_i(x)
    elif isinstance(net, nn.Linear):
        A = net.weight
        dx = dx @ A.T
    elif isinstance(net, nn.ReLU):
        torch_col_zero(dx, x <= 0)
    else:
        assert False, type(net)
    return dx

In [5]:
torch.random.manual_seed(0)

# N = 3
# nin = 2
# nout = 1
# hidden_sizes = []

N = 512
nin = 2
nout = 1
hidden_sizes = [256, 256]

device = torch.device("cuda")
net = torch_make_mlp(nin, hidden_sizes, nout)
net.eval().to(device)

x = torch.randn((N, nin), device=device)

y = net(x)
dy_dx = torch_gradient(net, x)
dy_dx_a = torch_forward_diff(net, x).squeeze(-1)

print(dy_dx - dy_dx_a)

tensor([[-2.9323e-05, -1.9170e-05],
        [-1.2907e-05, -1.6205e-05],
        [-6.4762e-05, -4.5732e-05],
        ...,
        [ 4.7311e-07, -1.2666e-05],
        [ 6.1169e-06,  2.9355e-05],
        [-2.8305e-05, -9.4622e-06]], device='cuda:0')


In [6]:
%timeit -n 50 dy_dx = torch_gradient(net, x).cpu()

1.64 ms ± 78 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [7]:
%timeit -n 50 dy_dx_a = torch_forward_diff(net, x).squeeze(-1).cpu()

389 µs ± 96 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
