# Theory of Mind (ToM) Cognitive Action Experiment

This notebook demonstrates how to use Brije's cognitive action detection system to study Theory of Mind reasoning in language models.

## What This Experiment Does

Instead of just asking "Can the model answer ToM questions correctly?", we ask:
- **Which cognitive processes activate during ToM reasoning?**
- **How do cognitive patterns differ between ToM vs. non-ToM tasks?**
- **What is the "cognitive fingerprint" of mental state attribution?**

## Experiment Components

1. **Task Suite**: 105 classic ToM tasks (False Belief, Unexpected Contents, etc.)
2. **Real-Time Tracking**: Monitor cognitive actions token-by-token
3. **Comparative Analysis**: Test vs. Control conditions
4. **Multi-Agent Dialogues**: Watch two AIs reason about mental states
5. **Visualizations**: Heatmaps, networks, and layer activation profiles

## Setup

In [1]:
import sys
from pathlib import Path

# Add experiment modules to path
sys.path.insert(0, str(Path.cwd().parent / 'src' / 'experiments'))
sys.path.insert(0, str(Path.cwd().parent / 'src' / 'probes'))

from tom_tasks import ToMTaskGenerator, ToMTaskType, TaskDifficulty
from tom_inference import ToMInferenceEngine
from tom_dialogue import ToMDialogueEngine
from tom_analysis import ToMAnalyzer

import random
import json
from IPython.display import display, HTML

print("✓ Imports successful")

AMD GPU detected - configuring ROCm environment variables
  HSA_OVERRIDE_GFX_VERSION: 11.0.0
  PYTORCH_ROCM_ARCH: gfx1100
  TORCH_USE_HIP_DSA: 1
  HIP_LAUNCH_BLOCKING: 1


  from .autonotebook import tqdm as notebook_tqdm


✓ Imports successful


## Part 1: Explore the ToM Task Suite

Let's look at the different types of ToM tasks we've generated.

In [2]:
# Load tasks
task_path = Path.cwd().parent / "data" / "tom_tasks" / "tom_task_suite.json"

generator = ToMTaskGenerator()
tasks = generator.load_tasks(task_path)

print(f"Loaded {len(tasks)} ToM tasks")
print("\nTask type distribution:")
for task_type in ToMTaskType:
    count = len([t for t in tasks if t.task_type == task_type])
    print(f"  {task_type.value}: {count}")

Loaded 105 ToM tasks

Task type distribution:
  false_belief: 30
  unexpected_contents: 20
  appearance_reality: 20
  second_order_belief: 15
  affective_tom: 20


### Example: False Belief Task (Sally-Anne)

In [3]:
# Show a false belief task
false_belief_task = next(t for t in tasks if t.task_type == ToMTaskType.FALSE_BELIEF)

print("="*80)
print("FALSE BELIEF TASK EXAMPLE")
print("="*80)
print(f"\nTask ID: {false_belief_task.task_id}")
print(f"Difficulty: {false_belief_task.difficulty.value}")
print(f"\nScenario (requires ToM):")
print(false_belief_task.scenario)
print(f"\nControl Scenario (no ToM):")
print(false_belief_task.control_scenario)
print(f"\nQuestion: {false_belief_task.question}")
print(f"Correct Answer: {false_belief_task.correct_answer}")
print(f"\nWhy this requires ToM:")
print(false_belief_task.tom_explanation)
print(f"\nExpected Cognitive Actions:")
print(", ".join(false_belief_task.expected_cognitive_actions))

FALSE BELIEF TASK EXAMPLE

Task ID: false_belief_0001
Difficulty: easy

Scenario (requires ToM):
Emma and Sally are in the living room. Emma puts a book in the box and then leaves to go outside. While Emma is gone, Sally takes the book out of the box and puts it in the drawer. Emma comes back inside.

Control Scenario (no ToM):
Emma and Sally are in the living room. A book is first in the box. Then the book is moved to the drawer. Emma was not present when the book was moved.

Question: Where will Emma look for the book?
Correct Answer: in the box

Why this requires ToM:
Requires tracking that Emma has a false belief about the book's location. The reasoner must distinguish between reality (item in drawer) and Emma's belief (thinks item is still in box).

Expected Cognitive Actions:
perspective_taking, hypothesis_generation, metacognitive_monitoring, distinguishing, updating_beliefs


### Example: Second-Order Belief Task

This is the most complex type - requires recursive reasoning about what someone thinks that someone else thinks.

In [4]:
second_order_task = next(t for t in tasks if t.task_type == ToMTaskType.SECOND_ORDER_BELIEF)

print("="*80)
print("SECOND-ORDER BELIEF TASK EXAMPLE")
print("="*80)
print(f"\nScenario:")
print(second_order_task.scenario)
print(f"\nQuestion: {second_order_task.question}")
print(f"Correct Answer: {second_order_task.correct_answer}")
print(f"\nWhy this is hard:")
print(second_order_task.tom_explanation)

SECOND-ORDER BELIEF TASK EXAMPLE

Scenario:
Liam and Oliver are together in the garden. Liam tells Oliver that the book is in the garden. Oliver leaves. While Oliver is gone, Tom moves the book to the bedroom and tells Liam about it. Liam sees the book in the new location. Oliver doesn't know the book was moved. Liam doesn't know that Oliver doesn't know about the move.

Question: Where does Liam think that Oliver will look for the book?
Correct Answer: in the garden

Why this is hard:
Requires recursive reasoning: Liam thinks that Oliver thinks the book is in garden. Must track multiple nested belief states and distinguish who knows what.


## Part 2: Single Task Analysis with Real-Time Tracking

Let's analyze a single task and watch cognitive actions activate in real-time.

In [5]:
# Initialize ToM inference engine
engine = ToMInferenceEngine(
    probes_base_dir=Path.cwd().parent / "data" / "probes_binary",
    model_name="google/gemma-3-4b-it",
    verbose=True
)

Initializing Theory of Mind Inference Engine...
Detected compute device: CUDA device: AMD Radeon RX 7700 XT
Initializing StreamingProbeInferenceEngine...
  Probes base dir: /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary
  Model: google/gemma-3-4b-it
  Device: cuda
  Layer range: 21-30 (10 layers)

Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_21/probe_abstracting.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_21/probe_accepting.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_21/probe_analogical_thinking.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_21/probe_analyzing.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_21/probe_applying.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_21/probe_attentional_deploymen

  state = torch.load(load_path, map_location=device)
  untyped_storage = torch.UntypedStorage(self.size(), device=device)


Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_25/probe_concretizing.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_25/probe_connecting.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_25/probe_convergent_thinking.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_25/probe_counterfactual_reasoning.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_25/probe_creating.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_25/probe_distinguishing.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_25/probe_divergent_thinking.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/brije/data/probes_binary/layer_25/probe_emotion_characterizing.pth
Loaded probe from /home/koalacrown/Desktop/Code/Projects/

`torch_dtype` is deprecated! Use `dtype` instead!


Detected vision-language model. Loading text-only...


Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.68s/it]



✓ Initialized with 450 probes across 10 layers

✓ ToM Inference Engine ready
  Tracking 9 ToM-specific cognitive actions


In [6]:
# Pick an interesting task
sample_task = random.choice([t for t in tasks if t.task_type == ToMTaskType.FALSE_BELIEF])

print("Analyzing task:")
print(f"Scenario: {sample_task.scenario}")
print(f"Question: {sample_task.question}\n")

# Analyze with real-time display
signature = engine.analyze_task(
    sample_task,
    threshold=0.1,
    show_realtime=True  # Watch activations in real-time!
)

Analyzing task:
Scenario: Ethan and Sophia are in the garden. Ethan puts a ball in the box and then leaves to go outside. While Ethan is gone, Sophia takes the ball out of the box and puts it in the trunk. Ethan comes back inside.
Question: Where will Ethan look for the ball?


                       GENERATING WITH COGNITIVE TRACKING                       

Prompt: You are reasoning about a theory of mind scenario. Think step-by-step about what different people kn...



Generated Response:
--------------------------------------------------------------------------------
:

Answer: Ethan will look in the box.

Reasoning: Ethan doesn't know that Sophia moved the ball. He believes that the ball is still in the box because he last saw it there. This demonstrates an understanding of what someone else might believe, even if that belief is incorrect.

--------------------------------------------------------------------------------


Now analyzing cognitive actions during generation...


     

In [7]:
# Examine results
print("\n" + "="*80)
print("ANALYSIS RESULTS")
print("="*80)

print(f"\nToM Specificity Score: {signature.tom_specificity_score:.3f}")
print(f"Expected Action Coverage: {signature.expected_action_coverage:.1%}")

print(f"\nExpected Actions: {', '.join(signature.expected_actions)}")
print(f"Detected Expected: {', '.join(signature.detected_expected_actions)}")
print(f"Unexpected Actions: {', '.join(signature.unexpected_actions)}")

print(f"\nTop 10 Differential Activations (Test - Control):")
sorted_diff = sorted(signature.differential_actions.items(), key=lambda x: x[1], reverse=True)
for i, (action, diff) in enumerate(sorted_diff[:10], 1):
    marker = "✓" if action in signature.expected_actions else " "
    print(f"{marker} {i:2d}. {action:30s} {diff:+.4f}")

if signature.critical_tokens:
    print(f"\nCritical Reasoning Moments (tokens with high ToM activation):")
    for pos, token, actions in signature.critical_tokens[:5]:
        print(f"  Token {pos:3d}: '{token:15s}' -> {', '.join(actions)}")


ANALYSIS RESULTS

ToM Specificity Score: 0.667
Expected Action Coverage: 80.0%

Expected Actions: metacognitive_monitoring, updating_beliefs, hypothesis_generation, distinguishing, perspective_taking
Detected Expected: hypothesis_generation, updating_beliefs, metacognitive_monitoring, perspective_taking
Unexpected Actions: cognition_awareness, questioning, counterfactual_reasoning, convergent_thinking, understanding, emotion_perception, emotion_understanding

Top 10 Differential Activations (Test - Control):
   1. cognition_awareness            +1.0000
✓  2. metacognitive_monitoring       +1.0000
✓  3. updating_beliefs               +1.0000
✓  4. perspective_taking             +1.0000
   5. understanding                  +1.0000
   6. emotion_perception             +1.0000
✓  7. distinguishing                 +0.0000
   8. analogical_thinking            +0.0000
   9. zooming_out                    +0.0000
  10. self_questioning               +0.0000

Critical Reasoning Moments (tokens

## Part 3: Multi-Agent ToM Dialogue

Watch two AI agents discuss a ToM scenario while we track the reasoner's cognitive processes.

In [None]:
# Initialize dialogue engine
dialogue_engine = ToMDialogueEngine(
    probes_base_dir=Path.cwd().parent / "data" / "probes_binary",
    model_name="google/gemma-3-4b-it",
    verbose=True
)

In [None]:
# Run a dialogue session
dialogue_task = random.choice([t for t in tasks if t.task_type == ToMTaskType.UNEXPECTED_CONTENTS])

session = dialogue_engine.run_dialogue_session(
    dialogue_task,
    threshold=0.1,
    show_realtime=True
)

## Part 4: Batch Experiment Analysis

Analyze multiple tasks to find patterns across ToM task types.

In [None]:
# Select a diverse sample of tasks
sample_size = 10  # Adjust based on available time/compute

diverse_sample = []
for task_type in ToMTaskType:
    type_tasks = [t for t in tasks if t.task_type == task_type]
    diverse_sample.extend(random.sample(type_tasks, min(2, len(type_tasks))))

print(f"Running experiment on {len(diverse_sample)} diverse tasks...")
print("This may take a few minutes...\n")

In [None]:
# Run batch analysis
result = engine.analyze_task_suite(
    diverse_sample,
    threshold=0.1,
    save_path=Path.cwd().parent / "output" / "tom_experiments" / "batch_results.json"
)

## Part 5: Comprehensive Visualizations

Generate all analysis visualizations.

In [None]:
# Initialize analyzer
analyzer = ToMAnalyzer(
    output_dir=Path.cwd().parent / "output" / "tom_experiments" / "visualizations"
)

# Generate comprehensive report
analyzer.create_comprehensive_report(result, diverse_sample)

In [None]:
# Display visualizations inline
from IPython.display import Image

viz_dir = Path.cwd().parent / "output" / "tom_experiments" / "visualizations"

print("Cognitive Action Heatmap by Task Type:")
display(Image(filename=str(viz_dir / '01_action_by_tasktype_heatmap.png')))

In [None]:
print("Top ToM-Specific Cognitive Actions:")
display(Image(filename=str(viz_dir / '02_differential_activations.png')))

In [None]:
print("Layer Activation Profile:")
display(Image(filename=str(viz_dir / '03_layer_preferences.png')))

In [None]:
print("ToM Action Co-occurrence Network:")
display(Image(filename=str(viz_dir / '07_action_network.png')))

## Part 6: Key Findings & Interpretations

Let's summarize what we learned about ToM reasoning in this model.

In [None]:
print("="*80)
print("KEY FINDINGS")
print("="*80)

print(f"\n1. OVERALL ToM CAPABILITY")
print(f"   Average ToM Specificity: {result.avg_tom_specificity:.3f}")
print(f"   Average Expected Coverage: {result.avg_expected_coverage:.1%}")
print(f"   Interpretation: {'Strong' if result.avg_tom_specificity > 0.1 else 'Weak'} ToM-specific cognitive signature")

print(f"\n2. TASK TYPE DIFFERENCES")
for task_type, stats in sorted(result.by_task_type.items()):
    print(f"   {task_type}:")
    print(f"     Specificity: {stats['avg_specificity']:.3f}")
    print(f"     Coverage: {stats['avg_coverage']:.1%}")

print(f"\n3. TOP ToM-SPECIFIC COGNITIVE ACTIONS")
for i, (action, diff) in enumerate(result.tom_action_rankings[:5], 1):
    print(f"   {i}. {action}: {diff:+.4f}")

print(f"\n4. LAYER PREFERENCES")
peak_layer = max(result.tom_layer_preferences.items(), key=lambda x: x[1])[0]
print(f"   Peak ToM activation at Layer {peak_layer}")
print(f"   Interpretation: ToM reasoning emerges in {'early' if peak_layer < 25 else 'late'} layers")

print("\n" + "="*80)

## Next Steps

1. **Expand Sample Size**: Run on all 105 tasks for comprehensive results
2. **Compare Models**: Test different model sizes (Gemma 3 2B vs 4B vs 27B)
3. **Intervention Studies**: Ablate specific layers and measure ToM degradation
4. **Fine-Tuning**: Can we improve ToM by training on cognitive signatures?
5. **Cross-Task Transfer**: Do ToM cognitive patterns transfer to new scenarios?

## Citation

If you use this ToM experiment framework in your research, please cite:
- **Brije**: https://github.com/yourusername/brije
- Classic ToM papers: Wimmer & Perner (1983), Baron-Cohen et al. (1985)