In [1]:
import os
os.environ.update(
    CUDA_LAUNCH_BLOCKING="1",
    CUDA_VISIBLE_DEVICES="1",
    TF_CPP_MIN_LOG_LEVEL="2",
)

from functools import partial
import time

import numpy as np

import tensorflow.compat.v2 as tf
import torch
from torch import nn

from simple_profiling import ProfilingCProfile

In [2]:
devices = tf.config.list_physical_devices('GPU')
for device in devices:
    tf.config.experimental.set_memory_growth(device, True)
tf.config.run_functions_eagerly(False)

In [3]:
def tf_make_mlp(hidden_sizes):
    net = tf.keras.Sequential()
    for hidden_size in hidden_sizes:
        net.add(tf.keras.layers.Dense(hidden_size))
    net.add(tf.keras.layers.Dense(1))
    return net


def torch_make_mlp(input_size, hidden_sizes):
    layers = []
    prev_size = input_size
    for hidden_size in hidden_sizes:
        layers.append(nn.Linear(prev_size, hidden_size))
        prev_size = hidden_size
    layers.append(nn.Linear(prev_size, 1))
    return nn.Sequential(*layers)

In [4]:
def tf_gradient(f, x):
    with tf.GradientTape() as g:
        g.watch(x)
        f_value = f(x, training=False)
    df_dx = g.gradient(f_value, x)
    return df_dx


@tf.function
def tf_gradient_fast(f, x):
    return tf_gradient(f, x)


def torch_gradient(f, x):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    y = f(x).sum()
    df_dx, = torch.autograd.grad([y], [x])
    return df_dx


class TorchGradient(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
    
    def forward(self, x):
        y = self.f(x).sum()
        df_dx, = torch.autograd.grad([y], [x])
        return df_dx

In [5]:
def needs_net(f):
    f.needs_net = True
    return f

In [6]:
import functorch

In [7]:
step_size = 0.001

def tf_fake_langevin(net, yhs, num_iter, grad_func=tf_gradient):
    for _ in range(num_iter):
        de_dact = grad_func(net, yhs)
        yhs = yhs + de_dact * step_size
    return yhs

def tf_fake_langevin_kinda_fast(net, yhs, num_iter):
    # Only do gradient computation graph.
    return tf_fake_langevin(net, yhs, num_iter, grad_func=tf_gradient_fast)

@tf.function
def tf_fake_langevin_fast(net, yhs, num_iter):
    return tf_fake_langevin(net, yhs, num_iter, grad_func=tf_gradient_fast)


def torch_fake_langevin(net, yhs, num_iter, grad_func=torch_gradient):
    assert not torch.is_grad_enabled()
    for _ in range(num_iter):
        with torch.set_grad_enabled(True):
            de_dact = torch_gradient(
                net, yhs.detach().requires_grad_(True)
            )
        yhs = yhs + de_dact * step_size
    return yhs


@needs_net
def torch_fake_langevin_kinda_fast(net):
    fast = TorchGradient(net)
    fast = torch.jit.script(fast)
    func = lambda f, x: fast(x)

    def fake_langevin(net, yhs, num_iter):
        return torch_fake_langevin(net, yhs, num_iter, grad_func=func)
    
    return fake_langevin


class TorchFakeLangevin(nn.Module):
    def __init__(self, net, step_size):
        super().__init__()
        self.net = net
        self._step_size = step_size

    def _gradient(self, yhs) -> torch.Tensor:
        out = self.net(yhs).sum()
        dy, = torch.autograd.grad([out], [yhs], allow_unused=True)
        # TODO(eric): Why does this happen?
        if dy is None:
            dy = torch.zeros_like(yhs)
        return dy

    def forward(self, yhs, num_iter: int):
        yhs = yhs.detach().requires_grad_(True)
        for i in range(num_iter):
            de_dact = self._gradient(yhs).detach()
            yhs = yhs + de_dact * self._step_size
        return yhs
    

@needs_net
def torch_fake_langevin_fast(net):
    fast = TorchFakeLangevin(net, step_size)
    with torch.set_grad_enabled(True):
        fast = torch.jit.script(fast)

    def fake_langevin(net, yhs, num_iter):
        with torch.set_grad_enabled(True):
            return fast(yhs, num_iter)

    return fake_langevin


@needs_net
def functorch_fake_langevin(net):
    f, param = functorch.make_functional(net)
    
    def fake(x):
        return f(param, x).sum()

    # f(param, x)
    df = functorch.grad(fake)
    
    def fake_langevin(net, yhs, num_iter):
        for _ in range(num_iter):
            de_dact = df(yhs)
            yhs = yhs + de_dact * step_size
        return yhs
    
    return fake_langevin

In [8]:
N = 32
L = 8
DimY = 2
hidden_sizes = [256, 256]
num_iter = 100

np.random.seed(0)
yhs_init = np.random.rand(N * L, DimY).astype(np.float32)

In [9]:
def tf_profile(fake_langevin):
    name = fake_langevin.__name__
    net = tf_make_mlp(hidden_sizes)

    yhs = tf.convert_to_tensor(yhs_init)
    # Trace.
    net(yhs)

    # Warmup.
    fake_langevin(net, yhs, num_iter)
    
    profiler = ProfilingCProfile()
    with profiler.context():
        fake_langevin(net, yhs, num_iter)
    file, = profiler.save_to_file(base=f"/tmp/tensorflow_{name}")
    print(f"{profiler.dt}  ({file})")

@torch.no_grad()
def torch_profile(fake_langevin):
    name = fake_langevin.__name__
    device = torch.device("cuda")

    net = torch_make_mlp(DimY, hidden_sizes)
    net.to(device)
    yhs = torch.from_numpy(yhs_init).to(device)

    if getattr(fake_langevin, "needs_net", False):
        fake_langevin = fake_langevin(net)

    # Warmup.
    yhs_new = fake_langevin(net, yhs, num_iter)

    profiler = ProfilingCProfile()
    with profiler.context():
        yhs_new = fake_langevin(net, yhs, num_iter)
    file, = profiler.save_to_file(base=f"/tmp/torch_{name}")
    print(f"{profiler.dt}  ({file})")

In [10]:
tf_profile(tf_fake_langevin)
tf_profile(tf_fake_langevin_kinda_fast)
tf_profile(tf_fake_langevin_fast)

torch_profile(torch_fake_langevin)
torch_profile(torch_fake_langevin_kinda_fast)
torch_profile(functorch_fake_langevin)
for _ in range(3):  # Need to run more than once?
    torch_profile(torch_fake_langevin_fast)

0.20664525032043457  (/tmp/tensorflow_tf_fake_langevin_stats.txt)
0.08259344100952148  (/tmp/tensorflow_tf_fake_langevin_kinda_fast_stats.txt)
0.009865760803222656  (/tmp/tensorflow_tf_fake_langevin_fast_stats.txt)
0.18395233154296875  (/tmp/torch_torch_fake_langevin_stats.txt)
0.17860126495361328  (/tmp/torch_torch_fake_langevin_kinda_fast_stats.txt)
0.12345194816589355  (/tmp/torch_functorch_fake_langevin_stats.txt)
1.4889845848083496  (/tmp/torch_torch_fake_langevin_fast_stats.txt)
0.049472808837890625  (/tmp/torch_torch_fake_langevin_fast_stats.txt)
0.04242134094238281  (/tmp/torch_torch_fake_langevin_fast_stats.txt)
