# ðŸš€ Training Pipeline: 3-Phase Training Methodology

This notebook implements the complete training pipeline for our SLM. Training happens in three phases:

1. **Phase A: Base Pretraining** - Learn language at 2k context
2. **Phase B: Context Extension** - Extend to 4k-5k with RoPE scaling
3. **Phase C: Domain Fine-Tuning** - Shape reasoning (no labels/answers)

Each phase builds on the previous checkpoint.

---
## 1. Setup & Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from pathlib import Path
import json
import math
import time
from dataclasses import dataclass

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Paths
DATA_DIR = Path("../data/pre1986_training_streams_v1_FINAL")
TOKENIZER_PATH = Path("../tokenizer/tokenizer.json")
CHECKPOINT_DIR = Path("../checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)

# Load tokenizer
tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))
print(f"âœ“ Loaded tokenizer with vocab size: {tokenizer.get_vocab_size()}")

---
## 2. Dataset: Text Streaming

We create chunks dynamically from our text streams. Documents are separated by `<EOS>` tokens.

In [None]:
class TextDataset(Dataset):
    """
    Dataset that chunks text into fixed-length sequences.
    
    We tokenize once and store the token IDs, then serve
    random chunks during training.
    """
    
    def __init__(self, file_path: Path, tokenizer: Tokenizer, seq_len: int):
        self.seq_len = seq_len
        
        # Load and tokenize the entire file
        print(f"Loading {file_path.name}...")
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        
        print(f"Tokenizing ({len(text):,} chars)...")
        encoding = tokenizer.encode(text)
        self.tokens = torch.tensor(encoding.ids, dtype=torch.long)
        
        # Number of complete chunks we can make
        self.n_chunks = (len(self.tokens) - 1) // seq_len
        print(f"Created {self.n_chunks:,} chunks of length {seq_len}")
    
    def __len__(self):
        return self.n_chunks
    
    def __getitem__(self, idx):
        # Get a chunk of tokens
        start = idx * self.seq_len
        end = start + self.seq_len + 1  # +1 for target
        
        chunk = self.tokens[start:end]
        
        # Input and target (shifted by 1)
        x = chunk[:-1]  # Input tokens
        y = chunk[1:]   # Target tokens (predict next token)
        
        return x, y

# Quick test
test_dataset = TextDataset(DATA_DIR / "base_stream.txt", tokenizer, seq_len=128)
x, y = test_dataset[0]
print(f"\nSample batch: input shape={x.shape}, target shape={y.shape}")

---
## 3. Training Configuration

In [None]:
@dataclass
class TrainingConfig:
    """All training hyperparameters in one place."""
    
    # Batch settings
    batch_size: int = 8
    gradient_accumulation_steps: int = 4  # Effective batch = 32
    
    # Learning rate
    learning_rate: float = 3e-4
    weight_decay: float = 0.1
    warmup_steps: int = 1000
    
    # Training duration
    max_steps: int = 50000
    eval_interval: int = 500
    save_interval: int = 2000
    
    # Context length (varies by phase)
    seq_len: int = 2048

train_config = TrainingConfig()
print(f"Effective batch size: {train_config.batch_size * train_config.gradient_accumulation_steps}")

---
## 4. Learning Rate Scheduler

We use warmup followed by cosine decay - standard practice for transformer training.

In [None]:
def get_lr(step: int, config: TrainingConfig) -> float:
    """
    Learning rate schedule: linear warmup then cosine decay.
    """
    # Warmup phase
    if step < config.warmup_steps:
        return config.learning_rate * step / config.warmup_steps
    
    # Cosine decay phase
    decay_ratio = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps)
    decay_ratio = min(decay_ratio, 1.0)
    
    # Cosine decay to 10% of peak LR
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return config.learning_rate * 0.1 + coeff * (config.learning_rate * 0.9)

# Visualize the schedule
import matplotlib.pyplot as plt

steps = range(train_config.max_steps)
lrs = [get_lr(s, train_config) for s in steps]

plt.figure(figsize=(10, 4))
plt.plot(steps, lrs)
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.title('LR Schedule: Warmup + Cosine Decay')
plt.show()

---
## 5. Training Loop

The core training function - handles gradient accumulation, logging, and checkpointing.

In [None]:
def train_step(model, batch, optimizer, config, step):
    """
    Single training step with gradient accumulation.
    """
    x, y = batch
    x, y = x.to(device), y.to(device)
    
    # Forward pass
    logits = model(x)  # [batch, seq, vocab]
    
    # Cross-entropy loss (flatten for computation)
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        y.view(-1)
    )
    
    # Scale loss for gradient accumulation
    loss = loss / config.gradient_accumulation_steps
    loss.backward()
    
    return loss.item() * config.gradient_accumulation_steps  # Return unscaled loss


def train(model, train_loader, config, checkpoint_name="model"):
    """
    Main training loop.
    """
    model.to(device)
    model.train()
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
        betas=(0.9, 0.95)
    )
    
    # Training state
    step = 0
    running_loss = 0.0
    start_time = time.time()
    
    print(f"Starting training for {config.max_steps} steps...")
    print(f"Batch size: {config.batch_size}, Grad accum: {config.gradient_accumulation_steps}")
    print("-" * 60)
    
    data_iter = iter(train_loader)
    
    while step < config.max_steps:
        # Update learning rate
        lr = get_lr(step, config)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        # Gradient accumulation loop
        optimizer.zero_grad()
        accum_loss = 0.0
        
        for micro_step in range(config.gradient_accumulation_steps):
            try:
                batch = next(data_iter)
            except StopIteration:
                # Restart from beginning if we run out of data
                data_iter = iter(train_loader)
                batch = next(data_iter)
            
            loss = train_step(model, batch, optimizer, config, step)
            accum_loss += loss
        
        # Clip gradients (prevents exploding gradients)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        optimizer.step()
        
        running_loss += accum_loss
        step += 1
        
        # Logging
        if step % config.eval_interval == 0:
            avg_loss = running_loss / config.eval_interval
            elapsed = time.time() - start_time
            steps_per_sec = step / elapsed
            
            print(f"Step {step:6d} | Loss: {avg_loss:.4f} | LR: {lr:.2e} | {steps_per_sec:.1f} steps/s")
            running_loss = 0.0
        
        # Save checkpoint
        if step % config.save_interval == 0:
            save_path = CHECKPOINT_DIR / f"{checkpoint_name}_step{step}.pt"
            torch.save({
                'step': step,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss
            }, save_path)
            print(f"  â†’ Saved checkpoint: {save_path.name}")
    
    print("-" * 60)
    print(f"Training complete! Final loss: {avg_loss:.4f}")
    
    return model

---
## 6. Phase A: Base Pretraining

Train on `base_stream.txt` at 2k context with full attention.

In [None]:
# Import model from previous notebook (or copy the code here)
# For now, we'll reference the model we built

# Load model config
with open('../configs/model_config.json', 'r') as f:
    model_config_dict = json.load(f)

print("Model configuration:")
for k, v in model_config_dict.items():
    print(f"  {k}: {v}")

In [None]:
# Phase A configuration
phase_a_config = TrainingConfig(
    batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=3e-4,
    warmup_steps=1000,
    max_steps=50000,
    seq_len=2048  # 2k context for Phase A
)

print("Phase A: Base Pretraining")
print(f"  Context length: {phase_a_config.seq_len}")
print(f"  Data: base_stream.txt")
print(f"  Attention: Full")

In [None]:
# NOTE: This cell would actually run training - commented out for safety
# Uncomment to train

'''
# Create dataset and dataloader
train_dataset = TextDataset(
    DATA_DIR / "base_stream.txt",
    tokenizer,
    phase_a_config.seq_len
)

train_loader = DataLoader(
    train_dataset,
    batch_size=phase_a_config.batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

# Create model (import from notebook 03)
from model import SLM, ModelConfig  # You'd need to save model.py from notebook 03

model_config = ModelConfig(
    use_block_local=False,  # Full attention for Phase A
    max_seq_len=2048
)
model = SLM(model_config)

# Train!
model = train(model, train_loader, phase_a_config, checkpoint_name="phase_a")
'''

print("Phase A training code ready - uncomment to run")

---
## 7. Phase B: Context Extension

Extend from 2k to 5k context using RoPE scaling and block-local attention.

In [None]:
# Phase B configuration
phase_b_config = TrainingConfig(
    batch_size=4,  # Smaller batch - longer sequences use more memory
    gradient_accumulation_steps=8,
    learning_rate=1e-4,  # Lower LR for fine-tuning
    warmup_steps=500,
    max_steps=10000,
    seq_len=4096  # Extended context
)

print("Phase B: Context Extension")
print(f"  Context length: {phase_b_config.seq_len}")
print(f"  Data: base_stream.txt (same distribution)")
print(f"  Attention: Block-local (512 token blocks)")

In [None]:
def apply_rope_scaling(model, scale_factor: float = 2.0):
    """
    Scale RoPE frequencies for context extension.
    
    Linear scaling: divide frequencies by scale_factor.
    This lets the model handle positions it hasn't seen before.
    """
    for block in model.blocks:
        if hasattr(block.attention, 'rotary'):
            # Scale the inverse frequencies
            block.attention.rotary.inv_freq = block.attention.rotary.inv_freq / scale_factor
            # Rebuild the cache for the new context length
            block.attention.rotary._build_cache(model.config.max_seq_len)
    
    print(f"Applied RoPE scaling with factor {scale_factor}")

# This would be applied when loading a Phase A checkpoint for Phase B
print("RoPE scaling function ready")

---
## 8. Phase C: Domain Fine-Tuning

Fine-tune on control systems, nuclear, and reliability texts.

**Key principle:** No labels, no answers, no chain-of-thought forcing. We're shaping *how* the model reasons, not *what* conclusions it reaches.

In [None]:
class MultiStreamDataset(Dataset):
    """
    Dataset that samples from multiple fine-tuning streams.
    
    We want the model to see all domains, so we mix them.
    """
    
    def __init__(self, file_paths: list, tokenizer: Tokenizer, seq_len: int):
        self.seq_len = seq_len
        self.all_tokens = []
        
        for path in file_paths:
            print(f"Loading {path.name}...")
            with open(path, 'r', encoding='utf-8') as f:
                text = f.read()
            
            encoding = tokenizer.encode(text)
            tokens = torch.tensor(encoding.ids, dtype=torch.long)
            self.all_tokens.append(tokens)
            print(f"  â†’ {len(tokens):,} tokens")
        
        # Concatenate all streams
        self.tokens = torch.cat(self.all_tokens)
        self.n_chunks = (len(self.tokens) - 1) // seq_len
        print(f"\nTotal: {len(self.tokens):,} tokens, {self.n_chunks:,} chunks")
    
    def __len__(self):
        return self.n_chunks
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        end = start + self.seq_len + 1
        chunk = self.tokens[start:end]
        return chunk[:-1], chunk[1:]

# Fine-tuning files
finetune_files = [
    DATA_DIR / "finetune_control.txt",
    DATA_DIR / "finetune_nuclear.txt",
    DATA_DIR / "finetune_reliability.txt"
]

print("Fine-tuning streams:")
for f in finetune_files:
    print(f"  - {f.name}")

In [None]:
# Phase C configuration
phase_c_config = TrainingConfig(
    batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,  # Even lower LR for fine-tuning
    warmup_steps=200,
    max_steps=5000,
    seq_len=4096  # Keep extended context
)

print("Phase C: Domain Fine-Tuning")
print(f"  Context length: {phase_c_config.seq_len}")
print(f"  Data: control + nuclear + reliability streams")
print(f"  Attention: Block-local")
print(f"\nRemember: NO labels, NO answers - just exposure to domain text")

---
## 9. Putting It All Together

Here's the complete training pipeline as a script.

In [None]:
def full_training_pipeline():
    """
    Complete 3-phase training pipeline.
    
    This would take a long time to run - it's here as a reference.
    """
    
    # ============ PHASE A ============
    print("=" * 60)
    print("PHASE A: Base Pretraining")
    print("=" * 60)
    
    # Create model with full attention
    model_config = ModelConfig(use_block_local=False, max_seq_len=2048)
    model = SLM(model_config)
    
    # Create dataset
    train_dataset = TextDataset(DATA_DIR / "base_stream.txt", tokenizer, 2048)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    
    # Train Phase A
    model = train(model, train_loader, phase_a_config, "phase_a")
    
    # ============ PHASE B ============
    print("\n" + "=" * 60)
    print("PHASE B: Context Extension")
    print("=" * 60)
    
    # Switch to block-local attention and extend context
    model.config.use_block_local = True
    model.config.max_seq_len = 4096
    apply_rope_scaling(model, scale_factor=2.0)
    
    # Same data, longer sequences
    train_dataset = TextDataset(DATA_DIR / "base_stream.txt", tokenizer, 4096)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    
    # Train Phase B
    model = train(model, train_loader, phase_b_config, "phase_b")
    
    # ============ PHASE C ============
    print("\n" + "=" * 60)
    print("PHASE C: Domain Fine-Tuning")
    print("=" * 60)
    
    # Fine-tuning data
    train_dataset = MultiStreamDataset(finetune_files, tokenizer, 4096)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    
    # Train Phase C
    model = train(model, train_loader, phase_c_config, "phase_c_final")
    
    print("\n" + "=" * 60)
    print("TRAINING COMPLETE!")
    print("=" * 60)
    
    return model

print("Full training pipeline defined")
print("Run full_training_pipeline() to start (this will take many hours)")

---
## Summary

We've set up the complete training infrastructure:

| Phase | Context | Attention | Data | LR |
|-------|---------|-----------|------|----|
| A | 2048 | Full | base_stream | 3e-4 |
| B | 4096 | Block-local | base_stream | 1e-4 |
| C | 4096 | Block-local | finetune_* | 5e-5 |

Each phase builds on the previous checkpoint, progressively extending context and shaping reasoning.

**Next:** In notebook 05, we'll evaluate the trained model with zero-shot prompts.