# Theory of Mind Benchmark Suite

This notebook evaluates model performance across **four ToM benchmarks** to establish baselines for function vector research.

## Benchmarks

| Benchmark | Source | Format | Focus |
|-----------|--------|--------|-------|
| **ToMi** | Facebook Research | Short stories + questions | First/second-order false belief |
| **FANToM** | Allen AI | Multi-party conversations | Information asymmetry in dialogue |
| **SimpleToM** | Allen AI / HuggingFace | 2-sentence narratives | Explicit + applied ToM |
| **ToMBench** | Tsinghua/Chen et al. | Diverse social scenarios | 8 ToM tasks, 31 ATOMS abilities |

## Models Evaluated
- **Target model**: gpt-oss-20b (local)
- **Positive control**: Claude (via OpenRouter API)

## Contents
1. [Setup](#1-setup)
2. [Model Backends](#2-model-backends)
3. [Evaluation Functions](#3-evaluation-functions)
4. [ToMi Evaluation](#4-tomi)
5. [FANToM Evaluation](#5-fantom)
6. [SimpleToM Evaluation](#6-simpletom)
7. [ToMBench Evaluation](#7-tombench)
8. [Cross-Benchmark Comparison](#8-comparison)
9. [Save Results](#9-save)

---
## 1. Setup <a name="1-setup"></a>

In [1]:
# Environment configuration
import warnings
warnings.filterwarnings('ignore')

import os
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import transformers
transformers.logging.set_verbosity_error()

print("‚úì Environment configured")

‚úì Environment configured


In [2]:
# Core imports
import json
import re
import time
import gc
import subprocess
from abc import ABC, abstractmethod
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from tqdm import tqdm

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from rich.table import Table
from rich.console import Console
from rich import box

# Device setup
device = torch.device(
    "mps" if torch.backends.mps.is_available() 
    else "cuda" if torch.cuda.is_available() 
    else "cpu"
)
print(f"Using device: {device}")

console = Console()

Using device: cuda


In [3]:
# if MODELS_TO_EVALUATE:
#     del MODELS_TO_EVALUATE['local']
# gc.collect()
# torch.cuda.empty_cache()

# del all_tomi_results, all_fantom_results, all_simpletom_results, all_tombench_results
# gc.collect()

# del tom_dataset, no_tom_dataset, fantom_dataset, tombench_dataset
# gc.collect()

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU Memory available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB total")
    print(f"GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

GPU Memory available: 47.53 GB total
GPU Memory allocated: 0.00 GB


In [4]:
# Install openai for OpenRouter API
try:
    import openai
    print("‚úì openai package available")
except ImportError:
    print("Installing openai package...")
    !pip install openai -q
    import openai
    print("‚úì openai package installed")

‚úì openai package available


---
## 2. Model Backends <a name="2-model-backends"></a>

Abstracted model interface supporting both local HuggingFace models and API-based models (via OpenRouter).

In [5]:
class ModelBackend(ABC):
    """Abstract base class for model backends."""
    
    @property
    @abstractmethod
    def name(self) -> str:
        """Model identifier for logging."""
        pass
    
    @abstractmethod
    def generate(self, prompt: str, max_tokens: int = 500) -> Tuple[str, float, int, int]:
        """
        Generate response from model.
        
        Returns: (response_text, elapsed_time, input_tokens, output_tokens)
        """
        pass


class LocalModelBackend(ModelBackend):
    """Backend for local HuggingFace models."""
    
    def __init__(self, model_id: str, device: torch.device):
        self.model_id = model_id
        self.device = device
        
        print(f"Loading model: {model_id}")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype="auto",
            device_map="auto",
            trust_remote_code=True,
            local_files_only=True,
            low_cpu_mem_usage=True,
        )
        
        print("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
        print(f"‚úì Model loaded!")
    
    @property
    def name(self) -> str:
        return self.model_id.split('/')[-1]
    
    def format_prompt(self, user_message: str) -> str:
        """Apply the chat template to format prompts properly for the model."""
        messages = [{"role": "user", "content": user_message}]
        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    
    def generate(self, prompt: str, max_tokens: int = 500) -> Tuple[str, float, int, int]:
        # Apply chat template to the prompt
        formatted_prompt = self.format_prompt(prompt)
        inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
        input_len = inputs["input_ids"].shape[1]
        
        start_time = time.time()
        with torch.no_grad():
            output = self.model.generate(
                inputs["input_ids"],
                max_new_tokens=max_tokens,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        elapsed = time.time() - start_time
        
        response = self.tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
        output_len = output.shape[1] - input_len
        
        return response, elapsed, input_len, output_len


class OpenRouterBackend(ModelBackend):
    """Backend for OpenRouter API (Claude and other models)."""
    
    def __init__(self, model: str = "anthropic/claude-sonnet-4", api_key: str = None):
        self.model = model
        self.client = openai.OpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=api_key or os.environ.get("OPENROUTER_API_KEY"),
        )
        print(f"‚úì OpenRouter backend initialized: {model}")
    
    @property
    def name(self) -> str:
        return self.model.split("/")[-1]
    
    def generate(self, prompt: str, max_tokens: int = 500) -> Tuple[str, float, int, int]:
        start_time = time.time()
        
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                max_tokens=max_tokens,
                messages=[
                    {"role": "user", "content": prompt}
                ]
            )
            elapsed = time.time() - start_time
            
            text = response.choices[0].message.content if response.choices else ""
            input_tokens = response.usage.prompt_tokens if response.usage else 0
            output_tokens = response.usage.completion_tokens if response.usage else 0
            
            return text, elapsed, input_tokens, output_tokens
            
        except Exception as e:
            print(f"API Error: {e}")
            return f"ERROR: {e}", time.time() - start_time, 0, 0


print("‚úì Model backend classes defined")

‚úì Model backend classes defined


In [6]:
# ============================================================
# CONFIGURATION: Choose which models to evaluate
# ============================================================

# Set your OpenRouter API key here or via environment variable
os.environ["OPENROUTER_API_KEY"] = "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"

MODELS_TO_EVALUATE = {}

# Local model (comment out if not available)
LOCAL_MODEL_ID = "../gpt-oss-20b/"
LOAD_LOCAL_MODEL = Path(LOCAL_MODEL_ID).exists()

if LOAD_LOCAL_MODEL:
    print("Loading local model...")
    MODELS_TO_EVALUATE['local'] = LocalModelBackend(LOCAL_MODEL_ID, device)
else:
    print(f"‚ö† Local model not found at {LOCAL_MODEL_ID}")

# Claude via OpenRouter as positive control
LOAD_CLAUDE = True  # Set to False to skip Claude evaluation

if LOAD_CLAUDE:
    try:
        # Available Claude models on OpenRouter:
        # - anthropic/claude-sonnet-4 (Claude Sonnet 4)
        # - anthropic/claude-opus-4 (Claude Opus 4)
        # - anthropic/claude-3.5-sonnet (Claude 3.5 Sonnet)
        # - anthropic/claude-3-opus (Claude 3 Opus)
        MODELS_TO_EVALUATE['claude'] = OpenRouterBackend(model="anthropic/claude-sonnet-4")
    except Exception as e:
        print(f"‚ö† Could not initialize OpenRouter backend: {e}")
        print("  Set OPENROUTER_API_KEY environment variable or pass api_key parameter")

print(f"\n‚úì Models to evaluate: {list(MODELS_TO_EVALUATE.keys())}")

Loading local model...
Loading model: ../gpt-oss-20b/


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading tokenizer...
‚úì Model loaded!
‚úì OpenRouter backend initialized: anthropic/claude-sonnet-4

‚úì Models to evaluate: ['local', 'claude']


---
## 3. Evaluation Functions <a name="3-evaluation-functions"></a>

Shared utilities for all benchmarks.

In [7]:
@dataclass
class EvalResult:
    """Single evaluation result."""
    idx: int
    prompt: str
    correct_answer: str
    model_response: str
    extracted_answer: str
    is_correct: bool
    time_sec: float
    input_tokens: int
    output_tokens: int
    metadata: Dict[str, Any] = None
    
    def to_dict(self):
        return {
            'idx': self.idx,
            'prompt': self.prompt,
            'correct_answer': self.correct_answer,
            'model_response': self.model_response,
            'extracted_answer': self.extracted_answer,
            'is_correct': self.is_correct,
            'time_sec': self.time_sec,
            'input_tokens': self.input_tokens,
            'output_tokens': self.output_tokens,
            'metadata': self.metadata or {}
        }


def extract_answer_tags(response: str) -> str:
    """
    Extract answer from <answer> tags, skipping placeholders like '??'.
    Falls back to location pattern or first word.
    """
    # Find all <answer> tags
    matches = re.findall(r'<answer>\s*([^<]+)\s*</answer>', response, re.IGNORECASE)
    
    # Return first non-placeholder answer
    for match in matches:
        content = match.strip().lower()
        if content and content not in ('??', '???', '?'):
            return content
    
    # Fallback: location pattern (word_word)
    match = re.search(r'\b(\w+_\w+)\b', response)
    if match:
        return match.group(1).lower()
    
    # Last resort: first word
    return response.strip().lower().split()[0] if response.strip() else ""


def extract_answer_choice(response: str, choices: List[str]) -> str:
    """
    Extract answer for multiple choice questions.
    Looks for choice letters (A, B, C, D) or exact choice text.
    """
    response_lower = response.lower().strip()
    
    # Check for letter answers like "A", "(A)", "A.", "A:"
    letter_match = re.search(r'^\s*\(?([a-d])\)?[.:\s]', response_lower)
    if letter_match:
        letter = letter_match.group(1).upper()
        idx = ord(letter) - ord('A')
        if idx < len(choices):
            return choices[idx]
    
    # Check if response starts with a choice
    for choice in choices:
        if response_lower.startswith(choice.lower()):
            return choice
    
    # Check if any choice appears in response
    for choice in choices:
        if choice.lower() in response_lower:
            return choice
    
    return response_lower.split()[0] if response_lower else ""


def extract_answer_binary(response: str) -> str:
    """
    Extract yes/no or true/false answer.
    """
    response_lower = response.lower().strip()
    
    # Check for explicit yes/no
    if response_lower.startswith('yes') or 'yes' in response_lower[:20]:
        return 'yes'
    if response_lower.startswith('no') or response_lower.startswith('not '):
        return 'no'
    if 'true' in response_lower[:20]:
        return 'yes'
    if 'false' in response_lower[:20]:
        return 'no'
    
    return response_lower.split()[0] if response_lower else ""


print("‚úì Evaluation functions defined")

‚úì Evaluation functions defined


In [8]:
def show_results_summary(results: List[EvalResult], title: str):
    """Display summary statistics for evaluation results."""
    if not results:
        print(f"No results for {title}")
        return
    
    n = len(results)
    correct = sum(1 for r in results if r.is_correct)
    accuracy = correct / n
    avg_time = sum(r.time_sec for r in results) / n
    total_time = sum(r.time_sec for r in results)
    
    table = Table(title=title, box=box.ROUNDED)
    table.add_column("Metric", style="cyan")
    table.add_column("Value", style="green")
    
    table.add_row("Accuracy", f"{accuracy:.1%}")
    table.add_row("Correct / Total", f"{correct}/{n}")
    table.add_row("Total Time", f"{total_time:.1f}s")
    table.add_row("Avg Time/Example", f"{avg_time:.2f}s")
    
    console.print(table)
    return {'accuracy': accuracy, 'correct': correct, 'total': n, 'time': total_time}


def save_results(
    results: List[EvalResult], 
    benchmark_name: str, 
    model_name: str,
    output_dir: str = "eval_outputs",
    save_json: bool = True,
    save_csv: bool = True
) -> Dict[str, str]:
    """
    Save evaluation results to JSON and/or CSV files.
    
    Args:
        results: List of EvalResult objects
        benchmark_name: Name of the benchmark (e.g., "tomi_tom", "fantom_belief")
        model_name: Name of the model being evaluated
        output_dir: Directory to save outputs
        save_json: Whether to save JSON format
        save_csv: Whether to save CSV format
    
    Returns:
        Dict with paths to saved files
    """
    import csv
    
    if not results:
        print(f"No results to save for {benchmark_name}")
        return {}
    
    # Create output directory
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Generate filename with timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    base_name = f"{benchmark_name}_{model_name}_{timestamp}"
    
    saved_files = {}
    
    # Convert results to list of dicts
    results_data = [r.to_dict() for r in results]
    
    # Add summary stats
    accuracy = sum(1 for r in results if r.is_correct) / len(results)
    summary = {
        'benchmark': benchmark_name,
        'model': model_name,
        'timestamp': timestamp,
        'total_examples': len(results),
        'correct': sum(1 for r in results if r.is_correct),
        'accuracy': accuracy,
        'total_time_sec': sum(r.time_sec for r in results)
    }
    
    # Save JSON (includes full data)
    if save_json:
        json_path = output_path / f"{base_name}.json"
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump({
                'summary': summary,
                'results': results_data
            }, f, indent=2, ensure_ascii=False)
        saved_files['json'] = str(json_path)
        print(f"  üíæ Saved JSON: {json_path}")
    
    # Save CSV (flattened for spreadsheet viewing)
    if save_csv:
        csv_path = output_path / f"{base_name}.csv"
        
        # Flatten the results for CSV
        fieldnames = [
            'idx', 'is_correct', 'correct_answer', 'extracted_answer',
            'model_response', 'prompt', 'time_sec', 'input_tokens', 'output_tokens'
        ]
        
        with open(csv_path, 'w', newline='', encoding='utf-8') as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction='ignore')
            writer.writeheader()
            for r in results_data:
                # Flatten metadata into the row if needed
                row = {k: v for k, v in r.items() if k != 'metadata'}
                writer.writerow(row)
        
        saved_files['csv'] = str(csv_path)
        print(f"  üíæ Saved CSV:  {csv_path}")
    
    return saved_files


def show_sample_results(
    results: List[EvalResult], 
    title: str, 
    n_correct: int = 3, 
    n_incorrect: int = 5,
    prompt_max_len: int = 200,
    response_max_len: int = 150
):
    """
    Display sample results for debugging, showing both correct and incorrect examples.
    
    Args:
        results: List of evaluation results
        title: Title for the display
        n_correct: Number of correct examples to show
        n_incorrect: Number of incorrect examples to show
        prompt_max_len: Max characters to show from prompt
        response_max_len: Max characters to show from response
    """
    if not results:
        print(f"No results for {title}")
        return
    
    correct_results = [r for r in results if r.is_correct]
    incorrect_results = [r for r in results if not r.is_correct]
    
    def truncate(text: str, max_len: int) -> str:
        text = text.replace('\n', ' ‚Üµ ')
        if len(text) > max_len:
            return text[:max_len] + "..."
        return text
    
    # Show incorrect examples first (more important for debugging)
    if incorrect_results:
        table = Table(
            title=f"‚ùå {title} - INCORRECT Examples ({len(incorrect_results)} total)",
            box=box.ROUNDED,
            show_lines=True,
            width=120
        )
        table.add_column("#", style="dim", width=4)
        table.add_column("Expected", style="green", width=15)
        table.add_column("Extracted", style="red", width=15)
        table.add_column("Response", style="yellow", width=40)
        table.add_column("Prompt (end)", style="dim", width=40)
        
        for r in incorrect_results[:n_incorrect]:
            # Show end of prompt (usually contains the question)
            prompt_end = r.prompt[-prompt_max_len:] if len(r.prompt) > prompt_max_len else r.prompt
            table.add_row(
                str(r.idx),
                r.correct_answer[:15],
                r.extracted_answer[:15] if r.extracted_answer else "(empty)",
                truncate(r.model_response, response_max_len),
                truncate(prompt_end, prompt_max_len)
            )
        
        console.print(table)
    else:
        print(f"‚úì {title}: No incorrect examples!")
    
    # Show a few correct examples for comparison
    if correct_results and n_correct > 0:
        table = Table(
            title=f"‚úì {title} - CORRECT Examples (sample of {len(correct_results)} total)",
            box=box.SIMPLE,
            show_lines=True,
            width=120
        )
        table.add_column("#", style="dim", width=4)
        table.add_column("Expected", style="green", width=15)
        table.add_column("Extracted", style="green", width=15)
        table.add_column("Response", style="cyan", width=40)
        
        for r in correct_results[:n_correct]:
            table.add_row(
                str(r.idx),
                r.correct_answer[:15],
                r.extracted_answer[:15],
                truncate(r.model_response, response_max_len)
            )
        
        console.print(table)


def analyze_failure_patterns(results: List[EvalResult], title: str):
    """Analyze common failure patterns in results."""
    incorrect = [r for r in results if not r.is_correct]
    
    if not incorrect:
        print(f"‚úì {title}: No failures to analyze!")
        return
    
    print(f"\nüìä {title} - Failure Analysis ({len(incorrect)} failures)")
    print("-" * 50)
    
    # Check for empty extractions
    empty_extractions = sum(1 for r in incorrect if not r.extracted_answer or r.extracted_answer in ('', '?', '??'))
    if empty_extractions:
        print(f"  ‚Ä¢ Empty/invalid extractions: {empty_extractions}")
    
    # Check for partial matches (one contains the other)
    partial_matches = sum(1 for r in incorrect 
                         if r.extracted_answer and r.correct_answer.lower() in r.extracted_answer.lower())
    if partial_matches:
        print(f"  ‚Ä¢ Partial matches (answer in extraction): {partial_matches}")
    
    reverse_partial = sum(1 for r in incorrect 
                         if r.extracted_answer and r.extracted_answer.lower() in r.correct_answer.lower())
    if reverse_partial:
        print(f"  ‚Ä¢ Partial matches (extraction in answer): {reverse_partial}")
    
    # Check for responses that contain correct answer but extraction failed
    answer_in_response = sum(1 for r in incorrect 
                            if r.correct_answer.lower() in r.model_response.lower())
    if answer_in_response:
        print(f"  ‚Ä¢ Correct answer in response but extraction failed: {answer_in_response}")
    
    # Show unique extracted answers for failures
    extracted_counts = {}
    for r in incorrect:
        ext = r.extracted_answer if r.extracted_answer else "(empty)"
        extracted_counts[ext] = extracted_counts.get(ext, 0) + 1
    
    if extracted_counts:
        print(f"\n  Top incorrect extractions:")
        for ext, count in sorted(extracted_counts.items(), key=lambda x: -x[1])[:5]:
            print(f"    '{ext}': {count}x")


def compute_accuracy(results: List[EvalResult]) -> float:
    if not results:
        return 0.0
    return sum(1 for r in results if r.is_correct) / len(results)


print("‚úì Display, analysis, and save functions defined")

‚úì Display, analysis, and save functions defined


---
## 4. ToMi Evaluation <a name="4-tomi"></a>

**ToMi** (Theory of Mind Inventory) tests first-order and second-order false belief understanding through procedurally generated short stories.

- **ToM condition**: Questions require tracking a character's (false) belief
- **No-ToM condition**: Questions only require tracking what actually happened

In [12]:
class ToMiDataset:
    """Dataset loader for ToMi benchmark."""
    
    def __init__(self, jsonl_path: str, size: int = None):
        self.jsonl_path = Path(jsonl_path)
        self.data = []
        
        with open(self.jsonl_path, 'r') as f:
            for i, line in enumerate(f):
                if size is not None and i >= size:
                    break
                self.data.append(json.loads(line))
        
        self.size = len(self.data)
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __repr__(self):
        return f"ToMiDataset(n={self.size})"


# Improved ToMi prompt - emphasizes perspective-taking and cleaner output format
TOMI_SYSTEM_PROMPT = """You are answering a question about what a CHARACTER BELIEVES, not what is actually true.

CRITICAL: The question asks where a character THINKS something is located.
- Characters ONLY know about events they WITNESSED
- If a character LEFT before an object was moved, they still believe it's in the ORIGINAL location
- Track what each character SAW, not what actually happened

First reason briefly in <think> tags (2-3 sentences max).
Then give your final answer in <answer> tags with ONLY the location name.

IMPORTANT: 
- Use <answer> tags exactly ONCE with your final answer
- Do NOT put ?? or placeholders in the answer tags
- Format: <answer>location_name</answer>

Example: <answer>blue_pantry</answer>"""


def normalize_tomi_answer(answer: str) -> str:
    """Normalize ToMi answer by converting underscores to spaces and lowercasing."""
    return answer.lower().replace('_', ' ').strip()


def extract_answer_tags(response: str) -> str:
    """
    Extract answer from <answer> tags, skipping placeholders like '??'.
    Falls back to location pattern or first word.
    """
    # Find all <answer> tags
    matches = re.findall(r'<answer>\s*([^<]+)\s*</answer>', response, re.IGNORECASE)
    
    # Return LAST non-placeholder answer (model often puts real answer at end)
    for match in reversed(matches):
        content = match.strip().lower()
        if content and content not in ('??', '???', '?', '', 'blank', '...'):
            return content
    
    # Fallback: look for location pattern ONLY after </think> if present
    think_end = response.lower().find('</think>')
    search_text = response[think_end:] if think_end >= 0 else response
    
    # Look for location pattern (word_word)
    match = re.search(r'\b([a-z]+_[a-z]+)\b', search_text.lower())
    if match:
        return match.group(1)
    
    # Last resort: first word after </think>
    if think_end >= 0:
        after_think = response[think_end + 8:].strip()
        first_word = after_think.split()[0] if after_think.split() else ""
        return first_word.lower()
    
    return response.strip().lower().split()[0] if response.strip() else ""


def evaluate_tomi(
    dataset: ToMiDataset, 
    backend: ModelBackend,
    desc: str = "ToMi",
    max_examples: int = None
) -> List[EvalResult]:
    """Evaluate model on ToMi dataset."""
    results = []
    n = min(len(dataset), max_examples) if max_examples else len(dataset)
    
    for i in tqdm(range(n), desc=desc):
        example = dataset[i]
        prompt = f"{TOMI_SYSTEM_PROMPT}\n\nStory:\n{example['prompt']}"
        
        response, elapsed, in_tok, out_tok = backend.generate(prompt)
        extracted = extract_answer_tags(response)
        
        # Normalize both answers to handle underscore vs space differences
        is_correct = normalize_tomi_answer(extracted) == normalize_tomi_answer(example['answer'])
        
        results.append(EvalResult(
            idx=i,
            prompt=prompt,
            correct_answer=example['answer'],
            model_response=response,
            extracted_answer=extracted,
            is_correct=is_correct,
            time_sec=elapsed,
            input_tokens=in_tok,
            output_tokens=out_tok,
            metadata={'story_type': example.get('story_type'), 'question_type': example.get('question_type')}
        ))
    
    return results


print("‚úì ToMi evaluation functions defined")

‚úì ToMi evaluation functions defined


In [13]:
# Configure ToMi paths - adjust as needed
TOMI_DIR = Path('../tom_benchmarks/tomi/tomi_pairs')

# Check if data exists
if TOMI_DIR.exists():
    print(f"‚úì ToMi data found at {TOMI_DIR}")
    !ls -la {TOMI_DIR}/*.jsonl | head -10
else:
    print(f"‚ö† ToMi data not found at {TOMI_DIR}")
    print("  Run the ToMi extractor script first, or adjust TOMI_DIR path")

‚úì ToMi data found at ../tom_benchmarks/tomi/tomi_pairs
-rw-r--r-- 1 root root 1599533 Jan 22 14:41 ../tom_benchmarks/tomi/tomi_pairs/all_no_tom.jsonl
-rw-r--r-- 1 root root  471563 Jan 22 14:41 ../tom_benchmarks/tomi/tomi_pairs/all_tom.jsonl
-rw-r--r-- 1 root root  468215 Jan 22 14:41 ../tom_benchmarks/tomi/tomi_pairs/first_order_0_no_tom.jsonl
-rw-r--r-- 1 root root  452339 Jan 22 14:41 ../tom_benchmarks/tomi/tomi_pairs/first_order_0_no_tom_prompts.jsonl
-rw-r--r-- 1 root root  285728 Jan 22 14:41 ../tom_benchmarks/tomi/tomi_pairs/first_order_1_no_tom.jsonl
-rw-r--r-- 1 root root  275842 Jan 22 14:41 ../tom_benchmarks/tomi/tomi_pairs/first_order_1_no_tom_prompts.jsonl
-rw-r--r-- 1 root root  180769 Jan 22 14:41 ../tom_benchmarks/tomi/tomi_pairs/first_order_1_tom.jsonl
-rw-r--r-- 1 root root  174779 Jan 22 14:41 ../tom_benchmarks/tomi/tomi_pairs/first_order_1_tom_prompts.jsonl
-rw-r--r-- 1 root root  362752 Jan 22 14:41 ../tom_benchmarks/tomi/tomi_pairs/second_order_0_no_tom.jsonl
-r

In [14]:
# Run ToMi evaluation
N_TOMI = 100  # Examples per condition

all_tomi_results = {}  # model_name -> {tom: [...], no_tom: [...]}

if TOMI_DIR.exists():
    # Load datasets
    tom_file = TOMI_DIR / 'first_order_1_tom_prompts.jsonl'
    no_tom_file = TOMI_DIR / 'first_order_1_no_tom_prompts.jsonl'
    
    if tom_file.exists() and no_tom_file.exists():
        tom_dataset = ToMiDataset(tom_file, size=N_TOMI)
        no_tom_dataset = ToMiDataset(no_tom_file, size=N_TOMI)
        
        for model_name, backend in MODELS_TO_EVALUATE.items():
            print(f"\n{'='*60}")
            print(f"Evaluating ToMi with {model_name.upper()} ({N_TOMI} examples per condition)")
            print(f"{'='*60}\n")
            
            all_tomi_results[model_name] = {}
            all_tomi_results[model_name]['tom'] = evaluate_tomi(tom_dataset, backend, f"ToMi-ToM ({model_name})", N_TOMI)
            all_tomi_results[model_name]['no_tom'] = evaluate_tomi(no_tom_dataset, backend, f"ToMi-NoToM ({model_name})", N_TOMI)
            
            print(f"\n{model_name.upper()} ToMi RESULTS:")
            tom_stats = show_results_summary(all_tomi_results[model_name]['tom'], f"ToMi: ToM (false belief) - {model_name}")
            print()
            no_tom_stats = show_results_summary(all_tomi_results[model_name]['no_tom'], f"ToMi: No-ToM (true belief) - {model_name}")
            
            if tom_stats and no_tom_stats:
                gap = no_tom_stats['accuracy'] - tom_stats['accuracy']
                print(f"\nüìä ToM Gap: {gap:+.1%} (expected: No-ToM > ToM)")
            
            # Save results to files
            print(f"\nüìÅ Saving results...")
            save_results(all_tomi_results[model_name]['tom'], "tomi_tom", model_name)
            save_results(all_tomi_results[model_name]['no_tom'], "tomi_no_tom", model_name)
            
            # Show sample results for debugging
            print(f"\n{'‚îÄ'*60}")
            print(f"SAMPLE RESULTS FOR DEBUGGING - {model_name.upper()}")
            print(f"{'‚îÄ'*60}")
            show_sample_results(all_tomi_results[model_name]['tom'], f"ToMi-ToM ({model_name})")
            analyze_failure_patterns(all_tomi_results[model_name]['tom'], f"ToMi-ToM ({model_name})")
            print()
            show_sample_results(all_tomi_results[model_name]['no_tom'], f"ToMi-NoToM ({model_name})")
            analyze_failure_patterns(all_tomi_results[model_name]['no_tom'], f"ToMi-NoToM ({model_name})")
    else:
        print(f"‚ö† ToMi prompt files not found. Available files:")
        !ls {TOMI_DIR}
else:
    print("‚è≠ Skipping ToMi (data not found)")


Evaluating ToMi with LOCAL (100 examples per condition)



ToMi-ToM (local): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [32:57<00:00, 19.78s/it]
ToMi-NoToM (local):  37%|‚ñà‚ñà‚ñà‚ñã      | 37/100 [12:06<20:37, 19.65s/it]


KeyboardInterrupt: 

---
## 5. FANToM Evaluation <a name="5-fantom"></a>

**FANToM** (False-belief ANd Theory of Mind) tests ToM in multi-party **conversations** with information asymmetry.

Characters join/leave conversations, creating natural false beliefs about shared information.

Question types:
- **BeliefQ**: What does character X believe?
- **AnswerabilityQ**: Can character X answer question Y?
- **InfoAccessQ**: Who has access to information Z?

In [None]:
# Download FANToM if not present
# FANToM data is downloaded via their dataset_loader.py script
import sys
FANTOM_DIR = Path('fantom')
FANTOM_DATA_DIR = FANTOM_DIR / 'task' / 'data' / 'fantom'

if not (FANTOM_DATA_DIR / 'fantom_v1.json').exists():
    print("Downloading FANToM benchmark data...")
    # Add the task directory to path and use their loader
    sys.path.insert(0, str(FANTOM_DIR / 'task'))
    try:
        from dataset_loader import load as load_fantom
        fantom_df = load_fantom()  # This downloads and returns DataFrame
        print(f"‚úì FANToM downloaded: {len(fantom_df)} conversations")
    except Exception as e:
        print(f"‚ö† Could not download FANToM: {e}")
        print("  Try manually: cd fantom/task && python dataset_loader.py")
    finally:
        sys.path.pop(0)
else:
    print(f"‚úì FANToM data exists at {FANTOM_DATA_DIR}")

In [None]:
class FANToMDataset:
    """Dataset loader for FANToM benchmark."""
    
    def __init__(self, json_path: str, size: int = None, use_short_context: bool = True):
        """
        Args:
            json_path: Path to fantom_v1.json
            size: Max items to load (None for all)
            use_short_context: Use short_context (True) or full_context (False)
        """
        self.json_path = Path(json_path)
        self.use_short_context = use_short_context
        
        # Load as DataFrame then convert to list of dicts
        import pandas as pd
        df = pd.read_json(self.json_path)
        
        self.data = []
        for _, row in df.iterrows():
            self.data.append(row.to_dict())
            if size and len(self.data) >= size:
                break
        
        self.size = len(self.data)
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        return self.data[idx]


FANTOM_SYSTEM_PROMPT = """You are answering questions about a conversation. Pay attention to who was present when information was shared.

Answer the question directly and concisely."""


def format_fantom_prompt(item: dict, question_type: str = 'belief', use_short_context: bool = True) -> Tuple[str, str, List[str]]:
    """
    Format a FANToM item into a prompt.
    
    Args:
        item: Data item from FANToM dataset
        question_type: 'belief' for beliefQAs, 'fact' for factQA
        use_short_context: Use short_context (True) or full_context (False)
    
    Returns: (prompt, correct_answer, choices) or (None, None, None) if no question available
    """
    # Get conversation context
    context_key = 'short_context' if use_short_context else 'full_context'
    conversation = item.get(context_key, '')
    
    if question_type == 'belief':
        # beliefQAs is a list of question dicts
        belief_qas = item.get('beliefQAs', [])
        if not belief_qas:
            return None, None, None
        
        # Take first belief question
        q_data = belief_qas[0]
        question = q_data.get('question', '')
        correct = q_data.get('correct_answer', '')
        wrong = q_data.get('wrong_answer', '')
        
        # Create binary choice
        choices = [correct, wrong]
        
        prompt = f"{FANTOM_SYSTEM_PROMPT}\n\nConversation:\n{conversation}\n\n"
        prompt += f"Question: {question}\n\n"
        prompt += f"A. {correct}\n"
        prompt += f"B. {wrong}\n"
        prompt += "\nAnswer with the letter (A or B):"
        
        return prompt, 'A', choices  # Correct answer is always 'A' (first choice)
    
    elif question_type == 'fact':
        # factQA is a single dict
        fact_qa = item.get('factQA', {})
        if not fact_qa:
            return None, None, None
        
        question = fact_qa.get('question', '')
        correct = fact_qa.get('correct_answer', '')
        wrong = fact_qa.get('wrong_answer', '')
        
        choices = [correct, wrong]
        
        prompt = f"{FANTOM_SYSTEM_PROMPT}\n\nConversation:\n{conversation}\n\n"
        prompt += f"Question: {question}\n\n"
        prompt += f"A. {correct}\n"
        prompt += f"B. {wrong}\n"
        prompt += "\nAnswer with the letter (A or B):"
        
        return prompt, 'A', choices
    
    return None, None, None


def extract_fantom_answer(response: str) -> str:
    """Extract A or B from response."""
    response_upper = response.strip().upper()
    
    # Check for A or B at start
    if response_upper.startswith('A'):
        return 'A'
    if response_upper.startswith('B'):
        return 'B'
    
    # Check for (A) or (B) pattern
    if '(A)' in response_upper or 'ANSWER: A' in response_upper or 'ANSWER IS A' in response_upper:
        return 'A'
    if '(B)' in response_upper or 'ANSWER: B' in response_upper or 'ANSWER IS B' in response_upper:
        return 'B'
    
    # Check which appears first
    pos_a = response_upper.find('A')
    pos_b = response_upper.find('B')
    
    if pos_a >= 0 and (pos_b < 0 or pos_a < pos_b):
        return 'A'
    if pos_b >= 0:
        return 'B'
    
    return response.strip()[:1].upper() if response.strip() else ""


def evaluate_fantom(
    dataset: FANToMDataset,
    backend: ModelBackend,
    question_type: str = 'belief',
    max_examples: int = None
) -> List[EvalResult]:
    """Evaluate model on FANToM dataset."""
    results = []
    n = min(len(dataset), max_examples) if max_examples else len(dataset)
    
    for i in tqdm(range(n), desc=f"FANToM-{question_type}"):
        item = dataset[i]
        prompt, correct, choices = format_fantom_prompt(item, question_type)
        
        if prompt is None:
            continue
        
        response, elapsed, in_tok, out_tok = backend.generate(prompt, max_tokens=50)
        
        extracted = extract_fantom_answer(response)
        is_correct = extracted == correct
        
        results.append(EvalResult(
            idx=i,
            prompt=prompt,
            correct_answer=correct,
            model_response=response,
            extracted_answer=extracted,
            is_correct=is_correct,
            time_sec=elapsed,
            input_tokens=in_tok,
            output_tokens=out_tok,
            metadata={'question_type': question_type}
        ))
    
    return results


print("‚úì FANToM evaluation functions defined")

In [None]:
# Run FANToM evaluation
N_FANTOM = 30

all_fantom_results = {}  # model_name -> {belief: [...], fact: [...]}

# FANToM data file path
fantom_data_file = FANTOM_DATA_DIR / 'fantom_v1.json'

if fantom_data_file.exists():
    print(f"Loading FANToM from {fantom_data_file}")
    fantom_dataset = FANToMDataset(fantom_data_file, size=N_FANTOM)
    print(f"Loaded {len(fantom_dataset)} conversations")
    
    for model_name, backend in MODELS_TO_EVALUATE.items():
        print(f"\n{'='*60}")
        print(f"Evaluating FANToM with {model_name.upper()}")
        print(f"{'='*60}\n")
        
        all_fantom_results[model_name] = {}
        all_fantom_results[model_name]['belief'] = evaluate_fantom(fantom_dataset, backend, 'belief', N_FANTOM)
        all_fantom_results[model_name]['fact'] = evaluate_fantom(fantom_dataset, backend, 'fact', N_FANTOM)
        
        print(f"\n{model_name.upper()} FANToM RESULTS:")
        show_results_summary(all_fantom_results[model_name]['belief'], f"FANToM: Belief Questions - {model_name}")
        print()
        show_results_summary(all_fantom_results[model_name]['fact'], f"FANToM: Fact Questions - {model_name}")
        
        # Save results to files
        print(f"\nüìÅ Saving results...")
        save_results(all_fantom_results[model_name]['belief'], "fantom_belief", model_name)
        save_results(all_fantom_results[model_name]['fact'], "fantom_fact", model_name)
        
        # Show sample results for debugging
        print(f"\n{'‚îÄ'*60}")
        print(f"SAMPLE RESULTS FOR DEBUGGING - {model_name.upper()}")
        print(f"{'‚îÄ'*60}")
        show_sample_results(all_fantom_results[model_name]['belief'], f"FANToM-Belief ({model_name})")
        analyze_failure_patterns(all_fantom_results[model_name]['belief'], f"FANToM-Belief ({model_name})")
        print()
        show_sample_results(all_fantom_results[model_name]['fact'], f"FANToM-Fact ({model_name})")
        analyze_failure_patterns(all_fantom_results[model_name]['fact'], f"FANToM-Fact ({model_name})")
else:
    print(f"‚è≠ Skipping FANToM (data not found at {fantom_data_file})")
    print("  Run the download cell above first, or manually: cd fantom/task && python dataset_loader.py")

---
## 6. SimpleToM Evaluation <a name="6-simpletom"></a>

**SimpleToM** provides minimal, controlled ToM evaluation through brief 2-sentence narratives.

- **Explicit ToM**: Direct questions about character beliefs
- **Applied ToM**: Questions requiring ToM to answer correctly

In [16]:
# Load SimpleToM from HuggingFace
# SimpleToM has multiple configs: 'mental-state-qa', 'behavior-qa', 'judgment-qa', 'story-data'
try:
    from datasets import load_dataset
    
    # Load the mental-state-qa config (core ToM task)
    simpletom_mental = load_dataset("allenai/SimpleToM", "mental-state-qa", split="test")
    print(f"‚úì SimpleToM mental-state-qa loaded: {len(simpletom_mental)} examples")
    print(f"  Columns: {simpletom_mental.column_names}")
    
    # Optionally load behavior-qa (applied ToM)
    simpletom_behavior = load_dataset("allenai/SimpleToM", "behavior-qa", split="test")
    print(f"‚úì SimpleToM behavior-qa loaded: {len(simpletom_behavior)} examples")
    
except Exception as e:
    print(f"‚ö† Could not load SimpleToM: {e}")
    simpletom_mental = None
    simpletom_behavior = None

‚úì SimpleToM mental-state-qa loaded: 1147 examples
  Columns: ['id', 'story', 'question', 'scenario_name', 'choices', 'answerKey']
‚úì SimpleToM behavior-qa loaded: 1147 examples


In [20]:
SIMPLETOM_SYSTEM_PROMPT = """Answer the question about the story. Choose the best answer from the options provided.

You may reason through the problem, but you MUST end your response with ONLY the letter of your final answer on the last line.

Format your response like this:
[Your reasoning here if needed]

Final Answer: [Letter]

Answer with ONLY the letter (A or B) at the end."""


def format_simpletom_prompt(example: dict) -> Tuple[str, str, List[str]]:
    """
    Format a SimpleToM example into a prompt.
    
    Args:
        example: Dataset example with 'story', 'question', 'choices', 'answerKey'
    
    Returns: (prompt, correct_answer_letter, choice_texts)
    """
    story = example['story']
    question = example['question']
    choices = example['choices']
    correct_key = example['answerKey']
    
    # choices is a dict with 'text' and 'label' lists
    choice_texts = choices['text']
    choice_labels = choices['label']
    
    prompt = f"{SIMPLETOM_SYSTEM_PROMPT}\n\nStory: {story}\n\nQuestion: {question}\n\n"
    for label, text in zip(choice_labels, choice_texts):
        prompt += f"{label}. {text}\n"
    prompt += "\nAnswer:"
    
    return prompt, correct_key, choice_texts


def extract_simpletom_answer(response: str, valid_labels: List[str] = ['A', 'B']) -> str:
    """
    Extract answer letter from response.
    Optimized for reasoning models that provide explanations before answering.
    Searches from the END of the response backwards.
    """
    response_clean = response.strip()
    response_upper = response_clean.upper()
    
    # Strategy 1: Check if last character is a valid label
    if response_clean and response_clean[-1].upper() in valid_labels:
        return response_clean[-1].upper()
    
    # Strategy 2: Check last few lines for the answer
    lines = response_clean.split('\n')
    for line in reversed(lines[-5:]):  # Check last 5 lines
        line_stripped = line.strip().upper()
        
        # Check for exact match (just the letter)
        if line_stripped in valid_labels:
            return line_stripped
        
        # Check for "Final Answer: A" or "Answer: A" patterns
        for pattern in [r'FINAL\s*ANSWER\s*:\s*([AB])', r'ANSWER\s*:\s*([AB])', r'ANSWER\s*IS\s*([AB])']:
            match = re.search(pattern, line_stripped)
            if match:
                return match.group(1)
        
        # Check if line ends with a valid label
        words = line_stripped.split()
        if words and words[-1] in valid_labels:
            return words[-1]
    
    # Strategy 3: Find LAST occurrence of a standalone valid label
    # Use word boundaries to avoid matching labels in quoted options like "A."
    last_pos = -1
    last_label = ""
    for label in valid_labels:
        # Look for label as a standalone word (not "A." from options)
        for match in re.finditer(rf'\b{label}\b', response_upper):
            if match.start() > last_pos:
                last_pos = match.start()
                last_label = label
    
    if last_label:
        return last_label
    
    # Strategy 4: Check for letter at start (only if not followed by ".")
    # This avoids picking up "A." when model quotes option A
    for label in valid_labels:
        if re.match(rf'^{label}(?![.])', response_upper):
            return label
    
    # Ultimate fallback: return empty string if no valid label found
    return ""


def evaluate_simpletom(
    dataset,
    backend: ModelBackend,
    max_examples: int,
    desc: str
) -> List[EvalResult]:
    """Evaluate model on SimpleToM dataset."""
    results = []
    
    for i in tqdm(range(min(len(dataset), max_examples)), desc=desc):
        example = dataset[i]
        prompt, correct, choices = format_simpletom_prompt(example)
        
        response, elapsed, in_tok, out_tok = backend.generate(prompt, max_tokens=600)
        
        extracted = extract_simpletom_answer(response)
        is_correct = extracted == correct
        
        results.append(EvalResult(
            idx=i,
            prompt=prompt,
            correct_answer=correct,
            model_response=response,
            extracted_answer=extracted,
            is_correct=is_correct,
            time_sec=elapsed,
            input_tokens=in_tok,
            output_tokens=out_tok,
            metadata={
                'scenario': example.get('scenario_name', ''),
                'id': example.get('id', '')
            }
        ))
    
    return results


print("‚úì SimpleToM evaluation functions defined")

‚úì SimpleToM evaluation functions defined


In [None]:
# Run SimpleToM evaluation
N_SIMPLETOM = 100

all_simpletom_results = {}  # model_name -> {mental_state: [...], behavior: [...]}

if simpletom_mental is not None:
    for model_name, backend in MODELS_TO_EVALUATE.items():
        print(f"\n{'='*60}")
        print(f"Evaluating SimpleToM with {model_name.upper()} ({N_SIMPLETOM} examples)")
        print(f"{'='*60}\n")
        
        all_simpletom_results[model_name] = {}
        
        # Mental-state QA (core ToM: "Is X aware that...?")
        all_simpletom_results[model_name]['mental_state'] = evaluate_simpletom(
            simpletom_mental, 
            backend,
            N_SIMPLETOM,
            f"SimpleToM-MentalState ({model_name})"
        )
        
        # Behavior QA (applied ToM: "What will X do next?")
        if simpletom_behavior is not None:
            all_simpletom_results[model_name]['behavior'] = evaluate_simpletom(
                simpletom_behavior,
                backend,
                N_SIMPLETOM,
                f"SimpleToM-Behavior ({model_name})"
            )
        
        print(f"\n{model_name.upper()} SimpleToM RESULTS:")
        show_results_summary(all_simpletom_results[model_name].get('mental_state', []), f"SimpleToM: Mental State QA - {model_name}")
        if 'behavior' in all_simpletom_results[model_name]:
            print()
            show_results_summary(all_simpletom_results[model_name]['behavior'], f"SimpleToM: Behavior QA - {model_name}")
        
        # Save results to files
        print(f"\nüìÅ Saving results...")
        save_results(all_simpletom_results[model_name].get('mental_state', []), "simpletom_mental_state", model_name)
        if 'behavior' in all_simpletom_results[model_name]:
            save_results(all_simpletom_results[model_name]['behavior'], "simpletom_behavior", model_name)
        
        # Show sample results for debugging
        print(f"\n{'‚îÄ'*60}")
        print(f"SAMPLE RESULTS FOR DEBUGGING - {model_name.upper()}")
        print(f"{'‚îÄ'*60}")
        show_sample_results(all_simpletom_results[model_name].get('mental_state', []), f"SimpleToM-MentalState ({model_name})")
        analyze_failure_patterns(all_simpletom_results[model_name].get('mental_state', []), f"SimpleToM-MentalState ({model_name})")
        
        if 'behavior' in all_simpletom_results[model_name]:
            print()
            show_sample_results(all_simpletom_results[model_name]['behavior'], f"SimpleToM-Behavior ({model_name})")
            analyze_failure_patterns(all_simpletom_results[model_name]['behavior'], f"SimpleToM-Behavior ({model_name})")
else:
    print("‚è≠ Skipping SimpleToM (could not load dataset)")


Evaluating SimpleToM with LOCAL (30 examples)



SimpleToM-MentalState (local): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [05:02<00:00, 10.09s/it]
SimpleToM-Behavior (local): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [06:55<00:00, 13.87s/it]


LOCAL SimpleToM RESULTS:









üìÅ Saving results...
  üíæ Saved JSON: eval_outputs/simpletom_mental_state_local_20260122_225455.json
  üíæ Saved CSV:  eval_outputs/simpletom_mental_state_local_20260122_225455.csv
  üíæ Saved JSON: eval_outputs/simpletom_behavior_local_20260122_225455.json
  üíæ Saved CSV:  eval_outputs/simpletom_behavior_local_20260122_225455.csv

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
SAMPLE RESULTS FOR DEBUGGING - LOCAL
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
‚úì SimpleToM-MentalState (local): No incorrect examples!


‚úì SimpleToM-MentalState (local): No failures to analyze!


---
## 7. ToMBench Evaluation <a name="7-tombench"></a>

**ToMBench** (Chen et al., 2024) is a comprehensive bilingual ToM benchmark with 2,860 testing samples.

### 8 ToM Tasks:
1. **Unexpected Outcome Test**: Infer mental states when expected ‚â† actual emotions
2. **Scalar Implicature Task**: "Some" implies "not all"
3. **Persuasion Story Task**: Choose effective persuasion strategies
4. **False Belief Task**: Distinguish own beliefs from others' false beliefs
5. **Ambiguous Story Task**: Understand mental states in uncertain situations
6. **Hinting Test**: Infer mental states from indirect hints
7. **Strange Story Task**: Complex social communications (lies, irony, jokes)
8. **Faux-pas Recognition Test**: Recognize social faux pas

### 6 ATOMS Ability Categories (31 specific abilities):
- Emotion, Desire, Intention, Knowledge, Belief, Non-Literal Communication

In [None]:
# ToMBench setup
TOMBENCH_DIR = Path('tombench')

# Clone if not present
if not TOMBENCH_DIR.exists() or not (TOMBENCH_DIR / 'data').exists():
    print("Downloading ToMBench benchmark...")
    !rm -rf {TOMBENCH_DIR}
    !git clone --depth 1 https://github.com/zhchen18/ToMBench.git {TOMBENCH_DIR}
    print("‚úì ToMBench downloaded")
else:
    print(f"‚úì ToMBench already exists at {TOMBENCH_DIR}")

# List available task files
if (TOMBENCH_DIR / 'data').exists():
    print("\nAvailable ToMBench tasks:")
    for f in sorted((TOMBENCH_DIR / 'data').glob('*.jsonl')):
        print(f"  - {f.stem}")

In [None]:
class ToMBenchDataset:
    """
    Dataset loader for ToMBench benchmark.
    
    ToMBench has multiple JSONL files, one per task type.
    Each item has bilingual content (Chinese + English).
    """
    
    def __init__(self, data_dir: str, tasks: List[str] = None, size_per_task: int = None):
        """
        Args:
            data_dir: Path to tombench/data directory
            tasks: List of task names to load (None = all tasks)
            size_per_task: Max items per task (None = all)
        """
        self.data_dir = Path(data_dir)
        self.data = []
        self.task_counts = {}
        
        # Get all task files
        task_files = sorted(self.data_dir.glob('*.jsonl'))
        
        for task_file in task_files:
            task_name = task_file.stem
            
            # Filter by task list if provided
            if tasks and task_name not in tasks:
                continue
            
            count = 0
            with open(task_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if size_per_task and count >= size_per_task:
                        break
                    item = json.loads(line)
                    item['_task'] = task_name  # Add task name for tracking
                    self.data.append(item)
                    count += 1
            
            self.task_counts[task_name] = count
        
        self.size = len(self.data)
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __repr__(self):
        return f"ToMBenchDataset(n={self.size}, tasks={len(self.task_counts)})"


# Improved prompt - more forceful about format, prevents common failure modes
TOMBENCH_SYSTEM_PROMPT = """You will read a story and answer a multiple choice question.

CRITICAL INSTRUCTIONS:
- Output ONLY a single letter: A, B, C, or D
- Do NOT output the option text
- Do NOT explain your reasoning
- Do NOT output placeholder text like [...] or ...
- Your entire response must be exactly one character

Example correct response: B"""


def format_tombench_prompt(item: dict) -> Tuple[str, str, List[str]]:
    """
    Format a ToMBench item into a prompt.
    
    Returns: (prompt, correct_answer_letter, choices)
    """
    # Use English fields
    story = item.get('STORY', '')
    question = item.get('QUESTION', '')
    
    # Get options
    options = []
    for letter in ['A', 'B', 'C', 'D']:
        opt = item.get(f'OPTION-{letter}', '')
        if opt:
            options.append((letter, opt))
    
    # Get answer - handle the bilingual key
    answer_key = 'Á≠îÊ°à\nANSWER'
    correct = item.get(answer_key, item.get('ANSWER', ''))
    
    if not story or not question or not options:
        return None, None, None
    
    prompt = f"{TOMBENCH_SYSTEM_PROMPT}\n\nStory: {story}\n\nQuestion: {question}\n\n"
    for letter, text in options:
        prompt += f"{letter}. {text}\n"
    prompt += "\nYour answer (single letter only):"
    
    return prompt, correct, [opt for _, opt in options]


def extract_tombench_answer(response: str) -> str:
    """Extract A, B, C, or D from response."""
    # Clean common noise patterns first
    response_clean = response.strip()
    
    # Remove common placeholder patterns
    for noise in ['[??]', '[blank]', '[...]', '...?', '...', '**', '??']:
        response_clean = response_clean.replace(noise, '').strip()
    
    response_upper = response_clean.upper()
    
    # Check for standalone letter at start (most reliable)
    if response_upper and response_upper[0] in 'ABCD':
        # Make sure it's not part of a word
        if len(response_upper) == 1 or not response_upper[1].isalpha():
            return response_upper[0]
    
    # Check for pattern like "(A)" or "Answer: A"
    import re
    patterns = [
        r'^\s*([ABCD])\s*[.\):\s]',  # Letter at start with punctuation
        r'\(([ABCD])\)',              # (A), (B), etc.
        r'ANSWER[:\s]+([ABCD])',      # "Answer: A"
        r'IS\s+([ABCD])\b',           # "is A"
    ]
    for pattern in patterns:
        match = re.search(pattern, response_upper)
        if match:
            return match.group(1)
    
    # Last resort: first A/B/C/D found
    for char in response_upper:
        if char in 'ABCD':
            return char
    
    return ""


def evaluate_tombench(
    dataset: ToMBenchDataset,
    backend: ModelBackend,
    max_examples: int = None,
    desc: str = "ToMBench"
) -> List[EvalResult]:
    """Evaluate model on ToMBench dataset."""
    results = []
    n = min(len(dataset), max_examples) if max_examples else len(dataset)
    
    for i in tqdm(range(n), desc=desc):
        item = dataset[i]
        prompt, correct, choices = format_tombench_prompt(item)
        
        if prompt is None:
            continue
        
        response, elapsed, in_tok, out_tok = backend.generate(prompt, max_tokens=50)
        
        extracted = extract_tombench_answer(response)
        is_correct = extracted == correct
        
        results.append(EvalResult(
            idx=i,
            prompt=prompt,
            correct_answer=correct,
            model_response=response,
            extracted_answer=extracted,
            is_correct=is_correct,
            time_sec=elapsed,
            input_tokens=in_tok,
            output_tokens=out_tok,
            metadata={
                'task': item.get('_task', 'unknown'),
                'ability': item.get('ËÉΩÂäõ\nABILITY', item.get('ABILITY', 'unknown'))
            }
        ))
    
    return results


def analyze_tombench_by_task(results: List[EvalResult]) -> Dict[str, Dict]:
    """Break down ToMBench results by task type."""
    task_results = {}
    
    for r in results:
        task = r.metadata.get('task', 'unknown')
        if task not in task_results:
            task_results[task] = {'correct': 0, 'total': 0}
        task_results[task]['total'] += 1
        if r.is_correct:
            task_results[task]['correct'] += 1
    
    for task in task_results:
        t = task_results[task]
        t['accuracy'] = t['correct'] / t['total'] if t['total'] > 0 else 0.0
    
    return task_results


def analyze_tombench_by_ability(results: List[EvalResult]) -> Dict[str, Dict]:
    """Break down ToMBench results by ATOMS ability category."""
    ability_results = {}
    
    for r in results:
        ability = r.metadata.get('ability', 'unknown')
        if ability not in ability_results:
            ability_results[ability] = {'correct': 0, 'total': 0}
        ability_results[ability]['total'] += 1
        if r.is_correct:
            ability_results[ability]['correct'] += 1
    
    for ability in ability_results:
        a = ability_results[ability]
        a['accuracy'] = a['correct'] / a['total'] if a['total'] > 0 else 0.0
    
    return ability_results


print("‚úì ToMBench evaluation functions defined")

In [None]:
# Run ToMBench evaluation
N_TOMBENCH_PER_TASK = 5  # Examples per task (20 tasks = ~100 total)

all_tombench_results = {}  # model_name -> [EvalResult, ...]

tombench_data_dir = TOMBENCH_DIR / 'data'

if tombench_data_dir.exists() and list(tombench_data_dir.glob('*.jsonl')):
    print(f"Loading ToMBench from {tombench_data_dir}")
    
    # Load dataset (sample from each task for efficiency)
    tombench_dataset = ToMBenchDataset(tombench_data_dir, size_per_task=N_TOMBENCH_PER_TASK)
    print(f"Loaded {len(tombench_dataset)} examples from {len(tombench_dataset.task_counts)} tasks")
    print(f"Tasks: {list(tombench_dataset.task_counts.keys())[:5]}...")
    
    for model_name, backend in MODELS_TO_EVALUATE.items():
        print(f"\n{'='*60}")
        print(f"Evaluating ToMBench with {model_name.upper()} ({len(tombench_dataset)} examples)")
        print(f"{'='*60}\n")
        
        all_tombench_results[model_name] = evaluate_tombench(
            tombench_dataset, 
            backend, 
            max_examples=len(tombench_dataset),
            desc=f"ToMBench ({model_name})"
        )
        
        print(f"\n{model_name.upper()} ToMBench RESULTS:")
        show_results_summary(all_tombench_results[model_name], f"ToMBench Overall - {model_name}")
        
        # Show breakdown by task
        task_breakdown = analyze_tombench_by_task(all_tombench_results[model_name])
        if task_breakdown:
            print(f"\nüìä Breakdown by Task ({model_name}):")
            for task, stats in sorted(task_breakdown.items()):
                print(f"   {task}: {stats['accuracy']:.1%} ({stats['correct']}/{stats['total']})")
        
        # Save results to files
        print(f"\nüìÅ Saving results...")
        save_results(all_tombench_results[model_name], "tombench", model_name)
        
        # Show sample results for debugging
        print(f"\n{'‚îÄ'*60}")
        print(f"SAMPLE RESULTS FOR DEBUGGING - {model_name.upper()}")
        print(f"{'‚îÄ'*60}")
        show_sample_results(all_tombench_results[model_name], f"ToMBench ({model_name})")
        analyze_failure_patterns(all_tombench_results[model_name], f"ToMBench ({model_name})")
else:
    print("‚è≠ Skipping ToMBench (data not found)")
    print("  Run the download cell above, or manually clone:")
    print("  git clone https://github.com/zhchen18/ToMBench.git tombench")

---
## 8. Cross-Benchmark Comparison <a name="8-comparison"></a>

In [None]:
# Compile all results for cross-model comparison
comparison_data = []

for model_name in MODELS_TO_EVALUATE.keys():
    # ToMi
    if model_name in all_tomi_results:
        if all_tomi_results[model_name].get('tom'):
            comparison_data.append((model_name, 'ToMi', 'ToM (false belief)', 
                                    compute_accuracy(all_tomi_results[model_name]['tom']), 
                                    len(all_tomi_results[model_name]['tom'])))
        if all_tomi_results[model_name].get('no_tom'):
            comparison_data.append((model_name, 'ToMi', 'No-ToM (true belief)', 
                                    compute_accuracy(all_tomi_results[model_name]['no_tom']), 
                                    len(all_tomi_results[model_name]['no_tom'])))
    
    # FANToM
    if model_name in all_fantom_results:
        if all_fantom_results[model_name].get('belief'):
            comparison_data.append((model_name, 'FANToM', 'Belief Questions', 
                                    compute_accuracy(all_fantom_results[model_name]['belief']), 
                                    len(all_fantom_results[model_name]['belief'])))
        if all_fantom_results[model_name].get('fact'):
            comparison_data.append((model_name, 'FANToM', 'Fact Questions', 
                                    compute_accuracy(all_fantom_results[model_name]['fact']), 
                                    len(all_fantom_results[model_name]['fact'])))
    
    # SimpleToM (updated field names)
    if model_name in all_simpletom_results:
        if all_simpletom_results[model_name].get('mental_state'):
            comparison_data.append((model_name, 'SimpleToM', 'Mental State QA', 
                                    compute_accuracy(all_simpletom_results[model_name]['mental_state']), 
                                    len(all_simpletom_results[model_name]['mental_state'])))
        if all_simpletom_results[model_name].get('behavior'):
            comparison_data.append((model_name, 'SimpleToM', 'Behavior QA', 
                                    compute_accuracy(all_simpletom_results[model_name]['behavior']), 
                                    len(all_simpletom_results[model_name]['behavior'])))
    
    # ToMBench
    if model_name in all_tombench_results:
        comparison_data.append((model_name, 'ToMBench', 'Overall', 
                                compute_accuracy(all_tombench_results[model_name]), 
                                len(all_tombench_results[model_name])))

# Display comparison table
if comparison_data:
    print("\n" + "="*80)
    print("CROSS-BENCHMARK COMPARISON: ALL MODELS")
    print("="*80 + "\n")
    
    table = Table(title="ToM Benchmark Results", box=box.ROUNDED)
    table.add_column("Model", style="magenta")
    table.add_column("Benchmark", style="cyan")
    table.add_column("Condition", style="white")
    table.add_column("Accuracy", style="green")
    table.add_column("N", style="dim")
    
    for model, benchmark, condition, acc, n in comparison_data:
        table.add_row(model, benchmark, condition, f"{acc:.1%}", str(n))
    
    console.print(table)
    
    # Model comparison summary
    print("\nüìä Key Observations:")
    print("   - Claude serves as positive control (expected: high accuracy)")
    print("   - ToM conditions should show lower accuracy than control conditions")
    print("   - Gap between models indicates relative ToM capability")
    print("   - Consistent patterns across benchmarks suggest robust ToM deficit/capability")
else:
    print("No results to compare - run evaluations first.")

In [None]:
# Create side-by-side comparison table for easier analysis
if len(MODELS_TO_EVALUATE) > 1 and comparison_data:
    print("\n" + "="*80)
    print("MODEL HEAD-TO-HEAD COMPARISON")
    print("="*80 + "\n")
    
    # Group by benchmark/condition
    from collections import defaultdict
    grouped = defaultdict(dict)
    for model, benchmark, condition, acc, n in comparison_data:
        key = f"{benchmark}: {condition}"
        grouped[key][model] = (acc, n)
    
    # Display
    table = Table(title="Head-to-Head Accuracy", box=box.ROUNDED)
    table.add_column("Benchmark / Condition", style="cyan")
    for model_name in MODELS_TO_EVALUATE.keys():
        table.add_column(model_name.upper(), style="green")
    table.add_column("Œî", style="yellow")  # Difference if 2 models
    
    for key in sorted(grouped.keys()):
        row = [key]
        accs = []
        for model_name in MODELS_TO_EVALUATE.keys():
            if model_name in grouped[key]:
                acc, n = grouped[key][model_name]
                row.append(f"{acc:.1%} (n={n})")
                accs.append(acc)
            else:
                row.append("‚Äî")
                accs.append(None)
        
        # Calculate difference if exactly 2 models
        if len(accs) == 2 and all(a is not None for a in accs):
            diff = accs[1] - accs[0]  # claude - local typically
            row.append(f"{diff:+.1%}")
        else:
            row.append("‚Äî")
        
        table.add_row(*row)
    
    console.print(table)
elif comparison_data:
    print("(Head-to-head comparison requires 2+ models)")

---
## 9. Save Results <a name="9-save"></a>

In [None]:
# Compile all results for export
all_results = {
    'models': list(MODELS_TO_EVALUATE.keys()),
    'timestamp': datetime.now().isoformat(),
    'benchmarks': {}
}

for model_name in MODELS_TO_EVALUATE.keys():
    all_results['benchmarks'][model_name] = {}
    
    # ToMi
    if model_name in all_tomi_results:
        all_results['benchmarks'][model_name]['tomi'] = {
            'tom': {
                'accuracy': compute_accuracy(all_tomi_results[model_name].get('tom', [])),
                'n': len(all_tomi_results[model_name].get('tom', [])),
                'results': [r.to_dict() for r in all_tomi_results[model_name].get('tom', [])]
            },
            'no_tom': {
                'accuracy': compute_accuracy(all_tomi_results[model_name].get('no_tom', [])),
                'n': len(all_tomi_results[model_name].get('no_tom', [])),
                'results': [r.to_dict() for r in all_tomi_results[model_name].get('no_tom', [])]
            }
        }
    
    # FANToM
    if model_name in all_fantom_results:
        all_results['benchmarks'][model_name]['fantom'] = {
            'belief': {
                'accuracy': compute_accuracy(all_fantom_results[model_name].get('belief', [])),
                'n': len(all_fantom_results[model_name].get('belief', [])),
                'results': [r.to_dict() for r in all_fantom_results[model_name].get('belief', [])]
            },
            'fact': {
                'accuracy': compute_accuracy(all_fantom_results[model_name].get('fact', [])),
                'n': len(all_fantom_results[model_name].get('fact', [])),
                'results': [r.to_dict() for r in all_fantom_results[model_name].get('fact', [])]
            }
        }
    
    # SimpleToM (updated field names)
    if model_name in all_simpletom_results:
        all_results['benchmarks'][model_name]['simpletom'] = {
            'mental_state': {
                'accuracy': compute_accuracy(all_simpletom_results[model_name].get('mental_state', [])),
                'n': len(all_simpletom_results[model_name].get('mental_state', [])),
                'results': [r.to_dict() for r in all_simpletom_results[model_name].get('mental_state', [])]
            }
        }
        if 'behavior' in all_simpletom_results[model_name]:
            all_results['benchmarks'][model_name]['simpletom']['behavior'] = {
                'accuracy': compute_accuracy(all_simpletom_results[model_name].get('behavior', [])),
                'n': len(all_simpletom_results[model_name].get('behavior', [])),
                'results': [r.to_dict() for r in all_simpletom_results[model_name].get('behavior', [])]
            }
    
    # ToMBench
    if model_name in all_tombench_results:
        results = all_tombench_results[model_name]
        task_breakdown = analyze_tombench_by_task(results)
        ability_breakdown = analyze_tombench_by_ability(results)
        
        all_results['benchmarks'][model_name]['tombench'] = {
            'overall': {
                'accuracy': compute_accuracy(results),
                'n': len(results),
            },
            'by_task': task_breakdown,
            'by_ability': ability_breakdown,
            'results': [r.to_dict() for r in results]
        }

# Save
outfile = f"tom_benchmark_results_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
with open(outfile, 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"‚úì Results saved to {outfile}")

---
## Notes

### Expected Patterns
- **ToMi**: No-ToM accuracy should be higher than ToM accuracy
- **FANToM**: Fact accuracy should be higher than Belief accuracy
- **SimpleToM**: Applied ToM may be harder than Explicit ToM
- **ToMBench**: Claude should significantly outperform smaller models across all 8 tasks

### Positive Control Interpretation
- Claude Opus 4.5 establishes an upper bound for expected performance
- Large gaps between Claude and target model indicate areas for improvement
- If Claude also struggles on specific tasks, those may be genuinely difficult ToM problems

### Next Steps for Function Vector Research
1. Use ToMi ToM/No-ToM pairs to extract function vectors
2. Test steering on all four benchmarks
3. Check if ToM function vectors generalize across benchmarks
4. Compare first-order vs second-order ToM vectors
5. Use ToMBench's ATOMS ability breakdown to identify specific ToM components

### References
- ToMi: Le et al. (2019) - github.com/facebookresearch/ToMi
- FANToM: Kim et al. (2023) - github.com/skywalker023/fantom
- SimpleToM: Gu et al. (2024) - huggingface.co/datasets/allenai/SimpleToM
- ToMBench: Chen et al. (2024) - github.com/zhchen18/ToMBench - arXiv:2402.15052