In [1]:
import os
os.environ.update(
    CUDA_LAUNCH_BLOCKING="1",
    CUDA_VISIBLE_DEVICES="1",
    TF_CPP_MIN_LOG_LEVEL="2",
)

import time

import numpy as np

import tensorflow.compat.v2 as tf
import torch
from torch import nn

In [2]:
devices = tf.config.list_physical_devices('GPU')
for device in devices:
    tf.config.experimental.set_memory_growth(device, True)
tf.config.run_functions_eagerly(False)

In [3]:
def tf_make_mlp(hidden_sizes):
    net = tf.keras.Sequential()
    for hidden_size in hidden_sizes:
        net.add(tf.keras.layers.Dense(hidden_size))
    net.add(tf.keras.layers.Dense(1))
    return net


def torch_make_mlp(input_size, hidden_sizes):
    layers = []
    prev_size = input_size
    for hidden_size in hidden_sizes:
        layers.append(nn.Linear(prev_size, hidden_size))
        prev_size = hidden_size
    layers.append(nn.Linear(prev_size, 1))
    return nn.Sequential(*layers)

In [4]:
from typing import Callable

In [5]:
def tf_gradient(f, x):
    with tf.GradientTape() as g:
        g.watch(x)
        f_value = f(x, training=False)
    df_dx = g.gradient(f_value, x)
    return df_dx


def torch_gradient(f: Callable, x, create_graph: bool):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    y = f(x).sum()
    df_dx, = torch.autograd.grad(y, x, create_graph=create_graph)
    return df_dx

In [6]:
def tf_fake_langevin(net, yhs, num_iter):
    for _ in range(num_iter):
        de_dact = tf_gradient(net, yhs)
        yhs = yhs + de_dact * 0.1
    return yhs


tf_fake_langevin_fast = tf.function(tf_fake_langevin)


def torch_fake_langevin(net, yhs, num_iter):
    create_graph = torch.is_grad_enabled()
    yhs = yhs.detach().requires_grad_(True)
    with torch.set_grad_enabled(True):
        for _ in range(num_iter):
            de_dact = torch_gradient(net, yhs, create_graph)
            yhs = yhs + 0.1 * de_dact
    return yhs


class TorchFakeLangevin(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net

    def _gradient(self, yhs) -> torch.Tensor:
        out = self.net(yhs).sum()
        dy, = torch.autograd.grad([out], [yhs], allow_unused=True)
        if dy is None:
            dy = torch.zeros_like(yhs)
        return dy

    def forward(self, yhs, num_iter: int):
        yhs = yhs.detach().requires_grad_(True)
        for i in range(num_iter):
            de_dact = self._gradient(yhs)
            yhs = yhs + de_dact * 0.1
        return yhs


def make_torch_fake_langevin_fast(net):
    fast = TorchFakeLangevin(net)
    with torch.set_grad_enabled(True):
        fast = torch.jit.script(fast)

    def fake_langevin(net, yhs, num_iter):
        with torch.set_grad_enabled(True):
            return fast(yhs, num_iter)

    return fake_langevin

In [7]:
N = 3
L = 1
DimY = 2
hidden_sizes = [256, 256]
count = 10
num_iter = 100

np.random.seed(0)
yhs_init = np.random.rand(N * L, DimY).astype(np.float32)

In [8]:
def tf_profile(fake_langevin):
    net = tf_make_mlp(hidden_sizes)

    yhs = tf.convert_to_tensor(yhs_init)
    # Trace.
    net(yhs)

    t_start = None
    for _ in range(count + 1):

        fake_langevin(net, yhs, num_iter)

        if t_start is None:
            t_start = time.time()

    dt = time.time() - t_start
    print(dt / count)


@torch.no_grad()
def torch_profile(fake_langevin, is_constructor=False):
    net = torch_make_mlp(DimY, hidden_sizes)
    net = torch.jit.script(net)

    device = torch.device("cuda")

    net.to(device)
    yhs = torch.from_numpy(yhs_init).to(device)

    if is_constructor:
        fake_langevin = fake_langevin(net)

    t_start = None
    with torch.no_grad():
        for i in range(count + 1):

            yhs_new = fake_langevin(net, yhs, num_iter)

            if t_start is None:
                t_start = time.time()
    dt = time.time() - t_start
    print(dt / count)

In [9]:
tf_profile(tf_fake_langevin)  # This is 2x slower than use in ibc venv?!
tf_profile(tf_fake_langevin_fast)

torch_profile(torch_fake_langevin)
for _ in range(2):  # Need to run twice?
    torch_profile(make_torch_fake_langevin_fast, is_constructor=True)

0.14471290111541749
0.012119817733764648
0.14556553363800048
0.17220258712768555
0.030535435676574706
