# ATLAS Math Reasoning Demo

This notebook demonstrates the ATLAS two-pass inference protocol improving math problem solving accuracy.

## Overview

ATLAS uses a teacher model to guide student models through:
1. **Diagnostic Probing**: Teacher assesses student capability
2. **Adaptive Learning**: Teacher provides targeted guidance based on assessment
3. **Enhanced Response**: Student generates improved solution using guidance

The system is designed to improve performance while preventing degradation through zero-reward for harmful interventions.

## Setup and Installation

In [None]:
import os
import sys

if 'google.colab' in sys.modules:
    if not os.path.exists('/content/ATLAS'):
        !git clone https://github.com/Arc-Computer/ATLAS.git /content/ATLAS
    
    os.chdir('/content/ATLAS/examples')
    
    if '/content/ATLAS/examples' not in sys.path:
        sys.path.append('/content/ATLAS/examples')
    
    !pip install -q transformers torch accelerate datasets matplotlib pandas numpy
    print("✓ Repository cloned and packages installed for Google Colab")
    print(f"✓ Working directory: {os.getcwd()}")
else:
    if not os.path.exists('utils'):
        print("⚠️  Warning: 'utils' directory not found. Please run from the examples/ directory")
    print("Using local environment")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import warnings
import json
import random
import re
from typing import List, Dict, Any, Optional
warnings.filterwarnings('ignore')

from utils.atlas_inference import ATLASInference, load_atlas_models
from utils.evaluation import calculate_metrics, extract_numerical_answer

print("Imports successful")

## Configuration

In [None]:
DEFAULT_STUDENT_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
DEFAULT_TEACHER_THINKING = "Arc-Intelligence/ATLAS-8B-Thinking"
DEFAULT_TEACHER_INSTRUCT = "Arc-Intelligence/ATLAS-8B-Instruct"

PROBE_TOKEN_LIMIT = 150  
LEARNING_RESPONSE_LIMIT = 500 
STUDENT_RESPONSE_LIMIT = 1000 

CAPABILITY_HIGH_THRESHOLD = 4
CAPABILITY_MEDIUM_THRESHOLD = 2

MIN_GPU_MEMORY_GB = 12
RECOMMENDED_GPU_MEMORY_GB = 16

DEFAULT_NUM_SAMPLES = 20

print("Configuration loaded")

## Dataset Loading Functions

In [None]:
def load_atlas_teach_dataset(split: str = "train", num_samples: Optional[int] = 20) -> List[Dict[str, Any]]:
    """Load Arc-Intelligence/Arc-ATLAS-Teach-v0 dataset."""
    print("Loading Arc-Intelligence/Arc-ATLAS-Teach-v0 dataset...")
    
    try:
        from huggingface_hub import hf_hub_download
        
        # Download the RL training file
        file_path = hf_hub_download(
            repo_id="Arc-Intelligence/Arc-ATLAS-Teach-v0",
            filename="training/rl.jsonl",
            repo_type="dataset"
        )
        
        # Load the JSONL file
        problems = []
        with open(file_path, 'r') as f:
            for line in f:
                if line.strip():
                    item = json.loads(line)
                    
                    problem_text = item.get("prompt", "")
                    ground_truth = item.get("ground_truth", "")
                    
                    if problem_text:
                        problem_dict = {
                            "problem": problem_text,
                            "solution": ground_truth,
                            "source": "Arc-ATLAS-Teach-v0",
                            "problem_id": item.get("problem_id", ""),
                            "student_level": item.get("student_level", ""),
                            "baseline_score": item.get("baseline_score", 0),
                            "with_teaching_score": item.get("with_teaching_score", 0),
                            "teaching": item.get("teaching", ""),
                            "reward": item.get("reward", 0)
                        }
                        
                        # Extract numerical answer
                        if ground_truth:
                            numbers = re.findall(r"[-+]?\d*\.?\d+", str(ground_truth))
                            if numbers:
                                try:
                                    problem_dict["answer"] = float(numbers[-1])
                                except:
                                    pass
                        
                        problems.append(problem_dict)
        
        # Sample if requested
        if num_samples and len(problems) > num_samples:
            problems = random.sample(problems, num_samples)
        
        print(f"Loaded {len(problems)} problems from Arc-ATLAS-Teach dataset")
        return problems
        
    except Exception as e:
        print(f"Error loading Arc-ATLAS-Teach dataset: {e}")
        print("Falling back to sample problems...")
        return get_sample_math_problems()

def get_sample_math_problems() -> List[Dict[str, Any]]:
    """Fallback sample math problems."""
    return [
        {
            "problem": "Sarah has 24 apples. She gives 1/3 of them to her brother and 1/4 of the remaining apples to her sister. How many apples does Sarah have left?",
            "answer": 12,
            "solution": "Sarah starts with 24 apples. She gives 1/3 to her brother: 24 × 1/3 = 8 apples. Remaining: 24 - 8 = 16 apples. She gives 1/4 of remaining to her sister: 16 × 1/4 = 4 apples. Final amount: 16 - 4 = 12 apples.",
            "source": "sample"
        },
        {
            "problem": "A train travels 120 miles in 2 hours. If it maintains the same speed, how far will it travel in 5 hours?",
            "answer": 300,
            "solution": "Speed = Distance ÷ Time = 120 miles ÷ 2 hours = 60 miles per hour. Distance in 5 hours = Speed × Time = 60 mph × 5 hours = 300 miles.",
            "source": "sample"
        },
        {
            "problem": "The sum of two consecutive even numbers is 46. What are the two numbers?",
            "answer": "22 and 24",
            "solution": "Let the first even number be x. The next consecutive even number is x + 2. Sum: x + (x + 2) = 46. Solving: 2x + 2 = 46, so 2x = 44, and x = 22. The two numbers are 22 and 24.",
            "source": "sample"
        }
    ]

## Load Dataset

In [None]:
print("Loading math problems...\n")

try:
    problems = load_atlas_teach_dataset(num_samples=DEFAULT_NUM_SAMPLES)
    print(f"Loaded {len(problems)} problems")
except Exception as e:
    print(f"Dataset loading failed: {e}")
    print("Using sample problems...")
    problems = get_sample_math_problems()

print(f"\nSample problem:")
print(f"Problem: {problems[0]['problem'][:200]}...")
print(f"Answer: {problems[0].get('answer', 'N/A')}")

## Load Models

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {gpu_memory:.1f} GB")
    
    if gpu_memory < MIN_GPU_MEMORY_GB:
        print(f"Warning: GPU memory ({gpu_memory:.1f} GB) below recommended {MIN_GPU_MEMORY_GB} GB")
        print("   Consider using 8-bit quantization or smaller models")
else:
    print("No GPU detected. Using CPU (will be slower)")
    print("   For better performance, use Google Colab with GPU runtime")

In [None]:
print("\nLoading models...")
print(f"Student: {DEFAULT_STUDENT_MODEL}")
print(f"Teacher: {DEFAULT_TEACHER_THINKING}\n")

try:
    reasoning_atlas, _ = load_atlas_models(
        student_model_name=DEFAULT_STUDENT_MODEL,
        teacher_thinking_name=DEFAULT_TEACHER_THINKING,
        teacher_instruct_name=DEFAULT_TEACHER_INSTRUCT,
        device_map="auto",
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    )
    
    atlas = reasoning_atlas
    print("✓ Models loaded successfully")
    print(f"✓ Using device: {device}")
except Exception as e:
    print(f"Error loading models: {e}")
    print("\nTroubleshooting:")
    print("1. Check internet connection")
    print("2. Verify HuggingFace access (run: huggingface-cli login)")
    print("3. Check GPU memory (need ~12GB for both models)")
    print("4. Try reducing batch size or using CPU")
    raise

## Run ATLAS Protocol

In [None]:
results = []
print("\nRunning ATLAS protocol on problems...\n")

for i, problem in enumerate(problems[:3]):  # Focus on 3 high-quality examples
    print(f"{'='*70}")
    print(f"Problem {i+1}/{min(3, len(problems))}")
    print(f"{'='*70}")
    
    try:
        result = atlas.run_full_protocol(problem["problem"])
        
        result["problem_id"] = i
        result["ground_truth"] = problem.get("answer")
        result["problem_text"] = problem["problem"]
        results.append(result)
        
        # Show the actual diagnostic and teaching process
        print(f"\nPROBLEM: {problem['problem'][:200]}...")
        
        print(f"\nDIAGNOSTIC ASSESSMENT:")
        print(f"Capability Score: {result['diagnostic']['capability_score']}/5")
        print(f"Strategy Selected: {result['learning']['strategy']}")
        
        print(f"\nCompleted")
            
    except Exception as e:
        print(f"  Error: {e}")
        continue

print(f"\n{'='*70}")
print(f"Completed {len(results)} problems")

## Analyze Results

In [None]:
if results:
    metrics = calculate_metrics(
        problems[:len(results)], 
        results, 
        task_type="math"
    )
    
    print("\nPerformance Summary:")
    print("=" * 50)
    print(f"Baseline Accuracy: {metrics.get('baseline_accuracy', 0):.1%}")
    print(f"With ATLAS:        {metrics.get('guided_accuracy', 0):.1%}")
    print(f"Improvement:       +{metrics.get('improvement_percentage', 0):.1f}%")
    print(f"Improvements:      {metrics.get('improvements', 0)} problems")
    print(f"Degradations:      {metrics.get('degradations', 0)} problems")
    print(f"Non-degradation:   {metrics.get('non_degradation_rate', 0):.1%}")
    print("=" * 50)
else:
    print("No results to analyze")

## Example Comparisons

In [None]:
if results:
    for i, result in enumerate(results):
        print(f"\n{'='*80}")
        print(f"Example {i+1}: {problems[i]['problem'][:100]}...")
        print('='*80)
        
        print("\nSTUDENT ALONE (Baseline):")
        print(result['baseline_response'])
        
        print(f"\n{'-'*80}")
        print("PASS 1: DIAGNOSTIC PROBING")
        print(f"Capability: {result['diagnostic']['capability_score']}/5 → {result['learning']['strategy']} intervention")
        
        print(f"\n{'-'*80}")
        print("PASS 2: ADAPTIVE TEACHING")
        print(result['learning']['learning_guidance'][:500])
        
        print(f"\n{'-'*80}")
        print("STUDENT WITH ATLAS:")
        print(result['guided_response'])

## Interactive Testing

In [None]:
def test_custom_problem(problem_text: str):
    """Test ATLAS with a custom math problem."""
    print("\n" + "="*60)
    print("Testing custom problem")
    print("="*60)
    print(f"\nProblem: {problem_text}\n")
    
    result = atlas.run_full_protocol(problem_text)
    
    print("\nStudent Response (Alone):")
    print("-" * 40)
    print(result['baseline_response'])
    
    print("\nTeacher Guidance:")
    print("-" * 40)
    print(f"Strategy: {result['learning']['strategy']}")
    print(f"Guidance: {result['learning']['response'][:200]}...")
    
    print("\nStudent Response (With ATLAS):")
    print("-" * 40)
    print(result['guided_response'])
    
    return result

custom_problem = "A store offers a 20% discount on all items. If a jacket originally costs $80, how much will it cost after the discount?"
custom_result = test_custom_problem(custom_problem)

## Token Efficiency Analysis

## Save Results

In [None]:
# Save results for further analysis
if results:
    output_file = "atlas_math_results.json"
    with open(output_file, 'w') as f:
        json.dump({
            "metrics": metrics,
            "results": results,
            "config": {
                "student_model": DEFAULT_STUDENT_MODEL,
                "teacher_model": DEFAULT_TEACHER_THINKING,
                "num_problems": len(results)
            }
        }, f, indent=2)
    
    print(f"\nResults saved to {output_file}")