# ATLAS Code Generation Demo

This notebook demonstrates ATLAS improving code generation and explanation quality.

## Overview

ATLAS enhances code generation through:
1. **Diagnostic Assessment**: Teacher evaluates student's coding capability
2. **Adaptive Guidance**: Tailored instruction based on coding skill level
3. **Enhanced Generation**: Student produces better code with teacher support

Expected improvements: Better code quality, clearer explanations, and more complete solutions.

## Setup and Installation

In [None]:
# Install required packages (for Google Colab)
import sys
if 'google.colab' in sys.modules:
    !pip install -q transformers torch accelerate datasets matplotlib pandas numpy
    print("Packages installed for Google Colab")
else:
    print("Using local environment")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from IPython.display import display, HTML, Markdown
import warnings
import json
import random
import re
from typing import List, Dict, Any, Optional
warnings.filterwarnings('ignore')

# Import ATLAS utilities
from utils.atlas_inference import ATLASInference, load_atlas_models
from utils.evaluation import evaluate_code_quality, calculate_code_metrics
from utils.visualization import plot_code_metrics, display_code_comparison

print("Imports successful")

## Configuration

In [None]:
# Model Configuration
DEFAULT_STUDENT_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
DEFAULT_TEACHER_THINKING = "Arc-Intelligence/ATLAS-8B-Thinking"
DEFAULT_TEACHER_INSTRUCT = "Arc-Intelligence/ATLAS-8B-Instruct"

# Token Limits
PROBE_TOKEN_LIMIT = 50  # Maximum tokens for diagnostic probing
LEARNING_RESPONSE_LIMIT = 200  # Maximum tokens for adaptive learning guidance
STUDENT_RESPONSE_LIMIT = 500  # Higher limit for code generation

# Capability Score Thresholds
CAPABILITY_HIGH_THRESHOLD = 4  # Scores 4-5: Light intervention
CAPABILITY_MEDIUM_THRESHOLD = 2  # Scores 2-3: Medium guidance

# Memory Requirements
MIN_GPU_MEMORY_GB = 12
RECOMMENDED_GPU_MEMORY_GB = 16

# Dataset Settings
DEFAULT_NUM_SAMPLES = 15

print("Configuration loaded")

## Dataset Loading Functions

In [None]:
def load_code_problems(num_samples: Optional[int] = 15) -> List[Dict[str, Any]]:
    """Load coding problems for code generation demo."""
    print("Loading coding problems...")
    
    try:
        from datasets import load_dataset
        # Try loading from HumanEval dataset
        dataset = load_dataset("openai_humaneval", split="test")
        
        problems = []
        for item in dataset:
            problem_dict = {
                "problem": item.get("prompt", ""),
                "expected_behavior": item.get("docstring", ""),
                "test_cases": item.get("test", ""),
                "canonical_solution": item.get("canonical_solution", ""),
                "source": "HumanEval"
            }
            problems.append(problem_dict)
        
        # Sample if requested
        if num_samples and len(problems) > num_samples:
            problems = random.sample(problems, num_samples)
        
        print(f"Loaded {len(problems)} coding problems")
        return problems
        
    except Exception as e:
        print(f"Error loading coding dataset: {e}")
        print("Using sample coding problems...")
        return get_sample_code_problems()

def get_sample_code_problems() -> List[Dict[str, Any]]:
    """Fallback sample coding problems."""
    return [
        {
            "problem": "Write a function that returns the factorial of a positive integer n.",
            "expected_behavior": "factorial(5) should return 120, factorial(0) should return 1",
            "canonical_solution": "def factorial(n):\n    if n <= 1:\n        return 1\n    return n * factorial(n - 1)",
            "source": "sample",
            "difficulty": "easy"
        },
        {
            "problem": "Write a function that checks if a string is a palindrome (reads the same forwards and backwards).",
            "expected_behavior": "is_palindrome('racecar') should return True, is_palindrome('hello') should return False",
            "canonical_solution": "def is_palindrome(s):\n    s = s.lower().replace(' ', '')\n    return s == s[::-1]",
            "source": "sample",
            "difficulty": "easy"
        },
        {
            "problem": "Write a function that finds the longest common subsequence of two strings.",
            "expected_behavior": "lcs('ABCDGH', 'AEDFHR') should return 'ADH', lcs('AGGTAB', 'GXTXAYB') should return 'GTAB'",
            "canonical_solution": "def lcs(X, Y):\n    m, n = len(X), len(Y)\n    L = [[0] * (n + 1) for _ in range(m + 1)]\n    \n    for i in range(m + 1):\n        for j in range(n + 1):\n            if i == 0 or j == 0:\n                L[i][j] = 0\n            elif X[i-1] == Y[j-1]:\n                L[i][j] = L[i-1][j-1] + 1\n            else:\n                L[i][j] = max(L[i-1][j], L[i][j-1])\n    \n    # Reconstruct LCS\n    i, j = m, n\n    lcs_str = []\n    while i > 0 and j > 0:\n        if X[i-1] == Y[j-1]:\n            lcs_str.append(X[i-1])\n            i -= 1\n            j -= 1\n        elif L[i-1][j] > L[i][j-1]:\n            i -= 1\n        else:\n            j -= 1\n    \n    return ''.join(reversed(lcs_str))",
            "source": "sample",
            "difficulty": "hard"
        }
    ]

## Load Dataset

In [None]:
# Load coding problems
print("Loading coding problems...\n")

try:
    problems = load_code_problems(num_samples=DEFAULT_NUM_SAMPLES)
    print(f"Loaded {len(problems)} coding problems")
except Exception as e:
    print(f"Dataset loading failed: {e}")
    print("Using sample problems...")
    problems = get_sample_code_problems()

# Display sample problem
print(f"\nSample problem:")
print(f"Problem: {problems[0]['problem'][:200]}...")
print(f"Expected: {problems[0].get('expected_behavior', 'N/A')[:100]}...")

## Load Models

In [None]:
# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {gpu_memory:.1f} GB")
    
    if gpu_memory < MIN_GPU_MEMORY_GB:
        print(f"Warning: GPU memory ({gpu_memory:.1f} GB) below recommended {MIN_GPU_MEMORY_GB} GB")
        print("   Consider using 8-bit quantization or smaller models")
else:
    print("No GPU detected. Using CPU (will be slower)")
    print("   For better performance, use Google Colab with GPU runtime")

In [None]:
# Load models with ATLAS wrapper
print("\nLoading models...")
print(f"Student: {DEFAULT_STUDENT_MODEL}")
print(f"Teacher: {DEFAULT_TEACHER_INSTRUCT}\n")

try:
    # For code generation, use ATLAS-8B-Instruct
    atlas, models = load_atlas_models(
        student_model_name=DEFAULT_STUDENT_MODEL,
        teacher_thinking_name=DEFAULT_TEACHER_INSTRUCT,  # Using Instruct for code
        device=device,
        load_in_8bit=(gpu_memory < RECOMMENDED_GPU_MEMORY_GB) if device == "cuda" else False
    )
    print("Models loaded successfully")
except Exception as e:
    print(f"Error loading models: {e}")
    print("\nTroubleshooting:")
    print("1. Check internet connection")
    print("2. Verify HuggingFace access")
    print("3. Try with smaller models or enable 8-bit quantization")
    raise

## Run ATLAS Protocol

In [None]:
# Run inference on coding problems
results = []
print("\nRunning ATLAS protocol on coding problems...\n")

for i, problem in enumerate(problems[:3]):  # Run on first 3 for demo
    print(f"Problem {i+1}/{min(3, len(problems))}...")
    
    try:
        # Format problem for code generation
        prompt = f"{problem['problem']}\n\nExpected behavior: {problem.get('expected_behavior', '')}\n\nPlease provide a Python implementation."
        
        # Run full ATLAS protocol
        result = atlas.run_full_protocol(
            prompt,
            ground_truth=problem.get('canonical_solution'),
            max_student_tokens=STUDENT_RESPONSE_LIMIT
        )
        
        # Store results
        result["problem_id"] = i
        result["problem_text"] = problem['problem']
        result["canonical_solution"] = problem.get('canonical_solution')
        results.append(result)
        
        # Show strategy used
        strategy = result.get('learning', {}).get('strategy', 'Unknown')
        print(f"  Teaching strategy: {strategy}")
        print(f"  Completed")
            
    except Exception as e:
        print(f"  Error: {e}")
        continue

print(f"\nCompleted {len(results)} problems")

## Display Code Comparisons

In [None]:
# Show code generation comparisons
def display_code_comparison(result):
    """Display side-by-side comparison of baseline vs guided code."""
    print("\n" + "="*80)
    print(f"Problem: {result['problem_text'][:150]}...")
    print("="*80)
    
    print("\nStudent Code (Alone):")
    print("-" * 40)
    baseline_code = result.get('baseline_response', '')
    print(baseline_code[:500] + ("..." if len(baseline_code) > 500 else ""))
    
    print("\nTeaching Strategy: " + result.get('learning', {}).get('strategy', 'Unknown'))
    
    print("\nStudent Code (With ATLAS):")
    print("-" * 40)
    guided_code = result.get('guided_response', '')
    print(guided_code[:500] + ("..." if len(guided_code) > 500 else ""))
    
    if result.get('canonical_solution'):
        print("\nReference Solution:")
        print("-" * 40)
        print(result['canonical_solution'][:500])

# Display first result in detail
if results:
    display_code_comparison(results[0])

## Code Quality Analysis

In [None]:
# Analyze code quality improvements
def analyze_code_quality(results):
    """Analyze improvements in code quality metrics."""
    metrics = {
        "syntax_valid": {"baseline": 0, "guided": 0},
        "has_function": {"baseline": 0, "guided": 0},
        "has_docstring": {"baseline": 0, "guided": 0},
        "has_comments": {"baseline": 0, "guided": 0},
        "avg_length": {"baseline": [], "guided": []}
    }
    
    for result in results:
        baseline = result.get('baseline_response', '')
        guided = result.get('guided_response', '')
        
        # Check syntax validity (simple heuristic)
        if 'def ' in baseline and ':' in baseline:
            metrics["syntax_valid"]["baseline"] += 1
        if 'def ' in guided and ':' in guided:
            metrics["syntax_valid"]["guided"] += 1
        
        # Check for function definition
        if 'def ' in baseline:
            metrics["has_function"]["baseline"] += 1
        if 'def ' in guided:
            metrics["has_function"]["guided"] += 1
        
        # Check for docstrings
        if '"""' in baseline or "'''" in baseline:
            metrics["has_docstring"]["baseline"] += 1
        if '"""' in guided or "'''" in guided:
            metrics["has_docstring"]["guided"] += 1
        
        # Check for comments
        if '#' in baseline:
            metrics["has_comments"]["baseline"] += 1
        if '#' in guided:
            metrics["has_comments"]["guided"] += 1
        
        # Track length
        metrics["avg_length"]["baseline"].append(len(baseline))
        metrics["avg_length"]["guided"].append(len(guided))
    
    # Calculate percentages
    n = len(results)
    if n > 0:
        print("\nCode Quality Metrics:")
        print("=" * 50)
        print(f"Valid Syntax:   Baseline: {metrics['syntax_valid']['baseline']/n:.0%} | With ATLAS: {metrics['syntax_valid']['guided']/n:.0%}")
        print(f"Has Function:   Baseline: {metrics['has_function']['baseline']/n:.0%} | With ATLAS: {metrics['has_function']['guided']/n:.0%}")
        print(f"Has Docstring:  Baseline: {metrics['has_docstring']['baseline']/n:.0%} | With ATLAS: {metrics['has_docstring']['guided']/n:.0%}")
        print(f"Has Comments:   Baseline: {metrics['has_comments']['baseline']/n:.0%} | With ATLAS: {metrics['has_comments']['guided']/n:.0%}")
        
        avg_baseline = np.mean(metrics["avg_length"]["baseline"])
        avg_guided = np.mean(metrics["avg_length"]["guided"])
        print(f"\nAvg Length:     Baseline: {avg_baseline:.0f} chars | With ATLAS: {avg_guided:.0f} chars")
        print("=" * 50)
    
    return metrics

if results:
    code_metrics = analyze_code_quality(results)

## Interactive Testing

In [None]:
# Test with custom coding problem
def test_custom_code_problem(problem_text: str):
    """Test ATLAS with a custom coding problem."""
    print("\n" + "="*60)
    print("Testing custom coding problem")
    print("="*60)
    print(f"\nProblem: {problem_text}\n")
    
    result = atlas.run_full_protocol(
        problem_text,
        max_student_tokens=STUDENT_RESPONSE_LIMIT
    )
    
    print("\nStudent Code (Alone):")
    print("-" * 40)
    print(result['baseline_response'])
    
    print("\nTeacher Guidance:")
    print("-" * 40)
    print(f"Strategy: {result['learning']['strategy']}")
    print(f"Guidance preview: {result['learning']['response'][:200]}...")
    
    print("\nStudent Code (With ATLAS):")
    print("-" * 40)
    print(result['guided_response'])
    
    return result

# Example usage
custom_problem = """Write a Python function that takes a list of integers and returns 
a new list containing only the unique elements while preserving the original order.
For example: unique_ordered([1, 2, 2, 3, 1, 4]) should return [1, 2, 3, 4]"""

custom_result = test_custom_code_problem(custom_problem)

## Teaching Strategy Analysis

In [None]:
# Analyze teaching strategies used
if results:
    strategies = [r.get('learning', {}).get('strategy', 'Unknown') for r in results]
    strategy_counts = pd.Series(strategies).value_counts()
    
    print("\nTeaching Strategies Used:")
    print("=" * 40)
    for strategy, count in strategy_counts.items():
        print(f"{strategy}: {count} ({count/len(results):.0%})")
    
    # Plot strategy distribution
    if len(strategy_counts) > 0:
        plt.figure(figsize=(10, 6))
        plt.pie(strategy_counts.values, labels=strategy_counts.index, autopct='%1.0f%%')
        plt.title('Distribution of Teaching Strategies')
        plt.show()

## Example: Complex Problem

In [None]:
# Test with a more complex problem
complex_problem = """Write a Python class called 'TaskQueue' that implements a priority queue with the following features:
1. add_task(task_name, priority) - adds a task with given priority (higher number = higher priority)
2. get_next_task() - returns and removes the highest priority task
3. peek() - returns the highest priority task without removing it
4. is_empty() - returns True if queue is empty
5. size() - returns the number of tasks in the queue

The implementation should handle ties in priority by maintaining FIFO order."""

print("Testing with complex problem...\n")
complex_result = test_custom_code_problem(complex_problem)

## Token Efficiency

In [None]:
# Analyze token usage for code generation
if results:
    total_probe_tokens = sum(r.get('probe', {}).get('tokens_used', 0) for r in results)
    total_learning_tokens = sum(r.get('learning', {}).get('tokens_used', 0) for r in results)
    total_baseline_tokens = sum(len(r.get('baseline_response', '').split()) for r in results)
    total_guided_tokens = sum(len(r.get('guided_response', '').split()) for r in results)
    
    avg_probe = total_probe_tokens / len(results) if results else 0
    avg_learning = total_learning_tokens / len(results) if results else 0
    avg_overhead = avg_probe + avg_learning
    
    print("\nToken Efficiency for Code Generation:")
    print("=" * 50)
    print(f"Average probe tokens:     {avg_probe:.0f}")
    print(f"Average learning tokens:  {avg_learning:.0f}")
    print(f"Total teacher overhead:   {avg_overhead:.0f}")
    print(f"\nBaseline code avg words: {total_baseline_tokens/len(results):.0f}")
    print(f"Guided code avg words:    {total_guided_tokens/len(results):.0f}")
    print("=" * 50)

## Save Results

In [None]:
# Save results for further analysis
if results:
    output_file = "atlas_code_results.json"
    with open(output_file, 'w') as f:
        # Prepare serializable results
        save_results = []
        for r in results:
            save_results.append({
                "problem_id": r.get("problem_id"),
                "problem_text": r.get("problem_text"),
                "baseline_response": r.get("baseline_response"),
                "guided_response": r.get("guided_response"),
                "teaching_strategy": r.get("learning", {}).get("strategy"),
                "canonical_solution": r.get("canonical_solution")
            })
        
        json.dump({
            "results": save_results,
            "config": {
                "student_model": DEFAULT_STUDENT_MODEL,
                "teacher_model": DEFAULT_TEACHER_INSTRUCT,
                "num_problems": len(results)
            }
        }, f, indent=2)
    
    print(f"\nResults saved to {output_file}")

## Conclusion

This demo showed how ATLAS improves code generation through:

1. **Diagnostic Assessment**: Teacher evaluates student's coding capability
2. **Adaptive Guidance**: Tailored instruction based on skill level
3. **Quality Improvements**: Better structure, documentation, and correctness

### Key Benefits for Code Generation
- More complete and correct implementations
- Better code structure and organization
- Improved documentation (docstrings, comments)
- Handling of edge cases and error conditions

### Next Steps
- Test with domain-specific coding tasks
- Fine-tune teacher models for specific languages/frameworks
- Integrate with code review workflows
- Deploy as coding assistant API

For more information, see the [main repository](https://github.com/Arc-Intelligence/RCL).