# Lesson 4: Scaling TTRL on GSM8K

## üèãÔ∏è Benchmarking the Experiment
We've seen TTRL work on single examples. But is it actually reliable? 
To solve this, we need a **Benchmark**.

**The Dataset: GSM8K (Grade School Math 8K)**
This is the gold-standard dataset for LLM logical reasoning. It contains 8,500 high-quality math word problems created by humans.

**Our Experiment**:
We will run a head-to-head comparison:
1.  **Baseline**: Standard Mistral (One guess, Temperature 0).
2.  **TTRL Agent**: Mistral + Best-of-3 Search + Self-Verification.

We expect the TTRL Agent to achieve a higher score.

In [None]:
import re
from datasets import load_dataset
from rich.console import Console
from rich.table import Table
try:
    import ollama
except ImportError:
    print("pip install ollama")

console = Console()
MODEL_NAME = "mistral:7b"

### üì¶ Step 1: Loading the Data
We use the Hugging Face `datasets` library to pull GSM8K.
The dataset has a `test` split which is what we use for evaluation.

In [None]:
console.print("[yellow]Loading GSM8K...[/yellow]")
dataset = load_dataset("gsm8k", "main", split="test")

# For this tutorial, we select 5 examples for speed.
# In a real eval, you would use 100+.
examples = dataset.select(range(5))

def extract_answer(text: str):
    """Extracts the number after #### in GSM8K solutions"""
    match = re.search(r'####\s*(\d+)', text)
    return match.group(1) if match else None

### ‚öôÔ∏è Step 2: The Evaluation Engine
We define a function `solve_problem` that can toggle between modes.

*   **`method='greedy'`**: What most people use. Fast, cheap, often wrong.
*   **`method='ttrl'`**: Our agentic loop. Slower, more expensive, but smarter.

In [None]:
def solve_problem(prompt: str, method="greedy") -> str:
    """Solves using either Greedy (Baseline) or TTRL (Best-of-N)"""
    
    if method == "greedy":
        # Simple Zero-Shot
        response = ollama.chat(model=MODEL_NAME, messages=[{"role": "user", "content": prompt}])
        return response['message']['content']
    
    elif method == "ttrl":
        # Best-of-3 with Verification
        best_score = -1
        best_ans = ""
        
        for _ in range(3):
            # 1. Generate Proposal (High Temp)
            cand = ollama.chat(model=MODEL_NAME, messages=[{"role": "user", "content": prompt}], options={"temperature": 0.8})
            content = cand['message']['content']
            
            # 2. Verify Proposal (Self-Check)
            check_prompt = f"Question: {prompt}\nAnswer: {content}\nIs this correct? Reply 1 for Yes, 0 for No."
            check = ollama.chat(model=MODEL_NAME, messages=[{"role": "user", "content": check_prompt}], options={"temperature":0})
            
            score = 1.0 if "1" in check['message']['content'] else 0.0
            if score > best_score:
                best_score = score
                best_ans = content
        
        return best_ans

### üî¨ Step 3: Running the Experiment
We iterate through the problems and record wins/losses.

In [None]:
table = Table(title="GSM8K Results")
table.add_column("Problem", style="dim", width=30)
table.add_column("Baseline", justify="center")
table.add_column("TTRL", justify="center")
table.add_column("Truth", justify="center")

for ex in examples:
    q = ex['question']
    truth = extract_answer(ex['answer'])
    
    # Run Baseline
    base_raw = solve_problem(q, method="greedy")
    base_correct = truth in base_raw # Simple string match for tutorial
    
    # Run TTRL
    ttrl_raw = solve_problem(q, method="ttrl")
    ttrl_correct = truth in ttrl_raw
    
    table.add_row(
        q[:30]+"...", 
        "‚úÖ" if base_correct else "‚ùå", 
        "‚úÖ" if ttrl_correct else "‚ùå", 
        truth
    )

console.print(table)