In [1]:
# Setup and imports
import os
os.chdir('/home/smallyan/eval_agent')

import json
import random
import sys
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# Set the repository path
repo_path = '/net/scratch2/smallyan/belief_tracking_eval'
sys.path.insert(0, repo_path)
sys.path.insert(0, os.path.join(repo_path, 'notebooks', 'causalToM_novis'))

# Check CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

random.seed(42)

# Import dataset utilities
from src.dataset import Sample, Dataset
from utils import error_detection, get_answer_lookback_payload

# Load synthetic entities
data_path = os.path.join(repo_path, 'data', 'synthetic_entities')
with open(os.path.join(data_path, 'characters.json'), 'r') as f:
    all_characters = json.load(f)
with open(os.path.join(data_path, 'bottles.json'), 'r') as f:
    all_objects = json.load(f)
with open(os.path.join(data_path, 'drinks.json'), 'r') as f:
    all_states = json.load(f)

print(f"Loaded {len(all_characters)} characters, {len(all_objects)} objects, {len(all_states)} states")

Using device: cuda


Loaded 103 characters, 21 objects, 23 states


In [2]:
# Load Llama-3.1-8B-Instruct for GT1 evaluation
from nnsight import LanguageModel

print("Loading Llama-3.1-8B-Instruct...")
model_name = "meta-llama/Llama-3.1-8B-Instruct"
model = LanguageModel(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    dispatch=True,
)
print(f"Model loaded: {model_name}")
print(f"Number of layers: {model.config.num_hidden_layers}")

Loading Llama-3.1-8B-Instruct...


Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 122] Disk quota exceeded: '/net/projects/chai-lab/shared_models/hub/models--meta-llama--Llama-3.1-8B-Instruct/.no_exist/0e9e39f249a16976918f6564b8830bc894c89659/adapter_config.json'


Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 122] Disk quota exceeded: '/net/projects/chai-lab/shared_models/hub/models--meta-llama--Llama-3.1-8B-Instruct/.no_exist/0e9e39f249a16976918f6564b8830bc894c89659/adapter_config.json'


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded: meta-llama/Llama-3.1-8B-Instruct
Number of layers: 32


In [3]:
# Generate counterfactual dataset for answer lookback payload test
n_samples = 10
batch_size = 1

dataset_payload = get_answer_lookback_payload(
    all_characters,
    all_objects,
    all_states,
    n_samples,
)
dataloader_payload = DataLoader(dataset_payload, batch_size=batch_size, shuffle=False)
print(f"Created dataset with {len(dataset_payload)} samples")

# Detect errors (samples where model doesn't answer correctly)
print("\nDetecting errors...")
_, errors = error_detection(model, dataloader_payload, is_remote=False)
valid_samples = len(dataset_payload) - len(errors)
print(f"Valid samples for IIA: {valid_samples} ({len(errors)} errors)")

Created dataset with 10 samples

Detecting errors...


  0%|          | 0/10 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


 10%|█         | 1/10 [00:01<00:11,  1.24s/it]

 20%|██        | 2/10 [00:02<00:08,  1.06s/it]

 30%|███       | 3/10 [00:03<00:07,  1.13s/it]

 40%|████      | 4/10 [00:04<00:06,  1.05s/it]

 50%|█████     | 5/10 [00:05<00:05,  1.03s/it]

 60%|██████    | 6/10 [00:06<00:04,  1.11s/it]

 70%|███████   | 7/10 [00:07<00:03,  1.13s/it]

 80%|████████  | 8/10 [00:08<00:02,  1.07s/it]

 90%|█████████ | 9/10 [00:09<00:01,  1.03s/it]

100%|██████████| 10/10 [00:10<00:00,  1.00it/s]

100%|██████████| 10/10 [00:10<00:00,  1.06s/it]

Valid samples for IIA: 3 (7 errors)





In [4]:
# Need more samples since many have errors - generate larger dataset
n_samples = 30
dataset_payload = get_answer_lookback_payload(
    all_characters,
    all_objects,
    all_states,
    n_samples,
)
dataloader_payload = DataLoader(dataset_payload, batch_size=1, shuffle=False)
print(f"Created dataset with {len(dataset_payload)} samples")

# Detect errors 
print("\nDetecting errors...")
_, errors = error_detection(model, dataloader_payload, is_remote=False)
valid_samples = len(dataset_payload) - len(errors)
print(f"Valid samples for IIA: {valid_samples} ({len(errors)} errors)")

Created dataset with 30 samples

Detecting errors...


  0%|          | 0/30 [00:00<?, ?it/s]

  3%|▎         | 1/30 [00:01<00:55,  1.93s/it]

  7%|▋         | 2/30 [00:03<00:53,  1.92s/it]

 10%|█         | 3/30 [00:05<00:52,  1.94s/it]

 13%|█▎        | 4/30 [00:07<00:50,  1.93s/it]

 17%|█▋        | 5/30 [00:09<00:48,  1.93s/it]

 20%|██        | 6/30 [00:11<00:46,  1.92s/it]

 23%|██▎       | 7/30 [00:13<00:46,  2.03s/it]

 27%|██▋       | 8/30 [00:15<00:43,  2.00s/it]

 30%|███       | 9/30 [00:17<00:41,  1.98s/it]

 33%|███▎      | 10/30 [00:19<00:39,  1.95s/it]

 37%|███▋      | 11/30 [00:20<00:32,  1.70s/it]

 40%|████      | 12/30 [00:21<00:26,  1.46s/it]

 43%|████▎     | 13/30 [00:22<00:22,  1.30s/it]

 47%|████▋     | 14/30 [00:23<00:19,  1.19s/it]

 50%|█████     | 15/30 [00:24<00:17,  1.14s/it]

 53%|█████▎    | 16/30 [00:25<00:15,  1.08s/it]

 57%|█████▋    | 17/30 [00:26<00:13,  1.04s/it]

 60%|██████    | 18/30 [00:27<00:12,  1.02s/it]

 63%|██████▎   | 19/30 [00:28<00:10,  1.01it/s]

 67%|██████▋   | 20/30 [00:29<00:09,  1.03it/s]

 70%|███████   | 21/30 [00:30<00:08,  1.04it/s]

 73%|███████▎  | 22/30 [00:31<00:07,  1.03it/s]

 77%|███████▋  | 23/30 [00:32<00:06,  1.05it/s]

 80%|████████  | 24/30 [00:33<00:06,  1.01s/it]

 83%|████████▎ | 25/30 [00:34<00:04,  1.01it/s]

 87%|████████▋ | 26/30 [00:35<00:03,  1.03it/s]

 90%|█████████ | 27/30 [00:35<00:02,  1.05it/s]

 93%|█████████▎| 28/30 [00:36<00:01,  1.05it/s]

 97%|█████████▋| 29/30 [00:37<00:00,  1.06it/s]

100%|██████████| 30/30 [00:38<00:00,  1.07it/s]

100%|██████████| 30/30 [00:38<00:00,  1.29s/it]

Valid samples for IIA: 9 (21 errors)





In [5]:
# Run IIA experiment at key layers
# For 32-layer model, we scale the layer findings from 80-layer model
# 70B has 80 layers with effect at ~70% (layer 56+)
# So for 32 layers, expect effects around layer 22+

accs_answer_lookback_payload = {}
patch_layers = [0, 8, 16, 20, 22, 24, 26, 28, 30, 31]

print("Testing IIA across layers for answer lookback payload...")
print("=" * 60)

for layer_idx in patch_layers:
    correct, total = 0, 0
    for bi, batch in enumerate(dataloader_payload):
        if bi in errors:
            continue
        counterfactual_prompt = batch["counterfactual_prompt"][0]
        clean_prompt = batch["clean_prompt"][0]
        target = batch["target"][0]

        with torch.no_grad():
            # Get the counterfactual layer output (hidden states are first element of tuple)
            with model.trace(counterfactual_prompt):
                counterfactual_layer_out = model.model.layers[layer_idx].output[0][0, -1].save()
            
            # Patch into clean and get prediction
            with model.trace(clean_prompt):
                model.model.layers[layer_idx].output[0][0, -1] = counterfactual_layer_out
                pred = model.lm_head.output[0, -1].argmax(dim=-1).save()

            pred_text = model.tokenizer.decode([pred]).lower().strip()
            if pred_text == target.lower().strip():
                correct += 1
            total += 1

            del pred
            torch.cuda.empty_cache()

    acc = round(correct / total, 3) if total > 0 else 0.0
    print(f"Layer {layer_idx:2d}: IIA = {acc:.3f} ({correct}/{total})")
    accs_answer_lookback_payload[layer_idx] = acc

print("=" * 60)
peak_layer = max(accs_answer_lookback_payload, key=accs_answer_lookback_payload.get)
peak_iia = accs_answer_lookback_payload[peak_layer]
print(f"Peak IIA: {peak_iia:.3f} at layer {peak_layer}")

Testing IIA across layers for answer lookback payload...


Layer  0: IIA = 0.000 (0/9)


Layer  8: IIA = 0.000 (0/9)


Layer 16: IIA = 0.000 (0/9)


Layer 20: IIA = 0.000 (0/9)


Layer 22: IIA = 0.000 (0/9)


Layer 24: IIA = 0.222 (2/9)


Layer 26: IIA = 0.556 (5/9)


Layer 28: IIA = 0.889 (8/9)
