In [1]:
import os
os.environ.update(
    CUDA_LAUNCH_BLOCKING="1",
    CUDA_VISIBLE_DEVICES="1",
    TF_CPP_MIN_LOG_LEVEL="2",
)

import time

import numpy as np

import torch
from torch import nn
import tensorflow.compat.v2 as tf

In [2]:
devices = tf.config.list_physical_devices('GPU')
for device in devices:
    tf.config.experimental.set_memory_growth(device, True)
tf.config.run_functions_eagerly(False)

In [3]:
def tf_make_mlp(hidden_sizes):
    net = tf.keras.Sequential()
    for hidden_size in hidden_sizes:
        net.add(tf.keras.layers.Dense(hidden_size))
    net.add(tf.keras.layers.Dense(1))
    return net


def torch_make_mlp(input_size, hidden_sizes):
    layers = []
    prev_size = input_size
    for hidden_size in hidden_sizes:
        layers.append(nn.Linear(prev_size, hidden_size))
        prev_size = hidden_size
    layers.append(nn.Linear(prev_size, 1))
    return nn.Sequential(*layers)

In [4]:
def tf_gradient(f, x):
    with tf.GradientTape() as g:
        g.watch(x)
        f_value = f(x, training=False)
    df_dx = g.gradient(f_value, x)
    return df_dx


def torch_gradient(f: nn.Module, x, create_graph: bool):
    # Adapted from:
    # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5
    # TODO(eric.cousineau): Keep storage for dnet_dx?
    y = f(x).sum()
    df_dx, = torch.autograd.grad(y, x, create_graph=create_graph)
    return df_dx

In [5]:
def tf_fake_langevin(net, yhs, num_iter):
    for _ in range(num_iter):
        de_dact = tf_gradient(net, yhs)
        yhs = yhs + de_dact * 0.1
    return yhs


def torch_fake_langevin(net, yhs, num_iter):
    create_graph = torch.is_grad_enabled()
    yhs = yhs.detach().requires_grad_(True)
    with torch.set_grad_enabled(True):
        for _ in range(num_iter):
            de_dact = torch_gradient(net, yhs, create_graph)
            yhs = yhs + 0.1 * de_dact
    return yhs

In [6]:
N = 32
L = 8
DimY = 2
np.random.seed(0)
yhs_init = np.random.rand(N, L, DimY).astype(np.float32)
hidden_sizes = [256, 256]
count = 10
num_iter = 100


def tf_profile(use_jit):
    net = tf_make_mlp(hidden_sizes)

    yhs = tf.convert_to_tensor(yhs_init)
    # Trace.
    net(yhs)
    
    fake_langevin = tf_fake_langevin
    if use_jit:
        fake_langevin = tf.function(fake_langevin)

    t_start = None
    for _ in range(count + 1):

        fake_langevin(net, yhs, num_iter)

        if t_start is None:
            t_start = time.time()

    dt = time.time() - t_start
    print(dt / count)


@torch.no_grad()
def torch_profile():
    net = torch_make_mlp(DimY, hidden_sizes)
    # net = torch.jit.trace(net, yhs)

    device = torch.device("cuda")

    net.to(device)
    yhs = torch.from_numpy(yhs_init).to(device)

    count = 10
    num_iter = 100

    t_start = None
    with torch.no_grad():
        for _ in range(count + 1):

            yhs_new = torch_fake_langevin(net, yhs, num_iter)

            if t_start is None:
                t_start = time.time()
    dt = time.time() - t_start
    print(dt / count)

tf_profile(True)
tf_profile(False)
torch_profile()

0.013439536094665527
0.28806533813476565
0.2407933473587036
