In [1]:
# Complete setup
import os
os.chdir('/home/smallyan/eval_agent')

import json
import random
import sys
import torch
from torch.utils.data import DataLoader

repo_path = '/net/scratch2/smallyan/belief_tracking_eval'
sys.path.insert(0, repo_path)
sys.path.insert(0, os.path.join(repo_path, 'notebooks', 'causalToM_novis'))

random.seed(42)

from src.dataset import Sample, Dataset
from utils import error_detection, get_answer_lookback_payload

# Load entities
data_path = os.path.join(repo_path, 'data', 'synthetic_entities')
with open(os.path.join(data_path, 'characters.json'), 'r') as f:
    all_characters = json.load(f)
with open(os.path.join(data_path, 'bottles.json'), 'r') as f:
    all_objects = json.load(f)
with open(os.path.join(data_path, 'drinks.json'), 'r') as f:
    all_states = json.load(f)

# Load model
from nnsight import LanguageModel
model = LanguageModel("meta-llama/Llama-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.float16, dispatch=True)
print("Setup complete")

Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 122] Disk quota exceeded: '/net/projects/chai-lab/shared_models/hub/models--meta-llama--Llama-3.1-8B-Instruct/.no_exist/0e9e39f249a16976918f6564b8830bc894c89659/adapter_config.json'


Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 122] Disk quota exceeded: '/net/projects/chai-lab/shared_models/hub/models--meta-llama--Llama-3.1-8B-Instruct/.no_exist/0e9e39f249a16976918f6564b8830bc894c89659/adapter_config.json'


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setup complete


In [2]:
# GT2: Test with NEW data instances - use a different random seed to get 
# entity combinations that weren't used in original experiments

print("=" * 60)
print("GT2: Testing generalization to new data instances")
print("=" * 60)

# Use a new seed to get different entity combinations
random.seed(999)  # Different from seed used in original experiments

n_samples = 20
new_dataset = get_answer_lookback_payload(all_characters, all_objects, all_states, n_samples)
new_dataloader = DataLoader(new_dataset, batch_size=1, shuffle=False)

# Show example
print(f"Created {len(new_dataset)} new samples with novel entity combinations")
print(f"\nExample story:")
print(f"  Characters: {new_dataset[0]['clean_characters']}")
print(f"  Objects: {new_dataset[0]['clean_objects']}")
print(f"  States: {new_dataset[0]['clean_states']}")

GT2: Testing generalization to new data instances
Created 20 new samples with novel entity combinations

Example story:
  Characters: ['Sara', 'Jim']
  Objects: ['mug', 'drum']
  States: ['port', 'bourbon']


In [3]:
# Find valid samples
print("Finding valid samples on new data combinations...")
_, new_errors = error_detection(model, new_dataloader, is_remote=False)
new_valid = [i for i in range(len(new_dataset)) if i not in new_errors]
print(f"Found {len(new_valid)} valid samples out of {len(new_dataset)}")

Finding valid samples on new data combinations...


  0%|          | 0/20 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  5%|▌         | 1/20 [00:01<00:23,  1.25s/it]

 10%|█         | 2/20 [00:02<00:18,  1.05s/it]

 15%|█▌        | 3/20 [00:03<00:17,  1.01s/it]

 20%|██        | 4/20 [00:04<00:15,  1.03it/s]

 25%|██▌       | 5/20 [00:04<00:14,  1.05it/s]

 30%|███       | 6/20 [00:05<00:13,  1.06it/s]

 35%|███▌      | 7/20 [00:06<00:12,  1.08it/s]

 40%|████      | 8/20 [00:07<00:11,  1.08it/s]

 45%|████▌     | 9/20 [00:08<00:10,  1.09it/s]

 50%|█████     | 10/20 [00:09<00:09,  1.08it/s]

 55%|█████▌    | 11/20 [00:10<00:08,  1.09it/s]

 60%|██████    | 12/20 [00:11<00:07,  1.09it/s]

 65%|██████▌   | 13/20 [00:12<00:06,  1.09it/s]

 70%|███████   | 14/20 [00:13<00:05,  1.09it/s]

 75%|███████▌  | 15/20 [00:14<00:04,  1.10it/s]

 80%|████████  | 16/20 [00:14<00:03,  1.10it/s]

 85%|████████▌ | 17/20 [00:16<00:02,  1.05it/s]

 90%|█████████ | 18/20 [00:16<00:01,  1.06it/s]

 95%|█████████▌| 19/20 [00:17<00:00,  1.07it/s]

100%|██████████| 20/20 [00:18<00:00,  1.08it/s]

100%|██████████| 20/20 [00:18<00:00,  1.06it/s]

Found 9 valid samples out of 20





In [4]:
# Run IIA experiment on new data - use 3 trial examples
test_indices = new_valid[:3]
patch_layers = [0, 10, 20, 24, 28, 31]

print(f"\nRunning IIA on {len(test_indices)} new data samples")
print("=" * 60)

gt2_accs = {}
for layer_idx in patch_layers:
    correct, total = 0, 0
    for bi in test_indices:
        batch = new_dataset[bi]
        counterfactual_prompt = batch["counterfactual_prompt"]
        clean_prompt = batch["clean_prompt"]
        target = batch["target"]

        with torch.no_grad():
            with model.trace(counterfactual_prompt):
                cf_out = model.model.layers[layer_idx].output[0][0, -1].save()
            
            with model.trace(clean_prompt):
                model.model.layers[layer_idx].output[0][0, -1] = cf_out
                pred = model.lm_head.output[0, -1].argmax(dim=-1).save()

            pred_text = model.tokenizer.decode([pred]).lower().strip()
            if pred_text == target.lower().strip():
                correct += 1
            total += 1
            torch.cuda.empty_cache()

    acc = correct / total if total > 0 else 0.0
    print(f"Layer {layer_idx:2d}: IIA = {acc:.3f} ({correct}/{total})")
    gt2_accs[layer_idx] = acc

print("=" * 60)
gt2_peak_layer = max(gt2_accs, key=gt2_accs.get)
gt2_peak_iia = gt2_accs[gt2_peak_layer]
print(f"Peak IIA: {gt2_peak_iia:.3f} at layer {gt2_peak_layer}")
print(f"GT2 PASS: {gt2_peak_iia >= 0.33}")


Running IIA on 3 new data samples


Layer  0: IIA = 0.000 (0/3)


Layer 10: IIA = 0.000 (0/3)


Layer 20: IIA = 0.000 (0/3)


Layer 24: IIA = 0.333 (1/3)


Layer 28: IIA = 1.000 (3/3)


Layer 31: IIA = 1.000 (3/3)
Peak IIA: 1.000 at layer 28
GT2 PASS: True


In [5]:
# GT3: Method Generalizability
# The paper proposes the causal abstraction method with interchange interventions
# to identify layer-specific mechanisms for belief tracking

print("=" * 60)
print("GT3: Method Generalizability Assessment")
print("=" * 60)

# The key method is: Causal Abstraction with Interchange Interventions
# This can be applied to any task where we want to identify where specific 
# information is encoded in the model

# For GT3, we test if the method can be applied to a SIMILAR but DIFFERENT task
# Similar task: Object location tracking (simpler than belief tracking)
# The method should reveal similar layer-specific patterns

print("""
The paper proposes a NEW METHOD: Causal Abstraction with Interchange Interventions
for identifying layer-specific mechanisms in language models.

Key components:
1. Create counterfactual pairs (clean vs modified input)
2. Perform layer-wise activation patching
3. Measure Interchange Intervention Accuracy (IIA)
4. Identify layers where specific information is encoded

To test GT3, we apply this method to a DIFFERENT but related task:
- Original task: Belief tracking (Theory of Mind)
- New task: Simple object location tracking (no belief, just factual)

If the method generalizes, we should be able to identify layer-specific
patterns for factual information retrieval.
""")

# Create a simple object tracking task (no belief component)
simple_template = """Story: {char1} puts a {obj} in the {loc1}. Then {char1} moves the {obj} to the {loc2}.
Question: Where is the {obj}?
Answer:"""

# For this, we'll use the same intervention technique but on a simpler task
print("Creating simple object tracking dataset...")

GT3: Method Generalizability Assessment

The paper proposes a NEW METHOD: Causal Abstraction with Interchange Interventions
for identifying layer-specific mechanisms in language models.

Key components:
1. Create counterfactual pairs (clean vs modified input)
2. Perform layer-wise activation patching
3. Measure Interchange Intervention Accuracy (IIA)
4. Identify layers where specific information is encoded

To test GT3, we apply this method to a DIFFERENT but related task:
- Original task: Belief tracking (Theory of Mind)
- New task: Simple object location tracking (no belief, just factual)

If the method generalizes, we should be able to identify layer-specific
patterns for factual information retrieval.

Creating simple object tracking dataset...


In [6]:
# Create a simple factual retrieval task - object location tracking
# This is a different task but uses the same interchange intervention method

locations = ["kitchen", "bedroom", "bathroom", "garage", "garden", "basement", "attic", "office"]
objects_simple = ["ball", "book", "key", "phone", "wallet", "remote", "watch", "pen"]
names = ["Alice", "Bob", "Carol", "Dave", "Eve", "Frank"]

def create_location_tracking_sample(names, objects, locations):
    name = random.choice(names)
    obj = random.choice(objects)
    loc1, loc2 = random.sample(locations, 2)
    
    # Clean prompt - object ends in loc2
    clean_prompt = f"Story: {name} puts a {obj} in the {loc1}. Then {name} moves the {obj} to the {loc2}.\nQuestion: Where is the {obj}?\nAnswer:"
    clean_answer = loc2
    
    # Counterfactual - swap locations
    cf_prompt = f"Story: {name} puts a {obj} in the {loc2}. Then {name} moves the {obj} to the {loc1}.\nQuestion: Where is the {obj}?\nAnswer:"
    cf_answer = loc1
    
    return {
        "clean_prompt": clean_prompt,
        "clean_ans": clean_answer,
        "counterfactual_prompt": cf_prompt,
        "counterfactual_ans": cf_answer,
        "target": cf_answer  # What we expect after patching
    }

# Generate samples
random.seed(456)
location_samples = [create_location_tracking_sample(names, objects_simple, locations) for _ in range(10)]

print("Example location tracking sample:")
print(f"Clean: {location_samples[0]['clean_prompt']}")
print(f"Clean answer: {location_samples[0]['clean_ans']}")
print(f"CF answer: {location_samples[0]['counterfactual_ans']}")

Example location tracking sample:
Clean: Story: Frank puts a pen in the office. Then Frank moves the pen to the garage.
Question: Where is the pen?
Answer:
Clean answer: garage
CF answer: office


In [7]:
# Test model on location tracking task first
print("Testing model on location tracking task...")
valid_location_samples = []

for i, sample in enumerate(location_samples):
    with torch.no_grad():
        # Test clean
        with model.trace(sample["clean_prompt"]):
            clean_pred = model.lm_head.output[0, -1].argmax(dim=-1).save()
        clean_text = model.tokenizer.decode([clean_pred]).lower().strip()
        
        # Test counterfactual
        with model.trace(sample["counterfactual_prompt"]):
            cf_pred = model.lm_head.output[0, -1].argmax(dim=-1).save()
        cf_text = model.tokenizer.decode([cf_pred]).lower().strip()
        
        clean_correct = clean_text == sample["clean_ans"].lower()
        cf_correct = cf_text == sample["counterfactual_ans"].lower()
        
        if clean_correct and cf_correct:
            valid_location_samples.append(i)
        
        torch.cuda.empty_cache()

print(f"Valid samples: {len(valid_location_samples)}/{len(location_samples)}")

Testing model on location tracking task...


Valid samples: 0/10


In [8]:
# The model doesn't answer correctly - let's check what it outputs
print("Checking model outputs on location task...")
sample = location_samples[0]

with torch.no_grad():
    with model.trace(sample["clean_prompt"]):
        pred = model.lm_head.output[0, -1].argmax(dim=-1).save()
    pred_text = model.tokenizer.decode([pred])
    
print(f"Prompt: {sample['clean_prompt']}")
print(f"Expected: {sample['clean_ans']}")
print(f"Got: '{pred_text}'")

Checking model outputs on location task...


Prompt: Story: Frank puts a pen in the office. Then Frank moves the pen to the garage.
Question: Where is the pen?
Answer:
Expected: garage
Got: ' The'


In [9]:
# Need to adjust the prompt format - model expects more structured output
# Let's try a different format similar to the belief tracking prompts

def create_location_tracking_sample_v2(names, objects, locations):
    name = random.choice(names)
    obj = random.choice(objects)
    loc1, loc2 = random.sample(locations, 2)
    
    # Match format closer to belief tracking
    clean_prompt = f"Instruction: Answer the question based on the story.\n\nStory: {name} puts a {obj} in the {loc1}. Then {name} moves the {obj} to the {loc2}.\n\nQuestion: Where is the {obj} now?\nAnswer: The {obj} is in the"
    clean_answer = loc2
    
    cf_prompt = f"Instruction: Answer the question based on the story.\n\nStory: {name} puts a {obj} in the {loc2}. Then {name} moves the {obj} to the {loc1}.\n\nQuestion: Where is the {obj} now?\nAnswer: The {obj} is in the"
    cf_answer = loc1
    
    return {
        "clean_prompt": clean_prompt,
        "clean_ans": clean_answer,
        "counterfactual_prompt": cf_prompt,
        "counterfactual_ans": cf_answer,
        "target": cf_answer
    }

random.seed(789)
location_samples_v2 = [create_location_tracking_sample_v2(names, objects_simple, locations) for _ in range(10)]

# Test
sample = location_samples_v2[0]
with torch.no_grad():
    with model.trace(sample["clean_prompt"]):
        pred = model.lm_head.output[0, -1].argmax(dim=-1).save()
    pred_text = model.tokenizer.decode([pred])
    
print(f"Prompt: {sample['clean_prompt']}")
print(f"Expected: {sample['clean_ans']}")
print(f"Got: '{pred_text}'")

Prompt: Instruction: Answer the question based on the story.

Story: Dave puts a ball in the office. Then Dave moves the ball to the bedroom.

Question: Where is the ball now?
Answer: The ball is in the
Expected: bedroom
Got: ' bedroom'


In [10]:
# Great! Now test for valid samples
print("Testing model on location tracking task v2...")
valid_location_samples = []

for i, sample in enumerate(location_samples_v2):
    with torch.no_grad():
        with model.trace(sample["clean_prompt"]):
            clean_pred = model.lm_head.output[0, -1].argmax(dim=-1).save()
        clean_text = model.tokenizer.decode([clean_pred]).lower().strip()
        
        with model.trace(sample["counterfactual_prompt"]):
            cf_pred = model.lm_head.output[0, -1].argmax(dim=-1).save()
        cf_text = model.tokenizer.decode([cf_pred]).lower().strip()
        
        clean_correct = clean_text == sample["clean_ans"].lower()
        cf_correct = cf_text == sample["counterfactual_ans"].lower()
        
        if clean_correct and cf_correct:
            valid_location_samples.append(i)
        
        torch.cuda.empty_cache()

print(f"Valid samples: {len(valid_location_samples)}/{len(location_samples_v2)}")

Testing model on location tracking task v2...


Valid samples: 10/10
