# Model Comparison: Base vs GRPO-Trained

This notebook compares the outputs of:
- **Base Model**: HuggingFaceTB/SmolLM-135M-Instruct (before GRPO training)
- **Trained Model**: After GRPO training on GSM8K math problems

We'll test on several GSM8K questions and see how the trained model improves at:
1. Following the format instruction (using `<answer>...</answer>` tags)
2. Solving math problems correctly

In [23]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from datasets import load_dataset
import re
from IPython.display import display, HTML
import random

## 1. Configuration

In [24]:
# Model configuration
# BASE_MODEL_ID = "HuggingFaceTB/SmolLM-135M-Instruct"
BASE_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
TRAINED_MODEL_PATH = "../grpo-math"  # Path to your trained LoRA adapter

# Generation settings
MAX_NEW_TOKENS = 512
TEMPERATURE = 0.7
TOP_P = 0.9

# Number of examples to test
NUM_EXAMPLES = 5

## 2. Load Models

In [25]:
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

# Load base model
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    dtype=torch.bfloat16,
    device_map="auto"
)

# Load trained model (base model + LoRA adapter)
print("Loading trained model (base + LoRA adapter)...")
trained_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    dtype=torch.bfloat16,
    device_map="auto"
)
trained_model = PeftModel.from_pretrained(trained_model, TRAINED_MODEL_PATH)

print("‚úì Models loaded successfully!")

Loading tokenizer...
Loading base model...
Loading trained model (base + LoRA adapter)...
‚úì Models loaded successfully!


## 3. Load Test Dataset

In [26]:
# Load GSM8K test set
print("Loading GSM8K test dataset...")
test_dataset = load_dataset('openai/gsm8k', 'main', split='test')

# Select random examples
random.seed(42)
test_indices = random.sample(range(len(test_dataset)), NUM_EXAMPLES)
test_examples = [test_dataset[i] for i in test_indices]

print(f"‚úì Selected {NUM_EXAMPLES} random test examples")

Loading GSM8K test dataset...
‚úì Selected 5 random test examples


## 4. Helper Functions

In [27]:
def format_prompt(question):
    """Format the question into a chat prompt."""
    prompt = (
        "<|im_start|>system\n"
        "You are a helpful logic assistant. You must output your final answer "
        "wrapped in <answer> tags. Example: <answer>42</answer>.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
    return prompt


def generate_response(model, prompt, max_new_tokens=MAX_NEW_TOKENS):
    """Generate a response from the model."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode and extract only the assistant's response
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    assistant_response = full_response.split("<|im_start|>assistant\n")[-1]
    assistant_response = assistant_response.split("<|im_end|>")[0].strip()
    
    return assistant_response


def extract_answer(text):
    """Extract the answer from <answer> tags if present, with fallback to other patterns."""
    # First try: <answer> tags (preferred format)
    pattern = r"<answer>(.*?)</answer>"
    match = re.search(pattern, text, flags=re.DOTALL)
    if match:
        return match.group(1).strip()
    
    # Fallback 1: Look for patterns like "**273 yards**" or "**273**" at the end
    fallback1 = re.search(r'\*\*(\d+(?:\.\d+)?(?:\s*\w+)?)\*\*', text)
    if fallback1:
        return fallback1.group(1).strip()
    
    # Fallback 2: Look for "answer is X" patterns
    fallback2 = re.search(r'(?:answer is|total is|result is|equals?)\s*:?\s*\*?\*?(\d+(?:\.\d+)?)', text, re.IGNORECASE)
    if fallback2:
        return fallback2.group(1).strip()
    
    # Fallback 3: Last number in the text (risky but better than nothing)
    numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', text)
    if numbers:
        return numbers[-1]
    
    return None

def extract_thinking(text):
    """Extract the thinking/reasoning before the answer tag."""
    # Get everything before <answer> tag
    answer_pattern = r"<answer>"
    parts = re.split(answer_pattern, text, maxsplit=1)
    
    if len(parts) > 0:
        thinking = parts[0].strip()
        # Truncate if too long (show first 300 chars)
        if len(thinking) > 300:
            return thinking[:300] + "..."
        return thinking if thinking else "No reasoning shown"
    return "No reasoning shown"

def get_expected_answer(answer_text):
    """Extract the numerical answer from GSM8K format (after ####)."""
    return answer_text.split("####")[-1].strip()


def check_correctness(predicted, expected):
    """Check if the predicted answer matches the expected answer."""
    if predicted is None:
        return False
    
    try:
        # Extract numbers from both (handle cases with units like "273 yards")
        pred_num = re.search(r'(\d+(?:\.\d+)?)', str(predicted))
        exp_num = re.search(r'(\d+(?:\.\d+)?)', str(expected))
        
        if pred_num and exp_num:
            pred_val = float(pred_num.group(1))
            exp_val = float(exp_num.group(1))
            return abs(pred_val - exp_val) < 1e-5
        return False
    except:
        return False


def display_comparison(example_num, question, expected_answer, base_output, trained_output):
    """Display a nicely formatted comparison."""
    base_extracted = extract_answer(base_output)
    trained_extracted = extract_answer(trained_output)

    base_thinking = extract_thinking(base_output)
    trained_thinking = extract_thinking(trained_output)
    
    base_correct = check_correctness(base_extracted, expected_answer)
    trained_correct = check_correctness(trained_extracted, expected_answer)
    
    # Check if proper format was used
    base_has_format = bool(re.search(r"<answer>(.*?)</answer>", base_output, flags=re.DOTALL))
    trained_has_format = bool(re.search(r"<answer>(.*?)</answer>", trained_output, flags=re.DOTALL))
    
    html = f"""
    <div style="border: 2px solid #333; padding: 20px; margin: 20px 0; border-radius: 10px; background-color: #f9f9f9;">
        <h2 style="color: #2c3e50;">Example {example_num}</h2>
        
        <div style="margin: 15px 0;">
            <h3 style="color: #34495e;">üìù Question:</h3>
            <p style="background-color: white; padding: 10px; border-radius: 5px; border-left: 4px solid #3498db;">
                {question}
            </p>
        </div>
        
        <div style="margin: 15px 0;">
            <h3 style="color: #27ae60;">‚úì Expected Answer:</h3>
            <p style="background-color: #d5f4e6; padding: 10px; border-radius: 5px; font-weight: bold;">
                {expected_answer}
            </p>
        </div>
        
        <div style="margin: 15px 0;">
            <h3 style="color: #e74c3c;">ü§ñ Base Model Output:</h3>
            <div style="background-color: #fff8dc; padding: 10px; border-radius: 5px; margin-bottom: 10px; border-left: 4px solid #ffa500;">
                <strong>üí≠ Thinking:</strong>
                <p style="margin: 5px 0; font-style: italic; white-space: pre-wrap;">{base_thinking}</p>
            </div>
            <p style="background-color: white; color: #000000; padding: 10px; border-radius: 5px; border-left: 4px solid #e74c3c; white-space: pre-wrap;">
                <strong>Full Response:</strong><br>{base_output}
            </p>
            <p style="margin-top: 5px;">
                <strong>Format:</strong> <span style="color: {'green' if base_has_format else 'orange'};">{'‚úì Correct' if base_has_format else '‚ö† Missing <answer> tags (fallback used)'}</span> | 
                <strong>Answer:</strong> <span style="color: {'green' if base_correct else 'red'};">{'‚úì Correct' if base_correct else '‚úó Incorrect'}</span>
                {f' (extracted: {base_extracted})' if base_extracted else ''}
            </p>
        </div>
        
        <div style="margin: 15px 0;">
            <h3 style="color: #9b59b6;">üéì Trained Model Output:</h3>
            <div style="background-color: #e6e6fa; padding: 10px; border-radius: 5px; margin-bottom: 10px; border-left: 4px solid #9b59b6;">
                <strong>üí≠ Thinking:</strong>
                <p style="margin: 5px 0; font-style: italic; white-space: pre-wrap;">{trained_thinking}</p>
            </div>
            <p style="background-color: white; color: #000000; padding: 10px; border-radius: 5px; border-left: 4px solid #9b59b6; white-space: pre-wrap;">
                <strong>Full Response:</strong><br>{trained_output}
            </p>
            <p style="margin-top: 5px;">
                <strong>Format:</strong> <span style="color: {'green' if trained_has_format else 'orange'};">{'‚úì Correct' if trained_has_format else '‚ö† Missing <answer> tags (fallback used)'}</span> | 
                <strong>Answer:</strong> <span style="color: {'green' if trained_correct else 'red'};">{'‚úì Correct' if trained_correct else '‚úó Incorrect'}</span>
                {f' (extracted: {trained_extracted})' if trained_extracted else ''}
            </p>
        </div>
        
        <div style="margin-top: 15px; padding: 10px; background-color: {'#d5f4e6' if (trained_has_format and not base_has_format) or (trained_correct and not base_correct) else '#fff3cd'}; border-radius: 5px;">
            <strong>Improvement:</strong> 
            {"‚úì Trained model shows improvement!" if (trained_has_format and not base_has_format) or (trained_correct and not base_correct) else "No significant improvement" if not (base_has_format or base_correct) else "Both models performed similarly"}
        </div>
    </div>
    """
    
    display(HTML(html))

## 5. Generate and Compare Outputs

In [28]:
print("Generating comparisons...\n")

for i, example in enumerate(test_examples, 1):
    question = example['question']
    expected_answer = get_expected_answer(example['answer'])
    
    # Format prompt
    prompt = format_prompt(question)
    
    # Generate from both models
    print(f"Generating example {i}/{NUM_EXAMPLES}...")
    base_output = generate_response(base_model, prompt)
    trained_output = generate_response(trained_model, prompt)
    
    # Display comparison
    display_comparison(i, question, expected_answer, base_output, trained_output)

print("\n‚úì All comparisons complete!")

Generating comparisons...

Generating example 1/5...


Generating example 2/5...


Generating example 3/5...


Generating example 4/5...


Generating example 5/5...



‚úì All comparisons complete!


## 6. Summary Statistics

In [29]:
# Calculate statistics across all examples
base_format_count = 0
trained_format_count = 0
base_correct_count = 0
trained_correct_count = 0

print("Calculating statistics...\n")

for example in test_examples:
    question = example['question']
    expected_answer = get_expected_answer(example['answer'])
    prompt = format_prompt(question)
    
    base_output = generate_response(base_model, prompt)
    trained_output = generate_response(trained_model, prompt)
    
    base_extracted = extract_answer(base_output)
    trained_extracted = extract_answer(trained_output)
    
    if base_extracted is not None:
        base_format_count += 1
    if trained_extracted is not None:
        trained_format_count += 1
        
    if check_correctness(base_extracted, expected_answer):
        base_correct_count += 1
    if check_correctness(trained_extracted, expected_answer):
        trained_correct_count += 1

# Display summary
summary_html = f"""
<div style="border: 3px solid #2c3e50; padding: 20px; margin: 20px 0; border-radius: 10px; background-color: #ecf0f1;">
    <h2 style="color: #2c3e50; text-align: center;">üìä Summary Statistics (n={NUM_EXAMPLES})</h2>
    
    <table style="width: 100%; margin: 20px 0; border-collapse: collapse;">
        <tr style="background-color: #34495e; color: white;">
            <th style="padding: 12px; text-align: left;">Metric</th>
            <th style="padding: 12px; text-align: center;">Base Model</th>
            <th style="padding: 12px; text-align: center;">Trained Model</th>
            <th style="padding: 12px; text-align: center;">Improvement</th>
        </tr>
        <tr style="background-color: white;">
            <td style="padding: 12px;"><strong>Format Compliance</strong></td>
            <td style="padding: 12px; text-align: center;">{base_format_count}/{NUM_EXAMPLES} ({base_format_count/NUM_EXAMPLES*100:.1f}%)</td>
            <td style="padding: 12px; text-align: center;">{trained_format_count}/{NUM_EXAMPLES} ({trained_format_count/NUM_EXAMPLES*100:.1f}%)</td>
            <td style="padding: 12px; text-align: center; color: {'green' if trained_format_count > base_format_count else 'red' if trained_format_count < base_format_count else 'gray'};">+{trained_format_count - base_format_count}</td>
        </tr>
        <tr style="background-color: #ecf0f1;">
            <td style="padding: 12px;"><strong>Correct Answers</strong></td>
            <td style="padding: 12px; text-align: center;">{base_correct_count}/{NUM_EXAMPLES} ({base_correct_count/NUM_EXAMPLES*100:.1f}%)</td>
            <td style="padding: 12px; text-align: center;">{trained_correct_count}/{NUM_EXAMPLES} ({trained_correct_count/NUM_EXAMPLES*100:.1f}%)</td>
            <td style="padding: 12px; text-align: center; color: {'green' if trained_correct_count > base_correct_count else 'red' if trained_correct_count < base_correct_count else 'gray'};">+{trained_correct_count - base_correct_count}</td>
        </tr>
    </table>
</div>
"""

display(HTML(summary_html))
print("‚úì Summary complete!")

Calculating statistics...



Metric,Base Model,Trained Model,Improvement
Format Compliance,5/5 (100.0%),5/5 (100.0%),0
Correct Answers,0/5 (0.0%),1/5 (20.0%),1


‚úì Summary complete!
