# Qwen 2.5 3B Counting Benchmark + Mediation Analysis

This notebook runs on Google Colab with T4 GPU:
1. Benchmarks Qwen 2.5 3B on the counting task
2. Performs activation patching mediation analysis to identify which layers causally mediate count information

## Setup: Install Dependencies

In [None]:
!pip install -q transformers accelerate torch tqdm

## Clone Repository and Load Dataset

In [None]:
!git clone https://github.com/Larsen-Daniel/llm-counting-analysis.git
%cd llm-counting-analysis

## Check GPU Availability

In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Model will run on CPU (very slow).")
    print("Make sure you've enabled GPU in Runtime > Change runtime type")

## Benchmark Code

In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict
from tqdm import tqdm
import re

def load_dataset(file_path: str) -> List[Dict]:
    """Load the dataset from JSONL file."""
    examples = []
    with open(file_path, 'r') as f:
        for line in f:
            examples.append(json.loads(line))
    return examples

def extract_answer(text: str) -> int | None:
    """
    Extract numerical answer from model output.
    Try multiple patterns in order of preference:
    1. Last number in parentheses: (N)
    2. First number on first line
    """
    # First try to find number in parentheses
    matches = re.findall(r'\((\d+)\)', text)
    if matches:
        return int(matches[-1])  # Return the last match

    # Fall back to first number in the output
    first_line = text.strip().split('\n')[0]
    number_match = re.search(r'\d+', first_line)
    if number_match:
        return int(number_match.group(0))

    return None  # No number found

def benchmark_qwen(
    model_name: str,
    dataset: List[Dict],
    max_examples: int = None,
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
) -> Dict:
    """
    Benchmark Qwen model on the counting task.

    Args:
        model_name: HuggingFace model name
        dataset: List of examples
        max_examples: Maximum examples to evaluate
        device: Device to run on

    Returns:
        Dictionary with results
    """
    print(f"\n{'='*80}")
    print(f"Benchmarking: {model_name}")
    print(f"Device: {device}")
    print(f"{'='*80}\n")

    # Load model and tokenizer
    print("Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        device_map="auto" if device == "cuda" else None,
        low_cpu_mem_usage=True
    )

    if device == "cpu":
        model = model.to(device)

    model.eval()

    # Prepare dataset
    if max_examples:
        dataset = dataset[:max_examples]

    results = {
        "model_id": model_name,
        "total_examples": len(dataset),
        "correct": 0,
        "incorrect": 0,
        "parse_errors": 0,
        "numerical_errors": 0,
        "api_errors": 0,
        "predictions": []
    }

    print(f"Evaluating on {len(dataset)} examples...")

    for example in tqdm(dataset):
        prompt = example["prompt"]
        true_answer = example["answer"]

        # Generate
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                max_new_tokens=10,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )

        generated_text = tokenizer.decode(
            outputs[0][inputs.input_ids.shape[1]:],
            skip_special_tokens=True
        )

        # Extract answer
        predicted_answer = extract_answer(generated_text)

        # Check correctness
        if predicted_answer is None:
            results["parse_errors"] += 1
            is_correct = False
        else:
            is_correct = predicted_answer == true_answer
            if is_correct:
                results["correct"] += 1
            else:
                results["numerical_errors"] += 1
                results["incorrect"] += 1

        # Store prediction
        results["predictions"].append({
            "prompt": prompt,
            "true_answer": true_answer,
            "predicted_answer": predicted_answer,
            "generated_text": generated_text,
            "correct": is_correct
        })

    # Calculate metrics
    results["accuracy"] = results["correct"] / results["total_examples"]
    results["parse_error_rate"] = results["parse_errors"] / results["total_examples"]
    results["numerical_error_rate"] = results["numerical_errors"] / results["total_examples"]
    results["api_error_rate"] = results["api_errors"] / results["total_examples"]

    print(f"\nResults for {model_name}:")
    print(f"  Accuracy: {results['accuracy']:.2%} ({results['correct']}/{results['total_examples']})")
    print(f"  Parse Errors: {results['parse_error_rate']:.2%}")
    print(f"  Numerical Errors: {results['numerical_error_rate']:.2%}")

    return results

## Run Benchmark

In [None]:
# Load dataset
dataset = load_dataset('data/counting_dataset.jsonl')
print(f"Loaded {len(dataset)} examples")

# Auto-detect device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Run benchmark on 1000 examples
results = benchmark_qwen(
    model_name="Qwen/Qwen2.5-3B-Instruct",
    dataset=dataset,
    max_examples=1000,
    device=device
)

# Save results
with open('results/qwen_benchmark_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("\nResults saved to results/qwen_benchmark_results.json")

## Download Results

In [None]:
from google.colab import files
files.download('results/qwen_benchmark_results.json')

## Load Model for Mediation Analysis

Run this cell to load the model (if you haven't run the benchmark above, or if you want to reload the model)

In [None]:
# Load model and tokenizer for mediation analysis
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

model_name = "Qwen/Qwen2.5-3B-Instruct"
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto" if device == "cuda" else None,
    low_cpu_mem_usage=True
)

if device == "cpu":
    model = model.to(device)

model.eval()
print(f"Model loaded successfully!")
print(f"Model device: {next(model.parameters()).device}")

In [None]:
# Pull latest changes and reload
!git pull origin main

# Import after adding scripts to path
import sys
sys.path.append('scripts')

import importlib
import mediation_utils
importlib.reload(mediation_utils)

from mediation_utils import run_mediation_analysis
import random

# Auto-detect device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Generate simple synthetic pairs (no need to use dataset)
def generate_simple_pairs(n_pairs=500):
    """Generate simple 4-word lists with ~2 matching items (1-3 range)."""
    CATEGORIES = {
        "fruit": ["apple", "banana", "cherry", "grape", "orange"],
        "animal": ["dog", "cat", "bird", "fish", "horse"],
        "tool": ["hammer", "wrench", "saw", "drill", "pliers"],
    }
    NOISE = ["bowl", "window", "door", "cloud", "mountain", "table", "chair", "lamp"]
    
    pairs = []
    categories = list(CATEGORIES.keys())
    
    for i in range(n_pairs):
        category = categories[i % len(categories)]
        cat_words = CATEGORIES[category]
        
        # Randomly choose 1-3 matching items (averaging ~2)
        num_matching = random.randint(1, 3)
        matching = random.sample(cat_words, num_matching)
        non_matching = random.sample(NOISE, 4 - num_matching)
        
        # Create base list with the matching items
        base_list = matching + non_matching
        random.shuffle(base_list)
        
        # Low count: first word is noise
        low_list = base_list.copy()
        if low_list[0] in cat_words:
            # Swap first word with a noise word
            low_list[0] = random.choice([w for w in NOISE if w not in low_list])
        low_answer = sum(1 for w in low_list if w in cat_words)
        
        # High count: first word is matching
        high_list = base_list.copy()
        if high_list[0] not in cat_words:
            # Swap first word with a category word
            high_list[0] = random.choice([w for w in cat_words if w not in high_list])
        high_answer = sum(1 for w in high_list if w in cat_words)
        
        # Only keep pairs where they differ by exactly 1
        if high_answer - low_answer != 1:
            continue
        
        prompt_template = """Count how many words in the list below match the given type.

Type: {category}
List: {word_list}

IMPORTANT: Respond with ONLY a number in parentheses. Nothing else.

Example:
Type: fruit
List: apple door banana cloud
Answer: (2)

Do NOT write "2 (2)" or "The answer is (2)" or any other text.
ONLY write: (N)

Answer: """
        
        pair_low = {
            'prompt': prompt_template.format(category=category, word_list=' '.join(low_list)),
            'word_list': low_list,
            'category': category,
            'answer': low_answer
        }
        
        pair_high = {
            'prompt': prompt_template.format(category=category, word_list=' '.join(high_list)),
            'word_list': high_list,
            'category': category,
            'answer': high_answer
        }
        
        pairs.append((pair_low, pair_high))
    
    return pairs

# Generate 500 candidate pairs (will be filtered to 20 perfect ones)
print("\nGenerating 500 simple minimal pairs...")
pairs = generate_simple_pairs(n_pairs=500)
print(f"Generated {len(pairs)} valid pairs")

print(f"\nRunning mediation analysis (will filter to 20 perfect examples)...")
mediation_results = run_mediation_analysis(model, tokenizer, pairs, device=device)

# Save results
import json
with open('results/mediation_results_qwen3b.json', 'w') as f:
    json.dump(mediation_results, f, indent=2)

# Display results
print("\n" + "="*80)
print("TOP 5 LAYERS BY MEAN EFFECT")
print("="*80)
effects = [(int(i), data['mean_effect']) for i, data in mediation_results['layer_effects'].items()]
effects.sort(key=lambda x: x[1], reverse=True)
for layer, effect in effects[:5]:
    print(f"  Layer {layer}: {effect:.3f}")

print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"Baseline accuracy: {mediation_results['baseline_accuracy']:.1%}")
print(f"Number of pairs tested: {mediation_results['n_pairs']}")

from google.colab import files
files.download('results/mediation_results_qwen3b.json')