In [None]:
#@title ## 1. Setup and Imports

# Install necessary libraries (uncomment if running locally/outside Colab)
!pip install -q flax optax tiktoken einops

import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn            # Flax neural network library
from flax.training import train_state # Helper for managing model state
import optax                       # Optimization library
import tiktoken                    # Tokenizer for text processing
import functools                   # For function manipulation (e.g., partial)
from einops import rearrange       # For tensor manipulation
import urllib.request              # To download data
import os                          # Operating system interactions
from tqdm.notebook import tqdm

# Check for available hardware accelerator (GPU/TPU) and set the default device
try:
    # Check for TPU (primarily for Google Colab)
    import jax.tools.colab_tpu
    # To use TPU: Uncomment the next line and select TPU in Colab Runtime settings
    # jax.tools.colab_tpu.setup_tpu()
    # device = jax.devices('tpu')[0]
    # print("Using TPU")

    # Check for GPU
    # To use GPU: In Colab, go to Runtime -> Change runtime type -> Hardware accelerator -> GPU
    if jax.devices('gpu'):
      device = jax.devices('gpu')[0]
      print("Using GPU")
    else:
      # Fallback to CPU if no GPU found (and TPU not explicitly enabled)
      device = jax.devices('cpu')[0]
      print("GPU not found, using CPU")
except ImportError:
    # Handle cases outside Colab or where TPU tools aren't installed
    if jax.devices('gpu'):
      device = jax.devices('gpu')[0]
      print("Using GPU")
    else:
      # Fallback to CPU if no GPU/TPU detected
      device = jax.devices('cpu')[0]
      print("Using CPU")

print(f"Using device: {device}")

In [None]:
#@title ## 2. Data Preparation (In-Memory Focus)
# --- Download and Read Data (Happens ONCE) ---
# Download sample text data ("The Verdict" by Edith Wharton)
url = ("https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/"
       "main/ch02/01_main-chapter-code/the-verdict.txt")
file_path = "the-verdict.txt"
if not os.path.exists(file_path):
    print(f"Downloading {file_path}...")
    with urllib.request.urlopen(url) as response:
        text_data = response.read().decode('utf-8')
    with open(file_path, "w", encoding="utf-8") as file:
        file.write(text_data)
    print("Download complete.")
else:
    print(f"File {file_path} already exists.")

# Read the entire file content into RAM
print("Reading data from disk into RAM...")
with open(file_path, "r", encoding="utf-8") as file:
    text_data = file.read()
print(f"Loaded text data into RAM ({len(text_data)} characters)")
print("First 100 chars:", text_data[:100])

In [None]:
# --- Tokenization using tiktoken ---
# Initialize the GPT-2 tokenizer
print("Initializing tokenizer...")
tokenizer = tiktoken.get_encoding("gpt2")
vocab_size = tokenizer.n_vocab # Extract vocabulary size from the tokenizer
print(f"Tokenizer vocabulary size: {vocab_size}")

# Convert the raw text data into a sequence of token IDs
# This operation occurs in RAM.
print("Encoding text data into token IDs (in RAM)...")
encoded_text = tokenizer.encode(text_data)

# Convert the Python list of token IDs into a JAX array for efficient processing
# This array resides in RAM or accelerator (GPU/TPU) memory.
print("Converting token IDs to JAX array (in RAM/Device Memory)...")
encoded_text_jax = jnp.array(encoded_text, dtype=jnp.int32)
print(f"Encoded text stored as JAX array with shape: {encoded_text_jax.shape}")
print(encoded_text_jax[:50])

In [None]:
# --- Create Training and Validation Sets ---
# Split the encoded data into training (90%) and validation (10%) sets.
# These are created as views (slices) of the main JAX array in memory, avoiding data duplication.
print("Splitting data into train/validation sets (in RAM/Device Memory)...")
train_ratio = 0.90
split_idx = int(train_ratio * len(encoded_text_jax))
train_data = encoded_text_jax[:split_idx] # Slice for training data
val_data = encoded_text_jax[split_idx:]   # Slice for validation data

print(f"Training data shape: {train_data.shape}")
print(f"Validation data shape: {val_data.shape}")

# --- Data Loader / Batching Function ---
# Define a function to generate batches for training or evaluation.
# context_length: The number of tokens in each input sequence.
# batch_size: The number of sequences processed in parallel per batch.
def create_batches(data, batch_size, context_length, key):
    """
    Generates batches of (input sequence, target sequence) pairs from the data.

    Args:
        data: A JAX array containing the token IDs (e.g., train_data or val_data).
        batch_size: Number of sequences per batch.
        context_length: Length of each input sequence.
        key: JAX PRNG key for shuffling.

    Yields:
        tuple: A batch containing (x_batch, y_batch).
               x_batch: Input sequences (batch_size, context_length).
               y_batch: Target sequences (batch_size, context_length), shifted by one token.

    Note: Operates entirely on the provided JAX array in memory, using slicing and stacking.
          It does not perform disk I/O. Implicitly shuffles data via random starting indices.
    """
    # Total number of possible starting positions for sequences
    num_sequences = len(data) - context_length
    if num_sequences <= 0:
        raise ValueError("Dataset is too small for the given context length.")

    # Generate a random permutation of starting indices for sequences
    idxs = jax.random.permutation(key, num_sequences)

    # Calculate the number of full batches that can be created
    # Drops the last potentially smaller batch to ensure consistent batch shapes,
    # which is important for efficient JIT compilation.
    num_batches = num_sequences // batch_size

    print(f"Creating batches: {num_batches} batches of size {batch_size}...")
    # Iterate through the indices to form batches
    for i in range(num_batches):
        # Select indices for the current batch
        batch_idxs = idxs[i * batch_size : (i + 1) * batch_size]
        # Create input sequences by slicing the data array
        x_batch = jnp.stack([data[idx : idx + context_length] for idx in batch_idxs])
        # Create target sequences (input shifted by one token) by slicing
        y_batch = jnp.stack([data[idx + 1 : idx + context_length + 1] for idx in batch_idxs])
        yield x_batch, y_batch

In [None]:
# --- Model Configuration ---
# Define hyperparameters for the transformer model and training
config = {
    "vocab_size": vocab_size,       # Size of the token vocabulary
    "context_length": 128,      # Max sequence length for the model
    "emb_dim": 64,             # Dimension of token embeddings
    "n_heads": 4,               # Number of attention heads
    "n_layers": 4,              # Number of transformer blocks (layers)
    "drop_rate": 0.1,           # Dropout rate for regularization
    "qkv_bias": False,          # Whether to use bias in QKV projections
    "batch_size": 32,          # Number of sequences per training batch
}

# --- Demonstrate Batch Generation ---
# Create a JAX PRNG key for reproducibility in batch generation
data_key = random.PRNGKey(0)
# Create a generator for training batches using the function and training data
# The train_data array is already in memory.
batch_generator = create_batches(train_data, config["batch_size"], config["context_length"], data_key)
# Get the first batch from the generator
x_example, y_example = next(batch_generator)

# Print shapes and examples from the first batch
print("\nExample Input Batch Shape:", x_example.shape)
print("Example Target Batch Shape:", y_example.shape)
print("Example Input Batch (first 5 tokens):", x_example[0, :5])
print("Example Target Batch (first 5 tokens):", y_example[0, :5])

In [None]:
#@title ## 3. Transformer Components (JAX/Flax)

class TokenAndPositionalEmbedding(nn.Module):
    """Combines token embeddings and learnable absolute positional embeddings."""
    vocab_size: int      # Number of unique tokens in the vocabulary
    embed_dim: int       # Dimension of the embedding vectors
    context_length: int  # Maximum sequence length the model handles

    def setup(self):
        self.tok_emb = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim)
        # Learnable absolute positional embeddings
        self.pos_emb = nn.Embed(num_embeddings=self.context_length, features=self.embed_dim)

    def __call__(self, x):
        """
        Forward pass for combining token and positional embeddings.

        Args:
            x: Input token IDs, shape (batch_size, seq_len)

        Returns:
            Combined embeddings (token + position), shape (batch_size, seq_len, embed_dim)
        """
        seq_len = x.shape[1]
        #### Compelete Code here #######
        # ~ 3 lines

        ###################################
        # Add token and positional embeddings
        combined_embeddings = token_embeddings + position_embeddings
        return combined_embeddings

# --- Multi-Head Causal Self-Attention ---
class MultiHeadCausalSelfAttention(nn.Module):
    """
    Multi-head self-attention with causal masking, using separate Q, K, V projections.
    """
    embed_dim: int       # Total dimension of the embedding (e.g., 512)
    num_heads: int       # Number of attention heads (e.g., 8)
    use_bias: bool = False  # Whether to use bias in projections
    dropout_rate: float = 0.1 # Dropout rate

    def setup(self):
        assert self.embed_dim % self.num_heads == 0, "Embed dim must be divisible by num_heads"
        self.head_dim = self.embed_dim // self.num_heads # Dim per head (e.g., 64)

        # --- Separate Linear Projections for Q, K, V ---
        # Each projects from embed_dim to embed_dim

        self.q_proj = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, name="query_proj")
        self.k_proj = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, name="key_proj")
        self.v_proj = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, name="value_proj")

        # Output projection after combining heads
        self.out_proj = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, name="output_proj")

        # Dropout layers
        self.dropout_attn = nn.Dropout(rate=self.dropout_rate) # Dropout on attention weights
        self.dropout_out = nn.Dropout(rate=self.dropout_rate)  # Dropout on final output

    def __call__(self, x, deterministic: bool):
        """
        Forward pass for multi-head causal self-attention.

        Args:
            x: Input embeddings, shape (batch_size, seq_len, embed_dim)
            deterministic: If True, disables dropout (used during inference).

        Returns:
            Output context vectors after attention, shape (batch_size, seq_len, embed_dim)
        """
        batch_size, seq_len, _ = x.shape

        # --- Step 1: Project Input to Q, K, V Separately ---

        #### Compelete code here ########
        # ~ 3lines


        ##################################

        # --- Step 2: Split into Heads ---
        # Reshape and transpose each of Q, K, V for multi-head processing
        # 'b s (h d)' means: batch, sequence, (num_heads * head_dim)
        # '-> b h s d' means: transform into (batch, num_heads, sequence, head_dim)
        q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
        k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
        v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
        # Now q, k, v have shape: (batch, num_heads, seq_len, head_dim)

        # --- Step 3: Calculate Attention Scores ---
        # Scaled Dot-Product Attention: Q @ K^T / sqrt(head_dim)
        # MatMul: (b, h, s, d) @ (b, h, d, s) -> (b, h, s, s)


        #### Complete Code Here ############
        # ~ 2-3 line
        # compute attention scores, and normalize them

        ####################################

        # --- Step 4: Apply Causal Mask ---
        # Prevent attending to future tokens
        mask = nn.make_causal_mask(jnp.ones((batch_size, seq_len)), dtype=jnp.bool_)
        # Set masked positions to -inf before softmax
        # Mask shape is broadcastable: (batch_size, 1, seq_len, seq_len)
        attn_scores = jnp.where(mask, attn_scores, -jnp.inf)

        # --- Step 5: Calculate Attention Weights ---
        # Softmax converts scores to probabilities along the key sequence length dimension
        attn_weights = jax.nn.softmax(attn_scores, axis=-1) # Shape: (b, h, s, s)
        attn_weights = self.dropout_attn(attn_weights, deterministic=deterministic)

        # --- Step 6: Calculate Context Vectors ---
        # Weighted sum of Value vectors: Weights @ V
        # MatMul: (b, h, s, s) @ (b, h, s, d) -> (b, h, s, d)
        context_vec = jnp.matmul(attn_weights, v) # Shape: (batch, num_heads, seq_len, head_dim)

        # --- Step 7: Combine Heads ---
        # Rearrange back to merge head dimension into embedding dimension
        # 'b h s d' -> 'b s (h d)' which is (batch, seq_len, embed_dim)
        context_combined = rearrange(context_vec, 'b h s d -> b s (h d)')

        # --- Step 8: Final Output Projection ---
        # Apply final linear layer to mix head information
        output = self.out_proj(context_combined) # Shape: (batch, seq_len, embed_dim)
        output = self.dropout_out(output, deterministic=deterministic)

        return output
# --- Feed Forward Network (Position-wise) ---
class FeedForward(nn.Module):
    """A simple two-layer feed-forward network applied position-wise."""
    embed_dim: int       # Input and output dimension
    dropout_rate: float = 0.1 # Dropout rate

    @nn.compact
    def __call__(self, x, deterministic: bool):
        """
        Forward pass for the feed-forward network.

        Args:
            x: Input tensor, shape (batch_size, seq_len, embed_dim)
            deterministic: If True, disables dropout.

        Returns:
            Output tensor, shape (batch_size, seq_len, embed_dim)
        """
        # Typically expand to 4x embed_dim in the hidden layer
        hidden_dim = 4 * self.embed_dim
        ############## Complete code here ##############
        ## apply FFW block
        # ~ 3lines

        ################################################
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        return x

# --- Transformer Block ---
class TransformerBlock(nn.Module):
    """A single block of the Transformer architecture (Pre-LN variant)."""
    embed_dim: int       # Embedding dimension
    num_heads: int       # Number of attention heads
    use_bias: bool       # Whether to use bias in projections
    dropout_rate: float  # Dropout rate

    @nn.compact
    def __call__(self, x, deterministic: bool):
        """
        Forward pass for a Transformer block using Pre-Layer Normalization.

        Args:
            x: Input tensor, shape (batch_size, seq_len, embed_dim)
            deterministic: If True, disables dropout.

        Returns:
            Output tensor, shape (batch_size, seq_len, embed_dim)
        """
        # --- Attention Sub-layer (Pre-LN) ---
        # y = gamma * (x - mean(x)) / sqrt(variance(x) + epsilon) + beta

        ###################### Complete the block  ######################
        ## Attention part


        x = x + attn_output # Residual connection
        #################################################


        # --- Feed Forward Sub-layer (Pre-LN) ---
        ###################### Complete the block  ######################
        ## Feed Forward part


        x = x + ffn_output # Residual connection
        #################################################

        return x

In [None]:
#@title ## 4. GPT Model Architecture (JAX/Flax)
class GPT(nn.Module):
    """Defines the GPT (Generative Pre-trained Transformer) model architecture."""
    vocab_size: int      # Size of the vocabulary
    embed_dim: int       # Dimension of token and positional embeddings
    context_length: int  # Maximum sequence length the model can process
    num_heads: int       # Number of attention heads in each Transformer block
    num_layers: int      # Number of Transformer blocks stacked
    use_bias: bool       # Whether to use bias in linear layers (Dense, projections)
    dropout_rate: float  # Dropout rate for regularization

    @nn.compact
    def __call__(self, idx, deterministic: bool):
        """
        Forward pass of the GPT model.

        Args:
            idx: Input token indices, shape (batch_size, seq_len).
            deterministic: If True, disables dropout (used during inference).

        Returns:
            logits: Output logits over the vocabulary, shape (batch_size, seq_len, vocab_size).
        """
        # 1. Input Embeddings (Combine Token and Positional Embeddings)
        # Input: (batch, seq_len) -> Output: (batch, seq_len, embed_dim)
        x = TokenAndPositionalEmbedding(
            vocab_size=self.vocab_size,
            embed_dim=self.embed_dim,
            context_length=self.context_length,
            name="embedding" # Added name for clarity
        )(idx)
        # Apply dropout to the combined embeddings (if training)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

        # 2. Stacked Transformer Blocks
        # Process the sequence through multiple layers of Transformer blocks.
        # Input/Output shape for each block: (batch, seq_len, embed_dim)
        for i in range(self.num_layers):
            x = TransformerBlock(
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                use_bias=self.use_bias,
                dropout_rate=self.dropout_rate,
                name=f"transformer_block_{i}" # Added name for clarity
            )(x, deterministic=deterministic)

        # 3. Final Layer Normalization (applied after the last block)
        # Stabilizes the inputs to the final linear layer.
        x = nn.LayerNorm(epsilon=1e-5, name="final_ln")(x) # Added name for clarity

        # 4. Output Head (Linear layer to map to vocabulary size)
        # Projects the final transformer output to vocabulary-sized logits.
        # Input: (batch, seq_len, embed_dim) -> Output: (batch, seq_len, vocab_size)
        # Note: Weight tying between embedding and output layer is not used here.
        logits = nn.Dense(features=self.vocab_size, use_bias=False, name="output_projection")(x)
        # logits represent the unnormalized scores for each token in the vocabulary
        # at each position in the sequence.
        return logits

In [None]:
#@title ## 5. Pre-training Loop (Next Token Prediction)

# --- Loss Function (Cross-Entropy) ---
@functools.partial(jax.jit) # JIT compile for speed
def cross_entropy_loss(logits, targets):
    """Calculates average cross-entropy loss for language modeling."""
    # logits: (batch, seq_len, vocab_size)
    # targets: (batch, seq_len)

    # Convert targets to one-hot encoding
    one_hot_targets = jax.nn.one_hot(targets, num_classes=logits.shape[-1])
    # Calculate log-probabilities (log-softmax is numerically stable)
    log_softmax_logits = jax.nn.log_softmax(logits, axis=-1)
    # Calculate loss per position
    loss_per_position = -jnp.sum(one_hot_targets * log_softmax_logits, axis=-1)
    # Average loss over batch and sequence length
    return jnp.mean(loss_per_position)

In [None]:
# --- Training Step ---
# JIT compile the training step for performance.
# `model_apply` and `learning_rate_fn` are static to prevent recompilation.
@functools.partial(jax.jit, static_argnames=['model_apply', 'learning_rate_fn'])
def train_step(state, batch, dropout_key, model_apply, learning_rate_fn):
    """Performs a single gradient update step."""
    x, y = batch

    # Define loss function for gradient calculation
    def compute_loss(params):
        # Forward pass with dropout enabled
        logits = model_apply({'params': params}, x, deterministic=False, rngs={'dropout': dropout_key})
        loss = cross_entropy_loss(logits, y)
        return loss

    # Compute loss and gradients
    grad_fn = jax.value_and_grad(compute_loss)
    loss, grads = grad_fn(state.params)

    # Update model state (apply gradients, update optimizer state, increment step)
    state = state.apply_gradients(grads=grads)

    # Collect metrics
    lr = learning_rate_fn(state.step)
    metrics = {'loss': loss, 'learning_rate': lr}
    return state, metrics

# --- Evaluation Step ---
# JIT compile the evaluation step. `model_apply` is static.
@functools.partial(jax.jit, static_argnames=['model_apply'])
def eval_step(state, batch, model_apply):
    """Performs a single evaluation step (no gradients)."""
    x, y = batch
    # Forward pass with dropout disabled
    logits = model_apply({'params': state.params}, x, deterministic=True)
    loss = cross_entropy_loss(logits, y)
    # Return evaluation loss
    return {'loss': loss}

In [None]:
# --- Optimizer and Train State ---

# Configure the AdamW optimizer
learning_rate = 1e-4
tx = optax.adamw(learning_rate=learning_rate, weight_decay=0.1)

# Generate JAX PRNG keys for reproducible initialization
model_key, params_key, dropout_key_init = random.split(random.PRNGKey(123), 3)

# Instantiate the GPT model
model = GPT(
    vocab_size=config["vocab_size"],
    embed_dim=config["emb_dim"],
    context_length=config["context_length"],
    num_heads=config["n_heads"],
    num_layers=config["n_layers"],
    use_bias=config["qkv_bias"],
    dropout_rate=config["drop_rate"]
)

# Initialize model parameters using a dummy input and the params_key.
# `deterministic=True` disables dropout during initialization.
dummy_input = jnp.ones((1, config["context_length"]), dtype=jnp.int32)
params = model.init(params_key, dummy_input, deterministic=True)['params']

# Create the training state to bundle model apply function, parameters, and optimizer state.
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# Move the training state to the specified JAX device (GPU/TPU/CPU).
state = jax.device_put(state, device)

# Print the total number of model parameters.
param_count = sum(p.size for p in jax.tree_util.tree_leaves(state.params))
print(f"Model initialized with {param_count:,} parameters.")

In [None]:

# --- Training Loop ---
num_epochs = 1 # Number of full passes over the training data
eval_frequency = 1000 # How often to evaluate on validation data (in steps)

# Main PRNG key for the training loop
train_key = random.PRNGKey(42)

print(f"Starting training for {num_epochs} epochs...")

for epoch in range(num_epochs):
    print(f"--- Epoch {epoch+1}/{num_epochs} ---")
    epoch_train_loss = 0.0
    num_train_batches = 0

    # Create a new batch generator for each epoch with a unique key for shuffling
    epoch_key, train_key = random.split(train_key)
    batch_generator = create_batches(train_data, config["batch_size"], config["context_length"], epoch_key)

    # Calculate total steps for the tqdm progress bar
    # Note: This calculation should match the logic inside create_batches
    num_sequences = len(train_data) - config["context_length"]
    total_steps_per_epoch = num_sequences // config["batch_size"] # Drops partial batch, matching create_batches

    # Wrap the batch generator with tqdm for a progress bar
    pbar = tqdm(enumerate(batch_generator),
                total=total_steps_per_epoch,
                desc=f"Epoch {epoch+1} Training")

    # Iterate over training batches for the current epoch
    for step, train_batch in pbar:
        # Move the current training batch to the target device
        train_batch = jax.device_put(train_batch, device)

        # Generate a unique dropout key for this specific training step
        dropout_key_step = random.fold_in(dropout_key_init, state.step)

        # Perform a single training step
        state, train_metrics = train_step(state, train_batch, dropout_key_step, model.apply, lambda step: learning_rate)
        epoch_train_loss += train_metrics['loss']
        num_train_batches += 1

        # --- Logging and Evaluation Periodically ---
        if (step + 1) % eval_frequency == 0:
            # Calculate average training loss over the interval
            # Use jax.device_get to bring scalar loss values back from GPU/TPU if needed for calculation/display
            avg_train_loss = jax.device_get(epoch_train_loss / num_train_batches)

            # Evaluate performance on the validation set
            val_loss = 0.0
            num_val_batches = 0
            val_key, train_key = random.split(train_key)
            val_batch_generator = create_batches(val_data, config["batch_size"], config["context_length"], val_key)
            # Optional: Wrap validation loop in tqdm as well
            # val_pbar = tqdm(val_batch_generator, desc="Validation", leave=False)
            # for val_batch in val_pbar:
            for val_batch in val_batch_generator:
                val_batch = jax.device_put(val_batch, device)
                eval_metrics = eval_step(state, val_batch, model.apply)
                val_loss += eval_metrics['loss']
                num_val_batches += 1

            # Calculate average validation loss
            avg_val_loss = jax.device_get(val_loss / num_val_batches) if num_val_batches > 0 else 0.0

            # Update the progress bar postfix with the latest loss values
            pbar.set_postfix(TrainLoss=f"{avg_train_loss:.4f}", ValLoss=f"{avg_val_loss:.4f}")

            # Log performance metrics (using print might interfere slightly with tqdm, but often ok)
            # Alternatively, use tqdm.write()
            # tqdm.write(f"  Step: {state.step:>5} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {train_metrics['learning_rate']:.6f}")
            print(f"\n  Step: {state.step:>5} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {jax.device_get(train_metrics['learning_rate']):.6f}")

            # Reset training loss accumulator for the next evaluation interval
            epoch_train_loss = 0.0
            num_train_batches = 0

    # Ensure the progress bar finishes cleanly for the epoch
    pbar.close()
    print(f"--- Epoch {epoch+1} finished ---")

print("Training complete.")


In [None]:
#@title ## 6. Text Generation

# JIT compile for speed, static args prevent recompilation for config changes.
@functools.partial(jax.jit, static_argnames=['model_apply', 'max_new_tokens', 'context_length', 'temperature', 'top_k'])
def generate_text(state, prompt_ids, max_new_tokens, context_length, model_apply, temperature=1.0, top_k=None, key=random.PRNGKey(0)):
    """Generates text autoregressively from a prompt."""
    current_ids = prompt_ids

    for _ in range(max_new_tokens):
        # Truncate context to the model's maximum length
        context_ids = current_ids[:, -context_length:]

        # Get logits using the model in evaluation mode (deterministic=True)
        logits = model_apply({'params': state.params}, context_ids, deterministic=True)

        # Use only the logits for the last token to predict the next token
        last_token_logits = logits[:, -1, :] # Shape: (1, vocab_size)

        # --- Sampling ---
        if temperature <= 0:
            # Greedy sampling: take the most probable token
            next_token_id = jnp.argmax(last_token_logits, axis=-1)
        else:
            # Temperature scaling adjusts randomness
            scaled_logits = last_token_logits / temperature

            # Optional Top-K sampling limits choices to the K most likely tokens
            if top_k is not None:
                top_logits, _ = jax.lax.top_k(scaled_logits, k=top_k)
                min_top_logit = top_logits[:, -1]
                # Mask logits below the top K threshold
                mask = scaled_logits < min_top_logit
                scaled_logits = jnp.where(mask, -jnp.inf, scaled_logits)

            # Sample from the adjusted probability distribution
            key, subkey = random.split(key)
            next_token_id = random.categorical(subkey, scaled_logits, axis=-1) # Shape: (1,)

        # Append the sampled token ID to the sequence
        current_ids = jnp.concatenate([current_ids, next_token_id[:, None]], axis=1)

    return current_ids

In [None]:

# --- Generate Example Text ---
start_context = "Hello, I am"
start_ids = jnp.array(tokenizer.encode(start_context), dtype=jnp.int32)[None, :] # Add batch dim
start_ids = jax.device_put(start_ids, device) # Move prompt to device

print(f"\nGenerating text starting with: '{start_context}'")

# Use the trained state
generation_key = random.PRNGKey(567)
generated_ids = generate_text(
    state=state,
    prompt_ids=start_ids,
    max_new_tokens=10,
    context_length=config["context_length"],
    model_apply=model.apply,
    temperature=0.7, # Add some randomness
    top_k=50,        # Consider top 50 tokens
    key=generation_key
)

# Decode the generated token IDs back to text
generated_text = tokenizer.decode(generated_ids[0].tolist()) # Remove batch dim before decoding
print("\nGenerated Text:")
print(generated_text)


## 7. Conclusion

This notebook demonstrated the fundamentals of building and training a decoder-only Transformer for language modeling using JAX and Flax:

1.  Text data preparation and tokenization.
2.  Implementation of core Transformer components (Embeddings, Attention, LayerNorm, FFN).
3.  Construction of the full GPT model architecture.
4.  Implementation of a next-token prediction pre-training loop.
5.  Basic training on a single accelerator.
6.  Autoregressive text generation with the trained model.

This provides a foundation for understanding how such language models work.

Further steps to explore could include:
* Training for more epochs or using larger datasets.
* Implementing more sophisticated data loading pipelines.
* Experimenting with different hyperparameters (model size, learning rate, etc.).
* Adding techniques like learning rate scheduling or gradient clipping.