# Implementing GPT-2 124M with JAX, FLAX, and NNX

## Introduction

In this Jupyter notebook, we will implement the GPT-2 124M (small) model, a transformer-based language model developed by OpenAI, using JAX and FLAX with the NNX API. Our goal is to:

- Build the GPT-2 model architecture from scratch using FLAX's NNX API.
- Load pre-trained weights from Hugging Face's GPT-2 124M model.
- Optimize the model with JAX's JIT compilation.
- Demonstrate its usage by computing logits for a sample input.

This notebook serves as both a practical guide and an educational resource, walking you through each step with clear code and explanations. We'll assume some familiarity with Python and machine learning, but we'll explain JAX, FLAX, and transformer concepts as needed.

Let's get started!

## Setup

First, we need to install the necessary libraries and import the required modules.

### Install Libraries

Run the following commands in your terminal or a notebook cell to install the dependencies. You may need to adjust the JAX installation based on your hardware (CPU, GPU, or TPU).


In [1]:
!pip install jax jaxlib flax transformers
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Collecting libtpu==0.0.10.* (from jax[tpu])
  Downloading libtpu-0.0.10.1-py3-none-manylinux_2_27_x86_64.whl.metadata (202 bytes)
Downloading libtpu-0.0.10.1-py3-none-manylinux_2_27_x86_64.whl (129.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.7/129.7 MB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: libtpu
  Attempting uninstall: libtpu
    Found existing installation: libtpu 0.0.7.1
    Uninstalling libtpu-0.0.7.1:
      Successfully uninstalled libtpu-0.0.7.1
Successfully installed libtpu-0.0.10.1


### Import Modules

Now, let's import the modules we'll use throughout the notebook.
- **JAX**: Provides high-performance numerical computing and automatic differentiation.
- **FLAX**: A neural network library built on JAX, with NNX being a modern API for building models.
- **Transformers**: Hugging Face's library for accessing pre-trained models and tokenizers.
- **NumPy and Torch**: Used for weight conversions between PyTorch (Hugging Face) and JAX.


In [2]:
import jax
print(jax.devices())
import jax.numpy as jnp
import flax.nnx as nnx
from transformers import GPT2Config, GPT2Tokenizer, GPT2Model as HF_GPT2Model
from transformers import GPT2LMHeadModel
import torch
import numpy as np


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]



## Model Architecture

Here, we'll define the GPT-2 model architecture using FLAX's NNX API, ensuring it matches Hugging Face's implementation. GPT-2 is a decoder-only transformer with the following components:

- **Token and Positional Embeddings**: Convert input tokens and their positions into dense vectors.
- **Transformer Blocks**: A stack of layers, each containing multi-head self-attention and a feed-forward network, both with layer normalization and residual connections.
- **Final Layer Normalization**: Normalizes the output of the last transformer block.

We'll break this down into modular components.


### Define the Components

#### Multi-Head Attention

The multi-head attention mechanism allows the model to focus on different parts of the input sequence. GPT-2 uses causal (masked) attention to ensure that each token only attends to previous tokens.


**Notes**:
- `c_attn` projects the input to query (Q), key (K), and value (V) vectors in one go.
- The causal mask (`jnp.tril`) ensures attention is only applied to previous positions.
- Dropout is included for regularization, disabled during inference with `deterministic=True`.



In [3]:
class MultiHeadAttention(nnx.Module):
    def __init__(self, n_embd: int, n_head: int, attn_pdrop: float, resid_pdrop: float, *, rngs: nnx.Rngs):
        self.c_attn = nnx.Linear(n_embd, 3 * n_embd, rngs=rngs)  # Combined Q, K, V projection
        self.c_proj = nnx.Linear(n_embd, n_embd, rngs=rngs)      # Output projection
        self.attn_dropout = nnx.Dropout(attn_pdrop, rngs=rngs)
        self.resid_dropout = nnx.Dropout(resid_pdrop, rngs=rngs)
        self.n_head = n_head
        self.n_embd = n_embd

    def __call__(self, x: jnp.ndarray, rngs: nnx.Rngs, deterministic: bool = False):
        batch, seq_len, _ = x.shape
        head_dim = self.n_embd // self.n_head

        # Compute Q, K, V in one linear transformation and split
        qkv = self.c_attn(x)  # (batch, seq_len, 3 * n_embd)
        q, k, v = jnp.split(qkv, 3, axis=-1)  # Each (batch, seq_len, n_embd)

        # Reshape for multi-head attention
        q = q.reshape(batch, seq_len, self.n_head, head_dim).transpose(0, 2, 1, 3)
        k = k.reshape(batch, seq_len, self.n_head, head_dim).transpose(0, 2, 1, 3)
        v = v.reshape(batch, seq_len, self.n_head, head_dim).transpose(0, 2, 1, 3)

        # Attention scores with causal masking
        scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(head_dim)
        mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))
        scores = jnp.where(mask, scores, -1e9)
        attn_weights = jax.nn.softmax(scores, axis=-1)
        attn_weights = self.attn_dropout(attn_weights, rngs=rngs, deterministic=deterministic)

        # Weighted sum of values
        out = jnp.matmul(attn_weights, v)
        out = out.transpose(0, 2, 1, 3).reshape(batch, seq_len, self.n_embd)

        # Output projection
        out = self.c_proj(out)
        out = self.resid_dropout(out, rngs=rngs, deterministic=deterministic)
        return out


#### Feed-Forward Network

The feed-forward network (FFN) applies two linear transformations with a GELU activation in between.


**Notes**:
- `intermediate_size` is typically 4 * `n_embd` in GPT-2 (e.g., 3072 for `n_embd=768`).
- GELU activation adds nonlinearity, followed by dropout.


In [4]:
class FeedForward(nnx.Module):
    def __init__(self, n_embd: int, intermediate_size: int, resid_pdrop: float, *, rngs: nnx.Rngs):
        self.c_fc = nnx.Linear(n_embd, intermediate_size, rngs=rngs)  # Expansion
        self.c_proj = nnx.Linear(intermediate_size, n_embd, rngs=rngs)  # Projection back
        self.act = nnx.gelu
        self.dropout = nnx.Dropout(resid_pdrop, rngs=rngs)

    def __call__(self, x: jnp.ndarray, rngs: nnx.Rngs, deterministic: bool = False):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x, rngs=rngs, deterministic=deterministic)
        return x

#### Transformer Block

Each transformer block combines multi-head attention and a feed-forward network with residual connections and layer normalization.



**Notes**:
- Layer normalization is applied before attention and FFN (pre-LN transformer design).
- Residual connections help with gradient flow during training.


In [5]:
class TransformerBlock(nnx.Module):
    def __init__(self, n_embd: int, n_head: int, intermediate_size: int, attn_pdrop: float, resid_pdrop: float, *, rngs: nnx.Rngs):
        self.ln_1 = nnx.LayerNorm(n_embd, rngs=rngs)  # Pass rngs to LayerNorm
        self.attn = MultiHeadAttention(n_embd, n_head, attn_pdrop, resid_pdrop, rngs=rngs)
        self.ln_2 = nnx.LayerNorm(n_embd, rngs=rngs)  # Pass rngs to LayerNorm
        self.mlp = FeedForward(n_embd, intermediate_size, resid_pdrop, rngs=rngs)

    def __call__(self, x: jnp.ndarray, rngs: nnx.Rngs, deterministic: bool = False):
        residual = x
        x = self.ln_1(x)
        x = self.attn(x, rngs, deterministic)
        x = x + residual  # Residual connection
        residual = x
        x = self.ln_2(x)
        x = self.mlp(x, rngs, deterministic)
        x = x + residual  # Residual connection
        return x


#### GPT-2 Model



**Architecture Visualization** (Text-based):

```
Input IDs (batch, seq_len)
  ↓
[Token Embeddings] + [Positional Embeddings] → [Dropout]
  ↓
[Transformer Block 1] → [Transformer Block 2] → ... → [Transformer Block N]
  ↓
[Final LayerNorm]
  ↓
Hidden States (batch, seq_len, n_embd)
```

Where each `Transformer Block` is:

```
Input → [LayerNorm 1] → [Multi-Head Attention] → [Add Residual] → [LayerNorm 2] → [Feed-Forward] → [Add Residual] → Output
```

**Notes**:
- The model matches Hugging Face's `GPT2Model`, outputting hidden states (not logits).
- Config parameters (e.g., `n_embd`, `n_layer`) are sourced from `GPT2Config`.


In [6]:
class GPT2Model(nnx.Module):
    def __init__(self, config: GPT2Config, *, rngs: nnx.Rngs):
        self.wte = nnx.Embed(config.vocab_size, config.n_embd, rngs=rngs)
        self.wpe = nnx.Embed(config.n_positions, config.n_embd, rngs=rngs)
        self.drop = nnx.Dropout(config.embd_pdrop, rngs=rngs)
        self.h = [
            TransformerBlock(
                n_embd=config.n_embd,
                n_head=config.n_head,
                intermediate_size=config.n_inner or 4 * config.n_embd,
                attn_pdrop=config.attn_pdrop,
                resid_pdrop=config.resid_pdrop,
                rngs=rngs
            ) for _ in range(config.n_layer)
        ]
        self.ln_f = nnx.LayerNorm(config.n_embd, epsilon=1e-5, rngs=rngs)
        self.lm_head = nnx.Linear(config.n_embd, config.vocab_size, use_bias=False, rngs=rngs)

    def __call__(self, input_ids: jnp.ndarray, rngs: nnx.Rngs, deterministic: bool = False):

        batch_size, seq_len = input_ids.shape
        position_ids = jnp.arange(seq_len)  # Shape: (seq_len,)

        # Token embeddings
        token_emb = self.wte(input_ids)  # Shape: (batch_size, seq_len, n_embd)

        # Positional embeddings
        position_emb = self.wpe(position_ids)  # Shape: (seq_len, n_embd)

        # Ensure position_emb is broadcasted to match token_emb
        position_emb = position_emb[None, :, :]  # Add batch dimension: (1, seq_len, n_embd)

        # Combine embeddings
        x = token_emb + position_emb  # Shape: (batch_size, seq_len, n_embd)

        # Apply dropout
        x = self.drop(x, rngs=rngs, deterministic=deterministic)

        # Pass through transformer blocks
        for i, block in enumerate(self.h):
            x = block(x, rngs, deterministic)

        # Final LayerNorm
        x = self.ln_f(x)  # Shape should still be (batch_size, seq_len, n_embd)

        # Apply lm_head
        logits = self.lm_head(x)  # Shape: (batch_size, seq_len, vocab_size)

        return logits

## Loading Pre-trained Weights

We'll load the pre-trained weights from Hugging Face's GPT-2 124M model and map them to our NNX model.

### Load Hugging Face Model and Config



In [7]:
# Load configuration and pre-trained model
config = GPT2Config.from_pretrained("gpt2")  # GPT-2 small (124M parameters)
hf_model = HF_GPT2Model.from_pretrained("gpt2")
hf_state_dict = hf_model.state_dict()  # PyTorch state dictionary

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

### Initialize NNX Model


In [8]:
# Initialize NNX model with RNGs
rngs = nnx.Rngs(0)  # Seed 0, arbitrary since weights are overwritten
model = GPT2Model(config, rngs=rngs)

### Map Weights

Hugging Face uses PyTorch's conventions, where linear weights are `(out_features, in_features)`, while FLAX uses `(in_features, out_features)`. Thus, we need to transpose linear weights. We'll also convert PyTorch tensors to JAX arrays.



**Pitfalls to Watch**:
- **Transposition**: Forgetting to transpose linear weights (`c_attn`, `c_proj`, `c_fc`) will mismatch dimensions.
- **Naming**: Ensure parameter names align exactly with Hugging Face's state dict keys.
- **Dropout**: Dropout layers have no weights, so they don’t need loading.

---

In [9]:
config = GPT2Config.from_pretrained("gpt2")
hf_model = GPT2LMHeadModel.from_pretrained("gpt2")  # Use LMHeadModel to ensure lm_head weights are included
hf_state_dict = hf_model.state_dict()

rngs = nnx.Rngs(0)
model = GPT2Model(config, rngs=rngs)

print("Hugging Face state dict keys:", list(hf_state_dict.keys()))

def update_model_state(model, hf_state_dict):
    state = nnx.state(model)

    # Add 'transformer.' prefix to all keys
    transformer_prefix = 'transformer.'

    # Embeddings
    state.wte.embedding = nnx.Variable(jnp.array(hf_state_dict[f'{transformer_prefix}wte.weight'].cpu().numpy()))
    state.wpe.embedding = nnx.Variable(jnp.array(hf_state_dict[f'{transformer_prefix}wpe.weight'].cpu().numpy()))

    # Transformer blocks
    for i in range(config.n_layer):
        prefix = f'{transformer_prefix}h.{i}'
        state.h[i].ln_1.scale = nnx.Variable(jnp.array(hf_state_dict[f'{prefix}.ln_1.weight'].cpu().numpy()))
        state.h[i].ln_1.bias = nnx.Variable(jnp.array(hf_state_dict[f'{prefix}.ln_1.bias'].cpu().numpy()))
        state.h[i].ln_2.scale = nnx.Variable(jnp.array(hf_state_dict[f'{prefix}.ln_2.weight'].cpu().numpy()))
        state.h[i].ln_2.bias = nnx.Variable(jnp.array(hf_state_dict[f'{prefix}.ln_2.bias'].cpu().numpy()))

        # Attention weights
        attn_weight = jnp.array(hf_state_dict[f'{prefix}.attn.c_attn.weight'].cpu().numpy())
        if attn_weight.shape != (config.n_embd, 3 * config.n_embd):
            attn_weight = attn_weight.T  # Transpose if needed to (n_embd, 3 * n_embd)
        assert attn_weight.shape == (config.n_embd, 3 * config.n_embd), f"Expected {(config.n_embd, 3 * config.n_embd)}, got {attn_weight.shape}"
        state.h[i].attn.c_attn.kernel = nnx.Variable(attn_weight)
        state.h[i].attn.c_attn.bias = nnx.Variable(jnp.array(hf_state_dict[f'{prefix}.attn.c_attn.bias'].cpu().numpy()))

        # c_proj weights
        c_proj_weight = jnp.array(hf_state_dict[f'{prefix}.attn.c_proj.weight'].cpu().numpy())
        if c_proj_weight.shape != (config.n_embd, config.n_embd):
            c_proj_weight = c_proj_weight.T
        state.h[i].attn.c_proj.kernel = nnx.Variable(c_proj_weight)
        state.h[i].attn.c_proj.bias = nnx.Variable(jnp.array(hf_state_dict[f'{prefix}.attn.c_proj.bias'].cpu().numpy()))

        # MLP weights
        c_fc_weight = jnp.array(hf_state_dict[f'{prefix}.mlp.c_fc.weight'].cpu().numpy())
        if c_fc_weight.shape != (config.n_embd, config.n_inner or 4 * config.n_embd):
            c_fc_weight = c_fc_weight.T
        state.h[i].mlp.c_fc.kernel = nnx.Variable(c_fc_weight)
        state.h[i].mlp.c_fc.bias = nnx.Variable(jnp.array(hf_state_dict[f'{prefix}.mlp.c_fc.bias'].cpu().numpy()))

        c_proj_weight = jnp.array(hf_state_dict[f'{prefix}.mlp.c_proj.weight'].cpu().numpy())
        if c_proj_weight.shape != (config.n_inner or 4 * config.n_embd, config.n_embd):
            c_proj_weight = c_proj_weight.T
        state.h[i].mlp.c_proj.kernel = nnx.Variable(c_proj_weight)
        state.h[i].mlp.c_proj.bias = nnx.Variable(jnp.array(hf_state_dict[f'{prefix}.mlp.c_proj.bias'].cpu().numpy()))

    # Final LayerNorm
    state.ln_f.scale = nnx.Variable(jnp.array(hf_state_dict[f'{transformer_prefix}ln_f.weight'].cpu().numpy()))
    state.ln_f.bias = nnx.Variable(jnp.array(hf_state_dict[f'{transformer_prefix}ln_f.bias'].cpu().numpy()))

    if 'lm_head.weight' in hf_state_dict:
      print("lm_head weight found and loading...")
      # Get the lm_head weight and ensure it's in the right shape for nnx.Linear
      lm_head_weight = jnp.array(hf_state_dict['lm_head.weight'].cpu().numpy())
      print(f"Original lm_head weight shape: {lm_head_weight.shape}")

      # The nnx.Linear expects kernel shape (n_embd, vocab_size)
      if lm_head_weight.shape != (config.n_embd, config.vocab_size):
          if lm_head_weight.shape == (config.vocab_size, config.n_embd):
              lm_head_weight = lm_head_weight.T  # Transpose to (n_embd, vocab_size)
          else:
              print(f"Unexpected lm_head weight shape: {lm_head_weight.shape}")

      print(f"Final lm_head weight shape: {lm_head_weight.shape}")
      state.lm_head.kernel = nnx.Variable(lm_head_weight)
    else:
      # If lm_head.weight is not present, tie it to wte.embedding
      # Make sure the shape is correct: (n_embd, vocab_size)
      tied_weight = state.wte.embedding.value
      if tied_weight.shape == (config.vocab_size, config.n_embd):
          tied_weight = tied_weight.T  # Transpose to (n_embd, vocab_size)
      state.lm_head.kernel = nnx.Variable(tied_weight)  # Tie weights
      print("lm_head weight not found, tying to wte.embedding")
      print(f"Tied lm_head weight shape: {state.lm_head.kernel.value.shape}")

    nnx.update(model, state)

# Apply the update
update_model_state(model, hf_state_dict)

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Hugging Face state dict keys: ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.weight', 'tr



## JIT Compilation

JAX's JIT (Just-In-Time) compilation optimizes the model's forward pass for faster execution by compiling the computation graph.

### Define and JIT the Forward Function



### Why Use JIT?

- **Performance**: JIT compiles the function into efficient machine code, leveraging XLA (Accelerated Linear Algebra) for hardware acceleration (e.g., GPU/TPU).
- **Optimization**: It fuses operations, reducing overhead and improving speed, especially for repeated calls.

**Note**: Since `model` is an NNX object with parameters, JIT treats it as a static argument, compiling the computation based on its structure and parameter shapes.


In [10]:
def forward(model, input_ids, rngs, deterministic):
    return model(input_ids, rngs, deterministic)

# JIT compile the forward function
jit_forward = jax.jit(forward, static_argnums=(0,3))



In [26]:

from typing import Optional, Union, Dict, Any, List


def generate(
    model,
    tokenizer,
    prompt: str,
    max_length: int = 50,
    temperature: float = 0.3,
    top_k: int = 5,
    top_p: float = 0.95,
    repetition_penalty: float =5,
    eos_token_id: Optional[int] = None,
    pad_token_id: Optional[int] = None,
    do_sample: bool = True,
    seed: int = 42,
) -> str:
    """
    Generate text using a Flax GPT-2 model with JAX.

    Args:
        model: A FlaxGPT2LMHeadModel instance
        tokenizer: A tokenizer from the transformers library
        prompt: Initial text to condition generation
        max_length: Maximum length of generated sequence (including prompt)
        temperature: Controls randomness. Lower means more deterministic.
        top_k: Number of highest probability tokens to consider for sampling
        top_p: Cumulative probability threshold for nucleus sampling
        repetition_penalty: Penalty for repeating tokens (1.0 = no penalty)
        eos_token_id: Token ID that signals the end of generation
        pad_token_id: Token ID for padding
        do_sample: If False, use greedy decoding instead of sampling
        seed: Random seed for reproducibility

    Returns:
        Generated text as a string
    """
    # Set default values for special tokens if not provided
    if eos_token_id is None:
        eos_token_id = tokenizer.eos_token_id
    if pad_token_id is None:
        pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

    # Encode the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="np")

    # Create a JAX random key
    key = jax.random.PRNGKey(seed)

    input_ids = jax.device_put(input_ids)  # Move input to TPU
    # Prepare inputs for the model
    model_inputs = {"input_ids": input_ids}



    # Create a function to calculate the next token probabilities
    def get_next_token_logits(current_input_ids, key):
        # Get logits from the model
        rngs = nnx.Rngs(seed)
        outputs = jit_forward(model,rngs=key,input_ids=current_input_ids,deterministic=True)
        logits = outputs

        # Get the logits for the last token
        next_token_logits = logits[:, -1, :]

        # Apply repetition penalty
        if repetition_penalty != 1.0:
            # Create a mask for tokens that have appeared in the sequence
            unique_tokens = jnp.unique(current_input_ids)
            # Apply penalty to these tokens' logits (divide by penalty if logit > 0, multiply if logit < 0)
            penalty_mask = jnp.zeros_like(next_token_logits)

            # JAX-friendly approach to apply penalty
            def update_penalty(carry, token_id):
                logits, mask = carry
                token_mask = jnp.ones_like(logits) * (jnp.arange(logits.shape[-1]) == token_id)
                scaled_logits = jnp.where(
                    logits > 0,
                    logits / repetition_penalty,
                    logits * repetition_penalty
                )
                new_logits = jnp.where(token_mask, scaled_logits, logits)
                new_mask = mask + token_mask
                return (new_logits, new_mask), None

            (next_token_logits, _), _ = jax.lax.scan(
                update_penalty,
                (next_token_logits, penalty_mask),
                unique_tokens
            )

        # Apply temperature scaling
        if temperature > 0 and temperature != 1.0:
            next_token_logits = next_token_logits / temperature

        return next_token_logits, key

    # Create the sampling function
    def sample_token(logits, key):
        # Apply top-k filtering
        if top_k > 0:
            top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k)
            # Create a "impossible" logit value for tokens not in top-k
            min_value = jnp.min(top_k_logits) - 1e10
            # Create a mapping from original indices to top-k logits or min_value
            logits = jnp.full_like(logits, min_value)
            logits = logits.at[top_k_indices].set(top_k_logits)

        # Apply top-p (nucleus) filtering
        if 0.0 < top_p < 1.0:
            # Sort logits in descending order
            sorted_logits = jnp.sort(logits, descending=True)
            # Calculate cumulative probabilities
            sorted_probs = jax.nn.softmax(sorted_logits)
            cumulative_probs = jnp.cumsum(sorted_probs)
            # Find the threshold where cumulative probability exceeds top_p
            sorted_indices_to_remove = cumulative_probs > top_p
            # Keep the first token that exceeds the threshold
            sorted_indices_to_remove = jnp.concatenate([
                jnp.zeros(1, dtype=jnp.bool_),
                sorted_indices_to_remove[:-1]
            ])
            # Get the cutoff threshold value
            threshold_logit = jnp.min(
                jnp.where(sorted_indices_to_remove, sorted_logits, jnp.inf)
            )
            # Filter logits below the threshold value
            logits = jnp.where(logits < threshold_logit, -jnp.inf, logits)

        # Sample from the filtered distribution
        if do_sample:
            key, subkey = jax.random.split(key)
            next_token = jax.random.categorical(subkey, logits, axis=-1)
        else:
            # Greedy decoding - take the argmax
            next_token = jnp.argmax(logits, axis=-1)

        return next_token, key

    # Generate tokens one by one
    generated = input_ids.copy()[0]
    for _ in range(max_length - len(input_ids[0])):
        # Get next token logits
        next_token_logits, key = get_next_token_logits(generated[None, :], key)

        # Sample the next token
        next_token, key = sample_token(next_token_logits[0], key)

        # Add the token to our generated sequence
        generated = np.append(generated, next_token)

        # Check if we've generated an EOS token
        if next_token == eos_token_id:
            break

    # Decode the generated ids back to text
    generated_text = tokenizer.decode(generated, skip_special_tokens=True)


    return generated_text

In [44]:
print(generate(model,GPT2Tokenizer.from_pretrained("gpt2"),"Hi my name is",max_length=35))

Hi my name is David. I'm a freelance writer and the owner of The Bookstore, where you can find books by authors like Robert Heinlein or Thomas Pynch
