# SOFAI (Slow and Fast AI) Sampling Strategy in Mellea

This notebook demonstrates the **SOFAI sampling strategy** in Mellea, specifically applied to the **Code Debugging** domain. SOFAI implements a dual-solver approach inspired by cognitive psychology's System 1 (fast) and System 2 (slow) thinking.

## What is SOFAI?

SOFAI (Slow and Fast AI) is a sampling strategy that uses two LLM solvers:

1. **S1 Solver (Fast Model)**: A smaller, faster model that iteratively attempts to solve the problem with feedback-based repair
2. **S2 Solver (Slow Model)**: A larger, more capable model that is called once when S1 fails to produce a valid solution

This approach balances **cost efficiency** (using cheaper models when possible) with **quality** (escalating to more powerful models when needed).

## SOFAI Flow Diagram

```
┌─────────────────────────────────────────────────────────┐
│                    PHASE 1: S1 Solver                   │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  ┌─────────────┐    ┌──────────────┐    ┌───────────┐  │
│  │  Generate   │───►│   Validate   │───►│  Success? │  │
│  │  Solution   │    │  vs. Reqs    │    │           │  │
│  └─────────────┘    └──────────────┘    └─────┬─────┘  │
│         ▲                                     │        │
│         │              Yes ◄──────────────────┤        │
│         │                                     │ No     │
│  ┌──────┴──────┐                              ▼        │
│  │   Repair    │◄───────── Budget left? ◄────────      │
│  │  Feedback   │               │                       │
│  └─────────────┘               │ No                    │
│                                ▼                       │
└────────────────────────────────┼───────────────────────┘
                                 │
┌────────────────────────────────┼───────────────────────┐
│                    PHASE 2: S2 Solver                   │
├────────────────────────────────┼───────────────────────┤
│                                ▼                       │
│  ┌─────────────────────────────────────────────────┐   │
│  │  Prepare context based on s2_solver_mode:       │   │
│  │  • fresh_start: Original prompt only            │   │
│  │  • continue_chat: Original + S1 history         │   │
│  │  • best_attempt: Best S1 result + feedback      │   │
│  └─────────────────────────────────────────────────┘   │
│                                │                       │
│                                ▼                       │
│  ┌─────────────┐    ┌──────────────┐    ┌───────────┐  │
│  │  Generate   │───►│   Validate   │───►│  Return   │  │
│  │  Solution   │    │  vs. Reqs    │    │  Result   │  │
│  └─────────────┘    └──────────────┘    └───────────┘  │
│                                                         │
└─────────────────────────────────────────────────────────┘
```

## Setup

First, let's install Mellea from source by cloning the repository.

In [None]:
# Install Ollama (for Colab)
!curl -fsSL https://ollama.com/install.sh | sh > /dev/null
!nohup ollama serve >/dev/null 2>&1 &

# Wait for Ollama to start
import time
time.sleep(3)

In [None]:
# Install Mellea from source (clone and install)
!git clone https://github.com/generative-computing/mellea.git /tmp/mellea
!cd /tmp/mellea && pip install -e ".[all]" -q

In [None]:
# Pull the required models
!ollama pull phi:2.7b
!ollama pull llama3.2:3b

In [None]:
import logging

import mellea
from mellea.backends.ollama import OllamaModelBackend
from mellea.core import FancyLogger
from mellea.stdlib.context import ChatContext
from mellea.stdlib.requirements import ValidationResult, req
from mellea.stdlib.sampling import SOFAISamplingStrategy

# Set logging level to see SOFAI's progress
FancyLogger.get_logger().setLevel(logging.INFO)

## SOFAISamplingStrategy Parameters

Here are the key parameters for configuring SOFAI:

| Parameter | Type | Description |
|-----------|------|-------------|
| `s1_solver_backend` | `Backend` | **Required**. Backend for the fast S1 solver (e.g., smaller model) |
| `s2_solver_backend` | `Backend` | **Required**. Backend for the slow S2 solver (e.g., larger model) |
| `s2_solver_mode` | `str` | How S2 receives context. Options: `"fresh_start"`, `"continue_chat"`, `"best_attempt"` |
| `loop_budget` | `int` | Maximum attempts for S1 before escalating to S2 (default: 3) |
| `judge_backend` | `Backend` | Optional third backend for LLM-as-Judge validation |
| `feedback_strategy` | `str` | Feedback detail level: `"simple"`, `"first_error"`, `"all_errors"` |

### S2 Solver Modes Explained

| Mode | Description | Best For |
|------|-------------|----------|
| `fresh_start` | S2 gets only the original prompt (clean slate) | Independent problem solving |
| `continue_chat` | S2 gets original prompt + entire S1 conversation history | Learning from S1's attempts |
| `best_attempt` | S2 gets original prompt + best S1 attempt + feedback summary | Focused improvement on best solution |

### Feedback Strategies (for LLM-as-Judge)

| Strategy | Description |
|----------|-------------|
| `simple` | Binary yes/no validation, no detailed feedback |
| `first_error` | Reports only the first mistake found with detailed feedback |
| `all_errors` | Comprehensive feedback about all mistakes found |

---

# Code Debugging Domain

Now let's apply SOFAI to **code debugging** - fixing buggy Python code. This is an excellent domain for SOFAI because:

1. We can **programmatically validate** code by executing it and checking outputs
2. Smaller models can often fix simple bugs, but complex logic errors may need larger models
3. Feedback can be very specific (e.g., "Expected 5, got 4 for input [1,2,3,4,5]")

## Example 1: Two Sum with Custom Validator

A classic coding problem with a subtle bug. We'll use a **custom validation function** that executes the code and checks test cases.

In [None]:
# Buggy code to fix
buggy_two_sum = '''
def two_sum(nums, target):
    """Return indices of two numbers that add up to target."""
    seen = {}
    for i, num in enumerate(nums):
        complement = target - num
        if complement in seen:
            return [i, seen[complement]]  # BUG: indices are swapped!
        seen[num] = i
    return []
'''

# Test cases for validation
two_sum_tests = [
    {"input": {"nums": [2, 7, 11, 15], "target": 9}, "expected": [0, 1]},
    {"input": {"nums": [3, 2, 4], "target": 6}, "expected": [1, 2]},
    {"input": {"nums": [3, 3], "target": 6}, "expected": [0, 1]},
]

print("Buggy code:")
print(buggy_two_sum)

In [None]:
def extract_python_code(response: str) -> str:
    """Extract Python code from LLM response (handles markdown code blocks)."""
    response = response.strip()
    
    # Try to extract from markdown code blocks
    if "```python" in response:
        start = response.find("```python") + len("```python")
        end = response.find("```", start)
        if end > start:
            return response[start:end].strip()
    elif "```" in response:
        start = response.find("```") + 3
        end = response.find("```", start)
        if end > start:
            return response[start:end].strip()
    
    # If no code blocks, try to find def statement
    if "def " in response:
        lines = response.split("\n")
        code_lines = []
        in_function = False
        for line in lines:
            if line.strip().startswith("def "):
                in_function = True
            if in_function:
                code_lines.append(line)
        return "\n".join(code_lines)
    
    return response

In [None]:
def create_code_validator(test_cases: list, function_name: str):
    """Create a validator that tests the fixed code against test cases."""
    
    def validate_code(ctx) -> ValidationResult:
        output = ctx.last_output()
        if output is None:
            return ValidationResult(False, reason="No output found.")
        
        # Extract code from response
        fixed_code = extract_python_code(str(output.value))
        if not fixed_code or "def " not in fixed_code:
            return ValidationResult(
                False, 
                reason=f"Could not find valid Python function. Please provide the complete fixed function starting with 'def {function_name}'."
            )
        
        # Execute and test
        errors = []
        try:
            namespace = {}
            exec(fixed_code, namespace)  # noqa: S102
            
            if function_name not in namespace:
                return ValidationResult(
                    False, 
                    reason=f"Function '{function_name}' not found in code. Please provide the complete function."
                )
            
            func = namespace[function_name]
            
            for i, test in enumerate(test_cases):
                try:
                    result = func(**test["input"])
                    expected = test["expected"]
                    
                    if result != expected:
                        errors.append(
                            f"Test {i+1} FAILED: {function_name}({test['input']}) "
                            f"returned {result}, expected {expected}"
                        )
                except Exception as e:
                    errors.append(f"Test {i+1} raised exception: {type(e).__name__}: {e}")
                    
        except SyntaxError as e:
            return ValidationResult(False, reason=f"Syntax error in code: {e}")
        except Exception as e:
            return ValidationResult(False, reason=f"Error executing code: {type(e).__name__}: {e}")
        
        if errors:
            return ValidationResult(False, reason=" | ".join(errors))
        
        return ValidationResult(True, reason="All test cases passed!")
    
    return validate_code

In [None]:
# Set up backends
s1_solver = OllamaModelBackend(model_id="phi:2.7b")
s2_solver = OllamaModelBackend(model_id="llama3.2:3b")

# Create SOFAI strategy with custom validator (no judge_backend)
sofai_strategy = SOFAISamplingStrategy(
    s1_solver_backend=s1_solver,
    s2_solver_backend=s2_solver,
    s2_solver_mode="best_attempt",
    loop_budget=3,
)

In [None]:
# Create the debugging prompt
debug_prompt_two_sum = f"""
The following Python function has a bug. Fix it.

```python
{buggy_two_sum}
```

Test Cases:
- two_sum([2, 7, 11, 15], 9) should return [0, 1] (indices of 2 and 7)
- two_sum([3, 2, 4], 6) should return [1, 2] (indices of 2 and 4)
- two_sum([3, 3], 6) should return [0, 1]

Provide the complete fixed function.
"""

# Create requirement with custom validator
two_sum_requirement = [
    req(
        description="The fixed code must pass all test cases with correct output.",
        validation_fn=create_code_validator(two_sum_tests, "two_sum")
    )
]

In [None]:
# Run SOFAI debugging
print("=" * 60)
print("SOFAI Code Debugging: Two Sum (Custom Validator)")
print("=" * 60)

m = mellea.MelleaSession(backend=s1_solver, ctx=ChatContext())

result = m.instruct(
    debug_prompt_two_sum,
    requirements=two_sum_requirement,
    strategy=sofai_strategy,
    return_sampling_results=True,
    model_options={"temperature": 0.2},
)

print(f"\n{'='*60}")
print(f"SUCCESS: {result.success}")
print(f"Total attempts: {len(result.sample_generations)}")
print(f"{'='*60}")

In [None]:
# Display detailed results for Example 1
for i, (gen, val_list) in enumerate(zip(result.sample_generations, result.sample_validations)):
    solver_name = "S1 Solver" if i < sofai_strategy.loop_budget else "S2 Solver"
    status = "✓ PASS" if all(v[1].as_bool() for v in val_list) else "✗ FAIL"
    
    print(f"\n--- Attempt {i + 1} ({solver_name}) [{status}] ---")
    print(f"Output:\n{gen.value[:500]}..." if len(str(gen.value)) > 500 else f"Output:\n{gen.value}")
    
    for req_obj, val_result in val_list:
        print(f"Validation: {val_result.reason}")

---

## Example 2: FizzBuzz with LLM-as-Judge (No Custom Validator)

This example demonstrates using **LLM-as-Judge** for validation instead of a custom validation function. This is useful when:
- You don't have programmatic test cases
- The correctness criterion is subjective or complex to encode
- You want quick prototyping without writing validators

We configure SOFAI with a `judge_backend` and set the `feedback_strategy` to control how detailed the feedback is.

In [None]:
# Buggy FizzBuzz implementation
buggy_fizzbuzz = '''
def fizzbuzz(n):
    """Return a list of FizzBuzz results from 1 to n.
    
    Rules:
    - Return 'Fizz' for multiples of 3
    - Return 'Buzz' for multiples of 5
    - Return 'FizzBuzz' for multiples of both 3 and 5
    - Return the number as string otherwise
    """
    result = []
    for i in range(1, n + 1):
        if i % 3 == 0:
            result.append("Fizz")
        elif i % 5 == 0:
            result.append("Buzz")
        # BUG: Missing FizzBuzz case for multiples of both 3 and 5!
        else:
            result.append(str(i))
    return result
'''

print("Buggy FizzBuzz code:")
print(buggy_fizzbuzz)

# Show what the buggy code produces
exec(buggy_fizzbuzz)  # noqa: S102
print("\nBuggy output for fizzbuzz(15):")
print(fizzbuzz(15))

In [None]:
# Set up SOFAI with LLM-as-Judge
# We use a third model as the judge
judge_backend = OllamaModelBackend(model_id="llama3.2:3b")

# Create SOFAI strategy WITH judge_backend and feedback_strategy
sofai_strategy_llm_judge = SOFAISamplingStrategy(
    s1_solver_backend=s1_solver,
    s2_solver_backend=s2_solver,
    s2_solver_mode="fresh_start",
    loop_budget=3,
    # LLM-as-Judge configuration:
    judge_backend=judge_backend,  # Third model for validation
    feedback_strategy="all_errors",  # Options: "simple", "first_error", "all_errors"
)

print("SOFAI configured with LLM-as-Judge:")
print(f"  • S1 Solver: phi:2.7b")
print(f"  • S2 Solver: llama3.2:3b")
print(f"  • Judge: llama3.2:3b")
print(f"  • Feedback Strategy: all_errors")

In [None]:
# Create the debugging prompt for FizzBuzz
debug_prompt_fizzbuzz = f"""
The following Python function has a bug in its logic. Fix it.

```python
{buggy_fizzbuzz}
```

The function should:
1. Return 'FizzBuzz' for numbers divisible by BOTH 3 and 5 (like 15, 30)
2. Return 'Fizz' for numbers divisible by 3 only
3. Return 'Buzz' for numbers divisible by 5 only
4. Return the number as a string otherwise

For example, fizzbuzz(15) should produce:
['1', '2', 'Fizz', '4', 'Buzz', 'Fizz', '7', '8', 'Fizz', 'Buzz', '11', 'Fizz', '13', '14', 'FizzBuzz']

Note that position 15 should be 'FizzBuzz', not 'Fizz'.

Provide the complete fixed function.
"""

# Create requirement WITHOUT validation_fn - will use LLM-as-Judge
fizzbuzz_requirement = [
    req(
        description="The fixed code must correctly handle the FizzBuzz logic: return 'FizzBuzz' for multiples of both 3 and 5, 'Fizz' for multiples of 3 only, 'Buzz' for multiples of 5 only, and the number as string otherwise."
    )
]

print("Requirement (no validation_fn, will use LLM-as-Judge):")
print(f"  {fizzbuzz_requirement[0].description}")

In [None]:
# Run SOFAI debugging with LLM-as-Judge
print("=" * 60)
print("SOFAI Code Debugging: FizzBuzz (LLM-as-Judge)")
print("=" * 60)

m2 = mellea.MelleaSession(backend=s1_solver, ctx=ChatContext())

result2 = m2.instruct(
    debug_prompt_fizzbuzz,
    requirements=fizzbuzz_requirement,
    strategy=sofai_strategy_llm_judge,
    return_sampling_results=True,
    model_options={"temperature": 0.2},
)

print(f"\n{'='*60}")
print(f"SUCCESS: {result2.success}")
print(f"Total attempts: {len(result2.sample_generations)}")
print(f"{'='*60}")

In [None]:
# Display detailed results showing LLM-as-Judge feedback
print("\n" + "=" * 60)
print("LLM-as-Judge Validation Details")
print("=" * 60)

for i, (gen, val_list) in enumerate(zip(result2.sample_generations, result2.sample_validations)):
    solver_name = "S1 Solver" if i < sofai_strategy_llm_judge.loop_budget else "S2 Solver"
    status = "✓ PASS" if all(v[1].as_bool() for v in val_list) else "✗ FAIL"
    
    print(f"\n{'─'*60}")
    print(f"Attempt {i + 1} ({solver_name}) [{status}]")
    print(f"{'─'*60}")
    
    # Show the generated code
    code = extract_python_code(str(gen.value))
    print(f"\nGenerated Code:")
    print(code[:400] + "..." if len(code) > 400 else code)
    
    # Show LLM-as-Judge feedback
    print(f"\nLLM-as-Judge Feedback:")
    for req_obj, val_result in val_list:
        print(f"  Valid: {val_result.as_bool()}")
        print(f"  Reason: {val_result.reason}")

In [None]:
# Verify the final result by running the fixed code
if result2.success:
    final_code = extract_python_code(str(result2.result.value))
    print("\n" + "=" * 60)
    print("FINAL FIXED CODE")
    print("=" * 60)
    print(final_code)
    
    # Actually run the fixed code to verify
    print("\n" + "=" * 60)
    print("VERIFICATION: Running fizzbuzz(15)")
    print("=" * 60)
    try:
        namespace = {}
        exec(final_code, namespace)  # noqa: S102
        if "fizzbuzz" in namespace:
            output = namespace["fizzbuzz"](15)
            print(f"Output: {output}")
            expected = ['1', '2', 'Fizz', '4', 'Buzz', 'Fizz', '7', '8', 'Fizz', 'Buzz', '11', 'Fizz', '13', '14', 'FizzBuzz']
            print(f"Expected: {expected}")
            print(f"Match: {'✓ YES' if output == expected else '✗ NO'}")
    except Exception as e:
        print(f"Error running code: {e}")

---

## Comparison: Custom Validator vs LLM-as-Judge

| Aspect | Custom Validator | LLM-as-Judge |
|--------|-----------------|---------------|
| **Setup** | Requires writing test code | Just write a description |
| **Accuracy** | Deterministic, 100% reliable | Can be inconsistent |
| **Feedback** | Specific error messages | Natural language explanation |
| **Speed** | Fast | Slower (LLM call) |
| **Cost** | Free | Extra LLM calls |
| **Best For** | Algorithmic problems | Subjective quality, code style |

---

## Understanding SOFAI Behavior

Let's analyze both runs to see how SOFAI behaved.

In [None]:
def analyze_sofai_run(result, strategy, name):
    """Analyze and display SOFAI execution details."""
    print(f"\n{'='*60}")
    print(f"Analysis: {name}")
    print(f"{'='*60}")
    
    total_attempts = len(result.sample_generations)
    s1_attempts = min(strategy.loop_budget, total_attempts)
    s2_used = total_attempts > s1_attempts
    
    print(f"\nConfiguration:")
    print(f"  • S1 Loop Budget: {strategy.loop_budget}")
    print(f"  • S2 Mode: {strategy.s2_solver_mode}")
    print(f"  • Judge Backend: {'Yes' if strategy.judge_backend else 'No (custom validator)'}")
    if strategy.judge_backend:
        print(f"  • Feedback Strategy: {strategy.feedback_strategy}")
    
    print(f"\nExecution:")
    print(f"  • S1 Attempts: {s1_attempts}")
    print(f"  • S2 Used: {'Yes' if s2_used else 'No'}")
    print(f"  • Final Result: {'✓ SUCCESS' if result.success else '✗ FAILED'}")
    
    print(f"\nAttempt Details:")
    for i, val_list in enumerate(result.sample_validations):
        solver = "S1" if i < s1_attempts else "S2"
        passed = sum(1 for _, v in val_list if v.as_bool())
        total = len(val_list)
        status = "✓" if passed == total else "✗"
        print(f"  {i+1}. [{solver}] {status} Passed {passed}/{total} requirements")

# Analyze both runs
analyze_sofai_run(result, sofai_strategy, "Two Sum (Custom Validator)")
analyze_sofai_run(result2, sofai_strategy_llm_judge, "FizzBuzz (LLM-as-Judge)")

---

## Key Takeaways

### When to Use SOFAI

SOFAI is ideal when:
1. **Cost matters**: You want to use cheaper/faster models when possible
2. **Quality matters**: You need a fallback to more capable models for hard cases
3. **Feedback is informative**: Validators can provide specific error messages
4. **Iterative improvement is possible**: The problem allows incremental fixes

### SOFAI vs RejectionSampling

| Aspect | RejectionSampling | SOFAI |
|--------|------------------|-------|
| Models | Single model | Two models (fast + slow) |
| Feedback | Simple retry | Targeted repair messages |
| Cost | Higher (same model for all) | Lower (cheap model first) |
| Complexity | Simple | More sophisticated |

### Best Practices

1. **Use custom validators when possible**: Programmatic tests are faster and more reliable
2. **Use LLM-as-Judge for subjective criteria**: Code style, documentation quality, etc.
3. **Choose `feedback_strategy` wisely**:
   - `simple`: Fast, but no guidance for fixing
   - `first_error`: Good balance of speed and guidance
   - `all_errors`: Most informative, but slower
4. **Choose S2 mode based on problem type**:
   - `best_attempt`: When S1's attempts contain useful partial solutions
   - `continue_chat`: When the conversation context helps
   - `fresh_start`: When S1's attempts might confuse S2

---

## Further Reading

- [Mellea Documentation](https://mellea.ai/)
- [Mellea Tutorial](https://github.com/generative-computing/mellea/blob/main/docs/tutorial.md)
- [SOFAI Source Code](https://github.com/generative-computing/mellea/blob/main/mellea/stdlib/sampling/sofai.py)