In [10]:
#@title ## Complete Transformer Code (Gradient Accumulation + Manual Mixed Precision)

# -----------------------------------------------------------------------------
# Section 1: Setup and Imports
# -----------------------------------------------------------------------------
print("--- Section 1: Setup and Imports ---")
# Install necessary libraries
!pip install -q flax optax tiktoken einops

--- Section 1: Setup and Imports ---


In [11]:
import jax
import flax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
from flax.training import train_state
import optax
import tiktoken
import functools
from einops import rearrange
import urllib.request
import os
from tqdm.notebook import tqdm
import numpy as np

# Determine device
try:
    if jax.devices('gpu'):
        device = jax.devices('gpu')[0]
        print("Using GPU")
    else:
        device = jax.devices('cpu')[0]
        print("No GPU found, using CPU")
except Exception:
    if jax.devices('gpu'):
        device = jax.devices('gpu')[0]
        print("Using GPU")
    else:
        device = jax.devices('cpu')[0]
        print("Using CPU")
print(f"Using device: {device}")

Using GPU
Using device: cuda:0


In [12]:
# -----------------------------------------------------------------------------
# Section 2: Data Preparation
# -----------------------------------------------------------------------------
print("\n--- Section 2: Data Preparation ---")
url = ("https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/"
       "main/ch02/01_main-chapter-code/the-verdict.txt")
file_path = "the-verdict.txt"
if not os.path.exists(file_path):
    print(f"Downloading {file_path}...")
    with urllib.request.urlopen(url) as response:
        text_data = response.read().decode('utf-8')
    with open(file_path, "w", encoding="utf-8") as file:
        file.write(text_data)
    print("Download complete.")
else:
    print(f"File {file_path} already exists.")

print("Reading data from disk into RAM...")
with open(file_path, "r", encoding="utf-8") as file:
    text_data = file.read()
print(f"Loaded text data into RAM ({len(text_data)} characters)")

print("Initializing tokenizer...")
tokenizer = tiktoken.get_encoding("gpt2")
vocab_size = tokenizer.n_vocab
print(f"Tokenizer vocabulary size: {vocab_size}")

print("Encoding text data into token IDs (in RAM)...")
encoded_text = tokenizer.encode(text_data)
print("Converting token IDs to JAX array...")
encoded_text_jax = jnp.array(encoded_text, dtype=jnp.int32)
print(f"Encoded text stored as JAX array with shape: {encoded_text_jax.shape}")

print("Splitting data into train/validation sets...")
train_ratio = 0.90
split_idx = int(train_ratio * len(encoded_text_jax))
train_data = encoded_text_jax[:split_idx]
val_data = encoded_text_jax[split_idx:]
print(f"Training data shape: {train_data.shape}")
print(f"Validation data shape: {val_data.shape}")

def create_batches(data, batch_size, context_length, key):
    num_sequences = len(data) - context_length
    if num_sequences <= 0:
        raise ValueError("Dataset is too small for the given context length.")
    idxs = jax.random.permutation(key, num_sequences)
    num_batches = num_sequences // batch_size
    for i in range(num_batches):
        batch_idxs = idxs[i * batch_size:(i + 1) * batch_size]
        x_batch = jnp.stack([data[idx : idx + context_length] for idx in batch_idxs])
        y_batch = jnp.stack([data[idx + 1 : idx + context_length + 1] for idx in batch_idxs])
        yield x_batch, y_batch



--- Section 2: Data Preparation ---
File the-verdict.txt already exists.
Reading data from disk into RAM...
Loaded text data into RAM (20479 characters)
Initializing tokenizer...
Tokenizer vocabulary size: 50257
Encoding text data into token IDs (in RAM)...
Converting token IDs to JAX array...
Encoded text stored as JAX array with shape: (5145,)
Splitting data into train/validation sets...
Training data shape: (4630,)
Validation data shape: (515,)


In [13]:
# -----------------------------------------------------------------------------
# Section 3: Model Configuration
# -----------------------------------------------------------------------------
print("\n--- Section 3: Model Configuration ---")
config = {
    "vocab_size": vocab_size,
    "context_length": 128,
    "emb_dim": 128,
    "n_heads": 8,
    "n_layers": 6,
    "qkv_bias": False,
    "batch_size": 16,
    "gradient_accumulation_steps": 8,
    "compute_dtype": jnp.bfloat16,
    "param_dtype": jnp.float32,
    "output_head_dtype": jnp.float32,
}
effective_batch_size = config["batch_size"] * config["gradient_accumulation_steps"]
print(f"Micro-batch size: {config['batch_size']}")
print(f"Gradient Accumulation Steps: {config['gradient_accumulation_steps']}")
print(f"Effective Batch Size: {effective_batch_size}")
print(f"Compute dtype: {config['compute_dtype']}")
print(f"Parameter dtype: {config['param_dtype']}")
print(f"Output Head dtype: {config['output_head_dtype']}")

compute_dtype = config["compute_dtype"]
param_dtype = config["param_dtype"]
output_head_dtype = config["output_head_dtype"]

data_key = random.PRNGKey(0)
batch_generator_demo = create_batches(train_data, config["batch_size"], config["context_length"], data_key)
x_example, y_example = next(batch_generator_demo)
print("\nExample Input Batch Shape:", x_example.shape)
print("Example Target Batch Shape:", y_example.shape)


--- Section 3: Model Configuration ---
Micro-batch size: 16
Gradient Accumulation Steps: 8
Effective Batch Size: 128
Compute dtype: <class 'jax.numpy.bfloat16'>
Parameter dtype: <class 'jax.numpy.float32'>
Output Head dtype: <class 'jax.numpy.float32'>

Example Input Batch Shape: (16, 128)
Example Target Batch Shape: (16, 128)


In [14]:
# -----------------------------------------------------------------------------
# Section 4: Transformer Components
# -----------------------------------------------------------------------------
print("\n--- Section 4: Transformer Components ---")

class TokenAndPositionalEmbedding(nn.Module):
    vocab_size: int
    embed_dim: int
    context_length: int
    compute_dtype: jnp.dtype = compute_dtype
    param_dtype: jnp.dtype = param_dtype

    def setup(self):
        self.tok_emb = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim, param_dtype=self.param_dtype)
        self.pos_emb = nn.Embed(num_embeddings=self.context_length, features=self.embed_dim, param_dtype=self.param_dtype)

    def __call__(self, x):
        seq_len = x.shape[1]
        token_embeddings = self.tok_emb(x)
        positions = jnp.arange(seq_len)
        position_embeddings = self.pos_emb(positions)
        combined_embeddings = token_embeddings + position_embeddings
        return combined_embeddings.astype(self.compute_dtype)

class MultiHeadCausalSelfAttention(nn.Module):
    embed_dim: int
    num_heads: int
    use_bias: bool = False
    compute_dtype: jnp.dtype = compute_dtype
    param_dtype: jnp.dtype = param_dtype

    def setup(self):
        assert self.embed_dim % self.num_heads == 0, "Embed dim must be divisible by num_heads"
        self.head_dim = self.embed_dim // self.num_heads
        self.q_proj = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, name="query_proj", dtype=self.compute_dtype, param_dtype=self.param_dtype)
        self.k_proj = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, name="key_proj", dtype=self.compute_dtype, param_dtype=self.param_dtype)
        self.v_proj = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, name="value_proj", dtype=self.compute_dtype, param_dtype=self.param_dtype)
        self.out_proj = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, name="output_proj", dtype=self.compute_dtype, param_dtype=self.param_dtype)

    def __call__(self, x):
        batch_size, seq_len, _ = x.shape
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)
        k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
        v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)

        attn_scores = jnp.matmul(q, jnp.transpose(k, (0, 1, 3, 2)))
        scale_factor = jnp.sqrt(self.head_dim).astype(self.compute_dtype)
        attn_scores = attn_scores / scale_factor

        mask = nn.make_causal_mask(jnp.ones((batch_size, seq_len)), dtype=jnp.bool_)
        attn_scores = jnp.where(mask, attn_scores, -jnp.inf)

        attn_weights = jax.nn.softmax(attn_scores.astype(jnp.float32), axis=-1)
        attn_weights = attn_weights.astype(self.compute_dtype)

        context_vec = jnp.matmul(attn_weights, v)
        context_combined = rearrange(context_vec, 'b h s d -> b s (h d)')
        output = self.out_proj(context_combined)
        return output

class FeedForward(nn.Module):
    embed_dim: int
    compute_dtype: jnp.dtype = compute_dtype
    param_dtype: jnp.dtype = param_dtype

    @nn.compact
    def __call__(self, x):
        hidden_dim = 4 * self.embed_dim
        x = nn.Dense(features=hidden_dim, dtype=self.compute_dtype, param_dtype=self.param_dtype)(x)
        x = nn.gelu(x)
        x = nn.Dense(features=self.embed_dim, dtype=self.compute_dtype, param_dtype=self.param_dtype)(x)
        return x

@nn.remat # Keep remat for memory saving
class TransformerBlock(nn.Module):
    embed_dim: int
    num_heads: int
    use_bias: bool
    compute_dtype: jnp.dtype = compute_dtype
    param_dtype: jnp.dtype = param_dtype


    @nn.compact
    def __call__(self, x):
        attn_input = nn.LayerNorm(epsilon=1e-5, param_dtype=self.param_dtype)(x)
        attn_output = MultiHeadCausalSelfAttention(
            embed_dim=self.embed_dim, num_heads=self.num_heads, use_bias=self.use_bias,
            compute_dtype=self.compute_dtype, param_dtype=self.param_dtype
        )(attn_input)
        x = x + attn_output

        ffn_input = nn.LayerNorm(epsilon=1e-5, param_dtype=self.param_dtype)(x)
        ffn_output = FeedForward(
            embed_dim=self.embed_dim,
            compute_dtype=self.compute_dtype, param_dtype=self.param_dtype
        )(ffn_input)
        x = x + ffn_output
        return x


--- Section 4: Transformer Components ---


In [15]:
# -----------------------------------------------------------------------------
# Section 5: GPT Model Architecture
# -----------------------------------------------------------------------------
print("\n--- Section 5: GPT Model Architecture (Manual MP + Remat) ---")

class GPT(nn.Module):
    """A Generative Pre-trained Transformer model."""
    vocab_size: int = config["vocab_size"]        # Size of the vocabulary.
    embed_dim: int = config["emb_dim"]            # Dimension of the token embeddings.
    context_length: int = config["context_length"]  # Maximum sequence length.
    num_heads: int = config["n_heads"]            # Number of attention heads.
    num_layers: int = config["n_layers"]          # Number of transformer blocks.
    use_bias: bool = config["qkv_bias"]           # Whether to use bias in QKV projections.
    compute_dtype: jnp.dtype = config["compute_dtype"]  # Data type for computation.
    param_dtype: jnp.dtype = config["param_dtype"]    # Data type for parameters.
    output_head_dtype: jnp.dtype = config["output_head_dtype"] # Data type for the final output layer.

    @nn.compact
    def __call__(self, idx):
        """
        Forward pass of the GPT model.

        Args:
            idx: Input tensor of token indices, shape (batch_size, sequence_length).

        Returns:
            Logits tensor, shape (batch_size, sequence_length, vocab_size).
        """
        # Generate token and positional embeddings.
        x = TokenAndPositionalEmbedding(
            vocab_size=self.vocab_size, embed_dim=self.embed_dim, context_length=self.context_length,
            compute_dtype=self.compute_dtype, param_dtype=self.param_dtype, name="embedding"
        )(idx)

        # Pass the embeddings through the transformer blocks.
        for i in range(self.num_layers):
            # Apply a single transformer block.
            x = TransformerBlock(
                embed_dim=self.embed_dim, num_heads=self.num_heads, use_bias=self.use_bias,
                compute_dtype=self.compute_dtype,
                param_dtype=self.param_dtype, name=f"transformer_block_{i}"
            )(x)

        # Apply final layer normalization.
        x = nn.LayerNorm(epsilon=1e-5, name="final_ln", param_dtype=self.param_dtype)(x)

        # Project the final hidden states to vocabulary size to get logits.
        logits = nn.Dense(
            features=self.vocab_size, use_bias=False, name="output_projection",
            dtype=self.compute_dtype, param_dtype=self.param_dtype
        )(x)
        # Cast logits to the specified output data type.
        logits = logits.astype(self.output_head_dtype)
        return logits


--- Section 5: GPT Model Architecture (Manual MP + Remat) ---


In [16]:
# -----------------------------------------------------------------------------
# Section 6: Training Setup
# -----------------------------------------------------------------------------
print("\n--- Section 6: Training Setup (Manual MP + Accum, No Dropout) ---")

class TrainStateWithAccum(train_state.TrainState):
    accum_grads: flax.core.FrozenDict


learning_rate = 1e-4
tx = optax.adamw(learning_rate=learning_rate, weight_decay=0.1)


model_key, params_key = random.split(random.PRNGKey(123), 2)
model = GPT()
dummy_input = jnp.ones((1, config["context_length"]), dtype=jnp.int32)
params = model.init(params_key, dummy_input)['params']


print("Verifying parameter dtypes (should be float32):")
jax.tree_util.tree_map(lambda p: print(f" - Param shape: {p.shape}, dtype: {p.dtype}"), params)

zero_grads = jax.tree_util.tree_map(
    lambda p: jnp.zeros_like(p, dtype=config["param_dtype"]), params
)

state = TrainStateWithAccum.create(
    apply_fn=model.apply, params=params, tx=tx, accum_grads=zero_grads
)
state = jax.device_put(state, device)

param_count = sum(p.size for p in jax.tree_util.tree_leaves(state.params))
print(f"\nModel initialized with {param_count:,} parameters.")
print(f"Accumulated gradients dtype: {jax.tree_util.tree_leaves(state.accum_grads)[0].dtype}")

@functools.partial(jax.jit)
def cross_entropy_loss(logits, targets):
    one_hot_targets = jax.nn.one_hot(targets, num_classes=logits.shape[-1])
    log_softmax_logits = jax.nn.log_softmax(logits.astype(config["output_head_dtype"]), axis=-1)
    loss_per_position = -jnp.sum(one_hot_targets * log_softmax_logits, axis=-1)
    return jnp.mean(loss_per_position)


@functools.partial(jax.jit, static_argnames=['model_apply'])
def compute_grads_step(state, batch, model_apply):
    x, y = batch
    def compute_loss(params):

        logits = model_apply({'params': params}, x)
        loss = cross_entropy_loss(logits, y)
        return loss
    grad_fn = jax.value_and_grad(compute_loss)
    loss, grads = grad_fn(state.params)
    metrics = {'loss': loss}
    return grads, metrics

@functools.partial(jax.jit, static_argnames=['learning_rate_fn'])
def apply_grads_step(state, accumulated_grads, learning_rate_fn):
    avg_grads = jax.tree_util.tree_map(lambda g: g / config["gradient_accumulation_steps"], accumulated_grads)
    new_state = state.apply_gradients(grads=avg_grads)
    zero_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params)
    new_state = new_state.replace(accum_grads=zero_grads)
    lr = learning_rate_fn(state.step)
    metrics = {'learning_rate': lr}
    return new_state, metrics

@functools.partial(jax.jit, static_argnames=['model_apply'])
def eval_step(state, batch, model_apply):
    x, y = batch
    logits = model_apply({'params': state.params}, x)
    loss = cross_entropy_loss(logits, y)
    return {'loss': loss}


--- Section 6: Training Setup (Manual MP + Accum, No Dropout) ---
Verifying parameter dtypes (should be float32):
 - Param shape: (128, 128), dtype: float32
 - Param shape: (50257, 128), dtype: float32
 - Param shape: (128,), dtype: float32
 - Param shape: (128,), dtype: float32
 - Param shape: (128, 50257), dtype: float32
 - Param shape: (512,), dtype: float32
 - Param shape: (128, 512), dtype: float32
 - Param shape: (128,), dtype: float32
 - Param shape: (512, 128), dtype: float32
 - Param shape: (128,), dtype: float32
 - Param shape: (128,), dtype: float32
 - Param shape: (128,), dtype: float32
 - Param shape: (128,), dtype: float32
 - Param shape: (128, 128), dtype: float32
 - Param shape: (128, 128), dtype: float32
 - Param shape: (128, 128), dtype: float32
 - Param shape: (128, 128), dtype: float32
 - Param shape: (512,), dtype: float32
 - Param shape: (128, 512), dtype: float32
 - Param shape: (128,), dtype: float32
 - Param shape: (512, 128), dtype: float32
 - Param shape: (1

In [17]:
# -----------------------------------------------------------------------------
# Section 7: Training Loop
# -----------------------------------------------------------------------------
print("\n--- Section 7: Training Loop (Manual MP + Accum + Remat) ---")

num_epochs = 1
eval_frequency = 200
accumulation_steps = config["gradient_accumulation_steps"]

train_key = random.PRNGKey(42)

print(f"Starting training for {num_epochs} epochs...")
print(f"Compute dtype: {config['compute_dtype']}, Param dtype: {config['param_dtype']}")
print(f"Gradient accumulation steps: {accumulation_steps}")
print(f"Evaluation frequency (optimizer steps): {eval_frequency}")
print("Gradient Checkpointing (Rematerialization) is ENABLED for Transformer Blocks.")
print("Dropout is DISABLED.")

global_step_counter = 0
optimizer_step_counter = 0

for epoch in range(num_epochs):
    print(f"\n-- Epoch {epoch+1}/{num_epochs} --")
    epoch_key, train_key = random.split(train_key)
    batch_generator = create_batches(
        train_data, config["batch_size"], config["context_length"], epoch_key
    )
    num_sequences = len(train_data) - config["context_length"]
    total_micro_batches_per_epoch = num_sequences // config["batch_size"]

    pbar = tqdm(
        enumerate(batch_generator), total=total_micro_batches_per_epoch,
        desc=f"Epoch {epoch+1} Training (Micro-batches)"
    )

    interval_loss = 0.0
    num_loss_samples = 0

    for step, train_batch in pbar:
        global_step_counter += 1
        train_batch = jax.device_put(train_batch, device)
        grads, compute_metrics = compute_grads_step(
            state, train_batch, model.apply
        )

        state = state.replace(
            accum_grads=jax.tree_util.tree_map(lambda acc, new: acc + new, state.accum_grads, grads)
        )

        interval_loss += compute_metrics['loss']
        num_loss_samples += 1

        if global_step_counter % accumulation_steps == 0:
            optimizer_step_counter += 1

            new_state, apply_metrics = apply_grads_step(
                state, state.accum_grads, lambda step: learning_rate
            )
            state = new_state

            if optimizer_step_counter % eval_frequency == 0:
                avg_train_loss = jax.device_get(interval_loss / num_loss_samples) if num_loss_samples > 0 else 0.0
                interval_loss, num_loss_samples = 0.0, 0

                val_loss = 0.0
                num_val_batches = 0
                val_key, train_key = random.split(train_key)
                val_batch_generator = create_batches(
                    val_data, config["batch_size"], config["context_length"], val_key
                )
                for val_batch in val_batch_generator:
                    val_batch = jax.device_put(val_batch, device)
                    eval_metrics = eval_step(state, val_batch, model.apply)
                    val_loss += eval_metrics['loss']
                    num_val_batches += 1

                avg_val_loss = jax.device_get(val_loss / num_val_batches) if num_val_batches > 0 else 0.0

                pbar.set_postfix(OptStep=optimizer_step_counter, TrainLoss=f"{avg_train_loss:.4f}", ValLoss=f"{avg_val_loss:.4f}", LR=f"{apply_metrics['learning_rate']:.1e}")

    pbar.close()
    print(f"-- Epoch {epoch+1} finished --")

print("\nTraining complete.")


--- Section 7: Training Loop (Manual MP + Accum + Remat, No Dropout) ---
Starting training for 1 epochs...
Compute dtype: <class 'jax.numpy.bfloat16'>, Param dtype: <class 'jax.numpy.float32'>
Gradient accumulation steps: 8
Evaluation frequency (optimizer steps): 200
Gradient Checkpointing (Rematerialization) is ENABLED for Transformer Blocks.
Dropout is DISABLED.

-- Epoch 1/1 --


Epoch 1 Training (Micro-batches):   0%|          | 0/281 [00:00<?, ?it/s]

-- Epoch 1 finished --

Training complete.
