In [1]:
# Simple forward mode autodiff.

In [2]:
# from imp import reload
# import torch_simple_grad as m
# reload(m)

In [3]:
import os
os.environ.update(
#     CUDA_LAUNCH_BLOCKING="1",
)

import torch
from torch import nn

In [4]:
from torch_simple_grad import torch_forward_diff, torch_col_zero

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f2d7afe9bb0>

In [6]:
def torch_gradient(f, x):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    with torch.set_grad_enabled(True):
        x = x.detach().requires_grad_(True)
        y = f(x).sum()
        df_dx, = torch.autograd.grad([y], [x])
        return df_dx.detach()


def torch_make_mlp(input_size, hidden_sizes, output_size):
    layers = []
    prev_size = input_size
    for hidden_size in hidden_sizes:
        layers.append(nn.Linear(prev_size, hidden_size))
        layers.append(nn.ReLU())
        prev_size = hidden_size
    layers.append(nn.Linear(prev_size, output_size))
    return nn.Sequential(*layers)


def torch_forward_diff_old(net, x, dx=None):
    # Imperative.
    if dx is None:
        N, L = x.shape
        dx = torch.eye(L, device=x.device, dtype=x.dtype)
        dx = dx.repeat(N, 1, 1)
    if isinstance(net, nn.Sequential):
        count = len(net)
        for i, net_i in enumerate(net):
            dx = torch_forward_diff_old(net_i, x, dx)
            # Don't compute for last.
            if i + 1 < count:
                x = net_i(x)
    elif isinstance(net, nn.Linear):
        A = net.weight
        dx = dx @ A.T
    elif isinstance(net, nn.ReLU):
        torch_col_zero(dx, x <= 0)
    else:
        assert False, type(net)
    return dx

In [7]:
torch.random.manual_seed(0)

# N = 3
# nin = 2
# nout = 1
# hidden_sizes = [1]

N = 512  # Seems OK
# N = 512 * 8  # Slows down a ton
nin = 16
nout = 1
hidden_sizes = [512] * 8

device = torch.device("cuda")
net = torch_make_mlp(nin, hidden_sizes, nout)
net.eval().to(device)
dnet_dx = torch_forward_diff(net)
dnet_dx_script = torch.jit.script(dnet_dx)

x = torch.randn((N, nin), device=device)

y = net(x)
dy_dx = torch_gradient(net, x)
# dy_dx_a = torch_forward_diff_old(net, x).squeeze(-1)
dy_dx_a = dnet_dx_script(x)

print((dy_dx - dy_dx_a).abs().max() / dy_dx.abs().max())

tensor(0.0007, device='cuda:0')


In [8]:
count = 2
from simple_profiling import ProfilingTorch, ProfilingCProfile, ProfilingWallClock

d = os.path.expanduser("~/tmp/torch_prof")
os.makedirs(d, exist_ok=True)

def prof_grad(name, x0, grad):
    prof = ProfilingTorch()
#     prof = ProfilingWallClock()
#     prof = ProfilingCProfile()
    with prof.context():
        x = x0.clone()
        step_size = 1e-8
        for _ in range(count):
            x += grad(x) * step_size
        torch.cuda.synchronize()
    print(prof.prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
    file, = prof.save_to_file(base=f"{d}/{name}")
    print(file)
    print(prof.dt / count)

In [9]:
prof_grad("torch_gradient", x, lambda x: torch_gradient(net, x))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                 timing        44.66%      14.297ms        57.71%      18.474ms      18.474ms      14.442ms        47.16%      18.468ms      18.468ms             1  
    autograd::engine::evaluate_function: AddmmBackward0         2.21%     708.000us        32.39%      10.368ms     576.000us       1.226ms         4.00%      10.199ms     566.611us            18  
         

In [10]:
prof_grad("torch_forward_diff_old", x, lambda x: torch_forward_diff_old(net, x).squeeze(-1))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                 timing        13.07%       1.676ms        97.14%      12.453ms      12.453ms       1.318ms        10.58%      12.445ms      12.445ms             1  
                                           aten::repeat         2.74%     351.000us        23.73%       3.042ms     169.000us     707.000us         5.68%       3.604ms     200.222us            18  
         

In [11]:
prof_grad("dnet_dx", x, dnet_dx)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                 timing        11.78%       1.482ms        97.11%      12.221ms      12.221ms       1.103ms         9.02%      12.216ms      12.216ms             1  
                                           aten::repeat         2.71%     341.000us        23.31%       2.934ms     163.000us     737.000us         6.03%       2.738ms     152.111us            18  
         

In [12]:
prof_grad("dnet_dx_script", x, dnet_dx_script)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                 timing        40.96%      10.503ms        98.60%      25.284ms      25.284ms      10.493ms        41.50%      25.278ms      25.278ms             1  
                                                forward         1.51%     386.000us        56.74%      14.551ms       7.276ms     485.000us         1.92%      14.620ms       7.310ms             2  
         