# ToMi Function Vector Extraction

This notebook extracts function vectors for Theory of Mind (ToM) reasoning using the ToMi dataset and NNSight.

**Approach:**
1. Load ToM-required vs no-ToM-required examples (same question format, different reasoning demands)
2. Run both through a model, extracting activations at a target layer
3. Compute: `function_vector = mean(tom_activations) - mean(no_tom_activations)`
4. Validate by steering model behavior with the function vector

## Setup

In [1]:
import json
import re
import torch
from pathlib import Path
from tqdm import tqdm
from nnsight import LanguageModel

In [2]:
# Paths - adjust if needed
DATA_DIR = Path('tomi/tomi_pairs')
TOM_FILE = DATA_DIR / 'first_order_1_tom.jsonl'
NO_TOM_FILE = DATA_DIR / 'first_order_1_no_tom.jsonl'

# Verify files exist
assert TOM_FILE.exists(), f"Missing: {TOM_FILE}"
assert NO_TOM_FILE.exists(), f"Missing: {NO_TOM_FILE}"
print(f"ToM file: {TOM_FILE}")
print(f"No-ToM file: {NO_TOM_FILE}")

ToM file: tomi/tomi_pairs/first_order_1_tom.jsonl
No-ToM file: tomi/tomi_pairs/first_order_1_no_tom.jsonl


## Prompting Setup

In [3]:
SYSTEM_PROMPT = """You are answering questions about a story. Answer with ONLY the location name in <answer> tags. Example: <answer>blue_pantry</answer>"""

def clean_story(story: str) -> str:
    """Remove line numbers from story."""
    lines = story.split('\n')
    return '\n'.join(
        line.split(' ', 1)[1] if line[0].isdigit() else line 
        for line in lines if line.strip()
    )

def make_prompt(example: dict) -> str:
    """Format single example as user prompt."""
    story = clean_story(example['story'])
    return f"""Story:
{story}

Question: {example['question']}
Answer:"""

def load_examples(jsonl_path: str, limit: int = None) -> list:
    """Load examples and add formatted prompts."""
    examples = []
    with open(jsonl_path) as f:
        for i, line in enumerate(f):
            if limit and i >= limit:
                break
            ex = json.loads(line)
            ex['prompt'] = make_prompt(ex)
            examples.append(ex)
    return examples

## Scoring Functions

In [None]:
def extract_answer(response: str) -> str:
    """Extract answer from <answer> tags."""
    match = re.search(r'<answer>\s*([^<]+)\s*</answer>', response, re.IGNORECASE)
    if match:
        return match.group(1).strip().lower()
    # Fallback: look for location pattern (word_word)
    match = re.search(r'\b(\w+_\w+)\b', response)
    return match.group(1).lower() if match else response.strip().lower()

def score(response: str, correct: str) -> bool:
    """Check if extracted answer matches correct answer."""
    extracted = extract_answer(response)
    return extracted == correct.lower()

def score_batch(responses: list, examples: list) -> dict:
    """Score a batch of responses."""
    correct = sum(score(r, ex['answer']) for r, ex in zip(responses, examples))
    return {
        'accuracy': correct / len(examples),
        'correct': correct,
        'total': len(examples),
    }

## Load Data

In [4]:
# Load examples (use limit for faster iteration during development)
N_EXAMPLES = 100  # Set to None for all examples

tom_examples = load_examples(TOM_FILE, limit=N_EXAMPLES)
no_tom_examples = load_examples(NO_TOM_FILE, limit=N_EXAMPLES)

print(f"Loaded {len(tom_examples)} ToM examples")
print(f"Loaded {len(no_tom_examples)} No-ToM examples")

Loaded 100 ToM examples
Loaded 100 No-ToM examples


In [5]:
# Inspect a sample
print("=== ToM Example ===")
print(tom_examples[0]['prompt'])
print(f"\nCorrect answer: {tom_examples[0]['answer']}")
print(f"Story type: {tom_examples[0]['story_type']}")

=== ToM Example ===
Story:
Isabella entered the den.
Olivia entered the den.
Isabella dislikes the pumpkin
The broccoli is in the blue_pantry.
Isabella exited the den.
Olivia moved the broccoli to the red_drawer.
Abigail entered the garden.
Isabella entered the garden.

Question: Where will Isabella look for the broccoli?
Answer:

Correct answer: blue_pantry
Story type: false_belief


In [6]:
print("=== No-ToM Example ===")
print(no_tom_examples[0]['prompt'])
print(f"\nCorrect answer: {no_tom_examples[0]['answer']}")
print(f"Story type: {no_tom_examples[0]['story_type']}")

=== No-ToM Example ===
Story:
Aria entered the front_yard.
Aiden entered the front_yard.
The grapefruit is in the green_bucket.
Aria moved the grapefruit to the blue_container.
Aiden exited the front_yard.
Noah entered the playroom.

Question: Where will Aiden look for the grapefruit?
Answer:

Correct answer: blue_container
Story type: true_belief


## Load Model

In [None]:
# # Choose model - smaller models for faster iteration
# MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"  # or "mistralai/Mistral-7B-Instruct-v0.2"

# model = LanguageModel(MODEL_NAME, device_map="auto")
# print(f"Loaded {MODEL_NAME}")
# print(f"Number of layers: {len(model.model.layers)}")

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct.
401 Client Error. (Request ID: Root=1-69711665-0c1f3982505dc2dc5eea9e04;7b126a48-666e-4096-991a-a6f2d402b720)

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.1-8B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in.

In [None]:
# Target layer for activation extraction (middle-ish layers often work well)
# TARGET_LAYER = len(model.model.layers) // 2
# print(f"Target layer: {TARGET_LAYER}")

## Helper: Format with Chat Template

In [None]:
def format_chat_prompt(user_prompt: str, system_prompt: str = SYSTEM_PROMPT) -> str:
    """Format prompt using model's chat template."""
    if hasattr(model.tokenizer, 'apply_chat_template'):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        return model.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    else:
        return f"{system_prompt}\n\n{user_prompt}"

# Test it
test_formatted = format_chat_prompt(tom_examples[0]['prompt'])
print(test_formatted[:500] + "...")

## Extract Activations

Run examples through the model and save activations at the target layer (last token position).

In [None]:
def extract_activations(examples: list, target_layer: int = TARGET_LAYER):
    """Extract activations and generate responses for a list of examples."""
    responses = []
    activations = []
    
    for ex in tqdm(examples, desc="Extracting"):
        full_prompt = format_chat_prompt(ex['prompt'])
        
        with model.generate(full_prompt, max_new_tokens=20) as gen:
            # Save activation at last token position before generation
            act = model.model.layers[target_layer].output[0][:, -1, :].save()
        
        # Decode response
        response_tokens = gen.output[0][len(model.tokenizer.encode(full_prompt)):]
        response_text = model.tokenizer.decode(response_tokens, skip_special_tokens=True)
        
        responses.append(response_text)
        activations.append(act.value.detach().cpu())
    
    return responses, activations

In [None]:
# Extract from ToM examples
print("Extracting ToM activations...")
tom_responses, tom_activations = extract_activations(tom_examples)

In [None]:
# Extract from No-ToM examples
print("Extracting No-ToM activations...")
no_tom_responses, no_tom_activations = extract_activations(no_tom_examples)

## Baseline Accuracy

In [None]:
tom_scores = score_batch(tom_responses, tom_examples)
no_tom_scores = score_batch(no_tom_responses, no_tom_examples)

print(f"ToM accuracy:    {tom_scores['accuracy']:.1%} ({tom_scores['correct']}/{tom_scores['total']})")
print(f"No-ToM accuracy: {no_tom_scores['accuracy']:.1%} ({no_tom_scores['correct']}/{no_tom_scores['total']})")

In [None]:
# Inspect some responses
print("=== Sample ToM Responses ===")
for i in range(3):
    print(f"\nQ: {tom_examples[i]['question']}")
    print(f"Model: {tom_responses[i]}")
    print(f"Correct: {tom_examples[i]['answer']}")
    print(f"Score: {score(tom_responses[i], tom_examples[i]['answer'])}")

## Compute Function Vector

In [None]:
# Stack activations
tom_acts_tensor = torch.stack(tom_activations).squeeze(1)  # [N, hidden_dim]
no_tom_acts_tensor = torch.stack(no_tom_activations).squeeze(1)

print(f"ToM activations shape: {tom_acts_tensor.shape}")
print(f"No-ToM activations shape: {no_tom_acts_tensor.shape}")

In [None]:
# Compute function vector
tom_mean = tom_acts_tensor.mean(dim=0)
no_tom_mean = no_tom_acts_tensor.mean(dim=0)

function_vector = tom_mean - no_tom_mean

print(f"Function vector shape: {function_vector.shape}")
print(f"Function vector norm: {function_vector.norm():.4f}")

In [None]:
# Save for later use
torch.save({
    'function_vector': function_vector,
    'tom_mean': tom_mean,
    'no_tom_mean': no_tom_mean,
    'target_layer': TARGET_LAYER,
    'model_name': MODEL_NAME,
    'n_tom_examples': len(tom_examples),
    'n_no_tom_examples': len(no_tom_examples),
}, 'tom_function_vector.pt')

print("Saved to tom_function_vector.pt")

## Steering Experiment

Test if adding the function vector improves ToM performance on held-out examples.

In [None]:
def generate_with_steering(prompt: str, steering_vector: torch.Tensor, 
                           layer: int, scale: float = 1.0):
    """Generate with function vector added at target layer."""
    full_prompt = format_chat_prompt(prompt)
    steering_vector = steering_vector.to(model.device)
    
    with model.generate(full_prompt, max_new_tokens=20) as gen:
        # Add steering vector to layer output
        model.model.layers[layer].output[0][:, -1, :] += scale * steering_vector
    
    response_tokens = gen.output[0][len(model.tokenizer.encode(full_prompt)):]
    return model.tokenizer.decode(response_tokens, skip_special_tokens=True)

In [None]:
# Test on a few examples
SCALE = 1.0  # Adjust this to control steering strength

print("=== Steering Comparison ===")
for i in range(5):
    ex = tom_examples[i]
    
    # Without steering (use cached response)
    baseline = tom_responses[i]
    
    # With steering
    steered = generate_with_steering(ex['prompt'], function_vector, TARGET_LAYER, scale=SCALE)
    
    print(f"\nQ: {ex['question']}")
    print(f"Correct: {ex['answer']}")
    print(f"Baseline: {baseline} ({'✓' if score(baseline, ex['answer']) else '✗'})")
    print(f"Steered:  {steered} ({'✓' if score(steered, ex['answer']) else '✗'})")

## Next Steps

1. **Tune hyperparameters**: Try different layers, steering scales
2. **Cross-validation**: Use train split for extraction, test for validation
3. **Generalization**: Test function vector on FANToM and SimpleToM
4. **Analysis**: Compare first-order vs second-order vectors (cosine similarity)