# Pun Evaluation Testing with Multi-Model Verification

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week3/pun_evaluation_testing.ipynb)

This notebook builds on `pun_dataset_builder.ipynb` to:
1. **Validate** generated pun examples using cross-model verification
2. **Distill** into multiple unambiguous test formats
3. **Evaluate** across multiple models of varying sizes
4. **Analyze** results by model size and family

**Key Concepts Demonstrated:**
- **Cross-model validation:** Use different models to validate LLM-generated data
- **Multi-format evaluation:** Same concept tested in different task formats
- **Scaling analysis:** How does concept understanding vary with model size?

**Supports:** Anthropic Claude, OpenAI GPT, Google Gemini, Together.ai, and Groq APIs

## Prerequisites

Before running this notebook, run `pun_dataset_builder.ipynb` to generate `pun_dataset.json`.

## References
- [Model-Written Evaluations](https://arxiv.org/abs/2212.09251) - Perez et al.
- [LLM-as-Judge](https://arxiv.org/abs/2306.05685) - Zheng et al.
- [LAMA: Language Models as Knowledge Bases](https://arxiv.org/abs/1909.01066)

## Part 1: Setup & Data Loading

In [None]:
# Install dependencies - uncomment the providers you plan to use
!pip install -q pandas matplotlib seaborn

# Provider SDKs (install whichever you have API keys for)
!pip install -q anthropic        # Anthropic Claude
!pip install -q openai           # OpenAI GPT
!pip install -q google-generativeai  # Google Gemini (free tier!)
!pip install -q together         # Together.ai (multi-model)
!pip install -q groq             # Groq (free tier, fast!)

In [None]:
import os
import json
import random
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, asdict, field
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import time

# Set your API key(s) - use whichever provider(s) you have access to
# You can set these as environment variables or paste directly (less secure)

# Option 1: Anthropic Claude
# os.environ["ANTHROPIC_API_KEY"] = "your-key-here"

# Option 2: OpenAI
# os.environ["OPENAI_API_KEY"] = "your-key-here"

# Option 3: Google Gemini (FREE tier available!)
# os.environ["GOOGLE_API_KEY"] = "your-key-here"

# Option 4: Together.ai (good for multi-model testing)
# os.environ["TOGETHER_API_KEY"] = "your-key-here"

# Option 5: Groq (FREE tier, very fast!)
# os.environ["GROQ_API_KEY"] = "your-key-here"

## Unified LLM Interface

We create a wrapper that works with any of the supported providers. This lets you use whichever API you have access to.

In [None]:
class LLMClient:
    """Unified interface for multiple LLM providers."""
    
    def __init__(self, provider: str = "gemini", model: Optional[str] = None):
        """
        Initialize with a provider: 'anthropic', 'openai', 'gemini', 'together', or 'groq'
        
        Args:
            provider: The API provider to use
            model: Optional model override (uses sensible defaults if not specified)
        """
        self.provider = provider.lower()
        
        if self.provider == "anthropic":
            from anthropic import Anthropic
            self.client = Anthropic()
            self.model = model or "claude-sonnet-4-20250514"
            
        elif self.provider == "openai":
            from openai import OpenAI
            self.client = OpenAI()
            self.model = model or "gpt-4o"
            
        elif self.provider == "gemini":
            import google.generativeai as genai
            genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
            self.model = model or "gemini-1.5-flash"
            self.client = genai.GenerativeModel(self.model)
            
        elif self.provider == "together":
            import together
            self.client = together.Together()
            self.model = model or "meta-llama/Llama-3.1-70B-Instruct-Turbo"
            
        elif self.provider == "groq":
            from groq import Groq
            self.client = Groq()
            self.model = model or "llama-3.1-70b-versatile"
            
        else:
            raise ValueError(f"Unknown provider: {provider}. Choose from: anthropic, openai, gemini, together, groq")
    
    def generate(self, prompt: str, system: str = "", max_tokens: int = 1024, temperature: float = 0.1) -> str:
        """Generate a response from the LLM."""
        
        if self.provider == "anthropic":
            response = self.client.messages.create(
                model=self.model,
                max_tokens=max_tokens,
                temperature=temperature,
                system=system if system else "You are a helpful assistant.",
                messages=[{"role": "user", "content": prompt}]
            )
            return response.content[0].text
            
        elif self.provider == "openai":
            messages = []
            if system:
                messages.append({"role": "system", "content": system})
            messages.append({"role": "user", "content": prompt})
            
            response = self.client.chat.completions.create(
                model=self.model,
                max_tokens=max_tokens,
                temperature=temperature,
                messages=messages
            )
            return response.choices[0].message.content
            
        elif self.provider == "gemini":
            full_prompt = f"{system}\n\n{prompt}" if system else prompt
            response = self.client.generate_content(
                full_prompt,
                generation_config={"max_output_tokens": max_tokens, "temperature": temperature}
            )
            return response.text
            
        elif self.provider == "together":
            messages = []
            if system:
                messages.append({"role": "system", "content": system})
            messages.append({"role": "user", "content": prompt})
            
            response = self.client.chat.completions.create(
                model=self.model,
                max_tokens=max_tokens,
                temperature=temperature,
                messages=messages
            )
            return response.choices[0].message.content
            
        elif self.provider == "groq":
            messages = []
            if system:
                messages.append({"role": "system", "content": system})
            messages.append({"role": "user", "content": prompt})
            
            response = self.client.chat.completions.create(
                model=self.model,
                max_tokens=max_tokens,
                temperature=temperature,
                messages=messages
            )
            return response.choices[0].message.content

In [None]:
# ============================================================
# CHOOSE YOUR PROVIDER HERE!
# ============================================================
# Options: "anthropic", "openai", "gemini", "together", "groq"
#
# Recommendations:
#   - "gemini" : Free tier, good for validation (Part 2)
#   - "groq"   : Free tier, very fast, good for validation
#   - "together": Best for multi-model evaluation (Part 4)
# ============================================================

VALIDATION_PROVIDER = "gemini"  # Provider for cross-model validation
EVAL_PROVIDER = "together"      # Provider for multi-model evaluation

# Initialize the validation client
llm = LLMClient(VALIDATION_PROVIDER)
print(f"Validation provider: {VALIDATION_PROVIDER}")
print(f"Validation model: {llm.model}")

# Test the connection
test_response = llm.generate("What is 2+2? Answer with just the number.", max_tokens=10)
print(f"Test response: {test_response}")

In [None]:
# Load the dataset generated by pun_dataset_builder.ipynb
try:
    with open('pun_dataset.json', 'r') as f:
        dataset = json.load(f)
    print(f"Loaded pun_dataset.json successfully!")
except FileNotFoundError:
    print("ERROR: pun_dataset.json not found.")
    print("Please run pun_dataset_builder.ipynb first to generate the dataset.")
    dataset = {'puns': [], 'pairs': [], 'cloze': []}

# Display dataset statistics
print(f"\n=== Dataset Statistics ===")
print(f"Puns: {len(dataset.get('puns', []))} examples")
print(f"Literal/Pun Pairs: {len(dataset.get('pairs', []))} pairs")
print(f"Cloze Examples: {len(dataset.get('cloze', []))} examples")

# Show sample pun
if dataset.get('puns'):
    sample = dataset['puns'][0]
    print(f"\nSample pun:")
    print(f"  {sample['setup']} {sample['punchline']}")
    print(f"  Pun word: '{sample['pun_word']}' ({sample['meaning1']} / {sample['meaning2']})")

## Part 2: Cross-Model Validation

A key insight from evaluation research: **use a different model or radically different prompt to validate generated examples**. This helps avoid self-confirmation bias where the same model that generated examples also validates them.

We use the unified LLM interface so you can validate with any provider you have access to.

In [None]:
@dataclass
class ValidationResult:
    """Result of validating a pun example."""
    is_valid_pun: bool
    double_meaning_works: bool
    validator_explanation: str
    pun_word_correct: bool
    confidence: str  # "high", "medium", "low"
    
def validate_pun(pun: Dict, llm_client: LLMClient) -> ValidationResult:
    """
    Use a different LLM to validate a pun example.
    
    Key checks:
    1. Is this actually a pun?
    2. Does the claimed double meaning work?
    3. Is the identified pun word correct?
    """
    
    full_joke = f"{pun['setup']} {pun['punchline']}"
    
    validation_prompt = f"""Analyze this joke and answer the following questions:

Joke: "{full_joke}"
Claimed pun word: "{pun['pun_word']}"
Claimed meaning 1: "{pun['meaning1']}"
Claimed meaning 2: "{pun['meaning2']}"

Questions:
1. Is this actually a pun (a joke that plays on words with multiple meanings)? Answer YES or NO.
2. Does the word "{pun['pun_word']}" actually have both meanings claimed? Answer YES or NO.
3. Is "{pun['pun_word']}" the correct word that creates the pun, or is it a different word? Answer CORRECT or WRONG.
4. How confident are you in your assessment? Answer HIGH, MEDIUM, or LOW.
5. Briefly explain the pun in your own words (1-2 sentences).

Format your response exactly as:
IS_PUN: [YES/NO]
MEANINGS_WORK: [YES/NO]
WORD_CORRECT: [CORRECT/WRONG]
CONFIDENCE: [HIGH/MEDIUM/LOW]
EXPLANATION: [your explanation]"""
    
    response = llm_client.generate(validation_prompt, max_tokens=300)
    
    # Parse the structured response
    lines = response.strip().split('\n')
    result = {
        'is_pun': False,
        'meanings_work': False,
        'word_correct': False,
        'confidence': 'low',
        'explanation': ''
    }
    
    for line in lines:
        line = line.strip()
        if line.startswith('IS_PUN:'):
            result['is_pun'] = 'YES' in line.upper()
        elif line.startswith('MEANINGS_WORK:'):
            result['meanings_work'] = 'YES' in line.upper()
        elif line.startswith('WORD_CORRECT:'):
            result['word_correct'] = 'CORRECT' in line.upper()
        elif line.startswith('CONFIDENCE:'):
            conf = line.split(':')[1].strip().lower()
            result['confidence'] = conf if conf in ['high', 'medium', 'low'] else 'medium'
        elif line.startswith('EXPLANATION:'):
            result['explanation'] = line.split(':', 1)[1].strip()
    
    return ValidationResult(
        is_valid_pun=result['is_pun'],
        double_meaning_works=result['meanings_work'],
        validator_explanation=result['explanation'],
        pun_word_correct=result['word_correct'],
        confidence=result['confidence']
    )

In [None]:
# Validate all puns in the dataset
# Note: This makes API calls, so we add rate limiting

validated_puns = []
validation_stats = {'valid': 0, 'invalid': 0, 'flagged': 0}

# Rate limits vary by provider
RATE_LIMITS = {
    'gemini': 0.5,    # ~15 RPM on free tier
    'groq': 0.3,      # Fast, but has limits
    'together': 0.5,
    'anthropic': 0.2,
    'openai': 0.2,
}
sleep_time = RATE_LIMITS.get(VALIDATION_PROVIDER, 0.5)

print(f"Validating puns using {VALIDATION_PROVIDER}...\n")

for i, pun in enumerate(dataset.get('puns', [])):
    # Rate limiting
    if i > 0:
        time.sleep(sleep_time)
    
    try:
        result = validate_pun(pun, llm)
        
        # Determine if pun passes validation
        passes = result.is_valid_pun and result.double_meaning_works and result.pun_word_correct
        
        validated_pun = {
            **pun,
            'validation': asdict(result),
            'passes_validation': passes
        }
        validated_puns.append(validated_pun)
        
        status = "VALID" if passes else "FLAGGED"
        print(f"{i+1}. {status}: {pun['setup'][:40]}...")
        print(f"   Validator says: {result.validator_explanation[:80]}...")
        
        if passes:
            validation_stats['valid'] += 1
        else:
            validation_stats['flagged'] += 1
            
    except Exception as e:
        print(f"{i+1}. ERROR: {str(e)[:50]}")
        validation_stats['invalid'] += 1

print(f"\n=== Validation Summary ===")
print(f"Valid: {validation_stats['valid']}")
print(f"Flagged: {validation_stats['flagged']}")
print(f"Errors: {validation_stats['invalid']}")

In [None]:
# Filter to only validated puns for downstream tasks
filtered_puns = [p for p in validated_puns if p['passes_validation']]

print(f"Keeping {len(filtered_puns)}/{len(validated_puns)} puns that passed validation")

# Show examples of filtered vs kept
if len(validated_puns) > len(filtered_puns):
    print("\nExamples of filtered-out puns:")
    for p in validated_puns:
        if not p['passes_validation']:
            print(f"  - {p['setup']} {p['punchline']}")
            print(f"    Reason: {p['validation']}")
            break

## Part 3: Distill to Multiple Test Formats

We convert validated puns into multiple evaluation formats. Each format tests a different aspect of pun understanding:

- **Format A (Binary Classification):** Can the model recognize a pun?
- **Format B (Forced Choice):** Can the model distinguish pun from literal usage?
- **Format C (Cloze Completion):** Can the model predict the pun word from context?
- **Format D (Word Identification):** Can the model identify which word creates the pun?

In [None]:
@dataclass
class BinaryClassificationExample:
    """Format A: Binary pun/not-pun classification."""
    text: str
    is_pun: bool
    source_type: str  # "pun" or "literal_control"
    pun_word: Optional[str] = None

def create_binary_classification_dataset(puns: List[Dict], pairs: List[Dict]) -> List[BinaryClassificationExample]:
    """
    Create binary classification examples.
    
    Includes:
    - Pun examples (positive class)
    - Matched literal sentences (negative class / controls)
    """
    examples = []
    
    # Add puns as positive examples
    for pun in puns:
        examples.append(BinaryClassificationExample(
            text=f"{pun['setup']} {pun['punchline']}",
            is_pun=True,
            source_type="pun",
            pun_word=pun['pun_word']
        ))
    
    # Add literal contexts as negative examples (from pairs)
    for pair in pairs:
        examples.append(BinaryClassificationExample(
            text=pair['literal_context1'],
            is_pun=False,
            source_type="literal_control",
            pun_word=pair['target_word']
        ))
        examples.append(BinaryClassificationExample(
            text=pair['literal_context2'],
            is_pun=False,
            source_type="literal_control",
            pun_word=pair['target_word']
        ))
    
    return examples

binary_dataset = create_binary_classification_dataset(
    filtered_puns if filtered_puns else dataset.get('puns', []),
    dataset.get('pairs', [])
)

print(f"Binary Classification Dataset: {len(binary_dataset)} examples")
print(f"  Puns: {sum(1 for ex in binary_dataset if ex.is_pun)}")
print(f"  Non-puns: {sum(1 for ex in binary_dataset if not ex.is_pun)}")

# Show examples
print("\nSample positive (pun):")
pos = next((ex for ex in binary_dataset if ex.is_pun), None)
if pos:
    print(f"  Text: {pos.text}")
    print(f"  Label: is_pun = True")

print("\nSample negative (literal):")
neg = next((ex for ex in binary_dataset if not ex.is_pun), None)
if neg:
    print(f"  Text: {neg.text}")
    print(f"  Label: is_pun = False")

In [None]:
@dataclass
class ForcedChoiceExample:
    """Format B: Forced choice between pun and literal sentence."""
    option_a: str
    option_b: str
    correct_answer: str  # "A" or "B"
    target_word: str
    pun_position: str  # "A" or "B" (for position bias analysis)

def create_forced_choice_dataset(pairs: List[Dict]) -> List[ForcedChoiceExample]:
    """
    Create forced choice (pairwise) examples.
    
    Uses matched pairs from the dataset (same word, pun vs literal usage).
    Randomizes A/B position to control for position bias.
    """
    examples = []
    
    for pair in pairs:
        # Randomly assign pun to A or B
        pun_first = random.random() < 0.5
        
        if pun_first:
            examples.append(ForcedChoiceExample(
                option_a=pair['pun_context'],
                option_b=pair['literal_context1'],
                correct_answer="A",
                target_word=pair['target_word'],
                pun_position="A"
            ))
        else:
            examples.append(ForcedChoiceExample(
                option_a=pair['literal_context1'],
                option_b=pair['pun_context'],
                correct_answer="B",
                target_word=pair['target_word'],
                pun_position="B"
            ))
    
    return examples

forced_choice_dataset = create_forced_choice_dataset(dataset.get('pairs', []))

print(f"Forced Choice Dataset: {len(forced_choice_dataset)} examples")

# Check position balance
pos_a = sum(1 for ex in forced_choice_dataset if ex.pun_position == "A")
print(f"  Pun in position A: {pos_a}")
print(f"  Pun in position B: {len(forced_choice_dataset) - pos_a}")

# Show example
if forced_choice_dataset:
    ex = forced_choice_dataset[0]
    print(f"\nSample forced choice:")
    print(f"  Which sentence contains a pun?")
    print(f"  A: {ex.option_a}")
    print(f"  B: {ex.option_b}")
    print(f"  Correct: {ex.correct_answer}")

In [None]:
@dataclass
class ClozeCompletionExample:
    """Format C: Cloze (fill-in-the-blank) completion."""
    prompt_template: str  # "This joke works because '___' can mean both {meaning1} and {meaning2}."
    full_prompt: str  # The complete prompt with context
    target_word: str
    foils: List[str]  # Distractor words
    context: str  # The original joke

def create_cloze_dataset(puns: List[Dict]) -> List[ClozeCompletionExample]:
    """
    Create cloze completion examples for LAMA-style probing.
    
    Format: "This joke works because '___' can mean both [meaning1] and [meaning2]."
    """
    examples = []
    
    for pun in puns:
        full_joke = f"{pun['setup']} {pun['punchline']}"
        
        # Create the cloze prompt
        template = f"This joke works because '___' can mean both {pun['meaning1']} and {pun['meaning2']}."
        full_prompt = f"Joke: \"{full_joke}\"\n\n{template}"
        
        # Generate simple foils (we could use LLM for better foils)
        # For now, use other pun words as foils
        other_words = [p['pun_word'] for p in puns if p['pun_word'] != pun['pun_word']]
        foils = random.sample(other_words, min(3, len(other_words))) if other_words else []
        
        examples.append(ClozeCompletionExample(
            prompt_template=template,
            full_prompt=full_prompt,
            target_word=pun['pun_word'],
            foils=foils,
            context=full_joke
        ))
    
    return examples

cloze_dataset = create_cloze_dataset(
    filtered_puns if filtered_puns else dataset.get('puns', [])
)

print(f"Cloze Completion Dataset: {len(cloze_dataset)} examples")

# Show example
if cloze_dataset:
    ex = cloze_dataset[0]
    print(f"\nSample cloze:")
    print(f"  {ex.full_prompt}")
    print(f"  Target: {ex.target_word}")
    print(f"  Foils: {ex.foils}")

In [None]:
@dataclass
class WordIdentificationExample:
    """Format D: Identify which word creates the pun."""
    joke: str
    options: List[str]  # Multiple choice options
    correct_index: int  # Index of correct answer (0-based)
    correct_word: str

def create_word_identification_dataset(puns: List[Dict]) -> List[WordIdentificationExample]:
    """
    Create word identification examples.
    
    Task: Given a pun, identify which word creates the double meaning.
    """
    examples = []
    
    for pun in puns:
        full_joke = f"{pun['setup']} {pun['punchline']}"
        pun_word = pun['pun_word'].lower()
        
        # Extract words from the joke for options
        words = [w.strip('.,!?"') for w in full_joke.split()]
        words = [w for w in words if len(w) > 2]  # Filter short words
        
        # Make sure pun_word is in options
        if pun_word not in [w.lower() for w in words]:
            continue  # Skip if pun word not found
        
        # Select distractors (other words from the joke)
        distractors = [w for w in words if w.lower() != pun_word][:3]
        
        if len(distractors) < 3:
            continue  # Need at least 3 distractors
        
        # Build options list with pun word at random position
        options = distractors[:3]
        correct_index = random.randint(0, 3)
        options.insert(correct_index, pun['pun_word'])
        
        examples.append(WordIdentificationExample(
            joke=full_joke,
            options=options,
            correct_index=correct_index,
            correct_word=pun['pun_word']
        ))
    
    return examples

word_id_dataset = create_word_identification_dataset(
    filtered_puns if filtered_puns else dataset.get('puns', [])
)

print(f"Word Identification Dataset: {len(word_id_dataset)} examples")

# Show example
if word_id_dataset:
    ex = word_id_dataset[0]
    print(f"\nSample word identification:")
    print(f"  Joke: {ex.joke}")
    print(f"  Which word creates the pun?")
    for i, opt in enumerate(ex.options):
        marker = " <--" if i == ex.correct_index else ""
        print(f"    {chr(65+i)}: {opt}{marker}")

In [None]:
# Summary of all test formats
print("=== Test Format Summary ===")
print(f"\nFormat A (Binary Classification): {len(binary_dataset)} examples")
print(f"  Task: Is this a pun? Yes/No")
print(f"\nFormat B (Forced Choice): {len(forced_choice_dataset)} examples")
print(f"  Task: Which sentence contains a pun? A/B")
print(f"\nFormat C (Cloze Completion): {len(cloze_dataset)} examples")
print(f"  Task: Fill in the blank with the pun word")
print(f"\nFormat D (Word Identification): {len(word_id_dataset)} examples")
print(f"  Task: Which word creates the pun? A/B/C/D")

## Part 4: Multi-Model Testing

Now we evaluate multiple models on each test format. We test models across different:
- **Sizes:** 7B, 8B, 9B --> 27B, 70B, 72B
- **Families:** Llama, Mistral, Qwen, Gemma

**Provider options for multi-model testing:**
- **Together.ai**: Best selection of models in one API
- **Groq**: Free, fast, but fewer models (Llama, Mistral, Gemma)
- **Mix providers**: Use Gemini + Groq + others for variety

In [None]:
# Model registry for different providers
# Each entry: (provider, model_id, display_name, family, size_b)

TOGETHER_MODELS = {
    "llama-8b": ("together", "meta-llama/Llama-3.1-8B-Instruct-Turbo", "llama", 8),
    "llama-70b": ("together", "meta-llama/Llama-3.1-70B-Instruct-Turbo", "llama", 70),
    "mistral-7b": ("together", "mistralai/Mistral-7B-Instruct-v0.3", "mistral", 7),
    "mixtral-8x7b": ("together", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistral", 56),
    "qwen-7b": ("together", "Qwen/Qwen2.5-7B-Instruct-Turbo", "qwen", 7),
    "qwen-72b": ("together", "Qwen/Qwen2.5-72B-Instruct-Turbo", "qwen", 72),
    "gemma-9b": ("together", "google/gemma-2-9b-it", "gemma", 9),
    "gemma-27b": ("together", "google/gemma-2-27b-it", "gemma", 27),
}

GROQ_MODELS = {
    "llama-8b": ("groq", "llama-3.1-8b-instant", "llama", 8),
    "llama-70b": ("groq", "llama-3.1-70b-versatile", "llama", 70),
    "mixtral-8x7b": ("groq", "mixtral-8x7b-32768", "mistral", 56),
    "gemma-9b": ("groq", "gemma2-9b-it", "gemma", 9),
}

GEMINI_MODELS = {
    "gemini-flash": ("gemini", "gemini-1.5-flash", "gemini", 0),  # Size unknown
    "gemini-pro": ("gemini", "gemini-1.5-pro", "gemini", 0),
}

# Select which model set to use based on your provider
if EVAL_PROVIDER == "together":
    EVAL_MODELS = TOGETHER_MODELS
elif EVAL_PROVIDER == "groq":
    EVAL_MODELS = GROQ_MODELS
elif EVAL_PROVIDER == "gemini":
    EVAL_MODELS = GEMINI_MODELS
else:
    # Default to groq (free)
    EVAL_MODELS = GROQ_MODELS

print(f"Evaluation provider: {EVAL_PROVIDER}")
print(f"Available models: {len(EVAL_MODELS)}")
for name, (provider, model_id, family, size) in EVAL_MODELS.items():
    size_str = f"{size}B" if size > 0 else "unknown"
    print(f"  {name}: {family} family, {size_str}")

In [None]:
class MultiModelEvaluator:
    """Evaluate multiple models using the unified LLM interface."""
    
    def __init__(self, models: Dict):
        """
        Args:
            models: Dict mapping name -> (provider, model_id, family, size_b)
        """
        self.models = models
        self.clients = {}  # Lazy initialization
        
    def get_client(self, model_name: str) -> LLMClient:
        """Get or create a client for the given model."""
        if model_name not in self.clients:
            provider, model_id, _, _ = self.models[model_name]
            self.clients[model_name] = LLMClient(provider, model=model_id)
        return self.clients[model_name]
    
    def evaluate_binary(self, model_name: str, examples: List[BinaryClassificationExample],
                       max_examples: int = 20) -> Dict:
        """Evaluate binary pun classification."""
        client = self.get_client(model_name)
        results = []
        correct = 0
        
        eval_examples = random.sample(examples, min(max_examples, len(examples)))
        
        for ex in eval_examples:
            prompt = f"""Does this text contain a pun (a joke based on a word with two meanings)?

Text: "{ex.text}"

Answer with just Yes or No."""
            
            try:
                response = client.generate(prompt, max_tokens=10)
                response_lower = response.strip().lower()
                predicted_pun = 'yes' in response_lower and 'no' not in response_lower[:3]
                
                is_correct = predicted_pun == ex.is_pun
                if is_correct:
                    correct += 1
                
                results.append({
                    'text': ex.text[:50],
                    'true_label': ex.is_pun,
                    'predicted': predicted_pun,
                    'correct': is_correct
                })
            except Exception as e:
                results.append({'error': str(e), 'correct': False})
            
            time.sleep(0.3)
        
        return {
            'model': model_name,
            'format': 'binary',
            'accuracy': correct / len(eval_examples) if eval_examples else 0,
            'correct': correct,
            'total': len(eval_examples),
            'results': results
        }
    
    def evaluate_forced_choice(self, model_name: str, examples: List[ForcedChoiceExample],
                               max_examples: int = 20) -> Dict:
        """Evaluate forced choice pun identification."""
        client = self.get_client(model_name)
        results = []
        correct = 0
        
        eval_examples = random.sample(examples, min(max_examples, len(examples)))
        
        for ex in eval_examples:
            prompt = f"""Which sentence contains a pun?

A: {ex.option_a}
B: {ex.option_b}

Answer with just A or B."""
            
            try:
                response = client.generate(prompt, max_tokens=10)
                response_upper = response.strip().upper()
                
                if 'A' in response_upper and 'B' not in response_upper:
                    predicted = 'A'
                elif 'B' in response_upper and 'A' not in response_upper:
                    predicted = 'B'
                elif response_upper.startswith('A'):
                    predicted = 'A'
                elif response_upper.startswith('B'):
                    predicted = 'B'
                else:
                    predicted = response_upper[0] if response_upper else None
                
                is_correct = predicted == ex.correct_answer
                if is_correct:
                    correct += 1
                
                results.append({
                    'correct_answer': ex.correct_answer,
                    'predicted': predicted,
                    'correct': is_correct
                })
            except Exception as e:
                results.append({'error': str(e), 'correct': False})
            
            time.sleep(0.3)
        
        return {
            'model': model_name,
            'format': 'forced_choice',
            'accuracy': correct / len(eval_examples) if eval_examples else 0,
            'correct': correct,
            'total': len(eval_examples),
            'results': results
        }
    
    def evaluate_word_identification(self, model_name: str, examples: List[WordIdentificationExample],
                                     max_examples: int = 20) -> Dict:
        """Evaluate pun word identification."""
        client = self.get_client(model_name)
        results = []
        correct = 0
        
        eval_examples = random.sample(examples, min(max_examples, len(examples)))
        
        for ex in eval_examples:
            options_str = "\n".join([f"{chr(65+i)}: {opt}" for i, opt in enumerate(ex.options)])
            
            prompt = f"""Which word creates the pun in this joke?

Joke: "{ex.joke}"

{options_str}

Answer with just the letter (A, B, C, or D)."""
            
            try:
                response = client.generate(prompt, max_tokens=10)
                response_upper = response.strip().upper()
                
                predicted_letter = None
                for letter in ['A', 'B', 'C', 'D']:
                    if letter in response_upper:
                        predicted_letter = letter
                        break
                
                if predicted_letter:
                    predicted_index = ord(predicted_letter) - ord('A')
                    is_correct = predicted_index == ex.correct_index
                else:
                    is_correct = False
                
                if is_correct:
                    correct += 1
                
                results.append({
                    'correct_word': ex.correct_word,
                    'predicted_letter': predicted_letter,
                    'correct': is_correct
                })
            except Exception as e:
                results.append({'error': str(e), 'correct': False})
            
            time.sleep(0.3)
        
        return {
            'model': model_name,
            'format': 'word_identification',
            'accuracy': correct / len(eval_examples) if eval_examples else 0,
            'correct': correct,
            'total': len(eval_examples),
            'results': results
        }

In [None]:
# Run evaluation across models
# Note: Adjust which models to test based on your API access and budget

evaluator = MultiModelEvaluator(EVAL_MODELS)
all_results = []

# Select a subset of models for faster evaluation
# Modify this list based on what you want to test
MODELS_TO_TEST = list(EVAL_MODELS.keys())[:2]  # Start with first 2 models

print("Running multi-model evaluation...")
print(f"Models: {MODELS_TO_TEST}")
print(f"Formats: binary, forced_choice, word_identification")
print()

for model_name in MODELS_TO_TEST:
    print(f"\n=== Evaluating {model_name} ===")
    
    # Binary classification
    if binary_dataset:
        print(f"  Format A (Binary)...", end=" ")
        result = evaluator.evaluate_binary(model_name, binary_dataset, max_examples=10)
        all_results.append(result)
        print(f"Accuracy: {result['accuracy']:.1%}")
    
    # Forced choice
    if forced_choice_dataset:
        print(f"  Format B (Forced Choice)...", end=" ")
        result = evaluator.evaluate_forced_choice(model_name, forced_choice_dataset, max_examples=10)
        all_results.append(result)
        print(f"Accuracy: {result['accuracy']:.1%}")
    
    # Word identification
    if word_id_dataset:
        print(f"  Format D (Word ID)...", end=" ")
        result = evaluator.evaluate_word_identification(model_name, word_id_dataset, max_examples=10)
        all_results.append(result)
        print(f"Accuracy: {result['accuracy']:.1%}")

print("\n=== Evaluation Complete ===")

## Part 5: Results Analysis

Now we analyze the results to understand:
1. How accuracy varies with model size
2. Differences between model families
3. Which puns are hardest
4. Whether different formats correlate

In [None]:
# Convert results to DataFrame for analysis
def get_model_info(model_name: str) -> Tuple[str, int]:
    """Get family and size for a model."""
    if model_name in EVAL_MODELS:
        _, _, family, size = EVAL_MODELS[model_name]
        return family, size
    return "unknown", 0

results_df = pd.DataFrame([
    {
        'model': r['model'],
        'format': r['format'],
        'accuracy': r['accuracy'],
        'correct': r['correct'],
        'total': r['total'],
        'family': get_model_info(r['model'])[0],
        'size_b': get_model_info(r['model'])[1]
    }
    for r in all_results
])

print("Results Summary:")
display(results_df)

In [None]:
# Plot: Accuracy vs Model Size
if len(results_df) > 0 and results_df['size_b'].max() > 0:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Different colors for different formats
    formats = results_df['format'].unique()
    colors = plt.cm.Set2(range(len(formats)))
    
    for fmt, color in zip(formats, colors):
        fmt_data = results_df[results_df['format'] == fmt]
        ax.scatter(fmt_data['size_b'], fmt_data['accuracy'], 
                   c=[color], s=100, label=fmt, alpha=0.7)
    
    ax.set_xlabel('Model Size (B parameters)', fontsize=12)
    ax.set_ylabel('Accuracy', fontsize=12)
    ax.set_title('Pun Understanding vs Model Size', fontsize=14)
    ax.legend()
    ax.set_ylim(0, 1)
    ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Chance (binary)')
    ax.axhline(y=0.25, color='gray', linestyle=':', alpha=0.5, label='Chance (4-way)')
    
    plt.tight_layout()
    plt.savefig('pun_accuracy_vs_size.png', dpi=150)
    plt.show()
else:
    print("No results with size info to plot. Run evaluation first or use a provider with size data.")

In [None]:
# Plot: Accuracy by Model Family
if len(results_df) > 0:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Pivot for grouped bar chart
    pivot_df = results_df.pivot_table(
        index='family', 
        columns='format', 
        values='accuracy',
        aggfunc='mean'
    )
    
    if len(pivot_df) > 0:
        pivot_df.plot(kind='bar', ax=ax, width=0.8)
        
        ax.set_xlabel('Model Family', fontsize=12)
        ax.set_ylabel('Accuracy', fontsize=12)
        ax.set_title('Pun Understanding by Model Family', fontsize=14)
        ax.set_ylim(0, 1)
        ax.legend(title='Test Format')
        ax.set_xticklabels(ax.get_xticklabels(), rotation=0)
        
        plt.tight_layout()
        plt.savefig('pun_accuracy_by_family.png', dpi=150)
        plt.show()
    else:
        print("Not enough data for family comparison.")
else:
    print("No results to plot yet.")

In [None]:
# Error Analysis: Which puns are hardest?
def analyze_errors(results: List[Dict]) -> Dict:
    """Identify patterns in model errors."""
    
    # Aggregate errors by example
    example_errors = defaultdict(int)
    example_totals = defaultdict(int)
    
    for result in results:
        if result['format'] == 'binary' and 'results' in result:
            for r in result['results']:
                if 'text' in r:
                    example_totals[r['text']] += 1
                    if not r.get('correct', False):
                        example_errors[r['text']] += 1
    
    # Calculate error rates
    error_rates = {
        text: example_errors[text] / example_totals[text]
        for text in example_totals
    }
    
    # Sort by error rate
    hardest = sorted(error_rates.items(), key=lambda x: -x[1])[:5]
    
    return {
        'hardest_examples': hardest,
        'avg_error_rate': sum(error_rates.values()) / len(error_rates) if error_rates else 0
    }

if all_results:
    error_analysis = analyze_errors(all_results)
    
    print("=== Error Analysis ===")
    print(f"\nAverage error rate: {error_analysis['avg_error_rate']:.1%}")
    print(f"\nHardest examples (highest error rate):")
    for text, rate in error_analysis['hardest_examples']:
        print(f"  {rate:.0%} error: {text}...")
else:
    print("Run evaluation first to analyze errors.")

In [None]:
# Format correlation: Do models that do well on one format do well on others?
if len(results_df) > 0 and len(results_df['format'].unique()) > 1:
    format_pivot = results_df.pivot_table(
        index='model',
        columns='format',
        values='accuracy'
    )
    
    if len(format_pivot.columns) > 1:
        print("=== Format Correlation ===")
        print("\nAccuracy by model and format:")
        display(format_pivot)
        
        print("\nCorrelation between formats:")
        display(format_pivot.corr())
else:
    print("Need results from multiple formats to compute correlation.")

## Part 6: Export Clean Dataset

Export the validated and filtered dataset along with all test formats for downstream use.

In [None]:
# Prepare export data
export_data = {
    'metadata': {
        'source': 'pun_evaluation_testing.ipynb',
        'validation_provider': VALIDATION_PROVIDER,
        'eval_provider': EVAL_PROVIDER,
        'validated_count': len(filtered_puns) if filtered_puns else len(dataset.get('puns', [])),
        'formats': ['binary', 'forced_choice', 'cloze', 'word_identification']
    },
    'validated_puns': filtered_puns if filtered_puns else dataset.get('puns', []),
    'test_formats': {
        'binary_classification': [asdict(ex) for ex in binary_dataset],
        'forced_choice': [asdict(ex) for ex in forced_choice_dataset],
        'cloze_completion': [asdict(ex) for ex in cloze_dataset],
        'word_identification': [asdict(ex) for ex in word_id_dataset]
    },
    'evaluation_results': all_results
}

# Save as JSON
with open('pun_evaluation_dataset.json', 'w') as f:
    json.dump(export_data, f, indent=2)
print("Saved pun_evaluation_dataset.json")

# Save results as CSV for easy analysis
if len(results_df) > 0:
    results_df.to_csv('pun_evaluation_results.csv', index=False)
    print("Saved pun_evaluation_results.csv")

In [None]:
# Also export individual format datasets as CSVs
if binary_dataset:
    binary_df = pd.DataFrame([asdict(ex) for ex in binary_dataset])
    binary_df.to_csv('pun_test_binary.csv', index=False)
    print(f"Saved pun_test_binary.csv ({len(binary_df)} examples)")

if forced_choice_dataset:
    fc_df = pd.DataFrame([asdict(ex) for ex in forced_choice_dataset])
    fc_df.to_csv('pun_test_forced_choice.csv', index=False)
    print(f"Saved pun_test_forced_choice.csv ({len(fc_df)} examples)")

if cloze_dataset:
    cloze_df = pd.DataFrame([asdict(ex) for ex in cloze_dataset])
    cloze_df.to_csv('pun_test_cloze.csv', index=False)
    print(f"Saved pun_test_cloze.csv ({len(cloze_df)} examples)")

if word_id_dataset:
    word_df = pd.DataFrame([{
        'joke': ex.joke,
        'option_a': ex.options[0] if len(ex.options) > 0 else '',
        'option_b': ex.options[1] if len(ex.options) > 1 else '',
        'option_c': ex.options[2] if len(ex.options) > 2 else '',
        'option_d': ex.options[3] if len(ex.options) > 3 else '',
        'correct_index': ex.correct_index,
        'correct_word': ex.correct_word
    } for ex in word_id_dataset])
    word_df.to_csv('pun_test_word_id.csv', index=False)
    print(f"Saved pun_test_word_id.csv ({len(word_df)} examples)")

In [None]:
# Download files (for Colab)
try:
    from google.colab import files
    
    # Zip all outputs
    !zip -r pun_evaluation_outputs.zip pun_evaluation_*.json pun_evaluation_*.csv pun_test_*.csv pun_accuracy_*.png 2>/dev/null || true
    files.download('pun_evaluation_outputs.zip')
except ImportError:
    print("Not in Colab - files saved to current directory.")

## Summary

In this notebook, we:

1. **Loaded** pun datasets from `pun_dataset_builder.ipynb`
2. **Validated** examples using cross-model verification (with provider of your choice)
3. **Created multiple test formats:**
   - Binary classification (is this a pun?)
   - Forced choice (which is the pun?)
   - Cloze completion (fill in the pun word)
   - Word identification (which word creates the pun?)
4. **Evaluated** multiple models across different sizes and families
5. **Analyzed** results by model size, family, and identified hardest examples

### Provider Recommendations

| Use Case | Recommended Provider | Why |
|----------|---------------------|-----|
| Validation (Part 2) | Gemini or Groq | Free tiers, capable models |
| Multi-model eval (Part 4) | Together.ai | Best model variety |
| Budget-friendly | Groq | Free, fast, good Llama models |
| Best quality | Anthropic/OpenAI | Frontier models |

### Connection to Course

This notebook demonstrates Week 3 concepts:
- **Model-Written Evals:** LLMs help create and validate evaluation data
- **LLM-as-Judge:** Cross-validation prevents self-confirmation bias
- **LAMA-style probing:** Cloze format tests knowledge extraction
- **Scaling analysis:** How concept understanding varies with model size

### Next Steps

The exported datasets can be used in later weeks:
- **Week 4:** Visualize pun word representations in embedding space
- **Week 5:** Causal tracing to find where pun understanding happens
- **Week 6:** Train probes to detect pun processing in hidden states