# ATLAS Math Reasoning Demo

This notebook demonstrates the ATLAS two-pass inference protocol improving math problem solving accuracy.

## Overview

ATLAS uses a teacher model to guide student models through:
1. **Diagnostic Probing**: Teacher assesses student capability (~50 tokens)
2. **Adaptive Learning**: Teacher provides targeted guidance based on assessment
3. **Enhanced Response**: Student generates improved solution using guidance

Expected improvement: **15.7% accuracy gain** with near-zero degradation.

## Setup and Installation

In [None]:
# Install required packages (for Google Colab)
import sys
if 'google.colab' in sys.modules:
    !pip install -q transformers torch accelerate datasets matplotlib pandas numpy
    print("Packages installed for Google Colab")
else:
    print("Using local environment")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from IPython.display import display, HTML
import warnings
import json
import random
import re
from typing import List, Dict, Any, Optional
warnings.filterwarnings('ignore')

# Import ATLAS utilities
from utils.atlas_inference import ATLASInference, load_atlas_models
from utils.evaluation import calculate_metrics
from utils.visualization import plot_comparison, display_results_table, show_example_comparisons

print("Imports successful")

## Configuration

In [None]:
# Model Configuration
DEFAULT_STUDENT_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
DEFAULT_TEACHER_THINKING = "Arc-Intelligence/ATLAS-8B-Thinking"
DEFAULT_TEACHER_INSTRUCT = "Arc-Intelligence/ATLAS-8B-Instruct"

# Token Limits
PROBE_TOKEN_LIMIT = 50  # Maximum tokens for diagnostic probing
LEARNING_RESPONSE_LIMIT = 200  # Maximum tokens for adaptive learning guidance
STUDENT_RESPONSE_LIMIT = 300  # Maximum tokens for student responses

# Capability Score Thresholds
CAPABILITY_HIGH_THRESHOLD = 4  # Scores 4-5: Light intervention
CAPABILITY_MEDIUM_THRESHOLD = 2  # Scores 2-3: Medium guidance
# Score 1: Heavy support

# Evaluation Settings
DEGRADATION_PENALTY_MULTIPLIER = 2.0
IMPROVEMENT_REWARD = 1.0
NO_CHANGE_REWARD = 0.0

# Memory Requirements
MIN_GPU_MEMORY_GB = 12
RECOMMENDED_GPU_MEMORY_GB = 16

# Dataset Settings
DEFAULT_NUM_SAMPLES = 20
DEFAULT_DATASET_SPLIT = "train"

print("Configuration loaded")

## Dataset Loading Functions

In [None]:
def load_atlas_teach_dataset(split: str = "train", num_samples: Optional[int] = 20) -> List[Dict[str, Any]]:
    """Load Arc-Intelligence/Arc-ATLAS-Teach-v0 dataset."""
    print("Loading Arc-Intelligence/Arc-ATLAS-Teach-v0 dataset...")
    
    try:
        from huggingface_hub import hf_hub_download
        
        # Download the RL training file
        file_path = hf_hub_download(
            repo_id="Arc-Intelligence/Arc-ATLAS-Teach-v0",
            filename="training/rl.jsonl",
            repo_type="dataset"
        )
        
        # Load the JSONL file
        problems = []
        with open(file_path, 'r') as f:
            for line in f:
                if line.strip():
                    item = json.loads(line)
                    
                    problem_text = item.get("prompt", "")
                    ground_truth = item.get("ground_truth", "")
                    
                    if problem_text:
                        problem_dict = {
                            "problem": problem_text,
                            "solution": ground_truth,
                            "source": "Arc-ATLAS-Teach-v0",
                            "problem_id": item.get("problem_id", ""),
                            "student_level": item.get("student_level", ""),
                            "baseline_score": item.get("baseline_score", 0),
                            "with_teaching_score": item.get("with_teaching_score", 0),
                            "teaching": item.get("teaching", ""),
                            "reward": item.get("reward", 0)
                        }
                        
                        # Extract numerical answer
                        if ground_truth:
                            numbers = re.findall(r"[-+]?\d*\.?\d+", str(ground_truth))
                            if numbers:
                                try:
                                    problem_dict["answer"] = float(numbers[-1])
                                except:
                                    pass
                        
                        problems.append(problem_dict)
        
        # Sample if requested
        if num_samples and len(problems) > num_samples:
            problems = random.sample(problems, num_samples)
        
        print(f"Loaded {len(problems)} problems from Arc-ATLAS-Teach dataset")
        return problems
        
    except Exception as e:
        print(f"Error loading Arc-ATLAS-Teach dataset: {e}")
        print("Falling back to sample problems...")
        return get_sample_math_problems()

def get_sample_math_problems() -> List[Dict[str, Any]]:
    """Fallback sample math problems."""
    return [
        {
            "problem": "Sarah has 24 apples. She gives 1/3 of them to her brother and 1/4 of the remaining apples to her sister. How many apples does Sarah have left?",
            "answer": 12,
            "solution": "Sarah starts with 24 apples. She gives 1/3 to her brother: 24 × 1/3 = 8 apples. Remaining: 24 - 8 = 16 apples. She gives 1/4 of remaining to her sister: 16 × 1/4 = 4 apples. Final amount: 16 - 4 = 12 apples.",
            "source": "sample"
        },
        {
            "problem": "A train travels 120 miles in 2 hours. If it maintains the same speed, how far will it travel in 5 hours?",
            "answer": 300,
            "solution": "Speed = Distance ÷ Time = 120 miles ÷ 2 hours = 60 miles per hour. Distance in 5 hours = Speed × Time = 60 mph × 5 hours = 300 miles.",
            "source": "sample"
        },
        {
            "problem": "The sum of two consecutive even numbers is 46. What are the two numbers?",
            "answer": "22 and 24",
            "solution": "Let the first even number be x. The next consecutive even number is x + 2. Sum: x + (x + 2) = 46. Solving: 2x + 2 = 46, so 2x = 44, and x = 22. The two numbers are 22 and 24.",
            "source": "sample"
        }
    ]

## Load Dataset

In [None]:
# Load math problems
print("Loading math problems...\n")

try:
    # Try Arc-ATLAS-Teach dataset first
    problems = load_atlas_teach_dataset(num_samples=DEFAULT_NUM_SAMPLES)
    print(f"Loaded {len(problems)} problems")
except Exception as e:
    print(f"Dataset loading failed: {e}")
    print("Using sample problems...")
    problems = get_sample_math_problems()

# Display sample problems
print(f"\nSample problem:")
print(f"Problem: {problems[0]['problem'][:200]}...")
print(f"Answer: {problems[0].get('answer', 'N/A')}")

## Load Models

In [None]:
# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {gpu_memory:.1f} GB")
    
    if gpu_memory < MIN_GPU_MEMORY_GB:
        print(f"Warning: GPU memory ({gpu_memory:.1f} GB) below recommended {MIN_GPU_MEMORY_GB} GB")
        print("   Consider using 8-bit quantization or smaller models")
else:
    print("No GPU detected. Using CPU (will be slower)")
    print("   For better performance, use Google Colab with GPU runtime")

In [None]:
# Load models with ATLAS wrapper
print("\nLoading models...")
print(f"Student: {DEFAULT_STUDENT_MODEL}")
print(f"Teacher: {DEFAULT_TEACHER_THINKING}\n")

try:
    atlas, models = load_atlas_models(
        student_model_name=DEFAULT_STUDENT_MODEL,
        teacher_thinking_name=DEFAULT_TEACHER_THINKING,
        device=device,
        load_in_8bit=(gpu_memory < RECOMMENDED_GPU_MEMORY_GB) if device == "cuda" else False
    )
    print("Models loaded successfully")
except Exception as e:
    print(f"Error loading models: {e}")
    print("\nTroubleshooting:")
    print("1. Check internet connection")
    print("2. Verify HuggingFace access (some models require authentication)")
    print("3. Try with smaller models or enable 8-bit quantization")
    raise

## Run ATLAS Protocol

In [None]:
# Run inference on all problems
results = []
print("\nRunning ATLAS protocol on problems...\n")

for i, problem in enumerate(problems[:5]):  # Run on first 5 for demo
    print(f"Problem {i+1}/{min(5, len(problems))}...")
    
    try:
        # Run full ATLAS protocol
        result = atlas.run_full_protocol(
            problem["problem"],
            ground_truth=problem.get("answer"),
            max_student_tokens=STUDENT_RESPONSE_LIMIT
        )
        
        # Store results
        result["problem_id"] = i
        result["ground_truth"] = problem.get("answer")
        results.append(result)
        
        # Show improvement
        if result.get("improvement_category") == "improved":
            print(f"  Improved: {result.get('baseline_correct', False)} -> {result.get('guided_correct', True)}")
        elif result.get("improvement_category") == "degraded":
            print(f"  Degraded: {result.get('baseline_correct', True)} -> {result.get('guided_correct', False)}")
        else:
            print(f"  No change")
            
    except Exception as e:
        print(f"  Error: {e}")
        continue

print(f"\nCompleted {len(results)} problems")

## Analyze Results

In [None]:
# Calculate metrics
if results:
    metrics = calculate_metrics(results)
    
    print("\nPerformance Summary:")
    print("=" * 50)
    print(f"Baseline Accuracy: {metrics['baseline_accuracy']:.1%}")
    print(f"With ATLAS:        {metrics['guided_accuracy']:.1%}")
    print(f"Improvement:       {metrics['improvement_rate']:.1%}")
    print(f"Degradation:       {metrics['degradation_rate']:.1%}")
    print(f"Non-degradation:   {metrics['non_degradation_rate']:.1%}")
    print("=" * 50)
    
    # Visualize results
    plot_comparison(results, metrics)
else:
    print("No results to analyze")

## Example Comparisons

In [None]:
# Show detailed comparisons for improved cases
if results:
    show_example_comparisons(results, problems, n_examples=3)

## Interactive Testing

In [None]:
# Test with your own problem
def test_custom_problem(problem_text: str):
    """Test ATLAS with a custom math problem."""
    print("\n" + "="*60)
    print("Testing custom problem")
    print("="*60)
    print(f"\nProblem: {problem_text}\n")
    
    result = atlas.run_full_protocol(problem_text)
    
    print("\nStudent Response (Alone):")
    print("-" * 40)
    print(result['baseline_response'])
    
    print("\nTeacher Guidance:")
    print("-" * 40)
    print(f"Strategy: {result['learning']['strategy']}")
    print(f"Guidance: {result['learning']['response'][:200]}...")
    
    print("\nStudent Response (With ATLAS):")
    print("-" * 40)
    print(result['guided_response'])
    
    return result

# Example usage
custom_problem = "A store offers a 20% discount on all items. If a jacket originally costs $80, how much will it cost after the discount?"
custom_result = test_custom_problem(custom_problem)

## Token Efficiency Analysis

In [None]:
# Analyze token usage
if results:
    total_probe_tokens = sum(r.get('probe', {}).get('tokens_used', 0) for r in results)
    total_learning_tokens = sum(r.get('learning', {}).get('tokens_used', 0) for r in results)
    total_baseline_tokens = sum(r.get('baseline_tokens', 0) for r in results)
    total_guided_tokens = sum(r.get('guided_tokens', 0) for r in results)
    
    avg_probe = total_probe_tokens / len(results)
    avg_learning = total_learning_tokens / len(results)
    avg_overhead = avg_probe + avg_learning
    
    print("\nToken Efficiency:")
    print("=" * 50)
    print(f"Average probe tokens:     {avg_probe:.0f}")
    print(f"Average learning tokens:  {avg_learning:.0f}")
    print(f"Total teacher overhead:   {avg_overhead:.0f}")
    print(f"\nBaseline response avg:    {total_baseline_tokens/len(results):.0f}")
    print(f"Guided response avg:      {total_guided_tokens/len(results):.0f}")
    print("=" * 50)
    
    efficiency_ratio = avg_overhead / (total_guided_tokens/len(results))
    print(f"\nEfficiency ratio: {efficiency_ratio:.1%} overhead for {metrics['improvement_rate']:.1%} gain")

## Save Results

In [None]:
# Save results for further analysis
if results:
    output_file = "atlas_math_results.json"
    with open(output_file, 'w') as f:
        json.dump({
            "metrics": metrics,
            "results": results,
            "config": {
                "student_model": DEFAULT_STUDENT_MODEL,
                "teacher_model": DEFAULT_TEACHER_THINKING,
                "num_problems": len(results)
            }
        }, f, indent=2)
    
    print(f"\nResults saved to {output_file}")

## Conclusion

This demo showed how ATLAS improves math problem solving through:

1. **Diagnostic Probing**: Teacher assesses student capability in ~50 tokens
2. **Adaptive Teaching**: Conditional guidance based on diagnosed strength
3. **Performance Gains**: ~15.7% accuracy improvement with minimal overhead

### Key Takeaways
- ATLAS works with any student model (4B-8B parameters)
- Teacher overhead is minimal (~250 tokens total)
- Non-degradation rate of 97% ensures reliability
- Drop-in replacement for existing inference pipelines

### Next Steps
- Try with your own student models
- Test on different problem types
- Fine-tune teacher models for specific domains
- Deploy in production with the vLLM server

For more information, see the [main repository](https://github.com/Arc-Intelligence/RCL).