# ATLAS Training V2 - Enhanced Spatial Transformation Expert

This notebook provides an enhanced training pipeline for ATLAS with improved learning dynamics and stability.

## 1. Setup and Installation

In [None]:
# Clone repository if needed
import os
if not os.path.exists('/content/AutomataNexus_Olympus_AGI2'):
    !git clone https://github.com/AutomataControls/AutomataNexus_Olympus_AGI2.git /content/AutomataNexus_Olympus_AGI2
else:
    print("Repository already exists")

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q tqdm matplotlib numpy scikit-learn
!pip install -q einops

## 2. Data Download

In [None]:
# Download ARC dataset
import os
data_dir = '/content/AutomataNexus_Olympus_AGI2/data'
os.makedirs(data_dir, exist_ok=True)

# Download training data
if not os.path.exists(f'{data_dir}/arc-agi_training_challenges.json'):
    !wget -q https://raw.githubusercontent.com/fchollet/ARC-AGI/master/data/training/arc-agi_training_challenges.json -P {data_dir}/
    print("✅ Downloaded training challenges")
else:
    print("Training data already exists")

# Download evaluation data
if not os.path.exists(f'{data_dir}/arc-agi_evaluation_challenges.json'):
    !wget -q https://raw.githubusercontent.com/fchollet/ARC-AGI/master/data/evaluation/arc-agi_evaluation_challenges.json -P {data_dir}/
    print("✅ Downloaded evaluation challenges")
else:
    print("Evaluation data already exists")

## 3. Run ATLAS V2 Training

In [None]:
# Change to project directory
%cd /content/AutomataNexus_Olympus_AGI2

In [None]:
# Import and configure
import torch
import gc

# Check GPU
if torch.cuda.is_available():
    print(f"🔧 GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ No GPU available, training will be slow!")

# Clear GPU memory
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Option 1: Quick test run (2 stages, 10 epochs each)
!python scripts/training/train_atlas_specialized2.py --test_mode

In [None]:
# Option 2: Full training (all 8 stages)
!python scripts/training/train_atlas_specialized2.py

## 4. Monitor Training Progress

In [None]:
# Load and display training history
import torch
import matplotlib.pyplot as plt

# Load checkpoint
checkpoint_path = '/content/AutomataNexus_Olympus_AGI2/results/atlas_v2/atlas_v2_checkpoint.pt'
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    history = checkpoint['history']
    
    # Plot training progress
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(history['train_loss'], label='Train Loss', alpha=0.7)
    plt.plot(history['val_loss'], label='Val Loss', alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss History')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 3, 2)
    plt.plot(history['train_exact'], label='Train Exact', alpha=0.7)
    plt.plot(history['val_exact'], label='Val Exact', alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('Exact Match %')
    plt.title('Exact Match History')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 3, 3)
    plt.plot(history['learning_rates'], alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.yscale('log')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n📊 Current Status:")
    print(f"   Epoch: {checkpoint['epoch']}")
    print(f"   Stage: {checkpoint['stage']}")
    print(f"   Best Exact Match: {checkpoint.get('best_exact', 0):.2f}%")
else:
    print("No checkpoint found yet")

## 5. Test the Model

In [None]:
# Test on sample patterns
import sys
sys.path.append('/content/AutomataNexus_Olympus_AGI2/src')
from models.atlas_model import EnhancedAtlasNet
import torch.nn.functional as F

# Load best model
best_model_path = '/content/AutomataNexus_Olympus_AGI2/results/atlas_v2/atlas_v2_best.pt'
if os.path.exists(best_model_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = EnhancedAtlasNet(max_grid_size=30, hidden_dim=256).to(device)
    
    checkpoint = torch.load(best_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Test on a simple rotation pattern
    test_input = torch.tensor([
        [1, 0, 0],
        [0, 2, 0],
        [0, 0, 3]
    ]).unsqueeze(0).to(device)
    
    # Expected output (90-degree rotation)
    expected = torch.tensor([
        [0, 0, 1],
        [0, 2, 0],
        [3, 0, 0]
    ])
    
    # Convert to one-hot
    input_oh = F.one_hot(test_input, num_classes=10).permute(0, 3, 1, 2).float()
    
    # Predict
    with torch.no_grad():
        output = model(input_oh, mode='inference')
        pred = output['predicted_output'].argmax(dim=1).squeeze()
    
    print("Input:")
    print(test_input.squeeze().cpu().numpy())
    print("\nPredicted:")
    print(pred.cpu().numpy())
    print("\nExpected:")
    print(expected.numpy())
    
    # Check if prediction matches expected
    if torch.equal(pred.cpu(), expected):
        print("\n✅ Correct prediction!")
    else:
        print("\n❌ Incorrect prediction")
else:
    print("No trained model found yet")

## 6. Download Trained Model

In [None]:
# Zip and download results
import os
from datetime import datetime

results_dir = '/content/AutomataNexus_Olympus_AGI2/results/atlas_v2'
if os.path.exists(results_dir):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    zip_name = f'atlas_v2_results_{timestamp}.zip'
    
    !cd /content/AutomataNexus_Olympus_AGI2 && zip -r {zip_name} results/atlas_v2/
    
    print(f"✅ Created {zip_name}")
    print("Download using the file browser or:")
    print(f"from google.colab import files")
    print(f"files.download('/content/AutomataNexus_Olympus_AGI2/{zip_name}')")
else:
    print("No results to download yet")