In [1]:
from VQVAE import VQVAE
import time
import jax
import jax.numpy as jnp
import torch
import equinox as eqx
# import wandb
from dvae import DiscreteVAE

dvae = DiscreteVAE(
            channels=80,
            normalization=None,
            positional_dims=1,
            num_tokens=1024,
            codebook_dim=512,
            hidden_dim=512,
            num_resnet_blocks=3,
            kernel_size=3,
            num_layers=2,
            use_transposed_convs=False,
        )

dvae.cuda()

def to_cuda(x: torch.Tensor) -> torch.Tensor:
    if x is None:
        return None
    if torch.is_tensor(x):
        x = x.contiguous()
        if torch.cuda.is_available():
            x = x.cuda(non_blocking=True)
    return x

jax_dvae = VQVAE(jax.random.key(1))
jax_dvae = eqx.filter_jit(jax_dvae)

In [2]:
import time
import jax
import jax.numpy as jnp
import torch
import csv

# Assuming jax_dvae and dvae are already defined

# Define a range of input sizes
input_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]

# JIT-compile jax_dvae if not already compiled
jax_dvae = eqx.filter_jit(jax_dvae)

# Open a CSV file to store the results
with open('execution_times.csv', mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Input Size', 'JAX Execution Time (s)', 'PyTorch Execution Time (s)'])

    for size in input_sizes:
        # JAX
        batch_jax = jnp.ones((size, 80, 1024)).to_device(jax.devices()[0])
        
        # Warm-up run (to avoid JIT compilation time in the measurement)
        _ = jax.vmap(jax_dvae)(batch_jax)
        
        # Measure execution time
        start_jax = time.perf_counter()
        val_jax = jax.vmap(jax_dvae)(batch_jax)
        end_jax = time.perf_counter()
        jax_time = end_jax - start_jax

        # PyTorch
        batch_torch = torch.ones((size, 80, 1024)).cuda()
        
        # Warm-up run (to avoid CUDA initialization time in the measurement)
        with torch.no_grad():
            dvae.eval()
            _ = dvae(batch_torch)
        
        # Measure execution time
        start_torch = time.perf_counter()
        with torch.no_grad():
            dvae.eval()
            val_torch = dvae(batch_torch)
        end_torch = time.perf_counter()
        torch_time = end_torch - start_torch

        # Write the results to the CSV file
        writer.writerow([size, jax_time, torch_time])

        print(f"Input Size: {size}, JAX Time: {jax_time:.6f} s, PyTorch Time: {torch_time:.6f} s")

print("Done! Results saved to execution_times.csv")

Input Size: 1, JAX Time: 0.005542 s, PyTorch Time: 0.004477 s
Input Size: 2, JAX Time: 0.005798 s, PyTorch Time: 0.003943 s
Input Size: 4, JAX Time: 0.006465 s, PyTorch Time: 0.003793 s
Input Size: 8, JAX Time: 0.007729 s, PyTorch Time: 0.003713 s
Input Size: 16, JAX Time: 0.009797 s, PyTorch Time: 0.003790 s
Input Size: 32, JAX Time: 0.013586 s, PyTorch Time: 0.003755 s
Input Size: 64, JAX Time: 0.021880 s, PyTorch Time: 0.003728 s
Input Size: 128, JAX Time: 0.038111 s, PyTorch Time: 0.003786 s
Input Size: 256, JAX Time: 0.075985 s, PyTorch Time: 0.003788 s
Input Size: 512, JAX Time: 0.147908 s, PyTorch Time: 0.003854 s
Input Size: 1024, JAX Time: 0.293042 s, PyTorch Time: 0.697063 s
Done! Results saved to execution_times.csv
