# üßÆ AI Mathematical Olympiad - Full Training on Colab

**Comprehensive Training Notebook for Mathematical Reasoning Model**

This notebook trains a transformer model on the full MATH dataset (~7,500 problems) with:
- ‚úÖ Properly sized model for available data
- ‚úÖ Full dataset (not just 500 examples)
- ‚úÖ Real-time monitoring and sample generation
- ‚úÖ Early stopping and checkpointing
- ‚úÖ GPU acceleration (50x faster than CPU)

---

## üìù Instructions

1. **Runtime Setup**: Runtime ‚Üí Change runtime type ‚Üí GPU (T4)
2. **Run All**: Runtime ‚Üí Run all
3. **Training Time**: ~2-4 hours on free Colab T4 GPU
4. **Checkpoints**: Saved to Google Drive (optional, see Step 1)

---

## Step 0: Check GPU Availability

In [None]:
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"‚úÖ GPU Available: {gpu_name}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è No GPU detected! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")
    print("   Training on CPU will be 50x slower.")

## Step 1: Mount Google Drive (Optional - for saving checkpoints)

In [None]:
from google.colab import drive
import os

# Uncomment the next line to save checkpoints to Google Drive
# drive.mount('/content/drive')

# Set checkpoint directory
USE_GDRIVE = False  # Set to True if you mounted Drive
if USE_GDRIVE:
    CHECKPOINT_DIR = '/content/drive/MyDrive/math_model_checkpoints'
else:
    CHECKPOINT_DIR = '/content/checkpoints'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"‚úÖ Checkpoints will be saved to: {CHECKPOINT_DIR}")

## Step 2: Install Dependencies and Clone Repository

In [None]:
# Install dependencies
!pip install -q datasets transformers tokenizers tqdm matplotlib

# Clone the repository
!git clone https://github.com/Alpyaman/AI-Mathematical-Olympiad.git
%cd AI-Mathematical-Olympiad

print("‚úÖ Repository cloned and dependencies installed")

## Step 3: Configuration - Optimized for Full Dataset

In [None]:
# ============================
# TRAINING CONFIGURATION
# ============================

# Model Size (choose one):
MODEL_SIZE = "small"  # Options: "tiny" (~20M params), "small" (~85M params), "medium" (~350M params)

# Training Hyperparameters
BATCH_SIZE = 8           # Increase if you have more GPU memory
GRAD_ACCUM_STEPS = 4     # Effective batch size = 32
LEARNING_RATE = 1e-4     # Lower than before for stability
MAX_EPOCHS = 30          # More epochs with early stopping
WARMUP_STEPS = 200       # Learning rate warmup
MAX_LENGTH = 1024        # Longer sequences for full solutions

# Early Stopping
PATIENCE = 5             # Stop if no improvement for N epochs
MIN_DELTA = 0.01         # Minimum improvement to count

# Monitoring
LOG_EVERY = 50           # Log metrics every N steps
SAMPLE_EVERY = 200       # Generate samples every N steps
SAVE_EVERY_EPOCH = 5     # Save checkpoint every N epochs

# Dataset
USE_FULL_DATASET = True  # True = ~7.5k examples, False = 500 examples
TRAIN_SPLIT = 0.85       # 85% train, 10% val, 5% test
VAL_SPLIT = 0.10
TEST_SPLIT = 0.05

print(f"""\n{'='*60}
TRAINING CONFIGURATION
{'='*60}
Model Size:        {MODEL_SIZE}
Effective Batch:   {BATCH_SIZE * GRAD_ACCUM_STEPS}
Learning Rate:     {LEARNING_RATE}
Max Epochs:        {MAX_EPOCHS}
Sequence Length:   {MAX_LENGTH}
Full Dataset:      {USE_FULL_DATASET}
{'='*60}\n""")

## Step 4: Load and Prepare Dataset

In [None]:
from datasets import load_dataset
from tqdm import tqdm
from src.data.data_schema import MathProblem, MathSolution, ReasoningStep, DifficultyLevel, ProblemType

def convert_hf_to_schema(hf_dataset):
    """Convert HuggingFace MATH dataset to our schema"""
    problems = []
    print("üîÑ Converting dataset...")
    
    for i, item in enumerate(tqdm(hf_dataset)):
        # Map difficulty if available
        level_map = {
            1: DifficultyLevel.EASY,
            2: DifficultyLevel.MEDIUM,
            3: DifficultyLevel.MEDIUM,
            4: DifficultyLevel.HARD,
            5: DifficultyLevel.OLYMPIAD,
        }
        difficulty = level_map.get(item.get('level', 2), DifficultyLevel.MEDIUM)
        
        # Map problem type
        type_map = {
            'algebra': ProblemType.ALGEBRA,
            'counting_and_probability': ProblemType.COMBINATORICS,
            'geometry': ProblemType.GEOMETRY,
            'intermediate_algebra': ProblemType.ALGEBRA,
            'number_theory': ProblemType.NUMBER_THEORY,
            'prealgebra': ProblemType.ALGEBRA,
            'precalculus': ProblemType.ALGEBRA,
        }
        prob_type = type_map.get(item.get('type', 'algebra'), ProblemType.ALGEBRA)
        
        # Create solution (wrap in single step for now)
        sol = MathSolution(
            steps=[ReasoningStep(1, "Solution", item['solution'], None)],
            final_answer=item['answer'],
            answer_type="exact",
            verification=None
        )
        
        prob = MathProblem(
            problem_id=f"MATH_{i}",
            problem_statement=item['problem'],
            solution=sol,
            difficulty=difficulty,
            problem_type=prob_type,
            topics=[item.get('type', 'math')],
            source="MATH",
            year=2024
        )
        problems.append(prob)
    
    return problems

# Load dataset
print(f"\nüìö Loading {'FULL' if USE_FULL_DATASET else 'MATH-500'} dataset...")

if USE_FULL_DATASET:
    # Load full MATH dataset (~7,500 problems)
    try:
        dataset_hf = load_dataset("lighteval/MATH", split="train")
    except:
        print("   Trying alternative dataset...")
        dataset_hf = load_dataset("hendrycks/math", "all", split="train")
else:
    # Load MATH-500 (small subset for quick testing)
    dataset_hf = load_dataset("HuggingFaceH4/MATH-500", split="test")

print(f"   Loaded {len(dataset_hf)} problems")

# Convert to our schema
problems = convert_hf_to_schema(dataset_hf)

# Split dataset
from src.data.dataset import split_dataset
train_probs, val_probs, test_probs = split_dataset(
    problems, 
    TRAIN_SPLIT, 
    VAL_SPLIT, 
    TEST_SPLIT
)

print(f"\n‚úÖ Dataset prepared:")
print(f"   Train:      {len(train_probs)} problems")
print(f"   Validation: {len(val_probs)} problems")
print(f"   Test:       {len(test_probs)} problems")

## Step 5: Initialize Model and Tokenizer

In [None]:
from src.config.model_config import MathTransformerConfig, get_small_config
from src.model.decoder import MathTransformerDecoder
from src.tokenizer.math_tokenizer import MathTokenizer
from src.data.dataset import MathReasoningDataset, create_dataloaders

# Model configuration
def get_config(size="small"):
    """Get model configuration based on size"""
    if size == "tiny":
        return MathTransformerConfig(
            hidden_size=256,
            num_hidden_layers=6,
            num_attention_heads=8,
            intermediate_size=1024,
            max_position_embeddings=1024,
            max_sequence_length=1024,
            hidden_dropout=0.2,
            attention_dropout=0.1,
        )
    elif size == "small":
        return get_small_config()
    elif size == "medium":
        return MathTransformerConfig(
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072,
            max_position_embeddings=2048,
            max_sequence_length=2048,
        )
    else:
        raise ValueError(f"Unknown model size: {size}")

# Initialize
print("\nüîß Initializing model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
config = get_config(MODEL_SIZE)
model = MathTransformerDecoder(config).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úÖ Model initialized:")
print(f"   Size:              {MODEL_SIZE}")
print(f"   Total parameters:  {total_params:,}")
print(f"   Trainable params:  {trainable_params:,}")
print(f"   Device:            {device}")
print(f"   Hidden size:       {config.hidden_size}")
print(f"   Layers:            {config.num_hidden_layers}")
print(f"   Attention heads:   {config.num_attention_heads}")

# Data-to-parameter ratio
ratio = len(train_probs) / (total_params / 1e6)
print(f"\nüìä Data-to-parameter ratio: {ratio:.1f} examples per million parameters")
if ratio < 10:
    print("   ‚ö†Ô∏è WARNING: Low ratio. Consider using a smaller model or more data.")
elif ratio < 50:
    print("   ‚ÑπÔ∏è Acceptable ratio, but more data would help.")
else:
    print("   ‚úÖ Good ratio for this task!")

# Initialize tokenizer and datasets
print("\nüî§ Initializing tokenizer and datasets...")
tokenizer = MathTokenizer()

train_ds = MathReasoningDataset(train_probs, tokenizer, max_length=MAX_LENGTH)
val_ds = MathReasoningDataset(val_probs, tokenizer, max_length=MAX_LENGTH)

train_loader, val_loader = create_dataloaders(
    train_ds, 
    val_ds, 
    batch_size=BATCH_SIZE,
    num_workers=2
)

print(f"‚úÖ Dataloaders created:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches:   {len(val_loader)}")

## Step 6: Training Setup (Optimizer, Scheduler, Early Stopping)

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import numpy as np

# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=0.01,
    betas=(0.9, 0.95)
)

# Learning rate scheduler with warmup
total_steps = len(train_loader) * MAX_EPOCHS // GRAD_ACCUM_STEPS
scheduler = OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    total_steps=total_steps,
    pct_start=0.05,  # 5% warmup
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1000.0
)

# Early stopping tracker
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False
        
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return self.should_stop

early_stopping = EarlyStopping(patience=PATIENCE, min_delta=MIN_DELTA)

# Training metrics tracker
history = {
    'train_loss': [],
    'val_loss': [],
    'learning_rate': [],
    'epoch': []
}

print("‚úÖ Training setup complete:")
print(f"   Optimizer:     AdamW (lr={LEARNING_RATE})")
print(f"   Scheduler:     OneCycleLR with warmup")
print(f"   Total steps:   {total_steps:,}")
print(f"   Early stopping patience: {PATIENCE} epochs")

## Step 7: Sample Generation Function (Monitor Training Progress)

In [None]:
def generate_sample(model, tokenizer, prompt, max_length=200):
    """Generate a sample response to monitor training progress"""
    model.eval()
    with torch.no_grad():
        # Encode prompt
        encoded = tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.tensor([encoded['input_ids']]).to(device)
        
        # Generate
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_length,
            temperature=0.8,
            top_k=50,
            top_p=0.9,
            eos_token_id=tokenizer.eos_token_id
        )
        
        # Decode
        output_text = tokenizer.decode(output_ids[0].cpu().tolist())
    
    model.train()
    return output_text

# Test problems for monitoring
TEST_PROMPTS = [
    "Problem: Solve for x: 2x + 5 = 13\n\nSolution:",
    "Problem: What is 7 √ó 8?\n\nSolution:",
    "Problem: If f(x) = 3x - 2, what is f(4)?\n\nSolution:"
]

print("‚úÖ Sample generation function ready")

## Step 8: Main Training Loop

In [None]:
import time
from datetime import datetime

print(f"\n{'='*70}")
print(f"üöÄ STARTING TRAINING")
print(f"{'='*70}")
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Device: {device}")
print(f"{'='*70}\n")

best_val_loss = float('inf')
global_step = 0
start_time = time.time()

for epoch in range(MAX_EPOCHS):
    epoch_start = time.time()
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{MAX_EPOCHS}")
    
    for step, batch in enumerate(pbar):
        # Move to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs['loss'] / GRAD_ACCUM_STEPS
        
        # Backward pass
        loss.backward()
        total_loss += loss.item() * GRAD_ACCUM_STEPS
        
        # Update weights
        if (step + 1) % GRAD_ACCUM_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1
            
            # Update progress bar
            current_lr = scheduler.get_last_lr()[0]
            pbar.set_postfix({
                'loss': f"{loss.item() * GRAD_ACCUM_STEPS:.4f}",
                'lr': f"{current_lr:.2e}"
            })
            
            # Generate samples periodically
            if global_step % SAMPLE_EVERY == 0:
                print(f"\n\n{'='*70}")
                print(f"üìù SAMPLE GENERATION (Step {global_step})")
                print(f"{'='*70}")
                for i, prompt in enumerate(TEST_PROMPTS[:2]):
                    print(f"\nTest {i+1}: {prompt[:50]}...")
                    print("-" * 70)
                    sample = generate_sample(model, tokenizer, prompt, max_length=150)
                    print(sample[:300])
                    print("-" * 70)
                print()
    
    # Validation
    model.eval()
    val_loss = 0
    val_steps = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            val_loss += outputs['loss'].item()
            val_steps += 1
    
    # Calculate metrics
    avg_train_loss = total_loss / len(train_loader)
    avg_val_loss = val_loss / val_steps
    epoch_time = time.time() - epoch_start
    
    # Update history
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['learning_rate'].append(scheduler.get_last_lr()[0])
    history['epoch'].append(epoch + 1)
    
    # Print epoch summary
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch+1}/{MAX_EPOCHS} COMPLETE")
    print(f"{'='*70}")
    print(f"Train Loss:      {avg_train_loss:.4f}")
    print(f"Val Loss:        {avg_val_loss:.4f}")
    print(f"Learning Rate:   {scheduler.get_last_lr()[0]:.2e}")
    print(f"Epoch Time:      {epoch_time/60:.2f} minutes")
    print(f"Total Time:      {(time.time()-start_time)/60:.2f} minutes")
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss,
            'config': config,
        }, f"{CHECKPOINT_DIR}/best_model.pt")
        print(f"‚úÖ New best model saved! (val_loss: {avg_val_loss:.4f})")
    
    # Periodic checkpoint
    if (epoch + 1) % SAVE_EVERY_EPOCH == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss,
            'config': config,
        }, f"{CHECKPOINT_DIR}/checkpoint_epoch_{epoch+1}.pt")
        print(f"üíæ Checkpoint saved (epoch {epoch+1})")
    
    print(f"{'='*70}\n")
    
    # Early stopping check
    if early_stopping(avg_val_loss):
        print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch+1} epochs")
        print(f"   No improvement for {PATIENCE} epochs")
        print(f"   Best val loss: {best_val_loss:.4f}")
        break

# Training complete
total_time = time.time() - start_time
print(f"\n{'='*70}")
print(f"üèÅ TRAINING COMPLETE!")
print(f"{'='*70}")
print(f"Total epochs:     {epoch+1}")
print(f"Total time:       {total_time/3600:.2f} hours")
print(f"Best val loss:    {best_val_loss:.4f}")
print(f"Final train loss: {avg_train_loss:.4f}")
print(f"{'='*70}\n")

## Step 9: Visualize Training Progress

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
axes[0].plot(history['epoch'], history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['epoch'], history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Learning rate
axes[1].plot(history['epoch'], history['learning_rate'], marker='o', color='orange')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].set_yscale('log')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CHECKPOINT_DIR}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úÖ Training curves saved to {CHECKPOINT_DIR}/training_curves.png")

## Step 10: Load Best Model and Test

In [None]:
# Load best checkpoint
print("\nüìÇ Loading best model...")
checkpoint = torch.load(f"{CHECKPOINT_DIR}/best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"‚úÖ Best model loaded (val_loss: {checkpoint['val_loss']:.4f})")
print(f"   Trained for {checkpoint['epoch']+1} epochs")

# Test on various problems
test_problems = [
    "Problem: Solve for x: 3x + 7 = 22\n\nSolution:",
    "Problem: What is the square root of 144?\n\nSolution:",
    "Problem: If f(x) = 2x + 3, find f(5)\n\nSolution:",
    "Problem: Calculate 15 √ó 12\n\nSolution:",
    "Problem: Find the area of a circle with radius 5\n\nSolution:"
]

print(f"\n{'='*70}")
print("üß™ FINAL MODEL TESTING")
print(f"{'='*70}\n")

for i, problem in enumerate(test_problems):
    print(f"\n{'‚îÄ'*70}")
    print(f"Test {i+1}: {problem.split('Solution:')[0].strip()}")
    print(f"{'‚îÄ'*70}")
    
    output = generate_sample(model, tokenizer, problem, max_length=256)
    
    # Extract just the solution part
    if "Solution:" in output:
        solution = output.split("Solution:")[1].strip()
        print(solution[:400])  # Print first 400 chars
    else:
        print(output[:400])
    print()

print(f"{'='*70}")
print("‚úÖ Testing complete!")
print(f"{'='*70}")

## Step 11: Interactive Demo - Try Your Own Problems!

In [None]:
def solve_math_problem(problem_text):
    """Solve a math problem using the trained model"""
    prompt = f"Problem: {problem_text}\n\nSolution:"
    output = generate_sample(model, tokenizer, prompt, max_length=400)
    
    # Extract solution
    if "Solution:" in output:
        solution = output.split("Solution:")[1]
        # Try to extract answer
        if "<answer>" in solution and "</answer>" in solution:
            answer = solution.split("<answer>")[1].split("</answer>")[0].strip()
            return solution, answer
        return solution, None
    return output, None

print("\nüéØ Interactive Math Problem Solver")
print("="*70)
print("Enter your math problem below (or press Enter to skip):\n")

# Example usage (you can modify this)
custom_problem = "Find the value of x if 5x - 8 = 17"

if custom_problem:
    print(f"Problem: {custom_problem}\n")
    solution, answer = solve_math_problem(custom_problem)
    print("Solution:")
    print("-" * 70)
    print(solution[:500])
    if answer:
        print(f"\nFinal Answer: {answer}")
    print("-" * 70)
else:
    print("(No custom problem provided, skipping interactive demo)")

print("\n‚úÖ Demo complete! Modify the 'custom_problem' variable above to try different problems.")

## Step 12: Save Final Model Metadata

In [None]:
import json

# Create metadata
metadata = {
    "model_size": MODEL_SIZE,
    "total_parameters": total_params,
    "training_examples": len(train_probs),
    "validation_examples": len(val_probs),
    "best_val_loss": best_val_loss,
    "final_epoch": epoch + 1,
    "total_training_time_hours": total_time / 3600,
    "hyperparameters": {
        "batch_size": BATCH_SIZE,
        "grad_accum_steps": GRAD_ACCUM_STEPS,
        "learning_rate": LEARNING_RATE,
        "max_length": MAX_LENGTH,
        "warmup_steps": WARMUP_STEPS,
    },
    "model_config": {
        "hidden_size": config.hidden_size,
        "num_layers": config.num_hidden_layers,
        "num_heads": config.num_attention_heads,
        "intermediate_size": config.intermediate_size,
    },
    "dataset": "MATH" if USE_FULL_DATASET else "MATH-500",
    "training_date": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}

# Save metadata
with open(f"{CHECKPOINT_DIR}/model_metadata.json", 'w') as f:
    json.dump(metadata, f, indent=2)

print("\nüìä Model Metadata:")
print(json.dumps(metadata, indent=2))
print(f"\n‚úÖ Metadata saved to {CHECKPOINT_DIR}/model_metadata.json")

## üéâ Training Complete!

### What's Saved:
- ‚úÖ `best_model.pt` - Best performing model checkpoint
- ‚úÖ `checkpoint_epoch_*.pt` - Periodic checkpoints
- ‚úÖ `training_curves.png` - Loss and LR visualization
- ‚úÖ `model_metadata.json` - Complete training information

### Next Steps:
1. **Download the model**: From the Files panel (left sidebar)
2. **Use locally**: Load the checkpoint in your local environment
3. **Evaluate**: Test on the MATH test set for proper evaluation
4. **Fine-tune**: Continue training with more data or adjust hyperparameters

### Expected Results:
- With **7,500 examples**: Model should show basic mathematical reasoning
- **Validation loss < 1.0**: Good sign of learning
- **Coherent outputs**: Should generate valid mathematical steps
- **Simple problems**: Should solve basic algebra correctly

### If Results are Poor:
- ‚úÖ Try training for more epochs
- ‚úÖ Use a smaller model ("tiny" size)
- ‚úÖ Lower learning rate (5e-5)
- ‚úÖ Add more data augmentation

---

**Need help?** Check the repository issues or documentation!