<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Grokking_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Grokking Demo Notebook

<b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

# Setup
(No need to read)

In [None]:
TRAIN_MODEL = True

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = True
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install circuitsvis

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

# Model Training

## Config

In [None]:
p = 10
frac_train = 0.72

# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.9, 0.98)

num_epochs = 10000
checkpoint_every = 100

DATA_SEED = 598

## Define Task
* Define generalized Fibonacci
* Define the dataset & labels

Input format:
|a|b|=|

In [None]:
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)
equals_vector = einops.repeat(torch.tensor(p), " -> (i j)", i=p, j=p)

In [None]:
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)

In [None]:
labels = (dataset[:, 0] + dataset[:, 1])

Convert this to a train + test set - 30% in the training set

In [None]:
torch.manual_seed(DATA_SEED)
indices = torch.randperm(p*p)
cutoff = int(p*p*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)

## Define Model

In [None]:

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=p+1,
    d_vocab_out=2*p,
    n_ctx=3,
    init_weights=True,
    device=device,
    seed = 999,
)

In [None]:
model = HookedTransformer(cfg)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [None]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [None]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()
train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)

In [None]:
print("Uniform loss:")
print(np.log(2*p))

## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent. We do this so to make training smoother and reduce the number of slingshots.

In [None]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if TRAIN_MODEL:
    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())

        optimizer.step()
        optimizer.zero_grad()

        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            checkpoint_epochs.append(epoch)
            model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

In [None]:
torch.save(
    {
        "model":model.state_dict(),
        "config": model.cfg,
        "checkpoints": model_checkpoints,
        "checkpoint_epochs": checkpoint_epochs,
        "test_losses": test_losses,
        "train_losses": train_losses,
        "train_indices": train_indices,
        "test_indices": test_indices,
    },
    PTH_LOCATION)

In [None]:
if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    train_indices = cached_data["train_indices"]
    test_indices = cached_data["test_indices"]

## Show Model Training Statistics, Check that it groks!

In [None]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line

In [None]:
line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Fibonacci", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)

# Analysing the Model

## Standard Things to Try

In [None]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.numel())

Get key weight matrices:

In [None]:
W_E = model.embed.W_E[:-1]
print("W_E", W_E.shape)
W_neur = W_E @ model.blocks[0].attn.W_V @ model.blocks[0].attn.W_O @ model.blocks[0].mlp.W_in
print("W_neur", W_neur.shape)
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

In [None]:
original_loss = loss_fn(original_logits, labels).item()
print("Original Loss:", original_loss)

In [None]:
input_number = 7
embedding_vector_np = W_E[input_number].detach().cpu().numpy() # this gets us the embedded vector for input_number

# first we try to plot directly the value of the entries of the embedded vector
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6)) # Adjust figure size as needed
plt.bar(np.arange(128), embedding_vector_np)
plt.xlabel("Component Index (within Embedding Vector)")
plt.ylabel("Component Value")
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

In [None]:
# now we will do a heatmap of the embedded vector for all possible inputs

# We want dimensions along y-axis and inputs along x-axis,
# so we need to transpose the matrix for imshow
W_E_transposed = W_E.detach().cpu().numpy().T # Shape (d, N)

# Get N and d from the original tensor shape
N_vocab = W_E.shape[0]
d_embed = W_E.shape[1]

# Create the heatmap
plt.figure(figsize=(8, 10)) # Adjust figsize as needed (width, height)
# Use imshow to display the matrix as an image.
# aspect='auto' allows the cells to be non-square to fit the plot area.
# interpolation='nearest' avoids blurring pixels.
# cmap='viridis' is a common colormap, change if you prefer another.
im = plt.imshow(W_E_transposed, aspect='auto', interpolation='nearest', cmap='viridis')

# Add labels and title|
plt.xlabel("Input Token Index (k)")
plt.ylabel("Embedding Dimension Index")
plt.title("Heatmap of Embedding Vectors (W_E)")

# Set ticks to match indices
# Show ticks for every input token if N is small
if N_vocab <= 20: # Adjust threshold as needed
     plt.xticks(ticks=np.arange(N_vocab), labels=np.arange(N_vocab))
else:
    # For larger N, show fewer ticks to avoid clutter
     plt.xticks(ticks=np.linspace(0, N_vocab-1, num=min(N_vocab, 10), dtype=int))


# Add a colorbar to show the mapping from color to value
plt.colorbar(im, label='Embedding Component Value')

# Ensure layout is tight
plt.tight_layout()

# Show the plot
plt.show()

It looks... kind of periodic? Let's now try to do PCA (or SVD) on this and see if we can learn anything.

In [None]:
# 1. Center the data (subtract the mean of each dimension)
W_E_mean = W_E.mean(dim=0, keepdim=True) # Calculate mean across samples (N) for each dimension (d)
W_E_centered = W_E - W_E_mean # Broadcasting subtracts the mean from each row

# 2. Perform SVD on the centered data
# U shape: (N, K), S shape: (K,), Vt shape: (K, d) where K = min(N, d)
# full_matrices=False is generally more efficient
U, S, Vt = torch.linalg.svd(W_E_centered, full_matrices=False)

# --- Results ---

# PCA Scores (Data projected onto principal components):
# This is often the primary result needed for visualization/dimensionality reduction.
# It represents each original sample (row in W_E) in the new PCA coordinate system.
pca_scores = U * S  # Shape: (N, K) - Scales the left singular vectors by singular values

# Principal Components (Loadings / Directions of maximum variance):
# These are the rows of Vt. Each row is a d-dimensional vector representing a principal direction.
principal_components = Vt # Shape: (K, d)

# Explained Variance Ratio (requires a bit more calculation):
explained_variance = S.square() / (W_E.shape[0] - 1) # Variance explained by each component
total_variance = explained_variance.sum()
explained_variance_ratio = explained_variance / total_variance

# Plot score on PC1 vs token index k (similar to previous PCA example)
plt.figure(figsize=(10, 5))
plt.scatter(np.arange(W_E.shape[0]), pca_scores[:, 0].detach().cpu().numpy(), alpha=0.7, s=10)
plt.xlabel("Input Token Index (k)")
plt.ylabel("Score on PC1")
plt.title("PyTorch PCA: Score on First Principal Component vs. Input Token Index")
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()



We observe that the principal component is linear, i.e., the model learns an embedding that is given by embedded_value = w * value + b, with value being the input token index. w is actually negative, and b is ~2.

We now wish to investigate how important this first component, which is a linear function of the inputs, is in comparison with the other components. We do this by looking at how much variance is explained by each of the components.

In [None]:
# --- Calculate Explained Variance (using PyTorch SVD results S) ---

# Variance explained by each component (proportional to singular value squared)
# Using N (W_E.shape[0]) instead of N-1 is fine for the ratio calculation
explained_variance = S.square() / W_E.shape[0] # Variance = (singular value / sqrt(N))^2
total_variance = explained_variance.sum()
explained_variance_ratio = explained_variance / total_variance

# Convert to numpy for printing/plotting if needed
explained_variance_ratio_np = explained_variance_ratio.detach().cpu().numpy()

# --- Print the Ratios ---
print(f"Explained Variance Ratio by Principal Component:")
# Limit printing to the number of components computed or a reasonable max (e.g., 10)
num_components_to_print = min(len(explained_variance_ratio_np), 10)
cumulative_variance = 0.0
for i in range(num_components_to_print):
    ratio = explained_variance_ratio_np[i]
    cumulative_variance += ratio
    print(f"  PC{i+1}: {ratio:.4f} ({ratio*100:.2f}%) \t| Cumulative: {cumulative_variance:.4f} ({cumulative_variance*100:.2f}%)")

# If you computed more components than printed:
if len(explained_variance_ratio_np) > num_components_to_print:
    print(f"  ...")
    print(f"Total Cumulative Variance (all {len(explained_variance_ratio_np)} components): {explained_variance_ratio_np.sum():.4f} ({explained_variance_ratio_np.sum()*100:.2f}%)")


# --- Visualize with a Scree Plot ---
plt.figure(figsize=(10, 6))

component_indices = np.arange(1, len(explained_variance_ratio_np) + 1)

# Plot individual explained variance ratios
plt.bar(component_indices, explained_variance_ratio_np, alpha=0.7, align='center',
        label='Individual Explained Variance Ratio')

# Plot cumulative explained variance ratio
cumulative_variance_ratio_np = np.cumsum(explained_variance_ratio_np)
plt.plot(component_indices, cumulative_variance_ratio_np, marker='o', linestyle='--',
         label='Cumulative Explained Variance Ratio')

# Add threshold lines (optional, but common)
plt.axhline(y=0.9, color='r', linestyle=':', linewidth=1, label='90% Threshold')
plt.axhline(y=0.95, color='g', linestyle=':', linewidth=1, label='95% Threshold')


plt.xlabel('Principal Component Index')
plt.ylabel('Explained Variance Ratio')
plt.title('Scree Plot - Explained Variance by Principal Component')
# Ensure x-axis ticks match component indices if not too many
if len(component_indices) <= 15:
    plt.xticks(ticks=component_indices)
plt.legend(loc='best')
plt.grid(True, linestyle='--', alpha=0.6)
plt.ylim(0, 1.1) # Set y-axis limit slightly above 1.0
plt.show()

We observe that the first component explains ~0.75 of the variance, and the second component ~0.25. The others are basically irrelevant. This does however mean that we also need to look at the second component to understand what is happening with the embedding. Let us do so now.

In [None]:
# Plot score on PC2 vs token index k
# This is basically the same code as before, but we now plot the second component rather than the first one
plt.figure(figsize=(10, 5))
plt.scatter(np.arange(W_E.shape[0]), pca_scores[:, 1].detach().cpu().numpy(), alpha=0.7, s=10)
plt.xlabel("Input Token Index (k)")
plt.ylabel("Score on PC2")
plt.title("PyTorch PCA: Score on Second Principal Component vs. Input Token Index")
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()

This looks like a parabola. I have no clue why! Let's plot below the scores on the first component against the scores on the second component.



In [None]:
plt.figure(figsize=(8, 8))
token_indices = np.arange(p)
scatter = plt.scatter(
    pca_scores[:, 0].detach().cpu().numpy(),  # X-coordinates are scores on PC1
    pca_scores[:, 1].detach().cpu().numpy(),  # Y-coordinates are scores on PC2
    c=token_indices,                          # Color points based on the input token index (0-9)
    cmap='viridis',                           # Colormap (e.g., 'viridis', 'plasma')
    alpha=0.8,
    s=50                                      # Increase point size for visibility
)
plt.xlabel("Score on PC1 (Linear Component)")
plt.ylabel("Score on PC2 (Parabolic Component)")
plt.title("Embeddings Projected onto First Two Principal Components")
plt.colorbar(scatter, label='Input Token Index (k)')
plt.grid(True, linestyle='--', alpha=0.6)
plt.gca().set_aspect('equal', adjustable='box') # Try to keep scales comparable

An attempt at a conclusion regarding the model's learned embedding:

The model projects inputs (numbers 0-9) onto a low-dimensional (2D) manifold within the embedding space. This representation encodes the number's value primarily along one linear axis (PC1) and secondarily along an orthogonal parabolic axis (PC2). This specific geometric arrangement is likely learned to facilitate the downstream integer addition computation. (Leveraging linear order + perhaps magnitude/centrality).

Let's consider some potential next steps to further understand the algorithm the model is implementing:

Step 2: Attention Pattern Check

Goal: Check how the model combines a and b info at the = position.

Select Sample Inputs: Get various pairs (a, b).

Hook Attention Weights: For each head, get the attention weights from position 2 (query: =) to positions 0 (a), 1 (b), and 2 (=). Store these weights [w_to_a, w_to_b, w_to_=] for each head and input pair.

Visualize Average Patterns: For each head, average the weights w_to_a, w_to_b, w_to_= across all sample inputs. Plot these three average weights per head (e.g., using bar charts).

Analyze: Does position 2 attend significantly to both position 0 (a) and position 1 (b)? Are there clear differences between heads? Is attention to position 2 (=) itself low?


Expected Outcome: Confirm attention gathers info from a and b. Note basic patterns.

Step 3: MLP Analysis

Goal: Understand how the MLP computes a+b from the combined a and b representations.

Hook MLP Activations: Get the n-dimensional activation vector after the ReLU inside the MLP, specifically for the state calculated at position 2 (=). Do this for various input pairs (a, b).

Correlate Neurons to Task Variables: For each MLP neuron, calculate its activation across the sample pairs. Find the correlation between each neuron's activation and key variables: a, b, the target sum a+b, PC1 score of a, PC2 score of a, PC1 score of b, PC2 score of b.

Identify Key Neurons: Note which neurons correlate strongly with the target sum a+b. Note any correlating strongly with input features (like PC1/PC2 scores).

(Optional) Visualize Key Neuron Activations: Make 2D heatmaps (x-axis a, y-axis b) for a few key neurons (e.g., sum-correlated neurons). Does the activation map look like a+b?

Probe MLP State: Train a linear regression model to predict the scalar a+b using the n-dimensional MLP activation vector as input. Check probe accuracy. High accuracy implies MLP state linearly encodes the sum.

Formulate Hypothesis: Based on correlations and probes, hypothesize how the MLP uses the input features (represented by PC1/PC2) via its neurons to arrive at an internal state that encodes the sum a+b.

In [None]:
captured_mlp_activations = []

# Define the hook function
def store_mlp_post_activation_hook(
    activation: torch.Tensor,
    hook: 'HookPoint'
):
    """
    Hook function to capture the MLP activation output (after ReLU).
    Selects only the activation for the target position (pos 2).
    """
    # activation shape is likely [batch_size, sequence_length, d_mlp]
    # We want the activation at the '=' token, which is at index 2
    target_activation = activation[:, 2, :].detach().cpu() # Shape: [batch_size, d_mlp]
    captured_mlp_activations.append(target_activation)

    # MUST return the activation to allow the forward pass to continue
    return activation

# Identify the exact hook point name
mlp_hook_point_name = "blocks.0.mlp.hook_post" # Output AFTER ReLU

# Clear previous captures before running
captured_mlp_activations = []

activation_list = []
target_sums = []

print(f"Processing {len(dataset)} input pairs...")
for a, b, equals in dataset:
    input_tokens = torch.tensor([[a, b, equals]], dtype=torch.long)

    # Clear captures for this specific run
    captured_mlp_activations = []

    # Run with hooks
    _ = model.run_with_hooks(
        input_tokens,
        fwd_hooks=[(mlp_hook_point_name, store_mlp_post_activation_hook)]
    )

    # Store the result if captured
    if captured_mlp_activations:
        # Get the vector for the first (and only) item in the batch
        mlp_activation_vector = captured_mlp_activations[0][0].numpy() # Convert to numpy array
        activation_list.append(mlp_activation_vector)
        target_sums.append(a + b)
    else:
        print(f"Warning: Hook did not capture activation for pair ({a}, {b})")

print(f"Collected {len(activation_list)} activation vectors.")

# Convert to numpy arrays for training the probe
X_probe = np.array(activation_list) # Shape: [num_samples, d_mlp]
y_probe = y_probe = np.array([item.cpu().numpy() for item in target_sums])     # Shape: [num_samples]

print("Probe input data shapes:")
print("X_probe:", X_probe.shape)
print("y_probe:", y_probe.shape)

We now have the outputs of MLP for each input pair. We will try to fit this to a linear function to see if this directly correlates to the desired representation of the sum.

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

# 1. Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    X_probe, y_probe, test_size=0.3, random_state=42 # Adjust test_size if needed
)

# 2. Initialize and train the Linear Regression model
probe_model = LinearRegression()
probe_model.fit(X_train, y_train)

# 3. Make predictions on the test set
y_pred = probe_model.predict(X_test)

# 4. Evaluate the model
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"Linear Probe Performance:")
print(f"  Mean Squared Error (MSE): {mse:.4f}")
print(f"  R-squared (R2): {r2:.4f}")

# Optional: Basic baseline comparison (predicting the mean of training labels)
baseline_pred = np.full_like(y_test, y_train.mean())
baseline_mse = mean_squared_error(y_test, baseline_pred)
baseline_r2 = r2_score(y_test, baseline_pred) # Will be close to 0 by definition

print(f"\nBaseline Performance (Predicting Mean):")
print(f"  Baseline MSE: {baseline_mse:.4f}")
print(f"  Baseline R2: {baseline_r2:.4f}")


# Optional: Visualize predictions vs actual
plt.figure(figsize=(6, 6))
plt.scatter(y_test, y_pred, alpha=0.7, label='Predictions')
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], '--', color='red', label='Perfect Prediction')
plt.xlabel("True Values (a+b)")
plt.ylabel("Predicted Values (a+b)")
plt.title("Linear Probe: True vs. Predicted Sums")
plt.legend()
plt.grid(True)
plt.show()

The linear regression we've done is almost perfectly accurate, which means that the post-ReLU activations contain a linear representation of the target sum a+b. This means that the bulk of the computation is being done by the MLP. Hence, we can likely find 'sum neurons', i.e., neurons that are mostly responsible for implementing the computation. We will do this by correlation analysis: what neurons' activations correlate more strongly with a+b?

In [None]:
from scipy.stats import pearsonr

# Ensure inputs are numpy arrays
X_probe = np.asarray(X_probe)
y_probe = np.asarray(y_probe)

num_samples, d_mlp = X_probe.shape
neuron_sum_correlations = np.zeros(d_mlp)

# Calculate Pearson correlation for each neuron
for i in range(d_mlp):
    # pearsonr returns (correlation_coefficient, p_value)
    correlation, _ = pearsonr(X_probe[:, i], y_probe)
    neuron_sum_correlations[i] = correlation

# Handle potential NaN values if a neuron's activation was constant (zero variance)
neuron_sum_correlations = np.nan_to_num(neuron_sum_correlations)

print(f"Calculated correlations for {d_mlp} neurons.")
# The 'neuron_sum_correlations' array now holds the correlation of each neuron with a+b.

# Find the indices of neurons with highest absolute correlation
top_n = 20
# Get indices sorted by absolute correlation, descending
indices_sorted_by_abs_corr = np.argsort(np.abs(neuron_sum_correlations))[::-1]

print(f"\nTop {top_n} neurons by absolute correlation with sum (a+b):")
for i in range(top_n):
    idx = indices_sorted_by_abs_corr[i]
    print(f"  Neuron {idx}: Correlation = {neuron_sum_correlations[idx]:.4f}")

In [None]:
import seaborn as sns
# --- Visualization 1: Histogram of Correlations ---
print("\nGenerating Histogram of Correlations...")
plt.figure(figsize=(10, 6))
sns.histplot(neuron_sum_correlations, bins=50, kde=True)
plt.title('Histogram of Neuron Activations\' Correlation with Sum (a+b)')
plt.xlabel('Pearson Correlation Coefficient')
plt.ylabel('Number of Neurons')
plt.grid(axis='y', alpha=0.5)
plt.show()

# --- Visualization 2: Sorted Correlation Plot (Stem Plot) ---
print("\nGenerating Sorted Correlation Plot...")
sorted_indices = np.argsort(neuron_sum_correlations)
sorted_correlations = neuron_sum_correlations[sorted_indices]
ranks = np.arange(d_mlp)

pos_mask = sorted_correlations >= 0
neg_mask = sorted_correlations < 0

plt.figure(figsize=(12, 7))
# Plot positive correlations
markerline_pos, stemlines_pos, baseline_pos = plt.stem(
    ranks[pos_mask], sorted_correlations[pos_mask],
    linefmt='b-', markerfmt='bo', basefmt=' ', label='Positive Corr'
)
plt.setp(stemlines_pos, linewidth=1, alpha=0.7)
plt.setp(markerline_pos, markersize=3, alpha=0.7)

# Plot negative correlations
markerline_neg, stemlines_neg, baseline_neg = plt.stem(
    ranks[neg_mask], sorted_correlations[neg_mask],
    linefmt='r-', markerfmt='ro', basefmt=' ', label='Negative Corr'
)
plt.setp(stemlines_neg, linewidth=1, alpha=0.7)
plt.setp(markerline_neg, markersize=3, alpha=0.7)

plt.title('Neuron Correlations with Sum (a+b), Sorted by Value')
plt.xlabel('Neuron Rank (Sorted by Correlation)')
plt.ylabel('Pearson Correlation Coefficient')
plt.ylim(-1.05, 1.05)
plt.grid(axis='y', alpha=0.5)
plt.legend()
plt.show()

num_outputs=20
# --- Visualization 3: Heatmap of Sorted W_L ---
# Ensure W_L is defined correctly (num_outputs x d_mlp)
# Example: W_L = model.W_U @ model.blocks[0].mlp.W_out (or however you get it)
if 'W_L' in locals() and W_L.shape == (num_outputs, d_mlp):
    print("\nGenerating Heatmap of Sorted W_L...")
    # Sort columns (neurons) of W_L based on correlation
    neuron_corr_sorted_indices = np.argsort(neuron_sum_correlations)
    W_L_sorted_by_neuron_corr = W_L[:, neuron_corr_sorted_indices]

    plt.figure(figsize=(15, 8))
    sns.heatmap(W_L_sorted_by_neuron_corr, cmap='coolwarm', center=0) # 'coolwarm' is good for weights
    plt.title('Neuron-Logit Map (W_L) Columns Sorted by Neuron Correlation')
    plt.xlabel('Neurons (Sorted by Correlation with Sum a+b)')
    plt.ylabel(f'Output Logit (0 to {num_outputs-1})')
    plt.yticks(np.arange(num_outputs) + 0.5, labels=np.arange(num_outputs), rotation=0)
    plt.xticks([]) # Hide neuron indices as they are too dense
    plt.show()
else:
    print("\nSkipping Heatmap of W_L: W_L matrix not found or incorrect shape.")
    print(f"Expected shape: ({num_outputs}, {d_mlp})")
    if 'W_L' in locals():
        print(f"Actual shape: {W_L.shape}")


# --- Visualization 4: Heatmap of Sorted Neuron Activations ---
# Ensure X_probe and y_probe are defined correctly
if 'X_probe' in locals() and 'y_probe' in locals() and X_probe.shape[0] == y_probe.shape[0]:
    print("\nGenerating Heatmap of Sorted Neuron Activations...")
    # 1. Sort columns (neurons) by correlation
    neuron_corr_sorted_indices = np.argsort(neuron_sum_correlations)
    X_probe_neurons_sorted = X_probe[:, neuron_corr_sorted_indices]

    # 2. Sort rows (inputs) by the target sum y_probe
    input_sum_sorted_indices = np.argsort(y_probe)
    X_probe_fully_sorted = X_probe_neurons_sorted[input_sum_sorted_indices, :]

    plt.figure(figsize=(15, 8))
    # Use robust=True to handle potential outliers in activations
    # Use cbar=False to avoid clutter, as absolute activation values might vary
    sns.heatmap(X_probe_fully_sorted, cmap='viridis', robust=True, cbar=False)
    plt.title('MLP Activations (Sorted)')
    plt.xlabel('Neurons (Sorted by Correlation with Sum a+b)')
    plt.ylabel('Input Samples (Sorted by Sum a+b)')
    plt.xticks([]) # Hide neuron indices
    plt.yticks([]) # Hide sample indices
    plt.show()
else:
    print("\nSkipping Heatmap of Activations: X_probe or y_probe not found or shapes mismatch.")
    if 'X_probe' in locals():
        print(f"X_probe shape: {X_probe.shape}")
    if 'y_probe' in locals():
        print(f"y_probe shape: {y_probe.shape}")

We observe widespread strong positive correlation of neuron activations with the target sum. In fact, more than 130 neurons have a correlation exceeding 0.9 with the target sum. This shows that the MLP represents the sum in a highly distributed manner. There aren't localized, specialized 'sum' neurons, per se.

We will now visualize the activation patterns of a few top-correlated neurons as heatmaps across the grid of all possible (a, b) input pairs. This reveals the specific function each of these top neurons computes according to input pairs. Hopefully from this we can hypothesize how they combine the principal components we identified earlier to achieve their strong correlation.

In [None]:
equals_token_id = 10
N = 10
top_n_to_visualize = 3
top_neuron_indices = indices_sorted_by_abs_corr[:top_n_to_visualize]
mlp_hook_point_name = "blocks.0.mlp.hook_post" # Output AFTER ReLU

# --- Data Storage ---
# Create heatmaps initialized with NaN (or zero)
neuron_activation_heatmaps = {
    idx: np.full((N, N), np.nan) for idx in top_neuron_indices
}
captured_mlp_activation_storage = {} # Temporary storage during hook

# --- Hook Function ---
def capture_mlp_post_hook(activation: torch.Tensor, hook: 'HookPoint'):
    """Captures MLP activation at position 2."""
    captured_mlp_activation_storage['activation'] = activation[:, 2, :].detach().cpu()
    return activation

# --- Generate Activations ---
print(f"Generating activations for {top_n_to_visualize} neurons across {N}x{N} grid...")
model.eval() # Ensure model is in evaluation mode
with torch.no_grad(): # No need to track gradients
    for a in range(N):
        for b in range(N):
              input_tokens = torch.tensor([[a, b, equals_token_id]], dtype=torch.long)

              # Clear previous capture
              captured_mlp_activation_storage.clear()

              # Run with hook
              _ = model.run_with_hooks(
                  input_tokens,
                  fwd_hooks=[(mlp_hook_point_name, capture_mlp_post_hook)]
              )

              # Store the activations for the target neurons
              if 'activation' in captured_mlp_activation_storage:
                  full_activation_vector = captured_mlp_activation_storage['activation'][0] # Batch size 1
                  for neuron_idx in top_neuron_indices:
                      neuron_activation_heatmaps[neuron_idx][a, b] = full_activation_vector[neuron_idx].item()
print("Activation generation complete.")

# --- Plotting ---
fig, axes = plt.subplots(1, top_n_to_visualize, figsize=(6 * top_n_to_visualize, 5))
if top_n_to_visualize == 1: # Handle case of single subplot
    axes = [axes]

fig.suptitle("MLP Neuron Activations (Post-ReLU) at Position '=' vs. Inputs (a, b)")

for i, neuron_idx in enumerate(top_neuron_indices):
    ax = axes[i]
    heatmap_data = neuron_activation_heatmaps[neuron_idx]
    sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap="viridis", ax=ax,
                linewidths=.5, linecolor='gray', cbar=True, square=True,
                # Mask cells with NaN so they don't show color
                mask=np.isnan(heatmap_data))
    ax.set_title(f"Neuron {neuron_idx}\nCorr w/ Sum: {neuron_sum_correlations[neuron_idx]:.3f}")
    ax.set_xlabel("Input b = F(n-1)")
    ax.set_ylabel("Input a = F(n-2)")
    # Set ticks to match input values
    ax.set_xticks(np.arange(N) + 0.5)
    ax.set_yticks(np.arange(N) + 0.5)
    ax.set_xticklabels(np.arange(N))
    ax.set_yticklabels(np.arange(N))
    ax.invert_yaxis() # Convention often puts (0,0) at top-left for matrices

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
plt.show()

The heatmaps show that the top-correlated neurons implement a function approximating ReLU(linear_function_of_sum). We observe a region of 0s (e.g., all values adding up to 6 or less for the middle neuron above), which is likely a result of the ReLU cutting off negatrive pre-activations, which basically creates a threshold below which the neurons do not fire. Further, we get roughly linear behavior with the sum of inputs above the threshold, as expected from the ReLU. The fact that the patterns are so similar across neurons again emphasizes that the computation implemented by the model is distributed.

We will now examine the pre-ReLU activations for these same neurons to test our hypothesis. If the core computation is indeed a linear function of the sum that is simply thresholded by ReLU, then the pre-ReLU heatmaps should reveal this underlying linear relationship more clearly across the entire input grid, without the large zeroed-out regions, thus confirming the thresholding mechanism.

In [None]:
mlp_hook_point_name_pre = "blocks.0.mlp.hook_pre" # Input BEFORE ReLU

# --- Data Storage ---
neuron_pre_activation_heatmaps = {
    idx: np.full((N, N), np.nan) for idx in top_neuron_indices
}
captured_mlp_pre_activation_storage = {} # Temporary storage during hook

# --- Hook Function ---
def capture_mlp_pre_hook(activation: torch.Tensor, hook: 'HookPoint'):
    """Captures MLP activation input (before ReLU) at position 2."""
    captured_mlp_pre_activation_storage['activation'] = activation[:, 2, :].detach().cpu()
    return activation

# --- Generate Activations ---
print(f"Generating PRE-ReLU activations for {len(top_neuron_indices)} neurons across {N}x{N} grid...")
model.eval()
with torch.no_grad():
    for a in range(N):
        for b in range(N):
              input_tokens = torch.tensor([[a, b, equals_token_id]], dtype=torch.long)

              # Clear previous capture
              captured_mlp_pre_activation_storage.clear()

              # Run with PRE-ReLU hook
              _ = model.run_with_hooks(
                  input_tokens,
                  fwd_hooks=[(mlp_hook_point_name_pre, capture_mlp_pre_hook)]
              )

              # Store the activations for the target neurons
              if 'activation' in captured_mlp_pre_activation_storage:
                  full_activation_vector = captured_mlp_pre_activation_storage['activation'][0]
                  for neuron_idx in top_neuron_indices:
                      neuron_pre_activation_heatmaps[neuron_idx][a, b] = full_activation_vector[neuron_idx].item()
print("PRE-ReLU activation generation complete.")

# --- Plotting ---
fig, axes = plt.subplots(1, len(top_neuron_indices), figsize=(6 * len(top_neuron_indices), 5))
if len(top_neuron_indices) == 1: axes = [axes]

fig.suptitle("MLP Neuron Activations (PRE-ReLU) at Position '=' vs. Inputs (a, b)")

for i, neuron_idx in enumerate(top_neuron_indices):
    ax = axes[i]
    # Use a diverging colormap like 'coolwarm' or 'RdBu_r' for pre-ReLU
    # as values can be positive or negative. Center the color map at 0.
    heatmap_data = neuron_pre_activation_heatmaps[neuron_idx]
    max_abs_val = np.nanmax(np.abs(heatmap_data)) # Find max absolute value for symmetric color scale
    sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap="coolwarm", ax=ax,
                linewidths=.5, linecolor='gray', cbar=True, square=True,
                mask=np.isnan(heatmap_data),
                vmin=-max_abs_val, vmax=max_abs_val) # Center colormap around 0
    ax.set_title(f"Neuron {neuron_idx} (Pre-ReLU)")
    ax.set_xlabel("Input b = F(n-1)")
    ax.set_ylabel("Input a = F(n-2)")
    ax.set_xticks(np.arange(N) + 0.5)
    ax.set_yticks(np.arange(N) + 0.5)
    ax.set_xticklabels(np.arange(N))
    ax.set_yticklabels(np.arange(N))
    ax.invert_yaxis()

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

We in fact observe a linear relation between value of the sum and pre-ReLU neruon value. However, note that the top neuron has a lot of zero values, whereas the other two go from negative to positive with similar magnitude. This indicates that they learn different linear functions (slopes, biases). However, part of this difference is thresholded away by the ReLU activation.

We now want to fit the pre-ReLU values to the PCA components to understand exactly what function thereof they are computing.

In [None]:
import pandas as pd

target_neuron_idx = 408
feature_list = []
target_activation_list = []

# Ensure pca_scores is NumPy for easier indexing if needed, keep precision
pca_scores_np = pca_scores.detach().cpu().numpy().astype(np.float32)
pre_activation_target_neuron = neuron_pre_activation_heatmaps[target_neuron_idx] # Get the heatmap

for a in range(N):
    for b in range(N):
        if a + b < N: # Only use valid sums
            # Get pre-calculated pre-ReLU activation
            target_activation = pre_activation_target_neuron[a, b]

            # Ensure we don't include NaN values if any exist in heatmap
            if not np.isnan(target_activation):
                # Extract PCA features for a and b
                pc1_a = pca_scores_np[a, 0]
                pc2_a = pca_scores_np[a, 1]
                pc1_b = pca_scores_np[b, 0]
                pc2_b = pca_scores_np[b, 1]

                # Add more features if desired (e.g., interactions)
                # For now, just the 4 core PCA features
                features = [pc1_a, pc2_a, pc1_b, pc2_b]
                feature_list.append(features)
                target_activation_list.append(target_activation)

# Convert lists to NumPy arrays
X_features_for_neuron_model = np.array(feature_list, dtype=np.float32)
y_target_neuron_pre_activation = np.array(target_activation_list, dtype=np.float32)

# Define feature names for interpreting coefficients
feature_names = ['PC1(a)', 'PC2(a)', 'PC1(b)', 'PC2(b)']

print(f"Prepared data for modeling neuron {target_neuron_idx}:")
print(f"  X_features shape: {X_features_for_neuron_model.shape}")
print(f"  y_target shape: {y_target_neuron_pre_activation.shape}")

# --- Fit the Linear Regression Model ---
neuron_model = LinearRegression()
neuron_model.fit(X_features_for_neuron_model, y_target_neuron_pre_activation)

# --- Evaluate the fit ---
y_neuron_pred = neuron_model.predict(X_features_for_neuron_model)
r2_neuron_fit = r2_score(y_target_neuron_pre_activation, y_neuron_pred)

print(f"\nLinear Model Fit for Neuron {target_neuron_idx} (Pre-ReLU):")
print(f"  R-squared of fit: {r2_neuron_fit:.4f}")

# --- Interpret the model ---
print(f"  Learned coefficients:")
coeffs = pd.Series(neuron_model.coef_, index=feature_names)
print(coeffs)
print(f"  Intercept: {neuron_model.intercept_:.4f}")

The excellent R² value confirms that Neuron 240's pre-ReLU activation is accurately modeled as a linear combination of the primary (linear value) and secondary (parabolic centrality) features derived from the input embeddings. The near-identical coefficients for the a and b components demonstrate symmetric treatment of the inputs, while the dominance of the PC1 coefficients confirms the neuron primarily computes a function proportional to the sum a+b. The smaller but non-zero PC2 coefficients suggest this core computation is subtly modulated based on the inputs' centrality, likely for fine-tuning or boundary adjustments within the N=10 range.

Something confuses me: it seems that the model just directly learns to add, as the linear component of the PCA is the most significant. However, the fact that there's a non-negligible contribution from the second component of the PCA makes me doubt this. In fact, the first one only explains 73% of the variance! That's a lot, but far from everything.

This motivates performing an ablation study, in which we get rid of the second principal component and investigate how (if at all) this affects model performance.

In [None]:
W_E_mean_gpu = W_E_mean.to(model.cfg.device) # Ensure mean is on correct device
pc1_vec = principal_components[0].to(model.cfg.device)
pc2_vec = principal_components[1].to(model.cfg.device)

# --- Squeeze the mean vector to be explicitly 1D ---
W_E_mean_1d_gpu = W_E_mean_gpu.squeeze(0) # Shape becomes [d_model]

def ablate_pc2_hook(
    activation: torch.Tensor, # Shape [batch, seq_len, d_model]
    hook: 'HookPoint'
):
    # Make a copy to modify
    ablated_activation = activation.clone()

    # Iterate through batch and relevant sequence positions (0 and 1 for a, b)
    for batch_idx in range(activation.shape[0]):
        for seq_idx in [0, 1]: # Ablate for token a and token b
            emb = activation[batch_idx, seq_idx, :] # Should be [d_model]

            # --- Subtract the squeezed 1D mean ---
            emb_centered = emb - W_E_mean_1d_gpu # Both are [d_model], result is [d_model]

            # Project - Now both inputs to torch.dot are 1D
            score_pc1 = torch.dot(emb_centered, pc1_vec)
            # score_pc2 = torch.dot(emb_centered, pc2_vec) # Still not needed

            # Reconstruct using only PC1 contribution
            # score_pc1 is a scalar, pc1_vec is [d_model] -> broadcasting works
            emb_centered_ablated = score_pc1 * pc1_vec

            # Add the 1D mean back
            emb_ablated = emb_centered_ablated + W_E_mean_1d_gpu

            # Put modified embedding back
            ablated_activation[batch_idx, seq_idx, :] = emb_ablated

    # Return the modified activations
    return ablated_activation

# --- How to run ---
embedding_hook_point = "hook_embed" # Or "hook_pos_embed", check model.hook_dict
logits = model.run_with_hooks(
     input_tokens,
     fwd_hooks=[(embedding_hook_point, ablate_pc2_hook)]
 )

test_pairs = []
test_labels_list = []
for a in range(N):
    for b in range(N):
          test_pairs.append([a, b, equals_token_id])
          test_labels_list.append(a + b)

# Convert to tensors
test_input_tokens = torch.tensor(test_pairs, dtype=torch.long).to(model.cfg.device)
test_labels = torch.tensor(test_labels_list, dtype=torch.long).to(model.cfg.device)

print(f"Created test data with {test_input_tokens.shape[0]} samples.")

# --- Helper function to calculate accuracy ---
def calculate_accuracy(logits, labels):
    """Calculates accuracy given logits and labels."""
    # Logits shape: [batch, seq_len, d_vocab]
    # We only care about the prediction at the last position (index 2)
    prediction_logits = logits[:, -1, :] # Shape: [batch, d_vocab]
    predicted_tokens = torch.argmax(prediction_logits, dim=-1) # Shape: [batch]
    correct_predictions = (predicted_tokens == labels).sum().item()
    total_predictions = labels.shape[0]
    accuracy = correct_predictions / total_predictions
    return accuracy, correct_predictions, total_predictions

# --- 2. Calculate Original Accuracy (No Hooks) ---
model.eval() # Set model to evaluation mode
with torch.no_grad(): # Use no_grad for inference efficiency
    original_logits = model(test_input_tokens)
    original_accuracy, orig_correct, orig_total = calculate_accuracy(original_logits, test_labels)

print(f"\nOriginal Model Performance:")
print(f"  Accuracy: {original_accuracy:.4f} ({orig_correct}/{orig_total})")

# --- 3. Calculate Ablated Accuracy (With Hook) ---
with torch.no_grad():
    ablated_logits = model.run_with_hooks(
        test_input_tokens,
        fwd_hooks=[(embedding_hook_point, ablate_pc2_hook)]
    )
    ablated_accuracy, ablated_correct, ablated_total = calculate_accuracy(ablated_logits, test_labels)

print(f"\nPC2 Ablated Model Performance:")
print(f"  Accuracy: {ablated_accuracy:.4f} ({ablated_correct}/{ablated_total})")

# --- 4. Compare ---
accuracy_drop = original_accuracy - ablated_accuracy
print(f"\nAccuracy Drop due to PC2 Ablation: {accuracy_drop:.4f}")

Accuracy drops by a lot! We go from getting it right 100% of the time, to only 42%. So the second principal component is still very important, and whatever non-linearity it's doing actually matters. Let's now look specifically at which input pairs the model fails at when we ablate the PC2.

In [None]:
# --- Get predictions from the ablated run ---
ablated_logits_last_pos = ablated_logits[:, -1, :] # Shape: [num_samples, d_vocab]
ablated_predicted_tokens = torch.argmax(ablated_logits_last_pos, dim=-1) # Shape: [num_samples]

# --- Identify incorrect predictions ---
incorrect_mask = (ablated_predicted_tokens != test_labels)
incorrect_indices = torch.where(incorrect_mask)[0].cpu().numpy()

# --- Map indices back to (a, b) pairs ---
failing_pairs = []
original_a_values = test_input_tokens[:, 0].cpu().numpy()
original_b_values = test_input_tokens[:, 1].cpu().numpy()
true_sums = test_labels.cpu().numpy()
predicted_sums_ablated = ablated_predicted_tokens.cpu().numpy()

for idx in incorrect_indices:
    a = original_a_values[idx]
    b = original_b_values[idx]
    true_sum = true_sums[idx]
    predicted_sum = predicted_sums_ablated[idx]
    failing_pairs.append({
        'a': a,
        'b': b,
        'True Sum (a+b)': true_sum,
        'Predicted Sum (Ablated)': predicted_sum
    })

# --- Display the failing pairs ---
if not failing_pairs:
    print("No failures found after PC2 ablation (unexpected based on previous accuracy).")
else:
    print(f"Found {len(failing_pairs)} failures after PC2 ablation:")
    fail_df = pd.DataFrame(failing_pairs)
    # Sort for potentially easier pattern spotting
    fail_df = fail_df.sort_values(by=['True Sum (a+b)', 'a', 'b']).reset_index(drop=True)
    print(fail_df.to_string()) # Use to_string to print the full DataFrame

# Optional: Visualize failures on a heatmap
failure_heatmap = np.zeros((N, N))
for failure in failing_pairs:
    failure_heatmap[failure['a'], failure['b']] = 1 # Mark failures with 1

plt.figure(figsize=(6, 5))
sns.heatmap(failure_heatmap, cmap="Reds", linewidths=.5, linecolor='gray', cbar=False, annot=False)
plt.title("Failure Cases (Red) after PC2 Ablation")
plt.xlabel("Input b = F(n-1)")
plt.ylabel("Input a = F(n-2)")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()

This is interesting... There's definitely patterns, but they are not easy to interpret. Some thoughts:

1.   Symmetry: Confirmed. Reinforces that a and b are treated symmetrically.
2.   Failures involving 0: This is a huge clue. The model fails for (0, 1) through (0, 8) and (1, 0) through (8, 0). However, it succeeds for (0, 0), (0, 9), and (9, 0).
3.   Same observation for failures involving 0. This suggests some symmetry in how the model treats small and big numbers.
3.   Central Failures: There's definitely a cluster of failures for pairs where both a and b are "mid-range" (roughly 2-7).

Given that the PC2 ablation dramatically impacts accuracy with a complex failure pattern—particularly affecting pairs involving edge tokens (0 and 9) differently and the central diagonal—it strongly suggests the issue lies not just in the MLP's internal calculation, but crucially in how the final output is derived from the MLP state. This points directly to the Unembedding layer (W_U), which performs this final mapping to output logits. Therefore, the next logical step is to analyze W_U to understand how it interprets the MLP's combined PC1/PC2 representation and why removing the PC2 component leads to these specific, non-uniform failures in decoding the sum.


In [None]:
hook_point_mlp_out = "blocks.0.hook_mlp_out"       # Output of the MLP layer
hook_point_resid_post = "blocks.0.hook_resid_post"  # Output of the whole block (after skip add)

# --- Shared dictionary for activation passing ---
# This dictionary will be accessible by both hooks during the run
captured_activations_for_run = {}

# --- Hook Function to CAPTURE mlp_out ---
def capture_mlp_out_hook(activation, hook):
    """Captures the mlp_out activation into the shared dictionary."""
    captured_activations_for_run['mlp_out'] = activation.detach()
    # Must return the original activation to not affect the forward pass here
    return activation

# --- Hook Function to MODIFY resid_post ---
def replace_resid_post_hook(activation, hook):
    """
    Replaces the resid_post activation with the captured mlp_out
    from the shared dictionary.
    """
    if 'mlp_out' not in captured_activations_for_run:
         print(f"Warning: Capture hook '{hook_point_mlp_out}' did not run before modify hook '{hook.name}'")
         return activation # Safety fallback

    mlp_out_value = captured_activations_for_run['mlp_out']

    # Ensure shapes are compatible
    if activation.shape == mlp_out_value.shape:
        # --- IMPORTANT: Return the captured mlp_out value ---
        # This replaces the original resid_post value (attn_out + mlp_out)
        # with just mlp_out for the rest of the forward pass.
        return mlp_out_value
    else:
        print(f"Shape mismatch warning: {hook.name} ({activation.shape}) vs captured mlp_out ({mlp_out_value.shape})")
        return activation # Fallback

# --- Run the Ablation ---
model.eval()
with torch.no_grad():
    # Clear the shared dictionary before the run
    captured_activations_for_run.clear()

    # Run with BOTH hooks active simultaneously
    # The library executes hooks in the order they appear in the forward pass
    # hook_mlp_out runs before hook_resid_post, so this works.
    print(f"Running model with capture hook at '{hook_point_mlp_out}' and modify hook at '{hook_point_resid_post}'")
    skip_ablated_logits = model.run_with_hooks(
        test_input_tokens,
        fwd_hooks=[
            (hook_point_mlp_out, capture_mlp_out_hook),
            (hook_point_resid_post, replace_resid_post_hook)
        ]
    )
    print("Model run with hooks completed.")


# --- Calculate and Compare Accuracy ---
# Recalculate original accuracy
original_logits = model(test_input_tokens) # Run without hooks
original_accuracy, orig_correct, orig_total = calculate_accuracy(original_logits, test_labels)

# Calculate ablated accuracy
skip_ablated_accuracy, skip_ablated_correct, skip_ablated_total = calculate_accuracy(skip_ablated_logits, test_labels)

print(f"\nOriginal Model Accuracy: {original_accuracy:.4f} ({orig_correct}/{orig_total})")
print(f"Skip Conn Ablated Accuracy (MLP Only): {skip_ablated_accuracy:.4f} ({skip_ablated_correct}/{skip_ablated_total})")
print(f"Accuracy Drop due to Skip Ablation: {original_accuracy - skip_ablated_accuracy:.4f}")

# Clear the dictionary just in case
captured_activations_for_run.clear()

No accuracy is lost by dropping the skip connection (the bit where you re-add the MLP input to its output before following to the output layer of the transformer). This indicates that we can effectively ignore it. Nanda had observed the same. Just to be sure, we are now going to do the same analysis but with the test loss rather than accuracy.

In [None]:
captured_activations_for_run = {}

# --- Run the Ablation and Calculate Logits (same as before) ---
model.eval()
with torch.no_grad():
    # Clear the shared dictionary before the run
    captured_activations_for_run.clear()

    print(f"Running model with capture hook at '{hook_point_mlp_out}' and modify hook at '{hook_point_resid_post}'")
    skip_ablated_logits = model.run_with_hooks(
        test_input_tokens,
        fwd_hooks=[
            (hook_point_mlp_out, capture_mlp_out_hook),
            (hook_point_resid_post, replace_resid_post_hook)
        ]
    )
    print("Model run with hooks completed.")

    # Calculate original logits
    original_logits = model(test_input_tokens)

# --- Calculate Accuracy AND Loss ---

# Original Performance
original_accuracy, orig_correct, orig_total = calculate_accuracy(original_logits, test_labels)
# Select logits for the prediction position (last token)
original_pred_logits = original_logits[:, -1, :]
original_loss = loss_fn(original_pred_logits, test_labels)

# Ablated Performance
skip_ablated_accuracy, skip_ablated_correct, skip_ablated_total = calculate_accuracy(skip_ablated_logits, test_labels)
# Select logits for the prediction position (last token)
skip_ablated_pred_logits = skip_ablated_logits[:, -1, :]
skip_ablated_loss = loss_fn(skip_ablated_pred_logits, test_labels)


# --- Print Results ---
print(f"\nOriginal Model Performance:")
print(f"  Accuracy: {original_accuracy:.4f} ({orig_correct}/{orig_total})")
print(f"  Loss:     {original_loss.item():.6f}") # Use .item() to get scalar value

print(f"\nSkip Conn Ablated Accuracy (MLP Only):")
print(f"  Accuracy: {skip_ablated_accuracy:.4f} ({skip_ablated_correct}/{skip_ablated_total})")
print(f"  Loss:     {skip_ablated_loss.item():.6f}")

print(f"\nAccuracy Drop due to Skip Ablation: {original_accuracy - skip_ablated_accuracy:.4f}")
print(f"Loss Change due to Skip Ablation:   {skip_ablated_loss.item() - original_loss.item():.6f}")

# Clear the dictionary
captured_activations_for_run.clear()

Loss increases by a few orders of magnitude, but in absolute terms is still small. I'd say this verifies our findings.

We are now going to look at the difference in logits when comparing the full trained model and the model in which we have ablated away the second PCA component (the parabolic one) in the embedded representation. We hope this sheds light on what it is that this 2nd component is doing.

In [None]:
# --- Get Original and Ablated Logits ---
model.eval()
with torch.no_grad():
    # Original logits
    original_logits = model(test_input_tokens) # Shape: [num_samples, seq_len, d_vocab]

    # Ablated logits (PC2 removed from embeddings)
    ablated_logits = model.run_with_hooks(
        test_input_tokens,
        fwd_hooks=[(embedding_hook_point, ablate_pc2_hook)]
    )

# --- Calculate Logit Difference ---
logit_diff = original_logits - ablated_logits # Shape: [num_samples, seq_len, d_vocab]

# Focus on the logits at the prediction position (last token)
logit_diff_pred_pos = logit_diff[:, -1, :].cpu().numpy() # Shape: [num_samples, d_vocab]
original_logits_pred_pos = original_logits[:, -1, :].cpu().numpy()
ablated_logits_pred_pos = ablated_logits[:, -1, :].cpu().numpy()

test_labels_np = test_labels.cpu().numpy()

# --- Analyze the Difference ---
results = []
original_a_values = test_input_tokens[:, 0].cpu().numpy()
original_b_values = test_input_tokens[:, 1].cpu().numpy()
ablated_predicted_tokens = np.argmax(ablated_logits_pred_pos, axis=-1)

for i in range(len(test_labels_np)):
    a = original_a_values[i]
    b = original_b_values[i]
    true_label = test_labels_np[i]
    ablated_pred = ablated_predicted_tokens[i]

    # Difference for the TRUE label's logit
    diff_for_true_label = logit_diff_pred_pos[i, true_label]

    # Difference for the ABLATED MODEL'S PREDICTED label's logit
    diff_for_ablated_pred = logit_diff_pred_pos[i, ablated_pred]

    original_correct_logit = original_logits_pred_pos[i, true_label]
    ablated_correct_logit = ablated_logits_pred_pos[i, true_label]

    results.append({
        'a': a,
        'b': b,
        'True Sum': true_label,
        'Ablated Pred': ablated_pred,
        'LogitDiff (True)': diff_for_true_label,
        'LogitDiff (AblatedPred)': diff_for_ablated_pred,
        'OrigLogit (True)': original_correct_logit,
        'AblatedLogit (True)': ablated_correct_logit,
        'Failed': true_label != ablated_pred
    })

results_df = pd.DataFrame(results)

# --- Display Key Information ---
print("Logit Difference Analysis (Original - Ablated):")
print("Positive 'LogitDiff (True)' means PC2 boosted the correct logit.")
print("Negative 'LogitDiff (AblatedPred)' for failures means PC2 suppressed the wrongly predicted logit.")

# Show stats for the change in the correct logit
print("\nStatistics for Logit Difference of the CORRECT sum:")
print(results_df['LogitDiff (True)'].describe())

# Show failure cases focusing on logit changes
failures_df = results_df[results_df['Failed']].sort_values(by=['True Sum', 'a', 'b'])
print(f"\nAnalysis of {len(failures_df)} Failure Cases:")
if not failures_df.empty:
    print(failures_df[['a', 'b', 'True Sum', 'Ablated Pred', 'LogitDiff (True)', 'LogitDiff (AblatedPred)', 'OrigLogit (True)', 'AblatedLogit (True)']].to_string())

# --- Visualize LogitDiff for the Correct Label ---
logit_diff_heatmap = np.full((N, N), np.nan)
for i in range(len(results_df)):
    row = results_df.iloc[i]
    logit_diff_heatmap[row['a'], row['b']] = row['LogitDiff (True)']

plt.figure(figsize=(7, 6))
max_abs_val = np.nanmax(np.abs(logit_diff_heatmap))
sns.heatmap(logit_diff_heatmap, cmap="coolwarm", center=0,
            annot=True, fmt=".2f", linewidths=.5, linecolor='gray',
            vmin=-max_abs_val, vmax=max_abs_val, # Center colorbar at 0
            mask=np.isnan(logit_diff_heatmap), square=True)
plt.title("Logit Difference (Orig - Ablated) for CORRECT Sum")
plt.xlabel("Input b = F(n-1)")
plt.ylabel("Input a = F(n-2)")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()

PC2 isn't just "fine-tuning". It's part of a complex encoding strategy interpreted by the Unembedding layer (W_U).

It acts as a strong identifier for "double-extreme" inputs ((0,0), (9,9)), massively boosting their correct logits.

It acts differently for "single-extreme" inputs ((0,k), (k,9)), where its net effect might even reduce the correct logit slightly, but likely plays a vital role in suppressing competitors.

It provides a moderate general boost for central inputs.

Let us now visualize by how much the ablated model is wrong, rather than just the fact that it is wrong.

In [None]:
prediction_error_heatmap = np.full((N, N), np.nan)

num_samples = len(test_labels_np)
prediction_errors = np.zeros(num_samples) # Store errors for potential stats

for i in range(num_samples):
    a_val = original_a_values[i]
    b_val = original_b_values[i]
    ablated_pred = ablated_predicted_tokens[i]
    true_label = test_labels_np[i]

    # Calculate the arithmetic difference
    error = float(ablated_pred - true_label) # Ensure float for calculations
    prediction_errors[i] = error

    # Place the error in the heatmap
    prediction_error_heatmap[a_val, b_val] = error

# --- Visualize Prediction Error Heatmap ---
plt.figure(figsize=(7, 6))

# Find the max absolute error for centering the color bar correctly
max_abs_error = np.nanmax(np.abs(prediction_error_heatmap))
# Handle case where there are no errors or all NaNs
if np.isnan(max_abs_error):
    max_abs_error = 1.0 # Default if no errors

# Use a diverging colormap since error can be positive or negative
cmap = "coolwarm" # Red for positive error (prediction > true), Blue for negative

sns.heatmap(prediction_error_heatmap, cmap=cmap, center=0,
            annot=True, fmt=".0f", # Show integer errors
            linewidths=.5, linecolor='lightgray',
            square=True,
            vmin=-max_abs_error, vmax=max_abs_error, # Center color scale around 0
            cbar_kws={'label': 'Ablated Prediction - True Sum'},
            mask=np.isnan(prediction_error_heatmap)) # Mask cells not in test data

plt.title("Prediction Error of Ablated Model (PC2 Removed)")
plt.xlabel("Input b = F(n-1)")
plt.ylabel("Input a = F(n-2)")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()

# --- Optional: Display Error Statistics ---
print("\nStatistics for Prediction Error (Ablated Pred - True Sum):")
# Create a temporary series excluding NaNs if your test set wasn't exhaustive
valid_errors = prediction_errors[~np.isnan(prediction_errors)]
if len(valid_errors) > 0:
    error_series = pd.Series(valid_errors)
    print(error_series.describe())
    print(f"\nNumber of incorrect predictions (error != 0): {np.sum(valid_errors != 0)}")
    print(f"Number of correct predictions (error == 0): {np.sum(valid_errors == 0)}")
else:
    print("No valid error data found (check test set population).")

Some observations: It's never wrong by more than 3. And is most often wrong by 1. It's off by 3 only at 0,0 and 9,9 -- it predicts 3 for 0,0 and 15 for 9,9. It's off by positive values when at least one of the inputs is 0, off by negative values when at least one of the inputs is 9. Exception made to 0,9 and 9,0, where it is correct -- suggesting that two wrongs make a right. Further, for low-ish values of a,b it underpredicts (off by negative), for high-ish values of a,b it overpredicts (off by positive).

Can we conclude anything about whatever algorithm it's implementing based on this? Let's fit the predictions of the ablated model to a linear function of the sum.

In [None]:
print("Fitting ablated predictions to a linear function of the true sum...")

# --- Prepare Data for Linear Regression ---
# X: Independent variable (True Sum)
# Needs to be a 2D array for scikit-learn, even with one feature
X_true_sum = test_labels_np.reshape(-1, 1)

# y: Dependent variable (Ablated Prediction)
y_ablated_pred = ablated_predicted_tokens

# --- Fit Linear Regression Model ---
lin_reg = LinearRegression()
lin_reg.fit(X_true_sum, y_ablated_pred)

# --- Get Results ---
# Slope (m)
learned_m = lin_reg.coef_[0]
# Intercept (c)
learned_c = lin_reg.intercept_
# Predictions from the fitted linear model
y_linear_fit_pred = lin_reg.predict(X_true_sum)

# --- Evaluate the Fit ---
# R-squared score
r2 = r2_score(y_ablated_pred, y_linear_fit_pred)
# Alternatively: r2 = lin_reg.score(X_true_sum, y_ablated_pred)
# Mean Squared Error
mse = mean_squared_error(y_ablated_pred, y_linear_fit_pred)

# --- Print Results and Comparison ---
print("\n--- Linear Fit Results: AblatedPred = m * TrueSum + c ---")
print(f"Learned Slope (m): {learned_m:.4f}")
print(f"Learned Intercept (c): {learned_c:.4f}")
print(f"R-squared score: {r2:.4f}")
print(f"Mean Squared Error: {mse:.4f}")

# Compare to hypothesized values
hypothesized_m = 2/3
hypothesized_c = 3
print("\nComparison to Hypothesis (m = 2/3, c = 3):")
print(f"Hypothesized m: {hypothesized_m:.4f}")
print(f"Hypothesized c: {hypothesized_c:.4f}")
print(f"Difference in m: {learned_m - hypothesized_m:.4f}")
print(f"Difference in c: {learned_c - hypothesized_c:.4f}")

# --- Visualize the Fit ---
plt.figure(figsize=(8, 6))
plt.scatter(test_labels_np, y_ablated_pred, alpha=0.5, label='Actual Ablated Predictions')
plt.plot(test_labels_np, y_linear_fit_pred, color='red', linewidth=2, label=f'Linear Fit (R²={r2:.3f})')
# Plot the hypothesized line
hypothesized_line = hypothesized_m * test_labels_np + hypothesized_c
plt.plot(test_labels_np, hypothesized_line, color='green', linestyle='--', linewidth=2, label='Hypothesized Line (m=2/3, c=3)')
# Plot the ideal line (y=x) for reference
plt.plot(test_labels_np, test_labels_np, color='gray', linestyle=':', linewidth=1, label='Ideal (y=x)')


plt.title('Linear Fit: Ablated Prediction vs. True Sum')
plt.xlabel('True Sum (a+b)')
plt.ylabel('Ablated Model Prediction')
plt.legend()
plt.grid(True, alpha=0.3)
# Ensure axes cover the full possible range (0 to max possible sum, e.g., 18)
max_sum = np.max(test_labels_np) if len(test_labels_np) > 0 else 18
max_pred = np.max(y_ablated_pred) if len(y_ablated_pred) > 0 else max_sum
plot_max = max(max_sum, max_pred)
plt.xlim(-0.5, plot_max + 0.5)
plt.ylim(-0.5, plot_max + 0.5)
plt.show()

Observation: the ablated model makes different predictions for some combinations of inputs which are different but have the same sum, which is why there are multiple blue dots on the same vertical line in the plot above. This necessarily implies that we cannot perfectly fit the ablated model's predictions to a function of the sum, we need to consider a and b separately.

In [None]:
print("\nFitting ablated predictions to a linear function of a and b separately...")

# --- Prepare Data for Multiple Linear Regression ---
# X: Independent variables (a, b)
# Create a 2D array where columns are 'a' and 'b'
X_inputs_ab = np.stack((original_a_values, original_b_values), axis=-1) # Shape: [num_samples, 2]

# y: Dependent variable (Ablated Prediction)
y_ablated_pred = ablated_predicted_tokens

# --- Fit Linear Regression Model ---
multi_lin_reg = LinearRegression()
multi_lin_reg.fit(X_inputs_ab, y_ablated_pred)

# --- Get Results ---
# Coefficients (w_a, w_b)
learned_wa, learned_wb = multi_lin_reg.coef_
# Intercept (c)
learned_c_multi = multi_lin_reg.intercept_
# Predictions from the fitted linear model
y_multilinear_fit_pred = multi_lin_reg.predict(X_inputs_ab)

# --- Evaluate the Fit ---
r2_multi = r2_score(y_ablated_pred, y_multilinear_fit_pred)
mse_multi = mean_squared_error(y_ablated_pred, y_multilinear_fit_pred)

# --- Print Results ---
print("\n--- Multiple Linear Fit Results: AblatedPred = wa*a + wb*b + c ---")
print(f"Learned Weight for a (wa): {learned_wa:.4f}")
print(f"Learned Weight for b (wb): {learned_wb:.4f}")
print(f"Learned Intercept (c): {learned_c_multi:.4f}")
print(f"R-squared score: {r2_multi:.4f}")
print(f"Mean Squared Error: {mse_multi:.4f}")

# --- Compare with previous single-variable fit ---
# Assuming 'r2', 'learned_m', 'learned_c' from the previous fit exist
if 'r2' in locals():
    print(f"\nComparison to previous fit (AblatedPred = m*(a+b) + c):")
    print(f"  Previous R-squared: {r2:.4f}")
    print(f"  Improvement in R-squared: {r2_multi - r2:.4f}")
    print(f"  Previous m: {learned_m:.4f} (Avg of wa, wb = {(learned_wa + learned_wb)/2:.4f})")
    print(f"  Previous c: {learned_c:.4f} (Current c = {learned_c_multi:.4f})")

# --- Optional: Visualize Residuals (If R-squared isn't close to 1) ---
residuals = y_ablated_pred - y_multilinear_fit_pred
plt.figure(figsize=(8, 6))
plt.scatter(y_multilinear_fit_pred, residuals, alpha=0.5)
plt.hlines(0, xmin=min(y_multilinear_fit_pred), xmax=max(y_multilinear_fit_pred), colors='red', linestyles='--')
plt.title('Residual Plot for Multi-Linear Fit')
plt.xlabel('Fitted Values (wa*a + wb*b + c)')
plt.ylabel('Residuals (Actual - Fitted)')
plt.grid(True, alpha=0.3)
plt.show()

# --- Optional: Heatmap of Residuals ---
residual_heatmap = np.full((N, N), np.nan)
fitted_values_heatmap = np.full((N, N), np.nan)
num_samples = len(y_ablated_pred)

for i in range(num_samples):
    a_val = original_a_values[i]
    b_val = original_b_values[i]
    residual_heatmap[a_val, b_val] = residuals[i]
    fitted_values_heatmap[a_val, b_val] = y_multilinear_fit_pred[i] # Store fitted value

plt.figure(figsize=(7, 6))
max_abs_resid = np.nanmax(np.abs(residual_heatmap))
if np.isnan(max_abs_resid): max_abs_resid = 1.0

cmap_resid = "coolwarm"
sns.heatmap(residual_heatmap, cmap=cmap_resid, center=0,
            annot=True, fmt=".2f",
            linewidths=.5, linecolor='lightgray',
            square=True,
            vmin=-max_abs_resid, vmax=max_abs_resid,
            cbar_kws={'label': 'Residual (Actual Ablated - Fitted)'},
            mask=np.isnan(residual_heatmap))
plt.title("Residuals of Multi-Linear Fit (wa*a + wb*b + c)")
plt.xlabel("Input b")
plt.ylabel("Input a")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()

The fit to a and b individually is the same as the fit to a+b. Hence, the dependence of the ablated model's prediction depends linearly on the inputs only via their sum. Given that the fit is not perfect (R^2 of ~0.93, not 1), this indicates that there is some non-linear component contributing to the model's output. We think this is introduced by the ReLU, as we see no other sources of non-linearity.

Hypothesis: note that the desired behavior is linear, i.e., just sum the inputs. The complete model implements this, albeit with a PC2 component which is non-linear. The ablated model that does not have the PC2 component fails to implement the linear behavior, and errs. Hence, it might be that the PC2 component's task is to cancel out the wrong non-linearity that the rest of the model appears to learn. We can test this by fitting the error in the ablated model's prediction to the value of the PC2 component for the same inputs.

In [None]:
from sklearn.preprocessing import PolynomialFeatures # For potential non-linear fit

# --- 1. Extract PC2 Features using EXISTING pca_scores ---
print("Extracting PC2 features from existing pca_scores...")
try:
    # Ensure pca_scores is available and is a tensor
    if 'pca_scores' not in locals() or not isinstance(pca_scores, torch.Tensor):
         raise NameError("Variable 'pca_scores' not found or is not a PyTorch tensor.")

    # Convert scores to numpy ONCE for use with sklearn and indexing
    pca_scores_np = pca_scores.detach().cpu().numpy()
    # PC2 values for all tokens in the vocabulary (index 1 is the second PC)
    if pca_scores_np.shape[1] < 2:
         raise ValueError(f"pca_scores_np has shape {pca_scores_np.shape}, but need at least 2 components to extract PC2.")
    pc2_values_all_tokens = pca_scores_np[:, 1]

    # Get PC2 value for input 'a' and 'b' for each sample in the test set
    pc2_a_features = pc2_values_all_tokens[original_a_values]
    pc2_b_features = pc2_values_all_tokens[original_b_values]

    # Create the feature matrix for regression
    X_pc2_features = np.stack((pc2_a_features, pc2_b_features), axis=-1) # Shape: [num_samples, 2]
    pc2_extracted_successfully = True
    print("PC2 features extracted successfully.")

except NameError as e:
    print(f"Error: {e}")
    print("Please ensure 'pca_scores' (Tensor) is computed and available.")
    pc2_extracted_successfully = False
except IndexError as e:
     print(f"Error indexing PC2 values. Check token indices range (0-{N-1}) vs vocab size used in PCA: {e}")
     pc2_extracted_successfully = False
except Exception as e:
    print(f"An unexpected error occurred during PC2 feature extraction: {e}")
    pc2_extracted_successfully = False
print(pc2_b_features)
# --- Proceed only if PC2 features were extracted successfully ---
if pc2_extracted_successfully:
    print("\nFitting Ablated Model Error to PC2 Features...")

    # --- 2. Calculate the Error ---
    y_error = test_labels_np - ablated_predicted_tokens # Correction needed

    # --- 3. Perform Linear Regression ---
    error_reg = LinearRegression()
    error_reg.fit(X_pc2_features, y_error)

    # --- Get Results ---
    learned_w_pc2a, learned_w_pc2b = error_reg.coef_
    learned_c_error = error_reg.intercept_
    y_error_pred = error_reg.predict(X_pc2_features)

    # --- Evaluate Fit ---
    r2_error_fit = r2_score(y_error, y_error_pred)
    mse_error_fit = mean_squared_error(y_error, y_error_pred)

    print("\n--- Linear Fit Results: Error ≈ w_pc2a*PC2(a) + w_pc2b*PC2(b) + c ---")
    print(f"Learned Weight for PC2(a): {learned_w_pc2a:.4f}")
    print(f"Learned Weight for PC2(b): {learned_w_pc2b:.4f}")
    print(f"Learned Intercept: {learned_c_error:.4f}")
    print(f"R-squared score: {r2_error_fit:.4f}")
    print(f"Mean Squared Error: {mse_error_fit:.4f}")

    # --- Optional: Try Polynomial Regression (Degree 2) ---
    print("\n--- Trying Polynomial Fit (Degree 2) ---")
    poly = PolynomialFeatures(degree=2, include_bias=False)
    # Create interaction & quadratic features: [pc2a, pc2b, pc2a^2, pc2a*pc2b, pc2b^2]
    X_pc2_poly_features = poly.fit_transform(X_pc2_features)

    error_poly_reg = LinearRegression()
    error_poly_reg.fit(X_pc2_poly_features, y_error)
    y_error_poly_pred = error_poly_reg.predict(X_pc2_poly_features)
    r2_error_poly_fit = r2_score(y_error, y_error_poly_pred)
    mse_error_poly_fit = mean_squared_error(y_error, y_error_poly_pred)

    print(f"Polynomial Fit R-squared score: {r2_error_poly_fit:.4f}")
    print(f"Polynomial Fit Mean Squared Error: {mse_error_poly_fit:.4f}")
    print(f"Improvement in R-squared: {r2_error_poly_fit - r2_error_fit:.4f}")


    # --- Visualize Error vs Predicted Error ---
    plt.figure(figsize=(8, 6))
    plt.scatter(y_error, y_error_pred, alpha=0.5, label=f'Linear Fit (R²={r2_error_fit:.3f})')
    # Only plot polynomial fit if it's meaningfully better
    if r2_error_poly_fit > r2_error_fit + 0.01:
         plt.scatter(y_error, y_error_poly_pred, alpha=0.5, marker='x', label=f'Poly Fit (R²={r2_error_poly_fit:.3f})')
    # Ideal line y=x
    min_val = min(np.min(y_error), np.min(y_error_pred)) - 0.5 # Add buffer
    max_val = max(np.max(y_error), np.max(y_error_pred)) + 0.5 # Add buffer
    #plt.plot([min_val, max_val], [min_val, max_val], color='red', linestyle='--', label='Ideal Fit (y=x)')
    plt.title('Predicting Ablated Model Error using PC2 Features')
    plt.xlabel('Actual Error (True Sum - Ablated Pred)')
    plt.ylabel('Predicted Error (from PC2 features)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("Skipping error prediction analysis due to issues extracting PC2 features.")

In [None]:


# --- Configuration ---
d_vocab = N # Or N+1 if your output includes a special token mapped within 0..N
d_mlp = model.cfg.d_mlp # Should be 512
num_key_neurons = 100 # Analyze the top 100 correlated neurons
key_neuron_indices = indices_sorted_by_abs_corr[:num_key_neurons]

# --- Step 1: Calculate W_L = W_U.T @ W_out.T ---

# Get the weight matrices
try:
    # Assume stored as [d_model, d_vocab]
    W_U_stored = model.unembed.W_U.detach().cpu()
    W_U = W_U_stored.T # Transpose to get [d_vocab, d_model]
except AttributeError:
    # Assume stored as [d_model, d_vocab]
    W_U_stored = model.W_U.detach().cpu()
    W_U = W_U_stored.T # Transpose to get [d_vocab, d_model]

# Assume stored as [d_mlp, d_model]
W_out_stored = model.blocks[0].mlp.W_out.detach().cpu()
W_out = W_out_stored.T # Transpose to get [d_model, d_mlp]

# Calculate the effective neuron->logit map
print(f"Multiplying W_U shape {W_U.shape} with W_out shape {W_out.shape}")
# This should now be [d_vocab, d_model] @ [d_model, d_mlp]
W_L = W_U @ W_out # Shape: [d_vocab, d_mlp]
W_L_np = W_L.numpy()

print(f"Calculated W_L with shape: {W_L_np.shape}") # Should be (d_vocab, d_mlp)

# --- Step 2: Analyze Rows of W_L (Code remains the same) ---
# ... (rest of the analysis code) ...

print(f"Calculated W_L with shape: {W_L_np.shape}") # Should be (d_vocab, d_mlp), e.g., (10, 512)

# --- Step 2: Analyze Rows of W_L and Neuron Contributions ---

print(f"\nAnalyzing weights in W_L for top {num_key_neurons} sum-correlated neurons:")

analysis_results = []

for k in range(d_vocab): # Iterate through each output logit (0 to N-1)
    # Get the weights from W_L for this specific logit 'k'
    # applied to the key neurons
    weights_for_key_neurons = W_L_np[k, key_neuron_indices] # Shape: [num_key_neurons]

    # Hypothesis Check: Contribution to correct logit
    positivity_ratio = np.mean(weights_for_key_neurons > 0)

    # Redundancy Check: Distribution statistics
    mean_weight = np.mean(weights_for_key_neurons)
    std_weight = np.std(weights_for_key_neurons)
    max_abs_weight = np.max(np.abs(weights_for_key_neurons))
    # Check sparsity: % of weights close to zero (e.g., < 1% of max abs weight)
    sparsity_threshold = 0.01 * max_abs_weight
    sparsity_ratio = np.mean(np.abs(weights_for_key_neurons) < sparsity_threshold)


    print(f"\n--- Logit for output token k={k} ---")
    print(f"  Positivity Ratio (Key Neurons): {positivity_ratio:.2f}")
    print(f"  Sparsity Ratio (<{sparsity_threshold:.4f}): {sparsity_ratio:.2f}")
    print(f"  Weight Stats (Key Neurons): Mean={mean_weight:.4f}, Std={std_weight:.4f}, MaxAbs={max_abs_weight:.4f}")

    analysis_results.append({
        'k': k,
        'positivity_ratio': positivity_ratio,
        'sparsity_ratio': sparsity_ratio,
        'mean_weight': mean_weight,
        'std_weight': std_weight,
        'max_abs_weight': max_abs_weight
    })

# Optional: Create a DataFrame for easier viewing
analysis_df = pd.DataFrame(analysis_results)
print("\n--- Summary Table ---")
print(analysis_df.round(3))

# Optional: Plot histogram for a specific k (e.g., k=5)
plt.figure(figsize=(8, 4))
k_to_plot = 5
plt.hist(W_L_np[k_to_plot, key_neuron_indices], bins=30, alpha=0.7)
plt.title(f"Weight Distribution in W_L[{k_to_plot}, :] for Top {num_key_neurons} Neurons")
plt.xlabel("Weight Value")
plt.ylabel("Frequency")
plt.grid(axis='y', alpha=0.5)
plt.show()

That didn't work! Let's try and fit the output of the ablated model to a quadratic function of the inputs.

In [None]:
print("\nFitting ablated predictions to quadratic functions of a and b...")

# --- Prepare Data ---
# y: Dependent variable (Ablated Prediction)
y_ablated_pred = ablated_predicted_tokens

# --- Fit 1: Requested Quadratic Model: x*(a+b) + y*a^2 + z*b^2 + c ---
print("\n--- Fitting Requested Model: x*(a+b) + y*a^2 + z*b^2 + c ---")

# Construct Features: [a+b, a^2, b^2]
sum_ab = original_a_values + original_b_values
a_squared = original_a_values**2
b_squared = original_b_values**2
X_quadratic_requested = np.stack((sum_ab, a_squared, b_squared), axis=-1) # Shape: [num_samples, 3]

# Fit Linear Regression Model
quad_reg_req = LinearRegression()
quad_reg_req.fit(X_quadratic_requested, y_ablated_pred)

# Get Results
learned_x, learned_y, learned_z = quad_reg_req.coef_
learned_c_req = quad_reg_req.intercept_
y_quad_req_pred = quad_reg_req.predict(X_quadratic_requested)

# Evaluate Fit
r2_quad_req = r2_score(y_ablated_pred, y_quad_req_pred)
mse_quad_req = mean_squared_error(y_ablated_pred, y_quad_req_pred)

print(f"Learned Coefficient for (a+b) (x): {learned_x:.4f}")
print(f"Learned Coefficient for a^2 (y):   {learned_y:.4f}")
print(f"Learned Coefficient for b^2 (z):   {learned_z:.4f}")
print(f"Learned Intercept (c):             {learned_c_req:.4f}")
print(f"R-squared score:                   {r2_quad_req:.4f}")
print(f"Mean Squared Error:                {mse_quad_req:.4f}")


# --- Fit 2: General Quadratic Model using PolynomialFeatures ---
print("\n--- Fitting General Quadratic Model (includes a*b term) ---")

# Create polynomial features (degree 2) from [a, b]
# Includes: a, b, a^2, a*b, b^2 (bias term is handled by LinearRegression)
poly = PolynomialFeatures(degree=2, include_bias=False)
X_inputs_ab = np.stack((original_a_values, original_b_values), axis=-1) # Shape: [num_samples, 2]
X_quadratic_general = poly.fit_transform(X_inputs_ab)
# Get feature names for clarity (order might vary slightly by sklearn version)
feature_names = poly.get_feature_names_out(['a', 'b'])


# Fit Linear Regression Model
quad_reg_gen = LinearRegression()
quad_reg_gen.fit(X_quadratic_general, y_ablated_pred)

# Get Results
learned_coeffs_gen = quad_reg_gen.coef_
learned_c_gen = quad_reg_gen.intercept_
y_quad_gen_pred = quad_reg_gen.predict(X_quadratic_general)

# Evaluate Fit
r2_quad_gen = r2_score(y_ablated_pred, y_quad_gen_pred)
mse_quad_gen = mean_squared_error(y_ablated_pred, y_quad_gen_pred)

print("General Quadratic Model: Pred ≈ w_a*a + w_b*b + w_aa*a^2 + w_ab*a*b + w_bb*b^2 + c_gen")
print("Learned Coefficients:")
for name, coeff in zip(feature_names, learned_coeffs_gen):
    print(f"  Weight for {name}: {coeff:.4f}")
print(f"Learned Intercept (c_gen): {learned_c_gen:.4f}")
print(f"R-squared score:           {r2_quad_gen:.4f}")
print(f"Mean Squared Error:        {mse_quad_gen:.4f}")


# --- Compare Fits ---
print("\n--- R-squared Comparison ---")
if 'r2_multi' in locals(): # Check if previous linear fit results exist
    print(f"Linear Fit (wa*a + wb*b + c):       R² = {r2_multi:.4f}")
else:
    print("Multi-linear fit results not available for comparison.")
print(f"Requested Quadratic Fit:          R² = {r2_quad_req:.4f}")
print(f"General Quadratic Fit (w/ a*b):   R² = {r2_quad_gen:.4f}")


# --- Visualize Residuals for the Best Quadratic Fit ---
best_r2_quad = max(r2_quad_req, r2_quad_gen)
print(f"\nVisualizing residuals for the best quadratic fit (R²={best_r2_quad:.4f})...")
if r2_quad_gen >= r2_quad_req:
    residuals = y_ablated_pred - y_quad_gen_pred
    fitted_values = y_quad_gen_pred
    title_suffix = "General Quadratic Fit"
else:
    residuals = y_ablated_pred - y_quad_req_pred
    fitted_values = y_quad_req_pred
    title_suffix = "Requested Quadratic Fit"

plt.figure(figsize=(8, 6))
plt.scatter(fitted_values, residuals, alpha=0.5)
plt.hlines(0, xmin=min(fitted_values), xmax=max(fitted_values), colors='red', linestyles='--')
plt.title(f'Residual Plot for {title_suffix}')
plt.xlabel('Fitted Values')
plt.ylabel('Residuals (Actual Ablated - Fitted)')
plt.grid(True, alpha=0.3)
plt.show()

# Optional: Heatmap of Residuals
residual_heatmap = np.full((N, N), np.nan)
num_samples = len(y_ablated_pred)
for i in range(num_samples):
    a_val = original_a_values[i]
    b_val = original_b_values[i]
    residual_heatmap[a_val, b_val] = residuals[i] # Use the correct residuals

plt.figure(figsize=(7, 6))
max_abs_resid = np.nanmax(np.abs(residual_heatmap))
if np.isnan(max_abs_resid) or max_abs_resid == 0: max_abs_resid = 1.0

cmap_resid = "coolwarm"
sns.heatmap(residual_heatmap, cmap=cmap_resid, center=0,
            annot=True, fmt=".2f",
            linewidths=.5, linecolor='lightgray',
            square=True,
            vmin=-max_abs_resid, vmax=max_abs_resid,
            cbar_kws={'label': 'Residual (Actual Ablated - Fitted)'},
            mask=np.isnan(residual_heatmap))
plt.title(f"Residuals of {title_suffix}")
plt.xlabel("Input b")
plt.ylabel("Input a")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()



Unfortunately our hypothesis about quadratic PC2 term simply being a correction on top of the linear output was incorrect. Hence, its contribution must be more subtle. To try to understand it, we will now compare the activations of the neurons in the ablated and the full model, to see how removing the PC2 affects their firing.

In [None]:
# --- Storage for Activations ---
activation_store = {}

def store_activation_hook(activation, hook, storage_key):
    # Store activation at the final sequence position
    # Clone and detach to prevent memory leaks and graph issues
    activation_store[storage_key] = activation[:, -1, :].detach().clone()
    # Return original activation to not interfere
    return activation

# --- Execute and Capture Activations ---
a_mlp_full = None
a_mlp_ablated = None
mlp_activation_hook_point = utils.get_act_name("mlp_post", 0)
print(mlp_activation_hook_point)
if mlp_activation_hook_point:
    model.eval() # Ensure model is in evaluation mode
    with torch.no_grad():
        # --- Run 1: Get Full Model Activations ---
        print("Running model normally to capture full MLP activations...")
        activation_store.clear() # Clear previous storage
        _ = model.run_with_hooks(
            test_input_tokens,
            fwd_hooks=[(mlp_activation_hook_point,
                        lambda act, hook: store_activation_hook(act, hook, 'full'))]
        )
        if 'full' in activation_store:
            a_mlp_full = activation_store['full'].cpu().numpy()
            print(f"Captured full activations, shape: {a_mlp_full.shape}")
        else:
            print("Failed to capture full activations.")

        # --- Run 2: Get Ablated Model Activations ---
        print("Running model with PC2 ablation to capture ablated MLP activations...")
        activation_store.clear() # Clear previous storage
        _ = model.run_with_hooks(
            test_input_tokens,
            fwd_hooks=[(embedding_hook_point, ablate_pc2_hook), # Ablate PC2 first
                       (mlp_activation_hook_point,
                        lambda act, hook: store_activation_hook(act, hook, 'ablated'))]
        )
        if 'ablated' in activation_store:
            a_mlp_ablated = activation_store['ablated'].cpu().numpy()
            print(f"Captured ablated activations, shape: {a_mlp_ablated.shape}")
        else:
            print("Failed to capture ablated activations.")


# --- Calculate Difference and Analyze ---
if a_mlp_full is not None and a_mlp_ablated is not None:
    if a_mlp_full.shape == a_mlp_ablated.shape:
        print("\nCalculating activation differences...")
        delta_a_mlp = a_mlp_full - a_mlp_ablated # Shape: [num_samples, d_mlp]

        # --- Analysis Focused on Failures ---
        failure_mask = results_df['Failed'].values
        num_failures = failure_mask.sum()

        if num_failures > 0:
            print(f"Analyzing {num_failures} failure cases...")
            delta_a_mlp_failures = delta_a_mlp[failure_mask] # Shape: [num_failures, d_mlp]

            # --- 1. Which neurons change the most on average during failure? ---
            mean_abs_delta_failures = np.mean(np.abs(delta_a_mlp_failures), axis=0)
            sorted_neuron_indices_by_delta = np.argsort(mean_abs_delta_failures)[::-1]

            print("\nTop 10 Neurons by Mean Absolute Activation Change During Failures:")
            for i in range(min(10, d_mlp)):
                idx = sorted_neuron_indices_by_delta[i]
                print(f"  Neuron {idx}: Mean Abs Delta = {mean_abs_delta_failures[idx]:.4f} (Corr={neuron_sum_correlations[idx]:.3f})")

            # --- 2. Visualize Average Delta for ALL Failures ---
            mean_delta_failures = np.mean(delta_a_mlp_failures, axis=0) # Shape: [d_mlp,]

            # Sort by correlation group for visualization
            corr_sorted_indices = np.argsort(neuron_sum_correlations)
            mean_delta_failures_sorted = mean_delta_failures[corr_sorted_indices]
            correlations_sorted = neuron_sum_correlations[corr_sorted_indices]

            plt.figure(figsize=(15, 6))
            colors = ['red' if c < 0 else 'blue' for c in correlations_sorted]
            plt.bar(range(d_mlp), mean_delta_failures_sorted, color=colors, width=1.0)
            plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
            plt.ylabel("Mean Activation Diff (Full - Ablated)")
            plt.title(f"Mean MLP Activation Difference on Failure Cases ({num_failures} samples)")
            plt.xticks([]) # Too many neurons to label
            plt.grid(axis='y', alpha=0.5)
            # Add legend for color (optional, can be tricky for bar charts)
            from matplotlib.lines import Line2D
            legend_elements = [Line2D([0], [0], color='blue', lw=4, label='Positively Correlated Neurons'),
                               Line2D([0], [0], color='red', lw=4, label='Negatively Correlated Neurons')]
            plt.legend(handles=legend_elements)
            plt.show()

            # --- 3. Visualize Average Delta for Specific Failure Groups (Example: a=0 failures) ---
            a0_failure_mask = failure_mask & (results_df['a'] == 0)
            num_a0_failures = a0_failure_mask.sum()

            if num_a0_failures > 0:
                print(f"\nAnalyzing {num_a0_failures} failure cases where a=0...")
                delta_a_mlp_a0_failures = delta_a_mlp[a0_failure_mask]
                mean_delta_a0_failures = np.mean(delta_a_mlp_a0_failures, axis=0)
                mean_delta_a0_failures_sorted = mean_delta_a0_failures[corr_sorted_indices]

                plt.figure(figsize=(15, 6))
                plt.bar(range(d_mlp), mean_delta_a0_failures_sorted, color=colors, width=1.0)
                plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
                plt.ylabel("Mean Activation Diff (Full - Ablated)")
                plt.title(f"Mean MLP Activation Difference on Failure Cases with a=0 ({num_a0_failures} samples)")
                plt.xticks([])
                plt.grid(axis='y', alpha=0.5)
                plt.legend(handles=legend_elements)
                plt.show()
            else:
                print("No failure cases found where a=0.")

            # --- 4. Visualize Average Delta for Specific Failure Groups (Example: a=9 failures) ---
            a9_failure_mask = failure_mask & (results_df['a'] == 9)
            num_a9_failures = a9_failure_mask.sum()

            if num_a9_failures > 0:
                print(f"\nAnalyzing {num_a9_failures} failure cases where a=0...")
                delta_a_mlp_a0_failures = delta_a_mlp[a9_failure_mask]
                mean_delta_a0_failures = np.mean(delta_a_mlp_a0_failures, axis=0)
                mean_delta_a0_failures_sorted = mean_delta_a0_failures[corr_sorted_indices]

                plt.figure(figsize=(15, 6))
                plt.bar(range(d_mlp), mean_delta_a0_failures_sorted, color=colors, width=1.0)
                plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
                plt.ylabel("Mean Activation Diff (Full - Ablated)")
                plt.title(f"Mean MLP Activation Difference on Failure Cases with a=9 ({num_a9_failures} samples)")
                plt.xticks([])
                plt.grid(axis='y', alpha=0.5)
                plt.legend(handles=legend_elements)
                plt.show()
            else:
                print("No failure cases found where a=0.")

        else:
            print("No failure cases detected to analyze activation differences.")
    else:
        print("Error: Shape mismatch between full and ablated activations.")
        print(f"Full shape: {a_mlp_full.shape}, Ablated shape: {a_mlp_ablated.shape}")
else:
    print("Skipping activation difference analysis because activations were not captured.")


There's some patterns here, in that the changes to the activations seem to be grouped by correlation values, again pointing towards the 'groups of neurons' hypothesis. Let's perform a similar analysis for the pre-ReLU values.

In [None]:
# --- Configuration ---
mlp_pre_activation_hook_point = "blocks.0.mlp.hook_pre" # Input BEFORE ReLU
print(f"Using MLP pre-activation hook point: {mlp_pre_activation_hook_point}")

# --- Storage ---
pre_activation_store = {}

# --- Hook Function ---
def store_pre_activation_hook(activation, hook, storage_key):
    # Store pre-activation at the final sequence position (usually where prediction happens)
    pre_activation_store[storage_key] = activation[:, -1, :].detach().clone()
    return activation

# --- Execute and Capture Activations ---
z_mlp_full = None
z_mlp_ablated = None

model.eval()
with torch.no_grad():

    # --- Run 1: Get Full Model Pre-Activations ---
    print(f"\nRunning model normally, hooking: {mlp_pre_activation_hook_point}")
    pre_activation_store.clear()
    fwd_hooks_full = [(
        mlp_pre_activation_hook_point,
        partial(store_pre_activation_hook, storage_key='full') # Use partial
    )]
    _ = model.run_with_hooks(test_input_tokens, fwd_hooks=fwd_hooks_full)
    if 'full' in pre_activation_store:
        z_mlp_full = pre_activation_store['full'].cpu().numpy()
        print(f"Captured full pre-activations, shape: {z_mlp_full.shape}")
    else:
        print(f"FAILED to capture full pre-activations using hook: {mlp_pre_activation_hook_point}")

    # --- Run 2: Get Ablated Model Pre-Activations ---
    if z_mlp_full is not None: # Proceed only if first run succeeded
        print(f"\nRunning model with PC2 ablation, hooking: {mlp_pre_activation_hook_point}")
        pre_activation_store.clear()
        fwd_hooks_ablated = [
            (embedding_hook_point, ablate_pc2_hook), # Ablate PC2
            (mlp_pre_activation_hook_point, partial(store_pre_activation_hook, storage_key='ablated')) # Store pre-activation
        ]
        _ = model.run_with_hooks(test_input_tokens, fwd_hooks=fwd_hooks_ablated)
        if 'ablated' in pre_activation_store:
            z_mlp_ablated = pre_activation_store['ablated'].cpu().numpy()
            print(f"Captured ablated pre-activations, shape: {z_mlp_ablated.shape}")
        else:
            print(f"FAILED to capture ablated pre-activations using hook: {mlp_pre_activation_hook_point}")


# --- Calculate Difference and Analyze ---
if z_mlp_full is not None and z_mlp_ablated is not None and z_mlp_full.shape == z_mlp_ablated.shape:
    print("\nCalculating pre-activation differences (delta_z_mlp)...")
    delta_z_mlp = z_mlp_full - z_mlp_ablated

    failure_mask = results_df['Failed'].values
    num_failures = failure_mask.sum()

    if num_failures > 0:
        print(f"Analyzing {num_failures} failure cases...")
        delta_z_mlp_failures = delta_z_mlp[failure_mask]

        # --- 1. Top Changing Neurons ---
        mean_abs_delta_z_failures = np.mean(np.abs(delta_z_mlp_failures), axis=0)
        sorted_neuron_indices_by_delta_z = np.argsort(mean_abs_delta_z_failures)[::-1]
        print("\nTop 10 Neurons by Mean Absolute Pre-Activation Change During Failures:")
        for i in range(min(10, d_mlp)):
            idx = sorted_neuron_indices_by_delta_z[i]
            print(f"  Neuron {idx}: Mean Abs Delta_Z = {mean_abs_delta_z_failures[idx]:.4f} (Corr={neuron_sum_correlations[idx]:.3f})")

        # --- 2. Average Delta_Z for ALL Failures ---
        mean_delta_z_failures = np.mean(delta_z_mlp_failures, axis=0)
        corr_sorted_indices = np.argsort(neuron_sum_correlations)
        mean_delta_z_failures_sorted = mean_delta_z_failures[corr_sorted_indices]
        correlations_sorted = neuron_sum_correlations[corr_sorted_indices]
        plt.figure(figsize=(15, 6))
        colors = ['red' if c < 0 else 'blue' for c in correlations_sorted]
        plt.bar(range(d_mlp), mean_delta_z_failures_sorted, color=colors, width=1.0)
        plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
        plt.ylabel("Mean Pre-Activation Diff (Full - Ablated)")
        plt.title(f"Mean MLP Pre-Activation Difference (delta_z_mlp) on Failure Cases ({num_failures} samples)")
        plt.xticks([])
        plt.grid(axis='y', alpha=0.5)
        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], color='blue', lw=4, label='Positively Correlated Neurons'),
                           Line2D([0], [0], color='red', lw=4, label='Negatively Correlated Neurons')]
        plt.legend(handles=legend_elements)
        plt.show()

        # --- 3. Average Delta_Z for Specific Failure Groups (a=0) ---
        a0_failure_mask = failure_mask & (results_df['a'] == 0)
        num_a0_failures = a0_failure_mask.sum()
        if num_a0_failures > 0:
            print(f"\nAnalyzing {num_a0_failures} failure cases where a=0...")
            delta_z_mlp_a0_failures = delta_z_mlp[a0_failure_mask]
            mean_delta_z_a0_failures = np.mean(delta_z_mlp_a0_failures, axis=0)
            mean_delta_z_a0_failures_sorted = mean_delta_z_a0_failures[corr_sorted_indices]
            plt.figure(figsize=(15, 6))
            plt.bar(range(d_mlp), mean_delta_z_a0_failures_sorted, color=colors, width=1.0)
            plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
            plt.ylabel("Mean Pre-Activation Diff (Full - Ablated)")
            plt.title(f"Mean MLP Pre-Activation Difference (delta_z_mlp) on Failure Cases with a=0 ({num_a0_failures} samples)")
            plt.xticks([])
            plt.grid(axis='y', alpha=0.5)
            plt.legend(handles=legend_elements)
            plt.show()
        else:
            print("No failure cases found where a=0.")

    else:
        print("No failure cases detected.")
else:
    print("\nAnalysis skipped: Could not capture both full and ablated pre-activations.")


### Looking at Activations

Helper variable:

In [None]:
pattern_a = cache["pattern", 0, "attn"][:, :, -1, 0]
pattern_b = cache["pattern", 0, "attn"][:, :, -1, 1]
neuron_acts = cache["post", 0, "mlp"][:, -1, :]
neuron_pre_acts = cache["pre", 0, "mlp"][:, -1, :]

Get all shapes:

In [None]:
for param_name, param in cache.items():
    print(param_name, param.shape)

In [None]:
imshow(cache["pattern", 0].mean(dim=0)[:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [None]:
imshow(cache["pattern", 0][5][:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [None]:
dataset[:4]

In [None]:
imshow(cache["pattern", 0][:, 0, -1, 0].reshape(p, p), title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a")

In [None]:
imshow(
    einops.rearrange(cache["pattern", 0][:, :, -1, 0], "(a b) head -> head a b", a=p, b=p),
    title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a", facet_col=0)

Plotting neuron activations

In [None]:
cache["post", 0, "mlp"].shape

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

### Singular Value Decomposition

In [None]:
W_E.shape


In [None]:
U, S, Vh = torch.svd(W_E)
line(S, title="Singular Values")
imshow(U, title="Principal Components on the Input")

In [None]:
# Control - random Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))
line(S, title="Singular Values Random")
imshow(U, title="Principal Components Random")

## Explaining Algorithm

### Analyse the Embedding - It's a Lookup Table!

In [None]:
U, S, Vh = torch.svd(W_E)
line(U[:, :8].T, title="Principal Components of the embedding", xaxis="Input Vocabulary")

In [None]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")
for freq in range(1, p//2+1):
    fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Sin {freq}")
    fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Cos {freq}")
fourier_basis = torch.stack(fourier_basis, dim=0).to(device)
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)
imshow(fourier_basis, xaxis="Input", yaxis="Component", y=fourier_basis_names)

In [None]:
line(fourier_basis[:8], xaxis="Input", line_labels=fourier_basis_names[:8], title="First 8 Fourier Components")
line(fourier_basis[25:29], xaxis="Input", line_labels=fourier_basis_names[25:29], title="Middle Fourier Components")

In [None]:
imshow(fourier_basis @ fourier_basis.T, title="All Fourier Vectors are Orthogonal")

### Analyse the Embedding

In [None]:
imshow(fourier_basis @ W_E, yaxis="Fourier Component", xaxis="Residual Stream", y=fourier_basis_names, title="Embedding in Fourier Basis")

In [None]:
line((fourier_basis @ W_E).norm(dim=-1), xaxis="Fourier Component", x=fourier_basis_names, title="Norms of Embedding in Fourier Basis")

In [None]:
key_freqs = [17, 25, 32, 47]
key_freq_indices = [33, 34, 49, 50, 63, 64, 93, 94]
fourier_embed = fourier_basis @ W_E
key_fourier_embed = fourier_embed[key_freq_indices]
print("key_fourier_embed", key_fourier_embed.shape)
imshow(key_fourier_embed @ key_fourier_embed.T, title="Dot Product of embedding of key Fourier Terms")

### Key Frequencies

In [None]:
line(fourier_basis[[34, 50, 64, 94]], title="Cos of key freqs", line_labels=[34, 50, 64, 94])

In [None]:
line(fourier_basis[[34, 50, 64, 94]].mean(0), title="Constructive Interference")

## Analyse Neurons

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, 0], "(a b) -> a b", a=p, b=p),
    title="First neuron act", xaxis="b", yaxis="a",)

In [None]:
imshow(fourier_basis[94][None, :] * fourier_basis[94][:, None], title="Cos 47a * cos 47b")

In [None]:
imshow(fourier_basis[94][None, :] * fourier_basis[0][:, None], title="Cos 47a * const")

In [None]:
imshow(fourier_basis @ neuron_acts[:, 0].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 0", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
imshow(fourier_basis @ neuron_acts[:, 5].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 5", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
imshow(fourier_basis @ torch.randn_like(neuron_acts[:, 0]).reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of RANDOM", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

### Neuron Clusters

In [None]:
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)

In [None]:
neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).to(device)
for freq in range(0, p//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]
imshow(neuron_freq_norm, xaxis="Neuron", yaxis="Freq", y=torch.arange(1, p//2+1), title="Neuron Frac Explained by Freq")

In [None]:
line(neuron_freq_norm.max(dim=0).values.sort().values, xaxis="Neuron", title="Max Neuron Frac Explained over Freqs")

## Read Off the Neuron-Logit Weights to Interpret

In [None]:
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

In [None]:
line((W_logit @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title="W_logit in the Fourier Basis")

In [None]:
neurons_17 = neuron_freq_norm[17-1]>0.85
neurons_17.shape

In [None]:
neurons_17.sum()

In [None]:
line((W_logit[neurons_17] @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title="W_logit for freq 17 neurons in the Fourier Basis")

Study sin 17

In [None]:
freq = 17
W_logit_fourier = W_logit @ fourier_basis
neurons_sin_17 = W_logit_fourier[:, 2*freq-1]
line(neurons_sin_17)

In [None]:
neuron_acts.shape

In [None]:
inputs_sin_17c = neuron_acts @ neurons_sin_17
imshow(fourier_basis @ inputs_sin_17c.reshape(p, p) @ fourier_basis.T, title="Fourier Heatmap over inputs for sin17c", x=fourier_basis_names, y=fourier_basis_names)

# Black Box Methods + Progress Measures

## Setup Code

Code to plot embedding freqs

In [None]:
def embed_to_cos_sin(fourier_embed):
    if len(fourier_embed.shape) == 1:
        return torch.stack([fourier_embed[1::2], fourier_embed[2::2]])
    else:
        return torch.stack([fourier_embed[:, 1::2], fourier_embed[:, 2::2]], dim=1)

from neel_plotly.plot import melt

def plot_embed_bars(
    fourier_embed,
    title="Norm of embedding of each Fourier Component",
    return_fig=False,
    **kwargs
):
    cos_sin_embed = embed_to_cos_sin(fourier_embed)
    df = melt(cos_sin_embed)
    # display(df)
    group_labels = {0: "sin", 1: "cos"}
    df["Trig"] = df["0"].map(lambda x: group_labels[x])
    fig = px.bar(
        df,
        barmode="group",
        color="Trig",
        x="1",
        y="value",
        labels={"1": "$w_k$", "value": "Norm"},
        title=title,
        **kwargs
    )
    fig.update_layout(dict(legend_title=""))

    if return_fig:
        return fig
    else:
        fig.show()

Code to test a tensor of edited logits

In [None]:
def test_logits(logits, bias_correction=False, original_logits=None, mode="all"):
    # Calculates cross entropy loss of logits representing a batch of all p^2
    # possible inputs
    # Batch dimension is assumed to be first
    if logits.shape[1] == p * p:
        logits = logits.T
    if logits.shape == torch.Size([p * p, p + 1]):
        logits = logits[:, :-1]
    logits = logits.reshape(p * p, p)
    if bias_correction:
        # Applies bias correction - we correct for any missing bias terms,
        # independent of the input, by centering the new logits along the batch
        # dimension, and then adding the average original logits across all inputs
        logits = (
            einops.reduce(original_logits - logits, "batch ... -> ...", "mean") + logits
        )
    if mode == "train":
        return loss_fn(logits[train_indices], labels[train_indices])
    elif mode == "test":
        return loss_fn(logits[test_indices], labels[test_indices])
    elif mode == "all":
        return loss_fn(logits, labels)

Code to run a metric over every checkpoint

In [None]:
metric_cache = {}

In [None]:
def get_metrics(model, metric_cache, metric_fn, name, reset=False):
    if reset or (name not in metric_cache) or (len(metric_cache[name]) == 0):
        metric_cache[name] = []
        for c, sd in enumerate(tqdm.tqdm((model_checkpoints))):
            model.reset_hooks()
            model.load_state_dict(sd)
            out = metric_fn(model)
            if type(out) == torch.Tensor:
                out = utils.to_numpy(out)
            metric_cache[name].append(out)
        model.load_state_dict(model_checkpoints[-1])
        try:
            metric_cache[name] = torch.tensor(metric_cache[name])
        except:
            metric_cache[name] = torch.tensor(np.array(metric_cache[name]))



## Defining Progress Measures

### Loss Curves

In [None]:
memorization_end_epoch = 1500
circuit_formation_end_epoch = 13300
cleanup_end_epoch = 16600

In [None]:
def add_lines(figure):
    figure.add_vline(memorization_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(circuit_formation_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(cleanup_end_epoch, line_dash="dash", opacity=0.7)
    return figure

In [None]:
fig = line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Modular Addition", line_labels=['train', 'test'], toggle_x=True, toggle_y=True, return_fig=True)
add_lines(fig)

### Logit Periodicity

In [None]:
all_logits = original_logits[:, -1, :]
print(all_logits.shape)
all_logits = einops.rearrange(all_logits, "(a b) c -> a b c", a=p, b=p)
print(all_logits.shape)

In [None]:
coses = {}
for freq in key_freqs:
    print("Freq:", freq)
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)
    cube_predicted_logits /= cube_predicted_logits.norm()
    coses[freq] = cube_predicted_logits

In [None]:
approximated_logits = torch.zeros_like(all_logits)
for freq in key_freqs:
    print("Freq:", freq)
    coeff = (all_logits * coses[freq]).sum()
    print("Coeff:", coeff)
    cosine_sim = coeff / all_logits.norm()
    print("Cosine Sim:", cosine_sim)
    approximated_logits += coeff * coses[freq]
residual = all_logits - approximated_logits
print("Residual size:", residual.norm())
print("Residual fraction of norm:", residual.norm()/all_logits.norm())

In [None]:
random_logit_cube = torch.randn_like(all_logits)
print((all_logits * random_logit_cube).sum()/random_logit_cube.norm()/all_logits.norm())

In [None]:
test_logits(all_logits)

In [None]:
test_logits(approximated_logits)

#### Look During Training

In [None]:
cos_cube = []
for freq in range(1, p//2 + 1):
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)
    cube_predicted_logits /= cube_predicted_logits.norm()
    cos_cube.append(cube_predicted_logits)
cos_cube = torch.stack(cos_cube, dim=0)
print(cos_cube.shape)

In [None]:
def get_cos_coeffs(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals


get_metrics(model, metric_cache, get_cos_coeffs, "cos_coeffs")
print(metric_cache["cos_coeffs"].shape)

In [None]:
fig = line(metric_cache["cos_coeffs"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Coefficients with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Coefficient", return_fig=True)
add_lines(fig)

In [None]:
def get_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals / logits.norm()

get_metrics(model, metric_cache, get_cos_sim, "cos_sim") # You may need a big GPU. If you don't have one and can't work around this, raise an issue for help!
print(metric_cache["cos_sim"].shape)

fig = line(metric_cache["cos_sim"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Cosine Sim with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

In [None]:
def get_residual_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    residual = logits - (vals[:, None, None, None] * cos_cube).sum(dim=0)
    return residual.norm() / logits.norm()

get_metrics(model, metric_cache, get_residual_cos_sim, "residual_cos_sim")
print(metric_cache["residual_cos_sim"].shape)

fig = line([metric_cache["cos_sim"][:, i] for i in range(p//2)]+[metric_cache["residual_cos_sim"]], line_labels=[f"Freq {i}" for i in range(1, p//2+1)]+["residual"], title="Cosine Sim with Predicted Logits + Residual", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

## Restricted Loss

In [None]:
neuron_acts.shape

In [None]:
neuron_acts_square = einops.rearrange(neuron_acts, "(a b) neur -> a b neur", a=p, b=p).clone()
# Center it
neuron_acts_square -= einops.reduce(neuron_acts_square, "a b neur -> 1 1 neur", "mean")
neuron_acts_square_fourier = einsum("a b neur, fa a, fb b -> fa fb neur", neuron_acts_square, fourier_basis, fourier_basis)
imshow(neuron_acts_square_fourier.norm(dim=-1), xaxis="Fourier Component b", yaxis="Fourier Component a", title="Norms of neuron activations by Fourier Component", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.numel())
neuron_acts = cache["post", 0, "mlp"][:, -1, :]

In [None]:
approx_neuron_acts = torch.zeros_like(neuron_acts)
approx_neuron_acts += neuron_acts.mean(dim=0)
a = torch.arange(p)[:, None]
b = torch.arange(p)[None, :]
for freq in key_freqs:
    cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
    cos_apb_vec /= cos_apb_vec.norm()
    cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
    sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
    sin_apb_vec /= sin_apb_vec.norm()
    sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
restricted_logits = approx_neuron_acts @ W_logit
print(loss_fn(restricted_logits[test_indices], test_labels))

In [None]:
print(loss_fn(all_logits, labels)) # This bugged on models not fully trained

### Look During Training

In [None]:
def get_restricted_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)
    approx_neuron_acts += neuron_acts.mean(dim=0)
    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
    restricted_logits = approx_neuron_acts @ model.blocks[0].mlp.W_out @ model.unembed.W_U
    # Add bias term
    restricted_logits += logits.mean(dim=0, keepdim=True) - restricted_logits.mean(dim=0, keepdim=True)
    return loss_fn(restricted_logits[test_indices], test_labels)
get_restricted_loss(model)

In [None]:
get_metrics(model, metric_cache, get_restricted_loss, "restricted_loss", reset=True)
print(metric_cache["restricted_loss"].shape)

In [None]:
fig = line([train_losses[::100], test_losses[::100], metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Restricted Loss Curve", line_labels=['train', 'test', "restricted_loss"], toggle_x=True, toggle_y=True, return_fig=True)
add_lines(fig)

In [None]:
fig = line([torch.tensor(test_losses[::100])/metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Restricted Loss to Test Loss Ratio", toggle_x=True, toggle_y=True, return_fig=True)
# WARNING: bugged when cancelling training half way thr ough
add_lines(fig)

## Excluded Loss

In [None]:
approx_neuron_acts = torch.zeros_like(neuron_acts)
# approx_neuron_acts += neuron_acts.mean(dim=0)
a = torch.arange(p)[:, None]
b = torch.arange(p)[None, :]
for freq in key_freqs:
    cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
    cos_apb_vec /= cos_apb_vec.norm()
    cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
    sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
    sin_apb_vec /= sin_apb_vec.norm()
    sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
excluded_neuron_acts = neuron_acts - approx_neuron_acts
excluded_logits = excluded_neuron_acts @ W_logit
print(loss_fn(excluded_logits[train_indices], train_labels))

In [None]:
def get_excluded_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)
    # approx_neuron_acts += neuron_acts.mean(dim=0)
    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
    excluded_neuron_acts = neuron_acts - approx_neuron_acts
    residual_stream_final = excluded_neuron_acts @ model.blocks[0].mlp.W_out + cache["resid_mid", 0][:, -1, :]
    excluded_logits = residual_stream_final @ model.unembed.W_U
    return loss_fn(excluded_logits[train_indices], train_labels)
get_excluded_loss(model)

In [None]:
get_metrics(model, metric_cache, get_excluded_loss, "excluded_loss", reset=True)
print(metric_cache["excluded_loss"].shape)

In [None]:
fig = line([train_losses[::100], test_losses[::100], metric_cache["excluded_loss"], metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Excluded and Restricted Loss Curve", line_labels=['train', 'test', "excluded_loss", "restricted_loss"], toggle_x=True, toggle_y=True, return_fig=True)

add_lines(fig)