In [None]:
# Simple forward mode autodiff.

In [None]:
# from imp import reload
# import torch_simple_grad as m
# reload(m)

In [None]:
import os

import torch
from torch import nn

In [None]:
from torch_simple_grad import torch_forward_diff, torch_col_zero

In [None]:
torch.set_grad_enabled(False)

In [None]:
def torch_gradient(f, x):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    with torch.set_grad_enabled(True):
        x = x.detach().requires_grad_(True)
        y = f(x).sum()
        df_dx, = torch.autograd.grad([y], [x])
        return df_dx.detach()


def torch_make_mlp(input_size, hidden_sizes, output_size):
    layers = []
    prev_size = input_size
    for hidden_size in hidden_sizes:
        layers.append(nn.Linear(prev_size, hidden_size))
        layers.append(nn.ReLU())
        prev_size = hidden_size
    layers.append(nn.Linear(prev_size, output_size))
    return nn.Sequential(*layers)


def torch_forward_diff_old(net, x, dx=None):
    # Imperative.
    if dx is None:
        N, L = x.shape
        dx = torch.eye(L, device=x.device, dtype=x.dtype)
        dx = dx.repeat(N, 1, 1)
    if isinstance(net, nn.Sequential):
        count = len(net)
        for i, net_i in enumerate(net):
            dx = torch_forward_diff_old(net_i, x, dx)
            # Don't compute for last.
            if i + 1 < count:
                x = net_i(x)
    elif isinstance(net, nn.Linear):
        A = net.weight
        dx = dx @ A.T
    elif isinstance(net, nn.ReLU):
        torch_col_zero(dx, x <= 0)
    else:
        assert False, type(net)
    return dx

In [None]:
torch.random.manual_seed(0)

# N = 3
# nin = 2
# nout = 1
# hidden_sizes = [1]

N = 512  # Seems OK
# N = 512 * 8  # Slows down a ton
nin = 16
nout = 1
hidden_sizes = [512] * 8

device = torch.device("cuda")
net = torch_make_mlp(nin, hidden_sizes, nout)
net.eval().to(device)
dnet_dx = torch_forward_diff(net)
dnet_dx_script = torch.jit.script(dnet_dx)

x = torch.randn((N, nin), device=device)

y = net(x)
dy_dx = torch_gradient(net, x)
# dy_dx_a = torch_forward_diff_old(net, x).squeeze(-1)
dy_dx_a = dnet_dx_script(x)

print((dy_dx - dy_dx_a).abs().max() / dy_dx.abs().max())

In [None]:
count = 30
from simple_profiling import ProfilingTorch, ProfilingCProfile, ProfilingWallClock

d = os.path.expanduser("~/tmp/torch_prof")
os.makedirs(d, exist_ok=True)

def prof_grad(name, x0, grad):
#     prof = ProfilingTorch()
#     prof = ProfilingWallClock()
    prof = ProfilingCProfile()
    with prof.context():
        x = x0.clone()
        step_size = 1e-8
        for _ in range(count):
            x += grad(x) * step_size
        x = x.cpu()
    file, = prof.save_to_file(base=f"{d}/{name}")
    print(file)
    print(prof.dt / count)

In [None]:
prof_grad("torch_gradient", x, lambda x: torch_gradient(net, x))

In [None]:
prof_grad("torch_forward_diff_old", x, lambda x: torch_forward_diff_old(net, x).squeeze(-1))

In [None]:
prof_grad("dnet_dx", x, dnet_dx)

In [None]:
prof_grad("dnet_dx_script", x, dnet_dx_script)