<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Grokking_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Grokking Demo Notebook

<b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

# Setup
(No need to read)

In [None]:
TRAIN_MODEL = True

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = True
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install circuitsvis

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

# Model Training

## Config

In [None]:
p = 10
frac_train = 0.72

# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.9, 0.98)

num_epochs = 10000
checkpoint_every = 100

DATA_SEED = 598

## Define Task
* Define generalized Fibonacci
* Define the dataset & labels

Input format:
|a|b|=|

In [None]:
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)
equals_vector = einops.repeat(torch.tensor(p), " -> (i j)", i=p, j=p)

In [None]:
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)

In [None]:
labels = (dataset[:, 0] + dataset[:, 1])

Convert this to a train + test set - 30% in the training set

In [None]:
torch.manual_seed(DATA_SEED)
indices = torch.randperm(p*p)
cutoff = int(p*p*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)

## Define Model

In [None]:

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=p+1,
    d_vocab_out=2*p,
    n_ctx=3,
    init_weights=True,
    device=device,
    seed = 999,
)

In [None]:
model = HookedTransformer(cfg)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [None]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [None]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()
train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)

In [None]:
print("Uniform loss:")
print(np.log(2*p))

## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent. We do this so to make training smoother and reduce the number of slingshots.

In [None]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if TRAIN_MODEL:
    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())

        optimizer.step()
        optimizer.zero_grad()

        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            checkpoint_epochs.append(epoch)
            model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

In [None]:
torch.save(
    {
        "model":model.state_dict(),
        "config": model.cfg,
        "checkpoints": model_checkpoints,
        "checkpoint_epochs": checkpoint_epochs,
        "test_losses": test_losses,
        "train_losses": train_losses,
        "train_indices": train_indices,
        "test_indices": test_indices,
    },
    PTH_LOCATION)

In [None]:
if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    train_indices = cached_data["train_indices"]
    test_indices = cached_data["test_indices"]

## Show Model Training Statistics, Check that it groks!

In [None]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line

In [None]:
line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Fibonacci", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)

# Analysing the Model

## Standard Things to Try

In [None]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.numel())

Get key weight matrices:

In [None]:
W_E = model.embed.W_E[:-1]
print("W_E", W_E.shape)
W_neur = W_E @ model.blocks[0].attn.W_V @ model.blocks[0].attn.W_O @ model.blocks[0].mlp.W_in
print("W_neur", W_neur.shape)
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

In [None]:
original_loss = loss_fn(original_logits, labels).item()
print("Original Loss:", original_loss)

In [None]:
input_number = 7
embedding_vector_np = W_E[input_number].detach().cpu().numpy() # this gets us the embedded vector for input_number

# first we try to plot directly the value of the entries of the embedded vector
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6)) # Adjust figure size as needed
plt.bar(np.arange(128), embedding_vector_np)
plt.xlabel("Component Index (within Embedding Vector)")
plt.ylabel("Component Value")
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

In [None]:
# now we will do a heatmap of the embedded vector for all possible inputs

# We want dimensions along y-axis and inputs along x-axis,
# so we need to transpose the matrix for imshow
W_E_transposed = W_E.detach().cpu().numpy().T # Shape (d, N)

# Get N and d from the original tensor shape
N_vocab = W_E.shape[0]
d_embed = W_E.shape[1]

# Create the heatmap
plt.figure(figsize=(8, 10)) # Adjust figsize as needed (width, height)
# Use imshow to display the matrix as an image.
# aspect='auto' allows the cells to be non-square to fit the plot area.
# interpolation='nearest' avoids blurring pixels.
# cmap='viridis' is a common colormap, change if you prefer another.
im = plt.imshow(W_E_transposed, aspect='auto', interpolation='nearest', cmap='viridis')

# Add labels and title|
plt.xlabel("Input Token Index (k)")
plt.ylabel("Embedding Dimension Index")
plt.title("Heatmap of Embedding Vectors (W_E)")

# Set ticks to match indices
# Show ticks for every input token if N is small
if N_vocab <= 20: # Adjust threshold as needed
     plt.xticks(ticks=np.arange(N_vocab), labels=np.arange(N_vocab))
else:
    # For larger N, show fewer ticks to avoid clutter
     plt.xticks(ticks=np.linspace(0, N_vocab-1, num=min(N_vocab, 10), dtype=int))


# Add a colorbar to show the mapping from color to value
plt.colorbar(im, label='Embedding Component Value')

# Ensure layout is tight
plt.tight_layout()

# Show the plot
plt.show()

It looks... kind of periodic? Let's now try to do PCA (or SVD) on this and see if we can learn anything.

In [None]:
# 1. Center the data (subtract the mean of each dimension)
W_E_mean = W_E.mean(dim=0, keepdim=True) # Calculate mean across samples (N) for each dimension (d)
W_E_centered = W_E - W_E_mean # Broadcasting subtracts the mean from each row

# 2. Perform SVD on the centered data
# U shape: (N, K), S shape: (K,), Vt shape: (K, d) where K = min(N, d)
# full_matrices=False is generally more efficient
U, S, Vt = torch.linalg.svd(W_E_centered, full_matrices=False)

# --- Results ---

# PCA Scores (Data projected onto principal components):
# This is often the primary result needed for visualization/dimensionality reduction.
# It represents each original sample (row in W_E) in the new PCA coordinate system.
pca_scores = U * S  # Shape: (N, K) - Scales the left singular vectors by singular values

# Principal Components (Loadings / Directions of maximum variance):
# These are the rows of Vt. Each row is a d-dimensional vector representing a principal direction.
principal_components = Vt # Shape: (K, d)

# Explained Variance Ratio (requires a bit more calculation):
explained_variance = S.square() / (W_E.shape[0] - 1) # Variance explained by each component
total_variance = explained_variance.sum()
explained_variance_ratio = explained_variance / total_variance

# Plot score on PC1 vs token index k (similar to previous PCA example)
plt.figure(figsize=(10, 5))
plt.scatter(np.arange(W_E.shape[0]), pca_scores[:, 0].detach().cpu().numpy(), alpha=0.7, s=10)
plt.xlabel("Input Token Index (k)")
plt.ylabel("Score on PC1")
plt.title("PyTorch PCA: Score on First Principal Component vs. Input Token Index")
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()



We observe that the principal component is linear, i.e., the model learns an embedding that is given by embedded_value = w * value + b, with value being the input token index. w is actually negative, and b is ~2.

We now wish to investigate how important this first component, which is a linear function of the inputs, is in comparison with the other components. We do this by looking at how much variance is explained by each of the components.

In [None]:
# --- Calculate Explained Variance (using PyTorch SVD results S) ---

# Variance explained by each component (proportional to singular value squared)
# Using N (W_E.shape[0]) instead of N-1 is fine for the ratio calculation
explained_variance = S.square() / W_E.shape[0] # Variance = (singular value / sqrt(N))^2
total_variance = explained_variance.sum()
explained_variance_ratio = explained_variance / total_variance

# Convert to numpy for printing/plotting if needed
explained_variance_ratio_np = explained_variance_ratio.detach().cpu().numpy()

# --- Print the Ratios ---
print(f"Explained Variance Ratio by Principal Component:")
# Limit printing to the number of components computed or a reasonable max (e.g., 10)
num_components_to_print = min(len(explained_variance_ratio_np), 10)
cumulative_variance = 0.0
for i in range(num_components_to_print):
    ratio = explained_variance_ratio_np[i]
    cumulative_variance += ratio
    print(f"  PC{i+1}: {ratio:.4f} ({ratio*100:.2f}%) \t| Cumulative: {cumulative_variance:.4f} ({cumulative_variance*100:.2f}%)")

# If you computed more components than printed:
if len(explained_variance_ratio_np) > num_components_to_print:
    print(f"  ...")
    print(f"Total Cumulative Variance (all {len(explained_variance_ratio_np)} components): {explained_variance_ratio_np.sum():.4f} ({explained_variance_ratio_np.sum()*100:.2f}%)")


# --- Visualize with a Scree Plot ---
plt.figure(figsize=(10, 6))

component_indices = np.arange(1, len(explained_variance_ratio_np) + 1)

# Plot individual explained variance ratios
plt.bar(component_indices, explained_variance_ratio_np, alpha=0.7, align='center',
        label='Individual Explained Variance Ratio')

# Plot cumulative explained variance ratio
cumulative_variance_ratio_np = np.cumsum(explained_variance_ratio_np)
plt.plot(component_indices, cumulative_variance_ratio_np, marker='o', linestyle='--',
         label='Cumulative Explained Variance Ratio')

# Add threshold lines (optional, but common)
plt.axhline(y=0.9, color='r', linestyle=':', linewidth=1, label='90% Threshold')
plt.axhline(y=0.95, color='g', linestyle=':', linewidth=1, label='95% Threshold')


plt.xlabel('Principal Component Index')
plt.ylabel('Explained Variance Ratio')
plt.title('Scree Plot - Explained Variance by Principal Component')
# Ensure x-axis ticks match component indices if not too many
if len(component_indices) <= 15:
    plt.xticks(ticks=component_indices)
plt.legend(loc='best')
plt.grid(True, linestyle='--', alpha=0.6)
plt.ylim(0, 1.1) # Set y-axis limit slightly above 1.0
plt.show()

We observe that the first component explains ~0.75 of the variance, and the second component ~0.25. The others are basically irrelevant. This does however mean that we also need to look at the second component to understand what is happening with the embedding. Let us do so now.

In [None]:
# Plot score on PC2 vs token index k
# This is basically the same code as before, but we now plot the second component rather than the first one
plt.figure(figsize=(10, 5))
plt.scatter(np.arange(W_E.shape[0]), pca_scores[:, 1].detach().cpu().numpy(), alpha=0.7, s=10)
plt.xlabel("Input Token Index (k)")
plt.ylabel("Score on PC2")
plt.title("PyTorch PCA: Score on Second Principal Component vs. Input Token Index")
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()

In [None]:
# Plot score on PC2 vs token index k
# This is basically the same code as before, but we now plot the third component rather than the first one
plt.figure(figsize=(10, 5))
plt.scatter(np.arange(W_E.shape[0]), pca_scores[:, 2].detach().cpu().numpy(), alpha=0.7, s=10)
plt.xlabel("Input Token Index (k)")
plt.ylabel("Score on PC2")
plt.title("PyTorch PCA: Score on Third Principal Component vs. Input Token Index")
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()

This looks like a parabola. I have no clue why! Let's plot below the scores on the first component against the scores on the second component.



In [None]:
plt.figure(figsize=(8, 8))
token_indices = np.arange(p)
scatter = plt.scatter(
    pca_scores[:, 0].detach().cpu().numpy(),  # X-coordinates are scores on PC1
    pca_scores[:, 1].detach().cpu().numpy(),  # Y-coordinates are scores on PC2
    c=token_indices,                          # Color points based on the input token index (0-9)
    cmap='viridis',                           # Colormap (e.g., 'viridis', 'plasma')
    alpha=0.8,
    s=50                                      # Increase point size for visibility
)
plt.xlabel("Score on PC1 (Linear Component)")
plt.ylabel("Score on PC2 (Parabolic Component)")
plt.title("Embeddings Projected onto First Two Principal Components")
plt.colorbar(scatter, label='Input Token Index (k)')
plt.grid(True, linestyle='--', alpha=0.6)
plt.gca().set_aspect('equal', adjustable='box') # Try to keep scales comparable

An attempt at a conclusion regarding the model's learned embedding:

The model projects inputs (numbers 0-9) onto a low-dimensional (2D) manifold within the embedding space. This representation encodes the number's value primarily along one linear axis (PC1) and secondarily along an orthogonal parabolic axis (PC2). This specific geometric arrangement is likely learned to facilitate the downstream integer addition computation. (Leveraging linear order + perhaps magnitude/centrality).

Let's consider some potential next steps to further understand the algorithm the model is implementing:

Step 2: Attention Pattern Check

Goal: Check how the model combines a and b info at the = position.

Select Sample Inputs: Get various pairs (a, b).

Hook Attention Weights: For each head, get the attention weights from position 2 (query: =) to positions 0 (a), 1 (b), and 2 (=). Store these weights [w_to_a, w_to_b, w_to_=] for each head and input pair.

Visualize Average Patterns: For each head, average the weights w_to_a, w_to_b, w_to_= across all sample inputs. Plot these three average weights per head (e.g., using bar charts).

Analyze: Does position 2 attend significantly to both position 0 (a) and position 1 (b)? Are there clear differences between heads? Is attention to position 2 (=) itself low?


Expected Outcome: Confirm attention gathers info from a and b. Note basic patterns.

Step 3: MLP Analysis

Goal: Understand how the MLP computes a+b from the combined a and b representations.

Hook MLP Activations: Get the n-dimensional activation vector after the ReLU inside the MLP, specifically for the state calculated at position 2 (=). Do this for various input pairs (a, b).

Correlate Neurons to Task Variables: For each MLP neuron, calculate its activation across the sample pairs. Find the correlation between each neuron's activation and key variables: a, b, the target sum a+b, PC1 score of a, PC2 score of a, PC1 score of b, PC2 score of b.

Identify Key Neurons: Note which neurons correlate strongly with the target sum a+b. Note any correlating strongly with input features (like PC1/PC2 scores).

(Optional) Visualize Key Neuron Activations: Make 2D heatmaps (x-axis a, y-axis b) for a few key neurons (e.g., sum-correlated neurons). Does the activation map look like a+b?

Probe MLP State: Train a linear regression model to predict the scalar a+b using the n-dimensional MLP activation vector as input. Check probe accuracy. High accuracy implies MLP state linearly encodes the sum.

Formulate Hypothesis: Based on correlations and probes, hypothesize how the MLP uses the input features (represented by PC1/PC2) via its neurons to arrive at an internal state that encodes the sum a+b.

In [None]:
captured_mlp_activations = []

# Define the hook function
def store_mlp_post_activation_hook(
    activation: torch.Tensor,
    hook: 'HookPoint'
):
    """
    Hook function to capture the MLP activation output (after ReLU).
    Selects only the activation for the target position (pos 2).
    """
    # activation shape is likely [batch_size, sequence_length, d_mlp]
    # We want the activation at the '=' token, which is at index 2
    target_activation = activation[:, 2, :].detach().cpu() # Shape: [batch_size, d_mlp]
    captured_mlp_activations.append(target_activation)

    # MUST return the activation to allow the forward pass to continue
    return activation

# Identify the exact hook point name
mlp_hook_point_name = "blocks.0.mlp.hook_post" # Output AFTER ReLU

# Clear previous captures before running
captured_mlp_activations = []

activation_list = []
target_sums = []

print(f"Processing {len(dataset)} input pairs...")
for a, b, equals in dataset:
    input_tokens = torch.tensor([[a, b, equals]], dtype=torch.long)

    # Clear captures for this specific run
    captured_mlp_activations = []

    # Run with hooks
    _ = model.run_with_hooks(
        input_tokens,
        fwd_hooks=[(mlp_hook_point_name, store_mlp_post_activation_hook)]
    )

    # Store the result if captured
    if captured_mlp_activations:
        # Get the vector for the first (and only) item in the batch
        mlp_activation_vector = captured_mlp_activations[0][0].numpy() # Convert to numpy array
        activation_list.append(mlp_activation_vector)
        target_sums.append(a + b)
    else:
        print(f"Warning: Hook did not capture activation for pair ({a}, {b})")

print(f"Collected {len(activation_list)} activation vectors.")

# Convert to numpy arrays for training the probe
X_probe = np.array(activation_list) # Shape: [num_samples, d_mlp]
y_probe = y_probe = np.array([item.cpu().numpy() for item in target_sums])     # Shape: [num_samples]

print("Probe input data shapes:")
print("X_probe:", X_probe.shape)
print("y_probe:", y_probe.shape)

We now have the outputs of MLP for each input pair. We will try to fit this to a linear function to see if this directly correlates to the desired representation of the sum.

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

# 1. Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    X_probe, y_probe, test_size=0.3, random_state=42 # Adjust test_size if needed
)

# 2. Initialize and train the Linear Regression model
probe_model = LinearRegression()
probe_model.fit(X_train, y_train)

# 3. Make predictions on the test set
y_pred = probe_model.predict(X_test)

# 4. Evaluate the model
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"Linear Probe Performance:")
print(f"  Mean Squared Error (MSE): {mse:.4f}")
print(f"  R-squared (R2): {r2:.4f}")

# Optional: Basic baseline comparison (predicting the mean of training labels)
baseline_pred = np.full_like(y_test, y_train.mean())
baseline_mse = mean_squared_error(y_test, baseline_pred)
baseline_r2 = r2_score(y_test, baseline_pred) # Will be close to 0 by definition

print(f"\nBaseline Performance (Predicting Mean):")
print(f"  Baseline MSE: {baseline_mse:.4f}")
print(f"  Baseline R2: {baseline_r2:.4f}")


# Optional: Visualize predictions vs actual
plt.figure(figsize=(6, 6))
plt.scatter(y_test, y_pred, alpha=0.7, label='Predictions')
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], '--', color='red', label='Perfect Prediction')
plt.xlabel("True Values (a+b)")
plt.ylabel("Predicted Values (a+b)")
plt.title("Linear Probe: True vs. Predicted Sums")
plt.legend()
plt.grid(True)
plt.show()

The linear regression we've done is almost perfectly accurate, which means that the post-ReLU activations contain a linear representation of the target sum a+b. This means that the bulk of the computation is being done by the MLP. Hence, we can likely find 'sum neurons', i.e., neurons that are mostly responsible for implementing the computation. We will do this by correlation analysis: what neurons' activations correlate more strongly with a+b?

In [None]:
from scipy.stats import pearsonr

# Ensure inputs are numpy arrays
X_probe = np.asarray(X_probe)
y_probe = np.asarray(y_probe)

num_samples, d_mlp = X_probe.shape
neuron_sum_correlations = np.zeros(d_mlp)

# Calculate Pearson correlation for each neuron
for i in range(d_mlp):
    # pearsonr returns (correlation_coefficient, p_value)
    correlation, _ = pearsonr(X_probe[:, i], y_probe)
    neuron_sum_correlations[i] = correlation

# Handle potential NaN values if a neuron's activation was constant (zero variance)
neuron_sum_correlations = np.nan_to_num(neuron_sum_correlations)

print(f"Calculated correlations for {d_mlp} neurons.")
# The 'neuron_sum_correlations' array now holds the correlation of each neuron with a+b.

# Find the indices of neurons with highest absolute correlation
top_n = 100
# Get indices sorted by absolute correlation, descending
indices_sorted_by_abs_corr = np.argsort(np.abs(neuron_sum_correlations))[::-1]

print(f"\nTop {top_n} neurons by absolute correlation with sum (a+b):")
for i in range(top_n):
    idx = indices_sorted_by_abs_corr[i]
    print(f"  Neuron {idx}: Correlation = {neuron_sum_correlations[idx]:.4f}")

In [None]:
import seaborn as sns
# --- Visualization 1: Histogram of Correlations ---
print("\nGenerating Histogram of Correlations...")
plt.figure(figsize=(10, 6))
sns.histplot(neuron_sum_correlations, bins=50, kde=True)
plt.title('Histogram of Neuron Activations\' Correlation with Sum (a+b)')
plt.xlabel('Pearson Correlation Coefficient')
plt.ylabel('Number of Neurons')
plt.grid(axis='y', alpha=0.5)
plt.show()

# --- Visualization 2: Sorted Correlation Plot (Stem Plot) ---
print("\nGenerating Sorted Correlation Plot...")
sorted_indices = np.argsort(neuron_sum_correlations)
sorted_correlations = neuron_sum_correlations[sorted_indices]
ranks = np.arange(d_mlp)

pos_mask = sorted_correlations >= 0
neg_mask = sorted_correlations < 0

plt.figure(figsize=(12, 7))
# Plot positive correlations
markerline_pos, stemlines_pos, baseline_pos = plt.stem(
    ranks[pos_mask], sorted_correlations[pos_mask],
    linefmt='b-', markerfmt='bo', basefmt=' ', label='Positive Corr'
)
plt.setp(stemlines_pos, linewidth=1, alpha=0.7)
plt.setp(markerline_pos, markersize=3, alpha=0.7)

# Plot negative correlations
markerline_neg, stemlines_neg, baseline_neg = plt.stem(
    ranks[neg_mask], sorted_correlations[neg_mask],
    linefmt='r-', markerfmt='ro', basefmt=' ', label='Negative Corr'
)
plt.setp(stemlines_neg, linewidth=1, alpha=0.7)
plt.setp(markerline_neg, markersize=3, alpha=0.7)

plt.title('Neuron Correlations with Sum (a+b), Sorted by Value')
plt.xlabel('Neuron Rank (Sorted by Correlation)')
plt.ylabel('Pearson Correlation Coefficient')
plt.ylim(-1.05, 1.05)
plt.grid(axis='y', alpha=0.5)
plt.legend()
plt.show()

num_outputs=20
# --- Visualization 3: Heatmap of Sorted W_L ---
# Ensure W_L is defined correctly (num_outputs x d_mlp)
# Example: W_L = model.W_U @ model.blocks[0].mlp.W_out (or however you get it)
if 'W_L' in locals() and W_L.shape == (num_outputs, d_mlp):
    print("\nGenerating Heatmap of Sorted W_L...")
    # Sort columns (neurons) of W_L based on correlation
    neuron_corr_sorted_indices = np.argsort(neuron_sum_correlations)
    W_L_sorted_by_neuron_corr = W_L[:, neuron_corr_sorted_indices]

    plt.figure(figsize=(15, 8))
    sns.heatmap(W_L_sorted_by_neuron_corr, cmap='coolwarm', center=0) # 'coolwarm' is good for weights
    plt.title('Neuron-Logit Map (W_L) Columns Sorted by Neuron Correlation')
    plt.xlabel('Neurons (Sorted by Correlation with Sum a+b)')
    plt.ylabel(f'Output Logit (0 to {num_outputs-1})')
    plt.yticks(np.arange(num_outputs) + 0.5, labels=np.arange(num_outputs), rotation=0)
    plt.xticks([]) # Hide neuron indices as they are too dense
    plt.show()
else:
    print("\nSkipping Heatmap of W_L: W_L matrix not found or incorrect shape.")
    print(f"Expected shape: ({num_outputs}, {d_mlp})")
    if 'W_L' in locals():
        print(f"Actual shape: {W_L.shape}")


# --- Visualization 4: Heatmap of Sorted Neuron Activations ---
# Ensure X_probe and y_probe are defined correctly
if 'X_probe' in locals() and 'y_probe' in locals() and X_probe.shape[0] == y_probe.shape[0]:
    print("\nGenerating Heatmap of Sorted Neuron Activations...")
    # 1. Sort columns (neurons) by correlation
    neuron_corr_sorted_indices = np.argsort(neuron_sum_correlations)
    X_probe_neurons_sorted = X_probe[:, neuron_corr_sorted_indices]

    # 2. Sort rows (inputs) by the target sum y_probe
    input_sum_sorted_indices = np.argsort(y_probe)
    X_probe_fully_sorted = X_probe_neurons_sorted[input_sum_sorted_indices, :]

    plt.figure(figsize=(15, 8))
    # Use robust=True to handle potential outliers in activations
    # Use cbar=False to avoid clutter, as absolute activation values might vary
    sns.heatmap(X_probe_fully_sorted, cmap='viridis', robust=True, cbar=False)
    plt.title('MLP Activations (Sorted)')
    plt.xlabel('Neurons (Sorted by Correlation with Sum a+b)')
    plt.ylabel('Input Samples (Sorted by Sum a+b)')
    plt.xticks([]) # Hide neuron indices
    plt.yticks([]) # Hide sample indices
    plt.show()
else:
    print("\nSkipping Heatmap of Activations: X_probe or y_probe not found or shapes mismatch.")
    if 'X_probe' in locals():
        print(f"X_probe shape: {X_probe.shape}")
    if 'y_probe' in locals():
        print(f"y_probe shape: {y_probe.shape}")

We observe widespread strong positive correlation of neuron activations with the target sum. In fact, more than 130 neurons have a correlation exceeding 0.9 with the target sum. This shows that the MLP represents the sum in a highly distributed manner. There aren't localized, specialized 'sum' neurons, per se.

We will now visualize the activation patterns of a few top-correlated neurons as heatmaps across the grid of all possible (a, b) input pairs. This reveals the specific function each of these top neurons computes according to input pairs. Hopefully from this we can hypothesize how they combine the principal components we identified earlier to achieve their strong correlation.

In [None]:
equals_token_id = 10
N = 10
top_n_to_visualize = 3
top_neuron_indices = indices_sorted_by_abs_corr[:top_n_to_visualize]
mlp_hook_point_name = "blocks.0.mlp.hook_post" # Output AFTER ReLU

# --- Data Storage ---
# Create heatmaps initialized with NaN (or zero)
neuron_activation_heatmaps = {
    idx: np.full((N, N), np.nan) for idx in top_neuron_indices
}
captured_mlp_activation_storage = {} # Temporary storage during hook

# --- Hook Function ---
def capture_mlp_post_hook(activation: torch.Tensor, hook: 'HookPoint'):
    """Captures MLP activation at position 2."""
    captured_mlp_activation_storage['activation'] = activation[:, 2, :].detach().cpu()
    return activation

# --- Generate Activations ---
print(f"Generating activations for {top_n_to_visualize} neurons across {N}x{N} grid...")
model.eval() # Ensure model is in evaluation mode
with torch.no_grad(): # No need to track gradients
    for a in range(N):
        for b in range(N):
              input_tokens = torch.tensor([[a, b, equals_token_id]], dtype=torch.long)

              # Clear previous capture
              captured_mlp_activation_storage.clear()

              # Run with hook
              _ = model.run_with_hooks(
                  input_tokens,
                  fwd_hooks=[(mlp_hook_point_name, capture_mlp_post_hook)]
              )

              # Store the activations for the target neurons
              if 'activation' in captured_mlp_activation_storage:
                  full_activation_vector = captured_mlp_activation_storage['activation'][0] # Batch size 1
                  for neuron_idx in top_neuron_indices:
                      neuron_activation_heatmaps[neuron_idx][a, b] = full_activation_vector[neuron_idx].item()
print("Activation generation complete.")

# --- Plotting ---
fig, axes = plt.subplots(1, top_n_to_visualize, figsize=(6 * top_n_to_visualize, 5))
if top_n_to_visualize == 1: # Handle case of single subplot
    axes = [axes]

fig.suptitle("MLP Neuron Activations (Post-ReLU) at Position '=' vs. Inputs (a, b)")

for i, neuron_idx in enumerate(top_neuron_indices):
    ax = axes[i]
    heatmap_data = neuron_activation_heatmaps[neuron_idx]
    sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap="viridis", ax=ax,
                linewidths=.5, linecolor='gray', cbar=True, square=True,
                # Mask cells with NaN so they don't show color
                mask=np.isnan(heatmap_data))
    ax.set_title(f"Neuron {neuron_idx}\nCorr w/ Sum: {neuron_sum_correlations[neuron_idx]:.3f}")
    ax.set_xlabel("Input b = F(n-1)")
    ax.set_ylabel("Input a = F(n-2)")
    # Set ticks to match input values
    ax.set_xticks(np.arange(N) + 0.5)
    ax.set_yticks(np.arange(N) + 0.5)
    ax.set_xticklabels(np.arange(N))
    ax.set_yticklabels(np.arange(N))
    ax.invert_yaxis() # Convention often puts (0,0) at top-left for matrices

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
plt.show()

The heatmaps show that the top-correlated neurons implement a function approximating ReLU(linear_function_of_sum). We observe a region of 0s (e.g., all values adding up to 6 or less for the middle neuron above), which is likely a result of the ReLU cutting off negatrive pre-activations, which basically creates a threshold below which the neurons do not fire. Further, we get roughly linear behavior with the sum of inputs above the threshold, as expected from the ReLU. The fact that the patterns are so similar across neurons again emphasizes that the computation implemented by the model is distributed.

We will now examine the pre-ReLU activations for these same neurons to test our hypothesis. If the core computation is indeed a linear function of the sum that is simply thresholded by ReLU, then the pre-ReLU heatmaps should reveal this underlying linear relationship more clearly across the entire input grid, without the large zeroed-out regions, thus confirming the thresholding mechanism.

In [None]:
mlp_hook_point_name_pre = "blocks.0.mlp.hook_pre" # Input BEFORE ReLU

# --- Data Storage ---
neuron_pre_activation_heatmaps = {
    idx: np.full((N, N), np.nan) for idx in top_neuron_indices
}
captured_mlp_pre_activation_storage = {} # Temporary storage during hook

# --- Hook Function ---
def capture_mlp_pre_hook(activation: torch.Tensor, hook: 'HookPoint'):
    """Captures MLP activation input (before ReLU) at position 2."""
    captured_mlp_pre_activation_storage['activation'] = activation[:, 2, :].detach().cpu()
    return activation

# --- Generate Activations ---
print(f"Generating PRE-ReLU activations for {len(top_neuron_indices)} neurons across {N}x{N} grid...")
model.eval()
with torch.no_grad():
    for a in range(N):
        for b in range(N):
              input_tokens = torch.tensor([[a, b, equals_token_id]], dtype=torch.long)

              # Clear previous capture
              captured_mlp_pre_activation_storage.clear()

              # Run with PRE-ReLU hook
              _ = model.run_with_hooks(
                  input_tokens,
                  fwd_hooks=[(mlp_hook_point_name_pre, capture_mlp_pre_hook)]
              )

              # Store the activations for the target neurons
              if 'activation' in captured_mlp_pre_activation_storage:
                  full_activation_vector = captured_mlp_pre_activation_storage['activation'][0]
                  for neuron_idx in top_neuron_indices:
                      neuron_pre_activation_heatmaps[neuron_idx][a, b] = full_activation_vector[neuron_idx].item()
print("PRE-ReLU activation generation complete.")

# --- Plotting ---
fig, axes = plt.subplots(1, len(top_neuron_indices), figsize=(6 * len(top_neuron_indices), 5))
if len(top_neuron_indices) == 1: axes = [axes]

fig.suptitle("MLP Neuron Activations (PRE-ReLU) at Position '=' vs. Inputs (a, b)")

for i, neuron_idx in enumerate(top_neuron_indices):
    ax = axes[i]
    # Use a diverging colormap like 'coolwarm' or 'RdBu_r' for pre-ReLU
    # as values can be positive or negative. Center the color map at 0.
    heatmap_data = neuron_pre_activation_heatmaps[neuron_idx]
    max_abs_val = np.nanmax(np.abs(heatmap_data)) # Find max absolute value for symmetric color scale
    sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap="coolwarm", ax=ax,
                linewidths=.5, linecolor='gray', cbar=True, square=True,
                mask=np.isnan(heatmap_data),
                vmin=-max_abs_val, vmax=max_abs_val) # Center colormap around 0
    ax.set_title(f"Neuron {neuron_idx} (Pre-ReLU)")
    ax.set_xlabel("Input b = F(n-1)")
    ax.set_ylabel("Input a = F(n-2)")
    ax.set_xticks(np.arange(N) + 0.5)
    ax.set_yticks(np.arange(N) + 0.5)
    ax.set_xticklabels(np.arange(N))
    ax.set_yticklabels(np.arange(N))
    ax.invert_yaxis()

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

We in fact observe a linear relation between value of the sum and pre-ReLU neruon value. However, note that the top neuron has a lot of zero values, whereas the other two go from negative to positive with similar magnitude. This indicates that they learn different linear functions (slopes, biases). However, part of this difference is thresholded away by the ReLU activation.

We now want to fit the pre-ReLU values to the PCA components to understand exactly what function thereof they are computing.

In [None]:
import pandas as pd

target_neuron_idx = 408
feature_list = []
target_activation_list = []

# Ensure pca_scores is NumPy for easier indexing if needed, keep precision
pca_scores_np = pca_scores.detach().cpu().numpy().astype(np.float32)
pre_activation_target_neuron = neuron_pre_activation_heatmaps[target_neuron_idx] # Get the heatmap

for a in range(N):
    for b in range(N):
        if a + b < N: # Only use valid sums
            # Get pre-calculated pre-ReLU activation
            target_activation = pre_activation_target_neuron[a, b]

            # Ensure we don't include NaN values if any exist in heatmap
            if not np.isnan(target_activation):
                # Extract PCA features for a and b
                pc1_a = pca_scores_np[a, 0]
                pc2_a = pca_scores_np[a, 1]
                pc1_b = pca_scores_np[b, 0]
                pc2_b = pca_scores_np[b, 1]

                # Add more features if desired (e.g., interactions)
                # For now, just the 4 core PCA features
                features = [pc1_a, pc2_a, pc1_b, pc2_b]
                feature_list.append(features)
                target_activation_list.append(target_activation)

# Convert lists to NumPy arrays
X_features_for_neuron_model = np.array(feature_list, dtype=np.float32)
y_target_neuron_pre_activation = np.array(target_activation_list, dtype=np.float32)

# Define feature names for interpreting coefficients
feature_names = ['PC1(a)', 'PC2(a)', 'PC1(b)', 'PC2(b)']

print(f"Prepared data for modeling neuron {target_neuron_idx}:")
print(f"  X_features shape: {X_features_for_neuron_model.shape}")
print(f"  y_target shape: {y_target_neuron_pre_activation.shape}")

# --- Fit the Linear Regression Model ---
neuron_model = LinearRegression()
neuron_model.fit(X_features_for_neuron_model, y_target_neuron_pre_activation)

# --- Evaluate the fit ---
y_neuron_pred = neuron_model.predict(X_features_for_neuron_model)
r2_neuron_fit = r2_score(y_target_neuron_pre_activation, y_neuron_pred)

print(f"\nLinear Model Fit for Neuron {target_neuron_idx} (Pre-ReLU):")
print(f"  R-squared of fit: {r2_neuron_fit:.4f}")

# --- Interpret the model ---
print(f"  Learned coefficients:")
coeffs = pd.Series(neuron_model.coef_, index=feature_names)
print(coeffs)
print(f"  Intercept: {neuron_model.intercept_:.4f}")

The excellent R² value confirms that Neuron 240's pre-ReLU activation is accurately modeled as a linear combination of the primary (linear value) and secondary (parabolic centrality) features derived from the input embeddings. The near-identical coefficients for the a and b components demonstrate symmetric treatment of the inputs, while the dominance of the PC1 coefficients confirms the neuron primarily computes a function proportional to the sum a+b. The smaller but non-zero PC2 coefficients suggest this core computation is subtly modulated based on the inputs' centrality, likely for fine-tuning or boundary adjustments within the N=10 range.

Something confuses me: it seems that the model just directly learns to add, as the linear component of the PCA is the most significant. However, the fact that there's a non-negligible contribution from the second component of the PCA makes me doubt this. In fact, the first one only explains 73% of the variance! That's a lot, but far from everything.

This motivates performing an ablation study, in which we get rid of the second principal component and investigate how (if at all) this affects model performance.

In [None]:
W_E_mean_gpu = W_E_mean.to(model.cfg.device) # Ensure mean is on correct device
pc1_vec = principal_components[0].to(model.cfg.device)
pc2_vec = principal_components[1].to(model.cfg.device)

# --- Squeeze the mean vector to be explicitly 1D ---
W_E_mean_1d_gpu = W_E_mean_gpu.squeeze(0) # Shape becomes [d_model]

def ablate_pc2_hook(
    activation: torch.Tensor, # Shape [batch, seq_len, d_model]
    hook: 'HookPoint'
):
    # Make a copy to modify
    ablated_activation = activation.clone()

    # Iterate through batch and relevant sequence positions (0 and 1 for a, b)
    for batch_idx in range(activation.shape[0]):
        for seq_idx in [0, 1]: # Ablate for token a and token b
            emb = activation[batch_idx, seq_idx, :] # Should be [d_model]

            # --- Subtract the squeezed 1D mean ---
            emb_centered = emb - W_E_mean_1d_gpu # Both are [d_model], result is [d_model]

            # Project - Now both inputs to torch.dot are 1D
            score_pc1 = torch.dot(emb_centered, pc1_vec)
            # score_pc2 = torch.dot(emb_centered, pc2_vec) # Still not needed

            # Reconstruct using only PC1 contribution
            # score_pc1 is a scalar, pc1_vec is [d_model] -> broadcasting works
            emb_centered_ablated = score_pc1 * pc1_vec

            # Add the 1D mean back
            emb_ablated = emb_centered_ablated + W_E_mean_1d_gpu

            # Put modified embedding back
            ablated_activation[batch_idx, seq_idx, :] = emb_ablated

    # Return the modified activations
    return ablated_activation

# --- How to run ---
embedding_hook_point = "hook_embed" # Or "hook_pos_embed", check model.hook_dict
logits = model.run_with_hooks(
     input_tokens,
     fwd_hooks=[(embedding_hook_point, ablate_pc2_hook)]
 )

test_pairs = []
test_labels_list = []
for a in range(N):
    for b in range(N):
          test_pairs.append([a, b, equals_token_id])
          test_labels_list.append(a + b)

# Convert to tensors
test_input_tokens = torch.tensor(test_pairs, dtype=torch.long).to(model.cfg.device)
test_labels = torch.tensor(test_labels_list, dtype=torch.long).to(model.cfg.device)

print(f"Created test data with {test_input_tokens.shape[0]} samples.")

# --- Helper function to calculate accuracy ---
def calculate_accuracy(logits, labels):
    """Calculates accuracy given logits and labels."""
    # Logits shape: [batch, seq_len, d_vocab]
    # We only care about the prediction at the last position (index 2)
    prediction_logits = logits[:, -1, :] # Shape: [batch, d_vocab]
    predicted_tokens = torch.argmax(prediction_logits, dim=-1) # Shape: [batch]
    correct_predictions = (predicted_tokens == labels).sum().item()
    total_predictions = labels.shape[0]
    accuracy = correct_predictions / total_predictions
    return accuracy, correct_predictions, total_predictions

# --- 2. Calculate Original Accuracy (No Hooks) ---
model.eval() # Set model to evaluation mode
with torch.no_grad(): # Use no_grad for inference efficiency
    original_logits = model(test_input_tokens)
    original_accuracy, orig_correct, orig_total = calculate_accuracy(original_logits, test_labels)

print(f"\nOriginal Model Performance:")
print(f"  Accuracy: {original_accuracy:.4f} ({orig_correct}/{orig_total})")

# --- 3. Calculate Ablated Accuracy (With Hook) ---
with torch.no_grad():
    ablated_logits = model.run_with_hooks(
        test_input_tokens,
        fwd_hooks=[(embedding_hook_point, ablate_pc2_hook)]
    )
    ablated_accuracy, ablated_correct, ablated_total = calculate_accuracy(ablated_logits, test_labels)

print(f"\nPC2 Ablated Model Performance:")
print(f"  Accuracy: {ablated_accuracy:.4f} ({ablated_correct}/{ablated_total})")

# --- 4. Compare ---
accuracy_drop = original_accuracy - ablated_accuracy
print(f"\nAccuracy Drop due to PC2 Ablation: {accuracy_drop:.4f}")

Accuracy drops by a lot! We go from getting it right 100% of the time, to only 42%. So the second principal component is still very important, and whatever non-linearity it's doing actually matters. Let's now look specifically at which input pairs the model fails at when we ablate the PC2.

In [None]:
# --- Get predictions from the ablated run ---
ablated_logits_last_pos = ablated_logits[:, -1, :] # Shape: [num_samples, d_vocab]
ablated_predicted_tokens = torch.argmax(ablated_logits_last_pos, dim=-1) # Shape: [num_samples]

# --- Identify incorrect predictions ---
incorrect_mask = (ablated_predicted_tokens != test_labels)
incorrect_indices = torch.where(incorrect_mask)[0].cpu().numpy()

# --- Map indices back to (a, b) pairs ---
failing_pairs = []
original_a_values = test_input_tokens[:, 0].cpu().numpy()
original_b_values = test_input_tokens[:, 1].cpu().numpy()
true_sums = test_labels.cpu().numpy()
predicted_sums_ablated = ablated_predicted_tokens.cpu().numpy()

for idx in incorrect_indices:
    a = original_a_values[idx]
    b = original_b_values[idx]
    true_sum = true_sums[idx]
    predicted_sum = predicted_sums_ablated[idx]
    failing_pairs.append({
        'a': a,
        'b': b,
        'True Sum (a+b)': true_sum,
        'Predicted Sum (Ablated)': predicted_sum
    })

# --- Display the failing pairs ---
if not failing_pairs:
    print("No failures found after PC2 ablation (unexpected based on previous accuracy).")
else:
    print(f"Found {len(failing_pairs)} failures after PC2 ablation:")
    fail_df = pd.DataFrame(failing_pairs)
    # Sort for potentially easier pattern spotting
    fail_df = fail_df.sort_values(by=['True Sum (a+b)', 'a', 'b']).reset_index(drop=True)
    print(fail_df.to_string()) # Use to_string to print the full DataFrame

# Optional: Visualize failures on a heatmap
failure_heatmap = np.zeros((N, N))
for failure in failing_pairs:
    failure_heatmap[failure['a'], failure['b']] = 1 # Mark failures with 1

plt.figure(figsize=(6, 5))
sns.heatmap(failure_heatmap, cmap="Reds", linewidths=.5, linecolor='gray', cbar=False, annot=False)
plt.title("Failure Cases (Red) after PC2 Ablation")
plt.xlabel("Input b = F(n-1)")
plt.ylabel("Input a = F(n-2)")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()

This is interesting... There's definitely patterns, but they are not easy to interpret. Some thoughts:

1.   Symmetry: Confirmed. Reinforces that a and b are treated symmetrically.
2.   Failures involving 0: This is a huge clue. The model fails for (0, 1) through (0, 8) and (1, 0) through (8, 0). However, it succeeds for (0, 0), (0, 9), and (9, 0).
3.   Same observation for failures involving 0. This suggests some symmetry in how the model treats small and big numbers.
3.   Central Failures: There's definitely a cluster of failures for pairs where both a and b are "mid-range" (roughly 2-7).

Given that the PC2 ablation dramatically impacts accuracy with a complex failure pattern—particularly affecting pairs involving edge tokens (0 and 9) differently and the central diagonal—it strongly suggests the issue lies not just in the MLP's internal calculation, but crucially in how the final output is derived from the MLP state. This points directly to the Unembedding layer (W_U), which performs this final mapping to output logits. Therefore, the next logical step is to analyze W_U to understand how it interprets the MLP's combined PC1/PC2 representation and why removing the PC2 component leads to these specific, non-uniform failures in decoding the sum.


In [None]:
hook_point_mlp_out = "blocks.0.hook_mlp_out"       # Output of the MLP layer
hook_point_resid_post = "blocks.0.hook_resid_post"  # Output of the whole block (after skip add)

# --- Shared dictionary for activation passing ---
# This dictionary will be accessible by both hooks during the run
captured_activations_for_run = {}

# --- Hook Function to CAPTURE mlp_out ---
def capture_mlp_out_hook(activation, hook):
    """Captures the mlp_out activation into the shared dictionary."""
    captured_activations_for_run['mlp_out'] = activation.detach()
    # Must return the original activation to not affect the forward pass here
    return activation

# --- Hook Function to MODIFY resid_post ---
def replace_resid_post_hook(activation, hook):
    """
    Replaces the resid_post activation with the captured mlp_out
    from the shared dictionary.
    """
    if 'mlp_out' not in captured_activations_for_run:
         print(f"Warning: Capture hook '{hook_point_mlp_out}' did not run before modify hook '{hook.name}'")
         return activation # Safety fallback

    mlp_out_value = captured_activations_for_run['mlp_out']

    # Ensure shapes are compatible
    if activation.shape == mlp_out_value.shape:
        # --- IMPORTANT: Return the captured mlp_out value ---
        # This replaces the original resid_post value (attn_out + mlp_out)
        # with just mlp_out for the rest of the forward pass.
        return mlp_out_value
    else:
        print(f"Shape mismatch warning: {hook.name} ({activation.shape}) vs captured mlp_out ({mlp_out_value.shape})")
        return activation # Fallback

# --- Run the Ablation ---
model.eval()
with torch.no_grad():
    # Clear the shared dictionary before the run
    captured_activations_for_run.clear()

    # Run with BOTH hooks active simultaneously
    # The library executes hooks in the order they appear in the forward pass
    # hook_mlp_out runs before hook_resid_post, so this works.
    print(f"Running model with capture hook at '{hook_point_mlp_out}' and modify hook at '{hook_point_resid_post}'")
    skip_ablated_logits = model.run_with_hooks(
        test_input_tokens,
        fwd_hooks=[
            (hook_point_mlp_out, capture_mlp_out_hook),
            (hook_point_resid_post, replace_resid_post_hook)
        ]
    )
    print("Model run with hooks completed.")


# --- Calculate and Compare Accuracy ---
# Recalculate original accuracy
original_logits = model(test_input_tokens) # Run without hooks
original_accuracy, orig_correct, orig_total = calculate_accuracy(original_logits, test_labels)

# Calculate ablated accuracy
skip_ablated_accuracy, skip_ablated_correct, skip_ablated_total = calculate_accuracy(skip_ablated_logits, test_labels)

print(f"\nOriginal Model Accuracy: {original_accuracy:.4f} ({orig_correct}/{orig_total})")
print(f"Skip Conn Ablated Accuracy (MLP Only): {skip_ablated_accuracy:.4f} ({skip_ablated_correct}/{skip_ablated_total})")
print(f"Accuracy Drop due to Skip Ablation: {original_accuracy - skip_ablated_accuracy:.4f}")

# Clear the dictionary just in case
captured_activations_for_run.clear()

No accuracy is lost by dropping the skip connection (the bit where you re-add the MLP input to its output before following to the output layer of the transformer). This indicates that we can effectively ignore it. Nanda had observed the same. Just to be sure, we are now going to do the same analysis but with the test loss rather than accuracy.

In [None]:
captured_activations_for_run = {}

# --- Run the Ablation and Calculate Logits (same as before) ---
model.eval()
with torch.no_grad():
    # Clear the shared dictionary before the run
    captured_activations_for_run.clear()

    print(f"Running model with capture hook at '{hook_point_mlp_out}' and modify hook at '{hook_point_resid_post}'")
    skip_ablated_logits = model.run_with_hooks(
        test_input_tokens,
        fwd_hooks=[
            (hook_point_mlp_out, capture_mlp_out_hook),
            (hook_point_resid_post, replace_resid_post_hook)
        ]
    )
    print("Model run with hooks completed.")

    # Calculate original logits
    original_logits = model(test_input_tokens)

# --- Calculate Accuracy AND Loss ---

# Original Performance
original_accuracy, orig_correct, orig_total = calculate_accuracy(original_logits, test_labels)
# Select logits for the prediction position (last token)
original_pred_logits = original_logits[:, -1, :]
original_loss = loss_fn(original_pred_logits, test_labels)

# Ablated Performance
skip_ablated_accuracy, skip_ablated_correct, skip_ablated_total = calculate_accuracy(skip_ablated_logits, test_labels)
# Select logits for the prediction position (last token)
skip_ablated_pred_logits = skip_ablated_logits[:, -1, :]
skip_ablated_loss = loss_fn(skip_ablated_pred_logits, test_labels)


# --- Print Results ---
print(f"\nOriginal Model Performance:")
print(f"  Accuracy: {original_accuracy:.4f} ({orig_correct}/{orig_total})")
print(f"  Loss:     {original_loss.item():.6f}") # Use .item() to get scalar value

print(f"\nSkip Conn Ablated Accuracy (MLP Only):")
print(f"  Accuracy: {skip_ablated_accuracy:.4f} ({skip_ablated_correct}/{skip_ablated_total})")
print(f"  Loss:     {skip_ablated_loss.item():.6f}")

print(f"\nAccuracy Drop due to Skip Ablation: {original_accuracy - skip_ablated_accuracy:.4f}")
print(f"Loss Change due to Skip Ablation:   {skip_ablated_loss.item() - original_loss.item():.6f}")

# Clear the dictionary
captured_activations_for_run.clear()

Loss increases by a few orders of magnitude, but in absolute terms is still small. I'd say this verifies our findings.

We are now going to look at the difference in logits when comparing the full trained model and the model in which we have ablated away the second PCA component (the parabolic one) in the embedded representation. We hope this sheds light on what it is that this 2nd component is doing.

In [None]:
# --- Get Original and Ablated Logits ---
model.eval()
with torch.no_grad():
    # Original logits
    original_logits = model(test_input_tokens) # Shape: [num_samples, seq_len, d_vocab]

    # Ablated logits (PC2 removed from embeddings)
    ablated_logits = model.run_with_hooks(
        test_input_tokens,
        fwd_hooks=[(embedding_hook_point, ablate_pc2_hook)]
    )

# --- Calculate Logit Difference ---
logit_diff = original_logits - ablated_logits # Shape: [num_samples, seq_len, d_vocab]

# Focus on the logits at the prediction position (last token)
logit_diff_pred_pos = logit_diff[:, -1, :].cpu().numpy() # Shape: [num_samples, d_vocab]
original_logits_pred_pos = original_logits[:, -1, :].cpu().numpy()
ablated_logits_pred_pos = ablated_logits[:, -1, :].cpu().numpy()

test_labels_np = test_labels.cpu().numpy()

# --- Analyze the Difference ---
results = []
original_a_values = test_input_tokens[:, 0].cpu().numpy()
original_b_values = test_input_tokens[:, 1].cpu().numpy()
ablated_predicted_tokens = np.argmax(ablated_logits_pred_pos, axis=-1)

for i in range(len(test_labels_np)):
    a = original_a_values[i]
    b = original_b_values[i]
    true_label = test_labels_np[i]
    ablated_pred = ablated_predicted_tokens[i]

    # Difference for the TRUE label's logit
    diff_for_true_label = logit_diff_pred_pos[i, true_label]

    # Difference for the ABLATED MODEL'S PREDICTED label's logit
    diff_for_ablated_pred = logit_diff_pred_pos[i, ablated_pred]

    original_correct_logit = original_logits_pred_pos[i, true_label]
    ablated_correct_logit = ablated_logits_pred_pos[i, true_label]

    results.append({
        'a': a,
        'b': b,
        'True Sum': true_label,
        'Ablated Pred': ablated_pred,
        'LogitDiff (True)': diff_for_true_label,
        'LogitDiff (AblatedPred)': diff_for_ablated_pred,
        'OrigLogit (True)': original_correct_logit,
        'AblatedLogit (True)': ablated_correct_logit,
        'Failed': true_label != ablated_pred
    })

results_df = pd.DataFrame(results)

# --- Display Key Information ---
print("Logit Difference Analysis (Original - Ablated):")
print("Positive 'LogitDiff (True)' means PC2 boosted the correct logit.")
print("Negative 'LogitDiff (AblatedPred)' for failures means PC2 suppressed the wrongly predicted logit.")

# Show stats for the change in the correct logit
print("\nStatistics for Logit Difference of the CORRECT sum:")
print(results_df['LogitDiff (True)'].describe())

# Show failure cases focusing on logit changes
failures_df = results_df[results_df['Failed']].sort_values(by=['True Sum', 'a', 'b'])
print(f"\nAnalysis of {len(failures_df)} Failure Cases:")
if not failures_df.empty:
    print(failures_df[['a', 'b', 'True Sum', 'Ablated Pred', 'LogitDiff (True)', 'LogitDiff (AblatedPred)', 'OrigLogit (True)', 'AblatedLogit (True)']].to_string())

# --- Visualize LogitDiff for the Correct Label ---
logit_diff_heatmap = np.full((N, N), np.nan)
for i in range(len(results_df)):
    row = results_df.iloc[i]
    logit_diff_heatmap[row['a'], row['b']] = row['LogitDiff (True)']

plt.figure(figsize=(7, 6))
max_abs_val = np.nanmax(np.abs(logit_diff_heatmap))
sns.heatmap(logit_diff_heatmap, cmap="coolwarm", center=0,
            annot=True, fmt=".2f", linewidths=.5, linecolor='gray',
            vmin=-max_abs_val, vmax=max_abs_val, # Center colorbar at 0
            mask=np.isnan(logit_diff_heatmap), square=True)
plt.title("Logit Difference (Orig - Ablated) for CORRECT Sum")
plt.xlabel("Input b = F(n-1)")
plt.ylabel("Input a = F(n-2)")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()

PC2 isn't just "fine-tuning". It's part of a complex encoding strategy interpreted by the Unembedding layer (W_U).

It acts as a strong identifier for "double-extreme" inputs ((0,0), (9,9)), massively boosting their correct logits.

It acts differently for "single-extreme" inputs ((0,k), (k,9)), where its net effect might even reduce the correct logit slightly, but likely plays a vital role in suppressing competitors.

It provides a moderate general boost for central inputs.

Let us now visualize by how much the ablated model is wrong, rather than just the fact that it is wrong.

In [None]:
prediction_error_heatmap = np.full((N, N), np.nan)

num_samples = len(test_labels_np)
prediction_errors = np.zeros(num_samples) # Store errors for potential stats

for i in range(num_samples):
    a_val = original_a_values[i]
    b_val = original_b_values[i]
    ablated_pred = ablated_predicted_tokens[i]
    true_label = test_labels_np[i]

    # Calculate the arithmetic difference
    error = float(ablated_pred - true_label) # Ensure float for calculations
    prediction_errors[i] = error

    # Place the error in the heatmap
    prediction_error_heatmap[a_val, b_val] = error

# --- Visualize Prediction Error Heatmap ---
plt.figure(figsize=(7, 6))

# Find the max absolute error for centering the color bar correctly
max_abs_error = np.nanmax(np.abs(prediction_error_heatmap))
# Handle case where there are no errors or all NaNs
if np.isnan(max_abs_error):
    max_abs_error = 1.0 # Default if no errors

# Use a diverging colormap since error can be positive or negative
cmap = "coolwarm" # Red for positive error (prediction > true), Blue for negative

sns.heatmap(prediction_error_heatmap, cmap=cmap, center=0,
            annot=True, fmt=".0f", # Show integer errors
            linewidths=.5, linecolor='lightgray',
            square=True,
            vmin=-max_abs_error, vmax=max_abs_error, # Center color scale around 0
            cbar_kws={'label': 'Ablated Prediction - True Sum'},
            mask=np.isnan(prediction_error_heatmap)) # Mask cells not in test data

plt.title("Prediction Error of Ablated Model (PC2 Removed)")
plt.xlabel("Input b = F(n-1)")
plt.ylabel("Input a = F(n-2)")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()

# --- Optional: Display Error Statistics ---
print("\nStatistics for Prediction Error (Ablated Pred - True Sum):")
# Create a temporary series excluding NaNs if your test set wasn't exhaustive
valid_errors = prediction_errors[~np.isnan(prediction_errors)]
if len(valid_errors) > 0:
    error_series = pd.Series(valid_errors)
    print(error_series.describe())
    print(f"\nNumber of incorrect predictions (error != 0): {np.sum(valid_errors != 0)}")
    print(f"Number of correct predictions (error == 0): {np.sum(valid_errors == 0)}")
else:
    print("No valid error data found (check test set population).")

Some observations: It's never wrong by more than 3. And is most often wrong by 1. It's off by 3 only at 0,0 and 9,9 -- it predicts 3 for 0,0 and 15 for 9,9. It's off by positive values when at least one of the inputs is 0, off by negative values when at least one of the inputs is 9. Exception made to 0,9 and 9,0, where it is correct -- suggesting that two wrongs make a right. Further, for low-ish values of a,b it underpredicts (off by negative), for high-ish values of a,b it overpredicts (off by positive).

Can we conclude anything about whatever algorithm it's implementing based on this? Let's fit the predictions of the ablated model to a linear function of the sum.

In [None]:
print("Fitting ablated predictions to a linear function of the true sum...")

# --- Prepare Data for Linear Regression ---
# X: Independent variable (True Sum)
# Needs to be a 2D array for scikit-learn, even with one feature
X_true_sum = test_labels_np.reshape(-1, 1)

# y: Dependent variable (Ablated Prediction)
y_ablated_pred = ablated_predicted_tokens

# --- Fit Linear Regression Model ---
lin_reg = LinearRegression()
lin_reg.fit(X_true_sum, y_ablated_pred)

# --- Get Results ---
# Slope (m)
learned_m = lin_reg.coef_[0]
# Intercept (c)
learned_c = lin_reg.intercept_
# Predictions from the fitted linear model
y_linear_fit_pred = lin_reg.predict(X_true_sum)

# --- Evaluate the Fit ---
# R-squared score
r2 = r2_score(y_ablated_pred, y_linear_fit_pred)
# Alternatively: r2 = lin_reg.score(X_true_sum, y_ablated_pred)
# Mean Squared Error
mse = mean_squared_error(y_ablated_pred, y_linear_fit_pred)

# --- Print Results and Comparison ---
print("\n--- Linear Fit Results: AblatedPred = m * TrueSum + c ---")
print(f"Learned Slope (m): {learned_m:.4f}")
print(f"Learned Intercept (c): {learned_c:.4f}")
print(f"R-squared score: {r2:.4f}")
print(f"Mean Squared Error: {mse:.4f}")

# Compare to hypothesized values
hypothesized_m = 2/3
hypothesized_c = 3
print("\nComparison to Hypothesis (m = 2/3, c = 3):")
print(f"Hypothesized m: {hypothesized_m:.4f}")
print(f"Hypothesized c: {hypothesized_c:.4f}")
print(f"Difference in m: {learned_m - hypothesized_m:.4f}")
print(f"Difference in c: {learned_c - hypothesized_c:.4f}")

# --- Visualize the Fit ---
plt.figure(figsize=(8, 6))
plt.scatter(test_labels_np, y_ablated_pred, alpha=0.5, label='Actual Ablated Predictions')
plt.plot(test_labels_np, y_linear_fit_pred, color='red', linewidth=2, label=f'Linear Fit (R²={r2:.3f})')
# Plot the hypothesized line
hypothesized_line = hypothesized_m * test_labels_np + hypothesized_c
plt.plot(test_labels_np, hypothesized_line, color='green', linestyle='--', linewidth=2, label='Hypothesized Line (m=2/3, c=3)')
# Plot the ideal line (y=x) for reference
plt.plot(test_labels_np, test_labels_np, color='gray', linestyle=':', linewidth=1, label='Ideal (y=x)')


plt.title('Linear Fit: Ablated Prediction vs. True Sum')
plt.xlabel('True Sum (a+b)')
plt.ylabel('Ablated Model Prediction')
plt.legend()
plt.grid(True, alpha=0.3)
# Ensure axes cover the full possible range (0 to max possible sum, e.g., 18)
max_sum = np.max(test_labels_np) if len(test_labels_np) > 0 else 18
max_pred = np.max(y_ablated_pred) if len(y_ablated_pred) > 0 else max_sum
plot_max = max(max_sum, max_pred)
plt.xlim(-0.5, plot_max + 0.5)
plt.ylim(-0.5, plot_max + 0.5)
plt.show()

Observation: the ablated model makes different predictions for some combinations of inputs which are different but have the same sum, which is why there are multiple blue dots on the same vertical line in the plot above. This necessarily implies that we cannot perfectly fit the ablated model's predictions to a function of the sum, we need to consider a and b separately.

In [None]:
print("\nFitting ablated predictions to a linear function of a and b separately...")

# --- Prepare Data for Multiple Linear Regression ---
# X: Independent variables (a, b)
# Create a 2D array where columns are 'a' and 'b'
X_inputs_ab = np.stack((original_a_values, original_b_values), axis=-1) # Shape: [num_samples, 2]

# y: Dependent variable (Ablated Prediction)
y_ablated_pred = ablated_predicted_tokens

# --- Fit Linear Regression Model ---
multi_lin_reg = LinearRegression()
multi_lin_reg.fit(X_inputs_ab, y_ablated_pred)

# --- Get Results ---
# Coefficients (w_a, w_b)
learned_wa, learned_wb = multi_lin_reg.coef_
# Intercept (c)
learned_c_multi = multi_lin_reg.intercept_
# Predictions from the fitted linear model
y_multilinear_fit_pred = multi_lin_reg.predict(X_inputs_ab)

# --- Evaluate the Fit ---
r2_multi = r2_score(y_ablated_pred, y_multilinear_fit_pred)
mse_multi = mean_squared_error(y_ablated_pred, y_multilinear_fit_pred)

# --- Print Results ---
print("\n--- Multiple Linear Fit Results: AblatedPred = wa*a + wb*b + c ---")
print(f"Learned Weight for a (wa): {learned_wa:.4f}")
print(f"Learned Weight for b (wb): {learned_wb:.4f}")
print(f"Learned Intercept (c): {learned_c_multi:.4f}")
print(f"R-squared score: {r2_multi:.4f}")
print(f"Mean Squared Error: {mse_multi:.4f}")

# --- Compare with previous single-variable fit ---
# Assuming 'r2', 'learned_m', 'learned_c' from the previous fit exist
if 'r2' in locals():
    print(f"\nComparison to previous fit (AblatedPred = m*(a+b) + c):")
    print(f"  Previous R-squared: {r2:.4f}")
    print(f"  Improvement in R-squared: {r2_multi - r2:.4f}")
    print(f"  Previous m: {learned_m:.4f} (Avg of wa, wb = {(learned_wa + learned_wb)/2:.4f})")
    print(f"  Previous c: {learned_c:.4f} (Current c = {learned_c_multi:.4f})")

# --- Optional: Visualize Residuals (If R-squared isn't close to 1) ---
residuals = y_ablated_pred - y_multilinear_fit_pred
plt.figure(figsize=(8, 6))
plt.scatter(y_multilinear_fit_pred, residuals, alpha=0.5)
plt.hlines(0, xmin=min(y_multilinear_fit_pred), xmax=max(y_multilinear_fit_pred), colors='red', linestyles='--')
plt.title('Residual Plot for Multi-Linear Fit')
plt.xlabel('Fitted Values (wa*a + wb*b + c)')
plt.ylabel('Residuals (Actual - Fitted)')
plt.grid(True, alpha=0.3)
plt.show()

# --- Optional: Heatmap of Residuals ---
residual_heatmap = np.full((N, N), np.nan)
fitted_values_heatmap = np.full((N, N), np.nan)
num_samples = len(y_ablated_pred)

for i in range(num_samples):
    a_val = original_a_values[i]
    b_val = original_b_values[i]
    residual_heatmap[a_val, b_val] = residuals[i]
    fitted_values_heatmap[a_val, b_val] = y_multilinear_fit_pred[i] # Store fitted value

plt.figure(figsize=(7, 6))
max_abs_resid = np.nanmax(np.abs(residual_heatmap))
if np.isnan(max_abs_resid): max_abs_resid = 1.0

cmap_resid = "coolwarm"
sns.heatmap(residual_heatmap, cmap=cmap_resid, center=0,
            annot=True, fmt=".2f",
            linewidths=.5, linecolor='lightgray',
            square=True,
            vmin=-max_abs_resid, vmax=max_abs_resid,
            cbar_kws={'label': 'Residual (Actual Ablated - Fitted)'},
            mask=np.isnan(residual_heatmap))
plt.title("Residuals of Multi-Linear Fit (wa*a + wb*b + c)")
plt.xlabel("Input b")
plt.ylabel("Input a")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()

The fit to a and b individually is the same as the fit to a+b. Hence, the dependence of the ablated model's prediction depends linearly on the inputs only via their sum. Given that the fit is not perfect (R^2 of ~0.93, not 1), this indicates that there is some non-linear component contributing to the model's output. We think this is introduced by the ReLU, as we see no other sources of non-linearity.

Hypothesis: note that the desired behavior is linear, i.e., just sum the inputs. The complete model implements this, albeit with a PC2 component which is non-linear. The ablated model that does not have the PC2 component fails to implement the linear behavior, and errs. Hence, it might be that the PC2 component's task is to cancel out the wrong non-linearity that the rest of the model appears to learn. We can test this by fitting the error in the ablated model's prediction to the value of the PC2 component for the same inputs.

In [None]:
from sklearn.preprocessing import PolynomialFeatures # For potential non-linear fit

# --- 1. Extract PC2 Features using EXISTING pca_scores ---
print("Extracting PC2 features from existing pca_scores...")
try:
    # Ensure pca_scores is available and is a tensor
    if 'pca_scores' not in locals() or not isinstance(pca_scores, torch.Tensor):
         raise NameError("Variable 'pca_scores' not found or is not a PyTorch tensor.")

    # Convert scores to numpy ONCE for use with sklearn and indexing
    pca_scores_np = pca_scores.detach().cpu().numpy()
    # PC2 values for all tokens in the vocabulary (index 1 is the second PC)
    if pca_scores_np.shape[1] < 2:
         raise ValueError(f"pca_scores_np has shape {pca_scores_np.shape}, but need at least 2 components to extract PC2.")
    pc2_values_all_tokens = pca_scores_np[:, 1]

    # Get PC2 value for input 'a' and 'b' for each sample in the test set
    pc2_a_features = pc2_values_all_tokens[original_a_values]
    pc2_b_features = pc2_values_all_tokens[original_b_values]

    # Create the feature matrix for regression
    X_pc2_features = np.stack((pc2_a_features, pc2_b_features), axis=-1) # Shape: [num_samples, 2]
    pc2_extracted_successfully = True
    print("PC2 features extracted successfully.")

except NameError as e:
    print(f"Error: {e}")
    print("Please ensure 'pca_scores' (Tensor) is computed and available.")
    pc2_extracted_successfully = False
except IndexError as e:
     print(f"Error indexing PC2 values. Check token indices range (0-{N-1}) vs vocab size used in PCA: {e}")
     pc2_extracted_successfully = False
except Exception as e:
    print(f"An unexpected error occurred during PC2 feature extraction: {e}")
    pc2_extracted_successfully = False
print(pc2_b_features)
# --- Proceed only if PC2 features were extracted successfully ---
if pc2_extracted_successfully:
    print("\nFitting Ablated Model Error to PC2 Features...")

    # --- 2. Calculate the Error ---
    y_error = test_labels_np - ablated_predicted_tokens # Correction needed

    # --- 3. Perform Linear Regression ---
    error_reg = LinearRegression()
    error_reg.fit(X_pc2_features, y_error)

    # --- Get Results ---
    learned_w_pc2a, learned_w_pc2b = error_reg.coef_
    learned_c_error = error_reg.intercept_
    y_error_pred = error_reg.predict(X_pc2_features)

    # --- Evaluate Fit ---
    r2_error_fit = r2_score(y_error, y_error_pred)
    mse_error_fit = mean_squared_error(y_error, y_error_pred)

    print("\n--- Linear Fit Results: Error ≈ w_pc2a*PC2(a) + w_pc2b*PC2(b) + c ---")
    print(f"Learned Weight for PC2(a): {learned_w_pc2a:.4f}")
    print(f"Learned Weight for PC2(b): {learned_w_pc2b:.4f}")
    print(f"Learned Intercept: {learned_c_error:.4f}")
    print(f"R-squared score: {r2_error_fit:.4f}")
    print(f"Mean Squared Error: {mse_error_fit:.4f}")

    # --- Optional: Try Polynomial Regression (Degree 2) ---
    print("\n--- Trying Polynomial Fit (Degree 2) ---")
    poly = PolynomialFeatures(degree=2, include_bias=False)
    # Create interaction & quadratic features: [pc2a, pc2b, pc2a^2, pc2a*pc2b, pc2b^2]
    X_pc2_poly_features = poly.fit_transform(X_pc2_features)

    error_poly_reg = LinearRegression()
    error_poly_reg.fit(X_pc2_poly_features, y_error)
    y_error_poly_pred = error_poly_reg.predict(X_pc2_poly_features)
    r2_error_poly_fit = r2_score(y_error, y_error_poly_pred)
    mse_error_poly_fit = mean_squared_error(y_error, y_error_poly_pred)

    print(f"Polynomial Fit R-squared score: {r2_error_poly_fit:.4f}")
    print(f"Polynomial Fit Mean Squared Error: {mse_error_poly_fit:.4f}")
    print(f"Improvement in R-squared: {r2_error_poly_fit - r2_error_fit:.4f}")


    # --- Visualize Error vs Predicted Error ---
    plt.figure(figsize=(8, 6))
    plt.scatter(y_error, y_error_pred, alpha=0.5, label=f'Linear Fit (R²={r2_error_fit:.3f})')
    # Only plot polynomial fit if it's meaningfully better
    if r2_error_poly_fit > r2_error_fit + 0.01:
         plt.scatter(y_error, y_error_poly_pred, alpha=0.5, marker='x', label=f'Poly Fit (R²={r2_error_poly_fit:.3f})')
    # Ideal line y=x
    min_val = min(np.min(y_error), np.min(y_error_pred)) - 0.5 # Add buffer
    max_val = max(np.max(y_error), np.max(y_error_pred)) + 0.5 # Add buffer
    #plt.plot([min_val, max_val], [min_val, max_val], color='red', linestyle='--', label='Ideal Fit (y=x)')
    plt.title('Predicting Ablated Model Error using PC2 Features')
    plt.xlabel('Actual Error (True Sum - Ablated Pred)')
    plt.ylabel('Predicted Error (from PC2 features)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("Skipping error prediction analysis due to issues extracting PC2 features.")

In [None]:


# --- Configuration ---
d_vocab = N # Or N+1 if your output includes a special token mapped within 0..N
d_mlp = model.cfg.d_mlp # Should be 512
num_key_neurons = 100 # Analyze the top 100 correlated neurons
key_neuron_indices = indices_sorted_by_abs_corr[:num_key_neurons]

# --- Step 1: Calculate W_L = W_U.T @ W_out.T ---

# Get the weight matrices
try:
    # Assume stored as [d_model, d_vocab]
    W_U_stored = model.unembed.W_U.detach().cpu()
    W_U = W_U_stored.T # Transpose to get [d_vocab, d_model]
except AttributeError:
    # Assume stored as [d_model, d_vocab]
    W_U_stored = model.W_U.detach().cpu()
    W_U = W_U_stored.T # Transpose to get [d_vocab, d_model]

# Assume stored as [d_mlp, d_model]
W_out_stored = model.blocks[0].mlp.W_out.detach().cpu()
W_out = W_out_stored.T # Transpose to get [d_model, d_mlp]

# Calculate the effective neuron->logit map
print(f"Multiplying W_U shape {W_U.shape} with W_out shape {W_out.shape}")
# This should now be [d_vocab, d_model] @ [d_model, d_mlp]
W_L = W_U @ W_out # Shape: [d_vocab, d_mlp]
W_L_np = W_L.numpy()

print(f"Calculated W_L with shape: {W_L_np.shape}") # Should be (d_vocab, d_mlp)

# --- Step 2: Analyze Rows of W_L (Code remains the same) ---
# ... (rest of the analysis code) ...

print(f"Calculated W_L with shape: {W_L_np.shape}") # Should be (d_vocab, d_mlp), e.g., (10, 512)

# --- Step 2: Analyze Rows of W_L and Neuron Contributions ---

print(f"\nAnalyzing weights in W_L for top {num_key_neurons} sum-correlated neurons:")

analysis_results = []

for k in range(d_vocab): # Iterate through each output logit (0 to N-1)
    # Get the weights from W_L for this specific logit 'k'
    # applied to the key neurons
    weights_for_key_neurons = W_L_np[k, key_neuron_indices] # Shape: [num_key_neurons]

    # Hypothesis Check: Contribution to correct logit
    positivity_ratio = np.mean(weights_for_key_neurons > 0)

    # Redundancy Check: Distribution statistics
    mean_weight = np.mean(weights_for_key_neurons)
    std_weight = np.std(weights_for_key_neurons)
    max_abs_weight = np.max(np.abs(weights_for_key_neurons))
    # Check sparsity: % of weights close to zero (e.g., < 1% of max abs weight)
    sparsity_threshold = 0.01 * max_abs_weight
    sparsity_ratio = np.mean(np.abs(weights_for_key_neurons) < sparsity_threshold)


    print(f"\n--- Logit for output token k={k} ---")
    print(f"  Positivity Ratio (Key Neurons): {positivity_ratio:.2f}")
    print(f"  Sparsity Ratio (<{sparsity_threshold:.4f}): {sparsity_ratio:.2f}")
    print(f"  Weight Stats (Key Neurons): Mean={mean_weight:.4f}, Std={std_weight:.4f}, MaxAbs={max_abs_weight:.4f}")

    analysis_results.append({
        'k': k,
        'positivity_ratio': positivity_ratio,
        'sparsity_ratio': sparsity_ratio,
        'mean_weight': mean_weight,
        'std_weight': std_weight,
        'max_abs_weight': max_abs_weight
    })

# Optional: Create a DataFrame for easier viewing
analysis_df = pd.DataFrame(analysis_results)
print("\n--- Summary Table ---")
print(analysis_df.round(3))

# Optional: Plot histogram for a specific k (e.g., k=5)
plt.figure(figsize=(8, 4))
k_to_plot = 5
plt.hist(W_L_np[k_to_plot, key_neuron_indices], bins=30, alpha=0.7)
plt.title(f"Weight Distribution in W_L[{k_to_plot}, :] for Top {num_key_neurons} Neurons")
plt.xlabel("Weight Value")
plt.ylabel("Frequency")
plt.grid(axis='y', alpha=0.5)
plt.show()

That didn't work! Let's try and fit the output of the ablated model to a quadratic function of the inputs.

In [None]:
print("\nFitting ablated predictions to quadratic functions of a and b...")

# --- Prepare Data ---
# y: Dependent variable (Ablated Prediction)
y_ablated_pred = ablated_predicted_tokens

# --- Fit 1: Requested Quadratic Model: x*(a+b) + y*a^2 + z*b^2 + c ---
print("\n--- Fitting Requested Model: x*(a+b) + y*a^2 + z*b^2 + c ---")

# Construct Features: [a+b, a^2, b^2]
sum_ab = original_a_values + original_b_values
a_squared = original_a_values**2
b_squared = original_b_values**2
X_quadratic_requested = np.stack((sum_ab, a_squared, b_squared), axis=-1) # Shape: [num_samples, 3]

# Fit Linear Regression Model
quad_reg_req = LinearRegression()
quad_reg_req.fit(X_quadratic_requested, y_ablated_pred)

# Get Results
learned_x, learned_y, learned_z = quad_reg_req.coef_
learned_c_req = quad_reg_req.intercept_
y_quad_req_pred = quad_reg_req.predict(X_quadratic_requested)

# Evaluate Fit
r2_quad_req = r2_score(y_ablated_pred, y_quad_req_pred)
mse_quad_req = mean_squared_error(y_ablated_pred, y_quad_req_pred)

print(f"Learned Coefficient for (a+b) (x): {learned_x:.4f}")
print(f"Learned Coefficient for a^2 (y):   {learned_y:.4f}")
print(f"Learned Coefficient for b^2 (z):   {learned_z:.4f}")
print(f"Learned Intercept (c):             {learned_c_req:.4f}")
print(f"R-squared score:                   {r2_quad_req:.4f}")
print(f"Mean Squared Error:                {mse_quad_req:.4f}")


# --- Fit 2: General Quadratic Model using PolynomialFeatures ---
print("\n--- Fitting General Quadratic Model (includes a*b term) ---")

# Create polynomial features (degree 2) from [a, b]
# Includes: a, b, a^2, a*b, b^2 (bias term is handled by LinearRegression)
poly = PolynomialFeatures(degree=2, include_bias=False)
X_inputs_ab = np.stack((original_a_values, original_b_values), axis=-1) # Shape: [num_samples, 2]
X_quadratic_general = poly.fit_transform(X_inputs_ab)
# Get feature names for clarity (order might vary slightly by sklearn version)
feature_names = poly.get_feature_names_out(['a', 'b'])


# Fit Linear Regression Model
quad_reg_gen = LinearRegression()
quad_reg_gen.fit(X_quadratic_general, y_ablated_pred)

# Get Results
learned_coeffs_gen = quad_reg_gen.coef_
learned_c_gen = quad_reg_gen.intercept_
y_quad_gen_pred = quad_reg_gen.predict(X_quadratic_general)

# Evaluate Fit
r2_quad_gen = r2_score(y_ablated_pred, y_quad_gen_pred)
mse_quad_gen = mean_squared_error(y_ablated_pred, y_quad_gen_pred)

print("General Quadratic Model: Pred ≈ w_a*a + w_b*b + w_aa*a^2 + w_ab*a*b + w_bb*b^2 + c_gen")
print("Learned Coefficients:")
for name, coeff in zip(feature_names, learned_coeffs_gen):
    print(f"  Weight for {name}: {coeff:.4f}")
print(f"Learned Intercept (c_gen): {learned_c_gen:.4f}")
print(f"R-squared score:           {r2_quad_gen:.4f}")
print(f"Mean Squared Error:        {mse_quad_gen:.4f}")


# --- Compare Fits ---
print("\n--- R-squared Comparison ---")
if 'r2_multi' in locals(): # Check if previous linear fit results exist
    print(f"Linear Fit (wa*a + wb*b + c):       R² = {r2_multi:.4f}")
else:
    print("Multi-linear fit results not available for comparison.")
print(f"Requested Quadratic Fit:          R² = {r2_quad_req:.4f}")
print(f"General Quadratic Fit (w/ a*b):   R² = {r2_quad_gen:.4f}")


# --- Visualize Residuals for the Best Quadratic Fit ---
best_r2_quad = max(r2_quad_req, r2_quad_gen)
print(f"\nVisualizing residuals for the best quadratic fit (R²={best_r2_quad:.4f})...")
if r2_quad_gen >= r2_quad_req:
    residuals = y_ablated_pred - y_quad_gen_pred
    fitted_values = y_quad_gen_pred
    title_suffix = "General Quadratic Fit"
else:
    residuals = y_ablated_pred - y_quad_req_pred
    fitted_values = y_quad_req_pred
    title_suffix = "Requested Quadratic Fit"

plt.figure(figsize=(8, 6))
plt.scatter(fitted_values, residuals, alpha=0.5)
plt.hlines(0, xmin=min(fitted_values), xmax=max(fitted_values), colors='red', linestyles='--')
plt.title(f'Residual Plot for {title_suffix}')
plt.xlabel('Fitted Values')
plt.ylabel('Residuals (Actual Ablated - Fitted)')
plt.grid(True, alpha=0.3)
plt.show()

# Optional: Heatmap of Residuals
residual_heatmap = np.full((N, N), np.nan)
num_samples = len(y_ablated_pred)
for i in range(num_samples):
    a_val = original_a_values[i]
    b_val = original_b_values[i]
    residual_heatmap[a_val, b_val] = residuals[i] # Use the correct residuals

plt.figure(figsize=(7, 6))
max_abs_resid = np.nanmax(np.abs(residual_heatmap))
if np.isnan(max_abs_resid) or max_abs_resid == 0: max_abs_resid = 1.0

cmap_resid = "coolwarm"
sns.heatmap(residual_heatmap, cmap=cmap_resid, center=0,
            annot=True, fmt=".2f",
            linewidths=.5, linecolor='lightgray',
            square=True,
            vmin=-max_abs_resid, vmax=max_abs_resid,
            cbar_kws={'label': 'Residual (Actual Ablated - Fitted)'},
            mask=np.isnan(residual_heatmap))
plt.title(f"Residuals of {title_suffix}")
plt.xlabel("Input b")
plt.ylabel("Input a")
plt.xticks(np.arange(N) + 0.5, np.arange(N))
plt.yticks(np.arange(N) + 0.5, np.arange(N))
plt.gca().invert_yaxis()
plt.show()



Unfortunately our hypothesis about quadratic PC2 term simply being a correction on top of the linear output was incorrect. Hence, its contribution must be more subtle. To try to understand it, we will now compare the activations of the neurons in the ablated and the full model, to see how removing the PC2 affects their firing.

In [None]:
# --- Storage for Activations ---
activation_store = {}

def store_activation_hook(activation, hook, storage_key):
    # Store activation at the final sequence position
    # Clone and detach to prevent memory leaks and graph issues
    activation_store[storage_key] = activation[:, -1, :].detach().clone()
    # Return original activation to not interfere
    return activation

# --- Execute and Capture Activations ---
a_mlp_full = None
a_mlp_ablated = None
mlp_activation_hook_point = utils.get_act_name("mlp_post", 0)
print(mlp_activation_hook_point)
if mlp_activation_hook_point:
    model.eval() # Ensure model is in evaluation mode
    with torch.no_grad():
        # --- Run 1: Get Full Model Activations ---
        print("Running model normally to capture full MLP activations...")
        activation_store.clear() # Clear previous storage
        _ = model.run_with_hooks(
            test_input_tokens,
            fwd_hooks=[(mlp_activation_hook_point,
                        lambda act, hook: store_activation_hook(act, hook, 'full'))]
        )
        if 'full' in activation_store:
            a_mlp_full = activation_store['full'].cpu().numpy()
            print(f"Captured full activations, shape: {a_mlp_full.shape}")
        else:
            print("Failed to capture full activations.")

        # --- Run 2: Get Ablated Model Activations ---
        print("Running model with PC2 ablation to capture ablated MLP activations...")
        activation_store.clear() # Clear previous storage
        _ = model.run_with_hooks(
            test_input_tokens,
            fwd_hooks=[(embedding_hook_point, ablate_pc2_hook), # Ablate PC2 first
                       (mlp_activation_hook_point,
                        lambda act, hook: store_activation_hook(act, hook, 'ablated'))]
        )
        if 'ablated' in activation_store:
            a_mlp_ablated = activation_store['ablated'].cpu().numpy()
            print(f"Captured ablated activations, shape: {a_mlp_ablated.shape}")
        else:
            print("Failed to capture ablated activations.")


# --- Calculate Difference and Analyze ---
if a_mlp_full is not None and a_mlp_ablated is not None:
    if a_mlp_full.shape == a_mlp_ablated.shape:
        print("\nCalculating activation differences...")
        delta_a_mlp = a_mlp_full - a_mlp_ablated # Shape: [num_samples, d_mlp]

        # --- Analysis Focused on Failures ---
        failure_mask = results_df['Failed'].values
        num_failures = failure_mask.sum()

        if num_failures > 0:
            print(f"Analyzing {num_failures} failure cases...")
            delta_a_mlp_failures = delta_a_mlp[failure_mask] # Shape: [num_failures, d_mlp]

            # --- 1. Which neurons change the most on average during failure? ---
            mean_abs_delta_failures = np.mean(np.abs(delta_a_mlp_failures), axis=0)
            sorted_neuron_indices_by_delta = np.argsort(mean_abs_delta_failures)[::-1]

            print("\nTop 10 Neurons by Mean Absolute Activation Change During Failures:")
            for i in range(min(10, d_mlp)):
                idx = sorted_neuron_indices_by_delta[i]
                print(f"  Neuron {idx}: Mean Abs Delta = {mean_abs_delta_failures[idx]:.4f} (Corr={neuron_sum_correlations[idx]:.3f})")

            # --- 2. Visualize Average Delta for ALL Failures ---
            mean_delta_failures = np.mean(delta_a_mlp_failures, axis=0) # Shape: [d_mlp,]

            # Sort by correlation group for visualization
            corr_sorted_indices = np.argsort(neuron_sum_correlations)
            mean_delta_failures_sorted = mean_delta_failures[corr_sorted_indices]
            correlations_sorted = neuron_sum_correlations[corr_sorted_indices]

            plt.figure(figsize=(15, 6))
            colors = ['red' if c < 0 else 'blue' for c in correlations_sorted]
            plt.bar(range(d_mlp), mean_delta_failures_sorted, color=colors, width=1.0)
            plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
            plt.ylabel("Mean Activation Diff (Full - Ablated)")
            plt.title(f"Mean MLP Activation Difference on Failure Cases ({num_failures} samples)")
            plt.xticks([]) # Too many neurons to label
            plt.grid(axis='y', alpha=0.5)
            # Add legend for color (optional, can be tricky for bar charts)
            from matplotlib.lines import Line2D
            legend_elements = [Line2D([0], [0], color='blue', lw=4, label='Positively Correlated Neurons'),
                               Line2D([0], [0], color='red', lw=4, label='Negatively Correlated Neurons')]
            plt.legend(handles=legend_elements)
            plt.show()

            # --- 3. Visualize Average Delta for Specific Failure Groups (Example: a=0 failures) ---
            a0_failure_mask = failure_mask & (results_df['a'] == 0)
            num_a0_failures = a0_failure_mask.sum()

            if num_a0_failures > 0:
                print(f"\nAnalyzing {num_a0_failures} failure cases where a=0...")
                delta_a_mlp_a0_failures = delta_a_mlp[a0_failure_mask]
                mean_delta_a0_failures = np.mean(delta_a_mlp_a0_failures, axis=0)
                mean_delta_a0_failures_sorted = mean_delta_a0_failures[corr_sorted_indices]

                plt.figure(figsize=(15, 6))
                plt.bar(range(d_mlp), mean_delta_a0_failures_sorted, color=colors, width=1.0)
                plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
                plt.ylabel("Mean Activation Diff (Full - Ablated)")
                plt.title(f"Mean MLP Activation Difference on Failure Cases with a=0 ({num_a0_failures} samples)")
                plt.xticks([])
                plt.grid(axis='y', alpha=0.5)
                plt.legend(handles=legend_elements)
                plt.show()
            else:
                print("No failure cases found where a=0.")

            # --- 4. Visualize Average Delta for Specific Failure Groups (Example: a=9 failures) ---
            a9_failure_mask = failure_mask & (results_df['a'] == 9)
            num_a9_failures = a9_failure_mask.sum()

            if num_a9_failures > 0:
                print(f"\nAnalyzing {num_a9_failures} failure cases where a=0...")
                delta_a_mlp_a0_failures = delta_a_mlp[a9_failure_mask]
                mean_delta_a0_failures = np.mean(delta_a_mlp_a0_failures, axis=0)
                mean_delta_a0_failures_sorted = mean_delta_a0_failures[corr_sorted_indices]

                plt.figure(figsize=(15, 6))
                plt.bar(range(d_mlp), mean_delta_a0_failures_sorted, color=colors, width=1.0)
                plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
                plt.ylabel("Mean Activation Diff (Full - Ablated)")
                plt.title(f"Mean MLP Activation Difference on Failure Cases with a=9 ({num_a9_failures} samples)")
                plt.xticks([])
                plt.grid(axis='y', alpha=0.5)
                plt.legend(handles=legend_elements)
                plt.show()
            else:
                print("No failure cases found where a=0.")

        else:
            print("No failure cases detected to analyze activation differences.")
    else:
        print("Error: Shape mismatch between full and ablated activations.")
        print(f"Full shape: {a_mlp_full.shape}, Ablated shape: {a_mlp_ablated.shape}")
else:
    print("Skipping activation difference analysis because activations were not captured.")


There's some patterns here, in that the changes to the activations seem to be grouped by correlation values, again pointing towards the 'groups of neurons' hypothesis. Let's perform a similar analysis for the pre-ReLU values.

In [None]:
# --- Configuration ---
mlp_pre_activation_hook_point = "blocks.0.mlp.hook_pre" # Input BEFORE ReLU
print(f"Using MLP pre-activation hook point: {mlp_pre_activation_hook_point}")

# --- Storage ---
pre_activation_store = {}

# --- Hook Function ---
def store_pre_activation_hook(activation, hook, storage_key):
    # Store pre-activation at the final sequence position (usually where prediction happens)
    pre_activation_store[storage_key] = activation[:, -1, :].detach().clone()
    return activation

# --- Execute and Capture Activations ---
z_mlp_full = None
z_mlp_ablated = None

model.eval()
with torch.no_grad():

    # --- Run 1: Get Full Model Pre-Activations ---
    print(f"\nRunning model normally, hooking: {mlp_pre_activation_hook_point}")
    pre_activation_store.clear()
    fwd_hooks_full = [(
        mlp_pre_activation_hook_point,
        partial(store_pre_activation_hook, storage_key='full') # Use partial
    )]
    _ = model.run_with_hooks(test_input_tokens, fwd_hooks=fwd_hooks_full)
    if 'full' in pre_activation_store:
        z_mlp_full = pre_activation_store['full'].cpu().numpy()
        print(f"Captured full pre-activations, shape: {z_mlp_full.shape}")
    else:
        print(f"FAILED to capture full pre-activations using hook: {mlp_pre_activation_hook_point}")

    # --- Run 2: Get Ablated Model Pre-Activations ---
    if z_mlp_full is not None: # Proceed only if first run succeeded
        print(f"\nRunning model with PC2 ablation, hooking: {mlp_pre_activation_hook_point}")
        pre_activation_store.clear()
        fwd_hooks_ablated = [
            (embedding_hook_point, ablate_pc2_hook), # Ablate PC2
            (mlp_pre_activation_hook_point, partial(store_pre_activation_hook, storage_key='ablated')) # Store pre-activation
        ]
        _ = model.run_with_hooks(test_input_tokens, fwd_hooks=fwd_hooks_ablated)
        if 'ablated' in pre_activation_store:
            z_mlp_ablated = pre_activation_store['ablated'].cpu().numpy()
            print(f"Captured ablated pre-activations, shape: {z_mlp_ablated.shape}")
        else:
            print(f"FAILED to capture ablated pre-activations using hook: {mlp_pre_activation_hook_point}")


# --- Calculate Difference and Analyze ---
if z_mlp_full is not None and z_mlp_ablated is not None and z_mlp_full.shape == z_mlp_ablated.shape:
    print("\nCalculating pre-activation differences (delta_z_mlp)...")
    delta_z_mlp = z_mlp_full - z_mlp_ablated

    failure_mask = results_df['Failed'].values
    num_failures = failure_mask.sum()

    if num_failures > 0:
        print(f"Analyzing {num_failures} failure cases...")
        delta_z_mlp_failures = delta_z_mlp[failure_mask]

        # --- 1. Top Changing Neurons ---
        mean_abs_delta_z_failures = np.mean(np.abs(delta_z_mlp_failures), axis=0)
        sorted_neuron_indices_by_delta_z = np.argsort(mean_abs_delta_z_failures)[::-1]
        print("\nTop 10 Neurons by Mean Absolute Pre-Activation Change During Failures:")
        for i in range(min(10, d_mlp)):
            idx = sorted_neuron_indices_by_delta_z[i]
            print(f"  Neuron {idx}: Mean Abs Delta_Z = {mean_abs_delta_z_failures[idx]:.4f} (Corr={neuron_sum_correlations[idx]:.3f})")

        # --- 2. Average Delta_Z for ALL Failures ---
        mean_delta_z_failures = np.mean(delta_z_mlp_failures, axis=0)
        corr_sorted_indices = np.argsort(neuron_sum_correlations)
        mean_delta_z_failures_sorted = mean_delta_z_failures[corr_sorted_indices]
        correlations_sorted = neuron_sum_correlations[corr_sorted_indices]
        plt.figure(figsize=(15, 6))
        colors = ['red' if c < 0 else 'blue' for c in correlations_sorted]
        plt.bar(range(d_mlp), mean_delta_z_failures_sorted, color=colors, width=1.0)
        plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
        plt.ylabel("Mean Pre-Activation Diff (Full - Ablated)")
        plt.title(f"Mean MLP Pre-Activation Difference (delta_z_mlp) on Failure Cases ({num_failures} samples)")
        plt.xticks([])
        plt.grid(axis='y', alpha=0.5)
        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], color='blue', lw=4, label='Positively Correlated Neurons'),
                           Line2D([0], [0], color='red', lw=4, label='Negatively Correlated Neurons')]
        plt.legend(handles=legend_elements)
        plt.show()

        # --- 3. Average Delta_Z for Specific Failure Groups (a=0) ---
        a0_failure_mask = failure_mask & (results_df['a'] == 0)
        num_a0_failures = a0_failure_mask.sum()
        if num_a0_failures > 0:
            print(f"\nAnalyzing {num_a0_failures} failure cases where a=0...")
            delta_z_mlp_a0_failures = delta_z_mlp[a0_failure_mask]
            mean_delta_z_a0_failures = np.mean(delta_z_mlp_a0_failures, axis=0)
            mean_delta_z_a0_failures_sorted = mean_delta_z_a0_failures[corr_sorted_indices]
            plt.figure(figsize=(15, 6))
            plt.bar(range(d_mlp), mean_delta_z_a0_failures_sorted, color=colors, width=1.0)
            plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
            plt.ylabel("Mean Pre-Activation Diff (Full - Ablated)")
            plt.title(f"Mean MLP Pre-Activation Difference (delta_z_mlp) on Failure Cases with a=0 ({num_a0_failures} samples)")
            plt.xticks([])
            plt.grid(axis='y', alpha=0.5)
            plt.legend(handles=legend_elements)
            plt.show()
        else:
            print("No failure cases found where a=0.")

    else:
        print("No failure cases detected.")
else:
    print("\nAnalysis skipped: Could not capture both full and ablated pre-activations.")


In the ablated model, are the pre-ReLU values still linear w.r.t. the sum, as is the case in the full model? We've established that this is not the case for the full model output, but not at the pre-ReLU stage. If it is, then the non-linearity must be coming from the ReLU. Otherwise, it's happening earlier on. Let's do this fit and plot it below

In [None]:
print("\nFitting PRE-ReLU ablated activations (z_mlp_ablated) vs. True Sum...")

# --- Check if data is available ---
if 'z_mlp_ablated' not in locals() or z_mlp_ablated is None:
    print("Error: 'z_mlp_ablated' not found. Please run the previous step first.")
else:
    # --- Prepare Data for Linear Regression ---
    # X: Independent variable (True Sum)
    X_true_sum = test_labels_np.reshape(-1, 1)
    num_samples = X_true_sum.shape[0]

    # Array to store R-squared for each neuron's fit
    neuron_r2_scores = np.zeros(d_mlp)
    neuron_slopes = np.zeros(d_mlp)
    neuron_intercepts = np.zeros(d_mlp)

    # --- Fit Linear Model for Each Neuron ---
    print(f"Fitting linear model for {d_mlp} neurons...")
    for i in range(d_mlp):
        # y: Dependent variable (Pre-ReLU activation of neuron i)
        y_neuron_z_ablated = z_mlp_ablated[:, i]

        # Check for constant activation (no variance) - skip if so
        if np.std(y_neuron_z_ablated) < 1e-9:
            neuron_r2_scores[i] = np.nan # Indicate no fit possible
            neuron_slopes[i] = np.nan
            neuron_intercepts[i] = np.nan
            continue

        # Fit model: z_i ≈ m_i * (a+b) + c_i
        lin_reg_neuron = LinearRegression()
        lin_reg_neuron.fit(X_true_sum, y_neuron_z_ablated)

        # Store results
        neuron_r2_scores[i] = lin_reg_neuron.score(X_true_sum, y_neuron_z_ablated)
        neuron_slopes[i] = lin_reg_neuron.coef_[0]
        neuron_intercepts[i] = lin_reg_neuron.intercept_

    print("Fitting complete.")

    # --- Analyze R-squared Distribution ---
    valid_r2_scores = neuron_r2_scores[~np.isnan(neuron_r2_scores)]

    if len(valid_r2_scores) > 0:
        print("\n--- Analysis of R-squared (z_mlp_ablated vs True Sum) ---")
        print(f"Mean R²:   {np.mean(valid_r2_scores):.4f}")
        print(f"Median R²: {np.median(valid_r2_scores):.4f}")
        print(f"Min R²:    {np.min(valid_r2_scores):.4f}")
        print(f"Max R²:    {np.max(valid_r2_scores):.4f}")
        print(f"Neurons with R² > 0.90: {np.sum(valid_r2_scores > 0.90)} / {len(valid_r2_scores)}")
        print(f"Neurons with R² > 0.95: {np.sum(valid_r2_scores > 0.95)} / {len(valid_r2_scores)}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_scores > 0.99)} / {len(valid_r2_scores)}")

        # --- Visualize R-squared Distribution ---
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_r2_scores, bins=50, kde=False)
        plt.title("Distribution of R²: Linear Fit of Pre-ReLU Ablated Activations vs. True Sum")
        plt.xlabel("R-squared Score per Neuron")
        plt.ylabel("Number of Neurons")
        plt.grid(axis='y', alpha=0.5)
        plt.show()


There's significant amounts of neurons that are very strongly correlated with the sum! Although I am not sure how this significant number compares to the same number in the full model. Nonetheless, let's do the same analysis for the post-ReLU values. Is this ReLU killing this linear relation?

In [None]:
print("\nFitting POST-ReLU ablated activations (a_mlp_ablated) vs. True Sum...")

# --- Check if data is available ---
if 'a_mlp_ablated' not in locals() or a_mlp_ablated is None:
    print("Error: 'a_mlp_ablated' not found. Please ensure post-ReLU ablated activations were captured.")
elif 'neuron_r2_scores' not in locals() or neuron_r2_scores is None:
    print("Error: 'neuron_r2_scores' (from pre-ReLU fit) not found. Please run the previous step first.")
else:
    # --- Prepare Data ---
    X_true_sum = test_labels_np.reshape(-1, 1)
    num_samples = X_true_sum.shape[0]

    # Array to store R-squared for the POST-ReLU fit
    neuron_r2_scores_post = np.zeros(d_mlp)
    neuron_slopes_post = np.zeros(d_mlp)
    neuron_intercepts_post = np.zeros(d_mlp)

    # --- Fit Linear Model for Each Neuron (POST-ReLU) ---
    print(f"Fitting linear model for {d_mlp} neurons (post-ReLU)...")
    for i in range(d_mlp):
        # y: Dependent variable (POST-ReLU activation of neuron i)
        y_neuron_a_ablated = a_mlp_ablated[:, i]

        # Check for constant activation (e.g., always zero post-ReLU)
        if np.std(y_neuron_a_ablated) < 1e-9:
            neuron_r2_scores_post[i] = np.nan # Indicate no fit possible
            neuron_slopes_post[i] = np.nan
            neuron_intercepts_post[i] = np.nan
            continue

        # Fit model: a_i ≈ m'_i * (a+b) + c'_i
        lin_reg_neuron_post = LinearRegression()
        lin_reg_neuron_post.fit(X_true_sum, y_neuron_a_ablated)

        # Store results
        neuron_r2_scores_post[i] = lin_reg_neuron_post.score(X_true_sum, y_neuron_a_ablated)
        neuron_slopes_post[i] = lin_reg_neuron_post.coef_[0]
        neuron_intercepts_post[i] = lin_reg_neuron_post.intercept_

    print("Fitting complete.")

    # --- Analyze POST-ReLU R-squared Distribution ---
    valid_r2_scores_post = neuron_r2_scores_post[~np.isnan(neuron_r2_scores_post)]

    if len(valid_r2_scores_post) > 0:
        print("\n--- Analysis of R-squared (POST-ReLU a_mlp_ablated vs True Sum) ---")
        print(f"Mean R²:   {np.mean(valid_r2_scores_post):.4f}")
        print(f"Median R²: {np.median(valid_r2_scores_post):.4f}")
        print(f"Min R²:    {np.min(valid_r2_scores_post):.4f}")
        print(f"Max R²:    {np.max(valid_r2_scores_post):.4f}")
        print(f"Neurons with R² > 0.90: {np.sum(valid_r2_scores_post > 0.90)} / {len(valid_r2_scores_post)}")
        print(f"Neurons with R² > 0.95: {np.sum(valid_r2_scores_post > 0.95)} / {len(valid_r2_scores_post)}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_scores_post > 0.99)} / {len(valid_r2_scores_post)}") # Expect fewer

        # --- Visualize POST-ReLU R-squared Distribution ---
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_r2_scores_post, bins=50, kde=False)
        plt.title("Distribution of R²: Linear Fit of POST-ReLU Ablated Activations vs. True Sum")
        plt.xlabel("R-squared Score per Neuron (Post-ReLU)")
        plt.ylabel("Number of Neurons")
        plt.grid(axis='y', alpha=0.5)
        plt.show()

        # --- Compare Pre-ReLU vs Post-ReLU R-squared ---
        # Use the original neuron_r2_scores from the pre-ReLU fit
        r2_diff = neuron_r2_scores - neuron_r2_scores_post # Pre - Post
        valid_r2_diff = r2_diff[~np.isnan(r2_diff)] # Filter NaNs from either pre or post

        print("\n--- Analysis of R-squared Change (Pre-ReLU R² - Post-ReLU R²) ---")
        print(f"Mean Difference:   {np.mean(valid_r2_diff):.4f} (Positive means ReLU reduced linearity)")
        print(f"Median Difference: {np.median(valid_r2_diff):.4f}")
        print(f"Min Difference:    {np.min(valid_r2_diff):.4f}") # Could be negative if post fit is somehow better
        print(f"Max Difference:    {np.max(valid_r2_diff):.4f}")

        # Visualize the distribution of the DIFFERENCE
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_r2_diff, bins=50, kde=False)
        plt.title("Distribution of R² Change (Pre-ReLU Fit R² minus Post-ReLU Fit R²)")
        plt.xlabel("Difference in R-squared (Positive = ReLU decreased linearity)")
        plt.ylabel("Number of Neurons")
        plt.grid(axis='y', alpha=0.5)
        plt.axvline(0, color='red', linestyle='--', linewidth=1) # Mark zero difference
        plt.show()

        # Focus on neurons that were highly linear PRE-ReLU
        highly_linear_pre_mask = neuron_r2_scores > 0.99
        if np.any(highly_linear_pre_mask):
            r2_diff_highly_linear = r2_diff[highly_linear_pre_mask & ~np.isnan(r2_diff)] # Also filter NaNs
            if len(r2_diff_highly_linear) > 0:
                 print(f"\n--- R² Change for Neurons with Pre-ReLU R² > 0.99 ({len(r2_diff_highly_linear)} neurons) ---")
                 print(f"Mean Difference:   {np.mean(r2_diff_highly_linear):.4f}")
                 print(f"Median Difference: {np.median(r2_diff_highly_linear):.4f}")
            else:
                 print("\nNo valid difference data for neurons with Pre-ReLU R² > 0.99 (check for NaNs in post-fit).")
        else:
            print("\nNo neurons found with Pre-ReLU R² > 0.99 based on previous results.")

    else:
        print("Could not calculate R-squared for any neurons post-ReLU.")


Ok, so the post-ReLU activations are much less strongly correlated with the sum. In fact, no neurons are now >0.9 fit with linear function of the sum. This could be just because the ReLU zeroes out the part that was negative, so you only get good correlation in the positive part, which results in overall not-so-good fit. I wonder if the same thing happens in the full model, though? Does the goodness of fit to linear function of sum change significantly pre and post ReLU?

In [None]:
print("\nFitting FULL model PRE- and POST-ReLU activations vs. True Sum...")

# --- Check if data is available ---
if 'z_mlp_full' not in locals() or z_mlp_full is None:
    print("Error: 'z_mlp_full' (full model pre-ReLU activations) not found.")
elif 'a_mlp_full' not in locals() or a_mlp_full is None:
    print("Error: 'a_mlp_full' (full model post-ReLU activations) not found.")
else:
    # --- Prepare Data ---
    X_true_sum = test_labels_np.reshape(-1, 1)
    num_samples = X_true_sum.shape[0]

    # Arrays to store R-squared for the FULL model fits
    neuron_r2_scores_pre_full = np.zeros(d_mlp)
    neuron_r2_scores_post_full = np.zeros(d_mlp)
    # Optional: store slopes/intercepts if needed later
    # neuron_slopes_pre_full = np.zeros(d_mlp)
    # neuron_intercepts_pre_full = np.zeros(d_mlp)
    # neuron_slopes_post_full = np.zeros(d_mlp)
    # neuron_intercepts_post_full = np.zeros(d_mlp)

    # --- Fit Linear Model for Each Neuron (Pre & Post ReLU) ---
    print(f"Fitting linear models for {d_mlp} neurons (full model)...")
    for i in range(d_mlp):
        # Pre-ReLU Fit
        y_neuron_z_full = z_mlp_full[:, i]
        if np.std(y_neuron_z_full) > 1e-9:
            lin_reg_neuron_pre = LinearRegression()
            lin_reg_neuron_pre.fit(X_true_sum, y_neuron_z_full)
            neuron_r2_scores_pre_full[i] = lin_reg_neuron_pre.score(X_true_sum, y_neuron_z_full)
        else:
            neuron_r2_scores_pre_full[i] = np.nan

        # Post-ReLU Fit
        y_neuron_a_full = a_mlp_full[:, i]
        if np.std(y_neuron_a_full) > 1e-9:
            lin_reg_neuron_post = LinearRegression()
            lin_reg_neuron_post.fit(X_true_sum, y_neuron_a_full)
            neuron_r2_scores_post_full[i] = lin_reg_neuron_post.score(X_true_sum, y_neuron_a_full)
        else:
            neuron_r2_scores_post_full[i] = np.nan

    print("Fitting complete.")

    # --- Analyze R-squared Distributions (FULL Model) ---
    valid_r2_pre_full = neuron_r2_scores_pre_full[~np.isnan(neuron_r2_scores_pre_full)]
    valid_r2_post_full = neuron_r2_scores_post_full[~np.isnan(neuron_r2_scores_post_full)]

    if len(valid_r2_pre_full) > 0 and len(valid_r2_post_full) > 0:
        print("\n--- Analysis of R² (FULL Model PRE-ReLU vs True Sum) ---")
        print(f"Mean R²:   {np.mean(valid_r2_pre_full):.4f}")
        print(f"Median R²: {np.median(valid_r2_pre_full):.4f}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_pre_full > 0.99)} / {len(valid_r2_pre_full)}")

        print("\n--- Analysis of R² (FULL Model POST-ReLU vs True Sum) ---")
        print(f"Mean R²:   {np.mean(valid_r2_post_full):.4f}")
        print(f"Median R²: {np.median(valid_r2_post_full):.4f}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_post_full > 0.99)} / {len(valid_r2_post_full)}") # Compare this count

        # --- Visualize Distributions ---
        fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
        sns.histplot(valid_r2_pre_full, bins=50, kde=False, ax=axes[0])
        axes[0].set_title("FULL Model: Pre-ReLU Activation Linearity (R² vs Sum)")
        axes[0].set_xlabel("R-squared Score per Neuron (Pre-ReLU)")
        axes[0].set_ylabel("Number of Neurons")
        axes[0].grid(axis='y', alpha=0.5)

        sns.histplot(valid_r2_post_full, bins=50, kde=False, ax=axes[1])
        axes[1].set_title("FULL Model: Post-ReLU Activation Linearity (R² vs Sum)")
        axes[1].set_xlabel("R-squared Score per Neuron (Post-ReLU)")
        axes[1].grid(axis='y', alpha=0.5)
        plt.suptitle("Comparison of Activation Linearity Before and After ReLU (FULL Model)")
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

        # --- Compare Pre-ReLU vs Post-ReLU R-squared (FULL Model) ---
        r2_diff_full = neuron_r2_scores_pre_full - neuron_r2_scores_post_full
        valid_r2_diff_full = r2_diff_full[~np.isnan(r2_diff_full)]

        print("\n--- Analysis of R² Change (Pre-ReLU R² - Post-ReLU R²) for FULL Model ---")
        print(f"Mean Difference:   {np.mean(valid_r2_diff_full):.4f} (Positive means ReLU reduced linearity)")
        print(f"Median Difference: {np.median(valid_r2_diff_full):.4f}")

        # Visualize the difference distribution
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_r2_diff_full, bins=50, kde=False)
        plt.title("FULL Model: Distribution of R² Change (Pre-ReLU Fit R² minus Post-ReLU Fit R²)")
        plt.xlabel("Difference in R-squared (Positive = ReLU decreased linearity)")
        plt.ylabel("Number of Neurons")
        plt.grid(axis='y', alpha=0.5)
        plt.axvline(0, color='red', linestyle='--', linewidth=1)
        plt.show()

        # --- Direct Comparison to Ablated Model Results (Optional) ---
        if 'neuron_r2_scores' in locals() and 'neuron_r2_scores_post' in locals():
             print("\n--- Comparison to Ablated Model Linearity Drop ---")
             r2_diff_ablated = neuron_r2_scores - neuron_r2_scores_post
             valid_r2_diff_ablated = r2_diff_ablated[~np.isnan(r2_diff_ablated)]
             print(f"Mean R² Drop (Full Model):    {np.mean(valid_r2_diff_full):.4f}")
             print(f"Mean R² Drop (Ablated Model): {np.mean(valid_r2_diff_ablated):.4f}")
             print(f"Median R² Drop (Full Model):    {np.median(valid_r2_diff_full):.4f}")
             print(f"Median R² Drop (Ablated Model): {np.median(valid_r2_diff_ablated):.4f}")
        else:
             print("\nAblated model R-squared results not available for direct comparison.")

    else:
        print("Could not calculate R-squared for enough neurons in the full model.")


Ok, so basically the same thing happens in the full model. Hence, this does not seem the be the answer. I suspect it is indeed because the ReLU zeroes out parts of the output, which means that you do not get a good correlation overall. Where do we go from here?

Let's check whether the pre-ReLU change introduced by the PC2 is a simple function (e.g., linear or quadratic) of the PC2 itself. To do this, we compute the difference in pre-ReLU values for the full and ablated model.

In [None]:
print("\nFitting delta_z_mlp (pre-ReLU change) vs. PC2 Features...")

# --- Check if data is available ---
if 'delta_z_mlp' not in locals() or delta_z_mlp is None:
    print("Error: 'delta_z_mlp' not found. Please calculate it first.")
elif 'X_pc2_features' not in locals() or X_pc2_features is None:
    print("Error: 'X_pc2_features' not found. Please extract PC2 features first.")
else:
    num_samples, num_neurons = delta_z_mlp.shape
    assert num_neurons == d_mlp

    # Arrays to store R-squared scores
    neuron_r2_linear_fit = np.zeros(d_mlp)
    neuron_r2_poly_fit = np.zeros(d_mlp)

    # --- Prepare Polynomial Features ---
    poly = PolynomialFeatures(degree=2, include_bias=False)
    X_pc2_poly_features = poly.fit_transform(X_pc2_features) # Shape [num_samples, 5]

    # --- Fit Models for Each Neuron ---
    print(f"Fitting models for {d_mlp} neurons...")
    for i in range(d_mlp):
        # Target variable: pre-ReLU change for neuron i
        y_delta_z_neuron = delta_z_mlp[:, i]

        # Check for constant change (no variance)
        if np.std(y_delta_z_neuron) < 1e-9:
            neuron_r2_linear_fit[i] = np.nan
            neuron_r2_poly_fit[i] = np.nan
            continue

        # --- Linear Fit ---
        # Fit: delta_z_i ≈ w_a*PC2(a) + w_b*PC2(b) + c
        lin_reg_delta_z = LinearRegression()
        lin_reg_delta_z.fit(X_pc2_features, y_delta_z_neuron)
        neuron_r2_linear_fit[i] = lin_reg_delta_z.score(X_pc2_features, y_delta_z_neuron)

        # --- Polynomial Fit ---
        # Fit: delta_z_i ≈ quadratic_func(PC2(a), PC2(b)) + c
        poly_reg_delta_z = LinearRegression()
        poly_reg_delta_z.fit(X_pc2_poly_features, y_delta_z_neuron)
        neuron_r2_poly_fit[i] = poly_reg_delta_z.score(X_pc2_poly_features, y_delta_z_neuron)

    print("Fitting complete.")

    # --- Analyze R-squared Distributions ---
    valid_r2_linear = neuron_r2_linear_fit[~np.isnan(neuron_r2_linear_fit)]
    valid_r2_poly = neuron_r2_poly_fit[~np.isnan(neuron_r2_poly_fit)]

    if len(valid_r2_linear) > 0 and len(valid_r2_poly) > 0:
        print("\n--- Analysis of R² (Fit delta_z_mlp vs PC2 Features) ---")

        print("\n--- Linear Fit ---")
        print(f"Mean R²:   {np.mean(valid_r2_linear):.4f}")
        print(f"Median R²: {np.median(valid_r2_linear):.4f}")
        print(f"Max R²:    {np.max(valid_r2_linear):.4f}")
        print(f"Neurons with R² > 0.90: {np.sum(valid_r2_linear > 0.90)} / {len(valid_r2_linear)}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_linear > 0.99)} / {len(valid_r2_linear)}")

        print("\n--- Polynomial Fit (Degree 2) ---")
        print(f"Mean R²:   {np.mean(valid_r2_poly):.4f}")
        print(f"Median R²: {np.median(valid_r2_poly):.4f}")
        print(f"Max R²:    {np.max(valid_r2_poly):.4f}")
        print(f"Neurons with R² > 0.90: {np.sum(valid_r2_poly > 0.90)} / {len(valid_r2_poly)}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_poly > 0.99)} / {len(valid_r2_poly)}")


        # --- Visualize Distributions ---
        fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
        sns.histplot(valid_r2_linear, bins=50, kde=False, ax=axes[0])
        axes[0].set_title("R²: Linear Fit of delta_z_mlp vs [PC2(a), PC2(b)]")
        axes[0].set_xlabel("R-squared Score per Neuron")
        axes[0].set_ylabel("Number of Neurons")
        axes[0].grid(axis='y', alpha=0.5)

        sns.histplot(valid_r2_poly, bins=50, kde=False, ax=axes[1])
        axes[1].set_title("R²: Quadratic Fit of delta_z_mlp vs [PC2(a), PC2(b)]")
        axes[1].set_xlabel("R-squared Score per Neuron")
        axes[1].grid(axis='y', alpha=0.5)
        plt.suptitle("How well do PC2 features predict the pre-ReLU change (delta_z_mlp)?")
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

    else:
        print("Could not calculate R-squared for any neurons.")


So a linear fit is already very good, a quadratic fit improves it only slightly. There's a few neurons that have roughly ~0 R^2 -- I suspect they might be neurons that anyway do not contribute to the computation (but I don't know that for a fact).

Ok, so this indicates that the change in pre-ReLU values is a simple function of PC2. What does this tell us? Can we infer anything about the algorithm from here? PC2 is somehow directly being used to change the pre-ReLU values so that the 'right' neurons fire. But what are these 'right' neurons?

Could we perhaps check which neurons used to fire but don't anymore in the ablated model for one of the cases where it fails by a lot (e.g., 0,0 input, where it misses the correct output by 3)? And perhaps try to understand how the effective W_L matrix makes use of these neurons? Perhaps this can give us a clue into what's happening?

More broadly, I think we still don't have a good picture of how many of these neurons are actually relevant. We have observed that tens of neurons have strong correlations with the sum, but we don't know if/how they are used by the model, as we don't know how W_L affects them. Any one of the following could hold:

1. They are treated roughly equally, i.e., similar weights in W_L, indicating that the algorithm is truly distributed
2. Only one/a few are not 'zeroed out', indicating that despite the fact that the transformer learns a distributed representation in the MLP, this is not actually used
3. The weights are severely input-dependent, indicating that some neurons are used for given input pairs, and others for other
4. ...

Let's indeed check for neurons that change firing state in ablated model for the (0,0) input case.

In [None]:
# --- Ensure W_L is NumPy for consistency in this block ---
# If W_L is potentially a tensor, convert it now
if isinstance(W_L, torch.Tensor):
    print("Converting W_L tensor to NumPy array...")
    W_L_np = W_L.detach().cpu().numpy()
else:
    W_L_np = np.asarray(W_L) # Ensure it's a NumPy array if not a tensor

# --- Find index for input (0,0) ---
try:
    idx_00 = np.where((original_a_values == 0) & (original_b_values == 0))[0][0]
    print(f"Found sample index for input (0,0): {idx_00}")
except IndexError:
    print("Error: Input pair (0,0) not found in the test set data used for activations.")
    idx_00 = None

if idx_00 is not None:
    # --- Get pre-ReLU activations for (0,0) ---
    z_full_00 = z_mlp_full[idx_00]      # Should be NumPy
    z_ablated_00 = z_mlp_ablated[idx_00] # Should be NumPy

    # --- Identify Flipping Neurons ---
    ablated_active_mask = z_ablated_00 > 0
    full_active_mask = z_full_00 > 0

    off_to_on_mask = (~ablated_active_mask) & full_active_mask
    on_to_off_mask = ablated_active_mask & (~full_active_mask)

    off_to_on_indices = np.where(off_to_on_mask)[0] # NumPy array of indices
    on_to_off_indices = np.where(on_to_off_mask)[0] # NumPy array of indices

    num_off_to_on = len(off_to_on_indices)
    num_on_to_off = len(on_to_off_indices)
    print(f"\nAnalysis for input (0,0) [True=0, Ablated=3]:")
    print(f"  Neurons flipped OFF -> ON : {num_off_to_on}")
    print(f"  Neurons flipped ON -> OFF : {num_on_to_off}")

    # --- Analyze W_L weights for flipping neurons ---
    correct_logit_idx = 0
    ablated_logit_idx = 3

    print("\n-- Neurons Turning ON (OFF -> ON) --")
    if num_off_to_on > 0:
        # Slice W_L_np using NumPy indices - result is NumPy array
        W_L_0_oto_np = W_L_np[correct_logit_idx, off_to_on_indices]
        W_L_3_oto_np = W_L_np[ablated_logit_idx, off_to_on_indices]

        if num_off_to_on == 1:
            oto_mean_0, oto_std_0 = W_L_0_oto_np.item(), 0.0
            oto_mean_3, oto_std_3 = W_L_3_oto_np.item(), 0.0
            prop_correct_sign_0 = 1.0 if oto_mean_0 > 0 else 0.0
            prop_correct_sign_3 = 1.0 if oto_mean_3 < 0 else 0.0
            print(f"  Weight for Neuron {off_to_on_indices[0]}:")
            print(f"    W_L[{correct_logit_idx}]: {oto_mean_0:.4f}")
            print(f"    W_L[{ablated_logit_idx}]: {oto_mean_3:.4f}")
        else:
            oto_mean_0, oto_std_0 = np.mean(W_L_0_oto_np), np.std(W_L_0_oto_np)
            oto_mean_3, oto_std_3 = np.mean(W_L_3_oto_np), np.std(W_L_3_oto_np)
            prop_correct_sign_0 = np.mean(W_L_0_oto_np > 0)
            prop_correct_sign_3 = np.mean(W_L_3_oto_np < 0)
            print(f"  Mean W_L weight for Logit {correct_logit_idx}: {oto_mean_0:.4f} (Std: {oto_std_0:.4f})")
            print(f"  Mean W_L weight for Logit {ablated_logit_idx}: {oto_mean_3:.4f} (Std: {oto_std_3:.4f})")
            print(f"  Proportion with W_L[{correct_logit_idx}] > 0 : {prop_correct_sign_0:.2f}")
            print(f"  Proportion with W_L[{ablated_logit_idx}] < 0 : {prop_correct_sign_3:.2f}")
    else:
        print("  None")


    print("\n-- Neurons Turning OFF (ON -> OFF) --")
    if num_on_to_off > 0:
        # Slice W_L_np using NumPy indices - result is NumPy array
        W_L_0_otf_np = W_L_np[correct_logit_idx, on_to_off_indices]
        W_L_3_otf_np = W_L_np[ablated_logit_idx, on_to_off_indices]

        if num_on_to_off == 1:
            otf_mean_0, otf_std_0 = W_L_0_otf_np.item(), 0.0
            otf_mean_3, otf_std_3 = W_L_3_otf_np.item(), 0.0
            prop_correct_sign_0 = 1.0 if otf_mean_0 < 0 else 0.0
            prop_correct_sign_3 = 1.0 if otf_mean_3 > 0 else 0.0
            print(f"  Weight for Neuron {on_to_off_indices[0]}:")
            print(f"    W_L[{correct_logit_idx}]: {otf_mean_0:.4f}")
            print(f"    W_L[{ablated_logit_idx}]: {otf_mean_3:.4f}")
        else:
            # These should now work correctly as input is NumPy array
            otf_mean_0, otf_std_0 = np.mean(W_L_0_otf_np), np.std(W_L_0_otf_np)
            otf_mean_3, otf_std_3 = np.mean(W_L_3_otf_np), np.std(W_L_3_otf_np)
            prop_correct_sign_0 = np.mean(W_L_0_otf_np < 0)
            prop_correct_sign_3 = np.mean(W_L_3_otf_np > 0)
            print(f"  Mean W_L weight for Logit {correct_logit_idx}: {otf_mean_0:.4f} (Std: {otf_std_0:.4f})")
            print(f"  Mean W_L weight for Logit {ablated_logit_idx}: {otf_mean_3:.4f} (Std: {otf_std_3:.4f})")
            print(f"  Proportion with W_L[{correct_logit_idx}] < 0 : {prop_correct_sign_0:.2f}")
            print(f"  Proportion with W_L[{ablated_logit_idx}] > 0 : {prop_correct_sign_3:.2f}")
    else:
        print("  None")


    # --- Calculate total change in logit contribution from flipping ---
    # Ensure delta_a_00 is numpy (should be from z_mlp numpy arrays)
    a_full_00 = np.maximum(0, z_full_00)
    a_ablated_00 = np.maximum(0, z_ablated_00)
    delta_a_00 = a_full_00 - a_ablated_00 # NumPy array

    delta_logit_0_oto = 0
    delta_logit_3_oto = 0
    if num_off_to_on > 0:
        # Use W_L_np for calculation
        delta_logit_0_oto = np.sum(W_L_np[correct_logit_idx, off_to_on_indices] * delta_a_00[off_to_on_indices])
        delta_logit_3_oto = np.sum(W_L_np[ablated_logit_idx, off_to_on_indices] * delta_a_00[off_to_on_indices])

    delta_logit_0_otf = 0
    delta_logit_3_otf = 0
    if num_on_to_off > 0:
         # Use W_L_np for calculation
         delta_logit_0_otf = np.sum(W_L_np[correct_logit_idx, on_to_off_indices] * delta_a_00[on_to_off_indices])
         delta_logit_3_otf = np.sum(W_L_np[ablated_logit_idx, on_to_off_indices] * delta_a_00[on_to_off_indices])

    print("\n--- Approx. Logit Change from Flipping Neurons ---")
    print("  (Change = Sum[ W_L[logit, neuron] * (a_full - a_ablated) ] for flipping neurons)")
    print(f"  From OFF->ON : dLogit[{correct_logit_idx}]={delta_logit_0_oto:.3f}, dLogit[{ablated_logit_idx}]={delta_logit_3_oto:.3f}")
    print(f"  From ON->OFF : dLogit[{correct_logit_idx}]={delta_logit_0_otf:.3f}, dLogit[{ablated_logit_idx}]={delta_logit_3_otf:.3f}")
    print(f"  Total Flip Contribution: dLogit[{correct_logit_idx}]={delta_logit_0_oto + delta_logit_0_otf:.3f}, dLogit[{ablated_logit_idx}]={delta_logit_3_oto + delta_logit_3_otf:.3f}")


Let's try to interpret the results. We are looking at input (0,0). The expected output is 0, which the full model correctly outputs. The ablated model outputs 3. We are looking at which neurons change their activation status after ablating. So either they fire in the full model, but not in the ablated one, or vice-versa. We find 25 such neurons, 24 of which go from firing to not firing. The one that turns on does not change the relevant logits (i.e., for 0 and for 3) at all. The others due! Mostly by massively increasing the logit for 0, and decreasing for 3. Wait, shouldn't the signs be the other way around?

Also, perhaps it makes sense to look at logit variation for all possible outputs due to these neurons? Perhaps this gives a clearer picture of the algorithm being implemented. Let's do so.

In [None]:
if idx_00 is not None and 'z_mlp_full' in locals() and 'z_mlp_ablated' in locals() and 'W_L_np' in locals():
    print(f"\nCalculating full logit change profile for input (0,0)...")

    # --- Calculate Post-ReLU activation change for (0,0) ---
    z_full_00 = z_mlp_full[idx_00]
    z_ablated_00 = z_mlp_ablated[idx_00]
    a_full_00 = np.maximum(0, z_full_00)
    a_ablated_00 = np.maximum(0, z_ablated_00)
    delta_a_00 = a_full_00 - a_ablated_00 # Shape: [d_mlp,]

    # --- Calculate induced change for each logit ---
    # Matrix multiplication: (num_outputs, d_mlp) @ (d_mlp,) -> (num_outputs,)
    delta_logits_00 = W_L_np @ delta_a_00

    # --- Visualize the Logit Change Profile ---
    plt.figure(figsize=(12, 6))
    output_indices = np.arange(num_outputs)
    colors = ['green' if i == 0 else 'red' if i == 3 else 'gray' for i in output_indices] # Highlight 0 and 3
    bars = plt.bar(output_indices, delta_logits_00, color=colors)

    plt.xlabel("Output Logit Index (k)")
    plt.ylabel("Approx. Change in Logit (Full - Ablated)")
    plt.title("Approx. Logit Change Profile for Input (0,0) due to PC2")
    plt.xticks(output_indices)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    # Add legend manually
    from matplotlib.lines import Line2D
    legend_elements = [Line2D([0], [0], color='green', lw=4, label=f'Logit {0} (Correct)'),
                       Line2D([0], [0], color='red', lw=4, label=f'Logit {3} (Ablated Pred)'),
                       Line2D([0], [0], color='gray', lw=4, label='Other Logits')]
    plt.legend(handles=legend_elements)
    plt.axhline(0, color='black', linewidth=0.5) # Zero line
    plt.show()

    print("\nLogit Change Values for (0,0):")
    for k, delta in enumerate(delta_logits_00):
        print(f"  Logit {k}: {delta:+.3f}")

else:
    print("Skipping full logit change analysis due to missing data.")


This is fascinating! The difference in logits looks like a cosine of period roughly k/2! This structure cannot be coincidental.

Let's now also look at what the logits are for the full model and the ablated model. This tells us what we're subtracting a cosine from, and what we obtain in result.

In [None]:
# --- Check if data is available ---
if 'original_logits_pred_pos' not in locals() or \
   'ablated_logits_pred_pos' not in locals() or \
   'delta_logits_00' not in locals() or \
   idx_00 is None:
    print("Error: Missing necessary logit data or index for (0,0).")
    print("Please ensure the initial model runs and the previous delta_logits calculation were successful.")
else:
    # --- Extract Logit Vectors for (0,0) ---
    logits_full_00 = original_logits_pred_pos[idx_00]
    logits_ablated_00 = ablated_logits_pred_pos[idx_00]

    # Optional: Verify the difference calculation
    # print("Sanity Check: Max difference between calculated delta and (full - ablated):",
    #       np.max(np.abs(delta_logits_00 - (logits_full_00 - logits_ablated_00))))

    # --- Create the Plot ---
    plt.figure(figsize=(14, 7))
    output_indices = np.arange(num_outputs)

    # Plot the three series
    plt.plot(output_indices, logits_ablated_00, marker='o', linewidth=4, color='orange', label='Ablated Logits (Input 0,0)')
    # plt.plot(output_indices, delta_logits_00, marker='x', linestyle=':', color='purple', label='Delta Logits (Full - Ablated)')
    plt.plot(output_indices, logits_full_00, marker='s', linewidth=4, color='blue', label='Full Model Logits (Input 0,0)')

    # Add vertical lines for key logits
    plt.axvline(0, color='green', linestyle='-.', linewidth=1, label='Correct Logit (0)')
    plt.axvline(3, color='red', linestyle='-.', linewidth=1, label='Ablated Prediction (3)')

    # Add labels and title
    plt.xlabel("Output Logit Index (k)")
    plt.ylabel("Logit Value")
    plt.title("Logit Comparison for Input (0,0)")
    plt.xticks(output_indices)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.axhline(0, color='black', linewidth=0.5) # Zero line
    plt.show()


This is less obviously a clear function shape, but let us try to describe it.

I'd say the shape is roughly still similar for both models, and some semblanbce of periodicity is still present -- although maybe the periodic part is just me squinting too hard.

They both have a dip at logit 9.

They both go to zero at logit 19.

The full model starts with its highest peak at 0, goes down until 5, has another brief peak 7, goes down to 9, and rises again from there to peak at 15, goes down to 18... This looks like maybe an attenuated sinusoidal? Maybe a sum of a few of them?

The ablated model also looks like something along these lines, but it's starting at roughly 0 instead of having its peak there...

Intriguing! Not sure where to go from here. Perhaps we can identify for the full model what the top W_L * neuron_activation contributions are, and plot these contributions to the logit separately, again for the (0,0) case? I am hypothesizing that what we are seeing is a sum of sinusoidals that is weighed in such a way to peak at the right places. By removing the PC2 we kill some of them, in a poorer approximation.

In [None]:
# --- Check if data is available ---
if 'a_full_00' not in locals() or a_full_00 is None:
    print("Error: 'a_full_00' (full model activations for (0,0)) not found.")
elif 'W_L_np' not in locals() or W_L_np is None:
    print("Error: 'W_L_np' (NumPy version of W_L) not found.")
else:
    print(f"\nCalculating neuron contributions to logits for input (0,0)...")

    # --- 1. Calculate individual neuron contributions ---
    # Element-wise multiplication with broadcasting:
    # (num_outputs, d_mlp) * (d_mlp,) -> (num_outputs, d_mlp)
    neuron_logit_contributions_00 = W_L_np * a_full_00 # Note: uses broadcasting

    # Verify shape
    # print(f"Shape of W_L_np: {W_L_np.shape}")
    # print(f"Shape of a_full_00: {a_full_00.shape}")
    # print(f"Shape of contributions: {neuron_logit_contributions_00.shape}") # Should be [num_outputs, d_mlp]

    # Sanity check: Sum of contributions should approximate the final logits
    # approx_logits_full_00 = np.sum(neuron_logit_contributions_00, axis=1)
    # if 'logits_full_00' in locals(): # Compare if original logits are available
    #      print("Max difference between sum of contributions and original logits:",
    #            np.max(np.abs(approx_logits_full_00 - logits_full_00)))


    # --- 2. Identify top contributing neurons ---
    # Sum absolute contributions across all logits for each neuron
    total_abs_contribution = np.sum(np.abs(neuron_logit_contributions_00), axis=0) # Sum over outputs -> shape [d_mlp,]
    top_contributing_indices = np.argsort(total_abs_contribution)[::-1] # Sort descending

    num_top_to_plot = 5 # Adjust how many top neurons to visualize
    print(f"\nTop {num_top_to_plot} contributing neurons for input (0,0):")
    for i in range(num_top_to_plot):
        idx = top_contributing_indices[i]
        print(f"  Neuron {idx}: Total Abs Contribution = {total_abs_contribution[idx]:.4f}")
        if 'neuron_sum_correlations' in locals():
            print(f"     (Correlation w/ sum: {neuron_sum_correlations[idx]:.3f})")


    # --- 3. Plot contribution profiles ---
    plt.figure(figsize=(14, 7))
    output_indices = np.arange(num_outputs)

    # Plot overall full logit for reference
    # if 'logits_full_00' in locals():
    #     plt.plot(output_indices, logits_full_00, marker='', linestyle='-', color='black', linewidth=2.5, label='Total Logits (Full Model)')

    # Plot contributions of top neurons
    for i in range(num_top_to_plot):
        neuron_idx = top_contributing_indices[i]
        contribution_profile = neuron_logit_contributions_00[:, neuron_idx] # Shape [num_outputs,]
        plt.plot(output_indices, contribution_profile, marker='.', linestyle='--',
                 label=f'Neuron {neuron_idx} Contribution (AbsSum={total_abs_contribution[neuron_idx]:.2f})')


    plt.xlabel("Output Logit Index (k)")
    plt.ylabel("Contribution to Logit Value (W_L[k, i] * a_full_00[i])")
    plt.title(f"Contribution Profiles of Top {num_top_to_plot} Neurons for Input (0,0)")
    plt.xticks(output_indices)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.axhline(0, color='black', linewidth=0.5) # Zero line
    plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to make space for legend
    plt.show()

Ok, this does not quite support my hypothesis. Of the top 20 neurons, 19 of them have a function shape that looks quite similar to the output logit -- the shape we described earlier and seems kind of sinusoidal if you squint but... not quite! Then there's one of this top 20 that seems to just be screwing things up, and actually peaks at 5 rather than 0, and starts negative at zero. That's neuron 398 I believe. The top 2 neurons are significanlt 'stronger' than the rest and are 357 and 184. Are these maybe also (among) the top 2 neurons for other inputs?

This seems indeed very distributed. I wonder the following:

1. Do we see the same shape for other inputs? If we do, then we can maybe hypothesize what the function is that it is learning and try to fit it? Although it kind of looks like a non-trivial function...
2. What do the separate neuron contributions look like for the ablated model?

Let's below plot the logits for inputs (4,4), to investigate question 1.

In [None]:
target_a, target_b = 9, 9
target_sum = target_a + target_b

try:
    idx_target = np.where((original_a_values == target_a) & (original_b_values == target_b))[0][0]
    print(f"Found sample index for input ({target_a},{target_b}): {idx_target}")

    logits_full_target = original_logits_pred_pos[idx_target]

    # --- Create the Comparison Plot ---
    plt.figure(figsize=(14, 7))
    output_indices = np.arange(num_outputs)

    plt.plot(output_indices, logits_full_00, marker='s', linestyle='-', color='blue', label=f'Logits for (0,0) -> Target 0')
    plt.plot(output_indices, logits_full_target, marker='o', linestyle='--', color='green', label=f'Logits for ({target_a},{target_b}) -> Target {target_sum}')

    plt.axvline(0, color='blue', linestyle='-.', linewidth=1, label='Target for (0,0)')
    plt.axvline(target_sum, color='green', linestyle='-.', linewidth=1, label=f'Target for ({target_a},{target_b})')

    plt.xlabel("Output Logit Index (k)")
    plt.ylabel("Logit Value (Full Model)")
    plt.title("Logit Shape Comparison for Different Inputs")
    plt.xticks(output_indices)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.axhline(0, color='black', linewidth=0.5) # Zero line
    plt.show()

except IndexError:
    print(f"Error: Input pair ({target_a},{target_b}) not found in the test set data used.")
except Exception as e:
    print(f"An error occurred: {e}")


It's hard to tell whether it's the same shape. Seems like a non-trivial function. I wouldn't know what to fit it to. I am still confused about why the difference in logits looks like a cosine, but the logits themselves don't obviously map to a sinusoidal. Or could it be that each W_L * activation is already such a complex sinusoidal, with this being learnt somewhere earlier in the process? What is the easiest way to check this hypothesis? Maybe fit the logits of the full model to a sum of sines/cosines? Fourier basis? Let's perform Fourier analysis on the full model logits for input (0,0) to try to verify this.

In [None]:
# --- Check if data is available ---
if 'logits_full_00' not in locals() or logits_full_00 is None:
    print("Error: 'logits_full_00' not found.")
else:
    print(f"\nPerforming Fourier Analysis on Full Model Logits for Input (0,0)...")

    # --- 1. Compute DFT ---
    # Use rfft for real-valued input, gives positive frequencies only
    fourier_coeffs = np.fft.rfft(logits_full_00)
    # Frequencies corresponding to the coefficients
    frequencies = np.fft.rfftfreq(num_outputs) # Length matches output of rfft

    # Get magnitudes
    fourier_magnitudes = np.abs(fourier_coeffs)

    # Normalize magnitudes (optional, for easier comparison)
    # fourier_magnitudes /= np.max(fourier_magnitudes)

    # --- 2. Analyze Magnitudes ---
    print("\nFourier Coefficient Magnitudes:")
    # Sort by magnitude to find dominant frequencies
    sorted_indices = np.argsort(fourier_magnitudes)[::-1]
    num_top_freqs_to_show = 10
    print(f"  Top {num_top_freqs_to_show} Frequencies by Magnitude:")
    for i in range(min(num_top_freqs_to_show, len(frequencies))):
        idx = sorted_indices[i]
        freq = frequencies[idx]
        mag = fourier_magnitudes[idx]
        print(f"    Freq Index {idx} (Freq ≈ {freq:.3f}): Magnitude = {mag:.4f}")

    # --- 3. Visualize ---
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(fourier_magnitudes)), fourier_magnitudes, width=0.9)
    plt.xlabel("Frequency Index (k)")
    plt.ylabel("Magnitude of Fourier Coefficient")
    plt.title("Fourier Transform Magnitudes for Full Model Logits [Input (0,0)]")
    # Optional: Add frequency values to x-axis if helpful, can get crowded
    # plt.xticks(range(len(fourier_magnitudes)), [f"{f:.2f}" for f in frequencies], rotation=45, ha='right')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

    # --- 4. Optional: Variance Explained by Top K Frequencies ---
    k_to_check = 6 # How many top frequencies to use for reconstruction
    print(f"\nChecking variance explained by top {k_to_check} frequencies...")

    # Create a filtered coefficient array (zeros except for top k)
    filtered_coeffs = np.zeros_like(fourier_coeffs)
    # Keep DC component (index 0) + top k non-DC frequencies
    top_k_indices = sorted_indices[:k_to_check+1] # Get indices including potential DC
    # Ensure we don't double-count or miss components if top k includes index 0
    indices_to_keep = np.union1d([0], top_k_indices) # Always keep DC, add others

    filtered_coeffs[indices_to_keep] = fourier_coeffs[indices_to_keep]

    # Inverse transform to reconstruct the signal
    reconstructed_logits = np.fft.irfft(filtered_coeffs, n=num_outputs)

    # Calculate variance explained (R-squared like measure)
    variance_original = np.var(logits_full_00)
    variance_residual = np.var(logits_full_00 - reconstructed_logits)
    if variance_original > 1e-9: # Avoid division by zero
        variance_explained = 1 - (variance_residual / variance_original)
        print(f"  Variance explained by top {k_to_check} frequencies (incl. DC): {variance_explained:.4f}")
    else:
        print("  Cannot calculate variance explained (original variance is zero).")

    # Optional: Plot reconstruction
    plt.figure(figsize=(12, 6))
    output_indices = np.arange(num_outputs)
    plt.plot(output_indices, logits_full_00, marker='o', linestyle='-', label='Original Logits')
    plt.plot(output_indices, reconstructed_logits, marker='x', linestyle='--', label=f'Reconstructed (Top {k_to_check} Freqs)')
    plt.title(f"Logit Reconstruction using Top {k_to_check} Fourier Frequencies [Input (0,0)]")
    plt.xlabel("Output Logit Index (k)")
    plt.ylabel("Logit Value")
    plt.xticks(output_indices)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()

The reconstruction using the top 3 frequencies looks quite similar already! Although it would actually produce the wrong output, as 0 goes down slightly, and 1 and 2 both come up slightly, ending above it! 9 is also less pronouncedly negative.  Adding more frequencies progressively improves, and with 6 we get the correct output. Let's see if we obtain similar results for other inputs, for example (4,4).

In [None]:
target_a, target_b = 4, 4
target_sum = target_a + target_b

idx_target = np.where((original_a_values == target_a) & (original_b_values == target_b))[0][0]
print(f"Found sample index for input ({target_a},{target_b}): {idx_target}")

logits_full_target = original_logits_pred_pos[idx_target]

print(f"\nPerforming Fourier Analysis on Full Model Logits for Input (4,4)...")

# --- 1. Compute DFT ---
# Use rfft for real-valued input, gives positive frequencies only
fourier_coeffs = np.fft.rfft(logits_full_target)
# Frequencies corresponding to the coefficients
frequencies = np.fft.rfftfreq(num_outputs) # Length matches output of rfft

# Get magnitudes
fourier_magnitudes = np.abs(fourier_coeffs)

# Normalize magnitudes (optional, for easier comparison)
# fourier_magnitudes /= np.max(fourier_magnitudes)

# --- 2. Analyze Magnitudes ---
print("\nFourier Coefficient Magnitudes:")
# Sort by magnitude to find dominant frequencies
sorted_indices = np.argsort(fourier_magnitudes)[::-1]
num_top_freqs_to_show = 10
print(f"  Top {num_top_freqs_to_show} Frequencies by Magnitude:")
for i in range(min(num_top_freqs_to_show, len(frequencies))):
    idx = sorted_indices[i]
    freq = frequencies[idx]
    mag = fourier_magnitudes[idx]
    print(f"    Freq Index {idx} (Freq ≈ {freq:.3f}): Magnitude = {mag:.4f}")

# --- 3. Visualize ---
plt.figure(figsize=(12, 6))
plt.bar(range(len(fourier_magnitudes)), fourier_magnitudes, width=0.9)
plt.xlabel("Frequency Index (k)")
plt.ylabel("Magnitude of Fourier Coefficient")
plt.title("Fourier Transform Magnitudes for Full Model Logits [Input (0,0)]")
# Optional: Add frequency values to x-axis if helpful, can get crowded
# plt.xticks(range(len(fourier_magnitudes)), [f"{f:.2f}" for f in frequencies], rotation=45, ha='right')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

# --- 4. Optional: Variance Explained by Top K Frequencies ---
k_to_check = 6 # How many top frequencies to use for reconstruction
print(f"\nChecking variance explained by top {k_to_check} frequencies...")

# Create a filtered coefficient array (zeros except for top k)
filtered_coeffs = np.zeros_like(fourier_coeffs)
# Keep DC component (index 0) + top k non-DC frequencies
top_k_indices = sorted_indices[:k_to_check+1] # Get indices including potential DC
# Ensure we don't double-count or miss components if top k includes index 0
indices_to_keep = np.union1d([0], top_k_indices) # Always keep DC, add others

filtered_coeffs[indices_to_keep] = fourier_coeffs[indices_to_keep]

# Inverse transform to reconstruct the signal
reconstructed_logits = np.fft.irfft(filtered_coeffs, n=num_outputs)

# Calculate variance explained (R-squared like measure)
variance_original = np.var(logits_full_target)
variance_residual = np.var(logits_full_target - reconstructed_logits)
if variance_original > 1e-9: # Avoid division by zero
    variance_explained = 1 - (variance_residual / variance_original)
    print(f"  Variance explained by top {k_to_check} frequencies (incl. DC): {variance_explained:.4f}")
else:
    print("  Cannot calculate variance explained (original variance is zero).")

# Optional: Plot reconstruction
plt.figure(figsize=(12, 6))
output_indices = np.arange(num_outputs)
plt.plot(output_indices, logits_full_target, marker='o', linestyle='-', label='Original Logits')
plt.plot(output_indices, reconstructed_logits, marker='x', linestyle='--', label=f'Reconstructed (Top {k_to_check} Freqs)')
plt.title(f"Logit Reconstruction using Top {k_to_check} Fourier Frequencies [Input (0,0)]")
plt.xlabel("Output Logit Index (k)")
plt.ylabel("Logit Value")
plt.xticks(output_indices)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()

We hypothesize that the network implements a Fourier-like algorithm where the final logits are synthesized by summing trigonometric basis functions (sines/cosines) weighted by MLP neuron activations. To test the readout part of this hypothesis, we will analyze the effective neuron-logit weight matrix, W_L. If the hypothesis is correct, the columns of W_L (each corresponding to a single neuron's influence on all output logits) should themselves resemble discrete sine or cosine waves of specific frequencies. Therefore, we will compute the Discrete Fourier Transform (DFT) for each column (neuron) of W_L. We expect to find that for many neurons, especially those relevant to the computation, the power in their Fourier spectrum is concentrated at just one or a few low-frequency indices, indicating they represent specific Fourier basis components.


In [None]:
# --- Check if data is available ---
if 'W_L_np' not in locals() or W_L_np is None:
    print("Error: 'W_L_np' (NumPy version of W_L) not found.")
elif 'neuron_sum_correlations' not in locals():
     print("Warning: 'neuron_sum_correlations' not found. Heatmap columns will not be sorted.")
     # Create a dummy sorter if correlations are missing
     neuron_sum_correlations = np.arange(d_mlp)
else:
    print(f"\nPerforming Fourier Analysis on columns (neurons) of W_L...")

    # --- Compute DFT for each column ---
    # Use rfft for real-valued input. Output length is n//2 + 1
    num_freqs = np.fft.rfftfreq(num_outputs).shape[0]
    neuron_fourier_mags = np.zeros((num_freqs, d_mlp))

    print(f"Calculating DFT for {d_mlp} neurons...")
    for i in range(d_mlp):
        w_col = W_L_np[:, i] # Get weights for neuron i across all logits
        fourier_coeffs = np.fft.rfft(w_col)
        neuron_fourier_mags[:, i] = np.abs(fourier_coeffs)
    print("DFT calculation complete.")

    # --- Normalize magnitudes per neuron (optional but recommended for visualization) ---
    # This makes peaks comparable across neurons with different overall weight norms
    max_mags_per_neuron = np.max(neuron_fourier_mags, axis=0, keepdims=True)
    # Avoid division by zero for neurons with all zero weights
    max_mags_per_neuron[max_mags_per_neuron < 1e-9] = 1.0
    neuron_fourier_mags_normalized = neuron_fourier_mags / max_mags_per_neuron


    # --- Visualize as Heatmap (Sorted by Neuron Correlation) ---
    # Sort neurons by correlation
    corr_sorted_indices = np.argsort(neuron_sum_correlations)
    heatmap_data = neuron_fourier_mags_normalized[:, corr_sorted_indices]
    frequencies = np.fft.rfftfreq(num_outputs) # For labeling y-axis

    print("\nGenerating heatmap of normalized Fourier magnitudes...")
    plt.figure(figsize=(16, 8))
    sns.heatmap(heatmap_data, cmap='viridis', cbar=True,
                cbar_kws={'label': 'Normalized Magnitude (Max=1 per Neuron)'})
                # Use robust=True instead of normalization if preferred:
                # sns.heatmap(neuron_fourier_mags[:, corr_sorted_indices], cmap='viridis', robust=True, cbar=True)
    plt.title("Normalized Fourier Magnitudes of W_L Columns (Neurons)")
    plt.xlabel("Neurons (Sorted by Correlation with Sum a+b)")
    plt.ylabel("Frequency Index (k)")
    # Label y-axis with frequency index k
    plt.yticks(np.arange(num_freqs) + 0.5, labels=np.arange(num_freqs), rotation=0)
    # Optional: Label y-axis with actual frequencies (can be dense)
    # plt.yticks(np.arange(num_freqs)[::2] + 0.5, labels=[f"{f:.2f}" for f in frequencies[::2]], rotation=0)
    plt.xticks([]) # Hide dense neuron labels
    plt.show()

    # --- Optional: Plot DFT for specific example neurons ---
    # Example: Top 2 contributors from previous analysis (if available) + maybe one from each corr group
    example_neuron_indices = []
    if 'top_contributing_indices' in locals():
         example_neuron_indices.extend(top_contributing_indices[:2].tolist())
    # Add indices from specific correlation groups if desired (find manually or via sorting)
    # example_neuron_indices.extend([idx_neg_group, idx_pos_group_low_corr])
    # Ensure unique indices
    example_neuron_indices = sorted(list(set(example_neuron_indices)))

    if example_neuron_indices:
        print(f"\nPlotting DFT magnitudes for example neurons: {example_neuron_indices}")
        n_examples = len(example_neuron_indices)
        fig, axes = plt.subplots(1, n_examples, figsize=(5 * n_examples, 4), sharey=True)
        if n_examples == 1: axes = [axes] # Make axes iterable if only one plot

        for i, neuron_idx in enumerate(example_neuron_indices):
            ax = axes[i]
            mags = neuron_fourier_mags[:, neuron_idx] # Use non-normalized for scale
            ax.bar(range(num_freqs), mags, width=0.9)
            ax.set_title(f"Neuron {neuron_idx} W_L DFT Mags")
            ax.set_xlabel("Freq Index (k)")
            if i == 0: ax.set_ylabel("Magnitude")
            ax.grid(axis='y', linestyle='--', alpha=0.7)
            # Highlight dominant frequency?
            # peak_freq_idx = np.argmax(mags)
            # ax.patches[peak_freq_idx].set_facecolor('red')

        plt.tight_layout()
        plt.show()


In the heat map, frequency indices 1, 2 and 3 (roughly in that order) are the ones with highest normalized magnitude for basically all neurons studied. Frequency zero is largely irrelevant. Frequency indices greater than 3 are progressively less important

Moving to the bar plots for two example neurons, the important frequencies are similar. 0 does not matter. 1 and 2 are the biggest (although 2 is biggest for 184 and 1 for 357).

Let us now see if there is frequency-specific structure in how the model represents inputs or reads out outputs. This is similar to Figure 3 in Neel's paper.

So we will analyze the frequency components of the model's key weight matrices: the embedding matrix (W_E) and the effective neuron-logit map (W_L = W_U @ W_out). For W_E, we compute the Discrete Fourier Transform (DFT) along the input token dimension (0-9) for each embedding feature and calculate the norm of these Fourier coefficients across the embedding dimension for each frequency. This reveals if certain frequencies are preferentially used to represent the input tokens. For W_L, we compute the DFT along the output logit dimension (0-18) for each neuron's weights and calculate the norm across the neuron dimension for each frequency. This reveals if the readout mechanism is structured to preferentially operate on specific output frequencies. Observing sparsity (only a few frequencies with high magnitude) in either plot would suggest the use of a Fourier-like basis, complementing our PCA findings.



In [None]:
# --- Ensure matrices are NumPy ---
if isinstance(W_E, torch.Tensor):
    W_E_np = W_E.detach().cpu().numpy()
else:
    W_E_np = np.asarray(W_E)

if isinstance(W_L, torch.Tensor):
    W_L_np = W_L.detach().cpu().numpy()
else:
    W_L_np = np.asarray(W_L)

# --- Analysis 1: Fourier Components of Embedding Matrix (W_E) ---
print(f"\nAnalyzing Fourier components of W_E ({W_E_np.shape})...")

# Calculate DFT along the token dimension (axis 0) for each feature
# Result shape: [num_input_freqs, d_model]
w_e_fft_coeffs = np.fft.rfft(W_E_np, axis=0)
w_e_fft_mags = np.abs(w_e_fft_coeffs)

# Calculate L2 norm across the embedding dimension (d_model) for each frequency
# Result shape: [num_input_freqs,]
w_e_freq_norms = np.linalg.norm(w_e_fft_mags, axis=1)
input_frequencies = np.fft.rfftfreq(N) # Frequencies for input size N

print("W_E analysis complete.")

# --- Analysis 2: Fourier Components of Neuron-Logit Map (W_L) ---
print(f"\nAnalyzing Fourier components of W_L ({W_L_np.shape})...")

# Calculate DFT along the logit dimension (axis 0) for each neuron's weights
# Result shape: [num_output_freqs, d_mlp]
w_l_fft_coeffs = np.fft.rfft(W_L_np, axis=0)
w_l_fft_mags = np.abs(w_l_fft_coeffs)

# Calculate L2 norm across the neuron dimension (d_mlp) for each frequency
# Result shape: [num_output_freqs,]
w_l_freq_norms = np.linalg.norm(w_l_fft_mags, axis=1)
output_frequencies = np.fft.rfftfreq(num_outputs) # Frequencies for output size num_outputs

print("W_L analysis complete.")


# --- Plotting (Similar to Nanda et al. Figure 3) ---
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot for W_E
axes[0].bar(range(len(w_e_freq_norms)), w_e_freq_norms, width=0.9)
axes[0].set_title("Norm of Fourier Components of Embedding Matrix (W_E)")
axes[0].set_xlabel(f"Input Frequency Index (k), N={N}")
axes[0].set_ylabel("Norm across d_model")
axes[0].grid(axis='y', linestyle='--', alpha=0.7)
# axes[0].set_xticks(range(len(w_e_freq_norms))) # Only if few freqs

# Plot for W_L
axes[1].bar(range(len(w_l_freq_norms)), w_l_freq_norms, width=0.9)
axes[1].set_title("Norm of Fourier Components of Neuron-Logit Map (W_L)")
axes[1].set_xlabel(f"Output Frequency Index (k), P={num_outputs}")
axes[1].set_ylabel("Norm across d_mlp")
axes[1].grid(axis='y', linestyle='--', alpha=0.7)
# axes[1].set_xticks(range(len(w_l_freq_norms))) # Only if few freqs

plt.tight_layout()
plt.show()


We've seen that the influence of PC2 on the pre-ReLU activations (delta_z_mlp) is quite simple (linear/quadratic in PC2 features), but the baseline pre-ReLU activations without PC2 (z_mlp_ablated) were only imperfectly linear with the target sum. To get a clearer picture of the computation before the ReLU non-linearity in the full model, we will now attempt to directly model the complete pre-ReLU activations (z_mlp_full) as a function of both PC1 and PC2 features extracted from the input tokens a and b. Specifically, we'll fit a degree-2 polynomial model (z_mlp_full ≈ Poly2(PC1(a), PC1(b), PC2(a), PC2(b))) for each neuron. If this polynomial fit yields high R-squared values across most neurons, it suggests the MLP's computation before the ReLU is relatively straightforward and mathematically describable, isolating the main complexity to the ReLU step and the W_L readout.


In [None]:
print("\nFitting FULL model pre-ReLU activations (z_mlp_full) vs. Polynomial(PC1, PC2 features)...")

# --- Check if data is available ---
if 'pca_scores_np' not in locals() or pca_scores_np is None or pca_scores_np.shape[1] < 2:
    print("Error: 'pca_scores_np' with at least 2 components not found.")
elif 'z_mlp_full' not in locals() or z_mlp_full is None:
    print("Error: 'z_mlp_full' not found. Please ensure full pre-ReLU activations were captured.")
else:
    # --- 1. Extract PC1 and PC2 Features for all tokens ---
    pc1_values_all_tokens = pca_scores_np[:, 0] # PC1 is index 0
    pc2_values_all_tokens = pca_scores_np[:, 1] # PC2 is index 1

    # --- 2. Map features to samples ---
    p1a = pc1_values_all_tokens[original_a_values]
    p1b = pc1_values_all_tokens[original_b_values]
    p2a = pc2_values_all_tokens[original_a_values]
    p2b = pc2_values_all_tokens[original_b_values]

    # --- 3. Create base feature matrix ---
    # Using individual features seems more general than pre-combining
    X_features = np.stack((p1a, p1b, p2a, p2b), axis=-1) # Shape: [num_samples, 4]

    # --- 4. Generate Polynomial Features (Degree 2) ---
    # include_bias=True adds a column of 1s for the intercept term automatically
    poly = PolynomialFeatures(degree=2, include_bias=True)
    X_poly_features = poly.fit_transform(X_features)
    # Note: This will include constant, linear terms (p1a,p1b,p2a,p2b), squares, and cross-terms.
    # print(f"Shape of polynomial features: {X_poly_features.shape}") # Debugging shape
    # print(f"Polynomial features names: {poly.get_feature_names_out(['p1a', 'p1b', 'p2a', 'p2b'])}") # Debugging names

    # Array to store R-squared scores
    neuron_r2_scores_z_full_fit = np.zeros(d_mlp)

    # --- 5. Regression Loop ---
    print(f"Fitting polynomial model for {d_mlp} neurons...")
    for i in range(d_mlp):
        # Target variable: full pre-ReLU activation for neuron i
        y_neuron_z_full = z_mlp_full[:, i]

        # Check for constant activation (no variance)
        if np.std(y_neuron_z_full) < 1e-9:
            neuron_r2_scores_z_full_fit[i] = np.nan
            continue

        # Fit model: z_full_i ≈ Poly2(p1a, p1b, p2a, p2b)
        # Use fit_intercept=False because PolynomialFeatures(include_bias=True) added it
        poly_reg_z_full = LinearRegression(fit_intercept=False)
        poly_reg_z_full.fit(X_poly_features, y_neuron_z_full)
        neuron_r2_scores_z_full_fit[i] = poly_reg_z_full.score(X_poly_features, y_neuron_z_full)

    print("Fitting complete.")

    # --- 6. Analysis ---
    valid_r2_scores = neuron_r2_scores_z_full_fit[~np.isnan(neuron_r2_scores_z_full_fit)]

    if len(valid_r2_scores) > 0:
        print("\n--- Analysis of R-squared (Fit z_mlp_full vs Poly2(PC1, PC2 features)) ---")
        print(f"Mean R²:   {np.mean(valid_r2_scores):.4f}")
        print(f"Median R²: {np.median(valid_r2_scores):.4f}")
        print(f"Min R²:    {np.min(valid_r2_scores):.4f}")
        print(f"Max R²:    {np.max(valid_r2_scores):.4f}")
        print(f"Neurons with R² > 0.95: {np.sum(valid_r2_scores > 0.95)} / {len(valid_r2_scores)}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_scores > 0.99)} / {len(valid_r2_scores)}")
        print(f"Neurons with R² > 0.999: {np.sum(valid_r2_scores > 0.999)} / {len(valid_r2_scores)}") # Add higher threshold

        # --- 7. Visualization ---
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_r2_scores, bins=50, kde=False)
        plt.title("R² Distribution: Fit of Pre-ReLU Activations (z_full) vs. Poly2(PC1, PC2 Features)")
        plt.xlabel("R-squared Score per Neuron")
        plt.ylabel("Number of Neurons")
        plt.grid(axis='y', alpha=0.5)
        plt.show()
    else:
        print("Could not calculate R-squared for any neurons.")


**What This Conclusively Tells Us:**

1.  **MLP Pre-ReLU Computation is Simple and Solved:** The core transformation performed by the MLP *before* the ReLU activation function (`z_mlp = MLP_Linear(PC1, PC2) + Bias`) is **extremely well-described** as a simple degree-2 polynomial function of the PC1 and PC2 features derived from the inputs `a` and `b`. There are no significant hidden complexities or dependencies on other obscure features happening *before* the ReLU.
2.  **PC1 and PC2 Suffice:** The top 2 Principal Components of the embeddings contain virtually all the information the MLP uses to calculate its pre-activations. The remaining embedding dimensions or PCs contribute negligible information to this stage.
3.  **Complexity Localized to ReLU and Readout:** This pinpoints exactly where the remaining complexity or "magic" must lie:
    *   **ReLU:** The piecewise linear `max(0, ...)` operation acting on these smooth polynomial pre-activations (`z_mlp_full`) is the primary source of non-linearity that makes the *post-ReLU* activations (`a_mlp_full`) and the final output (`k_ablated`) harder to model directly.
    *   **`W_L` Readout:** The structure of the `W_L` matrix (which we found uses a distributed, mixed low-frequency Fourier basis) determines how these post-ReLU activations `a_mlp_full` are interpreted and summed to form the final logits.

**Updated Concrete Algorithm Description:**

We can now state the algorithm with much higher confidence and specificity:

1.  **Embedding & Feature Extraction:**
    *   Map inputs `a, b` to embeddings `W_E(a), W_E(b)`.
    *   Extract the scalar features `PC1(a), PC1(b), PC2(a), PC2(b)` representing the projections onto the first two principal components (linear and quadratic features).

2.  **MLP Pre-Activation (Polynomial Computation):**
    *   For *each neuron `i`* in the MLP layer:
        `z_mlp_full[i] = Degree2Poly_i(PC1(a), PC1(b), PC2(a), PC2(b)) + Bias_i`
    *   Where `Degree2Poly_i` is a specific, learned quadratic polynomial function (with coefficients determined by the MLP's `W_in` weights for neuron `i`) involving terms up to degree 2 of the four input features (e.g., `PC1(a)`, `PC1(b)`, `PC2(a)`, `PC2(b)`, `PC1(a)^2`, `PC1(a)*PC1(b)`, `PC1(a)*PC2(a)`, etc.).
    *   **(This step is now known with high confidence).**

3.  **MLP Activation (ReLU):**
    *   `a_mlp_full[i] = max(0, z_mlp_full[i])`

4.  **Unembedding/Readout (Distributed Fourier Synthesis):**
    *   `Logits[k] = Sum_i ( W_L[k, i] * a_mlp_full[i] )`
    *   Where the columns `W_L[:, i]` represent combinations of low-frequency (k=1, 2, 3...) Fourier basis vectors. The dot product synthesizes the final logit shape.

Let's now look at what the coefficients of this polynomial actually are. Are they roughly the same for most/all neurons? Exactly *how* does the MLP make use of these linear and quadratic features?

In [None]:
print("\nAnalyzing coefficients of the Poly2 fit for z_mlp_full...")

# --- Check if data is available ---
if 'X_poly_features' not in locals() or X_poly_features is None:
    print("Error: 'X_poly_features' not found.")
elif 'z_mlp_full' not in locals() or z_mlp_full is None:
    print("Error: 'z_mlp_full' not found.")
elif 'poly' not in locals() or poly is None:
    print("Error: 'poly' (PolynomialFeatures object) not found.")
elif 'neuron_sum_correlations' not in locals() or neuron_sum_correlations is None:
    print("Error: 'neuron_sum_correlations' not found.")
else:
    num_samples, num_poly_feats = X_poly_features.shape
    num_samples_z, num_neurons = z_mlp_full.shape
    assert num_neurons == d_mlp
    assert num_samples == num_samples_z

    # --- Get Feature Names ---
    # Define base feature names used to create polynomials
    base_feature_names = ['p1a', 'p1b', 'p2a', 'p2b']
    poly_feature_names = poly.get_feature_names_out(base_feature_names)
    print(f"Polynomial features ({num_poly_feats}): {poly_feature_names}")

    # --- Fit Models and Extract Coefficients ---
    print(f"Fitting models and extracting coefficients for {d_mlp} neurons...")
    learned_poly_coeffs = np.zeros((d_mlp, num_poly_feats))
    neuron_fit_successful = np.zeros(d_mlp, dtype=bool)

    for i in range(d_mlp):
        y_neuron_z_full = z_mlp_full[:, i]
        if np.std(y_neuron_z_full) < 1e-9:
            learned_poly_coeffs[i, :] = np.nan # Mark as NaN if no variance
            continue

        # Fit model using pre-calculated polynomial features
        poly_reg_z_full = LinearRegression(fit_intercept=False) # Bias term is feature '1'
        poly_reg_z_full.fit(X_poly_features, y_neuron_z_full)
        learned_poly_coeffs[i, :] = poly_reg_z_full.coef_
        neuron_fit_successful[i] = True

    print("Coefficient extraction complete.")

    # Filter out neurons where fit failed (NaNs)
    valid_coeffs = learned_poly_coeffs[neuron_fit_successful]
    valid_correlations = neuron_sum_correlations[neuron_fit_successful]
    num_valid_neurons = valid_coeffs.shape[0]
    print(f"Analyzing {num_valid_neurons} neurons with successful fits.")

    # --- Overall Analysis: Mean and Std of Coefficients ---
    mean_coeffs = np.nanmean(valid_coeffs, axis=0)
    std_coeffs = np.nanstd(valid_coeffs, axis=0)

    print("\n--- Overall Mean Polynomial Coefficients (across all valid neurons) ---")
    coeffs_df_overall = pd.DataFrame({
        'Feature': poly_feature_names,
        'MeanCoeff': mean_coeffs,
        'StdCoeff': std_coeffs,
        'MeanAbsCoeff': np.nanmean(np.abs(valid_coeffs), axis=0) # Look at avg magnitude too
    }).sort_values('MeanAbsCoeff', ascending=False)
    print(coeffs_df_overall.to_string())

    # --- Visualize Overall Mean Coefficients ---
    plt.figure(figsize=(12, 7))
    indices = np.arange(num_poly_feats)
    plt.bar(indices, mean_coeffs, yerr=std_coeffs, capsize=5, alpha=0.7, ecolor='gray')
    plt.axhline(0, color='black', linewidth=0.5)
    plt.xticks(indices, poly_feature_names, rotation=45, ha='right')
    plt.ylabel("Coefficient Value")
    plt.title("Mean Polynomial Coefficients (+/- std dev) across All Neurons")
    plt.tight_layout()
    plt.show()


    # --- Group Analysis: Based on Correlation ---
    print("\n--- Mean Coefficients by Neuron Correlation Group ---")
    # Define correlation groups (adjust thresholds as needed based on your data)
    groups = {
        "High Pos (>=0.9)": valid_correlations >= 0.9,
        "Mid Pos (0.5-0.9)": (valid_correlations >= 0.5) & (valid_correlations < 0.9),
         #"Low Mag (<0.5)": np.abs(valid_correlations) < 0.5, # Optional
        "Mid Neg (-0.9 - -0.5)": (valid_correlations < -0.5) & (valid_correlations >= -0.9),
        "High Neg (<=-0.9)": valid_correlations < -0.9,
    }

    group_coeffs_mean = {}
    for name, mask in groups.items():
        group_mask_in_valid = mask[neuron_fit_successful] # Align mask with valid coeffs
        if np.sum(group_mask_in_valid) > 0:
            group_coeffs_mean[name] = np.nanmean(valid_coeffs[group_mask_in_valid], axis=0)
        else:
             group_coeffs_mean[name] = np.full(num_poly_feats, np.nan) # Handle empty groups

    group_coeffs_df = pd.DataFrame(group_coeffs_mean, index=poly_feature_names)
    print(group_coeffs_df.to_string(float_format="%.4f"))

    # --- Visualize Group Coefficients (Heatmap) ---
    # Transpose for better heatmap layout (features as rows, groups as columns)
    plt.figure(figsize=(10, 8))
    sns.heatmap(group_coeffs_df.T, annot=True, fmt=".3f", cmap="coolwarm", center=0,
                linewidths=.5, linecolor='lightgray')
    plt.title("Mean Polynomial Coefficients for Different Neuron Correlation Groups")
    plt.xlabel("Polynomial Feature")
    plt.ylabel("Neuron Correlation Group")
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

Dominance of Linear PC1: The linear terms for PC1 (p1a, p1b) have the largest mean absolute coefficients (after the bias '1'). This confirms PC1 (related to the sum a+b) is a primary driver.

Symmetry (a vs b): The coefficients for p1a and p1b are extremely similar (both overall and within groups). Likewise for p2a vs p2b, p1a^2 vs p1b^2, p2a^2 vs p2b^2, etc. This strongly suggests the MLP treats a and b symmetrically, consistent with the addition task. The computation depends primarily on combinations like PC1(a)+PC1(b) and PC2(a)+PC2(b), not the individual values in an asymmetric way.

PC1 Sign Flipping: This is a major finding from the group analysis. Neurons with positive correlation (High Pos, Mid Pos) have large negative coefficients for the linear PC1 terms (p1a, p1b). Neurons with negative correlation (Mid Neg) have large positive coefficients for these terms. This seems counter-intuitive at first glance, but likely interacts with the bias and other terms. It clearly shows opposing strategies for neurons that correlate positively vs negatively with the sum.

PC2 Linear Term Consistency: The coefficients for the linear PC2 terms (p2a, p2b) are consistently negative across the groups where they are significant. This suggests PC2 linearly contributes a suppressing effect on average to the pre-activation, regardless of the neuron's correlation group.
Interaction Terms: The PC1 cross-term (p1a p1b) has a noticeable negative coefficient overall. The PC1xPC2 cross-terms (p1a p2a, p1a p2b, etc.) have smaller average magnitudes but do show sign differences between positive and negative correlation groups, suggesting PC2's influence might be partially exerted through interaction with the PC1 signal.

Quadratic Terms Minor (on average): Pure quadratic terms (p1a^2, p2a^2, etc.) have the smallest coefficients on average, suggesting they play a less dominant role than linear or interaction terms in the average recipe.


standard deviations are high compared to means, so i presume the learned representations are not the same for all neurons? can we perhaps isolate this a bit more? perhaps look at e.g. only the high-correlation group, or even just at one or two high-correlation neurons

In [None]:
# Filter out neurons where fit failed (NaNs)
valid_coeffs = learned_poly_coeffs[neuron_fit_successful]
valid_correlations = neuron_sum_correlations[neuron_fit_successful]
valid_indices = np.where(neuron_fit_successful)[0] # Original indices of valid neurons
num_valid_neurons = valid_coeffs.shape[0]

# --- 1. Analyze High Positive Correlation Group (>0.9) ---
high_pos_mask_valid = valid_correlations >= 0.9
num_high_pos = np.sum(high_pos_mask_valid)

if num_high_pos > 0:
    coeffs_high_pos = valid_coeffs[high_pos_mask_valid]
    mean_coeffs_high_pos = np.mean(coeffs_high_pos, axis=0)
    std_coeffs_high_pos = np.std(coeffs_high_pos, axis=0)

    print(f"\n--- Mean Polynomial Coefficients for High Positive Correlation Neurons (>{num_high_pos}) ---")
    coeffs_df_high_pos = pd.DataFrame({
        'Feature': poly_feature_names,
        'MeanCoeff': mean_coeffs_high_pos,
        'StdCoeff': std_coeffs_high_pos,
        'MeanAbsCoeff': np.mean(np.abs(coeffs_high_pos), axis=0)
    }).sort_values('MeanAbsCoeff', ascending=False)
    print(coeffs_df_high_pos.to_string())

    # Visualize
    plt.figure(figsize=(12, 7))
    indices = np.arange(len(poly_feature_names))
    plt.bar(indices, mean_coeffs_high_pos, yerr=std_coeffs_high_pos, capsize=5, alpha=0.7, ecolor='gray')
    plt.axhline(0, color='black', linewidth=0.5)
    plt.xticks(indices, poly_feature_names, rotation=45, ha='right')
    plt.ylabel("Coefficient Value")
    plt.title(f"Mean Polynomial Coefficients (+/- std dev) for High Positive Neurons (N={num_high_pos})")
    plt.tight_layout()
    plt.show()

else:
    print("\nNo neurons found in the High Positive Correlation group (>=0.9).")


# --- 2. Analyze Specific Neurons (e.g., 357, 184) ---
# Check if these indices exist and had successful fits
neuron_ids_to_check = []
if 'top_contributing_indices' in locals() and len(top_contributing_indices) >= 2:
     neuron_ids_to_check = top_contributing_indices[:2] # Use the previously found top 2
else:
     # Fallback if top contributors aren't available, find highest correlation ones
     print("Warning: top_contributing_indices not found, using highest correlation neurons instead.")
     if num_valid_neurons >= 2:
        sorted_corr_indices_valid = np.argsort(np.abs(valid_correlations))[::-1]
        neuron_ids_to_check = valid_indices[sorted_corr_indices_valid[:2]] # Get original indices


if len(neuron_ids_to_check) > 0:
    print(f"\n--- Coefficients for Specific Neurons: {neuron_ids_to_check} ---")
    specific_coeffs_dict = {}
    for neuron_idx in neuron_ids_to_check:
         # Find the index within the valid_coeffs array
         valid_idx_location = np.where(valid_indices == neuron_idx)[0]
         if len(valid_idx_location) > 0:
             valid_idx = valid_idx_location[0]
             specific_coeffs_dict[f'Neuron {neuron_idx}'] = learned_poly_coeffs[valid_idx]
         else:
              print(f"Warning: Neuron {neuron_idx} did not have a successful fit or is not in valid indices.")

    if specific_coeffs_dict:
         specific_coeffs_df = pd.DataFrame(specific_coeffs_dict, index=poly_feature_names)
         print(specific_coeffs_df.to_string(float_format="%.4f"))

         # Visualize comparison
         specific_coeffs_df.plot(kind='bar', figsize=(14, 7), alpha=0.7)
         plt.title(f"Polynomial Coefficients for Specific Neurons")
         plt.ylabel("Coefficient Value")
         plt.xticks(rotation=45, ha='right')
         plt.axhline(0, color='black', linewidth=0.5)
         plt.tight_layout()
         plt.show()


Given that the quadratic terms in our previous degree-2 polynomial fit had relatively small average coefficients, we now test if the pre-ReLU activations (z_mlp_full) can be accurately modelled using only a linear combination of the PC1 and PC2 features from inputs a and b. We will fit z_mlp_full ≈ w1*PC1(a) + w2*PC1(b) + w3*PC2(a) + w4*PC2(b) + Bias for each neuron. Comparing the R-squared values from this linear fit to the previous quadratic fit will reveal the specific contribution of the quadratic terms to the pre-activation computation.


In [None]:
print("\nFitting z_mlp_full vs. LINEAR PC1 & PC2 Features...")

# --- Check if data is available ---
if 'X_features' not in locals() or X_features is None or X_features.shape[1] != 4:
    print("Error: 'X_features' with 4 columns [p1a, p1b, p2a, p2b] not found.")
elif 'z_mlp_full' not in locals() or z_mlp_full is None:
    print("Error: 'z_mlp_full' not found.")
else:
    num_samples, num_neurons = z_mlp_full.shape
    assert num_neurons == d_mlp

    # Array to store R-squared scores
    neuron_r2_linear_only_fit = np.zeros(d_mlp)

    # --- Fit Models for Each Neuron ---
    print(f"Fitting linear model for {d_mlp} neurons...")
    for i in range(d_mlp):
        y_neuron_z_full = z_mlp_full[:, i]
        if np.std(y_neuron_z_full) < 1e-9:
            neuron_r2_linear_only_fit[i] = np.nan
            continue

        # Fit model: z_full_i ≈ w1*p1a + w2*p1b + w3*p2a + w4*p2b + Bias
        lin_reg_z_full = LinearRegression(fit_intercept=True) # Fit intercept separately now
        lin_reg_z_full.fit(X_features, y_neuron_z_full)
        neuron_r2_linear_only_fit[i] = lin_reg_z_full.score(X_features, y_neuron_z_full)

    print("Fitting complete.")

    # --- Analysis ---
    valid_r2_scores = neuron_r2_linear_only_fit[~np.isnan(neuron_r2_linear_only_fit)]
    valid_r2_poly_prev = neuron_r2_scores_z_full_fit[~np.isnan(neuron_r2_linear_only_fit)] # Align masks

    if len(valid_r2_scores) > 0:
        print("\n--- Analysis of R-squared (Fit z_mlp_full vs Linear(PC1, PC2 features)) ---")
        print(f"Mean R²:   {np.mean(valid_r2_scores):.4f}")
        print(f"Median R²: {np.median(valid_r2_scores):.4f}")
        print(f"Min R²:    {np.min(valid_r2_scores):.4f}")
        print(f"Max R²:    {np.max(valid_r2_scores):.4f}")
        print(f"Neurons with R² > 0.95: {np.sum(valid_r2_scores > 0.95)} / {len(valid_r2_scores)}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_scores > 0.99)} / {len(valid_r2_scores)}")

        # Compare to Poly2 fit
        if 'neuron_r2_scores_z_full_fit' in locals():
            avg_r2_drop = np.mean(valid_r2_poly_prev - valid_r2_scores)
            median_r2_drop = np.median(valid_r2_poly_prev - valid_r2_scores)
            print(f"\n--- Comparison to Poly2 Fit ---")
            print(f"Mean R² drop by removing quadratics:   {avg_r2_drop:.5f}")
            print(f"Median R² drop by removing quadratics: {median_r2_drop:.5f}")
        else:
            print("\nCannot compare to Poly2 fit R-squared (previous results not found).")

        # Visualize
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_r2_scores, bins=50, kde=False)
        plt.title("R² Distribution: Fit of Pre-ReLU (z_full) vs. Linear(PC1, PC2 Features)")
        plt.xlabel("R-squared Score per Neuron")
        plt.ylabel("Number of Neurons")
        plt.grid(axis='y', alpha=0.5)
        plt.show()
    else:
        print("Could not calculate R-squared for any neurons.")


ok, it gets significantly worse! so we cannot analyze just linearly. however, maybe we can enforce that the coefficients be the same for a and b, and simplify in that way?

In [None]:
print("\nFitting z_mlp_full vs. Symmetric Quadratic Features derived from PC1 & PC2...")

# --- Check if data is available ---
if 'p1a' not in locals() or 'p1b' not in locals() or 'p2a' not in locals() or 'p2b' not in locals():
     print("Error: Base PC features (p1a, p1b, p2a, p2b) not found.")
elif 'z_mlp_full' not in locals() or z_mlp_full is None:
    print("Error: 'z_mlp_full' not found.")
else:
    # --- 1. Create Symmetric Quadratic Features ---
    p1_sum = p1a + p1b
    p2_sum = p2a + p2b
    p1_sum_sq = p1_sum**2
    p2_sum_sq = p2_sum**2
    p1_p2_interaction = p1_sum * p2_sum

    # Feature matrix: [p1_sum, p2_sum, p1_sum^2, p2_sum^2, p1_sum*p2_sum]
    X_symm_quad_features = np.stack(
        (p1_sum, p2_sum, p1_sum_sq, p2_sum_sq, p1_p2_interaction),
        axis=-1
    ) # Shape: [num_samples, 5]
    symm_quad_feature_names = ['p1_sum', 'p2_sum', 'p1_sum_sq', 'p2_sum_sq', 'p1_sum*p2_sum']

    num_samples, num_neurons = z_mlp_full.shape
    assert num_neurons == d_mlp

    # Array to store R-squared scores
    neuron_r2_symm_quad_fit = np.zeros(d_mlp)
    # Optional: Store coefficients if needed
    # neuron_coeffs_symm_quad = np.zeros((d_mlp, 5))
    # neuron_intercepts_symm_quad = np.zeros(d_mlp)
    neuron_fit_successful_symm_quad = np.zeros(d_mlp, dtype=bool)

    # --- Fit Models for Each Neuron ---
    print(f"Fitting symmetric quadratic model for {d_mlp} neurons...")
    for i in range(d_mlp):
        y_neuron_z_full = z_mlp_full[:, i]
        if np.std(y_neuron_z_full) < 1e-9:
            neuron_r2_symm_quad_fit[i] = np.nan
            continue

        # Fit model: z_full_i ≈ w1*p1_sum + w2*p2_sum + w3*p1_sum^2 + w4*p2_sum^2 + w5*(p1s*p2s) + Bias
        symm_quad_reg_z_full = LinearRegression(fit_intercept=True)
        symm_quad_reg_z_full.fit(X_symm_quad_features, y_neuron_z_full)
        neuron_r2_symm_quad_fit[i] = symm_quad_reg_z_full.score(X_symm_quad_features, y_neuron_z_full)
        # Optional: Store coeffs/intercept
        # neuron_coeffs_symm_quad[i, :] = symm_quad_reg_z_full.coef_
        # neuron_intercepts_symm_quad[i] = symm_quad_reg_z_full.intercept_
        neuron_fit_successful_symm_quad[i] = True

    print("Fitting complete.")

    # --- Analysis ---
    valid_r2_symm_quad = neuron_r2_symm_quad_fit[neuron_fit_successful_symm_quad]
    # Align masks for comparison
    valid_r2_poly_prev = neuron_r2_scores_z_full_fit[neuron_fit_successful_symm_quad]

    if len(valid_r2_symm_quad) > 0:
        print("\n--- Analysis of R-squared (Fit z_mlp_full vs Symmetric Quadratic Features) ---")
        print(f"Mean R²:   {np.mean(valid_r2_symm_quad):.4f}")
        print(f"Median R²: {np.median(valid_r2_symm_quad):.4f}")
        print(f"Min R²:    {np.min(valid_r2_symm_quad):.4f}")
        print(f"Max R²:    {np.max(valid_r2_symm_quad):.4f}")
        print(f"Neurons with R² > 0.95: {np.sum(valid_r2_symm_quad > 0.95)} / {len(valid_r2_symm_quad)}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_symm_quad > 0.99)} / {len(valid_r2_symm_quad)}")
        print(f"Neurons with R² > 0.999: {np.sum(valid_r2_symm_quad > 0.999)} / {len(valid_r2_symm_quad)}")


        # Compare to full Poly2 fit
        if 'neuron_r2_scores_z_full_fit' in locals():
            avg_r2_diff = np.mean(valid_r2_poly_prev - valid_r2_symm_quad)
            median_r2_diff = np.median(valid_r2_poly_prev - valid_r2_symm_quad)
            print(f"\n--- Comparison to Full Poly2 Fit ---")
            print(f"Mean R² Difference (Full Poly2 - Symmetric Quad):   {avg_r2_diff:.6f}")
            print(f"Median R² Difference (Full Poly2 - Symmetric Quad): {median_r2_diff:.6f}")
            # Expect this difference to be very small if symmetry simplification holds
        else:
            print("\nCannot compare to Full Poly2 fit R-squared (previous results not found).")

        # Visualize R^2 comparison
        plt.figure(figsize=(8, 6))
        plt.scatter(valid_r2_poly_prev, valid_r2_symm_quad, alpha=0.3)
        min_both = min(np.min(valid_r2_poly_prev), np.min(valid_r2_symm_quad))
        plt.plot([min_both, 1], [min_both, 1], color='red', linestyle='--', label='y=x')
        plt.xlabel("R² using Full Poly2 (14 features + bias)")
        plt.ylabel("R² using Symmetric Quad (5 features + bias)")
        plt.title("R² Comparison: Full vs Symmetric Quadratic Features")
        plt.grid(True, alpha=0.3)
        plt.legend()
        # Adjust limits slightly if needed, but should be close to 1
        plt.xlim(min_both - 0.001, 1.001)
        plt.ylim(min_both - 0.001, 1.001)
        plt.show()
    else:
        print("Could not calculate R-squared for any neurons.")


The fit gets worse, but not by a lot. The vast majority of neurons still have R^2 above 0.95. I would tentatively interpret this as "the model learns that the function it has to implement is symmetric, but does not get all of the way there". I am not sure where to go from here. Perhaps we can recap what we do know:

1. The embedding is quadratic. There is one linear term and one quadratic term. It's also highly symmetric, effectively the same for a and b.
2. Both of those features of the embedding are used by the MLP going into its final layer. Further, the input to the ReLU is a quadratic function of these two features.
3. This function is again very symmetric, with the coefficients for the a and b parts of the embedding being almost exactly the same. Enforcing this in the fit drops fit quality only slightly.
4. ReLU happens. We understand the mechanics of this.
5. Then we get to the unembedding matrix. This is a lot less clear to me. It seems to add neuron activations in a non-trivial way. I am not sure what it does, exactly.

Plan: Identify the top ~20-30 active neurons (a_mlp[i] > threshold) for input (0,0) and for input (4,4). Then, plot the DFT magnitudes of their corresponding W_L[:, i] columns side-by-side or overlaid. Do the active neurons for (0,0) tend to have W_L columns rich in the specific frequencies needed for the target=0 shape? Do the active neurons for (4,4) have W_L columns rich in the frequencies needed for the target=8 shape? This compares the "ingredients" (W_L structure) selected by the network (a_mlp activation) for different tasks.

We hypothesize that the network produces the correct logit shape (peaking at a+b) by activating a specific set of MLP neurons (a_mlp). Each active neuron i then contributes its corresponding "output basis vector" (W_L[:, i]) scaled by its activation. Our Fourier analysis suggests these W_L columns represent combinations of low-frequency components. To test this further, we will now identify the neurons most highly activated (a_mlp post-ReLU) for two different target sums: 0 (input (0,0)) and 8 (input (4,4)). We will then compute the average Discrete Fourier Transform (DFT) magnitude spectrum of the W_L columns corresponding to the top active neurons for each case. If the hypothesis holds, we expect the average spectrum for the neurons active for input (0,0) to potentially differ (in the balance of frequencies 1, 2, 3...) from the average spectrum for neurons active for input (4,4), reflecting the different frequency mixes needed to synthesize peaks at 0 versus 8.



In [None]:
# --- Configuration ---
k_top_neurons = 30 # How many most active neurons to consider
input_pair_1 = (0, 0)
input_pair_2 = (4, 4)
target_sum_1 = input_pair_1[0] + input_pair_1[1]
target_sum_2 = input_pair_2[0] + input_pair_2[1]

# --- Check if data is available ---
data_missing = False
if 'a_mlp_full' not in locals() or a_mlp_full is None:
    print("Error: 'a_mlp_full' not found.")
    data_missing = True
if 'W_L_np' not in locals() or W_L_np is None:
    print("Error: 'W_L_np' not found.")
    data_missing = True
if 'original_a_values' not in locals() or 'original_b_values' not in locals():
    print("Error: Input value arrays not found.")
    data_missing = True

if not data_missing:
    # --- Find indices for target input pairs ---
    try:
        idx_1 = np.where((original_a_values == input_pair_1[0]) & (original_b_values == input_pair_1[1]))[0][0]
        print(f"Found index for input {input_pair_1}: {idx_1}")
    except IndexError:
        print(f"Error: Input pair {input_pair_1} not found in the test set data.")
        idx_1 = None

    try:
        idx_2 = np.where((original_a_values == input_pair_2[0]) & (original_b_values == input_pair_2[1]))[0][0]
        print(f"Found index for input {input_pair_2}: {idx_2}")
    except IndexError:
        print(f"Error: Input pair {input_pair_2} not found in the test set data.")
        idx_2 = None

    if idx_1 is not None and idx_2 is not None:
        # --- Get activation vectors ---
        a_full_1 = a_mlp_full[idx_1] # Activations for input 1
        a_full_2 = a_mlp_full[idx_2] # Activations for input 2

        # --- Identify top active neurons ---
        top_indices_1 = np.argsort(a_full_1)[::-1][:k_top_neurons]
        top_indices_2 = np.argsort(a_full_2)[::-1][:k_top_neurons]
        print(f"\nTop {k_top_neurons} active neuron indices for {input_pair_1}: {top_indices_1[:5]}...")
        print(f"Top {k_top_neurons} active neuron indices for {input_pair_2}: {top_indices_2[:5]}...")

        # --- Helper function to get average DFT magnitude spectrum ---
        def get_avg_dft_mags(neuron_indices, W_L_matrix, num_outputs_fft):
            all_mags = []
            if len(neuron_indices) == 0:
                return np.zeros(num_outputs_fft // 2 + 1) # Return zeros if no neurons

            for i in neuron_indices:
                w_col = W_L_matrix[:, i]
                fft_coeffs = np.fft.rfft(w_col)
                all_mags.append(np.abs(fft_coeffs))
            # Average across the selected neurons
            return np.mean(np.array(all_mags), axis=0)

        # --- Calculate average spectra ---
        print(f"\nCalculating average W_L DFT spectra for top {k_top_neurons} active neurons...")
        avg_spectrum_1 = get_avg_dft_mags(top_indices_1, W_L_np, num_outputs)
        avg_spectrum_2 = get_avg_dft_mags(top_indices_2, W_L_np, num_outputs)
        num_freqs = len(avg_spectrum_1)
        frequencies = np.fft.rfftfreq(num_outputs) # For reference

        # --- Analyze Top Frequencies in Average Spectra ---
        print(f"\nAnalysis for Input {input_pair_1} (Target {target_sum_1}):")
        sorted_freq_indices_1 = np.argsort(avg_spectrum_1)[::-1]
        for i in range(min(5, num_freqs)):
            f_idx = sorted_freq_indices_1[i]
            print(f"  Rank {i+1}: Freq Index {f_idx} (Mag: {avg_spectrum_1[f_idx]:.4f})")

        print(f"\nAnalysis for Input {input_pair_2} (Target {target_sum_2}):")
        sorted_freq_indices_2 = np.argsort(avg_spectrum_2)[::-1]
        for i in range(min(5, num_freqs)):
            f_idx = sorted_freq_indices_2[i]
            print(f"  Rank {i+1}: Freq Index {f_idx} (Mag: {avg_spectrum_2[f_idx]:.4f})")


        # --- Visualize the average spectra ---
        plt.figure(figsize=(12, 6))
        bar_width = 0.35
        indices = np.arange(num_freqs)

        plt.bar(indices - bar_width/2, avg_spectrum_1, bar_width, label=f'Avg Spectrum for Input {input_pair_1} (Top {k_top_neurons})')
        plt.bar(indices + bar_width/2, avg_spectrum_2, bar_width, label=f'Avg Spectrum for Input {input_pair_2} (Top {k_top_neurons})')

        plt.xlabel("Frequency Index (k)")
        plt.ylabel("Average Magnitude of Fourier Coefficient")
        plt.title(f"Average W_L Column DFT Magnitudes for Top {k_top_neurons} Active Neurons")
        plt.xticks(indices) # Label all frequency indices
        plt.legend()
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.show()

    else:
        print("\nSkipping analysis as one or both input pairs were not found.")
else:
    print("\nSkipping analysis due to missing activation or W_L data.")


In [None]:
print("\nAnalyzing W_L Column DFTs for Neurons Flipping State for Input (0,0)...")

# --- Helper function to get average DFT magnitude spectrum ---
def get_avg_dft_mags(neuron_indices, W_L_matrix, num_outputs_fft):
    all_mags = []
    if len(neuron_indices) == 0:
        return np.zeros(num_outputs_fft // 2 + 1), [] # Return zeros and empty list

    for i in neuron_indices:
        w_col = W_L_matrix[:, i]
        fft_coeffs = np.fft.rfft(w_col)
        all_mags.append(np.abs(fft_coeffs))
    # Average across the selected neurons
    if not all_mags: # Should not happen if len > 0, but safe check
        return np.zeros(num_outputs_fft // 2 + 1), []
    return np.mean(np.array(all_mags), axis=0), np.array(all_mags)

# --- Calculate average spectra ---
avg_spectrum_oto, _ = get_avg_dft_mags(off_to_on_indices, W_L_np, num_outputs)
avg_spectrum_otf, all_mags_otf = get_avg_dft_mags(on_to_off_indices, W_L_np, num_outputs)
num_freqs = len(avg_spectrum_oto)

# --- Visualize the average spectra ---
plt.figure(figsize=(12, 6))
bar_width = 0.35
indices = np.arange(num_freqs)

# Plot average spectrum for OFF->ON neurons (if any)
if len(off_to_on_indices) > 0:
    plt.bar(indices - bar_width/2, avg_spectrum_oto, bar_width, label=f'Avg Spectrum OFF->ON (N={len(off_to_on_indices)})', color='cyan', alpha=0.8)

# Plot average spectrum for ON->OFF neurons (if any)
if len(on_to_off_indices) > 0:
    plt.bar(indices + bar_width/2, avg_spectrum_otf, bar_width, label=f'Avg Spectrum ON->OFF (N={len(on_to_off_indices)})', color='magenta', alpha=0.8)

    # Optional: Add individual lines for ON->OFF neurons to show variance
    # for i, mags in enumerate(all_mags_otf):
    #     plt.plot(indices + bar_width/2, mags, color='magenta', alpha=0.1) # Plot individual spectra lightly


plt.xlabel("Frequency Index (k)")
plt.ylabel("Average Magnitude of Fourier Coefficient")
plt.title("Average W_L Column DFT Magnitudes for Flipping Neurons [Input (0,0)]")
plt.xticks(indices)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

# --- Analyze Top Frequencies ---
print("\n--- Top Frequencies in Average Spectra for Flipping Neurons ---")
if len(off_to_on_indices) > 0:
    print(f"OFF -> ON (N={len(off_to_on_indices)}):")
    sorted_freq_indices_oto = np.argsort(avg_spectrum_oto)[::-1]
    for i in range(min(5, num_freqs)):
        f_idx = sorted_freq_indices_oto[i]
        print(f"  Rank {i+1}: Freq Index {f_idx} (Mag: {avg_spectrum_oto[f_idx]:.4f})")
else:
    print("OFF -> ON: None")

if len(on_to_off_indices) > 0:
     print(f"\nON -> OFF (N={len(on_to_off_indices)}):")
     sorted_freq_indices_otf = np.argsort(avg_spectrum_otf)[::-1]
     for i in range(min(5, num_freqs)):
         f_idx = sorted_freq_indices_otf[i]
         print(f"  Rank {i+1}: Freq Index {f_idx} (Mag: {avg_spectrum_otf[f_idx]:.4f})")
else:
    print("ON -> OFF: None")


The Fourier analysis hasn't been as illuminating as we had hoped - many frequencies (6+) are required to get correct predictions, which given the simplicity of the problem is a lot. Let's turn back to the polynomial stuff we've observed.

Having established that the pre-ReLU activations (z_mlp) are extremely well modelled by a degree-2 polynomial of the input PC1 and PC2 features (R² > 0.99 for all neurons), we will now test if this understanding is sufficient to reconstruct the model's final output. We will simulate the rest of the forward pass: first, we calculate the predicted z_mlp for all inputs using the learned polynomial coefficients for each neuron. Second, we apply the standard ReLU activation function to these predicted pre-activations. Third, we multiply the resulting post-ReLU activations by the effective unembedding matrix W_L. Finally, we compare these simulated logits to the actual logits produced by the full model. A close match would validate our complete algorithmic description: Logits ≈ W_L @ ReLU(Poly2(PC1, PC2)).


In [None]:
print("\nSimulating Logits using Poly2(PC1, PC2) fit for pre-ReLU activations...")

# --- Check if data is available ---
if 'learned_poly_coeffs' not in locals() or learned_poly_coeffs is None:
     print("Error: Learned polynomial coefficients ('learned_poly_coeffs') not found.")
elif 'X_poly_features' not in locals() or X_poly_features is None:
    print("Error: Polynomial features ('X_poly_features') not found.")
elif 'W_L_np' not in locals() or W_L_np is None:
    print("Error: 'W_L_np' not found.")
elif 'original_logits_pred_pos' not in locals() or original_logits_pred_pos is None:
     print("Error: Actual original logits ('original_logits_pred_pos') not found.")
else:
    # --- 1. Predict z_mlp using polynomial fits ---
    # We have X_poly_features (shape [num_samples, num_poly_feats])
    # We have learned_poly_coeffs (shape [d_mlp, num_poly_feats])
    # learned_poly_coeffs[i, :] contains the coefficients INCLUDING the bias term coefficient

    z_mlp_pred = np.zeros_like(z_mlp_full) # Initialize prediction array
    fit_indices = np.where(neuron_fit_successful)[0]

    print(f"Predicting z_mlp for {len(fit_indices)} successfully fitted neurons...")

    # The prediction is simply the dot product of features and coefficients
    # (num_samples, num_poly_feats) @ (num_poly_feats, d_mlp) -> (num_samples, d_mlp)
    # Transpose learned_poly_coeffs
    z_mlp_pred = X_poly_features @ learned_poly_coeffs[fit_indices, :].T # Use only valid coeffs

    # If some neurons failed the fit, their columns in z_mlp_pred will be zero.
    # This is usually acceptable, or could be filled with mean if needed.
    print("Predicted z_mlp shape:", z_mlp_pred.shape)

    # --- 2. Apply ReLU ---
    a_mlp_pred = np.maximum(0, z_mlp_pred)
    print("Applied ReLU. Predicted a_mlp shape:", a_mlp_pred.shape)

    # --- 3. Apply W_L ---
    # Logits[sample, k] = Sum_i W_L[k, i] * a_mlp_pred[sample, i]
    # Matrix multiplication: (num_samples, d_mlp) @ (d_mlp, num_outputs) -> (num_samples, num_outputs)
    # Need to transpose W_L_np
    logits_pred = a_mlp_pred @ W_L_np.T
    print("Calculated predicted logits. Shape:", logits_pred.shape)

    # --- 4. Compare Predicted vs Actual Logits ---
    print("\n--- Comparing Predicted Logits vs Actual Logits ---")

    # Overall Metrics
    mse_logits = mean_squared_error(original_logits_pred_pos.flatten(), logits_pred.flatten())
    # Calculate R-squared manually for the whole set of logits
    mean_actual_logits = np.mean(original_logits_pred_pos)
    total_variance = np.sum((original_logits_pred_pos - mean_actual_logits)**2)
    residual_variance = np.sum((original_logits_pred_pos - logits_pred)**2)
    r2_logits = 1 - (residual_variance / total_variance) if total_variance > 1e-9 else np.nan

    print(f"Overall Mean Squared Error between predicted and actual logits: {mse_logits:.4f}")
    print(f"Overall R-squared between predicted and actual logits:        {r2_logits:.4f}")

    # Per-Sample Comparison (Example: input (0,0))
    if 'idx_00' in locals() and idx_00 is not None:
        logits_pred_00 = logits_pred[idx_00]
        logits_actual_00 = original_logits_pred_pos[idx_00]

        plt.figure(figsize=(12, 6))
        output_indices = np.arange(num_outputs)
        plt.plot(output_indices, logits_actual_00, marker='o', linestyle='-', color='blue', label='Actual Logits (0,0)')
        plt.plot(output_indices, logits_pred_00, marker='x', linestyle='--', color='red', label=f'Predicted Logits (R²={r2_logits:.3f})')
        plt.title("Comparison of Actual vs Predicted Logits for Input (0,0)")
        plt.xlabel("Output Logit Index (k)")
        plt.ylabel("Logit Value")
        plt.xticks(output_indices)
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.show()

    # Scatter plot of predicted vs actual (for all logits)
    plt.figure(figsize=(8, 8))
    plt.scatter(original_logits_pred_pos.flatten(), logits_pred.flatten(), alpha=0.1)
    # Add y=x line
    min_val = min(np.min(original_logits_pred_pos), np.min(logits_pred)) - 1
    max_val = max(np.max(original_logits_pred_pos), np.max(logits_pred)) + 1
    plt.plot([min_val, max_val], [min_val, max_val], color='red', linestyle='--', label='y=x')
    plt.xlabel("Actual Logit Value")
    plt.ylabel("Predicted Logit Value (from Poly2 -> ReLU -> W_L)")
    plt.title(f"Predicted vs Actual Logits (Overall R² = {r2_logits:.4f})")
    plt.grid(True, alpha=0.3)
    plt.axis('equal') # Ensure axes have same scale
    plt.legend()
    plt.show()


This confirms the observation that the model learns to use the 2 PCs, and use them quadratically. Replacing our fit and running the rest of the model (ReLU + W_L multiplication) keeps performance effectively the same. Let us now look into whether the post-ReLU also looks like a quadratic function of the 2 PCs. We expect not, because of the piecewise nature of the ReLU, which is hard to approximate with just a quadratic.

In [None]:
print("\nFitting FULL model POST-ReLU activations (a_mlp_full) vs. Poly2(PC1, PC2 features)...")

# --- Check if data is available ---
if 'X_poly_features' not in locals() or X_poly_features is None:
    print("Error: 'X_poly_features' not found.")
elif 'a_mlp_full' not in locals() or a_mlp_full is None:
    print("Error: 'a_mlp_full' not found.")
else:
    num_samples, num_poly_feats = X_poly_features.shape
    num_samples_a, num_neurons = a_mlp_full.shape
    assert num_neurons == d_mlp
    assert num_samples == num_samples_a

    # Array to store R-squared scores
    neuron_r2_scores_a_full_fit = np.zeros(d_mlp)
    neuron_fit_successful_a = np.zeros(d_mlp, dtype=bool)

    # --- Fit Models for Each Neuron ---
    print(f"Fitting polynomial model for {d_mlp} neurons...")
    for i in range(d_mlp):
        # Target variable: full POST-ReLU activation for neuron i
        y_neuron_a_full = a_mlp_full[:, i]

        # Check for constant activation (e.g., always zero)
        if np.std(y_neuron_a_full) < 1e-9:
            neuron_r2_scores_a_full_fit[i] = np.nan
            continue

        # Fit model: a_full_i ≈ Poly2(PC1, PC2 features)
        poly_reg_a_full = LinearRegression(fit_intercept=False) # Bias included in X_poly_features
        poly_reg_a_full.fit(X_poly_features, y_neuron_a_full)
        neuron_r2_scores_a_full_fit[i] = poly_reg_a_full.score(X_poly_features, y_neuron_a_full)
        neuron_fit_successful_a[i] = True

    print("Fitting complete.")

    # --- Analysis ---
    valid_r2_scores_a = neuron_r2_scores_a_full_fit[neuron_fit_successful_a]
    valid_r2_pre_prev = neuron_r2_scores_z_full_fit[neuron_fit_successful_a] # Align masks

    if len(valid_r2_scores_a) > 0:
        print("\n--- Analysis of R-squared (Fit a_mlp_full vs Poly2(PC1, PC2 features)) ---")
        print(f"Mean R²:   {np.mean(valid_r2_scores_a):.4f}")
        print(f"Median R²: {np.median(valid_r2_scores_a):.4f}")
        print(f"Min R²:    {np.min(valid_r2_scores_a):.4f}")
        print(f"Max R²:    {np.max(valid_r2_scores_a):.4f}")
        print(f"Neurons with R² > 0.95: {np.sum(valid_r2_scores_a > 0.95)} / {len(valid_r2_scores_a)}")
        print(f"Neurons with R² > 0.99: {np.sum(valid_r2_scores_a > 0.99)} / {len(valid_r2_scores_a)}")

        # --- Visualize R-squared Distribution ---
        plt.figure(figsize=(10, 6))
        sns.histplot(valid_r2_scores_a, bins=50, kde=False)
        plt.title("R² Distribution: Fit of Post-ReLU Activations (a_full) vs. Poly2(PC1, PC2 Features)")
        plt.xlabel("R-squared Score per Neuron")
        plt.ylabel("Number of Neurons")
        plt.grid(axis='y', alpha=0.5)
        plt.show()

        # --- Compare Pre vs Post R-squared using Poly2 fit ---
        if 'neuron_r2_scores_z_full_fit' in locals():
            avg_r2_drop_poly = np.mean(valid_r2_pre_prev - valid_r2_scores_a)
            median_r2_drop_poly = np.median(valid_r2_pre_prev - valid_r2_scores_a)
            print(f"\n--- Comparison to Pre-ReLU Poly2 Fit ---")
            print(f"Mean R² drop (Poly2(z_full) - Poly2(a_full)):   {avg_r2_drop_poly:.4f}")
            print(f"Median R² drop (Poly2(z_full) - Poly2(a_full)): {median_r2_drop_poly:.4f}")

            # Visualize the drop
            plt.figure(figsize=(8, 6))
            plt.scatter(valid_r2_pre_prev, valid_r2_scores_a, alpha=0.3)
            min_both = min(np.min(valid_r2_pre_prev), np.min(valid_r2_scores_a))
            max_both = max(np.max(valid_r2_pre_prev), np.max(valid_r2_scores_a))
            plt.plot([min_both, max_both], [min_both, max_both], color='red', linestyle='--', label='y=x')
            plt.xlabel("R² for Pre-ReLU (z_full) vs Poly2(PCs)")
            plt.ylabel("R² for Post-ReLU (a_full) vs Poly2(PCs)")
            plt.title("Impact of ReLU on Poly2 Fit Accuracy")
            plt.grid(True, alpha=0.3)
            plt.legend()
            plt.xlim(min_both - 0.01, max_both + 0.01)
            plt.ylim(min_both - 0.01, max_both + 0.01)
            plt.show()
        else:
            print("\nCannot compare to Pre-ReLU Poly2 fit R-squared (previous results not found).")
    else:
        print("Could not calculate R-squared for any neurons.")


it actually does not get so much worse! we still have the vast majority of neurons with ~0.95 quality fit. this is somewhat unexpected. shall we try to use this fit as input to the W_L and compare the resulting simulated logits with the actual model logits? we now expect this to work decently well. and if it does, we can fully focus on the unembedding.

In [None]:
print("\nSimulating Logits using Poly2(PCs) fit for POST-ReLU activations (a_mlp_full)...")

# --- Check if data is available ---
# (Add checks similar to previous blocks if running standalone)
if 'a_mlp_full' not in locals() or 'X_poly_features' not in locals() or 'W_L_np' not in locals() or 'original_logits_pred_pos' not in locals():
     print("Error: Missing necessary data. Please run previous steps.")
else:
    # --- 1. Fit a_mlp_full and get coefficients ---
    # (Re-running fit just to ensure coeffs are available here, can be optimized)
    print(f"Fitting polynomial model for {d_mlp} neurons to get coefficients for a_mlp_full...")
    learned_poly_coeffs_a = np.zeros((d_mlp, X_poly_features.shape[1]))
    neuron_fit_successful_a = np.zeros(d_mlp, dtype=bool) # Recreate mask locally if needed

    for i in range(d_mlp):
        y_neuron_a_full = a_mlp_full[:, i]
        if np.std(y_neuron_a_full) < 1e-9:
            learned_poly_coeffs_a[i, :] = np.nan
            continue
        poly_reg_a_full = LinearRegression(fit_intercept=False)
        poly_reg_a_full.fit(X_poly_features, y_neuron_a_full)
        learned_poly_coeffs_a[i, :] = poly_reg_a_full.coef_
        neuron_fit_successful_a[i] = True

    valid_a_fit_indices = np.where(neuron_fit_successful_a)[0]
    print(f"Predicting a_mlp for {len(valid_a_fit_indices)} successfully fitted neurons...")

    # --- 2. Predict a_mlp using polynomial fits ---
    # Prediction is dot product: (num_samples, num_poly_feats) @ (num_poly_feats, d_mlp)
    a_mlp_pred_from_poly = X_poly_features @ learned_poly_coeffs_a[valid_a_fit_indices, :].T
    # Create full d_mlp array, filling non-fitted with zeros
    a_mlp_pred_full_dim = np.zeros((num_samples, d_mlp))
    a_mlp_pred_full_dim[:, valid_a_fit_indices] = a_mlp_pred_from_poly

    print("Predicted a_mlp shape:", a_mlp_pred_full_dim.shape)

    # --- 3. Apply W_L ---
    # Logits[sample, k] = Sum_i W_L[k, i] * a_mlp_pred[sample, i]
    logits_pred_from_poly_a = a_mlp_pred_full_dim @ W_L_np.T
    print("Calculated predicted logits (from poly fit of a_mlp). Shape:", logits_pred_from_poly_a.shape)

    # --- 4. Compare Predicted vs Actual Logits ---
    print("\n--- Comparing Logits [Actual vs. Poly2(a_mlp) Sim vs. ReLU(Poly2(z_mlp)) Sim] ---")

    # Overall Metrics for Poly2(a_mlp) simulation
    mse_logits_poly_a = mean_squared_error(original_logits_pred_pos.flatten(), logits_pred_from_poly_a.flatten())
    mean_actual_logits = np.mean(original_logits_pred_pos)
    total_variance = np.sum((original_logits_pred_pos - mean_actual_logits)**2)
    residual_variance_poly_a = np.sum((original_logits_pred_pos - logits_pred_from_poly_a)**2)
    r2_logits_poly_a = 1 - (residual_variance_poly_a / total_variance) if total_variance > 1e-9 else np.nan

    print(f"Simulation using Poly2(a_mlp) fit:")
    print(f"  Overall Mean Squared Error: {mse_logits_poly_a:.4f}")
    print(f"  Overall R-squared:          {r2_logits_poly_a:.4f}")

    # Compare to previous simulation using ReLU(Poly2(z_mlp))
    if 'logits_pred' in locals() and logits_pred is not None:
         mse_logits_relu = mean_squared_error(original_logits_pred_pos.flatten(), logits_pred.flatten())
         r2_logits_relu = 1 - (np.sum((original_logits_pred_pos - logits_pred)**2) / total_variance) if total_variance > 1e-9 else np.nan
         print(f"\nSimulation using ReLU(Poly2(z_mlp)) fit (Previous):")
         print(f"  Overall Mean Squared Error: {mse_logits_relu:.4f}")
         print(f"  Overall R-squared:          {r2_logits_relu:.4f}")
         print(f"\nDifference in R² (ReLU Sim - Poly(a_mlp) Sim): {r2_logits_relu - r2_logits_poly_a:.6f}")
    else:
        print("\nPrevious simulation results ('logits_pred') not available for comparison.")


# --- Per-Sample Comparison (Example: input (0,0)) ---
if 'idx_00' in locals() and idx_00 is not None \
   and 'original_logits_pred_pos' in locals() \
   and 'logits_pred_from_poly_a' in locals(): # Check the result of the current simulation

    logits_pred_poly_a_00 = logits_pred_from_poly_a[idx_00]
    logits_actual_00 = original_logits_pred_pos[idx_00]

    # --- Fix for AttributeError ---
    logits_pred_relu_00 = None # Default to None
    if 'logits_pred' in locals() and isinstance(logits_pred, np.ndarray):
        # Check if the index is valid for the array
        if idx_00 < logits_pred.shape[0]:
             logits_pred_relu_00 = logits_pred[idx_00]
        else:
             print(f"Warning: idx_00 ({idx_00}) is out of bounds for logits_pred shape {logits_pred.shape}")
    # --- End Fix ---

    plt.figure(figsize=(12, 6))
    output_indices = np.arange(num_outputs)
    plt.plot(output_indices, logits_actual_00, marker='o', linestyle='-', color='blue', label='Actual Logits (0,0)')
    plt.plot(output_indices, logits_pred_poly_a_00, marker='s', linestyle=':', color='green', alpha=0.8, label=f'Pred from Poly(a_mlp) (R²={r2_logits_poly_a:.3f})')

    # Plot the previous simulation only if available
    if logits_pred_relu_00 is not None and 'r2_logits_relu' in locals():
         plt.plot(output_indices, logits_pred_relu_00, marker='x', linestyle='--', color='red', alpha=0.8, label=f'Pred from ReLU(Poly(z_mlp)) (R²={r2_logits_relu:.3f})')
    elif 'r2_logits_relu' in locals():
        # Add legend entry even if plotting fails, indicating R^2
         plt.plot([], [], marker='x', linestyle='--', color='red', alpha=0.8, label=f'Pred from ReLU(Poly(z_mlp)) (R²={r2_logits_relu:.3f})')


    plt.title("Comparison of Actual vs Predicted Logits for Input (0,0)")
    plt.xlabel("Output Logit Index (k)")
    plt.ylabel("Logit Value")
    plt.xticks(output_indices)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()

# --- Scatter plot of predicted vs actual (for Poly2(a_mlp) fit) ---
if 'logits_pred_from_poly_a' in locals() and 'original_logits_pred_pos' in locals() and 'r2_logits_poly_a' in locals():
    plt.figure(figsize=(8, 8))
    plt.scatter(original_logits_pred_pos.flatten(), logits_pred_from_poly_a.flatten(), alpha=0.1)
    min_val = min(np.min(original_logits_pred_pos), np.min(logits_pred_from_poly_a)) - 1
    max_val = max(np.max(original_logits_pred_pos), np.max(logits_pred_from_poly_a)) + 1
    plt.plot([min_val, max_val], [min_val, max_val], color='red', linestyle='--', label='y=x')
    plt.xlabel("Actual Logit Value")
    plt.ylabel("Predicted Logit Value (from Poly2(a_mlp) -> W_L)")
    plt.title(f"Predicted vs Actual Logits using Poly(a_mlp) (Overall R² = {r2_logits_poly_a:.4f})")
    plt.grid(True, alpha=0.3)
    plt.axis('equal')
    plt.legend()
    plt.show()
else:
    print("Skipping plots due to missing data.")


the shape of the function is effectively identical! it follows the same shape, just slightly deviating from the original a tad bit more than the previous simulation. the relative ordering is also kept i think, so i suspect that this would result in the same predictions.

it appears that the main mystery is in the unembedding. how can we best analyze this? can we do PCA like we did for the embedding?

To further understand the structure embedded within the unembedding matrix W_L ([num_outputs, d_mlp]), we will now analyze its columns using Principal Component Analysis (PCA). Each column W_L[:, i] represents the pattern of influence neuron i has across all output logits (0 to num_outputs-1). By treating these d_mlp columns as our data points, PCA will identify the principal axes of variation, revealing the most common "shapes" or "basis vectors" that these neuron influence patterns are composed of. Based on our Fourier analysis, we hypothesize that the first few principal components will correspond to low-frequency sine and cosine waves, confirming that these periodic patterns are the dominant structural elements learned within W_L.


In [None]:
from sklearn.metrics.pairwise import cosine_similarity # For comparing vectors

# --- Assume previous steps ran and these variables exist ---
# W_L_np: Numpy array [num_outputs, d_mlp] (NumPy version of W_L matrix)
# num_outputs: Number of output logits (e.g., 19)
# d_mlp: Dimension of MLP
# --- End assumptions ---

# --- Check if data is available ---
if 'W_L_np' not in locals() or W_L_np is None:
    print("Error: 'W_L_np' not found.")
else:
    print(f"\nPerforming PCA on the COLUMNS of W_L (shape: {W_L_np.shape})...")
    print("Treating each neuron's output weight vector as a sample.")

    # --- 1. Prepare Data for PCA ---
    # We want to find patterns across the output dimension (length num_outputs).
    # Each neuron's column is a sample. PCA expects samples as rows. So, TRANSPOSE W_L.
    X_pca = W_L_np.T # Shape: [d_mlp, num_outputs]

    # --- 2. Center the Data ---
    # Subtract the mean neuron output vector
    mean_vector = np.mean(X_pca, axis=0) # Mean across neurons
    X_centered = X_pca - mean_vector

    # --- 3. Perform SVD ---
    # U: [d_mlp, K], S: [K,], Vt: [K, num_outputs] where K = min(d_mlp, num_outputs)
    # Need full_matrices=False (or True, Vt shape depends on it, but components are the same)
    try:
         U, S, Vt = np.linalg.svd(X_centered, full_matrices=False)
         print(f"SVD complete. Vt shape: {Vt.shape}") # Should be [~19, 19]
         num_components = Vt.shape[0] # Number of components found
    except Exception as e:
         print(f"SVD failed: {e}")
         Vt = None

    if Vt is not None:
        # --- 4. Principal Components & Variance Explained ---
        # Principal components are the ROWS of Vt
        principal_components = Vt # Shape: [num_components, num_outputs]

        # Explained variance ratio
        explained_variance = S**2 / np.sum(S**2)
        cumulative_variance = np.cumsum(explained_variance)

        print(f"\nVariance Explained by Principal Components:")
        for i in range(min(10, num_components)): # Show top 10
            print(f"  PC {i+1}: {explained_variance[i]:.4f} (Cumulative: {cumulative_variance[i]:.4f})")

        # --- 5. Scree Plot (Explained Variance) ---
        plt.figure(figsize=(10, 6))
        plt.bar(range(1, num_components + 1), explained_variance, alpha=0.7, align='center', label='Individual explained variance')
        plt.step(range(1, num_components + 1), cumulative_variance, where='mid', label='Cumulative explained variance')
        plt.ylabel('Explained variance ratio')
        plt.xlabel('Principal component index')
        plt.title('PCA Explained Variance for W_L Columns')
        plt.xticks(range(1, num_components + 1))
        plt.legend(loc='best')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.ylim(0, 1.1) # Set y-axis limit
        # Add threshold line, e.g., at 95% variance
        try:
            idx_95 = np.where(cumulative_variance >= 0.95)[0][0]
            plt.axvline(idx_95 + 1, color='red', linestyle=':', label=f'{(idx_95+1)} Components for 95% Var')
            plt.legend(loc='best') # Update legend
        except IndexError:
            pass # If 95% is never reached
        plt.show()

        # --- 6. Visualize Top Principal Components vs Fourier Basis ---
        num_pcs_to_plot = 4 # Plot PC1 to PC4
        output_indices_k = np.arange(num_outputs)

        fig, axes = plt.subplots(num_pcs_to_plot, 1, figsize=(12, 3 * num_pcs_to_plot), sharex=True)
        if num_pcs_to_plot == 1: axes = [axes] # Make iterable

        P = num_outputs # Period for Fourier basis

        for i in range(num_pcs_to_plot):
            ax = axes[i]
            pc_vector = principal_components[i, :]

            # Plot the PC vector
            ax.plot(output_indices_k, pc_vector, marker='o', linestyle='-', label=f'PC {i+1} (Var: {explained_variance[i]:.3f})')

            # Compare with low-frequency cosine/sine
            freq_to_compare = (i // 2) + 1 # Compare PC1/2 with freq 1, PC3/4 with freq 2, etc.
            cos_basis = np.cos(2 * np.pi * freq_to_compare * output_indices_k / P)
            sin_basis = np.sin(2 * np.pi * freq_to_compare * output_indices_k / P)

            # Normalize basis vectors for comparison scaling (optional)
            # cos_basis /= np.linalg.norm(cos_basis)
            # sin_basis /= np.linalg.norm(sin_basis)
            # pc_vector_norm = pc_vector / np.linalg.norm(pc_vector) # Normalize PC too if normalizing basis

            # Calculate Cosine Similarity
            sim_cos = cosine_similarity(pc_vector.reshape(1, -1), cos_basis.reshape(1, -1))[0, 0]
            sim_sin = cosine_similarity(pc_vector.reshape(1, -1), sin_basis.reshape(1, -1))[0, 0]

            # Plot the most similar basis vector (scaled arbitrarily for visualization)
            if abs(sim_cos) > abs(sim_sin):
                 scale = np.dot(pc_vector, cos_basis) / np.dot(cos_basis, cos_basis) # Project PC onto basis
                 ax.plot(output_indices_k, scale * cos_basis, marker='.', linestyle=':', color='red', label=f'Scaled Cos(k={freq_to_compare}), Sim={sim_cos:.2f}')
            else:
                 scale = np.dot(pc_vector, sin_basis) / np.dot(sin_basis, sin_basis) # Project PC onto basis
                 ax.plot(output_indices_k, scale * sin_basis, marker='.', linestyle=':', color='green', label=f'Scaled Sin(k={freq_to_compare}), Sim={sim_sin:.2f}')


            ax.set_ylabel("Component Value")
            ax.set_title(f"Principal Component {i+1} of W_L Columns")
            ax.grid(True, linestyle='--', alpha=0.6)
            ax.axhline(0, color='black', linewidth=0.5)
            ax.legend()

        axes[-1].set_xlabel("Output Logit Index (k)")
        fig.tight_layout()
        plt.show()


Ultra Low-Dimensional Structure: The fact that just TWO principal components explain 98.3% of the variance in the W_L columns is remarkable. It means that despite having 512 neurons, the patterns they use to influence the 20 output logits are overwhelmingly constrained to a simple 2D subspace. All 512 neuron output vectors (W_L[:, i]) are effectively just different linear combinations of these two primary basis vectors (PC1 and PC2 of W_L).


Simplification: This drastically simplifies the unembedding step conceptually. Instead of thinking about 512 complex columns, we know the readout is dominated by combining just two fundamental output shapes (likely sin(k=1) and cos(k=1)).

Revisiting the Whole Algorithm Fit:
Your idea to fit the entire model output (final logits) using the PC1/PC2 features of the input makes perfect sense now. We tried fitting z_mlp and a_mlp to Poly2(InputPCs). Given the W_L structure is also ultra-simple (dominated by 2 components likely related to output frequency k=1), perhaps the entire chain InputPCs -> Poly2 -> ReLU -> W_L(Dominated by 2 PCs) can be effectively approximated by a relatively simple function mapping input PCs to output logits.

Hypothesis: The final Logits[k] for a given k can be approximated by a function of the input features PC1(a), PC1(b), PC2(a), PC2(b), potentially a polynomial, leveraging the low-dimensional structure we found everywhere.

Before checking this hypothesis, we'll check in on the two principal components of the unembedding matrix.


In [None]:
# --- Visualize Top 2 Principal Components ---
num_pcs_to_plot = 2
output_indices_k = np.arange(num_outputs)

print(f"\nVisualizing the shapes of PC1 and PC2 of W_L columns:")

fig, axes = plt.subplots(num_pcs_to_plot, 1, figsize=(12, 3 * num_pcs_to_plot), sharex=True)
if num_pcs_to_plot == 1: axes = [axes] # Make iterable

for i in range(num_pcs_to_plot):
    ax = axes[i]
    pc_vector = principal_components[i, :] # Get the i-th PC (PC{i+1})

    # Plot the PC vector shape
    ax.plot(output_indices_k, pc_vector, marker='o', linestyle='-', label=f'PC {i+1} (Var: {explained_variance[i]:.3f})')

    ax.set_ylabel("Component Value")
    ax.set_title(f"Principal Component {i+1} of W_L Columns (Neuron Output Patterns)")
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.axhline(0, color='black', linewidth=0.5)
    ax.legend()

axes[-1].set_xlabel("Output Logit Index (k)")
axes[-1].set_xticks(output_indices_k) # Show all logit indices if not too many
fig.tight_layout()
plt.show()


In [None]:
print("\nFitting FINAL LOGITS vs. Polynomial(PC1, PC2 features)...")

# --- Check if data is available ---
if 'X_poly_features' not in locals() or X_poly_features is None:
    print("Error: 'X_poly_features' not found.")
elif 'original_logits_pred_pos' not in locals() or original_logits_pred_pos is None:
    print("Error: 'original_logits_pred_pos' not found.")
else:
    num_samples, num_poly_feats = X_poly_features.shape
    num_samples_l, num_logits = original_logits_pred_pos.shape
    assert num_logits == num_outputs
    assert num_samples == num_samples_l

    # Array to store R-squared scores for each logit's fit
    logit_r2_scores_fit = np.zeros(num_outputs)
    logit_fit_successful = np.zeros(num_outputs, dtype=bool)

    # --- Fit Models for Each Logit ---
    print(f"Fitting polynomial model for {num_outputs} output logits...")
    for k in range(num_outputs):
        # Target variable: actual logit value for index k
        y_logit_k = original_logits_pred_pos[:, k]

        # Check for constant value (might happen for far-off logits)
        if np.std(y_logit_k) < 1e-9:
            logit_r2_scores_fit[k] = np.nan
            continue

        # Fit model: Logit[k] ≈ Poly2(p1a, p1b, p2a, p2b)
        poly_reg_logit_k = LinearRegression(fit_intercept=False) # Bias included in X_poly_features
        poly_reg_logit_k.fit(X_poly_features, y_logit_k)
        logit_r2_scores_fit[k] = poly_reg_logit_k.score(X_poly_features, y_logit_k)
        logit_fit_successful[k] = True

    print("Fitting complete.")

    # --- Analysis ---
    valid_r2_scores = logit_r2_scores_fit[logit_fit_successful]

    if len(valid_r2_scores) > 0:
        print("\n--- Analysis of R-squared (Fit Logit[k] vs Poly2(PC1, PC2 features)) ---")
        print(f"Mean R² across logits:   {np.nanmean(logit_r2_scores_fit):.4f}")
        print(f"Median R² across logits: {np.nanmedian(logit_r2_scores_fit):.4f}")
        print(f"Min R² across logits:    {np.nanmin(logit_r2_scores_fit):.4f}")
        print(f"Max R² across logits:    {np.max(logit_r2_scores_fit):.4f}")
        print(f"Logits with R² > 0.95: {np.sum(logit_r2_scores_fit[logit_fit_successful] > 0.95)} / {len(valid_r2_scores)}")
        print(f"Logits with R² > 0.99: {np.sum(logit_r2_scores_fit[logit_fit_successful] > 0.99)} / {len(valid_r2_scores)}")

        # --- Visualize R-squared per Logit ---
        plt.figure(figsize=(12, 6))
        output_indices = np.arange(num_outputs)
        plt.bar(output_indices[logit_fit_successful], valid_r2_scores, width=0.9)
        plt.title("R² Score per Output Logit (Fit Logit[k] vs. Poly2(PC1, PC2 Features))")
        plt.xlabel("Output Logit Index (k)")
        plt.ylabel("R-squared Score")
        plt.xticks(output_indices)
        plt.ylim(min(0.9, np.nanmin(logit_r2_scores_fit)-0.01), 1.01) # Zoom in on high values
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.show()
    else:
        print("Could not calculate R-squared for any logits.")


Interpretation:

Quadratic Sufficiency: Yes, you can absolutely conclude that a degree-2 polynomial function of the input PC1/PC2 features is sufficient to approximate the final output logits with very high accuracy (R² > 0.96 for all logits, mean ~0.975).

Effective Algorithm IS Simple: While the mechanistic steps involve ReLU and a Fourier-structured W_L, the net functional mapping from the dominant input features to the final logit values behaves as if it were a relatively simple quadratic function. The complexities introduced by ReLU and the specific W_L structure largely cancel out or combine in a way that is well-approximated polynomially.

Let's find the logit k with the highest R² (~0.9906) from your results and look at its specific polynomial coefficients. This shows the "recipe" used to calculate the score for that specific potential sum.


In [None]:
print("\nAnalyzing coefficients for the logit with the highest R-squared fit...")

# --- Find best-fit logit ---
# Handle potential NaNs if any fits failed
valid_logit_indices = np.where(~np.isnan(logit_r2_scores_fit))[0]
if len(valid_logit_indices) == 0:
    print("Error: No valid R-squared scores found for any logit.")
else:
    best_logit_k_idx_in_valid = np.argmax(logit_r2_scores_fit[valid_logit_indices])
    best_logit_k = valid_logit_indices[best_logit_k_idx_in_valid] # Get the actual logit index k
    best_r2 = logit_r2_scores_fit[best_logit_k]
    print(f"Best fit found for Logit {best_logit_k} with R² = {best_r2:.4f}")

    # --- Re-fit just this logit to get coefficients ---
    y_best_logit = original_logits_pred_pos[:, best_logit_k]

    # Fit model: Logit[k_best] ≈ Poly2(p1a, p1b, p2a, p2b)
    poly_reg_best_logit = LinearRegression(fit_intercept=False) # Bias included in X_poly_features
    poly_reg_best_logit.fit(X_poly_features, y_best_logit)

    # --- Get and display coefficients ---
    coeffs_best_logit = poly_reg_best_logit.coef_
    poly_feature_names = poly.get_feature_names_out(['p1a', 'p1b', 'p2a', 'p2b'])

    print(f"\n--- Coefficients for Logit {best_logit_k} (R²={best_r2:.4f}) ---")
    coeffs_df_best = pd.DataFrame({
        'Feature': poly_feature_names,
        'Coefficient': coeffs_best_logit
    })
    # Add absolute value for sorting by importance
    coeffs_df_best['AbsCoefficient'] = np.abs(coeffs_df_best['Coefficient'])
    print(coeffs_df_best.sort_values('AbsCoefficient', ascending=False).to_string(index=False, float_format="%.4f"))


We found that the final logits are extremely well predicted by a degree-2 polynomial of the input PC features. To simplify the interpretation of the specific polynomial learned for the best-fitting logit (k=2, R²≈0.9906), we will leverage the observed symmetry between a and b features. We refit Logit[2] using only symmetric quadratic features: p1_sum = PC1(a)+PC1(b), p2_sum = PC2(a)+PC2(b), p1_sum^2, p2_sum^2, and p1_sum * p2_sum. Analyzing the coefficients of this simpler 5-feature model will make the core computational relationship between the combined input features and the logit value clearer, while also quantifying any small loss in accuracy due to enforcing perfect symmetry.


In [None]:
print(f"\nFitting Logit {best_logit_k} vs. SYMMETRIC Quadratic Features...")

# --- Check if data is available ---
if 'X_symm_quad_features' not in locals() or X_symm_quad_features is None:
     print("Error: Symmetric quadratic features ('X_symm_quad_features') not found.")
elif 'original_logits_pred_pos' not in locals() or original_logits_pred_pos is None:
     print("Error: Actual original logits ('original_logits_pred_pos') not found.")
elif 'best_logit_k' not in locals() or best_logit_k is None:
     print("Error: Index of best logit ('best_logit_k') not found.")
else:
    # --- Prepare Target Data ---
    y_best_logit = original_logits_pred_pos[:, best_logit_k]

    # --- Fit Symmetric Model ---
    # Fit model: Logit[k_best] ≈ SymmQuadPoly(p1s, p2s) + Bias
    symm_quad_reg_logit_k = LinearRegression(fit_intercept=True)
    symm_quad_reg_logit_k.fit(X_symm_quad_features, y_best_logit)

    # --- Get Results ---
    coeffs_symm_quad = symm_quad_reg_logit_k.coef_
    intercept_symm_quad = symm_quad_reg_logit_k.intercept_
    r2_symm_quad = symm_quad_reg_logit_k.score(X_symm_quad_features, y_best_logit)

    print(f"\n--- Coefficients for Logit {best_logit_k} using Symmetric Features (R²={r2_symm_quad:.4f}) ---")
    print(f"Intercept: {intercept_symm_quad:.4f}")
    coeffs_df_symm = pd.DataFrame({
        'Feature': symm_quad_feature_names,
        'Coefficient': coeffs_symm_quad
    })
    # Add absolute value for sorting
    coeffs_df_symm['AbsCoefficient'] = np.abs(coeffs_df_symm['Coefficient'])
    print(coeffs_df_symm.sort_values('AbsCoefficient', ascending=False).to_string(index=False, float_format="%.4f"))

    # --- Compare R² ---
    if 'best_r2' in locals():
        r2_drop = best_r2 - r2_symm_quad
        print(f"\nComparison to Full Poly2 Fit R² ({best_r2:.4f}):")
        print(f"  R² drop from enforcing symmetry: {r2_drop:.6f}")
    else:
        print("\nCannot compare R² (previous best R² value not found).")


**Project Summary: Reverse-Engineering Addition in a Small Transformer**

**Goal:** Understand the algorithm learned by a simple 1-layer Transformer trained to compute `c = a + b`, where `a` and `b` are single digits (0-9) and the output `c` ranges from 0-18. This was initially conceived as a simplified Fibonacci step.

**1. Input Representation (Embeddings `W_E`):**

*   **Method:** Principal Component Analysis (PCA) on the learned token embeddings (0-9).
*   **Findings:** The embeddings exist predominantly on a 2D manifold.
    *   **PC1 (~73% variance):** Correlates linearly with the token value (`i`).
    *   **PC2 (~25% variance, total ~99.4%):** Correlates quadratically/parabolically with the token value (`i^2`).
    *   **Symmetry:** The embedding structure treats `a` and `b` sources symmetrically.
*   **Frequency View:** Fourier analysis of `W_E` confirmed this, showing dominant power at frequency indices k=1 (corresponding to linear PC1) and k=2 (corresponding to quadratic PC2) across the input token dimension (0-9).
*   **Interpretation:** The model *learns* to represent input numbers primarily using their linear value and a quadratic/boundary-sensitive feature.

**2. The Necessity of PC2 (Ablation Studies):**

*   **Method:** Ablated (removed) the PC2 component from the embeddings before feeding them into the model.
*   **Findings:** Model accuracy plummeted from ~100% to ~42%. Failures were systematic:
    *   Inputs involving 0 (e.g., `(0,x)`) caused over-prediction.
    *   Inputs involving 9 (e.g., `(9,x)`) caused under-prediction.
    *   Diagonal inputs `(x,x)` often failed.
    *   Extreme asymmetry `(0,9)` and `(9,0)` were predicted *correctly*.
    *   Ablating the MLP skip connection had negligible impact.
*   **Interpretation:** PC2 is causally crucial for correct computation, specifically for handling boundary conditions. The linear PC1 feature alone is insufficient for the network (with ReLU and `W_L`) to generalize addition correctly.

**3. MLP Computation (Pre-ReLU `z_mlp`):**

*   **Method:** Fitted the pre-ReLU activations of the full model (`z_mlp_full`) to functions of the input PC features.
*   **Findings:**
    *   `z_mlp_full` is **extremely well-explained (mean R² ≈ 0.998 for all 512 neurons)** by a **degree-2 polynomial** of the four input features: `PC1(a), PC1(b), PC2(a), PC2(b)`.
    *   The fit using only *linear* terms of these features was significantly worse (mean R² ≈ 0.75), confirming quadratic terms are essential pre-ReLU.
    *   The learned polynomials exhibited strong **symmetry** between `a` and `b` features. A fit using only 5 symmetric quadratic features (like `PC1(a)+PC1(b)`, `(PC1(a)+PC1(b))^2`, etc.) still achieved high accuracy (mean R² ≈ 0.98).
    *   The *change* induced by PC2 (`delta_z_mlp = z_full - z_ablated`) was well-fit by a simple linear/quadratic function of `PC2(a), PC2(b)`.
*   **Interpretation:** The MLP's core computation *before* the non-linearity is surprisingly simple and mathematically characterizable: it computes a specific, symmetric quadratic polynomial of the dominant input features (PC1, PC2) for each neuron.

**4. MLP Activation & ReLU Impact:**

*   **Method:** Applied ReLU (`a_mlp = max(0, z_mlp)`). Fitted `a_mlp` vs inputs. Compared pre/post ReLU linearity. Analyzed neuron state flips caused by PC2.
*   **Findings:**
    *   Fitting `a_mlp_full` with `Poly2(PCs)` still yields a good fit (mean R² ≈ 0.975), but noticeably worse than the near-perfect pre-ReLU fit.
    *   ReLU significantly reduces the direct linear correlation between activations and the sum `a+b` (both for the full and ablated models, with a similar R² drop).
    *   Analysis of neuron flips (e.g., for input (0,0)) showed PC2 causes specific neurons (those suppressing logit 0 and boosting logit 3) to turn OFF, directly implementing the needed correction via `W_L`.
*   **Interpretation:** ReLU is the main source of complex non-linearity. PC2 corrects errors by strategically shifting `z_mlp` relative to the ReLU threshold, thereby controlling which neurons contribute to the final sum via `W_L`. Approximating `a_mlp` itself as `Poly2(PCs)` is functionally quite effective, although mechanistically less precise than `ReLU(Poly2(z_mlp))`.

**5. Unembedding Readout (`W_L` & Logits):**

*   **Method:** Fourier analysis of final logits and `W_L` columns. PCA on `W_L` columns. Simulation of final output.
*   **Findings:**
    *   Final logits (e.g., for inputs (0,0), (4,4)) are dominated by **low-frequency Fourier components** (indices 1, 2, 3 most prominent, up to ~6 relevant).
    *   The columns of `W_L` are also dominated by these same low frequencies, but typically as **combinations/mixes**, not pure sinusoids.
    *   PCA on `W_L` columns revealed an **ultra-low dimensional structure**: the top 2 PCs capture 98.3% of the variance and likely correspond to output frequency k=1 (sin/cos).
    *   Simulating the final output using `Logits ≈ W_L @ ReLU(Poly2(PCs))` achieved very high accuracy (Overall R² ≈ 0.988), closely matching the real model. Simulating using `Logits ≈ W_L @ Poly2(a_mlp)` was slightly less accurate (R² ≈ 0.966).
*   **Interpretation:** `W_L` acts as a readout mechanism structured around a low-dimensional, low-frequency Fourier basis (represented distributively across neurons). It decodes the post-ReLU activation pattern (`a_mlp`) to synthesize the final logit shape.

**Synthesized Algorithm:**

The network learns to compute `a+b` not directly, but by:
1.  Encoding inputs `a, b` using features dominated by PC1 (linear value) and PC2 (quadratic value).
2.  Computing pre-ReLU activations `z_mlp` via a highly accurate, symmetric **degree-2 polynomial function** of these PC features (`z_mlp ≈ Poly2(PCs)`). PC2 terms are essential here.
3.  Applying **ReLU** (`a_mlp = ReLU(z_mlp)`), which selects active neurons based on the polynomial result (PC2 critically influences this near the threshold for boundary inputs).
4.  Performing a linear readout **`Logits = W_L @ a_mlp`**, where `W_L` uses a low-dimensional basis strongly related to **low output frequencies (k=1, 2, 3...)** to combine the activations and synthesize a logit vector peaking at the correct sum `k=a+b`.

Functionally, the entire process is well-approximated by `Logit[k] ≈ Poly2'_k(PCs)`, where `Poly2'_k` is a specific quadratic function learned for each output logit index `k`.

**Significance:** This investigation fully characterized the learned algorithm, revealing why a quadratic computation (using PC2) emerged as necessary for this network to solve linear addition (boundary correction). It connected geometric (PCA), frequency (Fourier), and functional (polynomial fitting) perspectives, and pinpointed the role of each network component (Embedding, MLP pre/post ReLU, Unembedding).

**Next Step: Actual Fibonacci Sequence Prediction**

*   **Why Interesting:** Moves beyond simple addition to include sequence processing. The model must learn *which* inputs to add (`F(n-1)` and `F(n-2)`).
*   **Questions:** How does the attention mechanism learn to find the correct previous terms? Does it use positional encodings heavily? Is the core addition circuit (MLP/W_L) the same as found here, or does it adapt? How does the network handle the initial terms of the sequence? Analyzing attention patterns and comparing the MLP/W_L structure to the current findings would be key.

This provides a compelling narrative arc from a simple problem setup through detailed investigation to a nuanced algorithmic understanding and clear future directions.

While our analysis of the MLP pre-activations (z_mlp) revealed strong symmetry with respect to inputs a and b, we haven't explicitly verified the role of the attention layer preceding it. To confirm whether attention treats the a and b inputs symmetrically (e.g., by averaging their representations) or introduces asymmetry that the MLP later compensates for, we will analyze the attention patterns. Specifically, we will hook the attention probabilities and examine the attention paid by the query at the final sequence position (where the prediction is made) to the keys corresponding to input a (position 0) and input b (position 1). We will visualize these attention weights across all input pairs (a, b) for each head and calculate the average weights to quantify the attention distribution.


In [None]:
n_heads = 4

# --- Prepare Inputs (All 100 pairs) ---
a_inputs = np.arange(N)
b_inputs = np.arange(N)
input_pairs = np.array(np.meshgrid(a_inputs, b_inputs)).T.reshape(-1, 2) # Shape [100, 2]
num_samples = input_pairs.shape[0]

# Create input tokens: [a, b, =] for each pair
test_input_tokens_all = torch.zeros((num_samples, 3), dtype=torch.long)
for i in range(num_samples):
    test_input_tokens_all[i, 0] = input_pairs[i, 0]
    test_input_tokens_all[i, 1] = input_pairs[i, 1]
    test_input_tokens_all[i, 2] = equals_token_id

print(f"Created input tensor shape: {test_input_tokens_all.shape}")

# --- Hook Setup ---
# Hook point for attention probabilities/pattern AFTER softmax
attn_pattern_hook_point = utils.get_act_name("pattern", 0) # For layer 0
print(f"Using attention pattern hook point: {attn_pattern_hook_point}")

# Storage for attention patterns [batch, n_heads, query_pos, key_pos]
captured_attn_patterns = None

def store_attn_pattern_hook(attn_pattern, hook):
    global captured_attn_patterns
    # Store the whole pattern tensor, detach from graph
    captured_attn_patterns = attn_pattern.detach().cpu()
    return attn_pattern # Pass through

# --- Run Model with Hook ---
print("Running model to capture attention patterns...")
model.eval()
with torch.no_grad():
    try:
        _ = model.run_with_hooks(
            test_input_tokens_all,
            fwd_hooks=[(attn_pattern_hook_point, store_attn_pattern_hook)]
        )
    except Exception as e:
        print(f"ERROR during model run with hook: {e}")
        print(f"Please verify the hook point name '{attn_pattern_hook_point}' is correct for attention patterns.")
        captured_attn_patterns = None

# --- Analyze Captured Patterns ---
if captured_attn_patterns is not None:
    print(f"Captured attention pattern shape: {captured_attn_patterns.shape}") # Should be [100, n_heads, 3, 3]

    # Extract attention FROM final position (query_pos=2) TO positions 0 and 1
    # Shape: [batch=100, n_heads]
    attn_to_pos0_tensor = captured_attn_patterns[:, :, 2, 0]
    attn_to_pos1_tensor = captured_attn_patterns[:, :, 2, 1]
    attn_to_pos2_tensor = captured_attn_patterns[:, :, 2, 2] # Attention to '=' token itself

    attn_to_pos0 = attn_to_pos0_tensor.numpy()
    attn_to_pos1 = attn_to_pos1_tensor.numpy()
    attn_to_pos2 = attn_to_pos2_tensor.numpy()


    # Reshape for heatmaps (map batch index back to a, b grid)
    # attn_heatmaps[head_idx, a, b] will store the attention value
    attn_heatmaps_pos0 = np.zeros((n_heads, N, N))
    attn_heatmaps_pos1 = np.zeros((n_heads, N, N))
    attn_heatmaps_pos2 = np.zeros((n_heads, N, N)) # For checking attention to '='

    for i in range(num_samples):
        a = input_pairs[i, 0]
        b = input_pairs[i, 1]
        for head_idx in range(n_heads):
            attn_heatmaps_pos0[head_idx, a, b] = attn_to_pos0[i, head_idx]
            attn_heatmaps_pos1[head_idx, a, b] = attn_to_pos1[i, head_idx]
            attn_heatmaps_pos2[head_idx, a, b] = attn_to_pos2[i, head_idx]

    # --- Calculate Averages ---
    avg_attn_to_pos0 = np.mean(attn_to_pos0, axis=0) # Average over batch -> shape [n_heads,]
    avg_attn_to_pos1 = np.mean(attn_to_pos1, axis=0)
    avg_attn_to_pos2 = np.mean(attn_to_pos2, axis=0)

    print("\n--- Average Attention Weights (Query='=', Key=Position) ---")
    avg_df = pd.DataFrame({
        'Head': np.arange(n_heads),
        'Avg Attn to Pos 0 (a)': avg_attn_to_pos0,
        'Avg Attn to Pos 1 (b)': avg_attn_to_pos1,
        'Avg Attn to Pos 2 (=)': avg_attn_to_pos2,
        'Sum Avg Attn': avg_attn_to_pos0 + avg_attn_to_pos1 + avg_attn_to_pos2
    })
    print(avg_df.to_string(index=False, float_format="%.4f"))

    # --- Visualize Heatmaps ---
    num_rows = (n_heads + 1) // 2 # Arrange plots in ~2 columns
    fig, axes = plt.subplots(num_rows, 4, figsize=(18, 5 * num_rows), squeeze=False)
    axes_flat = axes.flatten()

    for head_idx in range(n_heads):
        # Plot Attn(= -> a)
        ax = axes_flat[head_idx*2]
        sns.heatmap(attn_heatmaps_pos0[head_idx], annot=False, fmt=".2f", cmap="viridis", ax=ax,
                    linewidths=.5, linecolor='lightgray', cbar=True, square=True, vmin=0, vmax=1)
        ax.set_title(f"Head {head_idx}: Attn Query='=' -> Key='a' (Pos 0)")
        ax.set_xlabel("Input b")
        ax.set_ylabel("Input a")
        ax.invert_yaxis()

        # Plot Attn(= -> b)
        ax = axes_flat[head_idx*2 + 1]
        sns.heatmap(attn_heatmaps_pos1[head_idx], annot=False, fmt=".2f", cmap="viridis", ax=ax,
                    linewidths=.5, linecolor='lightgray', cbar=True, square=True, vmin=0, vmax=1)
        ax.set_title(f"Head {head_idx}: Attn Query='=' -> Key='b' (Pos 1)")
        ax.set_xlabel("Input b")
        ax.set_ylabel("Input a")
        ax.invert_yaxis()

    # Hide any unused axes
    for i in range(n_heads * 2, len(axes_flat)):
         axes_flat[i].set_visible(False)

    fig.suptitle("Attention Patterns from Final Position ('=') Query", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.96])
    plt.show()

else:
    print("\nAttention pattern analysis skipped due to error during model run.")


All heads pay roughly the same amount of attention (what does this actually mean? what's the metric) to 'a' and 'b', differing only in the tenths of percent. This was expected, as we found the transformer treats a and b very symmetrically.

We find that two heads effectively ignore the '=', with less than 1% attention paid to it. The third one has roughly 2%, and the fourth one ~12%. Further, the plots seem to indicate that the attention varies per input token combination, and it vasries with a different pattern for each of the heads. There are however clear patterns for all of them. Can you explain the plots? I am not sure what I am looking at.

### Looking at Activations

Helper variable:

In [None]:
pattern_a = cache["pattern", 0, "attn"][:, :, -1, 0]
pattern_b = cache["pattern", 0, "attn"][:, :, -1, 1]
neuron_acts = cache["post", 0, "mlp"][:, -1, :]
neuron_pre_acts = cache["pre", 0, "mlp"][:, -1, :]

Get all shapes:

In [None]:
for param_name, param in cache.items():
    print(param_name, param.shape)

In [None]:
imshow(cache["pattern", 0].mean(dim=0)[:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [None]:
imshow(cache["pattern", 0][5][:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [None]:
dataset[:4]

In [None]:
imshow(cache["pattern", 0][:, 0, -1, 0].reshape(p, p), title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a")

In [None]:
imshow(
    einops.rearrange(cache["pattern", 0][:, :, -1, 0], "(a b) head -> head a b", a=p, b=p),
    title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a", facet_col=0)

Plotting neuron activations

In [None]:
cache["post", 0, "mlp"].shape

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

### Singular Value Decomposition

In [None]:
W_E.shape


In [None]:
U, S, Vh = torch.svd(W_E)
line(S, title="Singular Values")
imshow(U, title="Principal Components on the Input")

In [None]:
# Control - random Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))
line(S, title="Singular Values Random")
imshow(U, title="Principal Components Random")

## Explaining Algorithm

### Analyse the Embedding - It's a Lookup Table!

In [None]:
U, S, Vh = torch.svd(W_E)
line(U[:, :8].T, title="Principal Components of the embedding", xaxis="Input Vocabulary")

In [None]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")
for freq in range(1, p//2+1):
    fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Sin {freq}")
    fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Cos {freq}")
fourier_basis = torch.stack(fourier_basis, dim=0).to(device)
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)
imshow(fourier_basis, xaxis="Input", yaxis="Component", y=fourier_basis_names)

In [None]:
line(fourier_basis[:8], xaxis="Input", line_labels=fourier_basis_names[:8], title="First 8 Fourier Components")
line(fourier_basis[25:29], xaxis="Input", line_labels=fourier_basis_names[25:29], title="Middle Fourier Components")

In [None]:
imshow(fourier_basis @ fourier_basis.T, title="All Fourier Vectors are Orthogonal")

### Analyse the Embedding

In [None]:
imshow(fourier_basis @ W_E, yaxis="Fourier Component", xaxis="Residual Stream", y=fourier_basis_names, title="Embedding in Fourier Basis")

In [None]:
line((fourier_basis @ W_E).norm(dim=-1), xaxis="Fourier Component", x=fourier_basis_names, title="Norms of Embedding in Fourier Basis")

In [None]:
key_freqs = [17, 25, 32, 47]
key_freq_indices = [33, 34, 49, 50, 63, 64, 93, 94]
fourier_embed = fourier_basis @ W_E
key_fourier_embed = fourier_embed[key_freq_indices]
print("key_fourier_embed", key_fourier_embed.shape)
imshow(key_fourier_embed @ key_fourier_embed.T, title="Dot Product of embedding of key Fourier Terms")

### Key Frequencies

In [None]:
line(fourier_basis[[34, 50, 64, 94]], title="Cos of key freqs", line_labels=[34, 50, 64, 94])

In [None]:
line(fourier_basis[[34, 50, 64, 94]].mean(0), title="Constructive Interference")

## Analyse Neurons

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, 0], "(a b) -> a b", a=p, b=p),
    title="First neuron act", xaxis="b", yaxis="a",)

In [None]:
imshow(fourier_basis[94][None, :] * fourier_basis[94][:, None], title="Cos 47a * cos 47b")

In [None]:
imshow(fourier_basis[94][None, :] * fourier_basis[0][:, None], title="Cos 47a * const")

In [None]:
imshow(fourier_basis @ neuron_acts[:, 0].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 0", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
imshow(fourier_basis @ neuron_acts[:, 5].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 5", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
imshow(fourier_basis @ torch.randn_like(neuron_acts[:, 0]).reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of RANDOM", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

### Neuron Clusters

In [None]:
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)

In [None]:
neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).to(device)
for freq in range(0, p//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]
imshow(neuron_freq_norm, xaxis="Neuron", yaxis="Freq", y=torch.arange(1, p//2+1), title="Neuron Frac Explained by Freq")

In [None]:
line(neuron_freq_norm.max(dim=0).values.sort().values, xaxis="Neuron", title="Max Neuron Frac Explained over Freqs")

## Read Off the Neuron-Logit Weights to Interpret

In [None]:
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

In [None]:
line((W_logit @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title="W_logit in the Fourier Basis")

In [None]:
neurons_17 = neuron_freq_norm[17-1]>0.85
neurons_17.shape

In [None]:
neurons_17.sum()

In [None]:
line((W_logit[neurons_17] @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title="W_logit for freq 17 neurons in the Fourier Basis")

Study sin 17

In [None]:
freq = 17
W_logit_fourier = W_logit @ fourier_basis
neurons_sin_17 = W_logit_fourier[:, 2*freq-1]
line(neurons_sin_17)

In [None]:
neuron_acts.shape

In [None]:
inputs_sin_17c = neuron_acts @ neurons_sin_17
imshow(fourier_basis @ inputs_sin_17c.reshape(p, p) @ fourier_basis.T, title="Fourier Heatmap over inputs for sin17c", x=fourier_basis_names, y=fourier_basis_names)

# Black Box Methods + Progress Measures

## Setup Code

Code to plot embedding freqs

In [None]:
def embed_to_cos_sin(fourier_embed):
    if len(fourier_embed.shape) == 1:
        return torch.stack([fourier_embed[1::2], fourier_embed[2::2]])
    else:
        return torch.stack([fourier_embed[:, 1::2], fourier_embed[:, 2::2]], dim=1)

from neel_plotly.plot import melt

def plot_embed_bars(
    fourier_embed,
    title="Norm of embedding of each Fourier Component",
    return_fig=False,
    **kwargs
):
    cos_sin_embed = embed_to_cos_sin(fourier_embed)
    df = melt(cos_sin_embed)
    # display(df)
    group_labels = {0: "sin", 1: "cos"}
    df["Trig"] = df["0"].map(lambda x: group_labels[x])
    fig = px.bar(
        df,
        barmode="group",
        color="Trig",
        x="1",
        y="value",
        labels={"1": "$w_k$", "value": "Norm"},
        title=title,
        **kwargs
    )
    fig.update_layout(dict(legend_title=""))

    if return_fig:
        return fig
    else:
        fig.show()

Code to test a tensor of edited logits

In [None]:
def test_logits(logits, bias_correction=False, original_logits=None, mode="all"):
    # Calculates cross entropy loss of logits representing a batch of all p^2
    # possible inputs
    # Batch dimension is assumed to be first
    if logits.shape[1] == p * p:
        logits = logits.T
    if logits.shape == torch.Size([p * p, p + 1]):
        logits = logits[:, :-1]
    logits = logits.reshape(p * p, p)
    if bias_correction:
        # Applies bias correction - we correct for any missing bias terms,
        # independent of the input, by centering the new logits along the batch
        # dimension, and then adding the average original logits across all inputs
        logits = (
            einops.reduce(original_logits - logits, "batch ... -> ...", "mean") + logits
        )
    if mode == "train":
        return loss_fn(logits[train_indices], labels[train_indices])
    elif mode == "test":
        return loss_fn(logits[test_indices], labels[test_indices])
    elif mode == "all":
        return loss_fn(logits, labels)

Code to run a metric over every checkpoint

In [None]:
metric_cache = {}

In [None]:
def get_metrics(model, metric_cache, metric_fn, name, reset=False):
    if reset or (name not in metric_cache) or (len(metric_cache[name]) == 0):
        metric_cache[name] = []
        for c, sd in enumerate(tqdm.tqdm((model_checkpoints))):
            model.reset_hooks()
            model.load_state_dict(sd)
            out = metric_fn(model)
            if type(out) == torch.Tensor:
                out = utils.to_numpy(out)
            metric_cache[name].append(out)
        model.load_state_dict(model_checkpoints[-1])
        try:
            metric_cache[name] = torch.tensor(metric_cache[name])
        except:
            metric_cache[name] = torch.tensor(np.array(metric_cache[name]))



## Defining Progress Measures

### Loss Curves

In [None]:
memorization_end_epoch = 1500
circuit_formation_end_epoch = 13300
cleanup_end_epoch = 16600

In [None]:
def add_lines(figure):
    figure.add_vline(memorization_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(circuit_formation_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(cleanup_end_epoch, line_dash="dash", opacity=0.7)
    return figure

In [None]:
fig = line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Modular Addition", line_labels=['train', 'test'], toggle_x=True, toggle_y=True, return_fig=True)
add_lines(fig)

### Logit Periodicity

In [None]:
all_logits = original_logits[:, -1, :]
print(all_logits.shape)
all_logits = einops.rearrange(all_logits, "(a b) c -> a b c", a=p, b=p)
print(all_logits.shape)

In [None]:
coses = {}
for freq in key_freqs:
    print("Freq:", freq)
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)
    cube_predicted_logits /= cube_predicted_logits.norm()
    coses[freq] = cube_predicted_logits

In [None]:
approximated_logits = torch.zeros_like(all_logits)
for freq in key_freqs:
    print("Freq:", freq)
    coeff = (all_logits * coses[freq]).sum()
    print("Coeff:", coeff)
    cosine_sim = coeff / all_logits.norm()
    print("Cosine Sim:", cosine_sim)
    approximated_logits += coeff * coses[freq]
residual = all_logits - approximated_logits
print("Residual size:", residual.norm())
print("Residual fraction of norm:", residual.norm()/all_logits.norm())

In [None]:
random_logit_cube = torch.randn_like(all_logits)
print((all_logits * random_logit_cube).sum()/random_logit_cube.norm()/all_logits.norm())

In [None]:
test_logits(all_logits)

In [None]:
test_logits(approximated_logits)

#### Look During Training

In [None]:
cos_cube = []
for freq in range(1, p//2 + 1):
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)
    cube_predicted_logits /= cube_predicted_logits.norm()
    cos_cube.append(cube_predicted_logits)
cos_cube = torch.stack(cos_cube, dim=0)
print(cos_cube.shape)

In [None]:
def get_cos_coeffs(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals


get_metrics(model, metric_cache, get_cos_coeffs, "cos_coeffs")
print(metric_cache["cos_coeffs"].shape)

In [None]:
fig = line(metric_cache["cos_coeffs"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Coefficients with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Coefficient", return_fig=True)
add_lines(fig)

In [None]:
def get_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals / logits.norm()

get_metrics(model, metric_cache, get_cos_sim, "cos_sim") # You may need a big GPU. If you don't have one and can't work around this, raise an issue for help!
print(metric_cache["cos_sim"].shape)

fig = line(metric_cache["cos_sim"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Cosine Sim with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

In [None]:
def get_residual_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    residual = logits - (vals[:, None, None, None] * cos_cube).sum(dim=0)
    return residual.norm() / logits.norm()

get_metrics(model, metric_cache, get_residual_cos_sim, "residual_cos_sim")
print(metric_cache["residual_cos_sim"].shape)

fig = line([metric_cache["cos_sim"][:, i] for i in range(p//2)]+[metric_cache["residual_cos_sim"]], line_labels=[f"Freq {i}" for i in range(1, p//2+1)]+["residual"], title="Cosine Sim with Predicted Logits + Residual", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

## Restricted Loss

In [None]:
neuron_acts.shape

In [None]:
neuron_acts_square = einops.rearrange(neuron_acts, "(a b) neur -> a b neur", a=p, b=p).clone()
# Center it
neuron_acts_square -= einops.reduce(neuron_acts_square, "a b neur -> 1 1 neur", "mean")
neuron_acts_square_fourier = einsum("a b neur, fa a, fb b -> fa fb neur", neuron_acts_square, fourier_basis, fourier_basis)
imshow(neuron_acts_square_fourier.norm(dim=-1), xaxis="Fourier Component b", yaxis="Fourier Component a", title="Norms of neuron activations by Fourier Component", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.numel())
neuron_acts = cache["post", 0, "mlp"][:, -1, :]

In [None]:
approx_neuron_acts = torch.zeros_like(neuron_acts)
approx_neuron_acts += neuron_acts.mean(dim=0)
a = torch.arange(p)[:, None]
b = torch.arange(p)[None, :]
for freq in key_freqs:
    cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
    cos_apb_vec /= cos_apb_vec.norm()
    cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
    sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
    sin_apb_vec /= sin_apb_vec.norm()
    sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
restricted_logits = approx_neuron_acts @ W_logit
print(loss_fn(restricted_logits[test_indices], test_labels))

In [None]:
print(loss_fn(all_logits, labels)) # This bugged on models not fully trained

### Look During Training

In [None]:
def get_restricted_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)
    approx_neuron_acts += neuron_acts.mean(dim=0)
    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
    restricted_logits = approx_neuron_acts @ model.blocks[0].mlp.W_out @ model.unembed.W_U
    # Add bias term
    restricted_logits += logits.mean(dim=0, keepdim=True) - restricted_logits.mean(dim=0, keepdim=True)
    return loss_fn(restricted_logits[test_indices], test_labels)
get_restricted_loss(model)

In [None]:
get_metrics(model, metric_cache, get_restricted_loss, "restricted_loss", reset=True)
print(metric_cache["restricted_loss"].shape)

In [None]:
fig = line([train_losses[::100], test_losses[::100], metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Restricted Loss Curve", line_labels=['train', 'test', "restricted_loss"], toggle_x=True, toggle_y=True, return_fig=True)
add_lines(fig)

In [None]:
fig = line([torch.tensor(test_losses[::100])/metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Restricted Loss to Test Loss Ratio", toggle_x=True, toggle_y=True, return_fig=True)
# WARNING: bugged when cancelling training half way thr ough
add_lines(fig)

## Excluded Loss

In [None]:
approx_neuron_acts = torch.zeros_like(neuron_acts)
# approx_neuron_acts += neuron_acts.mean(dim=0)
a = torch.arange(p)[:, None]
b = torch.arange(p)[None, :]
for freq in key_freqs:
    cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
    cos_apb_vec /= cos_apb_vec.norm()
    cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
    sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
    sin_apb_vec /= sin_apb_vec.norm()
    sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
excluded_neuron_acts = neuron_acts - approx_neuron_acts
excluded_logits = excluded_neuron_acts @ W_logit
print(loss_fn(excluded_logits[train_indices], train_labels))

In [None]:
def get_excluded_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)
    # approx_neuron_acts += neuron_acts.mean(dim=0)
    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
    excluded_neuron_acts = neuron_acts - approx_neuron_acts
    residual_stream_final = excluded_neuron_acts @ model.blocks[0].mlp.W_out + cache["resid_mid", 0][:, -1, :]
    excluded_logits = residual_stream_final @ model.unembed.W_U
    return loss_fn(excluded_logits[train_indices], train_labels)
get_excluded_loss(model)

In [None]:
get_metrics(model, metric_cache, get_excluded_loss, "excluded_loss", reset=True)
print(metric_cache["excluded_loss"].shape)

In [None]:
fig = line([train_losses[::100], test_losses[::100], metric_cache["excluded_loss"], metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Excluded and Restricted Loss Curve", line_labels=['train', 'test', "excluded_loss", "restricted_loss"], toggle_x=True, toggle_y=True, return_fig=True)

add_lines(fig)