# TinyLM Training on Google Colab

## This notebook trains a TinyLM model on ARC-AGI data using Google Colab's GPU resources.

### Setup Instructions:
1. Upload this notebook to Google Colab
2. Enable GPU runtime: Runtime → Change runtime type → GPU (T4/V100)
3. Run all cells in order

In [8]:
# Check GPU availability and install requirements
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Please enable GPU runtime in Colab!")

PyTorch version: 2.2.2
CUDA available: False


In [None]:
# Clone the repository from GitHub
!git clone https://github.com/CalebTalley2024/ARC-AGI-2.git
%cd ARC-AGI-2
# Checkout specific branch for consistency
!git checkout vedant

In [None]:
# Install exact package versions for reproducibility
!pip install --quiet \
    numpy==1.24.4 \
    matplotlib==3.7.5 \
    pandas==2.0.3 \
    scipy==1.10.1 \
    scikit-learn==1.3.2 \
    torch==2.2.2 \
    torchvision==0.17.2 \
    torchaudio==2.2.2 \
    transformers==4.46.3 \
    huggingface-hub==0.36.0 \
    tokenizers==0.20.3 \
    safetensors==0.5.3 \
    seaborn==0.13.2 \
    plotly==6.4.0 \
    tqdm==4.67.1 \
    pyyaml==6.0.3 \
    requests==2.32.4 \
    packaging==25.0 \
    jsonschema==4.23.0 \
    fastjsonschema==2.21.2 \
    jinja2==3.1.6 \
    markupsafe==2.1.5 \
    urllib3==2.2.3 \
    certifi==2025.10.5 \
    charset-normalizer==3.4.4 \
    idna==3.11 \
    python-dateutil==2.9.0.post0 \
    pytz==2025.2 \
    tzdata==2025.2 \
    six==1.17.0 \
    setuptools==75.3.2

# Install the package in development mode
!pip install -e .

In [None]:
# Import necessary libraries
import sys
import os
from pathlib import Path

# Add project root to Python path (go up one level from notebooks)
project_root = Path.cwd()
sys.path.append(str(project_root))

# Import the training function and centralized constants
from arc.models.train import train
from arc.models.tiny_lm import TinyLMConfig
from arc.utils.constants import (
    MODEL_CONFIGS, 
    TRAINING_CONFIGS, 
    get_matched_configs,
    estimate_model_parameters
)

print("Successfully imported training modules and centralized constants")
print(f"Project root: {project_root}")
print(f"Available model sizes: {list(MODEL_CONFIGS.keys())}")
print(f"Available training profiles: {list(TRAINING_CONFIGS.keys())}")

Successfully imported training modules and centralized constants
Project root: /Users/vedanttibrewal/Documents/USC/lectures/sem_3/CSCI-544/project/ARC-AGI-2
Available model sizes: ['tiny', 'small', 'medium', 'large']
Available training profiles: ['debug', 'small_gpu', 'medium_gpu', 'large_gpu']


In [5]:
# Check data directory structure
data_dir = project_root / "data"
print("Data directory contents:")
if data_dir.exists():
    for item in data_dir.iterdir():
        print(f"  {item.name}")
        if item.is_dir():
            for subitem in item.iterdir():
                print(f"    {subitem.name}")
                if subitem.name == "arc" and subitem.is_dir():
                    for arcitem in subitem.iterdir():
                        print(f"      {arcitem.name}")
else:
    print("  WARNING: Data directory not found!")

# Check if training data exists
training_path = data_dir / "raw" / "arc" / "training"
eval_path = data_dir / "raw" / "arc" / "evaluation"

print(f"\nTraining data exists: {training_path.exists()}")
print(f"Evaluation data exists: {eval_path.exists()}")

if training_path.exists():
    training_files = list(training_path.glob("*.json"))
    print(f"Number of training files: {len(training_files)}")
    
if eval_path.exists():
    eval_files = list(eval_path.glob("*.json"))
    print(f"Number of evaluation files: {len(eval_files)}")

Data directory contents:
  processed
    index.jsonl
    .gitkeep
    eval_tasks.json
    dev_tasks.json
  raw
    arc
      training
      .gitkeep
      evaluation.txt
      training.txt
      evaluation

Training data exists: True
Evaluation data exists: True
Number of training files: 1000
Number of evaluation files: 120


## Training Configuration

Configure the training parameters. Adjust these based on your needs and available GPU memory.

In [None]:
# GPU-aware dynamic configuration selection
def select_optimal_configs(gpu_memory_gb):
    """Select optimal model and training configs based on GPU memory."""
    
    print(f"GPU Memory: {gpu_memory_gb:.1f} GB")
    
    # Select configs based on GPU memory
    if gpu_memory_gb < 4:
        model_size = 'tiny'
        training_profile = 'debug'
        print("Low GPU memory detected - using minimal config for testing")
    elif gpu_memory_gb < 8:
        model_size = 'tiny'
        training_profile = 'small_gpu'
        print("Small GPU detected - using tiny model with memory optimization")
    elif gpu_memory_gb < 16:
        model_size = 'small'
        training_profile = 'medium_gpu'
        print("Medium GPU detected - using small model")
    else:
        model_size = 'medium'
        training_profile = 'large_gpu'
        print("Large GPU detected - using medium model for best performance")
    

    model_config = "small"
    training_config = "small_gpu"

    # Get matched configurations
    model_config, training_config = get_matched_configs(model_size, training_profile)
    
    # Display configuration info
    param_count = estimate_model_parameters(model_config)
    effective_batch_size = training_config['batch_size'] * training_config['grad_accumulation_steps']
    
    print(f"\nSelected Configuration:")
    print(f"  Model: {model_size} ({param_count/1e6:.1f}M parameters)")
    print(f"  Training profile: {training_profile}")
    print(f"  Batch size: {training_config['batch_size']} (effective: {effective_batch_size})")
    print(f"  Max sequence length: {training_config['max_sequence_length']}")
    print(f"  Gradient accumulation: {training_config['grad_accumulation_steps']} steps")
    
    return model_size, training_profile, model_config, training_config

# Check GPU and select optimal configuration
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    selected_model_size, selected_training_profile, model_config, training_config = select_optimal_configs(gpu_memory)
else:
    print("No GPU detected! Using CPU-friendly minimal config")
    selected_model_size = 'tiny'
    selected_training_profile = 'debug'
    model_config, training_config = get_matched_configs(selected_model_size, selected_training_profile)

# Training configuration for this session
TRAINING_CONFIG = {
    # Data paths
    "data_path": str(data_dir / "raw" / "arc" / "training"),
    "model_dir": "./models/tinylm_checkpoints",
    
    # Use centralized configurations
    "model_config": model_config,
    "training_config": training_config,
    
    # Store selected profiles for use in training
    "selected_model_size": selected_model_size,
    "selected_training_profile": selected_training_profile,
}

print(f"\nFinal Training Configuration:")
print(f"  Model size: {selected_model_size}")
print(f"  Training profile: {selected_training_profile}")
print(f"  Data path: {TRAINING_CONFIG['data_path']}")
print(f"  Model dir: {TRAINING_CONFIG['model_dir']}")
print(f"  Steps: {training_config['steps']:,}")
print(f"  Learning rate: {training_config['learning_rate']}")
print(f"  Weight decay: {training_config['weight_decay']}")
print(f"  Use AMP: {training_config['use_amp']}")

No GPU detected! Using CPU-friendly minimal config

Final Training Configuration:
  Data path: /Users/vedanttibrewal/Documents/USC/lectures/sem_3/CSCI-544/project/ARC-AGI-2/data/raw/arc/training
  Model dir: ./models/tinylm_checkpoints
  Steps: 100
  Learning rate: 0.0003
  Weight decay: 0.01
  Use AMP: True


In [None]:
# Create output directory
output_dir = Path(TRAINING_CONFIG["model_dir"])
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Created output directory: {output_dir}")

# Test model creation with centralized config
print("\nTesting model creation with selected configuration...")
try:
    from arc.utils.constants import VOCAB_SIZE
    print(f"Vocabulary size: {VOCAB_SIZE}")
    
    # Create model with selected configuration
    model_cfg = TRAINING_CONFIG["model_config"]
    test_config = TinyLMConfig(**model_cfg)
    print(f"Model config created successfully")
    print(f"  Architecture: {test_config.d_model}D, {test_config.n_layers} layers, {test_config.n_heads} heads")
    print(f"  Parameters: ~{estimate_model_parameters(model_cfg)/1e6:.1f}M")
    print(f"  Max sequence length: {test_config.max_len}")
    
    # Display memory usage estimate
    param_size_mb = estimate_model_parameters(model_cfg) * 4 / 1e6  # 4 bytes per float32
    training_cfg = TRAINING_CONFIG["training_config"]
    batch_memory_mb = (training_cfg['batch_size'] * training_cfg['max_sequence_length'] * 
                      model_cfg['d_model'] * 4) / 1e6
    
    print(f"\nMemory Estimates:")
    print(f"  Model size: ~{param_size_mb:.0f} MB")
    print(f"  Batch memory: ~{batch_memory_mb:.0f} MB")
    print(f"  Total training memory: ~{param_size_mb + batch_memory_mb*3:.0f} MB (estimated)")
    
except Exception as e:
    print(f"Error creating model: {e}")
    print("This might indicate missing implementations in serialize module")
    import traceback
    traceback.print_exc()

In [6]:
# Test data loading before training
import sys
sys.path.append(str(project_root))

from arc.io.loader import load_tasks
from arc.models.train import ArcPairsDataset

print("Testing data loading...")
data_path = str(data_dir / "raw" / "arc" / "training")
print(f"Loading tasks from: {data_path}")

try:
    # Load a few tasks to test
    import glob
    json_files = glob.glob(f"{data_path}/*.json")
    print(f"Found {len(json_files)} JSON files")
    
    # Test loading just first few tasks
    test_files = json_files[:3]
    tasks = []
    for file in test_files:
        from arc.io.loader import load_task
        task = load_task(file)
        tasks.append(task)
        print(f"Loaded task from {file}: {len(task['train'])} training examples")
    
    # Test dataset creation
    print("\nTesting dataset creation...")
    ds = ArcPairsDataset(tasks, max_len=512)  # Use smaller max_len for testing
    print(f"Dataset created successfully with {len(ds)} examples")
    
    if len(ds) > 0:
        # Test getting first example
        x, y = ds[0]
        print(f"First example - Input shape: {x.shape}, Target shape: {y.shape}")
        print(f"Input range: [{x.min()}, {x.max()}], Target range: [{y.min()}, {y.max()}]")
    
except Exception as e:
    print(f"Error in data loading: {e}")
    import traceback
    traceback.print_exc()

Testing data loading...
Loading tasks from: /Users/vedanttibrewal/Documents/USC/lectures/sem_3/CSCI-544/project/ARC-AGI-2/data/raw/arc/training
Found 1000 JSON files
Loaded task from /Users/vedanttibrewal/Documents/USC/lectures/sem_3/CSCI-544/project/ARC-AGI-2/data/raw/arc/training/a85d4709.json: 4 training examples
Loaded task from /Users/vedanttibrewal/Documents/USC/lectures/sem_3/CSCI-544/project/ARC-AGI-2/data/raw/arc/training/c8cbb738.json: 3 training examples
Loaded task from /Users/vedanttibrewal/Documents/USC/lectures/sem_3/CSCI-544/project/ARC-AGI-2/data/raw/arc/training/8e1813be.json: 3 training examples

Testing dataset creation...
Dataset created successfully with 10 examples
First example - Input shape: torch.Size([106]), Target shape: torch.Size([106])
Input range: [1, 82], Target range: [2, 82]


In [None]:
# Check for existing checkpoints and resume capability
import torch
from pathlib import Path

def check_checkpoint_info(model_dir):
    """Check what checkpoints are available for resuming."""
    model_path = Path(model_dir)
    
    # Check for best.pt
    best_checkpoint = model_path / "best.pt"
    final_checkpoint = model_path / "final.pt"
    
    print("Checkpoint Status:")
    print("=" * 50)
    
    if best_checkpoint.exists():
        try:
            checkpoint = torch.load(best_checkpoint, map_location='cpu')
            print(f"✓ Found best.pt checkpoint:")
            print(f"  - Best loss: {checkpoint.get('loss', 'N/A'):.4f}")
            print(f"  - Last step: {checkpoint.get('step', 'N/A')}")
            print(f"  - Learning rate: {checkpoint.get('lr', 'N/A')}")
            
            if 'training_config' in checkpoint:
                train_cfg = checkpoint['training_config']
                print(f"  - Total steps: {train_cfg.get('steps', 'N/A')}")
                print(f"  - Batch size: {train_cfg.get('batch_size', 'N/A')}")
                steps_remaining = train_cfg.get('steps', 0) - checkpoint.get('step', 0) - 1
                print(f"  - Steps remaining: {max(0, steps_remaining)}")
            
            # Check if training was completed
            if checkpoint.get('training_completed', False):
                print("  - Status: Training completed")
            else:
                print("  - Status: Training can be resumed")
                
        except Exception as e:
            print(f"✗ Found best.pt but couldn't load: {e}")
    else:
        print("✗ No best.pt checkpoint found")
    
    if final_checkpoint.exists():
        try:
            checkpoint = torch.load(final_checkpoint, map_location='cpu')
            print(f"\n✓ Found final.pt checkpoint:")
            if checkpoint.get('training_completed', False):
                print("  - Training completed successfully")
            else:
                print("  - Training may have been interrupted")
        except Exception as e:
            print(f"\n✗ Found final.pt but couldn't load: {e}")
    
    # Check for regular checkpoints
    checkpoint_files = list(model_path.glob("ckpt_*.pt"))
    if checkpoint_files:
        print(f"\n✓ Found {len(checkpoint_files)} regular checkpoints:")
        for ckpt_file in sorted(checkpoint_files)[-3:]:  # Show last 3
            print(f"  - {ckpt_file.name}")
    
    print("=" * 50)

# Check current checkpoint status
checkpoint_dir = TRAINING_CONFIG["model_dir"]
check_checkpoint_info(checkpoint_dir)

## Checkpoint Resuming

This notebook now supports automatic resuming from the `best.pt` checkpoint file. Here's how it works:

### Automatic Resuming Features:
- **Model State**: Automatically loads the best model weights
- **Optimizer State**: Restores AdamW optimizer momentum and parameters  
- **Training Progress**: Resumes from the exact step where training stopped
- **Best Loss Tracking**: Continues tracking the best validation loss
- **Mixed Precision**: Restores AMP scaler state for consistent training

### Checkpoint Contents:
Each checkpoint now contains:
- `model`: Model state dictionary
- `cfg`: Model configuration  
- `loss`: Best loss achieved
- `step`: Current training step
- `optimizer`: Optimizer state (momentum, etc.)
- `scaler`: Mixed precision scaler state
- `training_config`: Training hyperparameters
- `model_config`: Model architecture settings

### Usage:
- Set `RESUME_TRAINING = True` to enable resuming (default)
- Set `RESUME_TRAINING = False` to start from scratch
- The system automatically detects and loads `best.pt` if available
- Use the checkpoint management utilities below for advanced operations

## Start Training

**Note:** The current implementation has placeholder functions for data loading. This will train on empty data but demonstrates the training loop. You'll need to implement proper data loading for actual training.

In [None]:
# Start training with centralized configurations
import time

print("Starting TinyLM training with centralized configuration...")
print(f"Model will be saved to: {TRAINING_CONFIG['model_dir']}")
print("="*60)

# Display final configuration summary
model_cfg = TRAINING_CONFIG["model_config"]
training_cfg = TRAINING_CONFIG["training_config"]
selected_model_size = TRAINING_CONFIG["selected_model_size"]
selected_training_profile = TRAINING_CONFIG["selected_training_profile"]

print("CONFIGURATION SUMMARY:")
print(f"  Model size: {selected_model_size}")
print(f"  Training profile: {selected_training_profile}")
print(f"  Model: {model_cfg['d_model']}D, {model_cfg['n_layers']}L, {model_cfg['n_heads']}H")
print(f"  Parameters: ~{estimate_model_parameters(model_cfg)/1e6:.1f}M")
print(f"  Training steps: {training_cfg['steps']:,}")
print(f"  Effective batch size: {training_cfg['batch_size'] * training_cfg['grad_accumulation_steps']}")
print(f"  Learning rate: {training_cfg['learning_rate']}")
print(f"  Sequence length: {training_cfg['max_sequence_length']}")
print("="*60)

start_time = time.time()

try:
    # Call the training function with dynamically selected configurations
    train(
        model_dir=TRAINING_CONFIG["model_dir"],
        data_path=TRAINING_CONFIG["data_path"],
        # Pass individual parameters
        steps=training_cfg["steps"],
        bs=training_cfg["batch_size"],
        lr=training_cfg["learning_rate"],
        d_model=model_cfg["d_model"],
        # Use the dynamically selected profiles from previous cell
        model_size=selected_model_size,
        training_profile=selected_training_profile,
        resume_from_checkpoint=True, # save checkpoints in ARC-AGI-2/models/tinylm_checkpoints/
    )
    
    end_time = time.time()
    training_time = end_time - start_time
    
    print("="*60)
    print(f"Training completed successfully!")
    print(f"  Configuration used: {selected_model_size} model with {selected_training_profile} profile")
    print(f"  Training time: {training_time/60:.1f} minutes")
    print(f"  Models saved to: {TRAINING_CONFIG['model_dir']}")
    
except Exception as e:
    print(f"Training failed with error: {e}")
    import traceback
    traceback.print_exc()

Starting TinyLM training with centralized configuration...
Model will be saved to: ./models/tinylm_checkpoints
CONFIGURATION SUMMARY:
  Model: 256D, 4L, 4H
  Parameters: ~3.7M
  Training steps: 100
  Effective batch size: 8
  Learning rate: 0.0003
  Sequence length: 2048
Training failed with error: {'steps': 100, 'batch_size': 4, 'learning_rate': 0.0003, 'max_sequence_length': 2048, 'betas': (0.9, 0.95), 'weight_decay': 0.01, 'grad_clip_norm': 1.0, 'serialization_mode': 'row', 'pad_token_id': 0, 'ignore_index': 0, 'save_every': 50, 'eval_every': 1000, 'use_amp': True, 'grad_accumulation_steps': 2} is not in list


Traceback (most recent call last):
  File "/var/folders/3w/z9ryb6gj05n_x375lp0r_81r0000gn/T/ipykernel_3819/3394006310.py", line 37, in <module>
    list(TRAINING_CONFIGS.values()).index(training_cfg)
ValueError: {'steps': 100, 'batch_size': 4, 'learning_rate': 0.0003, 'max_sequence_length': 2048, 'betas': (0.9, 0.95), 'weight_decay': 0.01, 'grad_clip_norm': 1.0, 'serialization_mode': 'row', 'pad_token_id': 0, 'ignore_index': 0, 'save_every': 50, 'eval_every': 1000, 'use_amp': True, 'grad_accumulation_steps': 2} is not in list


In [None]:
# Checkpoint Management Utilities # Hopefully not needed often
from pathlib import Path
import torch
import shutil
from datetime import datetime

def backup_checkpoint(model_dir, backup_name=None):
    """Create a backup of the current best.pt checkpoint."""
    model_path = Path(model_dir)
    best_checkpoint = model_path / "best.pt"
    
    if not best_checkpoint.exists():
        print("No best.pt checkpoint found to backup")
        return False
    
    if backup_name is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        backup_name = f"best_backup_{timestamp}.pt"
    
    backup_path = model_path / backup_name
    shutil.copy2(best_checkpoint, backup_path)
    print(f"Checkpoint backed up to: {backup_path}")
    return True

def clear_checkpoints(model_dir, keep_best=True, confirm=True):
    """Clear all checkpoints. Optionally keep best.pt."""
    model_path = Path(model_dir)
    
    if confirm:
        response = input("Are you sure you want to clear checkpoints? (yes/no): ")
        if response.lower() != 'yes':
            print("Operation cancelled")
            return
    
    # Find all checkpoint files
    checkpoint_files = list(model_path.glob("*.pt"))
    
    cleared_count = 0
    for checkpoint_file in checkpoint_files:
        if keep_best and checkpoint_file.name == "best.pt":
            continue
        
        checkpoint_file.unlink()
        print(f"Removed: {checkpoint_file.name}")
        cleared_count += 1
    
    print(f"Cleared {cleared_count} checkpoint files")

def reset_training_from_scratch(model_dir):
    """Reset training to start completely from scratch."""
    print("Resetting training to start from scratch...")
    
    # Backup current best checkpoint if it exists
    if backup_checkpoint(model_dir):
        print("Current best checkpoint has been backed up")
    
    # Clear all checkpoints
    clear_checkpoints(model_dir, keep_best=False, confirm=False)
    
    print("Training reset complete. Next training run will start from scratch.")

# Checkpoint management options
print("Checkpoint Management Options:")
print("1. Check checkpoint status (run the cell above)")  
print("2. Backup current best.pt: backup_checkpoint(TRAINING_CONFIG['model_dir'])")  
print("3. Clear old checkpoints: clear_checkpoints(TRAINING_CONFIG['model_dir'])")
print("4. Reset to start from scratch: reset_training_from_scratch(TRAINING_CONFIG['model_dir'])")
print("\nUncomment and run the desired operation below:")

# Example usage (uncomment to use):
# backup_checkpoint(TRAINING_CONFIG['model_dir'])
# clear_checkpoints(TRAINING_CONFIG['model_dir'], keep_best=True)  
# reset_training_from_scratch(TRAINING_CONFIG['model_dir'])

In [None]:
# Check training results
import os

model_dir = Path(TRAINING_CONFIG["model_dir"])
if model_dir.exists():
    print("Training output files:")
    for file in sorted(model_dir.iterdir()):
        if file.is_file():
            size_mb = file.stat().st_size / (1024 * 1024)
            print(f"  {file.name}: {size_mb:.1f} MB")
    
    # Check if best model exists
    best_model = model_dir / "best.pt"
    if best_model.exists():
        print(f"\nBest model saved: {best_model}")
        # Load and display best model info
        try:
            import torch
            checkpoint = torch.load(best_model, map_location='cpu')
            if 'loss' in checkpoint:
                print(f"   Best loss: {checkpoint['loss']:.4f}")
            if 'cfg' in checkpoint:
                cfg = checkpoint['cfg']
                print(f"   Model config: {cfg}")
        except Exception as e:
            print(f"   Could not load model info: {e}")
else:
    print("No training output found!")

## Model Testing (Optional)

Test the trained model with a simple forward pass to ensure it's working correctly.

In [None]:
# Load and test the best model
model_path = model_dir / "best.pt"

if model_path.exists():
    print("Testing the trained model...")
    
    try:
        # Load the model
        checkpoint = torch.load(model_path, map_location='cpu')
        
        # Recreate the model using centralized config
        from arc.models.tiny_lm import TinyLM, TinyLMConfig
        from arc.utils.constants import VOCAB_SIZE
        
        # Use the same config that was used for training
        cfg = TinyLMConfig(**checkpoint['cfg'])
        model = TinyLM(cfg)
        model.load_state_dict(checkpoint['model'])
        
        # Move to GPU if available
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = model.to(device)
        model.eval()
        
        print(f"Model loaded successfully on {device}")
        
        # Display model info
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"Model Statistics:")
        print(f"  Total parameters: {total_params:,}")
        print(f"  Trainable parameters: {trainable_params:,}")
        print(f"  Model size: ~{total_params * 4 / 1e6:.1f} MB")
        print(f"  Architecture: {cfg.d_model}D x {cfg.n_layers}L x {cfg.n_heads}H")
        
        # Test with dummy input using the trained sequence length
        batch_size = 4
        seq_len = min(64, cfg.max_len)  # Use shorter sequence for testing
        dummy_input = torch.randint(0, VOCAB_SIZE, (batch_size, seq_len)).to(device)
        
        with torch.no_grad():
            output = model(dummy_input)
            print(f"\nForward Pass Test:")
            print(f"  Input shape: {dummy_input.shape}")
            print(f"  Output shape: {output.shape}")
            print(f"  Output range: [{output.min():.3f}, {output.max():.3f}]")
            print(f"  Output mean: {output.mean():.3f}")
            print(f"  Model is working correctly!")
            
    except Exception as e:
        print(f"Model testing failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print("No trained model found to test")

## Download Trained Models

Download the trained models to your local machine or save to Google Drive.

In [None]:
# Option 1: Download files directly (in Colab)
from google.colab import files
import zipfile

if model_dir.exists():
    # Create a zip file of all models
    zip_path = "tinylm_models.zip"
    
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for file in model_dir.iterdir():
            if file.is_file():
                zipf.write(file, file.name)
    
    print(f"Created zip file: {zip_path}")
    print("Downloading...")
    files.download(zip_path)
else:
    print("No models to download")

In [None]:
# Option 2: Save to Google Drive (uncomment to use)
# from google.colab import drive
# drive.mount('/content/drive')
# 
# # Copy models to Google Drive
# import shutil
# drive_path = "/content/drive/MyDrive/TinyLM_Models"
# if model_dir.exists():
#     shutil.copytree(model_dir, drive_path, dirs_exist_ok=True)
#     print(f"Models saved to Google Drive: {drive_path}")

print("Training notebook complete!")
print("\n" + "="*60)
print("TRAINING SUMMARY:")

# Display final configuration that was used
final_model_cfg = TRAINING_CONFIG["model_config"]
final_training_cfg = TRAINING_CONFIG["training_config"]

print(f"Model: {final_model_cfg['d_model']}D x {final_model_cfg['n_layers']}L x {final_model_cfg['n_heads']}H")
print(f"Parameters: ~{estimate_model_parameters(final_model_cfg)/1e6:.1f}M")
print(f"Training steps: {final_training_cfg['steps']:,}")
print(f"Batch size: {final_training_cfg['batch_size']} (effective: {final_training_cfg['batch_size'] * final_training_cfg['grad_accumulation_steps']})")
print(f"Sequence length: {final_training_cfg['max_sequence_length']}")
print(f"Learning rate: {final_training_cfg['learning_rate']}")
print(f"Models saved to: {TRAINING_CONFIG['model_dir']}")

print(f"\nGPU Optimizations Used:")
print(f"  - Gradient accumulation: {final_training_cfg['grad_accumulation_steps']} steps")
print(f"  - Mixed precision: {'Yes' if final_training_cfg['use_amp'] else 'No'}")
print(f"  - Gradient clipping: {final_training_cfg['grad_clip_norm']}")

print(f"\nNext Steps:")
print("   - Centralized configuration system implemented")
print("   - GPU-aware automatic config selection")
print("   - Implement proper data loading (arc.io.load_task)")
print("   - Implement tokenization (arc.serialize.pack_example)")
print("   - Run evaluation on the trained model")
print("   - Experiment with different hyperparameters")
print("   - Try larger models if GPU memory allows")

print("="*60)