In [1]:
import os
os.environ.update(
#     CUDA_LAUNCH_BLOCKING="1",
#     CUDA_VISIBLE_DEVICES="1",
    TF_CPP_MIN_LOG_LEVEL="2",
)

from contextlib import contextmanager
from functools import partial
import time

import numpy as np

# import functorch
import tensorflow.compat.v2 as tf
import torch
from torch import nn

from simple_profiling import ProfilingCProfile, ProfilingWallClock

In [2]:
devices = tf.config.list_physical_devices('GPU')
for device in devices:
    tf.config.experimental.set_memory_growth(device, True)
tf.config.run_functions_eagerly(False)

In [3]:
def tf_make_mlp(hidden_sizes):
    net = tf.keras.Sequential()
    for hidden_size in hidden_sizes:
        net.add(tf.keras.layers.Dense(hidden_size))
    net.add(tf.keras.layers.Dense(1))
    return net


def torch_make_mlp(input_size, hidden_sizes):
    layers = []
    prev_size = input_size
    for hidden_size in hidden_sizes:
        layers.append(nn.Linear(prev_size, hidden_size))
        prev_size = hidden_size
    layers.append(nn.Linear(prev_size, 1))
    return nn.Sequential(*layers)

In [4]:
def tf_gradient(f, x):
    with tf.GradientTape() as g:
        g.watch(x)
        f_value = f(x, training=False)
    df_dx = g.gradient(f_value, x)
    return df_dx


@tf.function
def tf_gradient_fast(f, x):
    return tf_gradient(f, x)


def torch_gradient(f, x):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    y = f(x).sum()
    df_dx, = torch.autograd.grad([y], [x])
    return df_dx


class TorchGradient(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
    
    def forward(self, x):
        y = self.f(x).sum()
        df_dx, = torch.autograd.grad([y], [x])
        return df_dx

In [5]:
@contextmanager
def disable_param_grad(net):
    restore_grad = [param.requires_grad for param in net.parameters()]
    yield
    for param, restore_grad_i in zip(net.parameters(), restore_grad):
        param.requires_grad = restore_grad_i

In [6]:
N = 512
DimY = 2
hidden_sizes = [256, 256]

# N = 1
# DimY = 2
# hidden_sizes = []

num_iter = 100
step_size = 1e-8

np.random.seed(0)
yhs_init = np.random.rand(N, DimY).astype(np.float32)

stat_dir = "/tmp/torch_vs_tf"
os.makedirs(stat_dir, exist_ok=True)

def needs_net(f):
    # Hack specific to this stuff.
    f.needs_net = True
    return f

In [7]:
def tf_fake_langevin(net, yhs, num_iter, grad_func=tf_gradient):
    for _ in range(num_iter):
        de_dact = grad_func(net, yhs)
        yhs = yhs + de_dact * step_size
    return yhs

def tf_fake_langevin_kinda_fast(net, yhs, num_iter):
    # Only do gradient computation graph.
    return tf_fake_langevin(net, yhs, num_iter, grad_func=tf_gradient_fast)

@tf.function
def tf_fake_langevin_fast(net, yhs, num_iter):
    return tf_fake_langevin(net, yhs, num_iter, grad_func=tf_gradient_fast)


def torch_fake_langevin(net, yhs, num_iter, grad_func=torch_gradient):
    assert not torch.is_grad_enabled()
    for _ in range(num_iter):
        with torch.set_grad_enabled(True):
            de_dact = torch_gradient(
                net, yhs.detach().requires_grad_(True)
            )
        yhs = yhs + de_dact * step_size
    return yhs


@needs_net
def torch_fake_langevin_kinda_fast(net):
    fast = TorchGradient(net)
    fast = torch.jit.script(fast)
    func = lambda f, x: fast(x)

    def fake_langevin(net, yhs, num_iter):
        return torch_fake_langevin(net, yhs, num_iter, grad_func=func)
    
    return fake_langevin


class TorchFakeLangevin(nn.Module):
    # Since we can't pass a net as argument, just wrap it into a module...
    def __init__(self, net, step_size):
        super().__init__()
        self.net = net
        self._step_size = step_size

    def _gradient(self, yhs): # -> torch.Tensor:
        # See: https://github.com/pytorch/pytorch/issues/70223
        yhs = yhs.detach().requires_grad_(True)
        out = self.net(yhs).sum()
        dy, = torch.autograd.grad([out], [yhs])
        assert dy is not None
        return dy.detach()

    def forward(self, yhs, num_iter: int):
        for i in range(num_iter):
            de_dact = self._gradient(yhs)
            yhs += yhs + de_dact * self._step_size
        return yhs

#     def forward(self, yhs, num_iter: int):
#         for i in range(num_iter):
#             yhs_tmp = yhs.detach().requires_grad_(True)
#             loss = self.net(yhs_tmp).sum()
#             de_dact, = torch.autograd.grad([loss], [yhs_tmp])
#             assert de_dact is not None
#             de_dact = de_dact.detach()
#             yhs = yhs.detach()
# #             de_dact = self._gradient(yhs)
#             yhs += yhs + de_dact * self._step_size
#         return yhs


@needs_net
def torch_fake_langevin_fast(net):
    fast = TorchFakeLangevin(net, step_size)
    fast = torch.jit.script(fast)

    def fake_langevin(net, yhs, num_iter):
        with torch.set_grad_enabled(True):
            return fast(yhs, num_iter)

    return fake_langevin


@needs_net
def functorch_fake_langevin(net):
    f, param = functorch.make_functional(net)
    
    def fake(x):
        return f(param, x).sum()

    # f(param, x)
    df = functorch.grad(fake)
    
    def fake_langevin(net, yhs, num_iter):
        for _ in range(num_iter):
            de_dact = df(yhs)
            yhs = yhs + de_dact * step_size
        return yhs
    
    return fake_langevin

In [8]:
def tf_profile(fake_langevin):
    name = fake_langevin.__name__
    net = tf_make_mlp(hidden_sizes)

    yhs = tf.convert_to_tensor(yhs_init)
    # Trace.
    net(yhs)

    def work():
        return fake_langevin(net, yhs, num_iter)
    
    # Warmup.
    work()
    
    profiler = ProfilingWallClock()
    with profiler.context():
        work()

#     profiler.save_to_file(base=f"{stat_dir}/{name}")
    print(f"{name}: {profiler.dt:.3g}s")

@torch.no_grad()
def torch_profile(fake_langevin):
    name = fake_langevin.__name__
    device = torch.device("cuda")
#     device = torch.device("cpu")

    net = torch_make_mlp(DimY, hidden_sizes)
    net.eval().to(device)
    yhs = torch.from_numpy(yhs_init).to(device)

    def work():
        # Simulate device transfer to flush graph.
        return fake_langevin(net, yhs, num_iter).detach().cpu()
    
    if getattr(fake_langevin, "needs_net", False):
        fake_langevin = fake_langevin(net)

    # Warmup; needs >=2 for jit on first usage?
    for _ in range(2):
        work()

    profiler = ProfilingWallClock()
    with profiler.context():
        work()

#     profiler.save_to_file(base=f"{stat_dir}/{name}")
    print(f"{name}: {profiler.dt:.3g}s")

In [15]:
tf_profile(tf_fake_langevin)
tf_profile(tf_fake_langevin_kinda_fast)
tf_profile(tf_fake_langevin_fast)

print()

torch_profile(torch_fake_langevin)
torch_profile(torch_fake_langevin_kinda_fast)
# torch_profile(functorch_fake_langevin)
torch_profile(torch_fake_langevin_fast)

torch_fake_langevin: 0.0534s
