# Griffin Model Training: Comparative Training Experiments

This notebook demonstrates how to train and compare the three model architectures:
- **Griffin**: Hybrid recurrence + attention model
- **Hawk**: Pure recurrent model
- **Local Attention**: Pure attention model

We'll train these models on both MQAR and Chomsky hierarchy datasets and track their learning curves.

In [None]:
# Import necessary libraries
import sys
import os
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime
import json

# Add project root to path
project_root = Path('.').absolute().parent
sys.path.append(str(project_root))

# Import project modules
from models.griffin.griffin_model import GriffinModel
from models.hawk.hawk_model import HawkModel
from models.local_attention.attention_model import LocalAttentionModel
from datasets.mqar import MQARDataset
from datasets.chomsky import ChomskyDataset
from training.trainer import Trainer
from evaluation.evaluator import ModelEvaluator

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Check GPU availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Configuration Setup

Let's load the model configurations and set up training parameters.

In [None]:
# Load configurations
config_dir = project_root / 'config'

def load_config(config_name):
    config_path = config_dir / f'{config_name}.yaml'
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

# Load model configs
griffin_config = load_config('griffin_config')
hawk_config = load_config('hawk_config')
attention_config = load_config('attention_config')
training_config = load_config('training_config')

print("Configurations loaded successfully!")
print(f"\nTraining configuration:")
print(f"  Epochs: {training_config['num_epochs']}")
print(f"  Batch size: {training_config['batch_size']}")
print(f"  Learning rate: {training_config['learning_rate']}")
print(f"  Weight decay: {training_config['weight_decay']}")

## 2. Dataset Preparation

Create training datasets for our experiments.

In [None]:
# Create datasets with notebook-friendly sizes
print("Creating datasets...")

# MQAR dataset
train_mqar, val_mqar, test_mqar = create_mqar_datasets(
    train_size=2000,  # Reduced for notebook
    val_size=400,
    test_size=400,
    seq_len=griffin_config['max_seq_len'],
    vocab_size=griffin_config['vocab_size'],
    num_kv_pairs=8,
    num_queries=2
)

# Parentheses dataset
train_paren, val_paren, test_paren = create_chomsky_datasets(
    dataset_type="parentheses",
    train_size=2000,
    val_size=400,
    test_size=400,
    max_length=128
)

# Create dataloaders
batch_size = training_config['batch_size']

datasets = {
    'MQAR': {
        'train': train_mqar.create_dataloader(batch_size=batch_size, shuffle=True),
        'val': val_mqar.create_dataloader(batch_size=batch_size, shuffle=False),
        'test': test_mqar.create_dataloader(batch_size=batch_size, shuffle=False),
        'vocab_size': train_mqar.get_vocab_size()
    },
    'Parentheses': {
        'train': train_paren.create_dataloader(batch_size=batch_size, shuffle=True),
        'val': val_paren.create_dataloader(batch_size=batch_size, shuffle=False),
        'test': test_paren.create_dataloader(batch_size=batch_size, shuffle=False),
        'vocab_size': train_paren.total_vocab_size
    }
}

print("\nDatasets created:")
for name, data in datasets.items():
    print(f"  {name}:")
    print(f"    Train batches: {len(data['train'])}")
    print(f"    Val batches: {len(data['val'])}")
    print(f"    Vocab size: {data['vocab_size']}")

## 3. Model Initialization

Create and initialize all three models.

In [None]:
def create_models(vocab_size):
    """Create all three models with the given vocabulary size."""
    
    # Griffin model
    griffin = GriffinModel(
        vocab_size=vocab_size,
        d_model=griffin_config['d_model'],
        num_layers=griffin_config['num_layers'],
        num_heads=griffin_config['num_heads'],
        max_seq_len=griffin_config['max_seq_len'],
        local_window=griffin_config['local_window'],
        mixing_alpha=griffin_config['mixing_alpha']
    )
    
    # Hawk model
    hawk = HawkModel(
        vocab_size=vocab_size,
        d_model=hawk_config['d_model'],
        num_layers=hawk_config['num_layers'],
        max_seq_len=hawk_config['max_seq_len']
    )
    
    # Local Attention model
    local_attention = LocalAttentionModel(
        vocab_size=vocab_size,
        d_model=attention_config['d_model'],
        num_layers=attention_config['num_layers'],
        num_heads=attention_config['num_heads'],
        max_seq_len=attention_config['max_seq_len'],
        local_window=attention_config['local_window']
    )
    
    return {
        'Griffin': griffin,
        'Hawk': hawk,
        'Local Attention': local_attention
    }

# Models will be created per dataset due to different vocab sizes
print("Model creation function ready!")

## 4. Training Function

Define a helper function to train models and track metrics.

In [None]:
def train_model_experiment(model, model_name, train_loader, val_loader, 
                          dataset_name, num_epochs=5):
    """Train a single model and return training history."""
    
    print(f"Training {model_name} on {dataset_name}...")
    
    # Create trainer
    trainer = BaseTrainer(
        model=model,
        device=device,
        learning_rate=training_config['learning_rate'],
        weight_decay=training_config['weight_decay'],
        warmup_steps=100,  # Reduced for notebook
        save_dir=project_root / 'notebooks' / 'training_outputs',
        experiment_name=f"{model_name}_{dataset_name}",
        mixed_precision=training_config.get('mixed_precision', False)
    )
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_perplexity': [],
        'val_perplexity': [],
        'epochs': []
    }
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"  Epoch {epoch + 1}/{num_epochs}")
        
        # Train for one epoch
        train_metrics = trainer.train_epoch(train_loader)
        
        # Validate
        val_metrics = trainer.validate(val_loader)
        
        # Store metrics
        history['train_loss'].append(train_metrics['loss'])
        history['val_loss'].append(val_metrics['loss'])
        history['train_perplexity'].append(train_metrics['perplexity'])
        history['val_perplexity'].append(val_metrics['perplexity'])
        history['epochs'].append(epoch + 1)
        
        print(f"    Train Loss: {train_metrics['loss']:.4f}, "
              f"Val Loss: {val_metrics['loss']:.4f}, "
              f"Val PPL: {val_metrics['perplexity']:.4f}")
    
    return history, trainer

print("Training function defined!")

## 5. Training Experiments

Now let's train all models on both datasets and compare their learning curves.

In [None]:
# Training experiments
all_results = {}
trained_models = {}

# Reduced epochs for notebook demonstration
num_epochs = 3

for dataset_name, dataset_info in datasets.items():
    print(f"\n{'='*60}")
    print(f"TRAINING ON {dataset_name} DATASET")
    print(f"{'='*60}")
    
    # Create models for this dataset
    models = create_models(dataset_info['vocab_size'])
    
    dataset_results = {}
    dataset_models = {}
    
    for model_name, model in models.items():
        try:
            # Train the model
            history, trainer = train_model_experiment(
                model=model,
                model_name=model_name,
                train_loader=dataset_info['train'],
                val_loader=dataset_info['val'],
                dataset_name=dataset_name,
                num_epochs=num_epochs
            )
            
            dataset_results[model_name] = history
            dataset_models[model_name] = trainer.model
            
        except Exception as e:
            print(f"Error training {model_name}: {e}")
            dataset_results[model_name] = None
    
    all_results[dataset_name] = dataset_results
    trained_models[dataset_name] = dataset_models

print(f"\n{'='*60}")
print("TRAINING COMPLETED!")
print(f"{'='*60}")

## 6. Learning Curves Visualization

Let's visualize the learning curves for all models and datasets.

In [None]:
# Plot learning curves
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Learning Curves: Training Progress Comparison', fontsize=16, fontweight='bold')

colors = ['blue', 'red', 'green']
model_names = ['Griffin', 'Hawk', 'Local Attention']

for dataset_idx, (dataset_name, results) in enumerate(all_results.items()):
    # Training loss
    ax_train = axes[dataset_idx, 0]
    ax_train.set_title(f'{dataset_name}: Training Loss')
    
    # Validation loss
    ax_val = axes[dataset_idx, 1]
    ax_val.set_title(f'{dataset_name}: Validation Loss')
    
    for model_idx, (model_name, history) in enumerate(results.items()):
        if history is not None:
            color = colors[model_idx % len(colors)]
            
            # Plot training loss
            ax_train.plot(history['epochs'], history['train_loss'], 
                         color=color, label=model_name, marker='o', linewidth=2)
            
            # Plot validation loss
            ax_val.plot(history['epochs'], history['val_loss'], 
                       color=color, label=model_name, marker='s', linewidth=2)
    
    ax_train.set_xlabel('Epoch')
    ax_train.set_ylabel('Loss')
    ax_train.legend()
    ax_train.grid(True, alpha=0.3)
    
    ax_val.set_xlabel('Epoch')
    ax_val.set_ylabel('Loss')
    ax_val.legend()
    ax_val.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Perplexity Comparison

Let's also look at perplexity curves to understand model performance better.

In [None]:
# Plot perplexity curves
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Validation Perplexity: Model Performance Comparison', fontsize=16, fontweight='bold')

for dataset_idx, (dataset_name, results) in enumerate(all_results.items()):
    ax = axes[dataset_idx]
    ax.set_title(f'{dataset_name}: Validation Perplexity')
    
    for model_idx, (model_name, history) in enumerate(results.items()):
        if history is not None:
            color = colors[model_idx % len(colors)]
            ax.plot(history['epochs'], history['val_perplexity'], 
                   color=color, label=model_name, marker='o', linewidth=2)
    
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Perplexity')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')  # Log scale for better visualization

plt.tight_layout()
plt.show()

## 8. Final Performance Summary

Let's create a comprehensive summary of the final performance.

In [None]:
# Create performance summary
summary_data = []

for dataset_name, results in all_results.items():
    for model_name, history in results.items():
        if history is not None and len(history['val_loss']) > 0:
            final_train_loss = history['train_loss'][-1]
            final_val_loss = history['val_loss'][-1]
            final_val_ppl = history['val_perplexity'][-1]
            
            summary_data.append({
                'Dataset': dataset_name,
                'Model': model_name,
                'Final Train Loss': final_train_loss,
                'Final Val Loss': final_val_loss,
                'Final Val Perplexity': final_val_ppl,
                'Convergence': 'Good' if final_val_loss < final_train_loss * 1.2 else 'Overfitting'
            })

if summary_data:
    import pandas as pd
    df_summary = pd.DataFrame(summary_data)
    
    print("FINAL PERFORMANCE SUMMARY")
    print("=" * 60)
    print(df_summary.to_string(index=False))
    
    # Create summary visualization
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Final validation loss comparison
    sns.barplot(data=df_summary, x='Dataset', y='Final Val Loss', hue='Model', ax=axes[0])
    axes[0].set_title('Final Validation Loss Comparison')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Final perplexity comparison
    sns.barplot(data=df_summary, x='Dataset', y='Final Val Perplexity', hue='Model', ax=axes[1])
    axes[1].set_title('Final Validation Perplexity Comparison')
    axes[1].tick_params(axis='x', rotation=45)
    axes[1].set_yscale('log')
    
    plt.tight_layout()
    plt.show()
    
    # Best performers
    print("\nBEST PERFORMERS:")
    print("=" * 30)
    for dataset in df_summary['Dataset'].unique():
        dataset_df = df_summary[df_summary['Dataset'] == dataset]
        best_model = dataset_df.loc[dataset_df['Final Val Loss'].idxmin(), 'Model']
        best_loss = dataset_df['Final Val Loss'].min()
        print(f"{dataset}: {best_model} (Loss: {best_loss:.4f})")
else:
    print("No training results available for summary.")

## 9. Model Evaluation on Test Set

Now let's evaluate our trained models on the test sets.

In [None]:
# Test set evaluation
print("EVALUATING ON TEST SETS")
print("=" * 40)

evaluator = ModelEvaluator(device=device)
test_results = []

for dataset_name, dataset_info in datasets.items():
    print(f"\nEvaluating on {dataset_name} test set...")
    
    if dataset_name in trained_models:
        for model_name, model in trained_models[dataset_name].items():
            try:
                # Evaluate model
                metrics = evaluator.evaluate_model(
                    model=model,
                    dataloader=dataset_info['test'],
                    model_name=model_name,
                    dataset_name=dataset_name
                )
                
                test_results.append({
                    'Dataset': dataset_name,
                    'Model': model_name,
                    'Test Loss': metrics.loss,
                    'Test Perplexity': metrics.perplexity,
                    'Inference Time': metrics.inference_time,
                    'Memory Usage': metrics.memory_usage
                })
                
                print(f"  {model_name}: Loss={metrics.loss:.4f}, PPL={metrics.perplexity:.4f}")
                
            except Exception as e:
                print(f"  Error evaluating {model_name}: {e}")

if test_results:
    df_test = pd.DataFrame(test_results)
    
    print("\nTEST SET RESULTS:")
    print(df_test.to_string(index=False))
    
    # Test results visualization
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.barplot(data=df_test, x='Dataset', y='Test Loss', hue='Model', ax=ax)
    ax.set_title('Test Set Performance Comparison')
    ax.tick_params(axis='x', rotation=45)
    plt.tight_layout()
    plt.show()
else:
    print("No test results available.")

## 10. Save Results

Let's save our training results for future analysis.

In [None]:
# Save results
results_dir = project_root / 'notebooks' / 'results'
results_dir.mkdir(exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Save training history
results_file = results_dir / f'training_results_{timestamp}.json'
with open(results_file, 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"Training results saved to: {results_file}")

# Save summary tables
if summary_data:
    summary_file = results_dir / f'performance_summary_{timestamp}.csv'
    df_summary.to_csv(summary_file, index=False)
    print(f"Performance summary saved to: {summary_file}")

if test_results:
    test_file = results_dir / f'test_results_{timestamp}.csv'
    df_test.to_csv(test_file, index=False)
    print(f"Test results saved to: {test_file}")

print("\nAll results saved successfully!")

## 11. Conclusions

This notebook has demonstrated:

### Training Process:
1. **Setup**: Configured models and datasets for comparative training
2. **Training**: Trained all three architectures on both MQAR and Chomsky datasets
3. **Monitoring**: Tracked training and validation metrics across epochs
4. **Evaluation**: Assessed final performance on test sets

### Key Observations:
- **Convergence Speed**: How quickly each model learns the tasks
- **Generalization**: Performance gap between training and validation
- **Task Specificity**: Which architectures excel on which types of problems
- **Efficiency**: Training time and memory requirements

### Next Steps:
1. **Longer Training**: Train for more epochs to see full convergence
2. **Hyperparameter Tuning**: Optimize learning rates and architectures
3. **More Datasets**: Test on additional sequence modeling tasks
4. **Analysis**: Deep dive into what each model learns differently

This comparative study provides valuable insights into the trade-offs between different sequence modeling approaches and demonstrates Griffin's hybrid advantage in practice.