In [1]:
from generator import Generator
import time
import jax
import jax.numpy as jnp
import torch
import equinox as eqx
# import wandb
from hifigan_decoder import HifiDecoder

torch_hifigan = HifiDecoder(decoder_input_dim=80).waveform_decoder

torch_hifigan.cuda()

def to_cuda(x: torch.Tensor) -> torch.Tensor:
    if x is None:
        return None
    if torch.is_tensor(x):
        x = x.contiguous()
        if torch.cuda.is_available():
            x = x.cuda(non_blocking=True)
    return x

jax_hifigan = Generator(80, 1, key=jax.random.key(1))

In [None]:
import time
import jax
import jax.numpy as jnp
import torch
import csv

# Assuming jax_dvae and dvae are already defined

# Define a range of input sizes
input_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 512+256]

# JIT-compile jax_dvae if not already compiled
jax_hifigan = eqx.filter_jit(jax_hifigan)

# Open a CSV file to store the results
with open('execution_times.csv', mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Input Size', 'JAX Execution Time (s)', 'PyTorch Execution Time (s)'])

    for size in input_sizes:
        # JAX
        batch_jax = jnp.ones((size, 80, 1024)).to_device(jax.devices()[0])
        
        # Warm-up run (to avoid JIT compilation time in the measurement)
        for _ in range(0, 10):
            _ = eqx.filter_vmap(jax_hifigan)(batch_jax)
        
        # Measure execution time
        start_jax = time.perf_counter()
        val_jax = eqx.filter_vmap(jax_hifigan)(batch_jax)
        end_jax = time.perf_counter()
        jax_time = end_jax - start_jax

        # PyTorch
        batch_torch = torch.ones((size, 80, 1024)).cuda()
        for _ in range(0, 10):
            # Warm-up run (to avoid CUDA initialization time in the measurement)
            with torch.no_grad():
                torch_hifigan.eval()
                _ = torch_hifigan(batch_torch)
        
        # Measure execution time
        start_torch = time.perf_counter()
        with torch.no_grad():
            torch_hifigan.eval()
            val_torch = torch_hifigan(batch_torch)
        end_torch = time.perf_counter()
        torch_time = end_torch - start_torch

        # Write the results to the CSV file
        writer.writerow([size, jax_time, torch_time])

        print(f"Input Size: {size}, JAX Time: {jax_time:.6f} s, PyTorch Time: {torch_time:.6f} s")

print("Done! Results saved to execution_times.csv")

[CudaDevice(id=0)]
Input Size: 1, JAX Time: 0.042994 s, PyTorch Time: 0.019998 s
[CudaDevice(id=0)]
