# Benchmarking Jax vs PyTorch 2 (with compile)

The aim is to investigate performance on a reasonable consumer GPU from VastAI, for the forward pass
of GPT2.

In [1]:
# Hacky pip install for cuda on jax (won't install with Poetry)
%pip uninstall jax -y
%pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Install CUDA Compat
%sudo apt-get install -y cuda-compat-12-1

Found existing installation: jax 0.4.19
Uninstalling jax-0.4.19:
  Successfully uninstalled jax-0.4.19
Note: you may need to restart the kernel to use updated packages.
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda12_pip]
  Using cached jax-0.4.19-py3-none-any.whl.metadata (23 kB)
Using cached jax-0.4.19-py3-none-any.whl (1.7 MB)
Installing collected packages: jax
Successfully installed jax-0.4.19
Note: you may need to restart the kernel to use updated packages.


## Model Setup

In [11]:
from transformers import FlaxGPT2Model, GPT2Model, GPT2Config
import time
import jax
import torch

# Create the models
config = GPT2Config()
jax_model = FlaxGPT2Model(config)
pt_model = GPT2Model(config)

# Move models to CUDA & compile
print(jax.devices())
pt_model = pt_model.cuda()
torch.compile(pt_model)

jax_forward = jax.jit(jax_model.__call__)

[cuda(id=0)]


In [12]:
jax_inputs = []
pt_inputs = []

sample_size: int = 10
for _ in range(sample_size):
    jax_input = jax.random.randint(
        jax.random.PRNGKey(0), (1, 1024), 0, config.vocab_size
    )
    jax_inputs.append(jax_input)

    pt_input = torch.randint(0, config.vocab_size, (1, 1024)).cuda()
    pt_inputs.append(pt_input)

In [13]:
for model_name in ["jax", "pt"]:
    start = time.time()

    for i in range(sample_size):
        inputs = jax_inputs[i] if model_name == "jax" else pt_inputs[i]
        model = jax_forward if model_name == "jax" else pt_model
        model(inputs)

    print(f"{model_name} took {time.time() - start:.2f}s")

jax took 6.09s
pt took 0.15s
