# 06. Training the Model

Now that we have our model architecture defined, we need to train it. This notebook covers the training pipeline, including:

1.  **Data Loading**: Creating a PyTorch `DataLoader`.
2.  **Optimization**: Setting up the optimizer (AdamW) and learning rate scheduler.
3.  **Training Loop**: The core loop that updates model weights.
4.  **Checkpointing**: Saving and loading model states.

We will use a synthetic dataset for demonstration purposes, but the pipeline is identical for real data.

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import math
import os
import sys

# Try to import the actual model from src
sys.path.append(os.path.abspath('..'))

try:
    from src.model import DecoderLM, ModelConfig
    print("Successfully imported DecoderLM from src.model")
except ImportError:
    print("src.model not found. Using dummy class for demonstration.")
    from dataclasses import dataclass
    @dataclass
    class ModelConfig:
        n_embd: int = 128
        n_head: int = 4
        n_layer: int = 2
        n_positions: int = 128
        vocab_size: int = 1000
        dropout: float = 0.1
        bias: bool = True

    class DecoderLM(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Define parameters so the optimizer has something to optimize
            self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
            self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
            self.config = config
            
        def forward(self, x, targets=None):
            emb = self.token_emb(x)
            logits = self.lm_head(emb)
            loss = None
            if targets is not None:
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            return logits, loss

Successfully imported DecoderLM from src.model


## 1. Data Loading

We need a `Dataset` class that returns pairs of `(input_ids, target_ids)`. For causal language modeling, the target is usually the input shifted by one token.

In [5]:
class TextDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        # Input: tokens [i, i+seq_len]
        # Target: tokens [i+1, i+seq_len+1]
        chunk = self.data[idx : idx + self.seq_len + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return {"input_ids": x, "labels": y}

# Create dummy data
vocab_size = 1000
data_size = 10000
seq_len = 32
dummy_data = torch.randint(0, vocab_size, (data_size,))

dataset = TextDataset(dummy_data, seq_len)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

## 2. Optimization

We use the **AdamW** optimizer, which is standard for Transformers. We also use a learning rate scheduler with a warmup period followed by cosine decay.

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DecoderLM(ModelConfig()).to(device)

# Now this should work because model has parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# Simple scheduler for demonstration (StepLR)
# In production, use CosineAnnealingLR with warmup
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)

## 3. Training Loop

Here is a basic training loop.

In [7]:
def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        targets = batch['labels'].to(device)
        
        # Forward pass
        logits, loss = model(input_ids, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # Update weights
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
    return total_loss / len(dataloader)

# Run one epoch
avg_loss = train_epoch(model, dataloader, optimizer, device)
print(f"Average Loss: {avg_loss:.4f}")

Training:   0%|          | 0/312 [00:00<?, ?it/s]

Average Loss: 6.6549


## 4. Checkpointing

It's crucial to save your model periodically.

In [8]:
def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)
    print(f"Checkpoint saved to {path}")

save_checkpoint(model, optimizer, 0, "checkpoint.pt")

Checkpoint saved to checkpoint.pt
