In [1]:
# Simple forward mode autodiff.

In [2]:
from imp import reload
import torch_simple_grad as m
reload(m)

<module 'torch_simple_grad' from '/home/eacousineau/proj/tri/repo/repro/python/torch/torch_simple_grad.py'>

In [3]:
import torch
from torch import nn

In [4]:
from torch_simple_grad import torch_forward_diff, torch_col_zero

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f7ec4357e50>

In [6]:
def torch_gradient(f, x):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    with torch.set_grad_enabled(True):
        x = x.detach().requires_grad_(True)
        y = f(x).sum()
        df_dx, = torch.autograd.grad([y], [x])
        return df_dx.detach()


def torch_make_mlp(input_size, hidden_sizes, output_size):
    layers = []
    prev_size = input_size
    for hidden_size in hidden_sizes:
        layers.append(nn.Linear(prev_size, hidden_size))
        layers.append(nn.ReLU())
        prev_size = hidden_size
    layers.append(nn.Linear(prev_size, output_size))
    return nn.Sequential(*layers)


def torch_forward_diff_old(net, x, dx=None):
    # Imperative.
    if dx is None:
        N, L = x.shape
        dx = torch.eye(L, device=x.device, dtype=x.dtype)
        dx = dx.repeat(N, 1, 1)
    if isinstance(net, nn.Sequential):
        count = len(net)
        for i, net_i in enumerate(net):
            dx = torch_forward_diff_old(net_i, x, dx)
            # Don't compute for last.
            if i + 1 < count:
                x = net_i(x)
    elif isinstance(net, nn.Linear):
        A = net.weight
        dx = dx @ A.T
    elif isinstance(net, nn.ReLU):
        torch_col_zero(dx, x <= 0)
    else:
        assert False, type(net)
    return dx

In [7]:
torch.random.manual_seed(0)

# N = 3
# nin = 2
# nout = 1
# hidden_sizes = [1]

N = 512
nin = 16
nout = 1
hidden_sizes = [256, 256]

device = torch.device("cuda")
net = torch_make_mlp(nin, hidden_sizes, nout)
net.eval().to(device)
dnet_dx = torch_forward_diff(net)
dnet_dx_script = torch.jit.script(dnet_dx)

x = torch.randn((N, nin), device=device)

y = net(x)
dy_dx = torch_gradient(net, x)
# dy_dx_a = torch_forward_diff_old(net, x).squeeze(-1)
dy_dx_a = dnet_dx_script(x)

print(dy_dx - dy_dx_a)

tensor([[-7.9647e-06,  7.5959e-06, -5.0962e-06,  ..., -4.6147e-06,
          1.2591e-05,  5.7891e-06],
        [ 2.6040e-06, -5.8673e-06, -1.4782e-06,  ...,  3.9530e-06,
          2.0713e-06,  1.9036e-06],
        [ 9.0580e-06, -4.0885e-07,  8.7339e-06,  ..., -1.0902e-05,
          6.2212e-07,  2.1327e-07],
        ...,
        [-4.6156e-06, -2.4028e-06, -1.1645e-05,  ..., -7.6126e-06,
          5.6755e-06,  1.2004e-05],
        [-1.2450e-05, -3.0193e-06, -3.9227e-06,  ...,  1.7546e-06,
         -3.6228e-06,  3.0138e-06],
        [ 9.0079e-06, -2.1213e-06, -1.3924e-05,  ...,  1.0342e-05,
         -1.7490e-06, -4.9472e-06]], device='cuda:0')


In [8]:
%timeit -n 50 dy_dx = torch_gradient(net, x).cpu()

1.53 ms ± 195 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [9]:
%timeit -n 50 dy_dx_a = torch_forward_diff_old(net, x).squeeze(-1).cpu()

461 µs ± 72.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [10]:
%timeit -n 50 dy_dx_a = dnet_dx(x).cpu()

416 µs ± 36.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [11]:
%timeit -n 50 dy_dx_a = dnet_dx_script(x).cpu()

585 µs ± 46.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
