#### Imports

In [1]:
# Patching Analysis: Word Counting Task
# Investigating whether a model (Falcon-7B-Instruct) maintains internal running count representations

import pandas as pd
import numpy as np
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from nnsight import LanguageModel
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import random
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

print("Libraries loaded successfully!")

Libraries loaded successfully!


#### Load dataset and benchmarking results

In [3]:
# Load the original dataset
def load_jsonl(file_path):
    """Load JSONL file"""
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

# Load datasets
print("Loading datasets...")
dataset = load_jsonl('word_counting_dataset.jsonl')
results_df = pd.read_csv('evaluation_results_20250531_030418.csv')

print(f"Dataset size: {len(dataset)} examples")
print(f"Results size: {len(results_df)} evaluations")

Loading datasets...
Dataset size: 5000 examples
Results size: 20000 evaluations


In [4]:
# Explore the data structure
print("\nFirst example from dataset:")
print(dataset[0])

print("\nFirst row from results:")
print(results_df.iloc[0])

print("\nUnique models in results:")
print(results_df['model'].unique())


First example from dataset:
{'type': 'sports', 'list': ['book', 'tennis', 'soccer', 'cycling', 'apple', 'boxing', 'golf', 'chair', 'peach', 'horse', 'tennis'], 'answer': 4}

First row from results:
example_id                                           0
type                                            sports
list_length                                         11
correct_answer                                       4
raw_response                                        3)
predicted_answer                                     3
correct                                          False
model               mistralai/Mistral-7B-Instruct-v0.3
Name: 0, dtype: object

Unique models in results:
['mistralai/Mistral-7B-Instruct-v0.3' 'meta-llama/Llama-3.2-3B-Instruct'
 'openai-community/gpt2-large' 'tiiuae/Falcon3-7B-Instruct']


#### Filter data for analysis

In [5]:
def create_prompt(example):
    """Create a formatted prompt for the word counting task"""
    list_str = " ".join(example['list'])
    
    prompt = f"""Count the number of words in the following list that match the given type, and put the numerical answer in parentheses.
    Type: {example['type']}
    List: [{list_str}]
    Answer: ("""
    
    return prompt

In [6]:
def calculate_running_count(example, position):
    """Calculate running count up to given position"""
    target_type = example['type']
    count = 0
    
    # Define category mappings (you may need to expand this)
    category_mappings = {
        'fruits': ['apple', 'banana', 'cherry', 'grape', 'mango', 'peach', 'watermelon', 'pineapple', 'strawberry'],
        'animals': ['cat', 'dog', 'horse', 'rabbit', 'snake', 'tiger', 'elephant', 'penguin', 'dolphin', 'eagle', 'mouse', 'bird', 'cow'],
        'cities': ['Paris', 'London', 'Tokyo', 'Berlin', 'Mumbai', 'Bangkok', 'Sydney', 'Rio', 'Toronto'],
        'sports': ['tennis', 'soccer', 'cycling', 'boxing', 'golf', 'basketball', 'baseball', 'hockey', 'volleyball', 'swimming'],
        'professions': ['doctor', 'teacher', 'lawyer', 'engineer', 'pilot', 'chef', 'nurse', 'architect', 'mechanic', 'firefighter']
    }
    
    target_words = category_mappings.get(target_type, [])
    
    for i in range(min(position + 1, len(example['list']))):
        if example['list'][i].lower() in target_words:
            count += 1
            
    return count

In [7]:
# Filter for Falcon results only
falcon_results = results_df[results_df['model'] == 'tiiuae/Falcon3-7B-Instruct'].copy()

In [8]:
print(f"Falcon results: {len(falcon_results)} examples")

Falcon results: 5000 examples


In [9]:
# Apply filtering criteria
filtered_indices = []
filtered_data = []

In [10]:
for idx, row in falcon_results.iterrows():
    example_id = row['example_id']
    
    # Get corresponding example from dataset
    if example_id < len(dataset):
        example = dataset[example_id]
        
        # Apply filters
        list_length = len(example['list'])
        target_count = example['answer']
        is_correct = row['correct']
        
        if (list_length <= 8 and 
            target_count in [3, 4, 5, 6] and 
            is_correct):
            
            filtered_indices.append(example_id)
            filtered_data.append({
                'example_id': example_id,
                'example': example,
                'list_length': list_length,
                'target_count': target_count,
                'falcon_prediction': row['predicted_answer']
            })

In [11]:
print(f"\nFiltered dataset: {len(filtered_data)} examples")
print("Distribution by target count:")
count_dist = defaultdict(int)
for item in filtered_data:
    count_dist[item['target_count']] += 1
print(dict(count_dist))


Filtered dataset: 514 examples
Distribution by target count:
{5: 137, 4: 150, 3: 150, 6: 77}


#### Generate control pairs

In [12]:
def find_running_count_at_positions(example):
    """Find running count at each position in the sequence"""
    counts_by_position = {}
    for pos in range(len(example['list'])):
        counts_by_position[pos] = calculate_running_count(example, pos)
    return counts_by_position

In [13]:
def generate_control_pairs(filtered_data, pairs_per_condition=100):
    """Generate pairs for 2x2 control design"""
    
    # First, calculate running counts for all examples
    print("Calculating running counts for all examples...")
    for item in tqdm(filtered_data):
        item['running_counts'] = find_running_count_at_positions(item['example'])
    
    control_pairs = {
        'same_count_same_pos': [],      # Positive Control 1
        'same_count_diff_pos': [],      # Positive Control 2  
        'diff_count_same_pos': [],      # Negative Control 1
        'diff_count_diff_pos': []       # Negative Control 2
    }
    
    print(f"\nGenerating {pairs_per_condition} pairs for each control condition...")
    
    # Generate pairs for each condition
    max_attempts = len(filtered_data) * 10  # Prevent infinite loops
    
    for condition in control_pairs.keys():
        attempts = 0
        while len(control_pairs[condition]) < pairs_per_condition and attempts < max_attempts:
            attempts += 1
            
            # Randomly select two examples
            example1 = random.choice(filtered_data)
            example2 = random.choice(filtered_data)
            
            if example1['example_id'] == example2['example_id']:
                continue
                
            # Find valid position pairs based on condition
            valid_pairs = []
            
            for pos1 in range(len(example1['example']['list'])):
                for pos2 in range(len(example2['example']['list'])):
                    count1 = example1['running_counts'][pos1]
                    count2 = example2['running_counts'][pos2]
                    
                    if condition == 'same_count_same_pos':
                        if count1 == count2 and pos1 == pos2 and count1 > 0:
                            valid_pairs.append((pos1, pos2))
                    elif condition == 'same_count_diff_pos':
                        if count1 == count2 and pos1 != pos2 and count1 > 0:
                            valid_pairs.append((pos1, pos2))
                    elif condition == 'diff_count_same_pos':
                        if count1 != count2 and pos1 == pos2 and min(count1, count2) > 0:
                            valid_pairs.append((pos1, pos2))
                    elif condition == 'diff_count_diff_pos':
                        if count1 != count2 and pos1 != pos2 and min(count1, count2) > 0:
                            valid_pairs.append((pos1, pos2))
            
            if valid_pairs:
                pos1, pos2 = random.choice(valid_pairs)
                pair = {
                    'source_example': example1,
                    'target_example': example2,
                    'source_position': pos1,
                    'target_position': pos2,
                    'source_count': example1['running_counts'][pos1],
                    'target_count': example2['running_counts'][pos2],
                    'condition': condition
                }
                control_pairs[condition].append(pair)
    
    # Print summary
    print("\nControl pairs generated:")
    for condition, pairs in control_pairs.items():
        print(f"{condition}: {len(pairs)} pairs")
        if len(pairs) > 0:
            sample_pair = pairs[0]
            print(f"  Example: count {sample_pair['source_count']}@pos{sample_pair['source_position']} → count {sample_pair['target_count']}@pos{sample_pair['target_position']}")
    
    return control_pairs

In [14]:
# Generate control pairs
control_pairs = generate_control_pairs(filtered_data, pairs_per_condition=100)

Calculating running counts for all examples...


100%|██████████████████████████████████████| 514/514 [00:00<00:00, 73216.92it/s]


Generating 100 pairs for each control condition...

Control pairs generated:
same_count_same_pos: 100 pairs
  Example: count 1@pos0 → count 1@pos0
same_count_diff_pos: 100 pairs
  Example: count 2@pos1 → count 2@pos4
diff_count_same_pos: 100 pairs
  Example: count 4@pos3 → count 2@pos3
diff_count_diff_pos: 100 pairs
  Example: count 3@pos2 → count 5@pos4





#### Load Falcon model with NNsight

In [32]:
print("Loading Falcon model with NNsight...")

# Load model (NNsight handles device placement automatically)
model = LanguageModel("tiiuae/Falcon3-7B-Instruct")

print("✓ Model loaded successfully!")

Loading Falcon model with NNsight...
✓ Model loaded successfully!


In [None]:
# Test basic functionality
test_example = filtered_data[0]['example']
test_prompt = create_prompt(test_example)

print(f"\nTest prompt: {test_prompt}")

In [19]:
# Use the proper NNsight generate syntax
with model.generate(test_prompt, max_new_tokens=5, scan=False, validate=False) as tracer:
    pass

print("✓ Model test successful!")

# Inspect model architecture
print(f"\nModel architecture:")
try:
    # Try different possible layer access patterns for Falcon3
    if hasattr(model, 'model') and hasattr(model.model, 'layers'):
        layers = model.model.layers
        print(f"Number of layers: {len(layers)} (model.model.layers)")
        layer_access = "model.layers"
    elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
        layers = model.transformer.h
        print(f"Number of layers: {len(layers)} (model.transformer.h)")
        layer_access = "transformer.h"
    else:
        print("Exploring model structure...")
        for attr in ['model', 'transformer']:
            if hasattr(model, attr):
                submodel = getattr(model, attr)
                for subattr in ['layers', 'h', 'decoder']:
                    if hasattr(submodel, subattr):
                        layers = getattr(submodel, subattr)
                        if hasattr(layers, '__len__'):
                            print(f"Found {len(layers)} layers at {attr}.{subattr}")
                            layer_access = f"{attr}.{subattr}"
                            break
except Exception as e:
    print(f"Error inspecting architecture: {e}")
    # We'll detect this during the actual experiments
    layer_access = "unknown"

Loading Falcon model with NNsight...
✓ Model loaded successfully!

Test prompt: Count the number of words in the following list that match the given type, and put the numerical answer in parentheses.
    Type: fruits
    List: [watermelon grape apple peach chef pineapple]
    Answer: (


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.


✓ Model test successful!

Model architecture:
Number of layers: 28 (model.model.layers)


#### Record clean activations (baseline)

In [24]:
# ===================================================================
# CELL 5: Record Clean Activations (Baseline)
# ===================================================================

def record_clean_activations(model, examples, max_examples=50):
    """Record clean forward passes for baseline comparison"""
    
    print(f"Recording clean activations for {min(max_examples, len(examples))} examples...")
    
    clean_data = []
    
    for i, item in enumerate(tqdm(examples[:max_examples])):
        example = item['example']
        prompt = create_prompt(example)
        
        try:
            # Store references to save activations
            saved_activations = {}
            saved_output = None
            
            with model.generate(prompt, max_new_tokens=5, scan=False, validate=False) as tracer:
                # Store activations from all layers using correct pattern
                for layer_idx in range(28):  # We know Falcon3 has 28 layers
                    # Get hidden states from this layer using model.model.layers
                    layer_output = model.model.layers[layer_idx].output
                    # Save the proxy objects
                    saved_activations[layer_idx] = layer_output[0].save()
                
                saved_output = model.generator.output.save()
            
            # Now access the saved values (outside the context)
            final_activations = {}
            for layer_idx in range(28):
                # Check if it's a proxy with .value or a direct tensor
                activation = saved_activations[layer_idx]
                if hasattr(activation, 'value'):
                    final_activations[layer_idx] = activation.value.detach().cpu()
                else:
                    final_activations[layer_idx] = activation.detach().cpu()
                
            # Handle output similarly
            if hasattr(saved_output, 'value'):
                final_output = saved_output.value
            else:
                final_output = saved_output
                
            clean_data.append({
                'example_id': item['example_id'],
                'prompt': prompt,
                'output': final_output,
                'activations': final_activations,
                'target_count': item['target_count']
            })
                
        except Exception as e:
            print(f"Error processing example {i}: {e}")
            continue
    
    print(f"✓ Recorded {len(clean_data)} clean activations")
    return clean_data

In [25]:
# Record baseline activations for a subset of examples
print("Starting baseline activation recording...")
baseline_activations = record_clean_activations(model, filtered_data, max_examples=5)

if len(baseline_activations) > 0:
    print(f"✓ Baseline recording successful!")
    print(f"Activation shape example: {baseline_activations[0]['activations'][0].shape}")
    print(f"Number of layers recorded: {len(baseline_activations[0]['activations'])}")
else:
    print("❌ No baseline activations recorded - check for errors above")

Starting baseline activation recording...
Recording clean activations for 5 examples...


  0%|                                                     | 0/5 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.
 20%|█████████                                    | 1/5 [00:02<00:11,  2.99s/it]Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.
 40%|██████████████████                           | 2/5 [00:05<00:08,  2.97s/it]Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.
 60%|███████████████████████████                  | 3/5 [00:08<00:05,  2.56s/it]Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.
 80%|████████████████████████████████████         | 4/5 [00:11<00:02,  2.73s/it]Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.
100%|█████████████████████████████████████████████| 5/5 [00:13<00:00,  2.60s/it]

✓ Recorded 5 clean activations
✓ Baseline recording successful!
Activation shape example: torch.Size([1, 48, 3072])
Number of layers recorded: 28





In [26]:
# ===================================================================
# Inspect Baseline Outputs
# ===================================================================

print("Examining model outputs from baseline recording:")
print("=" * 60)

for i, data in enumerate(baseline_activations):
    print(f"\nExample {i+1}:")
    print(f"Example ID: {data['example_id']}")
    print(f"Target count: {data['target_count']}")
    
    # Show the prompt
    prompt = data['prompt']
    print(f"Prompt: {prompt}")
    
    # Decode and show the output
    output = data['output']
    if hasattr(output, 'cpu'):
        output = output.cpu()
    
    if torch.is_tensor(output):
        tokens = output.tolist()
        if isinstance(tokens[0], list):
            tokens = tokens[0]  # Remove batch dimension if present
    else:
        tokens = output
    
    # Decode the full sequence
    full_response = model.tokenizer.decode(tokens, skip_special_tokens=True)
    print(f"Full response: {full_response}")
    
    # Try to extract the answer
    import re
    answer_match = re.search(r'(\d+)\)', full_response)
    if answer_match:
        predicted_answer = int(answer_match.group(1))
        correct = predicted_answer == data['target_count']
        print(f"Predicted answer: {predicted_answer} {'✓' if correct else '✗'}")
    else:
        print("Could not extract numerical answer")
    
    print("-" * 40)

# Summary statistics
print(f"\nSummary:")
print(f"Total examples processed: {len(baseline_activations)}")
print(f"Activation tensor shape: {baseline_activations[0]['activations'][0].shape}")
print(f"Number of layers captured: {len(baseline_activations[0]['activations'])}")

Examining model outputs from baseline recording:

Example 1:
Example ID: 4
Target count: 5
Prompt: Count the number of words in the following list that match the given type, and put the numerical answer in parentheses.
    Type: fruits
    List: [watermelon grape apple peach chef pineapple]
    Answer: (
Full response: Count the number of words in the following list that match the given type, and put the numerical answer in parentheses.
    Type: fruits
    List: [watermelon grape apple peach chef pineapple]
    Answer: (5)

<
Predicted answer: 5 ✓
----------------------------------------

Example 2:
Example ID: 14
Target count: 5
Prompt: Count the number of words in the following list that match the given type, and put the numerical answer in parentheses.
    Type: professions
    List: [Sydney chef basketball teacher doctor lawyer engineer]
    Answer: (
Full response: Count the number of words in the following list that match the given type, and put the numerical answer in parenthes

#### Expt 1: Which layers encode running count?

In [49]:
# # ===================================================================
# # CELL 6: Test Basic Activation Access (No Patching Yet)
# # ===================================================================

# def test_activation_access(model, example, layer_idx, position):
#     """Test if we can access activations at specific positions"""
    
#     prompt = create_prompt(example['example'])
#     print(f"Testing layer {layer_idx}, position {position}")
#     print(f"Prompt: {prompt[:100]}...")
    
#     with model.generate(prompt, max_new_tokens=5, scan=False, validate=False, pad_token_id=model.tokenizer.eos_token_id) as tracer:
#         # Try to access the layer output
#         layer_output = model.model.layers[layer_idx].output
        
#         # Try to access specific position
#         batch_output = layer_output[0]  # Batch dimension
#         position_output = batch_output[position]  # Position dimension
        
#         # Save the activation
#         saved_activation = position_output.save()
        
#         # Also save the full output for comparison
#         full_output = model.generator.output.save()
    
#     # Access the saved values
#     if hasattr(saved_activation, 'value'):
#         activation_value = saved_activation.value
#     else:
#         activation_value = saved_activation
        
#     if hasattr(full_output, 'value'):
#         output_value = full_output.value
#     else:
#         output_value = full_output
        
#     print(f"✓ Activation shape: {activation_value.shape}")
#     print(f"✓ Output shape: {output_value.shape}")
    
#     # Decode the output
#     if hasattr(output_value, 'cpu'):
#         tokens = output_value.cpu()
#     else:
#         tokens = output_value
        
#     if torch.is_tensor(tokens):
#         token_list = tokens.tolist()
#         if isinstance(token_list[0], list):
#             token_list = token_list[0]
#     else:
#         token_list = tokens
        
#     response = model.tokenizer.decode(token_list, skip_special_tokens=True)
#     print(f"Response: ...{response[-50:]}")
    
#     return activation_value, output_value

# # Test activation access on a few examples
# print("Testing basic activation access...")

# test_examples = filtered_data[:2]
# test_layer = 5
# test_position = 0

# for i, example in enumerate(test_examples):
#     print(f"\n--- Example {i+1} ---")
#     try:
#         activation, output = test_activation_access(model, example, test_layer, test_position)
#         print("✓ Success!")
#     except Exception as e:
#         print(f"✗ Error: {e}")
#         import traceback
#         traceback.print_exc()
#         break

# print("\nIf basic access works, we can try simple patching next...")

Testing basic activation access...

--- Example 1 ---
Testing layer 5, position 0
Prompt: Count the number of words in the following list that match the given type, and put the numerical ans...
✓ Activation shape: torch.Size([48, 3072])
✓ Output shape: torch.Size([1, 53])
Response: ...ape apple peach chef pineapple]
    Answer: (5)

<
✓ Success!

--- Example 2 ---
Testing layer 5, position 0
Prompt: Count the number of words in the following list that match the given type, and put the numerical ans...
✓ Activation shape: torch.Size([48, 3072])
✓ Output shape: torch.Size([1, 53])
Response: ...teacher doctor lawyer engineer]
    Answer: (5)

<
✓ Success!

If basic access works, we can try simple patching next...


In [54]:
# # ===================================================================
# # CELL 6: Working Activation Modification (FINALLY FIXED!)
# # ===================================================================

# def test_correct_modification(model, prompt, layer_idx, position):
#     """Test modification with correct tensor indexing"""
    
#     print(f"Testing modification at layer {layer_idx}, position {position}")
    
#     # Get clean output
#     with model.generate(prompt, max_new_tokens=5, scan=False, validate=False, pad_token_id=model.tokenizer.eos_token_id) as tracer:
#         clean_output = model.generator.output.save()
    
#     clean_answer = extract_clean_answer(model, clean_output)
#     print(f"Clean answer: {clean_answer}")
    
#     # Now modify with correct indexing
#     with model.generate(prompt, max_new_tokens=5, scan=False, validate=False, pad_token_id=model.tokenizer.eos_token_id) as tracer:
#         layer_output = model.model.layers[layer_idx].output
        
#         # layer_output is tuple, first element is tensor [1, 48, 3072]
#         hidden_states = layer_output[0]  # Get the tensor [1, 48, 3072]
        
#         # Now access [batch=0, position, hidden_dim]
#         target_activation = hidden_states[0, position, :]  # Shape: [3072]
        
#         # Zero it out
#         hidden_states[0, position, :] = torch.zeros_like(target_activation)
        
#         modified_output = model.generator.output.save()
    
#     modified_answer = extract_clean_answer(model, modified_output)
#     print(f"Modified answer: {modified_answer}")
    
#     return clean_answer, modified_answer

# def test_cross_example_patch(model, source_example, target_example, layer_idx, position):
#     """Test patching between examples with correct indexing"""
    
#     source_prompt = create_prompt(source_example['example'])
#     target_prompt = create_prompt(target_example['example'])
    
#     print(f"Cross-example patch: layer {layer_idx}, position {position}")
#     print(f"Source count: {source_example['target_count']}")
#     print(f"Target count: {target_example['target_count']}")
    
#     # Get clean target output
#     with model.generate(target_prompt, max_new_tokens=5, scan=False, validate=False, pad_token_id=model.tokenizer.eos_token_id) as tracer:
#         clean_output = model.generator.output.save()
    
#     clean_answer = extract_clean_answer(model, clean_output)
#     print(f"Clean target answer: {clean_answer}")
    
#     # Get source activation
#     with model.generate(source_prompt, max_new_tokens=5, scan=False, validate=False, pad_token_id=model.tokenizer.eos_token_id) as tracer:
#         source_layer = model.model.layers[layer_idx].output
#         source_hidden = source_layer[0]  # [1, seq_len, hidden_dim]
#         source_activation = source_hidden[0, position, :].save()  # [hidden_dim]
    
#     # Apply patch to target
#     with model.generate(target_prompt, max_new_tokens=5, scan=False, validate=False, pad_token_id=model.tokenizer.eos_token_id) as tracer:
#         target_layer = model.model.layers[layer_idx].output
#         target_hidden = target_layer[0]  # [1, seq_len, hidden_dim]
        
#         # Get the saved activation value
#         if hasattr(source_activation, 'value'):
#             activation_value = source_activation.value
#         else:
#             activation_value = source_activation
        
#         # Apply patch
#         target_hidden[0, position, :] = activation_value
        
#         patched_output = model.generator.output.save()
    
#     patched_answer = extract_clean_answer(model, patched_output)
#     print(f"Patched answer: {patched_answer}")
    
#     return clean_answer, patched_answer

# def extract_clean_answer(model, output):
#     """Extract answer from model output"""
#     if hasattr(output, 'value'):
#         tokens = output.value
#     else:
#         tokens = output
        
#     if hasattr(tokens, 'cpu'):
#         tokens = tokens.cpu()
    
#     if torch.is_tensor(tokens):
#         token_list = tokens.tolist()
#         if isinstance(token_list[0], list):
#             token_list = token_list[0]
#     else:
#         token_list = tokens
    
#     response = model.tokenizer.decode(token_list, skip_special_tokens=True)
    
#     import re
#     prompt_end = response.find("Answer: (")
#     if prompt_end != -1:
#         answer_part = response[prompt_end + len("Answer: ("):]
#         match = re.search(r'^(\d+)', answer_part)
#         if match:
#             return int(match.group(1))
    
#     return -1

# # Test within-context modification
# print("=== Testing Within-Context Modification ===")
# test_example = filtered_data[0]
# test_prompt = create_prompt(test_example['example'])

# clean_ans, modified_ans = test_correct_modification(model, test_prompt, 5, 20)

# print(f"Expected: {test_example['target_count']}")
# print(f"Clean: {clean_ans}")
# print(f"Modified: {modified_ans}")

# if modified_ans != clean_ans:
#     print("✓ Within-context modification works!")
    
#     # Test cross-example patching
#     print("\n=== Testing Cross-Example Patching ===")
#     source_example = filtered_data[0]
#     target_example = filtered_data[1]
    
#     try:
#         clean_target, patched_target = test_cross_example_patch(model, source_example, target_example, 5, 20)
        
#         if patched_target != clean_target:
#             print("✓ Cross-example patching works!")
#             print("🎉 Ready for full layer analysis!")
#         else:
#             print("? Patching didn't change answer - may need different position")
            
#     except Exception as e:
#         print(f"Cross-example patching failed: {e}")
        
# else:
#     print("? No change - try different position")

=== Testing Within-Context Modification ===
Testing modification at layer 5, position 20
Clean answer: 5
Modified answer: 5
Expected: 5
Clean: 5
Modified: 5
? No change - try different position


In [55]:
def single_layer_patch(model, source_example, target_example, source_pos, target_pos, layer_idx):
    """Patch a single layer between two examples - WORKING VERSION"""
    
    source_prompt = create_prompt(source_example['example'])
    target_prompt = create_prompt(target_example['example'])
    
    # Get clean target output
    with model.generate(target_prompt, max_new_tokens=5, scan=False, validate=False, pad_token_id=model.tokenizer.eos_token_id) as tracer:
        clean_output = model.generator.output.save()
    
    # Get source activation
    with model.generate(source_prompt, max_new_tokens=5, scan=False, validate=False, pad_token_id=model.tokenizer.eos_token_id) as tracer:
        source_layer = model.model.layers[layer_idx].output
        source_hidden = source_layer[0]  # [1, seq_len, hidden_dim]
        source_activation = source_hidden[0, source_pos, :].save()  # [hidden_dim]
    
    # Apply patch to target
    with model.generate(target_prompt, max_new_tokens=5, scan=False, validate=False, pad_token_id=model.tokenizer.eos_token_id) as tracer:
        target_layer = model.model.layers[layer_idx].output
        target_hidden = target_layer[0]  # [1, seq_len, hidden_dim]
        
        # Apply patch
        if hasattr(source_activation, 'value'):
            activation_value = source_activation.value
        else:
            activation_value = source_activation
            
        target_hidden[0, target_pos, :] = activation_value
        
        patched_output = model.generator.output.save()
    
    return {
        'clean_output': clean_output,
        'patched_output': patched_output,
        'source_prompt': source_prompt,
        'target_prompt': target_prompt
    }

In [56]:
def extract_answer_from_output(model, output):
    """Extract numerical answer from model output"""
    if hasattr(output, 'value'):
        tokens = output.value
    else:
        tokens = output
        
    if hasattr(tokens, 'cpu'):
        tokens = tokens.cpu()
    
    if torch.is_tensor(tokens):
        token_list = tokens.tolist()
        if isinstance(token_list[0], list):
            token_list = token_list[0]
    else:
        token_list = tokens
    
    response = model.tokenizer.decode(token_list, skip_special_tokens=True)
    
    import re
    prompt_end = response.find("Answer: (")
    if prompt_end != -1:
        answer_part = response[prompt_end + len("Answer: ("):]
        match = re.search(r'^(\d+)', answer_part)
        if match:
            return int(match.group(1))
    
    return -1

In [57]:
def test_layer_effectiveness(model, control_pairs, test_layers, test_positions, max_pairs=5):
    """Test which layers and positions are most effective for patching"""
    
    print("Testing layer effectiveness...")
    
    # Use positive control pairs (same count, same position)
    test_pairs = control_pairs['same_count_same_pos'][:max_pairs]
    
    results = []
    
    for layer_idx in test_layers:
        print(f"\n--- Testing Layer {layer_idx} ---")
        
        for position in test_positions:
            layer_results = []
            
            for pair_idx, pair in enumerate(test_pairs):
                try:
                    result = single_layer_patch(
                        model,
                        pair['source_example'],
                        pair['target_example'], 
                        position,  # Use same position for both
                        position,
                        layer_idx
                    )
                    
                    # Extract answers
                    clean_answer = extract_answer_from_output(model, result['clean_output'])
                    patched_answer = extract_answer_from_output(model, result['patched_output'])
                    
                    target_correct_answer = pair['target_example']['target_count']
                    accuracy_preserved = (patched_answer == target_correct_answer)
                    
                    layer_results.append(accuracy_preserved)
                    
                except Exception as e:
                    print(f"      Error: {e}")
                    layer_results.append(False)
            
            if layer_results:
                success_rate = sum(layer_results) / len(layer_results)
                print(f"  Position {position}: {success_rate:.1%} accuracy preservation ({sum(layer_results)}/{len(layer_results)})")
                
                results.append({
                    'layer': layer_idx,
                    'position': position,
                    'success_rate': success_rate,
                    'successes': sum(layer_results),
                    'total': len(layer_results)
                })
    
    return results

In [None]:
# Test effectiveness across multiple layers and positions
print("Finding effective intervention points...")

# Test a range of layers (early, middle, late)
test_layers = [2, 5, 10, 15, 20, 25]

# Test key positions in the sequence
test_positions = [5, 10, 15, 20, 25, 30]

# Run the effectiveness test
effectiveness_results = test_layer_effectiveness(
    model, 
    control_pairs, 
    test_layers, 
    test_positions, 
    max_pairs=3  # Small number for fast testing
)

Finding effective intervention points...
Testing layer effectiveness...

--- Testing Layer 2 ---
  Position 5: 100.0% accuracy preservation (3/3)
  Position 10: 100.0% accuracy preservation (3/3)
  Position 15: 100.0% accuracy preservation (3/3)
  Position 20: 100.0% accuracy preservation (3/3)
  Position 25: 100.0% accuracy preservation (3/3)
  Position 30: 100.0% accuracy preservation (3/3)

--- Testing Layer 5 ---
  Position 5: 100.0% accuracy preservation (3/3)
  Position 10: 100.0% accuracy preservation (3/3)
  Position 15: 100.0% accuracy preservation (3/3)
  Position 20: 100.0% accuracy preservation (3/3)
  Position 25: 100.0% accuracy preservation (3/3)
  Position 30: 100.0% accuracy preservation (3/3)

--- Testing Layer 10 ---
  Position 5: 100.0% accuracy preservation (3/3)
  Position 10: 100.0% accuracy preservation (3/3)
  Position 15: 100.0% accuracy preservation (3/3)
  Position 20: 100.0% accuracy preservation (3/3)
  Position 25: 100.0% accuracy preservation (3/3)
  Pos

In [None]:
# Analyze results
print(f"\n=== EFFECTIVENESS ANALYSIS ===")

if effectiveness_results:
    # Sort by success rate
    sorted_results = sorted(effectiveness_results, key=lambda x: x['success_rate'], reverse=True)
    
    print("Top 5 most effective intervention points:")
    for i, result in enumerate(sorted_results[:5]):
        print(f"{i+1}. Layer {result['layer']}, Position {result['position']}: {result['success_rate']:.1%}")
    
    # Find best layers
    layer_scores = {}
    for result in effectiveness_results:
        layer = result['layer']
        if layer not in layer_scores:
            layer_scores[layer] = []
        layer_scores[layer].append(result['success_rate'])
    
    print(f"\nBest layers (average across positions):")
    for layer in sorted(layer_scores.keys()):
        avg_score = sum(layer_scores[layer]) / len(layer_scores[layer])
        print(f"  Layer {layer}: {avg_score:.1%}")
    
    # Recommend top layers for full analysis
    best_layers = sorted(layer_scores.keys(), key=lambda l: sum(layer_scores[l])/len(layer_scores[l]), reverse=True)[:3]
    print(f"\n🎯 Recommended layers for full analysis: {best_layers}")
    
else:
    print("No successful interventions found. May need to:")
    print("- Try different layers")
    print("- Use different positions") 
    print("- Check if the control pairs are appropriate")

print(f"\n✅ Cell 6 prototype complete! Patching pipeline is working.")

#### Cell 7

In [None]:
# ===================================================================
# CELL 7: Test All 4 Controls (Full 2x2 Analysis)
# ===================================================================

def test_all_controls(model, control_pairs, test_layers, test_position, max_pairs=5):
    """Test all 4 control conditions to see the real differences"""
    
    print(f"Testing all 4 controls at position {test_position}...")
    
    results = {}
    
    control_names = {
        'same_count_same_pos': 'Positive Control 1 (same count, same position)',
        'same_count_diff_pos': 'Positive Control 2 (same count, different position)', 
        'diff_count_same_pos': 'Negative Control 1 (different count, same position)',
        'diff_count_diff_pos': 'Negative Control 2 (different count, different position)'
    }
    
    for control_type, pairs in control_pairs.items():
        if len(pairs) == 0:
            continue
            
        print(f"\n--- {control_names[control_type]} ---")
        test_pairs = pairs[:max_pairs]
        
        control_results = {}
        
        for layer_idx in test_layers:
            layer_successes = []
            
            for pair in test_pairs:
                try:
                    # Use the actual positions from the pair for different position tests
                    if 'diff_pos' in control_type:
                        src_pos = pair['source_position'] 
                        tgt_pos = pair['target_position']
                    else:
                        src_pos = test_position
                        tgt_pos = test_position
                    
                    result = single_layer_patch(
                        model,
                        pair['source_example'],
                        pair['target_example'],
                        src_pos,
                        tgt_pos, 
                        layer_idx
                    )
                    
                    clean_answer = extract_answer_from_output(model, result['clean_output'])
                    patched_answer = extract_answer_from_output(model, result['patched_output'])
                    
                    target_correct = pair['target_example']['target_count']
                    accuracy_preserved = (patched_answer == target_correct)
                    layer_successes.append(accuracy_preserved)
                    
                except Exception as e:
                    print(f"    Layer {layer_idx} error: {e}")
                    layer_successes.append(False)
            
            if layer_successes:
                success_rate = sum(layer_successes) / len(layer_successes)
                control_results[layer_idx] = success_rate
                print(f"  Layer {layer_idx}: {success_rate:.1%}")
        
        results[control_type] = control_results
    
    return results

# Test all controls on a few key layers
key_layers = [5, 10, 15, 20]  # Middle layers most likely to be important
test_position = 15  # Safe position that should exist in most sequences

all_results = test_all_controls(model, control_pairs, key_layers, test_position, max_pairs=3)

# Analyze the pattern
print(f"\n=== CONTROL COMPARISON ===")
for layer in key_layers:
    print(f"\nLayer {layer}:")
    for control_type, results in all_results.items():
        if layer in results:
            rate = results[layer]
            print(f"  {control_type}: {rate:.1%}")

#### Visualize layer effects

In [None]:
# Plot layer criticality
plt.figure(figsize=(12, 6))
layers = list(layer_effects.keys())
effects = list(layer_effects.values())

plt.plot(layers, effects, 'b-o', linewidth=2, markersize=6)
plt.xlabel('Layer Index')
plt.ylabel('Accuracy Preservation Rate')
plt.title('Layer Criticality for Count Representation\n(Higher = More Critical)')
plt.grid(True, alpha=0.3)
plt.ylim(0, 1)

# Highlight top layers
top_layers = sorted(layer_effects.items(), key=lambda x: x[1], reverse=True)[:5]
for layer_idx, effect in top_layers:
    plt.annotate(f'L{layer_idx}: {effect:.3f}', 
                xy=(layer_idx, effect), 
                xytext=(5, 5), 
                textcoords='offset points',
                fontsize=8,
                bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.5))

plt.tight_layout()
plt.show()

In [None]:
print("Top 5 most critical layers:")
for i, (layer_idx, effect) in enumerate(top_layers):
    print(f"{i+1}. Layer {layer_idx}: {effect:.3f} accuracy preservation")

In [None]:
# Select top 3 layers for detailed analysis
critical_layers = [layer for layer, _ in top_layers[:3]]
print(f"\nSelected critical layers for detailed analysis: {critical_layers}")

#### 2x2 control matrix analysis

In [None]:
def run_2x2_analysis(model, control_pairs, critical_layers, max_pairs=20):
    """Run the full 2x2 control matrix analysis"""
    
    print("Running 2x2 control matrix analysis...")
    
    results = {}
    
    for condition in control_pairs.keys():
        print(f"\nTesting condition: {condition}")
        condition_results = {}
        
        for layer_idx in critical_layers:
            print(f"  Layer {layer_idx}...")
            
            test_pairs = control_pairs[condition][:max_pairs]
            accuracy_scores = []
            
            for pair in tqdm(test_pairs, desc=f"Layer {layer_idx}"):
                try:
                    result = single_layer_patch(
                        model,
                        pair['source_example'],
                        pair['target_example'],
                        pair['source_position'], 
                        pair['target_position'],
                        layer_idx
                    )
                    
                    # Extract answers
                    clean_answer = extract_answer_from_output(model, result['clean_output'])
                    patched_answer = extract_answer_from_output(model, result['patched_output'])
                    
                    # Check if patching preserved accuracy
                    target_correct = pair['target_example']['target_count']
                    accuracy_preserved = (patched_answer == target_correct)
                    accuracy_scores.append(accuracy_preserved)
                    
                except Exception as e:
                    accuracy_scores.append(False)
            
            condition_results[layer_idx] = np.mean(accuracy_scores)
            print(f"    Accuracy preservation: {condition_results[layer_idx]:.3f}")
        
        results[condition] = condition_results
    
    return results

In [None]:
# Run 2x2 analysis
matrix_results = run_2x2_analysis(model, control_pairs, critical_layers, max_pairs=15)

#### Visualize 2x2 results

In [None]:
# Create 2x2 heatmap for each critical layer
fig, axes = plt.subplots(1, len(critical_layers), figsize=(5*len(critical_layers), 4))
if len(critical_layers) == 1:
    axes = [axes]

for i, layer_idx in enumerate(critical_layers):
    # Extract data for this layer
    data_matrix = np.array([
        [matrix_results['same_count_same_pos'][layer_idx], matrix_results['same_count_diff_pos'][layer_idx]],
        [matrix_results['diff_count_same_pos'][layer_idx], matrix_results['diff_count_diff_pos'][layer_idx]]
    ])
    
    # Create heatmap
    sns.heatmap(data_matrix, 
                annot=True, 
                fmt='.3f', 
                cmap='RdYlGn',
                vmin=0, vmax=1,
                xticklabels=['Same Position', 'Different Position'],
                yticklabels=['Same Count', 'Different Count'],
                ax=axes[i])
    
    axes[i].set_title(f'Layer {layer_idx}\nAccuracy Preservation')
    axes[i].set_xlabel('Position Match')
    axes[i].set_ylabel('Count Match')

plt.tight_layout()
plt.show()

In [None]:
# Summary table
print("\n" + "="*60)
print("SUMMARY: 2x2 Control Matrix Results")
print("="*60)

In [None]:
summary_df = pd.DataFrame(matrix_results).round(3)
print(summary_df)

#### Interpret results

In [None]:
print("\n" + "="*60)
print("INTERPRETATION OF RESULTS")
print("="*60)

# Analyze the pattern
avg_results = {}
for condition in matrix_results.keys():
    avg_results[condition] = np.mean(list(matrix_results[condition].values()))

print(f"Average accuracy preservation across layers:")
for condition, avg_acc in avg_results.items():
    print(f"  {condition}: {avg_acc:.3f}")

In [None]:
# Key comparisons
same_count_same_pos = avg_results['same_count_same_pos']
same_count_diff_pos = avg_results['same_count_diff_pos'] 
diff_count_same_pos = avg_results['diff_count_same_pos']
diff_count_diff_pos = avg_results['diff_count_diff_pos']

In [None]:
print(f"\nKey Findings:")

print(f"\n1. Are count representations position-independent?")
if same_count_diff_pos > 0.5:
    print(f"   YES: Same count + different position works well ({same_count_diff_pos:.3f})")
    print(f"   This suggests count representations are somewhat position-independent")
else:
    print(f"   NO: Same count + different position fails ({same_count_diff_pos:.3f})")
    print(f"   This suggests count representations are position-specific")

In [None]:
print(f"\n2. What matters more: count match or position match?")
count_effect = same_count_same_pos - diff_count_same_pos
position_effect = same_count_same_pos - same_count_diff_pos

print(f"   Count match effect: {count_effect:.3f}")
print(f"   Position match effect: {position_effect:.3f}")

if count_effect > position_effect:
    print(f"   Count matching is more important than position matching")
else:
    print(f"   Position matching is more important than count matching")

In [None]:
print(f"\n3. Which layers are most critical?")
best_layer = max(critical_layers, key=lambda l: matrix_results['same_count_same_pos'][l])
print(f"   Layer {best_layer} shows strongest effects")
print(f"   This suggests layer {best_layer} is most important for count representation")

In [None]:
print(f"\n4. Overall assessment:")
if same_count_same_pos > 0.7:
    print(f"   STRONG evidence for internal count representations")
elif same_count_same_pos > 0.5:
    print(f"   MODERATE evidence for internal count representations") 
else:
    print(f"   WEAK evidence for internal count representations")

print(f"\n" + "="*60)
print("Analysis complete! 🎉")
print("="*60)