In [None]:
# Simple forward mode autodiff.

In [None]:
import torch
from torch import nn

In [None]:
torch.set_grad_enabled(False)

In [None]:
def torch_gradient(f, x):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    with torch.set_grad_enabled(True):
        x = x.detach().requires_grad_(True)
        y = f(x).sum()
        df_dx, = torch.autograd.grad([y], [x])
        return df_dx.detach()


def torch_make_mlp(input_size, hidden_sizes, output_size):
    layers = []
    prev_size = input_size
    for hidden_size in hidden_sizes:
        layers.append(nn.Linear(prev_size, hidden_size))
        layers.append(nn.ReLU())
        prev_size = hidden_size
    layers.append(nn.Linear(prev_size, output_size))
    return nn.Sequential(*layers)


def torch_col_zero(A, mask):
    # TODO(eric): Better way to do this?
    if mask.ndim == 1:
        A[:, mask] = 0.0
    else:
        N = mask.shape[0]
        for i in range(N):
            A[i, :, mask[i]] = 0.0


def torch_forward_diff(net, x, dx=None):
    if dx is None:
        if x.ndim == 1:
            L, = x.shape
            dx = torch.eye(L)
        elif x.ndim == 2:
            N, L = x.shape
            dx = torch.eye(L).repeat(N, 1, 1)
        dx = dx.to(x)
    if isinstance(net, nn.Sequential):
        count = len(net)
        for i, net_i in enumerate(net):
            dx = torch_forward_diff(net_i, x, dx)
            # Don't compute for last.
            if i + 1 < count:
                x = net_i(x)
    elif isinstance(net, nn.Linear):
        A = net.weight
        dx = dx @ A.T
    elif isinstance(net, nn.ReLU):
        torch_col_zero(dx, x <= 0)
    else:
        assert False, type(net)
    return dx

In [None]:
torch.random.manual_seed(0)

N = 512
nin = 2
nout = 1

device = torch.device("cuda")
net = torch_make_mlp(nin, [256, 256], nout)
net.eval().to(device)

x = torch.randn((N, nin), device=device)

y = net(x)

In [None]:
%timeit -n 50 dy_dx = torch_gradient(net, x).detach().cpu()

In [None]:
%timeit -n 50 dy_dx_a = torch_forward_diff(net, x).squeeze(-1).detach().cpu()