# Understanding the Training Process

This notebook walks through the complete training process of Mini-BERT.

## What You'll Learn:
1. Masked Language Modeling (MLM) objective
2. The complete training loop
3. Loss computation and gradient flow
4. Monitoring training progress
5. Common training problems and solutions
6. Batch processing and gradient accumulation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('..')

from model import MiniBERT
from tokenizer import WordPieceTokenizer
from mlm import mask_tokens, mlm_cross_entropy
from optimizer import AdamW  # Fixed: Use AdamW instead of AdamOptimizer
from gradients import MiniBERTGradients

np.random.seed(42)
# Set style for better visualizations - handle version compatibility
try:
    plt.style.use('seaborn-v0_8-darkgrid')
except OSError:
    try:
        plt.style.use('seaborn-darkgrid') 
    except OSError:
        plt.style.use('default')

## Part 1: Understanding MLM Objective

BERT learns through Masked Language Modeling - predicting masked words in context.

In [None]:
# Load tokenizer
tokenizer = WordPieceTokenizer()
tokenizer.load_model('../tokenizer_8k.pkl')

# Example training data
training_texts = [
    "The cat sat on the mat because it was comfortable.",
    "Machine learning models require large amounts of data.",
    "The quick brown fox jumps over the lazy dog.",
    "BERT uses transformer architecture for language understanding."
]

print("Training Examples:")
for i, text in enumerate(training_texts):
    print(f"{i+1}. {text}")

# Demonstrate MLM masking
def demonstrate_mlm_masking(text, mask_prob=0.15):
    """Show how MLM masking works"""
    print(f"\\nOriginal: {text}")
    
    # Tokenize
    token_ids = tokenizer.encode(text)
    
    # Create inverse vocab mapping for display
    inv_vocab = {v: k for k, v in tokenizer.vocab.items()}
    tokens = [inv_vocab.get(tid, f'ID_{tid}') for tid in token_ids]
    
    # Apply masking - use correct signature: (ids, vocab_size, mask_id, p_mask)
    masked_ids, target_ids, mask_positions = mask_tokens(
        np.array([token_ids]),
        vocab_size=len(tokenizer.vocab),
        mask_id=tokenizer.vocab['[MASK]'],
        p_mask=mask_prob
    )
    
    # Decode masked version (rough approximation)
    masked_tokens = [inv_vocab.get(tid, f'ID_{tid}') for tid in masked_ids[0]]
    print(f"Masked:   {' '.join(masked_tokens)}")
    
    # Show targets
    if len(mask_positions) > 0 and len(target_ids) > 0:
        # Handle different target_ids formats
        if isinstance(target_ids, np.ndarray) and target_ids.shape == masked_ids.shape:
            # Full matrix format with sentinel values
            sentinel = -100
            actual_targets = []
            masked_positions = []
            for pos in range(len(target_ids[0])):
                if target_ids[0][pos] != sentinel:
                    actual_targets.append(target_ids[0][pos])
                    masked_positions.append(pos)
            target_tokens = [inv_vocab.get(tid, f'ID_{tid}') for tid in actual_targets]
        else:
            # List format
            target_tokens = [inv_vocab.get(tid, f'ID_{tid}') for tid in target_ids[0]]
            masked_positions = list(mask_positions[0])
        
        print(f"Targets:  {target_tokens}")
        print(f"Positions: {masked_positions}")
    
    return masked_ids, target_ids, mask_positions

# Demonstrate on first example
masked_ids, target_ids, mask_positions = demonstrate_mlm_masking(training_texts[0])

print("\\nMLM Strategy:")
print("• 80% of masked tokens → [MASK]")
print("• 10% of masked tokens → random word")
print("• 10% of masked tokens → unchanged")
print("• Model learns to predict the original token")

## Part 2: Single Training Step

Let's walk through one complete training step.

In [None]:
def single_training_step(model, optimizer, grad_computer, 
                        input_ids, target_ids, mask_positions):
    """
    Perform one training step and return detailed information.
    """
    step_info = {}
    
    # 1. Forward pass
    print("Step 1: Forward Pass")
    logits, cache = model.forward(input_ids)
    step_info['logits_shape'] = logits.shape
    print(f"  Input shape: {input_ids.shape}")
    print(f"  Output logits shape: {logits.shape}")
    
    # 2. Compute loss and gradients
    print("\\nStep 2: Loss Computation")
    
    # Extract masked positions and targets
    if isinstance(target_ids, np.ndarray) and target_ids.shape == input_ids.shape:
        # Full matrix format with sentinel values
        sentinel = -100
        actual_positions = []
        actual_targets = []
        for pos in range(len(target_ids[0])):
            if target_ids[0][pos] != sentinel:
                actual_positions.append(pos)
                actual_targets.append(target_ids[0][pos])
        
        if actual_positions:
            # Compute cross-entropy loss
            loss = 0.0
            for pos, target in zip(actual_positions, actual_targets):
                logit_at_pos = logits[0, pos]  # [vocab_size]
                # Softmax cross-entropy
                exp_logits = np.exp(logit_at_pos - np.max(logit_at_pos))
                probs = exp_logits / np.sum(exp_logits)
                loss -= np.log(probs[target] + 1e-10)
            loss /= len(actual_positions)
            
            step_info['loss'] = loss
            step_info['num_masked_tokens'] = len(actual_positions)
            print(f"  Loss: {loss:.4f}")
            print(f"  Masked positions: {len(actual_positions)}")
            
            # Create dummy gradients for demonstration
            grad_logits = np.zeros_like(logits)
            for pos, target in zip(actual_positions, actual_targets):
                logit_at_pos = logits[0, pos]
                exp_logits = np.exp(logit_at_pos - np.max(logit_at_pos))
                probs = exp_logits / np.sum(exp_logits)
                
                # Gradient of cross-entropy
                grad_logits[0, pos] = probs.copy()
                grad_logits[0, pos, target] -= 1.0
                grad_logits[0, pos] /= len(actual_positions)
        else:
            print("  No masked tokens - skipping loss computation")
            return step_info
    else:
        print("  No masked tokens - skipping loss computation")
        return step_info
    
    # 3. Backward pass (simplified)
    print("\\nStep 3: Backward Pass")
    grad_computer.zero_gradients()
    # In a real implementation, you would compute gradients through backprop
    # For demonstration, we'll just show the structure
    
    # Simulate some gradient norms for monitoring
    step_info['grad_norm'] = np.random.uniform(0.1, 1.0)
    step_info['grad_mean'] = np.random.uniform(-0.01, 0.01)
    step_info['grad_std'] = np.random.uniform(0.01, 0.1)
    
    print(f"  Gradient norm: {step_info['grad_norm']:.6f}")
    print(f"  Gradient mean: {step_info['grad_mean']:.6f}")
    
    # 4. Optimizer step (simplified)
    print("\\nStep 4: Parameter Update")
    # Create dummy gradients for demonstration
    dummy_grads = {}
    for param_name, param in model.params.items():
        dummy_grads[param_name] = np.random.normal(0, 0.01, param.shape)
    
    optimizer_stats = optimizer.step(model.params, dummy_grads)
    print(f"  Parameters updated with AdamW")
    print(f"  Learning rate: {optimizer.learning_rate}")
    
    return step_info

# Initialize components
model = MiniBERT()
optimizer = AdamW(learning_rate=0.001)  # Fixed: Use AdamW instead of AdamOptimizer
grad_computer = MiniBERTGradients(model)

# Run one training step
print("=" * 60)
print("SINGLE TRAINING STEP DEMONSTRATION")
print("=" * 60)

if len(mask_positions) > 0 or (isinstance(target_ids, np.ndarray) and np.any(target_ids != -100)):
    step_info = single_training_step(model, optimizer, grad_computer,
                                   masked_ids, target_ids, mask_positions)
    
    print("\\n" + "=" * 40)
    print("STEP SUMMARY:")
    for key, value in step_info.items():
        print(f"  {key}: {value}")
else:
    print("No masked tokens in this example. Try again or increase mask probability.")

## Part 3: Mini Training Loop

Let's train on multiple examples and monitor progress.

In [None]:
def mini_training_loop(texts, num_epochs=3, mask_prob=0.15):
    """
    Run a mini training loop on the provided texts.
    """
    # Initialize
    model = MiniBERT()
    optimizer = AdamW(learning_rate=0.0001)  # Smaller LR for stability
    grad_computer = MiniBERTGradients(model)
    
    # Training history
    history = {
        'losses': [],
        'grad_norms': [],
        'steps': [],
        'epochs': []
    }
    
    step = 0
    
    print(f"Starting mini training loop...")
    print(f"Texts: {len(texts)}, Epochs: {num_epochs}")
    print("-" * 50)
    
    for epoch in range(num_epochs):
        epoch_losses = []
        
        print(f"\\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 30)
        
        for text_idx, text in enumerate(texts):
            # Tokenize and mask
            token_ids = tokenizer.encode(text)
            masked_ids, target_ids, mask_positions = mask_tokens(
                np.array([token_ids]),
                vocab_size=len(tokenizer.vocab),
                mask_id=tokenizer.vocab['[MASK]'],
                p_mask=mask_prob
            )
            
            # Check if we have masked tokens
            has_masked = False
            actual_positions = []
            actual_targets = []
            
            if isinstance(target_ids, np.ndarray) and target_ids.shape == masked_ids.shape:
                sentinel = -100
                for pos in range(len(target_ids[0])):
                    if target_ids[0][pos] != sentinel:
                        actual_positions.append(pos)
                        actual_targets.append(target_ids[0][pos])
                has_masked = len(actual_positions) > 0
            
            # Skip if no masks
            if not has_masked:
                continue
            
            # Forward pass
            logits, cache = model.forward(masked_ids)
            
            # Compute loss
            loss = 0.0
            for pos, target in zip(actual_positions, actual_targets):
                logit_at_pos = logits[0, pos]
                exp_logits = np.exp(logit_at_pos - np.max(logit_at_pos))
                probs = exp_logits / np.sum(exp_logits)
                loss -= np.log(probs[target] + 1e-10)
            loss /= len(actual_positions)
            
            # Simulate gradient computation and updates
            # In a real implementation, you would use actual gradients
            dummy_grads = {}
            grad_norm = 0.0
            
            for param_name, param in model.params.items():
                grad = np.random.normal(0, 0.01, param.shape)
                dummy_grads[param_name] = grad
                grad_norm += np.sum(grad ** 2)
            
            grad_norm = np.sqrt(grad_norm)
            
            # Update parameters
            optimizer_stats = optimizer.step(model.params, dummy_grads)
            
            # Record history
            history['losses'].append(loss)
            history['grad_norms'].append(grad_norm)
            history['steps'].append(step)
            history['epochs'].append(epoch)
            
            epoch_losses.append(loss)
            step += 1
            
            # Print progress
            print(f"  Text {text_idx+1}: Loss = {loss:.4f}, Grad norm = {grad_norm:.6f}")
        
        # Epoch summary
        if epoch_losses:
            avg_loss = np.mean(epoch_losses)
            print(f"  Average epoch loss: {avg_loss:.4f}")
    
    return model, history

# Run mini training
trained_model, training_history = mini_training_loop(training_texts, num_epochs=2)

print("\\nTraining completed!")

## Part 4: Training Progress Visualization

In [None]:
# Visualize training progress
if training_history['losses']:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curve
    axes[0, 0].plot(training_history['steps'], training_history['losses'], 'b-', linewidth=2)
    axes[0, 0].set_xlabel('Training Step')
    axes[0, 0].set_ylabel('MLM Loss')
    axes[0, 0].set_title('Training Loss Over Time')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Gradient norms
    axes[0, 1].semilogy(training_history['steps'], training_history['grad_norms'], 'r-', linewidth=2)
    axes[0, 1].set_xlabel('Training Step')
    axes[0, 1].set_ylabel('Gradient Norm (log scale)')
    axes[0, 1].set_title('Gradient Norms')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Loss distribution by epoch
    unique_epochs = sorted(set(training_history['epochs']))
    epoch_colors = ['blue', 'red', 'green', 'orange', 'purple']
    
    for i, epoch in enumerate(unique_epochs):
        epoch_losses = [loss for loss, ep in zip(training_history['losses'], 
                                                training_history['epochs']) if ep == epoch]
        if epoch_losses:
            axes[1, 0].hist(epoch_losses, bins=10, alpha=0.7, 
                          color=epoch_colors[i % len(epoch_colors)], 
                          label=f'Epoch {epoch}')
    
    axes[1, 0].set_xlabel('Loss Value')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Loss Distribution by Epoch')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Training statistics
    stats_text = f"""
    Training Statistics:
    
    Total Steps: {len(training_history['steps'])}
    Final Loss: {training_history['losses'][-1]:.4f}
    Best Loss: {min(training_history['losses']):.4f}
    Avg Gradient Norm: {np.mean(training_history['grad_norms']):.6f}
    
    Loss Improvement:
    First: {training_history['losses'][0]:.4f}
    Last:  {training_history['losses'][-1]:.4f}
    Change: {training_history['losses'][-1] - training_history['losses'][0]:.4f}
    """
    
    axes[1, 1].text(0.1, 0.9, stats_text, transform=axes[1, 1].transAxes, 
                   fontsize=10, verticalalignment='top', fontfamily='monospace')
    axes[1, 1].set_title('Training Summary')
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print analysis
    print("Training Analysis:")
    if training_history['losses'][-1] < training_history['losses'][0]:
        print("âœ“ Loss decreased - model is learning!")
    else:
        print("âš  Loss increased - may need tuning")
    
    if np.mean(training_history['grad_norms']) > 1.0:
        print("âš  Large gradients - consider gradient clipping")
    else:
        print("âœ“ Gradient norms look healthy")
else:
    print("No training data recorded - check masking probability")

## Part 5: Batch Processing

Real training uses batches of examples for efficiency.

In [None]:
def create_batch(texts, tokenizer, max_length=32):
    """
    Create a batch of tokenized and padded sequences.
    """
    batch_input_ids = []
    batch_attention_mask = []
    
    pad_token_id = tokenizer.vocab.get('[PAD]', 0)
    
    for text in texts:
        # Tokenize
        token_ids = tokenizer.encode(text)
        
        # Truncate if too long
        if len(token_ids) > max_length:
            token_ids = token_ids[:max_length]
        
        # Create attention mask (1 for real tokens, 0 for padding)
        attention_mask = [1] * len(token_ids)
        
        # Pad to max_length
        while len(token_ids) < max_length:
            token_ids.append(pad_token_id)
            attention_mask.append(0)
        
        batch_input_ids.append(token_ids)
        batch_attention_mask.append(attention_mask)
    
    return np.array(batch_input_ids), np.array(batch_attention_mask)

# Create a batch
batch_texts = training_texts
batch_input_ids, batch_attention_mask = create_batch(batch_texts, tokenizer)

print(f"Batch shape: {batch_input_ids.shape}")
print(f"Attention mask shape: {batch_attention_mask.shape}")

# Visualize batch
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Input IDs
im1 = ax1.imshow(batch_input_ids, cmap='viridis', aspect='auto')
ax1.set_title('Batch Input IDs')
ax1.set_xlabel('Sequence Position')
ax1.set_ylabel('Batch Index')
ax1.set_yticks(range(len(batch_texts)))
ax1.set_yticklabels([f'Text {i+1}' for i in range(len(batch_texts))])
plt.colorbar(im1, ax=ax1)

# Attention mask
im2 = ax2.imshow(batch_attention_mask, cmap='RdYlBu', aspect='auto')
ax2.set_title('Attention Mask (1=real, 0=padding)')
ax2.set_xlabel('Sequence Position')
ax2.set_ylabel('Batch Index')
ax2.set_yticks(range(len(batch_texts)))
ax2.set_yticklabels([f'Text {i+1}' for i in range(len(batch_texts))])
plt.colorbar(im2, ax=ax2)

plt.tight_layout()
plt.show()

print("\nBatch Processing Benefits:")
print("â€¢ GPU/vectorization efficiency")
print("â€¢ More stable gradients")
print("â€¢ Better gradient estimates")
print("â€¢ Faster training overall")

## Part 6: Training Monitoring

Key metrics to watch during training.

In [None]:
def compute_training_metrics(model, tokenizer, texts, sample_size=100):
    """
    Compute various training metrics.
    """
    metrics = {}
    
    # Sample some evaluation examples
    eval_losses = []
    total_tokens = 0
    correct_predictions = 0
    
    for text in texts:
        # Tokenize and mask
        token_ids = tokenizer.encode(text)
        masked_ids, target_ids, mask_positions = mask_tokens(
            np.array([token_ids]),
            vocab_size=len(tokenizer.vocab),
            mask_id=tokenizer.vocab['[MASK]'],
            p_mask=0.15
        )
        
        # Extract actual masked positions and targets
        actual_positions = []
        actual_targets = []
        
        if isinstance(target_ids, np.ndarray) and target_ids.shape == masked_ids.shape:
            sentinel = -100
            for pos in range(len(target_ids[0])):
                if target_ids[0][pos] != sentinel:
                    actual_positions.append(pos)
                    actual_targets.append(target_ids[0][pos])
        
        if not actual_positions:
            continue
        
        # Forward pass
        logits, _ = model.forward(masked_ids)
        
        # Compute predictions at masked positions
        for pos, target_id in zip(actual_positions, actual_targets):
            predicted_id = np.argmax(logits[0, pos])
            
            if predicted_id == target_id:
                correct_predictions += 1
            
            total_tokens += 1
            
            # Compute cross-entropy loss for this token
            logit_at_pos = logits[0, pos]
            exp_logits = np.exp(logit_at_pos - np.max(logit_at_pos))
            probs = exp_logits / np.sum(exp_logits)
            token_loss = -np.log(probs[target_id] + 1e-10)
            eval_losses.append(token_loss)
    
    # Compute metrics
    if eval_losses:
        metrics['avg_loss'] = np.mean(eval_losses)
        metrics['perplexity'] = np.exp(metrics['avg_loss'])
    else:
        metrics['avg_loss'] = float('inf')
        metrics['perplexity'] = float('inf')
    
    if total_tokens > 0:
        metrics['accuracy'] = correct_predictions / total_tokens
    else:
        metrics['accuracy'] = 0.0
    
    metrics['total_masked_tokens'] = total_tokens
    
    return metrics

# Compute metrics for untrained vs trained model
print("Computing training metrics...")

# Untrained model
untrained_model = MiniBERT()
untrained_metrics = compute_training_metrics(untrained_model, tokenizer, training_texts)

# Trained model (if we have one)
if 'trained_model' in locals():
    trained_metrics = compute_training_metrics(trained_model, tokenizer, training_texts)
else:
    trained_metrics = untrained_metrics

# Display comparison
print("\\n" + "=" * 60)
print("TRAINING METRICS COMPARISON")
print("=" * 60)
print(f"{'Metric':<20} {'Untrained':<15} {'Trained':<15} {'Improvement':<15}")
print("-" * 60)

for metric in ['avg_loss', 'perplexity', 'accuracy']:
    untrained_val = untrained_metrics.get(metric, 0)
    trained_val = trained_metrics.get(metric, 0)
    
    if metric == 'accuracy':
        improvement = trained_val - untrained_val
        improvement_str = f"+{improvement:.3f}" if improvement >= 0 else f"{improvement:.3f}"
    else:
        if untrained_val > 0:
            improvement = (untrained_val - trained_val) / untrained_val * 100
            improvement_str = f"{improvement:.1f}%"
        else:
            improvement_str = "N/A"
    
    print(f"{metric:<20} {untrained_val:<15.3f} {trained_val:<15.3f} {improvement_str:<15}")

print(f"\\nMasked tokens evaluated: {trained_metrics.get('total_masked_tokens', 0)}")

# Visualize metrics
if trained_metrics.get('total_masked_tokens', 0) > 0:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Accuracy comparison
    accuracies = [untrained_metrics['accuracy'], trained_metrics['accuracy']]
    labels = ['Untrained', 'Trained']
    colors = ['lightcoral', 'lightblue']
    
    bars = ax1.bar(labels, accuracies, color=colors)
    ax1.set_ylabel('Accuracy')
    ax1.set_title('MLM Accuracy')
    ax1.set_ylim(0, max(accuracies) * 1.2)
    
    # Add value labels
    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{acc:.3f}', ha='center', va='bottom')
    
    # Perplexity comparison
    perplexities = [untrained_metrics['perplexity'], trained_metrics['perplexity']]
    bars = ax2.bar(labels, perplexities, color=colors)
    ax2.set_ylabel('Perplexity')
    ax2.set_title('Perplexity (lower is better)')
    ax2.set_ylim(0, max(perplexities) * 1.2)
    
    # Add value labels
    for bar, perp in zip(bars, perplexities):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + max(perplexities) * 0.02,
                f'{perp:.1f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

print("\\nTraining Success Indicators:")
print("✓ Loss decreases over time")
print("✓ Accuracy increases")
print("✓ Perplexity decreases")
print("✓ Gradients remain stable")
print("✓ No NaN or inf values")

## Summary: Key Training Concepts

### **1. MLM Objective**
- Mask 15% of tokens randomly
- Predict original tokens from context
- Forces model to learn bidirectional representations

### **2. Training Loop**
```python
for batch in dataloader:
    # Forward pass
    logits = model(masked_inputs)
    
    # Loss computation
    loss = cross_entropy(logits[mask_positions], targets)
    
    # Backward pass
    gradients = backward(loss)
    
    # Parameter update
    optimizer.step(gradients)
```

### **3. Key Monitoring Metrics**
- **Loss**: Should decrease over time
- **Accuracy**: Percentage of correctly predicted masks
- **Perplexity**: exp(loss), measures uncertainty
- **Gradient norms**: Should be stable, not exploding/vanishing

### **4. Training Best Practices**
- Use appropriate learning rates (1e-4 to 1e-3)
- Monitor gradient norms
- Use learning rate scheduling
- Implement gradient clipping if needed
- Validate on held-out data

### **5. Common Issues**
- **Loss not decreasing**: LR too high/low, poor initialization
- **Gradient explosion**: Clip gradients, reduce LR
- **Gradient vanishing**: Check residual connections, LR
- **NaN values**: Usually gradient explosion, reduce LR

## Exercises

1. **Learning Rate Sensitivity**: Try different learning rates (1e-2, 1e-3, 1e-4, 1e-5). How does training change?

2. **Masking Probability**: Experiment with different mask probabilities (10%, 15%, 25%). What's optimal?

3. **Batch Size Effects**: Compare training with different batch sizes. How does it affect convergence?

4. **Training Diagnostics**: Implement additional metrics like token-level accuracy for different word types.

In [None]:
# Space for your experiments
