# Benchmarking Jax vs PyTorch 2 (with compile)

The aim is to investigate performance on a reasonable consumer GPU from VastAI, for the forward pass
of GPT2.

## Model Setup

In [1]:
from transformers import FlaxGPT2Model, GPT2Model, GPT2Config
import time
import jax
import torch

# Create the models
config = GPT2Config()
jax_model = FlaxGPT2Model(config)
pt_model = GPT2Model(config)

# Move models to CUDA & compile
# Note you have to install the gpu version of jax (`poetry add jax[cuda12_pip]`)
# jax_model = jax_model.cuda()
# pt_model = pt_model.cuda()
# torch.compile(pt_model)

In [5]:
jax_inputs = []
pt_inputs = []

sample_size: int = 10
for _ in range(sample_size):
    jax_input = jax.random.randint(
        jax.random.PRNGKey(0), (1, 1024), 0, config.vocab_size
    )
    jax_inputs.append(jax_input)

    pt_input = torch.randint(0, config.vocab_size, (1, 1024))
    pt_inputs.append(pt_input)

In [7]:
for model_name in ["jax", "pt"]:
    start = time.time()

    for i in range(sample_size):
        inputs = jax_inputs[i] if model_name == "jax" else pt_inputs[i]
        model = jax_model if model_name == "jax" else pt_model
        model(inputs)

    print(f"{model} took {time.time() - start:.2f}s")

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
) took 10.85s
GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attent