# ATLAS Online Optimization Demo

This tutorial walks you through the full ATLAS hybrid workflow: start from the **offline-trained ATLAS-8B-Thinking teacher** and then run the GEPA online optimization loop to specialize it for a tricky math reasoning task. You will:
1. Run the teacher–student protocol with the stock prompts and observe a failure case.
2. Launch GEPA reflective prompt evolution on a focused task dataset.
3. Compare the evolved prompts against the originals and validate the improvement.

The workflow mirrors the process documented in [reference/supercharging-rl-with-online-optimization.md](../reference/supercharging-rl-with-online-optimization.md) and [reference/atlas-sre-diagnosis.md](../reference/atlas-sre-diagnosis.md). The headline metrics—like the **165% gain achieved in two hours of online optimization**—are verified in the project [README](../README.md) and the full methodology is covered in `docs/ATLAS-Technical-Report.pdf`.


## Prerequisites
- In Google Colab, select *Runtime ▸ Change runtime type* and choose **A100 GPU** before executing anything below.
- Install dependencies in the next cell (`gepa[full]`, `transformers`, `litellm`, quantization helpers, and tooling for `.env` loading).
- Provide provider credentials:
  - `GEMINI_API_KEY` for the reflection/evaluation model `gemini/gemini-flash-2.5` (store it with `os.environ[...] = "..."` or a `.env`).
  - Optional: `HUGGINGFACE_TOKEN` to speed up Hugging Face downloads.
- This walkthrough assumes familiarity with the two-pass protocol defined in the ATLAS technical report; if you need a refresher, review `docs/ATLAS-Technical-Report.pdf` first.


In [None]:

import os
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd()
if PROJECT_ROOT.name == 'examples':
    PROJECT_ROOT = PROJECT_ROOT.parent
os.environ['ATLAS_PROJECT_ROOT'] = str(PROJECT_ROOT)
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))
print(f'Using project root: {PROJECT_ROOT}')


In [None]:

!pip install -q -r $ATLAS_PROJECT_ROOT/requirements-optimize.txt transformers accelerate bitsandbytes litellm python-dotenv pandas datasets


In [None]:

import json
import textwrap
import torch
from pathlib import Path

from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer

load_dotenv()

torch.set_grad_enabled(False)

def load_chat_model(model_name: str, *, dtype: str | None = None):
    load_kwargs = dict(
        device_map='auto',
        trust_remote_code=True,
    )
    if dtype == 'float16':
        load_kwargs['torch_dtype'] = torch.float16
    elif dtype == 'bfloat16':
        load_kwargs['torch_dtype'] = torch.bfloat16
    model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'
    return model, tokenizer

def make_chat_generator(model, tokenizer, *, max_new_tokens: int, temperature: float):
    def _generate(prompts):
        single = isinstance(prompts, str)
        prompt_list = [prompts] if single else list(prompts)
        outputs = []
        for prompt in prompt_list:
            messages = [{"role": "user", "content": prompt}]
            inputs = tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt",
            )
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            generated = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=temperature > 0,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            prompt_length = inputs["input_ids"].shape[-1]
            completion_tokens = generated[0][prompt_length:]
            decoded = tokenizer.decode(completion_tokens, skip_special_tokens=True).strip()
            outputs.append(decoded)
        return outputs[0] if single else outputs
    return _generate

teacher_model_name = 'Arc-Intelligence/ATLAS-8B-Thinking'
student_model_name = 'Qwen/Qwen3-4B-Instruct-2507-FP8'

teacher_model, teacher_tokenizer = load_chat_model(teacher_model_name, dtype='float16')
student_model, student_tokenizer = load_chat_model(student_model_name)

teacher_generate = make_chat_generator(teacher_model, teacher_tokenizer, max_new_tokens=512, temperature=0.1)
student_generate = make_chat_generator(student_model, student_tokenizer, max_new_tokens=512, temperature=0.0)

print('Models ready on devices:', teacher_model.device, student_model.device)


### Sanity-check the direct model loads
The snippet below mirrors the quick-start example from the prompt: it calls the teacher with a simple chat turn so you can confirm the model is resident on the GPU.


In [None]:
messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = teacher_tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors='pt',
)
inputs = {k: v.to(teacher_model.device) for k, v in inputs.items()}
outputs = teacher_model.generate(**inputs, max_new_tokens=40)
completion = outputs[0][inputs['input_ids'].shape[-1]:]
print(teacher_tokenizer.decode(completion, skip_special_tokens=True))


<Note>
To sanity-check the student, you can run the same snippet using `student_model` / `student_tokenizer`.
</Note>


In [None]:

if 'GEMINI_API_KEY' not in os.environ:
    print('⚠️  GEMINI_API_KEY not found. Set it before running the GEPA optimization cell.')
else:
    print('GEMINI_API_KEY detected; GEPA will use gemini/gemini-flash-2.5 via LiteLLM.')


In [None]:

from trainers.extraction_utils import ATLASExtractionUtils
from trainers.prompt_adapter import ATLASGEPAAdapter

seed_prompts = {
    'teacher_adaptive_template': textwrap.dedent('''
        You are an expert teacher guiding a math student on a multi-step reasoning problem.
        Question: {question}

        Student approach: {approach}

        Provide a concise teaching plan that highlights the next quantitative steps.
        Wrap the guidance in <teaching> tags and point out specific arithmetic checks.
    ''').strip(),
    'student_diagnostic_template': textwrap.dedent('''
        You are preparing to solve a challenging math word problem.
        Problem: {question}

        Outline your plan step by step. Identify the quantities you will compute,
        any intermediate totals, and where mistakes are likely. Do not solve yet.
    ''').strip(),
    'student_with_teaching_template': textwrap.dedent('''
        Use the teacher's guidance to finish the problem.
        Question: {question}
        Teaching: {teaching}

        Apply each instruction, show the updated calculations, and place the final answer inside <solution> tags.
    ''').strip(),
}

fixed_prompts = {
    'student_baseline_template': textwrap.dedent('''
        Solve the following problem step by step. Show all work and provide the final answer.
        Question: {question}

        Solution:
    ''').strip()
}

generation_config = {
    'max_tokens': 512,
    'diagnostic_max_tokens': 320,
    'temperature': 0.2,
    'timeout': 600,
    'request_timeout': 600,
}

TASK_SAMPLES = [
    {
        'question': 'During a two hour production shift, a snack factory runs three packaging lines. Line A produces 120 bars per minute, line B produces 150 bars per minute, and line C produces 90 bars per minute. The crew takes four scheduled five minute breaks when all lines stop. Each shipping box holds exactly 24 bars. How many fully packed boxes leave the factory during the shift?',
        'ground_truth': '1500',
        'additional_context': {}
    },
    {
        'question': 'A community fitness center sells 36 Basic memberships at $45 per month and 24 Family memberships at $68 per month. To encourage upgrades, the Family plan is discounted by 25 percent for the first three months while Basic members pay full price. What is the total revenue during that first quarter?',
        'ground_truth': '8532',
        'additional_context': {}
    }
]

trace_path = PROJECT_ROOT / 'traces' / 'online_demo' / 'notebook_run.jsonl'
trace_path.parent.mkdir(parents=True, exist_ok=True)

adapter = ATLASGEPAAdapter(
    teacher_model=teacher_generate,
    student_model=student_generate,
    trace_storage_path=str(trace_path),
    all_prompts={**seed_prompts, **fixed_prompts},
    generation_config=generation_config,
)

print('Adapter initialized. Baseline prompts loaded.')


## 1. Establish the Baseline
Run the hybrid protocol on the first task with the stock prompts. The offline-trained teacher is strong, but its generic prompt often fails to catch the arithmetic workload for this factory scenario—illustrating the “before” state prior to online optimization.


In [None]:

from IPython.display import Markdown, display

baseline_example = [TASK_SAMPLES[0]]
baseline_eval = adapter.evaluate(baseline_example, candidate=seed_prompts, capture_traces=True)
baseline_output = baseline_eval.outputs[0]
teacher_teaching = ATLASExtractionUtils.extract_teaching_content(baseline_output['teacher_response'])
baseline_answer = ATLASExtractionUtils.extract_solution(baseline_output['student_with_teaching'])
baseline_correct = ATLASExtractionUtils.check_correctness(baseline_answer, baseline_example[0]['ground_truth'])
baseline_student_only = ATLASExtractionUtils.extract_solution(baseline_output['student_baseline'])

display(Markdown(textwrap.dedent(f'''
### Baseline run
- **Ground truth:** {baseline_example[0]['ground_truth']}
- **Student without teaching:** `{baseline_student_only}`
- **Teacher guidance:**
```
{teacher_teaching}
```
- **Student with teaching:** `{baseline_answer}`
- **Correct after teaching?** {baseline_correct}
''')))

if baseline_correct:
    print('The baseline happened to succeed. If you see this, re-run the cell or adjust the task to reproduce the intended failure before optimization.')


## 2. Configure GEPA for Online Optimization
Create the focused dataset and optimization config. This mirrors the configuration template shown in `reference/atlas-sre-diagnosis.md`: it specifies the teacher/student pair, the LiteLLM-backed evaluator, and `reflection_goal` guidance for each component.


In [None]:

%%writefile $ATLAS_PROJECT_ROOT/examples/task_samples.jsonl
{"question": "During a two hour production shift, a snack factory runs three packaging lines. Line A produces 120 bars per minute, line B produces 150 bars per minute, and line C produces 90 bars per minute. The crew takes four scheduled five minute breaks when all lines stop. Each shipping box holds exactly 24 bars. How many fully packed boxes leave the factory during the shift?", "ground_truth": "1500", "additional_context": {}}
{"question": "A community fitness center sells 36 Basic memberships at $45 per month and 24 Family memberships at $68 per month. To encourage upgrades, the Family plan is discounted by 25 percent for the first three months while Basic members pay full price. What is the total revenue during that first quarter?", "ground_truth": "8532", "additional_context": {}}


In [None]:

%%writefile $ATLAS_PROJECT_ROOT/configs/optimize/demo_config.yaml
max_examples: 2

student_model: Qwen/Qwen3-4B-Instruct-2507-FP8
teacher_model: Arc-Intelligence/ATLAS-8B-Thinking

reflection_lm: gemini/gemini-flash-2.5

trace_storage: traces/online_demo/gepa_traces.jsonl
output: examples/optimized_prompts_demo.json

generation_config:
  max_tokens: 512
  diagnostic_max_tokens: 320
  temperature: 0.4
  timeout: 600
  request_timeout: 600

data_source:
  type: file
  path: examples/task_samples.jsonl
  columns:
    question: question
    answer: ground_truth

seed_prompts:
  teacher_adaptive_template: |
    You are an expert teacher guiding a math student on a multi-step reasoning problem.
    Question: {question}

    Student approach: {approach}

    Provide a concise teaching plan that highlights the next quantitative steps.
    Wrap the guidance in <teaching> tags and point out specific arithmetic checks.
  student_diagnostic_template: |
    You are preparing to solve a challenging math word problem.
    Problem: {question}

    Outline your plan step by step. Identify the quantities you will compute,
    any intermediate totals, and where mistakes are likely. Do not solve yet.
  student_with_teaching_template: |
    Use the teacher's guidance to finish the problem.
    Question: {question}
    Teaching: {teaching}

    Apply each instruction, show the updated calculations, and place the final answer inside <solution> tags.

fixed_prompts:
  student_baseline_template: |
    Solve the following problem step by step. Show all work and provide the final answer.
    Question: {question}

    Solution:

evaluation:
  evaluation_model: gemini/gemini-flash-2.5
  metrics:
    - name: math_correctness
      description: Judge whether the student_with_teaching answer matches the ground truth integer.
      weight: 1.0
      criteria: |
        Score 1.0 if the final answer inside <solution> matches the ground truth.
        Otherwise score 0.0.

optimization_targets:
  teacher_adaptive_template:
    optimize: true
    reflection_goal: >
      Make the teaching highly specific to multi step arithmetic so the student corrects calculation errors and returns the right integer answer.
  student_diagnostic_template:
    optimize: true
    reflection_goal: >
      Encourage the student to outline every numeric step and surface potential bottlenecks before solving.
  student_with_teaching_template:
    optimize: true
    reflection_goal: >
      Ensure the student applies the teaching instructions, shows updated calculations, and places the final answer inside <solution> tags.

gepa_config:
  reflection_minibatch_size: 2
  candidate_selection_strategy: pareto
  display_progress_bar: false

wandb:
  enabled: false


## 3. Run GEPA Reflective Prompt Evolution
Execute the optimization script with a small metric budget (`--max-metric-calls=10`). LiteLLM routes the reflection and evaluation calls to `gemini/gemini-flash-2.5`, so make sure your `GEMINI_API_KEY` is set.


In [None]:

from pathlib import Path as _Path

_Path(PROJECT_ROOT / 'examples').mkdir(exist_ok=True)
_Path(PROJECT_ROOT / 'results').mkdir(exist_ok=True)

!cd $ATLAS_PROJECT_ROOT && python optimize_teaching.py --config configs/optimize/demo_config.yaml --max-metric-calls=10


## 4. Inspect the Evolved Prompts
Load the optimized candidate and compare it with the seed prompts. GEPA’s reflective loop should add domain-specific checks while preserving the structure of the protocol.


In [None]:

from IPython.display import HTML
import html

optimized_path = PROJECT_ROOT / 'examples' / 'optimized_prompts_demo.json'
with open(optimized_path, 'r') as f:
    optimized_data = json.load(f)

best_candidate = optimized_data.get('best_candidate', {})
initial_score = optimized_data.get('initial_score')
best_score = optimized_data.get('best_score')

print(f'Initial score: {initial_score}')
print(f'Best score: {best_score}')

def render_cell(text: str) -> str:
    return f"<pre>{html.escape(text)}</pre>"

rows = []
for key, label in [
    ('teacher_adaptive_template', 'Teacher Adaptive Template'),
    ('student_diagnostic_template', 'Student Diagnostic Template'),
    ('student_with_teaching_template', 'Student With Teaching Template'),
]:
    original = seed_prompts.get(key, '')
    optimized = best_candidate.get(key, '')
    rows.append(f'<tr><th>{label}</th><td>{render_cell(original)}</td><td>{render_cell(optimized)}</td></tr>')

table_html = (
    "<table style='width:100%; table-layout:fixed;'>
"
    "<thead><tr><th>Component</th><th>Original Prompt</th><th>Optimized Prompt</th></tr></thead>
"
    f"<tbody>{''.join(rows)}</tbody></table>"
)

display(HTML(table_html))


## 5. Validate the Optimized Prompts
Re-run the protocol on the original problem using the evolved prompt. This “after” state should showcase the online gains emphasized in the README and technical report.


In [None]:

optimized_eval = adapter.evaluate(baseline_example, candidate=best_candidate, capture_traces=True)
optimized_output = optimized_eval.outputs[0]
optimized_answer = ATLASExtractionUtils.extract_solution(optimized_output['student_with_teaching'])
optimized_correct = ATLASExtractionUtils.check_correctness(
    optimized_answer, baseline_example[0]['ground_truth']
)

display(Markdown(textwrap.dedent(f'''
### Optimized run
- **Student with teaching:** `{optimized_answer}`
- **Correct after teaching?** {optimized_correct}
''')))

if not optimized_correct:
    print('If the optimized prompt still misses, increase the metric budget or include additional task examples before re-running GEPA.')


## Where to Go Next
- Dive deeper into the hybrid architecture and reflective mutation process in `docs/ATLAS-Technical-Report.pdf` and [reference/supercharging-rl-with-online-optimization.md](../reference/supercharging-rl-with-online-optimization.md).
- For production hardening (logging, distributed traces, replay datasets), follow the guidance in `reference/atlas-sre-diagnosis.md`.
- To reproduce the benchmarked gains cited in the README (15.7% average accuracy, 165% online improvement), expand the dataset and metric budget in this notebook and compare before/after scores on held-out validation sets.
