In [None]:
import time
import torch
import torch.nn.functional as F
from typing import Callable

print("print cuda stuff")
print(torch.cuda.is_available(), torch.cuda.get_device_name(torch.cuda.current_device()))
print()

##############################################################
# Define the methods
@torch.no_grad()
def compute_runtime(n_iters: int, f: Callable, *args, **kwargs) -> float:
    assert torch.cuda.is_available()

    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(n_iters):
        _ = f(*args, **kwargs)

    torch.cuda.synchronize()

    duration = time.perf_counter() - start
    print(f"Computation took {duration * 1000:.2f} ms")

    return duration

def log_hat_1(conv_f: torch.Tensor, conv_i: torch.Tensor):
    log_f_hat = -F.softplus(conv_i - conv_f)
    log_i_hat = -F.softplus((conv_f - conv_i))

    return log_f_hat, log_i_hat

def log_hat_2(conv_f: torch.Tensor, conv_i: torch.Tensor):
    f_hat_2 = F.sigmoid(conv_f - conv_i)
    log_f_hat_2 = f_hat_2.log()

    log_i_hat_2 = F.sigmoid(conv_i - conv_f).log()
    # log_i_hat_2 = (1. - f_hat_2).log()

    return log_f_hat_2, log_i_hat_2

#############################################################################

# Generate random data
B, C, H, W, L = 2, 10, 16, 16, 100
conv_f = torch.rand(B, C, H, W, L)
conv_i = torch.rand(B, C, H, W, L)


# Check if the operations are equivalent
log_f_hat, log_i_hat = log_hat_1(conv_f, conv_i)
log_f_hat_2, log_i_hat_2 = log_hat_2(conv_f, conv_i)
print(torch.allclose(log_f_hat, log_f_hat_2), torch.allclose(log_i_hat, log_i_hat_2))

# Runtime analysis
n_iters = 10_000
compute_runtime(n_iters, log_hat_1, conv_f, conv_i)
compute_runtime(n_iters, log_hat_2, conv_f, conv_i)