# Benchmarking Jax vs PyTorch 2 (with compile)

The aim is to investigate performance on a reasonable consumer GPU from VastAI, for the forward pass
of GPT2.

In [2]:
# Hacky pip install for cuda on jax (won't install with Poetry)
%pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

## Model Setup

In [3]:
from transformers import FlaxGPT2Model, GPT2Model, GPT2Config
import time
import jax
import torch

# Create the models
config = GPT2Config()
jax_model = FlaxGPT2Model(config)
pt_model = GPT2Model(config)

# Move models to CUDA & compile
print(jax.devices())
# jax_model = jax_model.cuda()
# pt_model = pt_model.cuda()
# torch.compile(pt_model)

ImportError: 
FlaxGPT2Model requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.


In [None]:
jax_inputs = []
pt_inputs = []

sample_size: int = 10
for _ in range(sample_size):
    jax_input = jax.random.randint(
        jax.random.PRNGKey(0), (1, 1024), 0, config.vocab_size
    )
    jax_inputs.append(jax_input)

    pt_input = torch.randint(0, config.vocab_size, (1, 1024))
    pt_inputs.append(pt_input)

In [None]:
for model_name in ["jax", "pt"]:
    start = time.time()

    for i in range(sample_size):
        inputs = jax_inputs[i] if model_name == "jax" else pt_inputs[i]
        model = jax_model if model_name == "jax" else pt_model
        model(inputs)

    print(f"{model} took {time.time() - start:.2f}s")