# Tutorial 3: Training a Transformer from Scratch

In this notebook, we'll build and train a complete GPT-style transformer model. You'll see:
- Complete transformer architecture assembly
- Training loop with backpropagation
- Adam optimizer in action
- Loss curves and learning dynamics
- Text generation from your trained model

## Prerequisites

Complete notebooks 01 (tensors) and 02 (attention) first. This builds on those concepts.

In [None]:
import (
    "fmt"
    "math"
    "math/rand"
    "github.com/scttfrdmn/local-code-model"
)

## Step 1: Prepare Training Data

For this tutorial, we'll train on a tiny text corpus. In practice, you'd use much more data.

**What's happening:**
- Tokenize text into integer IDs
- Create training sequences (input + target pairs)
- Batch sequences for efficient training

In [None]:
// Simple training corpus
corpus := []string{
    "The quick brown fox jumps over the lazy dog.",
    "A journey of a thousand miles begins with a single step.",
    "To be or not to be, that is the question.",
    "All that glitters is not gold.",
}

// Build character-level tokenizer
tokenizer := main.NewTokenizer()
tokenizer.SetType("char")
if err := tokenizer.Train(corpus, 100); err != nil {
    panic(err)
}

fmt.Printf("Vocabulary size: %d\n", tokenizer.VocabSize())

// Encode corpus
var tokens []int
for _, text := range corpus {
    tokens = append(tokens, tokenizer.Encode(text)...)
}

fmt.Printf("Total tokens: %d\n", len(tokens))
fmt.Printf("Sample tokens: %v\n", tokens[:20])

## Step 2: Create Training Dataset

Split tokens into sequences for training. Each sequence has:
- **Input**: tokens[i:i+seqLen]
- **Target**: tokens[i+1:i+seqLen+1]

The model learns to predict the next token given previous tokens.

In [None]:
seqLen := 16
batchSize := 2

// Create dataset
dataset := main.NewTextDataset(tokens, seqLen)

fmt.Printf("Dataset size: %d sequences\n", dataset.Size())
fmt.Printf("Number of batches: %d\n", dataset.Size()/batchSize)

// Look at one example
input, target := dataset.GetBatch(0, 1, seqLen)
fmt.Printf("\nExample training pair:\n")
fmt.Printf("Input:  %v\n", input.Data()[:seqLen])
fmt.Printf("Target: %v\n", target.Data()[:seqLen])
fmt.Printf("Decoded input: %q\n", tokenizer.Decode(input.AsInts()[:seqLen]))

## Step 3: Build the Transformer

Now let's create a tiny GPT model:
- 2 layers (very small!)
- 32 embedding dimensions
- 2 attention heads
- 16 token context window

This is ~10K parameters vs 125M for GPT-2 small.

In [None]:
// Model hyperparameters
config := main.TransformerConfig{
    VocabSize:    tokenizer.VocabSize(),
    NumLayers:    2,
    EmbedDim:     32,
    NumHeads:     2,
    FFDim:        64,  // Usually 4x embed_dim
    SeqLen:       seqLen,
    DropoutRate:  0.1,
}

model := main.NewTransformer(config)

fmt.Printf("Model architecture:\n")
fmt.Printf("  Layers: %d\n", config.NumLayers)
fmt.Printf("  Embed dim: %d\n", config.EmbedDim)
fmt.Printf("  Num heads: %d\n", config.NumHeads)
fmt.Printf("  FF dim: %d\n", config.FFDim)
fmt.Printf("  Sequence length: %d\n", config.SeqLen)

// Count parameters
numParams := model.NumParameters()
fmt.Printf("\nTotal parameters: %d\n", numParams)

## Step 4: Initialize Optimizer

We'll use **Adam** (Adaptive Moment Estimation):
- Maintains moving averages of gradients (momentum)
- Adapts learning rate per parameter
- Works much better than plain SGD

**Hyperparameters:**
- Learning rate: 0.001 (will decay during training)
- Beta1: 0.9 (momentum term)
- Beta2: 0.999 (variance term)

In [None]:
// Adam optimizer
learningRate := 0.001
optimizer := main.NewAdamOptimizer(learningRate, 0.9, 0.999, 1e-8)

fmt.Printf("Optimizer: Adam\n")
fmt.Printf("  Initial learning rate: %.4f\n", learningRate)
fmt.Printf("  Beta1 (momentum): %.2f\n", 0.9)
fmt.Printf("  Beta2 (variance): %.3f\n", 0.999)

## Step 5: Training Loop

The training loop:
1. **Forward pass**: Compute model predictions
2. **Loss**: Compare predictions to targets (cross-entropy)
3. **Backward pass**: Compute gradients via backpropagation
4. **Update**: Apply gradients using optimizer
5. **Repeat**: Until loss converges

We'll train for just 10 epochs on our tiny dataset to see the model learn.

In [None]:
epochs := 10
numBatches := dataset.Size() / batchSize

fmt.Printf("Training for %d epochs...\n\n", epochs)

// Track loss history
var lossHistory []float64

for epoch := 0; epoch < epochs; epoch++ {
    epochLoss := 0.0
    
    for batch := 0; batch < numBatches; batch++ {
        // Get batch
        input, target := dataset.GetBatch(batch*batchSize, batchSize, seqLen)
        
        // Forward pass
        logits := model.Forward(input, true)  // training=true for dropout
        
        // Compute loss
        loss := main.CrossEntropyLoss(logits, target)
        epochLoss += loss
        
        // Backward pass
        gradLoss := main.Ones(logits.Shape()...)  // dloss/dlogits = 1 initially
        model.Backward(gradLoss)
        
        // Update weights
        optimizer.Step(model.Parameters())
        
        // Zero gradients
        model.ZeroGrad()
    }
    
    avgLoss := epochLoss / float64(numBatches)
    lossHistory = append(lossHistory, avgLoss)
    
    fmt.Printf("Epoch %2d/%d, Loss: %.4f\n", epoch+1, epochs, avgLoss)
}

fmt.Printf("\nTraining complete!\n")

## Step 6: Visualize Training

Let's plot the loss curve to see how the model learned over time.

**What to look for:**
- Loss should decrease (model is learning)
- May fluctuate (small dataset, small batch size)
- Eventual plateau (model capacity reached)

In [None]:
fmt.Printf("Loss curve:\n")
fmt.Printf("Epoch  Loss\n")
fmt.Printf("-----  ----\n")

for i, loss := range lossHistory {
    // Simple ASCII bar chart
    barLen := int(loss * 10)  // Scale for visibility
    if barLen > 50 {
        barLen = 50
    }
    bar := ""
    for j := 0; j < barLen; j++ {
        bar += "█"
    }
    fmt.Printf("%5d  %.4f %s\n", i+1, loss, bar)
}

// Summary statistics
initialLoss := lossHistory[0]
finalLoss := lossHistory[len(lossHistory)-1]
improvement := initialLoss - finalLoss
percentImprovement := (improvement / initialLoss) * 100

fmt.Printf("\nInitial loss: %.4f\n", initialLoss)
fmt.Printf("Final loss:   %.4f\n", finalLoss)
fmt.Printf("Improvement:  %.4f (%.1f%%)\n", improvement, percentImprovement)

## Step 7: Generate Text

Now let's use our trained model to generate text!

**How generation works:**
1. Start with a prompt (seed text)
2. Model predicts next token probability distribution
3. Sample from distribution (with temperature)
4. Append sampled token to input
5. Repeat

**Temperature** controls randomness:
- Low (0.5): Conservative, predictable
- Medium (1.0): Balanced
- High (1.5): Creative, random

In [None]:
// Generation function
func generate(model *main.Transformer, tokenizer *main.Tokenizer, prompt string, maxLen int, temperature float64) string {
    // Encode prompt
    tokens := tokenizer.Encode(prompt)
    
    // Generate tokens one by one
    for i := 0; i < maxLen; i++ {
        // Take last seqLen tokens as input
        start := 0
        if len(tokens) > seqLen {
            start = len(tokens) - seqLen
        }
        input := tokens[start:]
        
        // Convert to tensor
        inputTensor := main.NewTensor([]int{1, len(input)}, toFloat64Slice(input))
        
        // Forward pass (no dropout during inference)
        logits := model.Forward(inputTensor, false)
        
        // Get logits for last position
        lastLogits := logits.Slice(0, 1).Slice(len(input)-1, len(input))
        
        // Apply temperature
        scaledLogits := lastLogits.Scale(1.0 / temperature)
        
        // Softmax to get probabilities
        probs := scaledLogits.Softmax(1)
        
        // Sample next token
        nextToken := sample(probs)
        tokens = append(tokens, nextToken)
        
        // Stop if we hit end-of-sequence
        if nextToken == tokenizer.EosID() {
            break
        }
    }
    
    return tokenizer.Decode(tokens)
}

func toFloat64Slice(ints []int) []float64 {
    floats := make([]float64, len(ints))
    for i, v := range ints {
        floats[i] = float64(v)
    }
    return floats
}

func sample(probs *main.Tensor) int {
    // Sample from probability distribution
    r := rand.Float64()
    cumProb := 0.0
    probData := probs.Data()
    
    for i, p := range probData {
        cumProb += p
        if r < cumProb {
            return i
        }
    }
    return len(probData) - 1  // Fallback
}

fmt.Println("Generation functions defined!")

## Generate Some Text

Let's try a few different prompts and temperatures.

In [None]:
// Try different prompts
prompts := []string{
    "The ",
    "To be",
    "All that",
}

temperature := 0.8
maxLen := 30

fmt.Printf("Generating text (temp=%.1f):\n\n", temperature)

for _, prompt := range prompts {
    generated := generate(model, tokenizer, prompt, maxLen, temperature)
    fmt.Printf("Prompt: %q\n", prompt)
    fmt.Printf("Generated: %q\n\n", generated)
}

## Temperature Effect

Compare different temperature settings:

In [None]:
prompt := "The quick"
temperatures := []float64{0.3, 0.8, 1.5}

fmt.Printf("Prompt: %q\n\n", prompt)

for _, temp := range temperatures {
    generated := generate(model, tokenizer, prompt, 20, temp)
    fmt.Printf("Temperature %.1f: %q\n", temp, generated)
}

fmt.Printf("\nNotice how higher temperature produces more varied (and less coherent) text!\n")

## Key Takeaways

Congratulations! You've trained a transformer from scratch. Here's what we covered:

1. **Data preparation**: Tokenization and sequence creation
2. **Model architecture**: Complete transformer with attention and feed-forward layers
3. **Training loop**: Forward pass → loss → backward pass → update
4. **Optimization**: Adam optimizer with adaptive learning rates
5. **Generation**: Autoregressive sampling with temperature control

## What Makes This Work?

- **Attention**: Model learns which tokens to focus on
- **Self-supervision**: Next-token prediction provides unlimited training signal
- **Depth**: Multiple layers compose increasingly abstract representations
- **Scale**: More data + bigger models = better results

## Limitations of This Tiny Model

- **Tiny dataset**: Only a few sentences (real models: billions of tokens)
- **Small model**: ~10K parameters (GPT-3: 175B parameters)
- **Short context**: 16 tokens (modern models: 8K-100K tokens)
- **No regularization**: Would overfit on real tasks

## Next Steps

**To improve this model:**
1. Train on more data (download WikiText or similar)
2. Increase model size (more layers, larger embeddings)
3. Longer training (hundreds of epochs)
4. Add learning rate scheduling (warmup + decay)
5. Implement gradient clipping
6. Use BPE tokenization instead of character-level

**To learn more:**
- Read the main codebase (see `../transformer.go`, `../train.go`)
- Study attention patterns with `cmd_visualize.go`
- Review backpropagation in `../docs/backpropagation.md`
- Explore training dynamics in `../docs/training-dynamics.md`

## Exercise

Try modifying the hyperparameters:
- Change `NumLayers` to 4
- Increase `EmbedDim` to 64
- Train for 20 epochs instead of 10

Does the model learn better? Does generation improve?

In [None]:
// TODO: Experiment with hyperparameters
// Try:
// - Different learning rates
// - More layers
// - Larger embeddings
// - Longer training
//
// Observe how each change affects:
// - Training loss curve
// - Generation quality
// - Training time