# Explore Candidate ToM Benchmarks

This notebook explores several Theory of Mind (ToM) benchmarks for evaluating language models' ability to reason about mental states, beliefs, and intentions of agents.

**Benchmarks covered:**
- **ToMi** - Theory of Mind Inventory with first/second-order belief tracking
- **FANToM** - A benchmark for stress-testing machine ToM in conversations
- **SimpleToM** - Allen AI's simplified ToM evaluation dataset
- **ToMBench** - Comprehensive bilingual ToM benchmark covering 8 tasks and 31 abilities

## Setup

Install required dependencies for loading and evaluating datasets.

In [None]:
! pip install evaluate datasets

In [18]:
import os

# Ensure working directory is the notebook's directory
# This makes relative paths like 'tomi/tomi_pairs/...' work correctly
NOTEBOOK_DIR = '/root/ARENA_3.0/capstone/Agent-Robustness-Via-ToM/tom_benchmarks'
os.chdir(NOTEBOOK_DIR)
print(f"Working directory: {os.getcwd()}")

Working directory: /root/ARENA_3.0/capstone/Agent-Robustness-Via-ToM/tom_benchmarks


---

## ToMi (Theory of Mind Inventory)

ToMi is a benchmark from Facebook Research that tests models on false-belief understanding through story-based question answering.

**Key features:**
- **First-order beliefs**: "Where does X think the object is?"
- **Second-order beliefs**: "Where does X think Y thinks the object is?"
- **ToM vs No-ToM**: Questions that require mental state reasoning vs. simple memory retrieval

**Paper:** [ToMi: A Test of Theory of Mind in Language Models](https://arxiv.org/abs/1909.01871)

### Install ToMi and Dependencies

In [None]:
! git clone https://github.com/facebookresearch/ToMi.git tomi
! rm -rf tomi/.git
! pip install tqdm pandas matplotlib seaborn numpy jupyterlab openpyxl transformers datasets evaluate wandb

In [None]:
! unzip tomi/tomi_balanced_story_types.zip -d tomi/

### Extract Question-Answer Pairs

The extractor script processes the raw story traces into structured JSONL files, organized by:
- Order (first-order vs second-order belief questions)
- Story type (0 or 1)
- ToM requirement (whether mental state reasoning is needed)

In [None]:
! python tomi/tomi_pair_extractor.py --data_dir ./tomi/tomi_balanced_story_types --output_dir ./tomi/tomi_pairs --split test

### Format Data for Model Inference

Convert the extracted JSONL files into a prompt format suitable for LLM evaluation.

In [19]:
import json
from pathlib import Path


def format_for_inference(jsonl_path: str, output_path: str = None):
    """Convert ToMi JSONL to inference-ready format.
    
    Args:
        jsonl_path: Path to input JSONL file with story/question pairs
        output_path: Optional output path (defaults to input path with '_prompts' suffix)
    
    Returns:
        List of formatted examples with prompts and metadata
    """
    examples = []
    with open(jsonl_path) as f:
        for line in f:
            ex = json.loads(line)
            
            # Clean up story (remove line numbers)
            story_lines = ex['story'].split('\n')
            clean_story = '\n'.join(
                line.split(' ', 1)[1] if line[0].isdigit() else line 
                for line in story_lines
            )
            
            # Format prompt
            prompt = f"""Story:
{clean_story}

Question: {ex['question']}
Answer:"""
            
            examples.append({
                'prompt': prompt,
                'answer': ex['answer'],
                'question_type': ex['question_type'],
                'story_type': ex['story_type'],
                'requires_tom': ex['requires_tom'],
            })
    
    # Save
    out_path = output_path or jsonl_path.replace('.jsonl', '_prompts.jsonl')
    with open(out_path, 'w') as f:
        for ex in examples:
            f.write(json.dumps(ex) + '\n')
    
    print(f"Saved {len(examples)} prompts to {out_path}")
    return examples


def evaluate_response(response: str, correct_answer: str) -> bool:
    """Check if model response contains the correct answer.
    
    Uses case-insensitive substring matching.
    """
    response_lower = response.lower().strip()
    answer_lower = correct_answer.lower()
    return answer_lower in response_lower

### Generate Prompt Files

Process all ToMi data splits into inference-ready prompts.

In [None]:
# First-order belief questions
format_for_inference('tomi/tomi_pairs/first_order_0_no_tom.jsonl')
format_for_inference('tomi/tomi_pairs/first_order_1_no_tom.jsonl')
format_for_inference('tomi/tomi_pairs/first_order_1_tom.jsonl')

# Second-order belief questions
format_for_inference('tomi/tomi_pairs/second_order_0_tom.jsonl')
format_for_inference('tomi/tomi_pairs/second_order_0_no_tom.jsonl')
format_for_inference('tomi/tomi_pairs/second_order_1_tom.jsonl')
format_for_inference('tomi/tomi_pairs/second_order_1_no_tom.jsonl')

### Sample Data from ToMi

View representative examples from the ToMi dataset across different question types.

In [20]:
import json
import pandas as pd
from IPython.display import display, HTML

def load_tomi_samples(filepath: str, n: int = 2) -> list[dict]:
    """Load n samples from a ToMi JSONL file."""
    samples = []
    with open(filepath) as f:
        for i, line in enumerate(f):
            if i >= n:
                break
            data = json.loads(line)
            # Clean up story for display
            story_lines = data['story'].split('\n')
            clean_story = '\n'.join(
                line.split(' ', 1)[1] if line[0].isdigit() else line
                for line in story_lines
            )
            samples.append({
                'Story': clean_story,
                'Question': data['question'],
                'Answer': data['answer'],
                'Story Type': data['story_type'],
                'Requires ToM': data['requires_tom']
            })
    return samples

def display_tomi_table(title: str, samples: list[dict]):
    """Display samples in a formatted table."""
    print(f"\n{'='*80}")
    print(f"  {title}")
    print(f"{'='*80}")
    df = pd.DataFrame(samples)
    df['Story'] = df['Story'].apply(lambda x: x.replace('\n', '<br>'))
    display(HTML(df.to_html(escape=False, index=False)))

In [21]:
# First-order ToM (false belief) - requires reasoning about what someone THINKS
first_order_tom = load_tomi_samples('tomi/tomi_pairs/first_order_1_tom.jsonl', n=2)
display_tomi_table("First-Order ToM (False Belief)", first_order_tom)

# First-order No-ToM (true belief) - simple memory/tracking
first_order_no_tom = load_tomi_samples('tomi/tomi_pairs/first_order_0_no_tom.jsonl', n=2)
display_tomi_table("First-Order No-ToM (True Belief)", first_order_no_tom)

# Second-order ToM - "What does X think Y thinks?"
second_order_tom = load_tomi_samples('tomi/tomi_pairs/second_order_0_tom.jsonl', n=2)
display_tomi_table("Second-Order ToM", second_order_tom)


  First-Order ToM (False Belief)


Story,Question,Answer,Story Type,Requires ToM
Isabella entered the den. Olivia entered the den. Isabella dislikes the pumpkin The broccoli is in the blue_pantry. Isabella exited the den. Olivia moved the broccoli to the red_drawer. Abigail entered the garden. Isabella entered the garden.,Where will Isabella look for the broccoli?,blue_pantry,false_belief,True
Amelia entered the pantry. Ethan loves the pumpkin Liam entered the pantry. The turnip is in the red_basket. Ethan entered the closet. Liam exited the pantry. Amelia moved the turnip to the blue_pantry.,Where will Liam look for the turnip?,red_basket,false_belief,True



  First-Order No-ToM (True Belief)


Story,Question,Answer,Story Type,Requires ToM
Aria entered the front_yard. Aiden entered the front_yard. The grapefruit is in the green_bucket. Aria moved the grapefruit to the blue_container. Aiden exited the front_yard. Noah entered the playroom.,Where will Aria look for the grapefruit?,blue_container,true_belief,False
Olivia entered the closet. Isla dislikes the hall Aria entered the closet. Isla entered the closet. The orange is in the green_drawer. Aria exited the closet. Olivia moved the orange to the blue_crate. Isla dislikes the cucumber Isla exited the closet. Isla entered the hall.,Where will Olivia look for the orange?,blue_crate,true_belief,False



  Second-Order ToM


Story,Question,Answer,Story Type,Requires ToM
Isabella entered the den. Olivia entered the den. Isabella dislikes the pumpkin The broccoli is in the blue_pantry. Isabella exited the den. Olivia moved the broccoli to the red_drawer. Abigail entered the garden. Isabella entered the garden.,Where does Olivia think that Isabella searches for the broccoli?,blue_pantry,false_belief,True
Amelia entered the pantry. Ethan loves the pumpkin Liam entered the pantry. The turnip is in the red_basket. Ethan entered the closet. Liam exited the pantry. Amelia moved the turnip to the blue_pantry.,Where does Amelia think that Liam searches for the turnip?,red_basket,false_belief,True


### ToMi Dataset Statistics

In [None]:
def count_jsonl_lines(filepath: str) -> int:
    with open(filepath) as f:
        return sum(1 for _ in f)

tomi_dir = Path('tomi/tomi_pairs')
stats = []
for jsonl_file in sorted(tomi_dir.glob('*.jsonl')):
    if '_prompts' not in jsonl_file.name and jsonl_file.name not in ['all_tom.jsonl', 'all_no_tom.jsonl']:
        count = count_jsonl_lines(jsonl_file)
        name = jsonl_file.stem
        order = 'First' if 'first' in name else 'Second'
        requires_tom = 'No' if 'no_tom' in name else 'Yes'
        stats.append({
            'File': jsonl_file.name,
            'Order': order,
            'Requires ToM': requires_tom,
            'Count': count
        })

stats_df = pd.DataFrame(stats)
print("ToMi Dataset Statistics:\n")
display(stats_df)
print(f"\nTotal examples: {stats_df['Count'].sum()}")

---

## FANToM

FANToM (FAct-tracking in NaTural cOnversations with Models) is a benchmark that tests ToM capabilities through naturalistic multi-party conversations.

**Key features:**
- Realistic conversational scenarios with multiple speakers
- Information asymmetry between conversation participants
- Tests whether models can track who knows what

**Paper:** [FANToM: A Benchmark for Stress-Testing Machine Theory of Mind in Interactions](https://arxiv.org/abs/2310.15421)

### Load FANToM Dataset

FANToM data is automatically downloaded from Google Cloud Storage on first use.

In [10]:
# Add fantom to path and load the dataset
import sys
sys.path.insert(0, 'fantom')

from task.dataset_loader import load
fantom_df = load()
print(f"Loaded FANToM dataset with {len(fantom_df)} conversation sets")

Already built at data/fantom. version 1.0
Loaded FANToM dataset with 870 conversation sets


### Sample Data from FANToM

View a sample conversation with its associated questions. FANToM tests multiple question types:
- **Belief Questions**: What does a character believe about a fact?
- **Answerability Questions**: Can a character answer a given question?
- **Info Accessibility Questions**: Does a character have access to certain information?
- **Fact Questions**: Basic factual recall from the conversation

In [11]:
# Display a sample conversation set from FANToM
sample_set = fantom_df.iloc[0]

print("="*80)
print("  SAMPLE CONVERSATION (Short Context)")
print("="*80)
print(sample_set['short_context'])

print("\n" + "="*80)
print("  QUESTIONS FOR THIS CONVERSATION")
print("="*80)

# Fact Question
print("\n--- Fact Question ---")
print(f"Q: {sample_set['factQA']['question']}")
print(f"A: {sample_set['factQA']['correct_answer']}")

# Belief Questions
print("\n--- Belief Questions (ToM) ---")
for i, bq in enumerate(sample_set['beliefQAs'][:2]):  # Show first 2
    print(f"\nQ{i+1}: {bq['question']}")
    print(f"Correct: {bq['correct_answer']}")
    print(f"Wrong (false belief): {bq['wrong_answer']}")
    print(f"ToM Type: {bq['tom_type']}")

# Answerability Questions
print("\n--- Answerability Questions ---")
print(f"List Q: {sample_set['answerabilityQA_list']['question']}")
print(f"Can answer: {sample_set['answerabilityQA_list']['correct_answer']}")
print(f"Cannot answer: {sample_set['answerabilityQA_list']['wrong_answer']}")

# Info Accessibility Questions  
print("\n--- Info Accessibility Questions ---")
print(f"List Q: {sample_set['infoAccessibilityQA_list']['question']}")
print(f"Has access: {sample_set['infoAccessibilityQA_list']['correct_answer']}")
print(f"No access: {sample_set['infoAccessibilityQA_list']['wrong_answer']}")

  SAMPLE CONVERSATION (Short Context)
Gianna: Guys, I've really enjoyed sharing our pet stories, but I need to excuse myself. I need to change clothes for a meeting later. Talk to you later!
Sara: Sure thing, Gianna. Take care!
Javier: Catch you later, Gianna.
Sara: So Javier, have you ever tried training Bruno?
Javier: Yes, I did actually. It was a challenge at times, but rewarding nevertheless. How about you? Did you try training Snowflake?
Sara: Oh gosh, trying to train a cat is a whole different ball game. But I did manage to teach her a few commands and tricks. She was quite an intelligent little furball.
Gianna: Hey guys, I'm back, couldn't miss out on more pet stories. Speaking of teaching and training pets, it is amazing how that further strengthens the bond between us and our pets, right?
Sara: Absolutely, Gianna! The fact that they trust us enough to learn from us is really special.
Javier: I can't agree more. I believe that's one of the ways Bruno conveyed his love and trust

### FANToM Dataset Statistics

In [12]:
# Count questions per conversation set
total_belief_qs = sum(len(row['beliefQAs']) for _, row in fantom_df.iterrows())
total_answerability_binary = sum(len(row['answerabilityQAs_binary']) for _, row in fantom_df.iterrows())
total_info_binary = sum(len(row['infoAccessibilityQAs_binary']) for _, row in fantom_df.iterrows())

fantom_stats = [
    {'Question Type': 'Conversation Sets', 'Count': len(fantom_df)},
    {'Question Type': 'Fact Questions', 'Count': len(fantom_df)},  # 1 per set
    {'Question Type': 'Belief Questions (ToM)', 'Count': total_belief_qs},
    {'Question Type': 'Answerability List Questions', 'Count': len(fantom_df)},  # 1 per set
    {'Question Type': 'Answerability Binary Questions', 'Count': total_answerability_binary},
    {'Question Type': 'Info Accessibility List Questions', 'Count': len(fantom_df)},  # 1 per set
    {'Question Type': 'Info Accessibility Binary Questions', 'Count': total_info_binary},
]

fantom_stats_df = pd.DataFrame(fantom_stats)
print("FANToM Dataset Statistics:\n")
display(fantom_stats_df)

total_questions = sum(s['Count'] for s in fantom_stats[1:])  # Exclude conversation sets count
print(f"\nTotal questions: {total_questions}")

FANToM Dataset Statistics:



Unnamed: 0,Question Type,Count
0,Conversation Sets,870
1,Fact Questions,870
2,Belief Questions (ToM),1540
3,Answerability List Questions,870
4,Answerability Binary Questions,3571
5,Info Accessibility List Questions,870
6,Info Accessibility Binary Questions,3571



Total questions: 11292


---

## SimpleToM

SimpleToM from Allen AI provides a clean, focused evaluation of ToM capabilities across three question types.

**Question types:**
- **Mental-state QA**: Questions about characters' beliefs and knowledge
- **Behavior QA**: Predicting actions based on mental states
- **Judgment QA**: Evaluating decisions given information asymmetry

**Paper:** [SimpleToM: Exposing the Gap between Explicit ToM Inference and Implicit ToM Application in LLMs](https://arxiv.org/abs/2410.13648)

### Load SimpleToM Dataset

In [14]:
from datasets import load_dataset

In [15]:
simple_tom_mental = load_dataset("allenai/SimpleToM", 'mental-state-qa')
simple_tom_behavior = load_dataset("allenai/SimpleToM", 'behavior-qa')
simple_tom_judgment = load_dataset("allenai/SimpleToM", 'judgment-qa')

### Explore Dataset Structure

In [16]:
print("Mental-state QA:", simple_tom_mental)
print("\nBehavior QA:", simple_tom_behavior)
print("\nJudgment QA:", simple_tom_judgment)

Mental-state QA: DatasetDict({
    test: Dataset({
        features: ['id', 'story', 'question', 'scenario_name', 'choices', 'answerKey'],
        num_rows: 1147
    })
})

Behavior QA: DatasetDict({
    test: Dataset({
        features: ['id', 'story', 'question', 'scenario_name', 'choices', 'answerKey'],
        num_rows: 1147
    })
})

Judgment QA: DatasetDict({
    test: Dataset({
        features: ['id', 'story', 'question', 'scenario_name', 'choices', 'answerKey'],
        num_rows: 1147
    })
})


### Sample Data from SimpleToM

View representative examples from each question type in SimpleToM.

In [17]:
def display_simpletom_samples(dataset, title: str, n: int = 2):
    """Display n samples from a SimpleToM dataset split."""
    print(f"\n{'='*80}")
    print(f"  {title}")
    print(f"{'='*80}")
    
    samples = []
    for i in range(min(n, len(dataset['test']))):
        ex = dataset['test'][i]
        choices_text = ' | '.join([f"({label}) {text}" for label, text in 
                                   zip(ex['choices']['label'], ex['choices']['text'])])
        samples.append({
            'Story': ex['story'],
            'Question': ex['question'],
            'Choices': choices_text,
            'Answer': ex['answerKey'],
            'Scenario': ex['scenario_name']
        })
    
    df = pd.DataFrame(samples)
    display(HTML(df.to_html(escape=False, index=False)))

# Mental-state QA - Tests awareness/knowledge attribution
display_simpletom_samples(simple_tom_mental, "Mental-State QA (Awareness Questions)", n=3)

# Behavior QA - Tests action prediction based on mental states  
display_simpletom_samples(simple_tom_behavior, "Behavior QA (Action Prediction)", n=3)

# Judgment QA - Tests evaluation of behavior given information asymmetry
display_simpletom_samples(simple_tom_judgment, "Judgment QA (Behavior Evaluation)", n=3)


  Mental-State QA (Awareness Questions)


Story,Question,Choices,Answer,Scenario
The bag of potato chips has moldy chips in it. Mary picks up the bag in the supermarket and walks to the cashier.,"Is Mary likely to be aware that ""The bag of potato chips has moldy chips in it.""?",(A) Yes | (B) No,B,food_item_in_grocery_store
The bag of potato chips has a dead cockroach in it. Mary picks up the bag in the supermarket and walks to the cashier.,"Is Mary likely to be aware that ""The bag of potato chips has a dead cockroach in it.""?",(A) Yes | (B) No,B,food_item_in_grocery_store
The sealed bag of potato chips has several fried rats in it. Mary picks up the bag in the supermarket and walks to the cashier.,"Is Mary likely to be aware that ""The sealed bag of potato chips has several fried rats in it.""?",(A) No | (B) Yes,A,food_item_in_grocery_store



  Behavior QA (Action Prediction)


Story,Question,Choices,Answer,Scenario
The bag of potato chips has moldy chips in it. Mary picks up the bag in the supermarket and walks to the cashier.,What will Mary likely do next?,(A) pay for the chips | (B) report the moldy chips,A,food_item_in_grocery_store
The bag of potato chips has a dead cockroach in it. Mary picks up the bag in the supermarket and walks to the cashier.,What will Mary likely do next?,(A) report the dead cockroach | (B) pay for the chips,B,food_item_in_grocery_store
The sealed bag of potato chips has several fried rats in it. Mary picks up the bag in the supermarket and walks to the cashier.,What will Mary likely do next?,(A) report the fried rats | (B) pay for the chips,B,food_item_in_grocery_store



  Judgment QA (Behavior Evaluation)


Story,Question,Choices,Answer,Scenario
The bag of potato chips has moldy chips in it. Mary picks up the bag in the supermarket and walks to the cashier.,"Next, Mary ""pay for the chips"" and not ""report the moldy chips"". The behavior of Mary can be best described as:",(A) Reasonable | (B) Not reasonable,A,food_item_in_grocery_store
The bag of potato chips has a dead cockroach in it. Mary picks up the bag in the supermarket and walks to the cashier.,"Next, Mary ""pay for the chips"" and not ""report the dead cockroach"". The behavior of Mary can be best described as:",(A) Not reasonable | (B) Reasonable,B,food_item_in_grocery_store
The sealed bag of potato chips has several fried rats in it. Mary picks up the bag in the supermarket and walks to the cashier.,"Next, Mary ""pay for the chips"" and not ""report the fried rats"". The behavior of Mary can be best described as:",(A) Reasonable | (B) Not reasonable,A,food_item_in_grocery_store


### SimpleToM Dataset Statistics

In [None]:
simpletom_stats = []
for name, dataset in [('Mental-State QA', simple_tom_mental), 
                       ('Behavior QA', simple_tom_behavior),
                       ('Judgment QA', simple_tom_judgment)]:
    for split in dataset.keys():
        simpletom_stats.append({
            'Question Type': name,
            'Split': split,
            'Count': len(dataset[split])
        })

simpletom_stats_df = pd.DataFrame(simpletom_stats)
print("SimpleToM Dataset Statistics:\n")
display(simpletom_stats_df)
print(f"\nTotal examples: {simpletom_stats_df['Count'].sum()}")

---

## ToMBench

ToMBench is a comprehensive bilingual (Chinese/English) Theory of Mind benchmark containing 2,860 testing samples across diverse real-world social scenarios.

**Key features:**
- **8 ToM Tasks**: False Belief, Strange Story, Faux-pas Recognition, Hinting, Scalar Implicature, Persuasion, Ambiguous Story, Unexpected Outcome
- **31 ToM Abilities**: Based on the ATOMS framework covering Emotion, Desire, Intention, Knowledge, Belief, and Non-Literal Communication
- **Bilingual**: Full Chinese and English versions of all questions

**Paper:** [ToMBench: Benchmarking Theory of Mind in Large Language Models](https://arxiv.org/abs/2402.15052)

### Load ToMBench Dataset

In [None]:
# Load all ToMBench JSONL files
tombench_dir = Path('tombench/data')
tombench_data = {}

for jsonl_file in sorted(tombench_dir.glob('*.jsonl')):
    task_name = jsonl_file.stem
    examples = []
    with open(jsonl_file) as f:
        for line in f:
            examples.append(json.loads(line))
    tombench_data[task_name] = examples
    
print(f"Loaded {len(tombench_data)} ToMBench task files:")
for task, examples in tombench_data.items():
    print(f"  - {task}: {len(examples)} examples")

### Sample Data from ToMBench

ToMBench includes both Chinese and English versions. Here are examples from different task types (showing English versions).

In [None]:
def display_tombench_sample(task_name: str, examples: list, n: int = 2):
    """Display n samples from a ToMBench task (English version)."""
    print(f"\n{'='*80}")
    print(f"  {task_name}")
    print(f"{'='*80}")
    
    for i, ex in enumerate(examples[:n]):
        print(f"\n--- Example {i+1} ---")
        print(f"Ability: {ex.get('能力\\nABILITY', 'N/A')}")
        print(f"\nStory: {ex.get('STORY', 'N/A')}")
        print(f"\nQuestion: {ex.get('QUESTION', 'N/A')}")
        
        # Display options (handle NaN values)
        options = []
        for opt in ['OPTION-A', 'OPTION-B', 'OPTION-C', 'OPTION-D']:
            val = ex.get(opt)
            if val and str(val) != 'nan' and pd.notna(val):
                options.append(f"{opt[-1]}) {val}")
        print(f"Options: {' | '.join(options)}")
        print(f"Answer: {ex.get('答案\\nANSWER', 'N/A')}")

# Show samples from key task types
display_tombench_sample("False Belief Task", tombench_data.get('False Belief Task', []), n=2)
display_tombench_sample("Strange Story Task (Irony/Sarcasm)", tombench_data.get('Strange Story Task', []), n=2)
display_tombench_sample("Faux-pas Recognition Test", tombench_data.get('Faux-pas Recognition Test', []), n=1)

### ToMBench Dataset Statistics

In [None]:
tombench_stats = []
for task_name, examples in sorted(tombench_data.items()):
    tombench_stats.append({
        'Task': task_name,
        'Count': len(examples)
    })

tombench_stats_df = pd.DataFrame(tombench_stats)
print("ToMBench Dataset Statistics:\n")
display(tombench_stats_df)
print(f"\nTotal examples: {tombench_stats_df['Count'].sum()}")