# Explore Decoupled Intelligence

This notebook investigates the **Structural Separation** hypothesis:
- Models contain distinct circuits for prompt categorization and response generation
- Categorization circuits are invariant to specific numeric inputs

Notes:
- This Colab runs on A100 GPU compute  
- Store your HuggingFace API token in the Colab HF_TOKEN secret
- Request HF access to the gated model via https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
- The detailed proposal/hypothesis is here https://docs.google.com/document/d/1x7n2iy1_LZXZNLQpxCzF84lZ8BEG6ZT3KWXC59erhJA
- The code base is here https://github.com/PhilipQuirke/LlmPromptCategorization

## Step 0: Import Libraries

Note that because of library mismatches to get the Colab to work you will need to 1) Run the code 2) Restart the session (when prompted) and 3) Run the code.

In [None]:
!pip install -q transformer-lens accelerate bitsandbytes

In [None]:
try:
    import google.colab
    from google.colab import userdata

    # --- Remove Colab-preinstalled ABI landmines ---
    #get_ipython().run_line_magic( "pip", "uninstall -y numpy pyarrow pandas scipy scikit-learn" )
except:
    pass

In [None]:
import os
import sys
import platform
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
import pandas as pd
import seaborn as sns
from transformers import GPTJForCausalLM, AutoTokenizer
import transformers
from transformer_lens import HookedTransformer, patching
from sklearn.decomposition import PCA
from huggingface_hub import login
import re
import pickle
from tqdm import tqdm


## Step 1: Config


In [None]:
# CategorizationGeneration (singleton) config class
class CG:
    # Model we are testing
    MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" # Gated Model. Need HF_TOKEN secret. Request access via https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct.

    # Default layer we expect to the model to use as the 'Categorization Layer'
    MODEL_LAYER = 16

    # Name of task set. Valid options are "maths", "linguistic" and "format"
    TASK_SET_NAME = "linguistic"

    # Number of tasks. Default is 6 maths tasks: Max, Min, Avg, Sum, Diff, Prod
    NUM_TASKS = 6

    # Number of examples. For Maths, we run each task using 6 times - using a different pair of numbers each time
    NUMBER_EXAMPLES = 6

    # Maximum new tokens for model to generate. Answer will likely be in the first few words of these tokens.
    MAX_NEW_TOKENS = 30

    # With the Singluar Vector-Based interpretability technique, we consider 25 vector directions per neuron/MLP
    SVD_K_DIRECTIONS = 35

    def in_colab():
        return "google.colab" in sys.modules

    def svd_save_path(this):
        return f"{this.MODEL_NAME}_SVD_top{this.SVD_K_DIRECTIONS}.pkl"

## Step 2: Load open-source model

In [None]:
# Retrieve the token from Colab Secrets
hf_token = userdata.get('HF_TOKEN')

# Log into Hugging Face Hub
login(hf_token)

In [None]:
transformers.utils.logging.set_verbosity_error()

In [None]:
#model = HookedTransformer.from_pretrained(
model = HookedTransformer.from_pretrained_no_processing( # Preferred with reduced precision
    model_name=CG.MODEL_NAME,
    device="cuda" if torch.cuda.is_available() else "cpu",
    dtype="float16",              # A100 handles float16 well
    fold_ln=True,
    center_writing_weights=True,
    center_unembed=True)
    #tokenizer_pad_token=None,
    #n_devices=1 )                  # Change to >1 if sharded across GPUs

## Step 3: Task Data

Define a class to store test data.
Create test data for the 3 task sets "maths", "linguistic" and "format".

In [None]:
# CategorizationGeneration (singleton) config class
class test_data_class:
    # List if tasks we are testing
    tasks = []

    # Set of synonyms for each task
    synonyms = {}

    # Number of tasks. Default is 6 maths tasks: Max, Min, Avg, Sum, Diff, Prod
    test_data = []

    # Maximum length of an answer in characters
    max_answer_chars = 6

### Step 3A: Define Maths Tasks
We use the core math tasks that share identical phrasing up until the final task-identifying word .

In [None]:
maths_data_config = test_data_class()

maths_data_config.tasks = ['min', 'max', 'avg', 'sum', 'diff', 'prod']

maths_data_config.synonyms = {
    'min': ['min', 'minimum', 'least', 'smaller', 'lesser', 'smallest'],
    'max': ['max', 'maximum', 'largest', 'biggest', 'larger'],
    'avg': ['avg', 'average', 'mean', 'median'],
    'sum': ['sum', 'add', 'plus', 'total', 'addition', 'aggregate'],
    'diff': ['diff', 'difference', 'minus', 'subtract', 'subtraction'],
    'prod': ['prod', 'product', 'multiply', 'multiplication', 'times']
}

min_prefix = "Q: What is 5 and 7 min? A: 5\nQ: What is 4 and 3 min? A: 3\nQ:"
max_prefix = "Q: What is 5 and 7 max? A: 7\nQ: What is 4 and 3 max? A: 4\nQ:"
avg_prefix = "Q: What is 5 and 9 avg? A: 7\nQ: What is 1 and 9 avg? A: 5\nQ:"
sum_prefix = "Q: What is 5 and 5 sum? A: 10\nQ: What is 2 and 2 sum? A: 4\nQ:"
diff_prefix = "Q: What is 5 and 7 diff? A: 2\nQ: What is 2 and 8 diff? A: 6\nQ:"
prod_prefix = "Q: What is 5 and 3 prod? A: 15\nQ: What is 2 and 6 prod? A: 12\nQ:"
maths_data_config.test_data = [
    {"prompt": f"{min_prefix} Given 21 and 39 what is the minimum?", "task": "min", "gt": ["21"], "x":"21", "y":"39"},
    {"prompt": f"{min_prefix} Given 11 & 23 what is the smallest?", "task": "min", "gt": ["11"], "x":"11", "y":"23"},
    {"prompt": f"{min_prefix} Given 65 and 49 what is the minimum?", "task": "min", "gt": ["49"], "x":"65", "y":"49"},
    {"prompt": f"{min_prefix} Given 32 and 11 which is lesser?", "task": "min", "gt": ["11"], "x":"32", "y":"11"},
    {"prompt": f"{min_prefix} Given 19 or 12 which is smallest?", "task": "min", "gt": ["12"], "x":"19", "y":"12"},
    {"prompt": f"{min_prefix} Given 17 and 23 which is minimum?", "task": "min", "gt": ["17"], "x":"17", "y":"23"},

    {"prompt": f"{max_prefix} Given 13 and 3 what is the maximum?", "task": "max", "gt": ["13"], "x":"13", "y":"3"},
    {"prompt": f"{max_prefix} Given 15 & 13 which is the largest?", "task": "max", "gt": ["15"], "x":"15", "y":"13"},
    {"prompt": f"{max_prefix} Given 22 and 36 what is the biggest?", "task": "max", "gt": ["36"], "x":"22", "y":"36"},
    {"prompt": f"{max_prefix} Given 48 or 32 what is the largest?", "task": "max", "gt": ["48"], "x":"48", "y":"32"},
    {"prompt": f"{max_prefix} Given 19 and 12 what is the maximum?", "task": "max", "gt": ["19"], "x":"19", "y":"12"},
    {"prompt": f"{max_prefix} Given 18 and 12 what is the largest?", "task": "max", "gt": ["18"], "x":"18", "y":"12"},

    {"prompt": f"{avg_prefix} Given 25 and 9 what is the average?", "task": "avg", "gt": ["17"], "x":"25", "y":"9"},
    {"prompt": f"{avg_prefix} Given 14 & 4 what is the avg?", "task": "avg", "gt": ["9"], "x":"14", "y":"4"},
    {"prompt": f"{avg_prefix} Given 11 and 47 what is the mean?", "task": "avg", "gt": ["29"], "x":"11", "y":"47"},
    {"prompt": f"{avg_prefix} Given 54 and 12 what is the average?", "task": "avg", "gt": ["33"], "x":"54", "y":"12"},
    {"prompt": f"{avg_prefix} Given 9 & 13 what is the mean?", "task": "avg", "gt": ["11"], "x":"9", "y":"13"},
    {"prompt": f"{avg_prefix} Given 8 and 22 what is the average?", "task": "avg", "gt": ["15"], "x":"8", "y":"22"},

    {"prompt": f"{sum_prefix} Given 25 and 9 what is the sum?", "task": "sum", "gt": ["34"], "x":"25", "y":"9"},
    {"prompt": f"{sum_prefix} Given 14 & 3 what is the total?", "task": "sum", "gt": ["17"], "x":"14", "y":"3"},
    {"prompt": f"{sum_prefix} Given 12 and 47 what is the total?", "task": "sum", "gt": ["59"], "x":"12", "y":"47"},
    {"prompt": f"{sum_prefix} Given 55 and 12 what is the aggregate?", "task": "sum", "gt": ["67"], "x":"55", "y":"12"},
    {"prompt": f"{sum_prefix} Given 9 and 13 what is the aggregate?", "task": "sum", "gt": ["22"], "x":"9", "y":"13"},
    {"prompt": f"{sum_prefix} Given 8 and 22 what is the sum?", "task": "sum", "gt": ["30"], "x":"8", "y":"22"},

    {"prompt": f"{diff_prefix} Given 15 and 9 what is the difference?", "task": "diff", "gt": ["6"], "x":"15", "y":"9"},
    {"prompt": f"{diff_prefix} Given 14 & 3 what is the diff?", "task": "diff", "gt": ["11"], "x":"14", "y":"3"},
    {"prompt": f"{diff_prefix} Given 12 and 40 return the delta?", "task": "diff", "gt": ["28"], "x":"12", "y":"40"},
    {"prompt": f"{diff_prefix} Given 55 and 12 what is the delta?", "task": "diff", "gt": ["43"], "x":"55", "y":"12"},
    {"prompt": f"{diff_prefix} Given 19 & 13 what is the difference?", "task": "diff", "gt": ["6"], "x":"19", "y":"13"},
    {"prompt": f"{diff_prefix} Given 8 and 22 what is the diff?", "task": "diff", "gt": ["14"], "x":"8", "y":"22"},

    {"prompt": f"{prod_prefix} Given 5 and 9 what is the product?", "task": "prod", "gt": ["45"], "x":"15", "y":"9"},
    {"prompt": f"{prod_prefix} Given 14 & 3, what is the multiplication?", "task": "prod", "gt": ["42"], "x":"14", "y":"3"},
    {"prompt": f"{prod_prefix} Given 12 and 40, what is the product?", "task": "prod", "gt": ["480"], "x":"12", "y":"40"},
    {"prompt": f"{prod_prefix} Given 55 by 3, what is the total multiply?", "task": "prod", "gt": ["165"], "x":"55", "y":"13"},
    {"prompt": f"{prod_prefix} Given 19 and 3, what is the multiply?", "task": "prod", "gt": ["57"], "x":"19", "y":"11"},
    {"prompt": f"{prod_prefix} Given 8 and 22 what is the total product?", "task": "prod", "gt": ["176"], "x":"8", "y":"22"},
]

maths_data_config.max_answer_chars = 5

### Step 3B: Define Linguistic Tasks

In [None]:
linguistic_data_config = test_data_class()

linguistic_data_config.tasks = ['synonym', 'antonym', 'definition', 'plural', 'part_of_speech']

linguistic_data_config.synonyms = {
    'synonym': ['synonym', 'alterative'],
    'antonym': ['antonym', 'opposite'],
    'definition': ['definition', 'meaning'],
    'plural': ['plural', 'plural'],
    'part_of_speech': ['part_of_speech'],
}

lang_prefix = "Q: Word 'dog' synonym? A: canine\nQ: Word 'hot' antonym? A: cold\nQ: "
linguistic_data_config.test_data = [
    # --- Synonyms ---
    {"prompt": f"{lang_prefix} Word 'ocean' synonym?", "task": "synonym", "gt": ["sea", "marine", "deep", "main"], "x":"ocean", "y":""},
    {"prompt": f"{lang_prefix} Word 'quick' synonym?", "task": "synonym", "gt": ["fast", "speedy", "rapid", "swift", "hasty"], "x":"quick", "y":""},
    {"prompt": f"{lang_prefix} Word 'happy' synonym?", "task": "synonym", "gt": ["glad", "joyful", "cheerful", "content", "joyous"], "x":"happy", "y":""},
    {"prompt": f"{lang_prefix} Word 'small' synonym?", "task": "synonym", "gt": ["little", "tiny", "miniature", "slight", "petite"], "x":"small", "y":""},
    {"prompt": f"{lang_prefix} Word 'start' synonym?", "task": "synonym", "gt": ["begin", "commence", "initiate", "launch", "open"], "x":"start", "y":""},

    # --- Antonyms ---
    {"prompt": f"{lang_prefix} Word 'ocean' antonym?", "task": "antonym", "gt": ["land", "shore", "coast"], "x":"ocean", "y":""},
    {"prompt": f"{lang_prefix} Word 'light' antonym?", "task": "antonym", "gt": ["dark", "heavy", "darkness", "dim"], "x":"light", "y":""},
    {"prompt": f"{lang_prefix} Word 'loud' antonym?", "task": "antonym", "gt": ["quiet", "soft", "silent", "faint", "muted"], "x":"loud", "y":""},
    {"prompt": f"{lang_prefix} Word 'empty' antonym?", "task": "antonym", "gt": ["full", "occupied", "packed", "filled"], "x":"empty", "y":""},
    {"prompt": f"{lang_prefix} Word 'wrong' antonym?", "task": "antonym", "gt": ["right", "correct", "proper", "true"], "x":"wrong", "y":""},

    # --- Definitions (Hypernyms/Categories) ---
    {"prompt": f"{lang_prefix} Word 'emerald' definition?", "task": "definition", "gt": ["gemstone", "jewel", "gem", "stone", "mineral"], "x":"emerald", "y":""},
    {"prompt": f"{lang_prefix} Word 'violin' definition?", "task": "definition", "gt": ["instrument", "musical instrument", "fiddle", "strings"], "x":"violin", "y":""},
    {"prompt": f"{lang_prefix} Word 'eagle' definition?", "task": "definition", "gt": ["bird", "raptor", "bird of prey", "animal"], "x":"eagle", "y":""},
    {"prompt": f"{lang_prefix} Word 'pentagon' definition?", "task": "definition", "gt": ["shape", "polygon", "5 sided", "five-sided"], "x":"pentagon", "y":""},
    {"prompt": f"{lang_prefix} Word 'microscope' definition?", "task": "definition", "gt": ["tool", "instrument", "optical", "device"], "x":"microscope", "y":""},

    # --- Plurals ---
    {"prompt": f"{lang_prefix} Word 'ocean' plural?", "task": "plural", "gt": ["oceans", "seas"], "x":"ocean", "y":""},
    {"prompt": f"{lang_prefix} Word 'mouse' plural?", "task": "plural", "gt": ["mice", "mouses"], "x":"mouse", "y":""},
    {"prompt": f"{lang_prefix} Word 'leaf' plural?", "task": "plural", "gt": ["leaves"], "x":"leaf", "y":""},
    {"prompt": f"{lang_prefix} Word 'tooth' plural?", "task": "plural", "gt": ["teeth"], "x":"tooth", "y":""},
    {"prompt": f"{lang_prefix} Word 'person' plural?", "task": "plural", "gt": ["people", "persons"], "x":"person", "y":""},

    # --- Part of Speech ---
    {"prompt": f"{lang_prefix} Word 'ocean' part_of_speech?", "task": "part_of_speech", "gt": ["noun"], "x":"ocean", "y":""},
    {"prompt": f"{lang_prefix} Word 'think' part_of_speech?", "task": "part_of_speech", "gt": ["verb", "action"], "x":"think", "y":""},
    {"prompt": f"{lang_prefix} Word 'blue' part_of_speech?", "task": "part_of_speech", "gt": ["adjective", "descriptor"], "x":"blue", "y":""},
    {"prompt": f"{lang_prefix} Word 'quickly' part_of_speech?", "task": "part_of_speech", "gt": ["adverb"], "x":"quickly", "y":""},
    {"prompt": f"{lang_prefix} Word 'under' part_of_speech?", "task": "part_of_speech", "gt": ["preposition"], "x":"under", "y":""},
]

linguistic_data_config.max_answer_chars = 18

### Step 3D: Define Formatting Tasks

In [None]:
format_data_config = test_data_class()

format_data_config.tasks = ['JSON', 'XML', 'CSV', 'YAML', 'TABLE']

format_data_config.synonyms = {
    'JSON': ['JSON', 'javascript object', 'curly braces'],
    'XML': ['XML', 'markup', 'tags'],
    'CSV': ['CSV', 'comma separated', 'spreadsheet format'],
    'YAML': ['YAML', 'nested list', 'key-value pairs'],
    'TABLE': ['Markdown table', 'table format', 'grid']
}

format_prefix = "Data: Name: John, Age: 30\nFormat: JSON\nOutput: {\"name\": \"John\", \"age\": 30}\n\nData: "
format_data_config.test_data = [

    {"prompt": f"{format_prefix} Name: Alice, ID: 602\nFormat: JSON", "task": "JSON", "gt": ["{\""], "x": "Alice", "y": "602"},
    {"prompt": f"{format_prefix} Name: Bob, ID: 123\nFormat: JSON", "task": "JSON", "gt": ["{\""], "x": "Bob", "y": "123"},
    {"prompt": f"{format_prefix} Name: Charlie, ID: 456\nFormat: JSON", "task": "JSON", "gt": ["{\""], "x": "Charlie", "y": "456"},
    {"prompt": f"{format_prefix} Name: Diana, ID: 789\nFormat: JSON", "task": "JSON", "gt": ["{\""], "x": "Diana", "y": "789"},
    {"prompt": f"{format_prefix} Name: Edward, ID: 101\nFormat: JSON", "task": "JSON", "gt": ["{\""], "x": "Edward", "y": "101"},

    {"prompt": f"{format_prefix} Name: Fiona, ID: 202\nFormat: XML", "task": "XML", "gt": ["<"], "x": "Fiona", "y": "202"},
    {"prompt": f"{format_prefix} Name: George, ID: 303\nFormat: XML", "task": "XML", "gt": ["<"], "x": "George", "y": "303"},
    {"prompt": f"{format_prefix} Name: Hannah, ID: 404\nFormat: XML", "task": "XML", "gt": ["<"], "x": "Hannah", "y": "404"},
    {"prompt": f"{format_prefix} Name: Ian, ID: 505\nFormat: XML", "task": "XML", "gt": ["<"], "x": "Ian", "y": "505"},
    {"prompt": f"{format_prefix} Name: Julia, ID: 606\nFormat: XML", "task": "XML", "gt": ["<"], "x": "Julia", "y": "606"},

    {"prompt": f"{format_prefix} Name: Kevin, ID: 707\nFormat: CSV", "task": "CSV", "gt": ["Kevin,"], "x": "Kevin", "y": "707"},
    {"prompt": f"{format_prefix} Name: Laura, ID: 808\nFormat: CSV", "task": "CSV", "gt": ["Laura,"], "x": "Laura", "y": "808"},
    {"prompt": f"{format_prefix} Name: Mike, ID: 909\nFormat: CSV", "task": "CSV", "gt": ["Mike,"], "x": "Mike", "y": "909"},
    {"prompt": f"{format_prefix} Name: Nora, ID: 111\nFormat: CSV", "task": "CSV", "gt": ["Nora,"], "x": "Nora", "y": "111"},
    {"prompt": f"{format_prefix} Name: Oscar, ID: 222\nFormat: CSV", "task": "CSV", "gt": ["Oscar,"], "x": "Oscar", "y": "222"},

    {"prompt": f"{format_prefix} Name: Peter, ID: 333\nFormat: YAML", "task": "YAML", "gt": ["name:"], "x": "Peter", "y": "333"},
    {"prompt": f"{format_prefix} Name: Queenie, ID: 444\nFormat: YAML", "task": "YAML", "gt": ["name:"], "x": "Queenie", "y": "444"},
    {"prompt": f"{format_prefix} Name: Richard, ID: 555\nFormat: YAML", "task": "YAML", "gt": ["name:"], "x": "Richard", "y": "555"},
    {"prompt": f"{format_prefix} Name: Sarah, ID: 666\nFormat: YAML", "task": "YAML", "gt": ["name:"], "x": "Sarah", "y": "666"},
    {"prompt": f"{format_prefix} Name: Tom, ID: 777\nFormat: YAML", "task": "YAML", "gt": ["name:"], "x": "Tom", "y": "777"},

    {"prompt": f"{format_prefix} Name: Ursula, ID: 888\nFormat: Markdown table", "task": "TABLE", "gt": ["|"], "x": "Ursula", "y": "888"},
    {"prompt": f"{format_prefix} Name: Victor, ID: 999\nFormat: Markdown table", "task": "TABLE", "gt": ["|"], "x": "Victor", "y": "999"},
    {"prompt": f"{format_prefix} Name: Wendy, ID: 121\nFormat: Markdown table", "task": "TABLE", "gt": ["|"], "x": "Wendy", "y": "121"},
    {"prompt": f"{format_prefix} Name: Xander, ID: 232\nFormat: Markdown table", "task": "TABLE", "gt": ["|"], "x": "Xander", "y": "232"},
    {"prompt": f"{format_prefix} Name: Yvonne, ID: 343\nFormat: Markdown table", "task": "TABLE", "gt": ["|"], "x": "Yvonne", "y": "343"},
]

format_data_config.max_answer_chars = 6

### Step 3E: Select Tasks set

In [None]:
if CG.TASK_SET_NAME == "maths":
    tasks = maths_data_config.tasks
    synonyms = maths_data_config.synonyms
    test_data = maths_data_config.test_data
    max_answer_chars = maths_data_config.max_answer_chars

elif CG.TASK_SET_NAME == "linguistic":
    tasks = linguistic_data_config.tasks
    synonyms = linguistic_data_config.synonyms
    test_data = linguistic_data_config.test_data
    max_answer_chars = linguistic_data_config.max_answer_chars

elif CG.TASK_SET_NAME == "format":
    tasks = format_data_config.tasks
    synonyms = format_data_config.synonyms
    test_data = format_data_config.test_data
    max_answer_chars = format_data_config.max_answer_chars

else:
    raise ValueError(f"Invalid task set name: {CG.TASK_SET_NAME}")


In [None]:
# Update config class
CG.NUM_TASKS = len(tasks)
CG.NUMBER_EXAMPLES = len(test_data) // len(tasks)

# Check all tasks have some test data
for d in test_data:
    assert d['task'] in tasks

In [None]:
# Generate the prompt list using the updated list-based GT
all_prompts = []
metadata = []

for item in test_data:
    all_prompts.append(item['prompt'])
    gt_display = "/".join(item['gt'])
    metadata.append({
        "task": item['task'],
        "pair": f"({item['x']},{item['y']})",
        "gt": gt_display
    })

print(all_prompts[0:2])
print(metadata[0:2])

## Step 4. Check Baseline Accuracy
If model can't answer the above prompts correctly, then it may not have categorization or generation circuits for the task concepts, making investigation useless.

Model answers can vary on the same question on different runs. So this test may overreport False instances

In [None]:
def _is_last_number_close(last_number: str, ground_truth: str) -> bool:
    try:
        return abs(float(last_number) - float(ground_truth)) < 0.001
    except (ValueError, TypeError):
        return False

def is_ground_truth_correct(answer: str, ground_truth_list: list) -> bool:
    """
    Returns True if ANY of the ground_truth strings appear in the answer.
    """
    # Remove trailing whitespace and punctuation from model answer
    answer_clean = answer.strip().rstrip('.!**')
    answer_no_comma = " " + answer_clean.replace(",", " ") + " "

    for gt in ground_truth_list:
        # Standardize the current GT for comparison
        gt_clean = str(gt).strip()

        # Check various common formatting patterns
        found = (
            gt_clean == answer_clean or
            f"**{gt_clean}**" in answer or
            f"boxed{{{gt_clean}}}" in answer or
            f" {gt_clean}\n" in answer_no_comma or
            f" {gt_clean} " in answer_no_comma or
            answer_clean.startswith(gt_clean)
        )

        # For numeric strings, apply the float closeness check
        if not found:
            numbers = re.findall(r'-?[\d,]+', answer_clean)
            numbers_clean = [num.replace(',', '') for num in numbers]
            if numbers_clean and _is_last_number_close(numbers_clean[-1], gt_clean):
                found = True

        if found:
            return True

    return False

In [None]:
assert( is_ground_truth_correct("A: sea\nQ: etc", ["sea", "marine", "deep", "main"]))
assert( is_ground_truth_correct("sea\nQ: etc", ["sea", "marine", "deep", "main"]))

In [None]:
def check_baseline_accuracy(data):
    print("Checking baseline accuracy:")

    results = []
    for i, d in enumerate(data):
      the_prompt = d["prompt"] + " A: "

      # Generate output
      output = model.generate(the_prompt, max_new_tokens=CG.MAX_NEW_TOKENS, stop_at_eos=True, verbose=False)
      the_output = output.replace(the_prompt, "").strip()[:max_answer_chars]

      is_correct = is_ground_truth_correct(the_output, d["gt"])
      results.append({"prompt": the_prompt[-35:], "output": the_output[:30], "correct": bool(is_correct)})

    accuracy_df = pd.DataFrame(results)
    correct_answers_count = accuracy_df['correct'].sum()
    total_answers_count = len(accuracy_df)

    print(f"Accuracy: {correct_answers_count} of {total_answers_count} = {correct_answers_count / total_answers_count * 100:.2f}%")

    pd.set_option('display.width', 100)
    print(accuracy_df[['prompt', 'output', 'correct']])


check_baseline_accuracy(test_data)

## Step 5: Layer-wise Separation Profile

Here we evaluate the layer that the categorization occurs at. In many models like GPT-NeoX or GPT-J, this typically occurs in the middle-to-late layers (e.g., layers 8–16 of 28).

- Intra-task Similarity (Blue Line): This represents the "Stability" of the categorization. According to your Structural Separation hypothesis, this should rise sharply and stay high once the model has recognized the "intent" (e.g., "summing"), regardless of the numbers provided.

- Inter-task Similarity (Red Line): This represents the "Ambiguity" between tasks. Ideally, this should remain low. If this line rises alongside the blue line, the model is seeing "mathematical intent" but failing to distinguish "sum" from "product."

- The Gating Point: You are looking for the point where the Blue line is highest and the Red line is lowest.

The lines diverge significantly in early layers, suggesting the Categorization Circuits may be simple enough (perhaps semantic detection of 2 number and keyword/synonym detection).

TODO: Retry with no numbers in prompt. Does graph change?
TODO: Retry with no task noun in prompt. Does graph become random?

In [None]:
def visualize_layerwise_separation_profile():

    # Initialize storage for metrics
    layer_indices = range(model.cfg.n_layers)
    intra_task_sims = []
    inter_task_sims = []

    print("Analyzing layer-wise separation profile...")

    # 1. Loop through all layers in the model
    for layer_idx in layer_indices:
        layer_activations = []

        # Extract activations for all prompts at the current layer
        for prompt in all_prompts:
            with torch.no_grad():
                # Note: We use the same cache logic as Step 4 but iterate through layers
                _, cache = model.run_with_cache(prompt, names_filter=lambda name: name.endswith("resid_post"))
                vec = cache["resid_post", layer_idx][0, -1, :].detach().cpu()
                layer_activations.append(vec)

        # Convert list to tensor: [NUM_TASKS * NUMBER_EXAMPLES, d_model]
        layer_tensor = torch.stack(layer_activations)

        # Calculate Centroid-Subtracted (Task-Specific) Vectors for this layer
        layer_centroid = layer_tensor.mean(dim=0)
        layer_specific = layer_tensor - layer_centroid

        # Normalize for cosine similarity calculation
        norm_layer = F.normalize(layer_specific, p=2, dim=1)

        # Calculate the full similarity matrix [25, 25] for this layer
        sim_matrix = torch.mm(norm_layer, norm_layer.t())

        # 2. Calculate Intra-task similarity
        # How similar are different number pairs within the same task block (diagonal N x N blocks)?
        # Handle the case where there's only one example per task
        if CG.NUMBER_EXAMPLES == 1:
            # If only one example per task, intra-task similarity is trivially 1 (self-similarity)
            # or undefined. For plotting purposes, we can assume perfect consistency.
            avg_intra_sim = 1.0
        else:
            intra_sim_vals = []
            for t in range(CG.NUM_TASKS):
                start_idx = t * CG.NUMBER_EXAMPLES
                end_idx = start_idx + CG.NUMBER_EXAMPLES
                # Extract the N x N sub-matrix for this task
                block = sim_matrix[start_idx:end_idx, start_idx:end_idx]
                # Get values excluding the self-similarity diagonal (which is always 1.0)
                mask = ~torch.eye(CG.NUMBER_EXAMPLES, dtype=bool)
                intra_sim_vals.append(block[mask].mean())

            avg_intra_sim = torch.stack(intra_sim_vals).mean().item()
        intra_task_sims.append(avg_intra_sim)

        # 3. Calculate Inter-task similarity
        # How similar are different tasks to each other (off-diagonal regions)?
        all_pairs_mask = torch.ones_like(sim_matrix, dtype=bool)
        for t in range(CG.NUM_TASKS):
            start_idx = t * CG.NUMBER_EXAMPLES
            end_idx = start_idx + CG.NUMBER_EXAMPLES
            # Mask out the diagonal intra-task blocks
            all_pairs_mask[start_idx:end_idx, start_idx:end_idx] = False

        avg_inter_sim = sim_matrix[all_pairs_mask].mean().item()
        inter_task_sims.append(avg_inter_sim)

    # 4. Visualization
    plt.figure(figsize=(12, 6))
    plt.plot(layer_indices, intra_task_sims, label='Intra-task Similarity (Consistency)', marker='o', color='blue')
    plt.plot(layer_indices, inter_task_sims, label='Inter-task Similarity (Confusion)', marker='x', color='red')

    # Identify the "Categorization Layer" (Maximum Gap)
    gap = [intra - inter for intra, inter in zip(intra_task_sims, inter_task_sims)]
    best_layer = gap.index(max(gap))
    plt.axvline(x=best_layer, linestyle='--', color='green', alpha=0.5, label=f'Peak Separation (Layer {best_layer})')

    plt.title(f"Layer-wise Task Separation Profile ({CG.MODEL_NAME})")
    plt.xlabel("Layer Index")
    plt.ylabel("Average Cosine Similarity")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    #print(f"Analysis Complete. The highest separation between tasks was observed at Layer {best_layer}.")
    return best_layer


CG.MODEL_LAYER = visualize_layerwise_separation_profile()


## Step 6: Extract Residual Stream Activations
To isolate the "Categorization Layer", you should extract the activations from the residual stream at the final token position across all layers. The final token (the task word) is where the categorization is finalized.

In [None]:
model_prompt_act = []
model_answers = []

for prompt in all_prompts:
    with torch.no_grad():
        # Extract activations for the categorization layer analysis using run_with_cache
        logits_for_activations, cache = model.run_with_cache(prompt)
        vec = cache["resid_post", CG.MODEL_LAYER][0, -1, :].detach().cpu()
        model_prompt_act.append(vec)

        # Generate a sequence of tokens for the model's answer
        input_ids = model.tokenizer.encode(prompt, return_tensors='pt').to(model.cfg.device)
        # Generate up to 10 new tokens for the answer. Using do_sample=False for deterministic output.
        generated_output_ids = model.generate(
            input_ids,
            max_new_tokens=10, # Allow up to 10 new tokens to cover ~35 characters
            do_sample=False,   # For deterministic answers for mathematical tasks
            temperature=0.0    # Set temperature to 0.0 for greedy decoding
            # Removed pad_token_id as it's not accepted by HookedTransformer.generate() for this model
        )

        # Decode only the newly generated part of the output
        generated_answer_tokens = generated_output_ids[0, len(input_ids[0]):]
        predicted_answer = model.tokenizer.decode(generated_answer_tokens, skip_special_tokens=True).strip()
        model_answers.append(predicted_answer)

model_prompt_tensor = torch.stack(model_prompt_act)

In [None]:
print(model_answers[0:5])

## Step 7: Disentangling Categorization from Data

We subtract the average prompt "template" to find the task-specific vectors

In [None]:
# Calculate global mean (centroid) to remove template bias
global_centroid = model_prompt_tensor.mean(dim=0)
task_specific_vectors = model_prompt_tensor - global_centroid

## Step 8: Task Cosine Similarity Heatmap

We visualize this disentanglement using a similarity heatmap of all prompts (NUM_TASKS × NUMBER_EXAMPLES)

Intra-Task Consistency: Each 6x6 block represents a task (e.g., all "sum" prompts).The 6x6 blocks on the heatmap diagonal show how similar "sum (25,9)" is to "sum (99,1)". High similarity here confirms the categorization circuit is ignoring numeric (input) noise.

Inter-Task Orthogonality: The dark regions between blocks represent the separation between tasks. Different activations for different tasks => clear categorization between tasks.

In [None]:
# Later code will show that interesting neurons at layers 14 and 16 are important in categorization.
# So we choose layer 16. Other layers give similar results.
CG.MODEL_LAYER = 16


def visualize_similarity_heatmap():

    # Normalize for cosine similarity
    norm_vecs = F.normalize(task_specific_vectors, p=2, dim=1)
    sim_matrix = torch.mm(norm_vecs, norm_vecs.t()).numpy()

    plt.figure(figsize=(12, 10))
    labels = [f"{m['task']} {m['pair']}" for m in metadata]
    sns.heatmap(sim_matrix, xticklabels=labels, yticklabels=labels, cmap="viridis", annot=False)
    plt.title(f"Cosine Similarity Matrix: Stability of Task Categorization (Layer {CG.MODEL_LAYER}) {CG.MODEL_NAME}")
    plt.xlabel("Prompt (Task + Number Pair)")
    plt.ylabel("Prompt (Task + Number Pair)")
    plt.show()


visualize_similarity_heatmap()

## Step 9. Visualization: PCA (2D) Projection

We project the NUM_TASKS × NUMBER_EXAMPLES vectors into 2D space to see the "Task Clusters"

- If the "Structural Separation" hypothesis is true, these task clusters should be geometrically distant in the PCA plot. This is useful but weak evidence.

- Scale Coordination: You can observe if the clusters are roughly the same distance from the center, which would support the idea that the model uses a unified activation scale for all 100 tasks

In [None]:
def visualize_pca_projection():
    pca = PCA(n_components=2)
    pca_results = pca.fit_transform(task_specific_vectors.numpy())

    plt.figure(figsize=(10, 8))
    colors = sns.color_palette("hls", len(tasks))
    task_colors = {task: colors[i] for i, task in enumerate(tasks)}

    # Create a scatter plot for each task group separately to ensure correct legend and colors
    for task_name in tasks:
        # Get indices for current task
        task_indices = [i for i, m in enumerate(metadata) if m['task'] == task_name]

        # Plot points for this task
        plt.scatter(
            pca_results[task_indices, 0],
            pca_results[task_indices, 1],
            color=task_colors[task_name],
            label=task_name, # Each task gets one legend entry
            s=50 # default size for scatter points
        )

    plt.legend(title="Tasks", bbox_to_anchor=(1.05, 1), loc='upper left') # Move legend outside to prevent overlap
    plt.title(f"PCA: Task Categorization Clusters (Invariant to Numeric Inputs) {CG.MODEL_NAME}")
    plt.xlabel("Principal Component 1")
    plt.ylabel("Principal Component 2")
    plt.grid(True, alpha=0.3)
    plt.tight_layout() # Adjust layout to prevent labels/legend from being cut off
    plt.show()


visualize_pca_projection()

## Step 10. Synonym Logit Lens: Emergence of Task Intent Across Layers

Instead of comparing residual stream vectors (which contain numbers, syntax, and intent), we should project those vectors into Vocabulary Space. If the model has categorized a prompt as "Sum," its internal state should be "thinking" about addition-related tokens long before it actually generates the answer.
- The Logic: If the model understands the category, then the residual stream at the final prompt token should have a high projection onto the entire set of synonyms for that task (e.g., for SUM: "add", "plus", "total", "sum").
- The Experiment:
  - For each task, there is a set of synonyms (e.g., $S_{sum} = \{add, plus, sum, total\}$).
  - At each layer $L$, take the residual stream $x_L$ at the final token position
  - Apply the model's final Unembed layer to $x_L$ to get logits.
  - Calculate a Category Score: The average logit value for all tokens in $S_{sum}$.
- Success Metric: Look for the layer where the Category Score for "Sum" spikes significantly higher than other categories, regardless of whether the prompt used the word "add" or "total."

Why this is a "Closer" approach to the Paper's logic:
- Mid-Computation Detection: Like the "Models Know" paper, this doesn't wait for the final output. It checks if the model has "internally decided" on the math operator at intermediate steps.
- Projection over Proximity: Raw vectors are messy because they represent everything at once. By using the W_U (unembedding) matrix, you are using the model's own "dictionary" to filter out numeric noise and find the specific bits of the vector that represent the Categorization Intent.
- Probing Generalization: If the score for the "Sum" category is high across all 6 prompts (even those using synonyms like "plus"), it proves the existence of a shared Categorization Circuit that has mapped different input synonyms to a single "task".

In [None]:
# Convert synonyms to token IDs
task_token_ids = {}
for task, syns in synonyms.items():
    ids = []
    for s in syns:
        # Get IDs for the word, usually with and without a leading space
        ids.extend(model.to_tokens(s, prepend_bos=False)[0].tolist())
        ids.extend(model.to_tokens(" " + s, prepend_bos=False)[0].tolist())
    task_token_ids[task] = list(set(ids))

def run_logit_lens_probe():
    results = []

    for i, prompt_data in enumerate(test_data):
        prompt = prompt_data['prompt']
        task = prompt_data['task']

        # Run model and cache residual stream
        # resid_post is the state of the stream after the layer's Attention and MLP
        logits, cache = model.run_with_cache(prompt, names_filter=lambda name: "resid_post" in name)

        for layer_idx in range(model.cfg.n_layers):
            # Extract final token activation [batch, pos, d_model] -> [d_model]
            resid_vec = cache["resid_post", layer_idx][0, -1, :]

            # Apply Logit Lens: Project residual stream directly to vocabulary
            # We use the model's unembedding matrix (W_U)
            layer_logits = resid_vec @ model.W_U

            # Calculate scores for ALL tasks to see which one is dominant
            for t_name, t_ids in task_token_ids.items():
                # Average logit of all synonyms in this category
                avg_logit = layer_logits[t_ids].mean().item()

                results.append({
                    "prompt_idx": i,
                    "task": task, # Actual ground truth task
                    "probed_category": t_name, # The category we are measuring
                    "layer": layer_idx,
                    "score": avg_logit
                })

    return pd.DataFrame(results)

# Execute Probe
df_logit_lens = run_logit_lens_probe()

In [None]:
# Visualize: When does the "Correct" category emerge?
plt.figure(figsize=(12, 6))
for task_name in tasks:
    # Get prompts belonging to this task
    subset = df_logit_lens[(df_logit_lens['task'] == task_name) &
                          (df_logit_lens['probed_category'] == task_name)]

    # Average across the 6 examples per task
    avg_scores = subset.groupby('layer')['score'].mean()
    plt.plot(avg_scores.index, avg_scores.values, label=f"Intent: {task_name}", marker='o')

plt.title("Synonym Logit Lens: Emergence of Task Intent Across Layers")
plt.xlabel("Layer Index")
plt.ylabel("Avg Logit Score of Task Synonyms")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

**1. Reaching a "Winner-Takes-All" Consensus**
The hypothesis suggests that output selection is a "winner-takes-all" process where the activation of a Categorization Circuit acts as a gate.

- **The Late-Layer Spike:** The late layers represent the final refinement stage. An exponential rise in logits in these layers indicates that the model is rapidly shifting its internal state from a "superposition" of possibilities to a singular, high-confidence prediction.

- **Finalization on the Last Word:** The rise toward the end of the model suggests that these final layers are responsible for converting that high-level "intent" into specific token probabilities.


**2. The Logic of the Exponential Curve**
The exponential nature of the curve is a characteristic of the transformer's residual stream as it approaches the Unembedding Head:

- **Magnitude and Confidence:** As the model processes information through the final blocks, it increases the vector magnitude in the "task-specific direction" of the residual stream. Because logits are eventually passed through a softmax function to determine token probability, a linear increase in vector alignment in these final layers manifests as an exponential increase in logit scores.

- **Decoupling Confirmation:** As the logit scores for synonyms (e.g., "sum," "add," "total") all rise together, it suggests the model has successfully mapped the prompt into a "pure task direction" or an abstract category space that is independent of the specific phrasing used in the input.


**3. Identifying the "Categorization Layer"**
If each task categorization circuit is independent, there is no reason to expect them all to complete their calculations at the same layer.

While the logits peak at layer 31, the layer at which the categorization result is clear is the layer where the separation between the ground-truth task and other tasks first becomes distinct for all tasks. For example Layer 25 shows a clear gap between the correct task's logits and the other tasks' logits.

In [None]:
# At what layer do categories first dominate?
best_probe_layer = 25 # Manually adjusted based on the graph above
final_scores = df_logit_lens[df_logit_lens['layer'] == best_probe_layer]

# Calculate a pivot table to see if 'Sum' prompts have the highest 'Sum' score
summary = final_scores.groupby(['task', 'probed_category'])['score'].mean().unstack()
print(f"\nCategory Confidence Matrix at Layer {best_probe_layer}:")
print(summary)

df_logit_lens = None

## Step 11. Gradient-Based Circuit Localization

To find the specific "Circuit" responsible for this categorization, we use Integrated Gradients. This tells us which specific components (Attention Heads or MLP neurons) are actually contributing to that "Category Score" we defined above.

**The Experiment:** Define your target "Categorization Signal" (e.g., the Logit Lens score for the "Sum" synonym set). Calculate the gradient of this signal with respect to the output of every Attention Head.

**Localize:** Identify "Intent Heads" that have high attribution to the correct category across all 6 variations of the prompt phrasing.

**Why this works:** This bypasses "similarity noise." Even if two "Sum" prompts look different in PCA, they might both be relying on the same Attention Head to "gate" the addition logic.

TODO: Does this really cover MLP neurons?

In [None]:
def get_gradient_based_head_attributions(target_task_name):
    # 1. Prepare Target Synonym IDs
    target_syn_ids = task_token_ids[target_task_name]

    # 2. Extract Prompts for the Target Task
    target_prompts = [d['prompt'] for d in test_data if d['task'] == target_task_name]

    # Storage for head attributions: [Layer, Head]
    total_head_attribs = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))

    for prompt in target_prompts:
        # We use a clean run to get the actual activations
        logits, cache = model.run_with_cache(prompt)

        # We want the gradient of the Avg Synonym Logit at the final layer
        # with respect to the output of every attention head in the model

        for layer in range(model.cfg.n_layers):
            # hook_z is the output of the attention heads before being mixed by W_O
            # shape: [batch, pos, head, d_head]
            head_outputs = cache["z", layer]

            # For simplicity in this localization step, we calculate
            # Grad * Activation (a common approximation of Integrated Gradients)
            # which is highly effective for identifying 'heavy lifters' in circuits.

            def head_output_hook(value, hook):
                return value

            # Define the 'Categorization Signal' for backprop
            # Target the residual stream at the very last layer before Unembed
            last_resid = cache["resid_post", model.cfg.n_layers - 1][0, -1, :]

            # Calculate attribution using the model's W_U head
            # Projecting the specific head's contribution to the target synonym logits
            for head in range(model.cfg.n_heads):
                # Isolate head output and project to d_model space using W_O
                # then project to vocabulary space using W_U
                z = head_outputs[0, -1, head, :] # [d_head]
                head_contribution_to_resid = z @ model.W_O[layer, head] # [d_model]

                # Projection into synonym logit space
                head_synonym_logits = head_contribution_to_resid @ model.W_U[:, target_syn_ids]

                # Attribution Score = Average logit boost provided by this head
                total_head_attribs[layer, head] += head_synonym_logits.mean().item()

    # Average attributions across the variations of the prompt phrasing
    avg_head_attribs = total_head_attribs / len(target_prompts)
    return avg_head_attribs

In [None]:
def visualize_gradient_based_data(target_task_name, sum_attribs):
  plt.figure(figsize=(7, 4))
  sns.heatmap(sum_attribs.numpy(), cmap="RdBu_r", center=0)
  plt.title(f"Attention Head Attribution to '{target_task_name}' Category Signal")
  plt.xlabel("Head Index")
  plt.ylabel("Layer Index")
  plt.show()

def list_gradient_based_data(target_task_name, sum_attribs):
  # Identify top contributing heads
  top_values, top_indices = torch.topk(sum_attribs.flatten(), 5)
  print(f"Top 5 'Intent Heads' for {target_task_name} Categorization:")
  for val, idx in zip(top_values, top_indices):
      layer = idx.item() // model.cfg.n_heads
      head = idx.item() % model.cfg.n_heads
      print(f"Layer {layer}, Head {head}: Attribution Score {val:.4f}")

In [None]:
sum_attribs = {}

for task in tasks:
    sum_attribs[task] = get_gradient_based_head_attributions(task)
    visualize_gradient_based_data(task, sum_attribs[task])

### Step 11A. Gradient-Based Circuit Localization - Neuron overlap per task

Show the overlap between the neurons important to the various task categorization circuits.

The various tasks are not randomly distributed. Instead they are heavily clustered in a few neurons.


In [None]:
# 1. Extract the top 5 heads and their scores for each task
summary_data = []
for task_name, attrib_tensor in sum_attribs.items():
    # Find top 5 values and their flattened indices
    top_values, top_indices = torch.topk(attrib_tensor.flatten(), 5)

    for val, idx in zip(top_values, top_indices):
        layer = idx.item() // model.cfg.n_heads
        head = idx.item() % model.cfg.n_heads
        summary_data.append({
            "Head": f"L{layer}.H{head}",
            "Task": task_name,
            "Score": val.item()
        })

# 2. Create a DataFrame and pivot it into a grid
df_sum = pd.DataFrame(summary_data)
grid_df = df_sum.pivot(index="Head", columns="Task", values="Score")

# 3. Order columns and rows
ordered_tasks = ['min', 'max', 'avg', 'diff', 'sum', 'prod']
grid_df = grid_df[[t for t in ordered_tasks if t in grid_df.columns]]

# Sort rows by the highest value in that row (descending)
grid_df['row_max'] = grid_df.max(axis=1)
grid_df = grid_df.sort_values(by='row_max', ascending=False).drop(columns='row_max')

# 4. Define Styling Function
def apply_color_scale(val):
    if pd.isna(val):
        return ""
    if val > 0.5:
        return 'color: red; font-weight: bold;'
    elif val > 0.25:
        return 'color: orange; font-weight: bold;'
    elif val > 0.15:
        return 'color: #D4AF37; font-weight: bold;' # Dark Yellow/Gold
    else:
        return 'color: black;'

# 5. Apply style and display
styled_grid = grid_df.style.applymap(apply_color_scale).format(lambda x: f"{x:.4f}" if pd.notna(x) else "-")

print("Polysemantic Intent Heads Attribution Grid (Sorted by Row Max):")
display(styled_grid)


# TODO: Flip graph on its side for easier reading

## Step 12. Categorization Entropy Calculation
This measures the model's "certainty" in task selection before it begins generating the first token of the answer.

Calculates the entropy of the identified gating neurons/heads:
- Low entropy = Strong winner-takes-all (Predicts Correctness)
- High entropy = Interference/Confusion (Predicts Hallucination)

In [None]:
# Dynamically determine the top 4 gating heads based on their maximum attribution score across tasks
def determine_top_gating_heads():
    # First, parse the 'Head' string into layer and head integers
    df_sum['Layer'] = df_sum['Head'].apply(lambda x: int(x.split('L')[1].split('.H')[0]))
    df_sum['Head_idx'] = df_sum['Head'].apply(lambda x: int(x.split('.H')[1]))

    # Group by actual (Layer, Head_idx) and get the maximum score for each unique head
    head_max_scores = df_sum.groupby(['Layer', 'Head_idx'])['Score'].max().reset_index()

    # Sort by score in descending order and select the top 4
    top_gating_heads_df = head_max_scores.sort_values(by='Score', ascending=False).head(4)

    # Convert to the desired list of tuples format
    top_gating_heads = [(int(row['Layer']), int(row['Head_idx'])) for index, row in top_gating_heads_df.iterrows()]

    print(f"Top few gating heads: {top_gating_heads}")
    print()

    return top_gating_heads

In [None]:
def calculate_categorization_entropy(prompt, layer_idx, gating_head_indices):
    _, cache = model.run_with_cache(prompt, names_filter=lambda name: "z" in name)

    # Extract magnitudes of identified gating heads at final token
    activations = []
    for layer, head in gating_head_indices:
        act = cache["z", layer][0, -1, head, :].norm().item()
        activations.append(act)

    # Softmax to get probability distribution
    acts_tensor = torch.tensor(activations)
    probs = F.softmax(acts_tensor, dim=0)

    # Entropy calculation
    entropy = -torch.sum(probs * torch.log(probs + 1e-9)).item()

    return entropy

### Step 12A. Failure Mode Investigation (Nonsense Prompts)
Use this code to test if high entropy in the categorization layer predicts "Generation Collapse."

In [None]:
nonsense_prompts = [
    "Q: Given the numbers 21 and 14, calculate the plural. A: ",
    "Q: Given the inputs True and False, count the spaces. A: ",
    "Q: Given the word 'COMPUTER', calculate the average. A: "
]

top_gating_heads = determine_top_gating_heads()

print("Failure Mode Analysis:")
for prompt in nonsense_prompts:
    entropy = calculate_categorization_entropy(prompt, 16, top_gating_heads)

    # Generate response
    input_ids = model.tokenizer.encode(prompt, return_tensors='pt').to(model.cfg.device)
    output = model.generate(input_ids, max_new_tokens=20, verbose=False)
    decoded = model.tokenizer.decode(output[0][len(input_ids[0]):]).strip()

    print(f"Prompt: {prompt}")
    print(f"-> Gating Entropy: {entropy:.4f}")
    print(f"-> Model Output: {decoded}")
    print("-" * 30)

## Step 13. Causal Intervention: Activation Patching
This is the "Gold Standard" proof. We take the "Sum" intent from a source prompt and patch it into a "Product" target prompt to see if we can force the model to add instead of multiply.

In [None]:
# Define source and target
source_prompt = "Q: Given 5 and 3 what is sum? A: 8"
target_prompt = "Q: Given 5 and 3 what is product? A: " # We want to steer this to '8'

# Identify the 'Categorization Subspace'
# For Llama-3-8B, we'll focus on the residual stream at the final prompt token
def patch_categorization_manifold(target_activations, hook):
    # We replace the target's internal intent with the source's intent
    target_activations[0, -1, :] = source_cache[hook.name][0, -1, :]
    return target_activations

# 1. Run source and cache
_, source_cache = model.run_with_cache(source_prompt)

# 2. Run target with the patch at the identified 'Categorization Layer' (e.g., Layer 16)
patched_logits = model.run_with_hooks(
    target_prompt,
    fwd_hooks=[(f"blocks.{CG.MODEL_LAYER}.hook_resid_post", patch_categorization_manifold)]
)

# 3. Check if the top logit is now '8' instead of '15'
predicted_token = model.tokenizer.decode(patched_logits[0, -1, :].argmax().item())
print(f"Steered Prediction (Target was 'prod'): {predicted_token}")

## Step 20: Analyze L31H14.

Is the Layer 31 Head 14 neuron just a "mathematical intent" or does it (polysemantically) embed separate data for each task?

Use the "Singular Vector-Based Interpretability" technique explained in https://arxiv.org/abs/2511.20273 with code in
https://github.com/Exploration-Lab/Beyond-Components to measure the alignment of the neuron SVD vectors to the various tasks.




In [None]:
def get_single_head_svd(model, layer, head, top_k):
    """
    Performs SVD on the OV (Value-Output) circuit for a specific head.
    """
    # 1. Isolate the specific Value and Output weights
    # W_V shape: [n_heads, d_model, d_head]
    # W_O shape: [n_heads, d_head, d_model]
    W_V_h = model.W_V[layer, head]
    W_O_h = model.W_O[layer, head]

    # 2. Compute the OV matrix for this head: W_V_h @ W_O_h
    # Shape: [d_model, d_model]
    W_OV_h = W_V_h @ W_O_h

    # 3. SVD decomposition (Convert to float32 for stability)
    U, S, Vh = torch.linalg.svd(W_OV_h.to(torch.float32), full_matrices=False)

    # 4. Return the top_k singular vectors and values
    return {
        "U": U[:, :top_k].detach().cpu(),
        "S": S[:top_k].detach().cpu(),
        "Vh": Vh[:top_k, :].detach().cpu()
    }

In [None]:
def get_single_mlp_svd(model, layer, top_k):
    W_up = model.W_out[layer] # Down projection in Llama
    W_in = model.W_in[layer]  # Up/Gate projection
    W_MLP = W_in.T @ W_up.T

    U_m, S_m, Vh_m = torch.linalg.svd(W_MLP.to(torch.float32), full_matrices=False)

    return {
        "U": U_m[:, :top_k].detach().cpu(),
        "S": S_m[:top_k].detach().cpu(),
        "Vh": Vh_m[:top_k, :].detach().cpu()
    }

In [None]:
# Save SVD data to disk (Optional)
def save_svd_results(full_model_svd):
  save_path = CG.svd_save_path()
  with open(save_path, 'wb') as f:
      pickle.dump({
          "metadata": {
              "model": CG.MODEL_NAME,
              "top_k": CG.SVD_K_DIRECTIONS,
              "d_model": model.cfg.d_model
          },
          "layers": full_model_svd
      }, f)

  print(f"SVD results saved to {save_path}")

In [None]:
# Load the SVD results (Optional)
def load_svd_results():
  load_path = CG.svd_save_path()
  with open(load_path, "rb") as f:
      full_model_svd = pickle.load(f)
  return full_model_svd

## Step 21: Scree Plot

The first few singular vectors capture most of the head's variance.

If you see a few very large singular values followed by a sharp drop, it confirms the head is "low-rank." This means it isn't doing random calculations but is focused on a small number of specific, independent functions.

In [None]:
# --- Example Usage for Layer 31, Head 14 ---
target_layer = 31
target_head = 14

print(f"Computing SVD for Layer {target_layer}, Head {target_head}")
single_head_svd = get_single_head_svd(model, target_layer, target_head, CG.SVD_K_DIRECTIONS)

# Visualization: Scree Plot (Singular Values)
# This shows how many "dimensions" of information this head actually uses.
plt.figure(figsize=(12, 4))
plt.plot(single_head_svd["S"][:50].numpy(), marker='o', linestyle='--')
plt.title(f"Scree Plot: Singular Values for Layer {target_layer} Head {target_head}")
plt.xlabel("Singular Vector Index")
plt.ylabel("Singular Value (Importance)")
plt.grid(True, alpha=0.3)
plt.show()

## Step 22: Heatmap of Singular Vectors to Task Directions
We will check if the 'Output' directions (U vectors) of the head
align with the 'Task Intent' directions we found in the residual stream.

- **Task Specialization:** If a column (e.g., SV 2) has a high positive score for "Sum" and near-zero for everything else, you have found the specific "Sum Vector" inside that head.
- P**olysemantic Overlap:** If a single vector (like SV 0) has high scores for multiple tasks (e.g., "Max" and "Min"), it suggests that this specific sub-component represents a higher-level concept like "Comparison Intent" rather than a specific arithmetic operation.
- **Independent Subspaces:** According to the "Beyond Components" theory, this heatmap should reveal that different tasks are "superposed" in the same head but are mathematically independent because they align with different orthogonal singular vectors.

In [None]:
# 1. Calculate Task Centroids (Averaged Task Directions)
task_centroids = {}
for task_name in tasks:
    indices = [i for i, m in enumerate(metadata) if m['task'] == task_name]
    # Average the centroid-subtracted vectors for this task
    task_centroids[task_name] = task_specific_vectors[indices].mean(dim=0)

# Convert dictionary to a tensor [num_tasks, d_model]
task_centroid_tensor = torch.stack([task_centroids[t] for t in tasks])

# 2. Compute Cosine Similarity between Singular Vectors and Task Centroids
U_vectors = single_head_svd["U"][:, :CG.SVD_K_DIRECTIONS] # Using the single head results

norm_U = F.normalize(U_vectors.to(torch.float32), p=2, dim=0)
norm_tasks = F.normalize(task_centroid_tensor.to(torch.float32), p=2, dim=1)

# Resulting similarity matrix: [tasks, singular_vectors]
# The @ operator performs matrix multiplication
sim_matrix = (norm_tasks @ norm_U).numpy()

# 3. Visualization: Task-to-Vector Heatmap
plt.figure(figsize=(22, 6))
sns.heatmap(sim_matrix, annot=True, fmt=".2f", cmap="RdBu_r", center=0,
            xticklabels=[f"{i}" for i in range(CG.SVD_K_DIRECTIONS)],
            yticklabels=tasks)
plt.title(f"Correlation: L{target_layer} H{target_head} Singular Vectors vs. Task Intents")
plt.xlabel("Singular Vector (Decomposed Sub-components of Head)")
plt.ylabel("Mathematical Task Intent")
plt.show()