In [None]:
# !pip install torch numpy tqdm wandb

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


## Phase 1: Model Hyperparameters

In [None]:
vocab_size   = 50257    # matches our tokenizer
block_size   = 128      # context window (you can grow later)
d_model      = 256      # model “width” (embedding size)
n_heads      = 4        # number of attention heads
n_layers     = 4        # number of transformer blocks
dropout_rate = 0.1      # small dropout to regularize

## 🚧 Phase 2: Core Components

We’ll implement each in PyTorch:

1. **Token + Positional Embeddings**
    - `nn.Embedding(vocab_size, d_model)`
    - `nn.Embedding(block_size, d_model)`
    - Sum them.
2. **Multi‑Head Self‑Attention**
    - Linear projections for Q, K, V: each `d_model → d_model`
    - Split into `n_heads` (i.e. reshape to `(batch, heads, seq, d_head)`)
    - Scaled dot‑product with causal mask
    - Concat heads → final projection
3. **Feed‑Forward Network (FFN)**
    - Two linear layers:
        - `d_model → 4·d_model`
        - Activation (GELU)
        - `4·d_model → d_model`
4. **Transformer Block**
    - Pre‑LayerNorm
    - Attention + residual
    - Pre‑LayerNorm
    - FFN + residual
5. **Stack of Blocks + Final LayerNorm**
    - Repeat `n_layers`
    
    - Final `LayerNorm(d_model)`
6. **Output Head (LM Head)**
    - Tie weights with token embedding: project final hidden state back to `vocab_size` logits.

1. Token + Positional Embeddings

In [None]:
class TokenPositionalEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, block_size):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed   = nn.Embedding(block_size, d_model)
        self.dropout     = nn.Dropout(0.1)  # GPT-2 uses dropout after embedding

    def forward(self, x):
        B, T = x.size()  # (batch_size, sequence_length)
        # sanity check
        assert T <= self.pos_embed.num_embeddings, (
            f"Sequence length T={T} exceeds block_size={self.pos_embed.num_embeddings}"
        )
        tok_emb = self.token_embed(x)                    # (B, T, d_model)
        pos_ids = torch.arange(T, device=x.device)       # (T,)
        pos_emb = self.pos_embed(pos_ids)[None, :, :]    # (1, T, d_model)
        out = tok_emb + pos_emb                          # (B, T, d_model)
        return self.dropout(out)


Example Usage

In [None]:
vocab_size = 50257
d_model = 256
block_size = 128

embed_layer = TokenPositionalEmbedding(vocab_size, d_model, block_size)

dummy_input = torch.randint(0, vocab_size, (4, block_size))  # (batch=4)
output = embed_layer(dummy_input)

print(output.shape)  # → torch.Size([4, 128, 256])

2. Multi‑Head Self‑Attention

In [None]:
import torch
import torch.nn as nn

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout_rate=0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # Linear projections for queries, keys, values, and final output
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)

        # Output projection (projects concatenated heads back to d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout_rate)

        # The mask will be created in forward based on input shape
        # self.register_buffer(
        #     "mask",
        #     torch.tril(torch.ones((1, 1, 512, 512), dtype=torch.bool)),  # 512 = max block_size; adjust if you use different
        #     persistent=False
        # )

    def forward(self, x):
        """
        x: (B, T, d_model)  batch of embeddings
        returns: (B, T, d_model) same shape, after self-attention
        """
        B, T, C = x.size()  # C = d_model
        assert C == self.n_heads * self.d_head

        # 1. project to queries, keys, values and reshape for multi-head
        q = self.q_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B, nh, T, dh)
        k = self.k_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B, nh, T, dh)
        v = self.v_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B, nh, T, dh)

        # 2. compute scaled dot-product attention scores
        #    q @ k^T : (B, nh, T, dh) @ (B, nh, dh, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)

        # 3. apply causal mask: prevent attending to future positions
        #    Create mask based on current sequence length T
        causal_mask = torch.tril(torch.ones((T, T), device=x.device, dtype=torch.bool))[None, None, :, :] # (1, 1, T, T)
        att = att.masked_fill(~causal_mask, float('-inf'))

        # 4. softmax and dropout
        att = torch.softmax(att, dim=-1)
        att = self.dropout(att)

        # 5. attention output weighted sum
        #    (B, nh, T, T) @ (B, nh, T, dh) -> (B, nh, T, dh)
        out = att @ v

        # 6. combine heads and final projection
        out = out.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, d_model)
        out = self.out_proj(out)
        return self.dropout(out)

In [None]:
# Suppose d_model=12 and n_heads=3 (so d_head=4) for a toy example
d_model = 12
n_heads = 3
d_head = d_model // n_heads

# Create a dummy q_proj just to inspect its weight
q_proj = nn.Linear(d_model, d_model)
W = q_proj.weight.data  # shape: (12, 12)

# Split W into 3 heads of size 4×12 each
heads = W.view(n_heads, d_head, d_model)

for i, W_i in enumerate(heads):
    print(f"Head {i} weight shape: {W_i.shape}")
    # This W_i is exactly the 4×12 matrix that each head uses


3. Feed‑Forward Network (FFN)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, dropout_rate=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        return self.net(x)

4. Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, block_size, dropout_rate=0.1): # Added dropout_rate here
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout_rate) # Pass dropout_rate
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, dropout_rate) # Pass dropout_rate
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

In [None]:
class GPT2Model(nn.Module):
    def __init__(self, vocab_size, block_size, d_model, n_heads, n_layers, dropout=0.1):
        super().__init__()
        # 1) Embeddings
        self.token_pos_embed = TokenPositionalEmbedding(vocab_size, d_model, block_size)

        # 2) Stacked Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, block_size, dropout)
            for _ in range(n_layers)
        ])

        # 3) Final layer norm
        self.ln_f = nn.LayerNorm(d_model)

        # 4) LM head (tie weights to token embedding)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = self.token_pos_embed.token_embed.weight

    def forward(self, idx):
        """
        idx: (B, T) token IDs
        returns:
          logits: (B, T, vocab_size)
        """
        # Embedding
        x = self.token_pos_embed(idx)  # (B, T, d_model)

        # Transformer blocks
        for block in self.blocks:
            x = block(x)               # (B, T, d_model)

        # Final norm
        x = self.ln_f(x)              # (B, T, d_model)

        # LM head to vocab logits
        logits = self.lm_head(x)      # (B, T, vocab_size)
        return logits


## Phase 3: Training

In [None]:
# In Colab
!wget -O tiny_shakespeare.txt \
     https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt


In [None]:
# In a Colab code cell
import glob

# Read the entire file into one big string
with open("tiny_shakespeare.txt", "r", encoding="utf-8") as f:
    text_data = f.read()
print(f"Loaded {len(text_data):,} characters.")

### Tokenizer

In [None]:
# pip install tiktoken


In [None]:
import tiktoken

# Load GPT-2 tokenizer
enc = tiktoken.get_encoding("gpt2")

text = "Hello world! How are you?"

# Encode to token IDs
token_ids = enc.encode(text)
print(token_ids)

# Decode back
decoded_text = enc.decode(token_ids)
print(decoded_text)

# Special tokens are not included by default in GPT-2 encoding


In [None]:
print("Vocab size:", enc.n_vocab)  # 50257 (GPT-2 has 50256 regular + 1 special)
print("Tokens:", enc.encode(" endoftext"))  # Usually [50256]


1. Prepared the training text

Tokenize your full text

In [None]:
tokens = enc.encode(text_data)
print(f"Total tokens: {len(tokens)}")

 1: Save or Cache the Tokenized Data (Optional but recommended)

In [None]:
import numpy as np

tokens_np = np.array(tokens, dtype=np.uint16)  # or uint32 if vocab is larger
np.save("tokens.npy", tokens_np)


In [None]:
tokens = np.load("tokens.npy")


2: Split into Train and Validation

In [None]:
split_ratio = 0.9  # 90% train, 10% validation
n = int(len(tokens) * split_ratio)
train_tokens = tokens[:n]
val_tokens = tokens[n:]
print(f"Train: {len(train_tokens)} tokens, Val: {len(val_tokens)} tokens")

3: Define a Dataset Class (for PyTorch)

In [None]:
import torch

class GPTDataset(torch.utils.data.Dataset):
    def __init__(self, data, block_size):
        self.data = torch.tensor(data, dtype=torch.long)
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        x = self.data[idx : idx + self.block_size]
        y = self.data[idx + 1 : idx + 1 + self.block_size]
        return x, y


4: Create Dataloaders

In [None]:
block_size = 128  # or 256, 512, depending on memory
batch_size = 32

train_dataset = GPTDataset(train_tokens, block_size)
val_dataset = GPTDataset(val_tokens, block_size)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)


5: Define the GPT Model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

In [None]:
model = GPT2Model(
    vocab_size=vocab_size,
    block_size=128,
    d_model=256,
    n_heads=4,
    n_layers=4,
    dropout=0.1
).to(device)


 6: Set up the Loss and Optimizer

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()


7: Training Loop

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for x, y in train_loader:
        x = x.to(device) # Move input to the same device as the model
        y = y.to(device) # Move target to the same device as the model
        logits = model(x)  # (B, T, vocab_size)
        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_train_loss = 0

    # Use tqdm for a progress bar for the training loader
    from tqdm.auto import tqdm
    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)")

    for x, y in train_loop:
        x = x.to(device)  # Move input to the same device as the model
        y = y.to(device)  # Move target to the same device as the model

        # Forward pass
        logits = model(x)  # (B, T, vocab_size)

        # Calculate loss
        # Reshape logits and y to (batch_size * sequence_length, vocab_size) and (batch_size * sequence_length,)
        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))

        # Backward pass and optimize
        optimizer.zero_grad()  # Clear previous gradients
        loss.backward()        # Compute gradients
        optimizer.step()       # Update model parameters

        total_train_loss += loss.item()
        train_loop.set_postfix(loss=loss.item())

    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Training Loss: {avg_train_loss:.4f}")

    # Optional: Add validation loop here to evaluate on validation set
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Validation)")
        for x, y in val_loop:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
            total_val_loss += loss.item()
            val_loop.set_postfix(loss=loss.item())
    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Validation Loss: {avg_val_loss:.4f}")

    # Optional: Save model checkpoint
    torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pt")

print("\nTraining finished.")

8: Save the Trained Model

In [None]:
torch.save(model.state_dict(), "babygpt2_model.pt")


Let's run a command to list the files in the current directory to see if gpt2_model.pt is present.


In [None]:
!ls -lh

9: Inference (Text Generation)

In [None]:
def generate(model, idx, max_new_tokens):
    model.eval()
    for _ in range(max_new_tokens):
        logits = model(idx)
        logits = logits[:, -1, :]  # only last token
        probs = torch.nn.functional.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, next_token), dim=1)
    return idx


In [None]:
# Make sure the model is on the correct device
model.to(device)

# 1. Define a prompt
prompt = "To be or not to be, that is the question:"

# 2. Encode the prompt
# Ensure 'enc' (your tokenizer) is available in the notebook
token_ids = enc.encode(prompt)

# 3. Convert to tensor and move to device
x = torch.tensor([token_ids], dtype=torch.long).to(device)

# 4. Generate text (e.g., 50 new tokens)
max_new_tokens = 50
generated_tokens = generate(model, x, max_new_tokens)

# 5. Decode the generated tokens
# Convert the tensor back to a list of IDs before decoding
generated_text = enc.decode(generated_tokens[0].tolist())

print(generated_text)

So, while the foundational coding work is largely complete, the training and evaluation phases are essential for having a practical, working language model. You've built the engine and the car, but it still needs to be driven and tuned to perform well!



---



---



---

# extras

## adding wandb logging, early stopping, resume from checkpoint, etc.

All those extras—Weights & Biases logging, early stopping, checkpointing, schedulers, gradient clipping—might look like “boilerplate,” but each serves a clear purpose in real-world model training:

this block of code is your enhanced training pipeline. While you already have a basic loop that updates model weights, this version adds the following to help you train more effectively and manage experiments:

**Experiment Tracking (wandb):**

*   Logs your hyperparameters, training/validation losses, learning rates, and other metrics in real time.
*   Makes it trivial to compare runs, visualize curves, and share results with collaborators.

**Checkpointing & Resume:**

*   Saves the “best” model so far (lowest validation loss) to disk.
*   If your notebook or Colab crashes—or you simply want to stop and come back—you can resume training exactly where you left off, without losing days of work.

**Early Stopping:**

*   Stops training when validation loss hasn’t improved for a set number of epochs (patience).
*   Prevents over-training (overfitting) and saves compute/energy by not wasting epochs once the model has plateaued.

**Learning Rate Scheduler (ReduceLROnPlateau):**

*   Automatically reduces the learning rate when validation loss stops improving.
*   Helps the optimizer make finer adjustments later in training, often leading to lower final loss.

**Gradient Clipping:**

*   Caps the gradient norm to avoid “exploding gradients” which can destabilize training, especially in deep or large-step models.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import wandb

# ─── 1) CONFIGURATION ────────────────────────────────────────────────
config = {
    "vocab_size":     vocab_size,    # e.g. 50257
    "block_size":     block_size,    # e.g. 128
    "d_model":        d_model,       # e.g. 256
    "n_heads":        n_heads,       # e.g. 4
    "n_layers":       n_layers,      # e.g. 4
    "dropout":        dropout_rate,  # e.g. 0.1
    "learning_rate":  3e-4,
    "batch_size":     32,
    "epochs":         20,
    "patience":       3,             # for early stopping
    "checkpoint_dir": "./checkpoints",
    "project_name":   "baby-gpt2",
}
os.makedirs(config["checkpoint_dir"], exist_ok=True)

# ─── 2) WANDB SETUP ─────────────────────────────────────────────────
wandb.init(project=config["project_name"], config=config)
wandb.watch_called = False  # ensure watch only once

# ─── 3) MODEL, OPTIMIZER, SCHEDULER, CRITERION ───────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2Model(
    vocab_size=config["vocab_size"],
    block_size=config["block_size"],
    d_model=config["d_model"],
    n_heads=config["n_heads"],
    n_layers=config["n_layers"],
    dropout=config["dropout"]
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=1, verbose=True
)
criterion = nn.CrossEntropyLoss()

# ─── 4) DATA LOADERS ─────────────────────────────────────────────────
train_dataset = TokenChunkDataset(train_inputs, train_targets)
val_dataset   = TokenChunkDataset(val_inputs,   val_targets)
train_loader  = torch.utils.data.DataLoader(
    train_dataset, batch_size=config["batch_size"], shuffle=True
)
val_loader    = torch.utils.data.DataLoader(
    val_dataset, batch_size=config["batch_size"], shuffle=False
)

# ─── 5) OPTIONAL: RESUME FROM CHECKPOINT ─────────────────────────────
start_epoch = 0
best_val_loss = float("inf")
stop_counter = 0

latest_ckpt = os.path.join(config["checkpoint_dir"], "best.pt")
if os.path.isfile(latest_ckpt):
    checkpoint = torch.load(latest_ckpt, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    best_val_loss = checkpoint["best_val_loss"]
    print(f"Resumed from epoch {start_epoch}, best_val_loss={best_val_loss:.4f}")

# ─── 6) TRAINING + VALIDATION LOOP ─────────────────────────────────
for epoch in range(start_epoch, config["epochs"]):
    # — Training —
    model.train()
    train_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [train]")
    for xb, yb in pbar:
        xb, yb = xb.to(device), yb.to(device)

        optimizer.zero_grad()
        logits = model(xb)  # (B, T, V)
        B, T, V = logits.size()
        loss = criterion(logits.view(B*T, V), yb.view(B*T))

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        train_loss += loss.item()
        pbar.set_postfix(train_loss=loss.item())

    avg_train_loss = train_loss / len(train_loader)

    # — Validation —
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in tqdm(val_loader, desc=f"Epoch {epoch+1} [val]"):
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            B, T, V = logits.size()
            loss = criterion(logits.view(B*T, V), yb.view(B*T))
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)

    # — Scheduler step & early stopping logic —
    scheduler.step(avg_val_loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        stop_counter = 0
        # Save best checkpoint
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "best_val_loss": best_val_loss,
        }, latest_ckpt)
        print(f"New best model at epoch {epoch+1}, val_loss={avg_val_loss:.4f}")
    else:
        stop_counter += 1
        print(f"No improvement for {stop_counter} epoch(s). Best={best_val_loss:.4f}")
        if stop_counter >= config["patience"]:
            print("Early stopping triggered.")
            break

    # — Logging to wandb —
    wandb.log({
        "epoch": epoch+1,
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss,
        "lr": optimizer.param_groups[0]["lr"],
    })

print("Training complete.")


## multi‑GPU training

Here’s the simplest way to get multi‑GPU training going in your notebook using PyTorch’s DataParallel

1. Wrap Your Model in DataParallel

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2Model(
    vocab_size=vocab_size,
    block_size=block_size,
    d_model=d_model,
    n_heads=n_heads,
    n_layers=n_layers,
    dropout=dropout_rate
).to(device)

# multi‑GPU wrapper
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

2. Adjust Checkpointing & Resume

When saving or loading state dicts, remember that under DataParallel, your model’s weights live in model.module:

In [None]:
# Saving
to_save = model.module if isinstance(model, nn.DataParallel) else model
torch.save({
    "epoch": epoch,
    "model_state_dict": to_save.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "best_val_loss": best_val_loss,
    # Add other states like scheduler state if needed
}, latest_ckpt)

# Loading
chk = torch.load(latest_ckpt)
if torch.cuda.device_count() > 1:
    model.module.load_state_dict(chk["model_state_dict"])
else:
    model.load_state_dict(chk["model_state_dict"])


3. No Changes Needed to DataLoader

DataParallel automatically splits each incoming batch (xb, yb) along the batch dimension and collects the outputs. You do not need to change how you construct your DataLoader.

4. Full Example Snippet

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm

# ─── 1) CONFIG ────────────────────────────────────────────────────────
config = {
    "vocab_size":     vocab_size,
    "block_size":     block_size,
    "d_model":        d_model,
    "n_heads":        n_heads,
    "n_layers":       n_layers,
    "dropout":        dropout_rate,
    "learning_rate":  3e-4,
    "batch_size":     32,
    "epochs":         20,
    "patience":       3,
    "checkpoint_dir": "./checkpoints",
}
os.makedirs(config["checkpoint_dir"], exist_ok=True)

# ─── 2) MODEL + MULTI‑GPU WRAP ─────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2Model(
    vocab_size=config["vocab_size"],
    block_size=config["block_size"],
    d_model=config["d_model"],
    n_heads=config["n_heads"],
    n_layers=config["n_layers"],
    dropout=config["dropout"]
).to(device)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

# ─── 3) OPTIMIZER, SCHEDULER, CRITERION ────────────────────────────────
optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=1, verbose=True
)
criterion = nn.CrossEntropyLoss()

# ─── 4) DATA LOADERS ───────────────────────────────────────────────────
train_dataset = TokenChunkDataset(train_inputs, train_targets)
val_dataset   = TokenChunkDataset(val_inputs,   val_targets)
train_loader  = torch.utils.data.DataLoader(
    train_dataset, batch_size=config["batch_size"], shuffle=True
)
val_loader    = torch.utils.data.DataLoader(
    val_dataset,   batch_size=config["batch_size"], shuffle=False
)

# ─── 5) OPTIONAL: RESUME FROM CHECKPOINT ──────────────────────────────
start_epoch = 0
best_val_loss = float("inf")
stop_counter = 0
ckpt_path = os.path.join(config["checkpoint_dir"], "best.pt")

if os.path.isfile(ckpt_path):
    chk = torch.load(ckpt_path, map_location=device)
    # If wrapped in DataParallel, state dict under .module
    target = model.module if isinstance(model, nn.DataParallel) else model
    target.load_state_dict(chk["model_state_dict"])
    optimizer.load_state_dict(chk["optimizer_state_dict"])
    start_epoch = chk["epoch"] + 1
    best_val_loss = chk["best_val_loss"]
    print(f"Resumed from epoch {start_epoch}, best_val_loss={best_val_loss:.4f}")

# ─── 6) TRAIN + VALIDATE LOOP ─────────────────────────────────────────
for epoch in range(start_epoch, config["epochs"]):
    # — Training —
    model.train()
    train_loss = 0.0
    for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch+1} [train]"):
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()

        logits = model(xb)                     # (B, T, V)
        B, T, V = logits.size()
        loss = criterion(logits.view(B*T, V), yb.view(B*T))
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        train_loss += loss.item()
    avg_train = train_loss / len(train_loader)

    # — Validation —
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in tqdm(val_loader, desc=f"Epoch {epoch+1} [val]"):
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            B, T, V = logits.size()
            loss = criterion(logits.view(B*T, V), yb.view(B*T))
            val_loss += loss.item()
    avg_val = val_loss / len(val_loader)

    # — Scheduler, Checkpointing, Early Stopping —
    scheduler.step(avg_val)
    if avg_val < best_val_loss:
        best_val_loss = avg_val
        stop_counter = 0
        # Save best
        to_save = model.module if isinstance(model, nn.DataParallel) else model
        torch.save({
            "epoch": epoch,
            "model_state_dict": to_save.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "best_val_loss": best_val_loss,
        }, ckpt_path)
        print(f"[Epoch {epoch+1}] New best val loss: {avg_val:.4f}")
    else:
        stop_counter += 1
        print(f"[Epoch {epoch+1}] No improvement ({stop_counter}/{config['patience']})")
        if stop_counter >= config["patience"]:
            print("Early stopping triggered.")
            break

    print(f"Epoch {epoch+1} | Train: {avg_train:.4f} | Val: {avg_val:.4f}")

print("Training complete.")


# ✅ Build the GPT class

In [None]:
import torch
import torch.nn as nn

class GPTConfig:
    """Configuration for the GPT model."""
    def __init__(
        self,
        vocab_size: int,
        block_size: int,
        d_model: int,
        n_heads: int,
        n_layers: int,
        dropout: float = 0.1
    ):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.d_model    = d_model
        self.n_heads    = n_heads
        self.n_layers   = n_layers
        self.dropout    = dropout

class GPT(nn.Module):
    """
    GPT‐2–style model:
      • Token + Positional Embeddings
      • N Transformer blocks (pre-LN, causal self-attn, FFN)
      • Final LayerNorm
      • Tied LM Head
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        # 1) Embeddings
        self.token_embed = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embed   = nn.Embedding(config.block_size, config.d_model)
        self.dropout     = nn.Dropout(config.dropout)

        # 2) Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=config.d_model,
                n_heads=config.n_heads,
                block_size=config.block_size,
                dropout=config.dropout
            )
            for _ in range(config.n_layers)
        ])

        # 3) Final layer norm
        self.ln_f = nn.LayerNorm(config.d_model)

        # 4) Language‑model head (tied to token embeddings)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        # tie weights
        self.lm_head.weight = self.token_embed.weight

        # ensure everything is initialized properly
        self._init_weights()

    def _init_weights(self):
        # GPT‑style initialization (following OpenAI GPT-2)
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, nn.LayerNorm):
                nn.init.zeros_(module.bias)
                nn.init.ones_(module.weight)

    def forward(self, idx: torch.LongTensor) -> torch.FloatTensor:
        """
        Input:
          idx: (B, T) token IDs
        Output:
          logits: (B, T, vocab_size)
        """
        B, T = idx.size()
        assert T <= self.config.block_size, \
            f"Sequence length {T} exceeds block_size {self.config.block_size}"

        # Embeddings
        token_embeddings = self.token_embed(idx)               # (B, T, d_model)
        positions = torch.arange(T, device=idx.device)        # (T,)
        pos_embeddings = self.pos_embed(positions)            # (T, d_model)
        x = token_embeddings + pos_embeddings.unsqueeze(0)    # (B, T, d_model)
        x = self.dropout(x)

        # Transformer blocks
        for block in self.blocks:
            x = block(x)                                       # (B, T, d_model)

        # Final norm
        x = self.ln_f(x)                                       # (B, T, d_model)

        # Language modeling head
        logits = self.lm_head(x)                               # (B, T, vocab_size)
        return logits

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits = self(idx_cond)
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            idx = torch.cat([idx, next_token], dim=1)
        return idx



How to instantiate and use

In [None]:
# 1) Create a config
cfg = GPTConfig(
    vocab_size=50257,
    block_size=128,
    d_model=256,
    n_heads=4,
    n_layers=4,
    # dropout=0.1
)

# 2) Instantiate the model
model = GPT(cfg).to(device)

# 3) Forward pass
input_ids = torch.randint(0, cfg.vocab_size, (2, 64), device=device)  # batch_size=2, seq_len=64
logits = model(input_ids)  # → (2, 64, cfg.vocab_size)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Re-defining TokenPositionalEmbedding
class TokenPositionalEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, block_size):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed   = nn.Embedding(block_size, d_model)
        self.dropout     = nn.Dropout(0.1)  # GPT-2 uses dropout after embedding

    def forward(self, x):
        B, T = x.size()  # (batch_size, sequence_length)
        # sanity check
        assert T <= self.pos_embed.num_embeddings, (
            f"Sequence length T={T} exceeds block_size={self.pos_embed.num_embeddings}"
        )
        tok_emb = self.token_embed(x)                    # (B, T, d_model)
        pos_ids = torch.arange(T, device=x.device)       # (T,)
        pos_emb = self.pos_embed(pos_ids)[None, :, :]    # (1, T, d_model)
        out = tok_emb + pos_emb                          # (B, T, d_model)
        return self.dropout(out)

# Re-defining CausalSelfAttention
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout_rate=0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        # Output projection (projects concatenated heads back to d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout_rate)

        # The mask will be created in forward based on input shape
        # self.register_buffer(
        #     "mask",
        #     torch.tril(torch.ones((1, 1, 512, 512), dtype=torch.bool)),  # 512 = max block_size; adjust if you use different
        #     persistent=False
        # )

    def forward(self, x):
        """
        x: (B, T, d_model)  batch of embeddings
        returns: (B, T, d_model) same shape, after self-attention
        """
        B, T, C = x.size() # C = d_model
        assert C == self.n_heads * self.d_head

        # 1. project to queries, keys, values and reshape for multi-head
        q = self.q_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B, nh, T, dh)
        k = self.k_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B, nh, T, dh)
        v = self.v_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B, nh, T, dh)

        # 2. compute scaled dot-product attention scores
        #    q @ k^T : (B, nh, T, dh) @ (B, nh, dh, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)

        # 3. apply causal mask: prevent attending to future positions
        #    Create mask based on current sequence length T
        causal_mask = torch.tril(torch.ones((T, T), device=x.device, dtype=torch.bool))[None, None, :, :] # (1, 1, T, T)
        att = att.masked_fill(~causal_mask, float('-inf'))

        # 4. softmax and dropout
        att = torch.softmax(att, dim=-1)
        att = self.dropout(att)

        # 5. attention output weighted sum
        #    (B, nh, T, T) @ (B, nh, T, dh) -> (B, nh, T, dh)
        out = att @ v

        # 6. combine heads and final projection
        out = out.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, d_model)
        out = self.out_proj(out)
        return self.dropout(out)

# Re-defining FeedForward
class FeedForward(nn.Module):
    def __init__(self, d_model, dropout_rate=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        return self.net(x)

# Re-defining TransformerBlock
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, block_size, dropout_rate=0.1): # Corrected to accept dropout_rate
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout_rate) # Pass dropout_rate
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, dropout_rate) # Pass dropout_rate

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

# Re-defining GPTConfig
class GPTConfig:
    """Configuration for the GPT model."""
    def __init__(
        self,
        vocab_size: int,
        block_size: int,
        d_model: int,
        n_heads: int,
        n_layers: int,
        dropout: float = 0.1
    ):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.d_model    = d_model
        self.n_heads    = n_heads
        self.n_layers   = n_layers
        self.dropout    = dropout

# Re-defining GPT
class GPT(nn.Module):
    """
    GPT‐2–style model:
      • Token + Positional Embeddings
      • N Transformer blocks (pre-LN, causal self-attn, FFN)
      • Final LayerNorm
      • Tied LM Head
    """
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        # 1) Embeddings
        self.token_embed = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embed   = nn.Embedding(config.block_size, config.d_model)
        self.dropout     = nn.Dropout(config.dropout)

        # self.token_pos_embed = TokenPositionalEmbedding(vocab_size, d_model, block_size)

        # 2) Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=config.d_model,
                n_heads=config.n_heads,
                block_size=config.block_size,
                dropout_rate=config.dropout # Pass dropout_rate here
            )
            for _ in range(config.n_layers)
        ])

        # 3) Final layer norm
        self.ln_f = nn.LayerNorm(config.d_model)

        # 4) Language‑model head (tied to token embeddings)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        # tie weights
        self.lm_head.weight = self.token_embed.weight

        # ensure everything is initialized properly
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, nn.LayerNorm):
                nn.init.zeros_(module.bias)
                nn.init.ones_(module.weight)

    def forward(self, idx: torch.LongTensor) -> torch.FloatTensor:
        B, T = idx.size()
        assert T <= self.config.block_size, \
            f"Sequence length {T} exceeds block_size {self.config.block_size}"

        token_embeddings = self.token_embed(idx)
        positions = torch.arange(T, device=idx.device)
        pos_embeddings = self.pos_embed(positions)
        x = token_embeddings + pos_embeddings.unsqueeze(0)
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.config.block_size:] # Use config.block_size
            logits = self(idx_cond)
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            idx = torch.cat([idx, next_token], dim=1)
        return idx


# --- Model Instantiation ---
# Make sure device is defined (e.g., device = 'cuda' if torch.cuda.is_available() else 'cpu')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Define the configuration parameters (these must match the saved model)
vocab_size   = 50257
block_size   = 128
d_model      = 256
n_heads      = 4
n_layers     = 4
dropout_rate = 0.1

# Create a GPTConfig instance
cfg = GPTConfig(
    vocab_size=vocab_size,
    block_size=block_size,
    d_model=d_model,
    n_heads=n_heads,
    n_layers=n_layers,
    dropout=dropout_rate
)

# Instantiate the GPT model
model = GPT(cfg).to(device)

print("GPT model instantiated successfully.")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,} trainable parameters")

To swap in your new `GPT` class for `GPT2Model`, you just need to:

1.  Import/define the `GPTConfig` and `GPT` classes in your notebook (instead of `GPT2Model`).
2.  Instantiate using `cfg = GPTConfig(...)` and `model = GPT(cfg)` instead of the old call.

Everything else stays the same—your training loop, data loaders, optimizer, etc., all work unchanged because `GPT` and `GPT2Model` have the same `forward(idx) → logits` interface.

If you’re using multi‑GPU wrapping:

In [None]:
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)


Here is the code cell to **Load the Checkpoint**.

This assumes your checkpoint file (`babygpt2_model.pt`) contains the `model_state_dict`, `optimizer_state_dict`, and `epoch`. Adjust the `checkpoint_path` if your file has a different name or location.

Make sure your `GPTConfig`, `GPT`, and `AdamW` (optimizer) are defined in the notebook before running this.

In [None]:
import torch
import torch.optim as optim # Needed to instantiate the optimizer before loading its state

# Make sure your model configuration (cfg) and model (loaded_model) are instantiated first
# You'll need the same architecture as when the checkpoint was saved.
# Define the configuration parameters (these must match the saved model)
vocab_size   = 50257
block_size   = 128
d_model      = 256
n_heads      = 4
n_layers     = 4
dropout_rate = 0.1

# Create a GPTConfig instance
cfg = GPTConfig(
    vocab_size=vocab_size,
    block_size=block_size,
    d_model=d_model,
    n_heads=n_heads,
    n_layers=n_layers,
    dropout=dropout_rate
)
loaded_model = GPT(cfg).to(device)

# Instantiate the optimizer (must be the same type as when saved)
# The learning rate will be overwritten when loading the state_dict, but you need an optimizer object first.
# Use a dummy learning rate here, it will be replaced.
optimizer = optim.AdamW(loaded_model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

# Initialize variables for resuming
start_epoch = 0
# Initialize other variables you might have saved, like best_val_loss
# best_val_loss = float('inf')


In [None]:
# Load the checkpoint
import torch.serialization

# Add GPTConfig to allowed globals for safe loading
torch.serialization.add_safe_globals([GPTConfig])

# Define the path to your comprehensive checkpoint file
checkpoint_path = "babygpt2_model_final_checkpoint.pt"
# checkpoint_path = "babygpt2_model_final.pt"

try:
    checkpoint = torch.load(checkpoint_path, map_location=device) # Add map_location=device if needed

    # Load model state
    loaded_model.load_state_dict(checkpoint['model_state_dict'])

    # Load optimizer state
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    # Load other states if saved
    start_epoch = checkpoint['epoch'] + 1 # Start from the next epoch
    # If you saved best_val_loss:
    # best_val_loss = checkpoint['best_val_loss']
    # If you saved scheduler state:
    # scheduler.load_state_dict(checkpoint['scheduler_state_dict'])


    print(f"Resuming training from epoch {start_epoch}.")
    # If you loaded best_val_loss:
    # print(f"Previous best validation loss: {best_val_loss:.4f}")

except FileNotFoundError:
    print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.")
    # If no checkpoint is found, you might want to start training from epoch 0
    # and initialize optimizer/scheduler here if they weren't initialized before the try block.

except KeyError as e:
     print(f"Error loading checkpoint: Missing key {e}. Make sure the checkpoint dictionary structure matches.")

except Exception as e:
    print(f"An error occurred while loading the checkpoint: {e}")


# Move the loaded model and optimizer to the desired device (CPU or GPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
loaded_model.to(device)
# Optimizer doesn't typically need to be moved to device explicitly, but its internal tensors are
# managed by the model's parameters being on the device.

print(f"Model and optimizer loaded, ready to resume training on {device}.")

# Now you can use the 'loaded_model' for inference or other tasks.
# For example, to use it for text generation, you would need your tokenizer ('enc').
prompt = "Hello, how are you?"
token_ids = enc.encode(prompt)
input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device)
generated_output = loaded_model.generate(input_tensor, max_new_tokens=100)
generated_text = enc.decode(generated_output[0].tolist())
print(generated_text)

In [None]:
model = loaded_model

Here is the code cell for the **Training Loop**.

This loop will start from the `start_epoch` loaded from the checkpoint and continue for the specified number of `num_epochs`.

Make sure your `train_loader`, `val_loader`, and `criterion` are defined before running this.

In [None]:
import torch
from tqdm.notebook import tqdm # Assuming you want progress bars

# Ensure model, optimizer, start_epoch, train_loader, val_loader, criterion, and device are defined
# (These should be ready from the previous loading cell and your data loading cells)
# Initialize variables for resuming
start_epoch = 0

num_epochs = 10 # Define the total number of epochs you want to train for (e.g., 10 total epochs)
                # If you loaded from epoch 5 and num_epochs is 10, it will train for epochs 5, 6, 7, 8, 9.

# --- Start or resume the training loop ---
for epoch in range(start_epoch, num_epochs):
    # --- Training ---
    model.train()  # Set the model to training mode
    total_train_loss = 0

    # Use tqdm for a progress bar for the training loader
    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)")

    for x, y in train_loop:
        x = x.to(device)  # Move input to the same device as the model
        y = y.to(device)  # Move target to the same device as the model

        # Forward pass
        logits = model(x)  # (B, T, vocab_size)

        # Calculate loss
        # Reshape logits and y to (batch_size * sequence_length, vocab_size) and (batch_size * sequence_length,)
        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))

        # Backward pass and optimize
        optimizer.zero_grad()  # Clear previous gradients
        loss.backward()        # Compute gradients
        optimizer.step()       # Update model parameters

        total_train_loss += loss.item()
        train_loop.set_postfix(loss=loss.item())

    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Training Loss: {avg_train_loss:.4f}")

    # --- Optional: Add validation loop here ---
    # It's good practice to evaluate on the validation set periodically
    model.eval() # Set the model to evaluation mode
    total_val_loss = 0
    with torch.no_grad(): # Disable gradient calculation for validation
       val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Validation)")
       for x, y in val_loop:
           x = x.to(device)
           y = y.to(device)
           logits = model(x)
           loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
           total_val_loss += loss.item()
           val_loop.set_postfix(loss=loss.item())

    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Validation Loss: {avg_val_loss:.4f}")

    # Optional: Save checkpoint after each epoch or periodically
    # This is where you would save a new checkpoint to be able to resume again later
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item(), # Or avg_train_loss
        # Save other metrics if you want to resume early stopping
        # 'best_val_loss': best_val_loss,
    }, f'gpt2_checkpoint_epoch_{epoch+1}.pt')

    # Optional: Save the final model or a final comprehensive checkpoint after training is complete
    # Example saving just the model state:
    torch.save(model.state_dict(), "babygpt2_model_final.pt")

    # Example saving a final comprehensive checkpoint:
    final_checkpoint = {
        'epoch': num_epochs,  # Save the current epoch
        'model_state_dict': model.state_dict(), # Save the model's weights and biases
        'optimizer_state_dict': optimizer.state_dict(), # Save the optimizer's state
        'loss': loss.item(), # Save the loss at this point (or average loss)
        'config': cfg, # Uncomment if 'config' is accessible and you want to save it
        # Optional:
        # 'scheduler_state_dict': scheduler.state_dict(), # Save the scheduler's state
        # 'best_val_loss': best_val_loss, # Save the best validation loss
    }
    torch.save(final_checkpoint, "babygpt2_model_final_checkpoint.pt")


print("\nTraining finished.")

# # Optional: Save the final model or a final comprehensive checkpoint after training is complete
# # Example saving just the model state:
# torch.save(model.state_dict(), "babygpt2_model_final.pt")

# # Example saving a final comprehensive checkpoint:
# final_checkpoint = {
#     'epoch': num_epochs,  # Save the current epoch
#     'model_state_dict': model.state_dict(), # Save the model's weights and biases
#     'optimizer_state_dict': optimizer.state_dict(), # Save the optimizer's state
#     'loss': loss.item(), # Save the loss at this point (or average loss)
#     'config': cfg, # Uncomment if 'config' is accessible and you want to save it
#     # Optional:
#     # 'scheduler_state_dict': scheduler.state_dict(), # Save the scheduler's state
#     # 'best_val_loss': best_val_loss, # Save the best validation loss
# }
# torch.save(final_checkpoint, "babygpt2_model_final_checkpoint.pt")


In [None]:
# Optional: Save the final model or a final comprehensive checkpoint after training is complete
# Example saving just the model state:
torch.save(model.state_dict(), "babygpt2_model_final.pt")

# Example saving a final comprehensive checkpoint:
final_checkpoint = {
    'epoch': num_epochs,  # Save the current epoch
    'model_state_dict': model.state_dict(), # Save the model's weights and biases
    'optimizer_state_dict': optimizer.state_dict(), # Save the optimizer's state
    'loss': loss.item(), # Save the loss at this point (or average loss)
    'config': cfg, # Uncomment if 'config' is accessible and you want to save it
    # Optional:
    # 'scheduler_state_dict': scheduler.state_dict(), # Save the scheduler's state
    # 'best_val_loss': best_val_loss, # Save the best validation loss
}
torch.save(final_checkpoint, "babygpt2_model_final_checkpoint.pt")

# there are 2 ways to save the model

1. `torch.save(model.state_dict(), "babygpt2_model_final.pt")`

2. `torch.save(final_checkpoint, "babygpt2_model_final_checkpoint.pt")
`

### 1. Loading from `babygpt2_model_final.pt` (Model State Dictionary only)

This method is suitable if you only need the model's learned weights for inference or evaluation.

In [None]:
import torch

# Ensure GPTConfig and GPT classes are defined in your notebook
# Example instantiation (must match the architecture saved):
# vocab_size = 50257
# block_size = 128
# d_model = 256
# n_heads = 4
# n_layers = 4
# dropout_rate = 0.1
# cfg = GPTConfig(vocab_size, block_size, d_model, n_heads, n_layers, dropout_rate)
# loaded_model_state_only = GPT(cfg)

# Define the path to the saved model state dictionary file
model_state_path = "babygpt2_model_final.pt"

# Instantiate the model (you need the architecture defined first)
# Assuming 'cfg' is already defined from your notebook's setup
loaded_model_state_only = GPT(cfg)


# Load the state dictionary into the instantiated model
try:
    loaded_model_state_only.load_state_dict(torch.load(model_state_path, map_location=device)) # Use map_location=device if needed
    print(f"Model state dictionary loaded successfully from {model_state_path}")
except FileNotFoundError:
    print(f"Error: Model state file not found at {model_state_path}. Please check the path.")
except Exception as e:
    print(f"An error occurred while loading the model state dictionary: {e}")

# Move the loaded model to the desired device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
loaded_model_state_only.to(device)

print(f"Model loaded for inference/evaluation and moved to {device}.")

# You can now use loaded_model_state_only for inference
# For example: loaded_model_state_only.generate(...)

In [None]:
model = loaded_model_state_only

### 2. Loading from `babygpt2_model_final_checkpoint.pt` (Comprehensive Checkpoint)

This method is suitable if you need to resume training, as it includes the optimizer state, epoch, and potentially other training-related information.

In [None]:
import torch
import torch.optim as optim # Needed to instantiate optimizer before loading state
import torch.serialization # Import serialization module

# Ensure GPTConfig and GPT classes are defined in your notebook
# Example instantiation (must match the architecture saved):
# vocab_size = 50257
# block_size = 128
# d_model = 256
# n_heads = 4
# n_layers = 4
# dropout_rate = 0.1
# cfg = GPTConfig(vocab_size, block_size, d_model, n_heads, n_layers, dropout_rate)
# loaded_model_checkpoint = GPT(cfg)

# Instantiate the model (you need the architecture defined first)
# Assuming 'cfg' is already defined from your notebook's setup
loaded_model_checkpoint = GPT(cfg)


# Instantiate the optimizer (must be the same type as when saved)
# A dummy learning rate is fine here as it will be overwritten by the loaded state
optimizer_checkpoint = optim.AdamW(loaded_model_checkpoint.parameters(), lr=3e-4)


# Define the path to the comprehensive checkpoint file
checkpoint_file_path = "babygpt2_model_final_checkpoint.pt"

# Initialize variables for resuming (will be overwritten if checkpoint exists)
start_epoch_checkpoint = 0
# best_val_loss_checkpoint = float('inf') # Uncomment if you saved this

# Add GPTConfig to allowed globals for safe loading
torch.serialization.add_safe_globals([GPTConfig])


# Load the checkpoint dictionary
try:
    checkpoint = torch.load(checkpoint_file_path, map_location=device) # Use map_location=device if needed

    # Load model state
    loaded_model_checkpoint.load_state_dict(checkpoint['model_state_dict'])

    # Load optimizer state
    optimizer_checkpoint.load_state_dict(checkpoint['optimizer_state_dict'])

    # Load other states if saved
    start_epoch_checkpoint = checkpoint['epoch'] + 1
    # if 'best_val_loss' in checkpoint:
    #     best_val_loss_checkpoint = checkpoint['best_val_loss']
    # if 'scheduler_state_dict' in checkpoint and scheduler_checkpoint is not None:
    #      scheduler_checkpoint.load_state_dict(checkpoint['scheduler_state_dict'])
    # if 'config' in checkpoint:
    #      loaded_cfg = checkpoint['config'] # Load the saved config

    print(f"Checkpoint loaded successfully from {checkpoint_file_path}")
    print(f"Ready to resume training from epoch {start_epoch_checkpoint}.")

except FileNotFoundError:
    print(f"Error: Checkpoint file not found at {checkpoint_file_path}. Please check the path.")
    # If no checkpoint, training will start from epoch 0 with initialized model/optimizer

except KeyError as e:
    print(f"Error loading checkpoint: Missing key {e}. Make sure the checkpoint dictionary structure matches.")

except Exception as e:
    print(f"An error occurred while loading the checkpoint: {e}")

# Move the loaded model and optimizer to the desired device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
loaded_model_checkpoint.to(device)
# Optimizer's state tensors are moved implicitly with the model's parameters

print(f"Model and optimizer loaded from checkpoint and moved to {device}.")

# You can now use loaded_model_checkpoint and optimizer_checkpoint to resume training
# The next training loop should start from start_epoch_checkpoint