# Grading Methodology Comparison Analysis

This notebook demonstrates how to compare different grading methodologies and analyze their effects on model behavior during GRPO training.

## Overview

We'll explore:
1. **Different grading methods** - Format, accuracy, and rubric-based rewards
2. **Statistical comparison** - How methods differ in score distributions
3. **Correlation analysis** - Which methods agree/disagree
4. **Behavior effects** - How grading influences model outputs
5. **Visualization** - Plots to understand the differences

In [None]:
# Setup
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / "src"))

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from grading_registry import create_standard_comparator, get_all_grading_methods
from utils import load_gsm8k_dataset, load_openrubrics_dataset

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("Setup complete!")

## 1. Available Grading Methods

Let's see what grading methods are available:

In [None]:
methods = get_all_grading_methods()

print("Available Grading Methods:")
print("=" * 80)

for name, metadata in methods.items():
    print(f"\n{name}")
    print(f"  Description: {metadata['description']}")
    print(f"  Score range: {metadata['score_range']}")
    print(f"  Requires ground truth: {metadata['requires_ground_truth']}")
    print(f"  Requires rubric: {metadata['requires_rubric']}")

## 2. Load and Prepare Dataset

We'll use the GSM8K math dataset for this analysis:

In [None]:
# Load dataset
try:
    dataset = load_gsm8k_dataset(split="train", max_examples=200, seed=42)
    print(f"Loaded {len(dataset)} examples from GSM8K")
    
    # Show a sample
    print("\nSample question:")
    print(dataset[0]['question'])
    print(f"\nAnswer: {dataset[0]['answer']}")
    
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("\nPlease download the GSM8K dataset from:")
    print("https://www.kaggle.com/datasets/thedevastator/grade-school-math-8k-q-a")
    dataset = None

In [None]:
# Generate synthetic model completions with varying quality
import random
random.seed(42)

if dataset:
    for item in dataset:
        # Simulate different quality levels
        quality = random.choice(['good', 'good', 'partial', 'poor'])  # Bias toward good
        
        if quality == 'good':
            # Proper format + correct answer
            item['completion'] = (
                f"<reasoning>Let me solve this step by step. "
                f"After careful calculation, the answer is {item['answer']}.</reasoning>"
                f"<answer>{item['answer']}</answer>"
            )
        elif quality == 'partial':
            # Proper format + wrong answer
            wrong_answer = str(int(item['answer']) + random.randint(-10, 10))
            item['completion'] = (
                f"<reasoning>Working through this problem...</reasoning>"
                f"<answer>{wrong_answer}</answer>"
            )
        else:
            # Poor format + wrong answer
            item['completion'] = f"I think the answer is {random.randint(0, 100)}"
    
    print("Generated synthetic completions")
    print("\nExample completion:")
    print(dataset[0]['completion'])

## 3. Run Comparison

Now let's compare different grading methods on this dataset:

In [None]:
if dataset:
    # Prepare data
    prompts = [item['question'] for item in dataset]
    completions = [item['completion'] for item in dataset]
    answers = [item['answer'] for item in dataset]
    
    # Create comparator
    comparator = create_standard_comparator()
    
    # Run comparison
    results = comparator.compare(
        prompts=prompts,
        completions=completions,
        method_names=['format_reward', 'accuracy_reward'],
        answers=answers
    )
    
    print("Comparison complete!")
    print(f"Compared {len(results.methods)} methods on {len(dataset)} examples")

## 4. Statistical Analysis

Let's examine the statistical properties of each grading method:

In [None]:
if dataset and results:
    print("Statistical Summary:")
    print("=" * 80)
    
    for method in results.methods:
        stats = results.statistics[method]
        print(f"\n{method.upper()}:")
        print(f"  Mean:     {stats['mean']:.3f}")
        print(f"  Median:   {stats['median']:.3f}")
        print(f"  Std Dev:  {stats['std']:.3f}")
        print(f"  Min:      {stats['min']:.3f}")
        print(f"  Max:      {stats['max']:.3f}")
        print(f"  Q1 (25%): {stats['q25']:.3f}")
        print(f"  Q3 (75%): {stats['q75']:.3f}")

## 5. Correlation Analysis

How do the methods correlate with each other?

In [None]:
if dataset and results:
    print("Correlations Between Methods:")
    print("=" * 80)
    
    for pair, corr in results.correlations.items():
        if 'pearson' in pair:
            methods = pair.replace('_pearson', '').replace('_vs_', ' vs ')
            print(f"\n{methods}:")
            print(f"  Pearson correlation: {corr:.3f}")
        elif 'spearman' in pair:
            spearman_corr = corr
            print(f"  Spearman correlation: {spearman_corr:.3f}")

## 6. Visualizations

### 6.1 Score Distributions

In [None]:
if dataset and results:
    comparator.plot_distributions()

### 6.2 Correlation Heatmap

In [None]:
if dataset and results and len(results.methods) > 1:
    comparator.plot_correlation_heatmap()

### 6.3 Pairwise Scatter Plot

In [None]:
if dataset and results and len(results.methods) > 1:
    comparator.plot_pairwise_scatter('format_reward', 'accuracy_reward')

## 7. Find Disagreements

Let's find examples where the two methods disagree most:

In [None]:
if dataset and results and len(results.methods) > 1:
    disagreements = comparator.find_disagreements(
        'format_reward', 'accuracy_reward', top_k=5, normalize=True
    )
    
    print("Top 5 Disagreements:")
    print("=" * 80)
    
    for i, ex in enumerate(disagreements, 1):
        print(f"\n{i}. Difference: {ex['difference']:.3f}")
        print(f"   Format reward score: {ex['format_reward_score']:.3f}")
        print(f"   Accuracy reward score: {ex['accuracy_reward_score']:.3f}")
        print(f"\n   Question: {ex['prompt'][:150]}...")
        print(f"\n   Completion: {ex['completion'][:200]}...")
        print("-" * 80)

## 8. Behavior Analysis

How do different score ranges correlate with model behavior?

In [None]:
if dataset and results:
    print("Behavior Analysis by Score Range:")
    print("=" * 80)
    
    for method in results.methods:
        print(f"\n{method.upper()}:")
        behavior = comparator.analyze_behavior_effects(method)
        
        for group_name, analysis in behavior.items():
            print(f"\n  {group_name.upper()}:")
            print(f"    Count: {analysis['count']}")
            print(f"    Avg completion length: {analysis['avg_length']:.1f} chars")
            print(f"    Avg word count: {analysis['avg_word_count']:.1f} words")

## 9. Save Results

Save the analysis results for future reference:

In [None]:
if dataset and results:
    output_dir = Path.cwd().parent / "results" / "notebook_analysis"
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save results
    results.save(str(output_dir / "comparison_results.json"))
    
    # Generate report
    comparator.generate_report(str(output_dir / "comparison_report.md"))
    
    # Save plots
    comparator.plot_distributions(str(output_dir / "distributions.png"))
    if len(results.methods) > 1:
        comparator.plot_correlation_heatmap(str(output_dir / "correlations.png"))
        comparator.plot_pairwise_scatter(
            'format_reward', 'accuracy_reward',
            str(output_dir / "scatter.png")
        )
    
    print(f"Results saved to {output_dir}/")
    print("\nSaved files:")
    for f in output_dir.iterdir():
        print(f"  - {f.name}")

## 10. Custom Analysis

Use this cell to run your own custom analyses:

In [None]:
# Your custom analysis here
# For example, you could:
# - Compare additional methods
# - Analyze different datasets
# - Create custom visualizations
# - Test your own grading functions

pass