## Setup

In [38]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader

import os
import copy

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
import umap

import einops
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd, itertools
from tqdm.auto import tqdm

from transformer_lens import HookedTransformer, HookedTransformerConfig, utils

# Configure plotly to use static rendering if widgets fail
import plotly.io as pio
pio.renderers.default = "notebook"


In [39]:
float_formatter = "{:.5f}".format
np.set_printoptions(formatter={'float_kind':float_formatter})

## Model

In [None]:
# ---------- constants ----------
LIST_LEN = 2 # [d1, d2]
SEQ_LEN = LIST_LEN * 2 + 1 # [d1, d2, SEP, o1, o2]

N_DIGITS = 100
DIGITS = list(range(N_DIGITS)) # 100 digits from 0 to 99
PAD = N_DIGITS # special padding token
SEP = N_DIGITS + 1 # special seperator token for the model to think about the input (+1 to avoid confusion with the last digit)
VOCAB = len(DIGITS) + 2  # + the special tokens

# For backward compatibility with older versions
USE_PAD = False # whether to use the PAD token in the input sequences (or just SEP)
if not USE_PAD:
    VOCAB -= 1  # -1 for the PAD token

D_MODEL = 8
N_HEAD = 1 # 1
N_LAYER = 3 # 2
USE_LN = False # use layer norm in model
USE_BIAS = False # use bias in model
FREEZE_WV = True # no value matrix in attn 
FREEZE_WO = True # no output matrix in attn (i.e. attn head can only copy inputs to outputs)
WEIGHT_DECAY = 0.01 # default 0.01

TRAIN_SPLIT = 0.8 # 80% train, 20% test

# model name for loading
MODEL_NAME = 'v2_3layer_100dig_16d'
MODEL_PATH = "models/" + MODEL_NAME + ".pt"

DEV = (
    "cuda"
    if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
device = DEV
torch.manual_seed(0)

<torch._C.Generator at 0x739c9ef0c6d0>

In [41]:
# ---------- mask ----------
# attention mask for [d1, d2, SEP, o1, o2] looks like this (query rows are horizontal, key columns are vertical):
# -    d1    d2    SEP    o1    o2   (keys)
# d1  -inf  -inf   -inf  -inf  -inf
# d2   0    -inf   -inf  -inf  -inf
# SEP  0      0    -inf  -inf  -inf
# o1  -inf  -inf    0    -inf   -inf
# o2  -inf  -inf    0      0    -inf
# (queries)

mask_bias = torch.triu(torch.ones(SEQ_LEN, SEQ_LEN) * float("-inf")) # upper triangular bias mask (lead_diag & above = -inf, rest = 0)
mask_bias[0, 0] = 0. # don't want a full row of -inf! otherwise we get nan erros & training breaks
mask_bias[LIST_LEN+1:, :LIST_LEN] = float("-inf") # stop output tokens from attending to input tokens
mask_bias = mask_bias.unsqueeze(0).unsqueeze(0) # (1,1,T,T) broadcastable across batch and heads

print(mask_bias.cpu()[0][0])

tensor([[0., -inf, -inf, -inf, -inf],
        [0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [-inf, -inf, 0., -inf, -inf],
        [-inf, -inf, 0., 0., -inf]])


In [42]:
# ---------- data ----------
# Create all possible combinations of digits
all_data = list(itertools.product(DIGITS, repeat=LIST_LEN))
n_data = len(all_data)
all_data = torch.tensor(all_data, dtype=torch.int64)

# Create sequences of the form [d1, d2, SEP, d1, d2]
all_targets = torch.full((n_data, SEQ_LEN), SEP)
all_targets[:, :LIST_LEN] = all_data
all_targets[:, LIST_LEN+1:] = all_data

# Create input sequences of the form [d1, d2, SEP, PAD, PAD]
all_inputs = all_targets.clone()
all_inputs[:, LIST_LEN+1:] = PAD if USE_PAD else SEP # for backward compat

# Shuffle the dataset (inputs and targets together)
perm = torch.randperm(n_data)
all_inputs = all_inputs[perm]
all_targets = all_targets[perm]

train_ds = TensorDataset(all_inputs[:int(TRAIN_SPLIT*n_data)], all_targets[:int(TRAIN_SPLIT*n_data)])  # 80% for training
val_ds = TensorDataset(all_inputs[int(TRAIN_SPLIT*n_data):], all_targets[int(TRAIN_SPLIT*n_data):])  # 20% for validation
train_batch_size = min(128, len(train_ds))  # Use a batch size of 128 or less if dataset is smaller
val_batch_size = min(256, len(val_ds))  # Use a batch size of 256 or less if dataset is smaller
train_dl = DataLoader(train_ds, train_batch_size, shuffle=True, drop_last=True)
val_dl = DataLoader(val_ds, val_batch_size, drop_last=False)

print("Input:", train_ds[0][0])  # Example input: [d1, d2, SEP, SEP, SEP]
print("Target:", train_ds[0][1]) # Example target: [d1, d2, SEP, d1, d2]
len(train_ds), len(val_ds)  # Should be 80% for train and 20% for validation

Input: tensor([ 60,  44, 101, 101, 101])
Target: tensor([ 60,  44, 101,  60,  44])


(8000, 2000)

In [43]:
# ---------- config helper ----------
def attach_custom_mask(model):
    def _mask(scores, hook=None):
        # scores: (batch, heads, Q, K)
        return scores + mask_bias.to(scores.device)
    
    # register the same mask hook on every layer
    for block in model.blocks:
        block.attn.hook_attn_scores.add_perma_hook(_mask, dir="fwd")


def strip_bias(m):
    for mod in m.modules():
        if hasattr(mod, "bias") and mod.bias is not None:
            mod.bias.requires_grad_(False)
            torch.nn.init.zeros_(mod.bias)
            print(mod)

    # remove biases from attention layers
    attn_biases = ['b_Q', 'b_K', 'b_V', 'b_O']
    for block in m.blocks:
        for b in attn_biases:
            mod = getattr(block.attn, b, None)
            if mod is not None:
                mod.requires_grad_(False)
                torch.nn.init.zeros_(mod)

    # remove unembed bias
    if hasattr(m, "unembed") and m.b_U is not None:
        m.unembed.b_U.requires_grad_(False)
        torch.nn.init.zeros_(m.unembed.b_U)

def set_WV_identity_and_freeze(model, d_model):
    with torch.no_grad():
        # Create a stack of identity-like matrices for W_V
        # Each matrix is of shape (d_model, d_head)
        # We take the first d_head columns of the d_model x d_model identity matrix
        identity_slice = torch.eye(d_model, model.cfg.d_head)
        # Repeat for each head
        W_V_identity = identity_slice.unsqueeze(0).repeat(model.cfg.n_heads, 1, 1)
        
        for block in model.blocks:
            block.attn.W_V.copy_(W_V_identity)
            block.attn.W_V.requires_grad = False

def set_WO_identity_and_freeze(model, d_model):
    with torch.no_grad():
        # Create a stack of identity-like matrices for W_O
        # Each matrix is of shape (d_head, d_model)
        # We take the first d_head rows of the d_model x d_model identity matrix
        identity_slice = torch.eye(model.cfg.d_head, d_model)
        # Repeat for each head
        W_O_identity = identity_slice.unsqueeze(0).repeat(model.cfg.n_heads, 1, 1)

        for block in model.blocks:
            block.attn.W_O.copy_(W_O_identity)
            block.attn.W_O.requires_grad = False


def make_model(n_layers=N_LAYER, n_heads=N_HEAD, d_model=D_MODEL, ln=USE_LN, use_bias=USE_BIAS, freeze_wv=FREEZE_WV, freeze_wo=FREEZE_WO):
    cfg = HookedTransformerConfig(
        n_layers = n_layers,
        n_heads = n_heads,
        d_model = d_model,
        d_head = d_model//n_heads,
        n_ctx=SEQ_LEN,
        d_vocab=VOCAB,
        attn_only=True, # no MLP!
        normalization_type=("LN" if ln else None),
    )
    model = HookedTransformer(cfg).to(DEV)
    if freeze_wv:
        set_WV_identity_and_freeze(model, d_model)
    if freeze_wo:
        set_WO_identity_and_freeze(model, d_model)
    if not use_bias:
        strip_bias(model)
    
    attach_custom_mask(model)
    return model

In [44]:
# ----- Model saving / loading helpers ------
def save_model(model, path = MODEL_PATH):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(path = MODEL_PATH, device = DEV):
    print("Loading model from", path)
    model = make_model()
    model.load_state_dict(
        torch.load(path, map_location=device)
    )  # map weights to target device
    model.eval()
    return model

In [45]:
# ---------- utilities ----------
def accuracy(m):
    m.eval()
    hits = tots = 0
    with torch.no_grad():
        for inputs, targets in val_dl:
            logits = m(inputs.to(DEV))[:, LIST_LEN+1:]  # (batch, 2, vocab)
            preds = logits.argmax(-1)
            hits += (preds == targets[:, LIST_LEN+1:].to(DEV)).sum().item()
            tots += preds.numel()
    return hits / tots


In [46]:
# Check train set
train_ds[:5]

(tensor([[ 60,  44, 101, 101, 101],
         [ 28,  90, 101, 101, 101],
         [ 93,  99, 101, 101, 101],
         [ 19,  17, 101, 101, 101],
         [ 49,  19, 101, 101, 101]]),
 tensor([[ 60,  44, 101,  60,  44],
         [ 28,  90, 101,  28,  90],
         [ 93,  99, 101,  93,  99],
         [ 19,  17, 101,  19,  17],
         [ 49,  19, 101,  49,  19]]))



Bias not needed

In [47]:
# LOAD existing model
if os.path.exists(MODEL_PATH):
    model = load_model(MODEL_PATH, device=DEV)
else:
    raise FileNotFoundError(f"Model file {MODEL_PATH} does not exist. Please train the model first.")

# from torchinfo import summary
# summary(model) 

Loading model from models/NoPAD_3layer_100dig_8d.pt
Moving model to device:  cuda


In [None]:
# --- Model Parameters Overview ---

print("--- Overview of TRAINABLE Model Parameters ---")
total_params = 0
trainable_params = 0

# Use a formatted string for better alignment
print(f"{'Parameter Name':<40} | {'Shape':<20} | {'Trainable':<10}")
print("-" * 80)

for name, param in model.named_parameters():
    shape_str = str(tuple(param.shape))
    is_trainable = "Yes" if param.requires_grad else "No"
    total_params += param.numel()

    if not param.requires_grad:
        continue
    # Print only trainable parameters
    print(f"{name:<40} | {shape_str:<20} | {is_trainable:<10}")
    trainable_params += param.numel()

print("-" * 80)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
print("-" * 80)

--- Overview of Model Parameters ---
Parameter Name                           | Shape                | Trainable 
--------------------------------------------------------------------------------
embed.W_E                                | (101, 8)             | Yes       
pos_embed.W_pos                          | (5, 8)               | Yes       
blocks.0.attn.W_Q                        | (1, 8, 8)            | Yes       
blocks.0.attn.W_K                        | (1, 8, 8)            | Yes       
blocks.1.attn.W_Q                        | (1, 8, 8)            | Yes       
blocks.1.attn.W_K                        | (1, 8, 8)            | Yes       
blocks.2.attn.W_Q                        | (1, 8, 8)            | Yes       
blocks.2.attn.W_K                        | (1, 8, 8)            | Yes       
unembed.W_U                              | (8, 101)             | Yes       
--------------------------------------------------------------------------------
Total parameters: 2621
Trainabl

In [None]:
# # Helper variables
# # W_O_l0, W_O_l1 = model.W_O # frozen as I
# W_K_l0, W_K_l1 = model.W_K
# W_Q_l0, W_Q_l1 = model.W_Q
# # W_V_l0, W_V_l1 = model.W_V # frozen as I
# W_pos = model.W_pos
# W_E = model.W_E[:-1] # exclude SEP embedding (last token in W_E vocab)
# final_pos_resid_initial = model.W_E[-1] + W_pos[LIST_LEN+1]  # W_pos[2] is the SEP token position, which is at index LIST_LEN+1
# W_U = model.W_U[:, :-1] # exclude SEP

# # print('W_O  ', tuple(W_O_l0.shape))
# print('W_K  ', tuple(W_K_l0.shape))
# print('W_Q  ', tuple(W_Q_l0.shape))
# # print('W_V  ', tuple(W_V_l0.shape))
# print('W_pos', tuple(W_pos.shape))
# print('W_E  ', tuple(W_E.shape))
# print('W_U  ', tuple(W_U.shape)) # (d_model, vocab-1) - excludes SEP token

### Model attention

We confirm below that the model does not leak attention onto the first two tokens, which are the inputs to the task. The model should only attend to the first two tokens when predicting the third token, and not attend to them at all when predicting the fourth and fifth tokens.

In [None]:
# --- Using Plotly for visualization ---

def check_attention(m, dataloader, eps=1e-3):
    for inputs, _ in dataloader:
        with torch.no_grad():
            _, cache = m.run_with_cache(inputs.to(DEV))
        for l in range(m.cfg.n_layers):
            pat = cache["pattern", l][:, 0]  # (batch, Q, K)
            leak = pat[:, LIST_LEN+1:, :LIST_LEN].sum(dim=-1)  # mass on forbidden keys
            if (leak > eps).any():
                raise ValueError(f"❌ Layer {l}: output tokens attend to x₁/x₂ by >{eps:.0e}")
    print("✅ no attention leakage onto x₁/x₂")


sample = val_ds[0][0] # Example input sequence
print(f"Sample sequence: {sample.cpu().numpy()}")  # Print the sample sequence for reference
_, cache = model.run_with_cache(sample.unsqueeze(0).to(DEV))

# --- Create Plotly visualization ---
token_labels = [f'd{i+1}' for i in range(LIST_LEN)] + ['SEP'] + [f'o{i+1}' for i in range(LIST_LEN)]
subplot_titles = [f"Layer {l} Attention Pattern" for l in range(model.cfg.n_layers)]

fig = make_subplots(
    rows=1, 
    cols=model.cfg.n_layers, 
    subplot_titles=subplot_titles,
    horizontal_spacing=0.08 # Add spacing between plots
)

for l in range(model.cfg.n_layers):
    pat = cache["pattern", l][0, 0].cpu().numpy()
    
    fig.add_trace(
        go.Heatmap(
            z=pat,
            x=token_labels,
            y=token_labels,
            colorscale="Viridis",
            zmin=0,
            zmax=1,
            showscale=(l == model.cfg.n_layers - 1) # Show colorbar only for the last plot
        ),
        row=1, col=l+1
    )

fig.update_layout(
    title_text="Attention Patterns for a Sample Sequence",
    width=1200,
    height=450,
    template="plotly_white"
)

# Apply settings to all axes
fig.update_xaxes(title_text="Key Position")
fig.update_yaxes(title_text="Query Position", autorange='reversed')

fig.show()

check_attention(model, val_dl)

Sample sequence: [ 80  52 100 100 100]


✅ no attention leakage onto x₁/x₂


In [None]:
# --- Mean Attention Patterns ---

all_pats = [[] for _ in range(model.cfg.n_layers)]
for inputs, _ in val_dl:
    with torch.no_grad():
        _, cache = model.run_with_cache(inputs.to(DEV))
    for l in range(model.cfg.n_layers):
        pat = cache["pattern", l][:, 0]  # (batch, Q, K)
        all_pats[l].append(pat)
all_pats = [torch.cat(pats, dim=0) for pats in all_pats]

for l, pats in enumerate(all_pats):
    identical = torch.allclose(pats, pats[0].expand_as(pats))
    print(f"Layer {l}: all attention patterns identical? {'✅' if identical else '❌'}")

with torch.no_grad():
    avg_pats = [
        torch.zeros(SEQ_LEN, SEQ_LEN, device=DEV) for _ in range(model.cfg.n_layers)
    ]
    n = 0
    for inputs, _ in val_dl:
        _, cache = model.run_with_cache(inputs.to(DEV))
        for l in range(model.cfg.n_layers):
            avg_pats[l] += cache["pattern", l][:, 0].sum(0)
        n += inputs.shape[0]
    avg_pats = [p / n for p in avg_pats]

# --- Visualize Average Attention Patterns ---
token_labels = [f'd{i+1}' for i in range(LIST_LEN)] + ['SEP'] + [f'o{i+1}' for i in range(LIST_LEN)]
subplot_titles = [f"Layer {l} Average Attention" for l in range(model.cfg.n_layers)]

fig = make_subplots(
    rows=1, 
    cols=model.cfg.n_layers, 
    subplot_titles=subplot_titles,
    horizontal_spacing=0.08
)

for l in range(model.cfg.n_layers):
    avg_pat_np = avg_pats[l].cpu().numpy()
    
    fig.add_trace(
        go.Heatmap(
            z=avg_pat_np,
            x=token_labels,
            y=token_labels,
            colorscale="Viridis",
            zmin=0,
            zmax=1,
            showscale=(l == model.cfg.n_layers - 1) # Show colorbar only for the last plot
        ),
        row=1, col=l+1
    )

fig.update_layout(
    title_text="Average Attention Patterns Across Validation Set",
    width=1200,
    height=450,
    template="plotly_white"
)
fig.update_xaxes(title_text="Key Position")
fig.update_yaxes(title_text="Query Position", autorange='reversed')
fig.show()


# Create a deep copy of the model to avoid modifying the original
model_with_avg_attn = copy.deepcopy(model)

def mk_hook(avg):
    logits = (avg + 1e-12).log()  # log-prob so softmax≈avg, ε avoids -∞

    def f(scores, hook):
        return logits.unsqueeze(0).unsqueeze(0).expand_as(scores)

    return f

for l in range(model_with_avg_attn.cfg.n_layers):
    model_with_avg_attn.blocks[l].attn.hook_attn_scores.add_hook(
        mk_hook(avg_pats[l]), dir="fwd"
    )

print("Accuracy with avg-attn:", accuracy(model_with_avg_attn))

Layer 0: all attention patterns identical? ❌
Layer 1: all attention patterns identical? ❌
Layer 2: all attention patterns identical? ❌


Accuracy with avg-attn: 0.9915


The attention patterns are not the same across inputs. However, we can replace the attention scores with their average and still get almost perfect performance.

## Interp

In [None]:
# --- Setup ---
head_index_to_ablate = 0 # fixed

# Define the loss function
loss_fn = torch.nn.CrossEntropyLoss()

# Check loss on validation set
val_inputs = val_ds.tensors[0].to(DEV)
val_targets = val_ds.tensors[1].to(DEV)
sample_idx = 0  # Use the xth sample in the validation set for comparing predictions
sample_list = val_inputs[sample_idx].cpu().numpy()

# --- Calculate Original Loss on last 2 digits ---
with torch.no_grad():
    original_logits, cache = model.run_with_cache(val_inputs, return_type="logits")
    output_logits = original_logits[:, LIST_LEN+1:] # Slice to get logits for the last two positions
    output_targets = val_targets[:, LIST_LEN+1:] # Slice to get the target tokens
    
    original_loss = loss_fn(output_logits.reshape(-1, VOCAB), output_targets.reshape(-1)) # Calculate the loss
    # Calculate accuracy
    original_predictions = original_logits.argmax(dim=-1) 
    original_output_predictions = original_predictions[:, LIST_LEN+1:]
    original_accuracy = (original_output_predictions == output_targets).float().mean()

print(f"Original loss: {original_loss.item()}")
print(f"Original accuracy: {original_accuracy.item()}")
print(f"Sample sequence {sample_idx}: {sample_list}")

Original loss: 0.007969769649207592
Original accuracy: 0.9977500438690186
Sample sequence 0: [ 80  52 100 100 100]


### Positional encoding

In [None]:
# --- Positional Encoding Ablation ---

print("--- Positional Encoding Ablation Results ---")
print(f"Original Loss: {original_loss.item():.4f}, Original Accuracy: {original_accuracy.item():.4f}")
print(f"Original prediction (sample {sample_idx}): {original_predictions[sample_idx].cpu().numpy()}")
print("-" * 60)

# Hook to subtract positional encodings from the residual stream
def ablate_pos_encoding_hook(resid, hook):
    # resid shape: [batch, seq_pos, d_model]
    # W_pos shape: [seq_pos, d_model]
    # We subtract the positional embeddings from the residual stream.
    # W_pos is automatically broadcast across the batch dimension.
    result = resid - model.W_pos
    # Restore some positions to their original values
    idx = 2
    result[:, idx] = resid[:, idx]
    return result

# Define the ablation experiments: a description and the hook point
ablation_points = [
    ("Before Layer 0", "blocks.0.hook_resid_pre"),
    ("After Layer 0", "blocks.0.hook_resid_post"), # i.e. pre layer 1
    ("After Layer 1", "blocks.1.hook_resid_post"),
    ("After Layer 2", "blocks.2.hook_resid_post"), # Before Unembed
]

# --- Perform Ablation for Each Case ---
for description, hook_name in ablation_points:
    with torch.no_grad():
        # Run the model with the ablation hook
        ablated_logits = model.run_with_hooks(
            val_inputs,
            return_type="logits",
            fwd_hooks=[(hook_name, ablate_pos_encoding_hook)]
        )
    
    # Calculate ablated loss on the output tokens
    output_logits_ablated = ablated_logits[:, LIST_LEN+1:]
    ablated_loss = loss_fn(output_logits_ablated.reshape(-1, VOCAB), output_targets.reshape(-1))
    
    # Calculate ablated accuracy on the output tokens
    ablated_predictions = ablated_logits.argmax(dim=-1)
    ablated_output_predictions = ablated_predictions[:, LIST_LEN+1:]
    ablated_accuracy = (ablated_output_predictions == output_targets).float().mean()
    
    print(f"Ablating Positional Encodings {description}:")
    print(f"  Ablated Loss: {ablated_loss.item():.4f}")
    print(f"  Ablated Accuracy: {ablated_accuracy.item():.4f}")
    print(f"  Ablated prediction (sample {sample_idx}):  {ablated_predictions[sample_idx].cpu().numpy()}")
    print("-" * 60)

--- Positional Encoding Ablation Results ---
Original Loss: 0.0080, Original Accuracy: 0.9978
Original prediction (sample 0): [80 80 80 80 52]
------------------------------------------------------------
Ablating Positional Encodings Before Layer 0:
  Ablated Loss: 14.6122
  Ablated Accuracy: 0.4140
  Ablated prediction (sample 0):  [80 80 80 80 80]
------------------------------------------------------------
Ablating Positional Encodings After Layer 0:
  Ablated Loss: 12.6352
  Ablated Accuracy: 0.5775
  Ablated prediction (sample 0):  [80 80 80 80 80]
------------------------------------------------------------
Ablating Positional Encodings After Layer 1:
  Ablated Loss: 13.2708
  Ablated Accuracy: 0.6825
  Ablated prediction (sample 0):  [80 80 80 80 81]
------------------------------------------------------------
Ablating Positional Encodings After Layer 2:
  Ablated Loss: 0.7545
  Ablated Accuracy: 0.8378
  Ablated prediction (sample 0):  [80 80 80 80 49]
-------------------------

### Residual stream

In [None]:
def ablate_skip_connection(layer_to_ablate):
    """
    Ablates the skip connection over the attention block for a specific layer.
    """
    
    # This dictionary will store the input residual stream for the layer
    captured_resid_pre = {}

    # Hook to capture the input to the attention block
    def capture_resid_pre_hook(resid, hook):
        captured_resid_pre['value'] = resid
        return resid

    # Hook to ablate the skip connection by subtracting the captured input
    def ablate_skip_hook(resid, hook):
        # resid here is resid_pre + attn_out
        # We subtract resid_pre to leave only attn_out
        result = resid - captured_resid_pre['value']
        # result[:, 3:] = resid[:, 3:]
        return result

    resid_pre_hook_name = f"blocks.{layer_to_ablate}.hook_resid_pre"
    hook_attn_out_name = f"blocks.{layer_to_ablate}.hook_attn_out"
    resid_post_hook_name = f"blocks.{layer_to_ablate}.hook_resid_post"

    with torch.no_grad():
        ablated_logits = model.run_with_hooks(
            val_inputs,
            return_type="logits",
            fwd_hooks=[
                (resid_pre_hook_name, capture_resid_pre_hook),
                (resid_post_hook_name, ablate_skip_hook)
            ]
        )
    return ablated_logits

print(f"--- Attention Skip Connection Ablation Results ---")
print(f"Sample sequence: {val_inputs[sample_idx].cpu().numpy()}") # last sample in validation set
print(f"Original Loss: {original_loss.item():.4f}")
print(f"Original Accuracy: {original_accuracy.item():.4f}")
print(f"Original predictions: {original_predictions[sample_idx].cpu().numpy()}")
print("-" * 50)

# --- Perform Ablation for Each Layer ---
for l in range(N_LAYER):
    ablated_logits = ablate_skip_connection(l)
    output_logits_ablated = ablated_logits[:, LIST_LEN+1:]
    ablated_loss = loss_fn(output_logits_ablated.reshape(-1, VOCAB), output_targets.reshape(-1))
    ablated_predictions = ablated_logits.argmax(dim=-1)
    
    # --- Calculate Ablated Accuracy ---
    ablated_output_predictions = ablated_predictions[:, LIST_LEN+1:]
    ablated_accuracy = (ablated_output_predictions == output_targets).float().mean()

    print(f"Ablating Skip Connection at Layer {l}:")
    print(f"  Ablated Loss: {ablated_loss.item():.4f}")
    print(f"  Loss Increase: {(ablated_loss - original_loss).item():.4f}")
    print(f"  Ablated Accuracy: {ablated_accuracy.item():.4f}")
    print(f"  Ablated predictions: {ablated_predictions[sample_idx].cpu().numpy()}")
    print("-" * 50)

--- Attention Skip Connection Ablation Results ---
Sample sequence: [ 80  52 100 100 100]
Original Loss: 0.0080
Original Accuracy: 0.9978
Original predictions: [80 80 80 80 52]
--------------------------------------------------
Ablating Skip Connection at Layer 0:
  Ablated Loss: 7.7675
  Loss Increase: 7.7595
  Ablated Accuracy: 0.3583
  Ablated predictions: [80 80 80 80 81]
--------------------------------------------------
Ablating Skip Connection at Layer 1:
  Ablated Loss: 28.3601
  Loss Increase: 28.3522
  Ablated Accuracy: 0.5650
  Ablated predictions: [80 80 80 80 80]
--------------------------------------------------
Ablating Skip Connection at Layer 2:
  Ablated Loss: 1.4052
  Loss Increase: 1.3973
  Ablated Accuracy: 0.7338
  Ablated predictions: [80 80 80 80 49]
--------------------------------------------------


In [None]:
# --- Ablate Skip Connections for Layer 0 and 1 Simultaneously ---

# This dictionary will store the input residual streams for each layer
captured_resid_pres = {}

# Hook to capture the input to an attention block
def capture_resid_pre_hook(resid, hook):
    layer_idx = hook.layer()
    captured_resid_pres[layer_idx] = resid
    return resid

# Hook to ablate the skip connection by subtracting the captured input
def ablate_skip_hook(resid, hook):
    layer_idx = hook.layer()
    # Subtract the captured resid_pre for the corresponding layer
    result = resid - captured_resid_pres[layer_idx]
    idx = -2
    # result[:, -2:] = resid[:, -2:] # keep some tokens intact (last 2)
    return result

# Define the hooks for both layers
fwd_hooks = []
for l in range(model.cfg.n_layers):
    resid_pre_hook_name = f"blocks.{l}.hook_resid_pre"
    resid_post_hook_name = f"blocks.{l}.hook_resid_post"
    fwd_hooks.extend([
        (resid_pre_hook_name, capture_resid_pre_hook),
        (resid_post_hook_name, ablate_skip_hook)
    ])

# Run the model with both skip connections ablated
with torch.no_grad():
    ablated_logits = model.run_with_hooks(
        val_inputs,
        return_type="logits",
        fwd_hooks=fwd_hooks
    )
    output_logits_ablated = ablated_logits[:, LIST_LEN+1:]
    ablated_loss = loss_fn(output_logits_ablated.reshape(-1, VOCAB), output_targets.reshape(-1))
    ablated_predictions = ablated_logits.argmax(dim=-1)
    
    # --- Calculate Ablated Accuracy ---
    ablated_output_predictions = ablated_predictions[:, LIST_LEN+1:]
    ablated_accuracy = (ablated_output_predictions == output_targets).float().mean()


print(f"--- Ablating All Skip Connections ---")
print(f"Validation set size: {len(val_inputs)} samples")
print("-" * 50)
print(f"{'Metric':<12} | {'Original':<10} | {'Ablated':<10}")
print("-" * 50)
print(f"{'Loss':<12} | {original_loss.item():<10.4f} | {ablated_loss.item():<10.4f}")
print(f"{'Accuracy':<12} | {original_accuracy.item():<10.4f} | {ablated_accuracy.item():<10.4f}")
print("-" * 50)
print(f"Example from {sample_idx}th validation sample:")
print(f"  Sample sequence:      {val_inputs[sample_idx].cpu().numpy()}")
print(f"  Original predictions: {original_predictions[sample_idx].cpu().numpy()}")
print(f"  Ablated predictions:  {ablated_predictions[sample_idx].cpu().numpy()}")

--- Ablating Both Skip Connections (Layers 0 & 1) ---
Validation set size: 2000 samples
--------------------------------------------------
Metric       | Original   | Ablated   
--------------------------------------------------
Loss         | 0.0080     | 18.4558   
Accuracy     | 0.9978     | 0.5445    
--------------------------------------------------
Example from 0th validation sample:
  Sample sequence:      [ 80  52 100 100 100]
  Original predictions: [80 80 80 80 52]
  Ablated predictions:  [80 80 80 80 80]


In [None]:
# --- Mean Ablation of Skip Connections ---

# 1. Cache the 'resid_pre' activations for each layer across the validation set
resid_pre_hook_names = [f"blocks.{l}.hook_resid_pre" for l in range(model.cfg.n_layers)]
with torch.no_grad():
    _, cache = model.run_with_cache(val_inputs, names_filter=lambda name: name in resid_pre_hook_names)

# 2. Calculate the mean of these activations
mean_resid_pres = {}
for l in range(model.cfg.n_layers):
    mean_resid_pres[l] = cache[resid_pre_hook_names[l]].mean(dim=(0, 1))

# --- Define hooks for mean ablation ---
captured_resid_pre = {}

def capture_resid_pre_hook(resid, hook):
    """Saves the current resid_pre to be subtracted in the next hook."""
    captured_resid_pre[hook.layer()] = resid
    return resid

def mean_ablate_skip_hook(resid, hook):
    """Replaces the skip connection with its mean value."""
    layer_idx = hook.layer()
    # resid_post = resid_pre + block_output
    # We want: mean_resid_pre + block_output
    # So we calculate: resid_post - resid_pre + mean_resid_pre
    return resid - captured_resid_pre[layer_idx] + mean_resid_pres[layer_idx]

# --- Function to calculate loss and accuracy ---
loss_fn = torch.nn.CrossEntropyLoss()

def calculate_metrics(logits, targets):
    """Calculates loss and accuracy for the output tokens."""
    output_logits = logits[:, LIST_LEN+1:]
    output_targets = targets[:, LIST_LEN+1:]
    
    loss = loss_fn(output_logits.reshape(-1, VOCAB), output_targets.reshape(-1)).item()
    
    predictions = output_logits.argmax(dim=-1)
    accuracy = (predictions == output_targets).float().mean().item()
    
    return loss, accuracy

# --- Evaluate metrics for each ablation case ---

print("--- Skip Connection Mean Ablation Metrics ---")

# Original metrics
with torch.no_grad():
    original_logits = model(val_inputs)
    og_loss, original_acc = calculate_metrics(original_logits, val_targets)
print(f"Original -> Loss: {og_loss:.4f}, Accuracy: {original_acc:.2%}")
print("-" * 50)

# Ablate each layer individually
for l in range(model.cfg.n_layers):
    fwd_hooks = [
        (f"blocks.{l}.hook_resid_pre", capture_resid_pre_hook),
        (f"blocks.{l}.hook_resid_post", mean_ablate_skip_hook)
    ]
    with torch.no_grad():
        ablated_logits = model.run_with_hooks(val_inputs, fwd_hooks=fwd_hooks)
        ablated_loss, ablated_acc = calculate_metrics(ablated_logits, val_targets)
    print(f"Ablating Layer {l} Skip -> Loss: {ablated_loss:.4f}, Accuracy: {ablated_acc:.2%}")

# Ablate all layers simultaneously
fwd_hooks = []
for l in range(model.cfg.n_layers):
    fwd_hooks.extend([
        (f"blocks.{l}.hook_resid_pre", capture_resid_pre_hook),
        (f"blocks.{l}.hook_resid_post", mean_ablate_skip_hook)
    ])
with torch.no_grad():
    ablated_logits = model.run_with_hooks(val_inputs, fwd_hooks=fwd_hooks)
    ablated_loss, ablated_acc = calculate_metrics(ablated_logits, val_targets)
print(f"Ablating All Skips -> Loss: {ablated_loss:.4f}, Accuracy: {ablated_acc:.2%}")
print("-" * 50)

--- Skip Connection Mean Ablation Metrics ---
Original -> Loss: 0.0080, Accuracy: 99.78%
--------------------------------------------------
Ablating Layer 0 Skip -> Loss: 5.2257, Accuracy: 47.18%
Ablating Layer 1 Skip -> Loss: 23.0065, Accuracy: 64.63%
Ablating Layer 2 Skip -> Loss: 0.3617, Accuracy: 94.15%
Ablating Both Skips -> Loss: 23.9267, Accuracy: 32.70%
--------------------------------------------------


### W_E and W_U

In [None]:
# https://umap-learn.readthedocs.io/en/latest/parameters.html
N_DIM_VIS = 2  # <-- CHANGE THIS VALUE to 2 or 3 to switch visualizations
umap_n_neighbors = min(5, VOCAB - 1)  # Use smaller n_neighbors for small dataset (max VOCAB-1 = 10)
umap_min_dist = 0.1  # Spread points out more
umap_metric = 'euclidean' # default: euclidean 

def visualize_w_e_and_w_u(model):
    """
    Extracts W_E and W_U, applies UMAP to get 2D or 3D projections,
    and creates side-by-side interactive plots based on N_DIM_VIS.
    """
    if N_DIM_VIS not in [2, 3]:
        raise ValueError("N_DIM_VIS must be set to 2 or 3.")
    

    print(f"\nStarting {N_DIM_VIS}D UMAP visualization  for W_E and W_U... \n(n_neighbors={umap_n_neighbors}, min_dist={umap_min_dist}, metric={umap_metric})")
    model.eval()

    w_e = model.embed.W_E.detach().cpu().numpy()
    w_u = model.unembed.W_U.T.detach().cpu().numpy()
    
    reducer = umap.UMAP(
        n_neighbors=umap_n_neighbors,
        min_dist=umap_min_dist,
        n_components=N_DIM_VIS,
        random_state=42, # UMAP is stochastic, so we set a seed for reproducibility
        metric=umap_metric,  # Use Euclidean distance for UMAP
        # verbose=True,
    )

    w_e_proj = reducer.fit_transform(w_e)
    w_u_proj = reducer.fit_transform(w_u)
    labels = [str(d) for d in DIGITS] + ['SEP']

    # --- Find common axis ranges across both projections ---
    all_proj = np.vstack([w_e_proj, w_u_proj])
    min_vals = all_proj.min(axis=0)
    max_vals = all_proj.max(axis=0)
    
    # Add a 10% margin for better visualization
    margin = (max_vals - min_vals) * 0.1
    ranges = [(min_v - m, max_v + m) for min_v, max_v, m in zip(min_vals, max_vals, margin)]

    if N_DIM_VIS == 3:
        fig = make_subplots(
            rows=1, cols=2,
            specs=[[{'type': 'scene'}, {'type': 'scene'}]],
            subplot_titles=('3D UMAP of W_E (Embeddings)', '3D UMAP of W_U (Unembeddings)')
        )
        fig.add_trace(go.Scatter3d(
            x=w_e_proj[:, 0], y=w_e_proj[:, 1], z=w_e_proj[:, 2],
            mode='markers+text', text=labels, textfont=dict(size=10, color='black'),
            marker=dict(size=5, color=list(range(VOCAB)), colorscale='viridis'),
            hoverinfo='text',
            hovertext=[f'Token: {l}<br>x: {x:.2f}, y: {y:.2f}, z: {z:.2f}' for l, x, y, z in zip(labels, w_e_proj[:, 0], w_e_proj[:, 1], w_e_proj[:, 2])],
            showlegend=False
        ), row=1, col=1)
        fig.add_trace(go.Scatter3d(
            x=w_u_proj[:, 0], y=w_u_proj[:, 1], z=w_u_proj[:, 2],
            mode='markers+text', text=labels, textfont=dict(size=10, color='black'),
            marker=dict(
                size=5, color=list(range(VOCAB)), colorscale='viridis', showscale=True,
                colorbar=dict(title="Token ID", tickvals=list(range(VOCAB)), ticktext=labels)
            ),
            hoverinfo='text',
            hovertext=[f'Token: {l}<br>x: {x:.2f}, y: {y:.2f}, z: {z:.2f}' for l, x, y, z in zip(labels, w_u_proj[:, 0], w_u_proj[:, 1], w_u_proj[:, 2])]
        ), row=1, col=2)
        fig.update_layout(title_text='3D UMAP Projections', height=700, width=1400)
        # Apply the same axis ranges to both 3D scenes
        fig.update_scenes(
            xaxis_title_text='Dim 1', yaxis_title_text='Dim 2', zaxis_title_text='Dim 3',
            xaxis_range=ranges[0], yaxis_range=ranges[1], zaxis_range=ranges[2]
        )
    else:  # N_DIM_VIS == 2
        fig = make_subplots(
            rows=1, cols=2,
            subplot_titles=('2D UMAP of W_E (Embeddings)', '2D UMAP of W_U (Unembeddings)')
        )
        fig.add_trace(go.Scatter(
            x=w_e_proj[:, 0], y=w_e_proj[:, 1],
            mode='markers+text', text=labels, textposition='top center',
            marker=dict(size=10, color=list(range(VOCAB)), colorscale='viridis'),
            hoverinfo='text',
            hovertext=[f'Token: {l}<br>x: {x:.3f}, y: {y:.3f}' for l, x, y in zip(labels, w_e_proj[:, 0], w_e_proj[:, 1])],
            showlegend=False
        ), row=1, col=1)
        fig.add_trace(go.Scatter(
            x=w_u_proj[:, 0], y=w_u_proj[:, 1],
            mode='markers+text', text=labels, textposition='top center',
            marker=dict(
                size=10, color=list(range(VOCAB)), colorscale='viridis', showscale=True,
                colorbar=dict(title="Token ID", tickvals=list(range(VOCAB)), ticktext=labels)
            ),
            hoverinfo='text',
            hovertext=[f'Token: {l}<br>x: {x:.3f}, y: {y:.3f}' for l, x, y in zip(labels, w_u_proj[:, 0], w_u_proj[:, 1])]
        ), row=1, col=2)
        fig.update_layout(title_text='2D UMAP Projections', height=600, width=1200, template='plotly_white')
        # Apply the same axis ranges to both 2D plots
        fig.update_xaxes(title_text="UMAP Dim 1", range=ranges[0])
        fig.update_yaxes(title_text="UMAP Dim 2", range=ranges[1])

    fig.show()

visualize_w_e_and_w_u(model)



Starting 2D UMAP visualization  for W_E and W_U... 
(n_neighbors=5, min_dist=0.1, metric=euclidean)



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def calculate_8d_pairwise_angles(model):
    """
    Calculates the pairwise angles (in degrees) between all token vectors
    in the original 8D embedding space.
    """
    print(f"Calculating pairwise angles between all vectors in the original {D_MODEL}D space...")
    model.eval()

    # Get original 8D embeddings
    w_e = model.embed.W_E.detach().cpu().numpy()
    w_u = model.unembed.W_U.T.detach().cpu().numpy()
    labels = [str(d) for d in DIGITS] + ['SEP']

    # --- Calculate for W_E ---
    # cosine_similarity(X, Y) computes (X @ Y.T) / (norm(X) * norm(Y))
    # The result is the cosine of the angle between vectors.
    cos_sim_e = cosine_similarity(w_e)
    # Clip values to be within [-1, 1] to avoid domain errors with arccos due to floating point inaccuracies
    cos_sim_e = np.clip(cos_sim_e, -1.0, 1.0)
    # The angle is the arccosine of the similarity
    angles_e_rad = np.arccos(cos_sim_e)
    angles_e_deg = np.rad2deg(angles_e_rad)
    df_e = pd.DataFrame(angles_e_deg, index=labels, columns=labels)

    # --- Calculate for W_U ---
    cos_sim_u = cosine_similarity(w_u)
    cos_sim_u = np.clip(cos_sim_u, -1.0, 1.0)
    angles_u_rad = np.arccos(cos_sim_u)
    angles_u_deg = np.rad2deg(angles_u_rad)
    df_u = pd.DataFrame(angles_u_deg, index=labels, columns=labels)

    print(f"\n--- Pairwise Angles (Degrees) for W_E (Embeddings) in {D_MODEL}D ---")
    # Use a specific float format for better readability
    print(df_e.to_markdown(floatfmt=".1f"))

    print(f"\n--- Pairwise Angles (Degrees) for W_U (Unembeddings) in {D_MODEL}D ---")
    print(df_u.to_markdown(floatfmt=".1f"))

    return df_e, df_u

# Calculate and display the pairwise angles from the 8D space
df_e_angles_8d, df_u_angles_8d = calculate_8d_pairwise_angles(model)

Calculating pairwise angles between all vectors in the original 8D space...

--- Pairwise Angles (Degrees) for W_E (Embeddings) in 8D ---
|     |     0 |     1 |     2 |     3 |     4 |     5 |     6 |     7 |     8 |     9 |    10 |    11 |    12 |    13 |    14 |    15 |    16 |    17 |    18 |    19 |    20 |    21 |    22 |    23 |    24 |    25 |    26 |    27 |    28 |    29 |    30 |    31 |    32 |    33 |    34 |    35 |    36 |    37 |    38 |    39 |    40 |    41 |    42 |    43 |    44 |    45 |    46 |    47 |    48 |    49 |    50 |    51 |    52 |    53 |    54 |    55 |    56 |    57 |    58 |    59 |    60 |    61 |    62 |    63 |    64 |    65 |    66 |    67 |    68 |    69 |    70 |    71 |    72 |    73 |    74 |    75 |    76 |    77 |    78 |    79 |    80 |    81 |    82 |    83 |    84 |    85 |    86 |    87 |    88 |    89 |    90 |    91 |    92 |    93 |    94 |    95 |    96 |    97 |    98 |    99 |   SEP |
|:----|------:|------:|------:|------:|------:

In [None]:
def analyze_spacing_invariant(df_angles, name):
    """
    Analyzes vector spacing in a permutation-invariant way by checking
    the angles to the two nearest neighbors for each digit vector.
    """
    # We only care about the digit embeddings for this analysis
    digit_angles_df = df_angles.iloc[:len(DIGITS), :len(DIGITS)].copy()
    
    all_neighbor_angles = []
    
    print(f"\n--- Permutation-Invariant Spacing Analysis for {name} ---")
    # print("This method finds the angles to the two nearest neighbors for each digit vector.")

    for i in range(len(DIGITS)):
        if i in []:
            print(f"Skipping digit {DIGITS[i]} (index {i}) for analysis.")
            continue
        # Get angles from digit i to all other digits
        angles_from_i = digit_angles_df.iloc[i].drop(digit_angles_df.columns[i])
        # Sort to find the two smallest angles
        sorted_angles = angles_from_i.sort_values()
        # The two nearest neighbors
        all_neighbor_angles.extend(sorted_angles.iloc[:2].values)
        # print(f"Digit {DIGITS[i]} has nearest neighbours: {sorted_angles.index[0]} ({sorted_angles.iloc[0]:.2f}°) & {sorted_angles.index[1]} ({sorted_angles.iloc[1]:.2f}°)")

    neighbor_angles = np.array(all_neighbor_angles)
    
    print(f"Mean neighbor angle: {neighbor_angles.mean():.2f}°")
    print(f"Std Dev of neighbor angles: {neighbor_angles.std():.2f}°")

analyze_spacing_invariant(df_e_angles_8d, "W_E (Embeddings)")
analyze_spacing_invariant(df_u_angles_8d, "W_U (Unembeddings)")


--- Permutation-Invariant Spacing Analysis for W_E (Embeddings) ---
Mean neighbor angle: 50.23°
Std Dev of neighbor angles: 3.51°

--- Permutation-Invariant Spacing Analysis for W_U (Unembeddings) ---
Mean neighbor angle: 51.72°
Std Dev of neighbor angles: 3.51°


### Attention

In [None]:
# try setting specific attention positions to zero
layer_to_ablate = 0 # output digits do nothijg in layer 0
print(f"Layer {layer_to_ablate} Ablation")
# Define which specific attention position you want to zero out
key_pos, query_pos = 1,2 # (top left = [0,0]. query is the row, key is the column)

def specific_attention_ablation_hook(
    pattern, # Shape: [batch, head_index, query_pos, key_pos]
    hook
):    
    # Set specific attention weight to 0
    with torch.no_grad():
        # This print statement will only run once if the validation set is processed as a single batch
        print("--- Attention Pattern Change (Head 0) ---")
        print(f'BEFORE Ablation:\n{pattern[0, head_index_to_ablate, :, :].cpu().numpy()}')

        # pattern[:, head_index_to_ablate, :2, :] = 0.0
        # pattern[:, head_index_to_ablate, 3:, :] = 0.0
        # pattern[:, head_index_to_ablate, 3, :] = 0.0
        pattern[:, head_index_to_ablate, query_pos, key_pos] = 0.0
        # pattern[:, head_index_to_ablate, 2, :2] = 0.5
        
        print(f'AFTER Ablation:\n{pattern[0, head_index_to_ablate, :, :].cpu().numpy()}')
        print("-" * 45)
    
    return pattern

# Get the attention pattern hook name
attn_pattern_hook_name = utils.get_act_name("pattern", layer_to_ablate)

# --- Calculate Ablated Loss on last 2 digits ---
with torch.no_grad():
    ablated_logits = model.run_with_hooks(
        val_inputs,
        return_type="logits",  # Get logits instead of loss
        fwd_hooks=[(attn_pattern_hook_name, specific_attention_ablation_hook)],
    )
    # Slice to get logits for the last two positions
    output_logits_ablated = ablated_logits[:, LIST_LEN+1:]
    # Calculate the loss
    ablated_loss = loss_fn(
        output_logits_ablated.reshape(-1, VOCAB), output_targets.reshape(-1)
    )

    # Calculate accuracy
    ablated_predictions = ablated_logits.argmax(dim=-1)
    ablated_output_predictions = ablated_predictions[:, LIST_LEN+1:]
    ablated_accuracy = (ablated_output_predictions == output_targets).float().mean()

print("\n--- Performance Metrics (on last 2 digits) ---")
print(f"{'':<12} | {'Original':<10} | {'Ablated':<10}")
print("-" * 45)
print(f"{'Loss:':<12} | {original_loss.item():<10.3f} | {ablated_loss.item():<10.3f}")
print(f"{'Accuracy:':<12} | {original_accuracy.item():<10.3f} | {ablated_accuracy.item():<10.3f}")
print("-" * 45)

# Get the predicted tokens from the ablated logits
ablated_predictions = ablated_logits.argmax(dim=-1)

print(f"\n--- Prediction Comparison (Sample {sample_idx}) ---")
print(f"Original sequence:   {val_inputs[sample_idx].cpu().numpy()}")
print(f"Original prediction: {original_predictions[sample_idx].cpu().numpy()}")
print(f"Ablated prediction:  {ablated_predictions[sample_idx].cpu().numpy()}")
print("-" * 45)

Layer 0 Ablation
--- Attention Pattern Change (Head 0) ---
BEFORE Ablation:
[[1.00000 0.00000 0.00000 0.00000 0.00000]
 [1.00000 0.00000 0.00000 0.00000 0.00000]
 [0.40103 0.59897 0.00000 0.00000 0.00000]
 [0.00000 0.00000 1.00000 0.00000 0.00000]
 [0.00000 0.00000 0.90175 0.09825 0.00000]]
AFTER Ablation:
[[1.00000 0.00000 0.00000 0.00000 0.00000]
 [1.00000 0.00000 0.00000 0.00000 0.00000]
 [0.40103 0.00000 0.00000 0.00000 0.00000]
 [0.00000 0.00000 1.00000 0.00000 0.00000]
 [0.00000 0.00000 0.90175 0.09825 0.00000]]
---------------------------------------------

--- Performance Metrics (on last 2 digits) ---
             | Original   | Ablated   
---------------------------------------------
Loss:        | 0.008      | 16.598    
Accuracy:    | 0.998      | 0.497     
---------------------------------------------

--- Prediction Comparison (Sample 0) ---
Original sequence:   [ 80  52 100 100 100]
Original prediction: [80 80 80 80 52]
Ablated prediction:  [80 80 80 80 81]
------------

In [None]:
# --- Analyze Failure Cases ---
# Find indices where the ablated prediction is incorrect
is_incorrect = (ablated_output_predictions != output_targets).any(dim=1)
error_indices = torch.where(is_incorrect)[0]

print(f"\n--- Analysis of {len(error_indices)} Failure Cases ---")
if len(error_indices) > 0:
    # Limit the number of printed examples for readability
    n_examples_to_show = min(10, len(error_indices))
    print(f"Showing the first {n_examples_to_show} incorrect predictions:")
    
    for i in range(n_examples_to_show):
        idx = error_indices[i]
        full_sequence = val_inputs[idx].cpu().numpy()
        input_digits = full_sequence[:LIST_LEN]
        correct_output = val_targets[idx, LIST_LEN+1:].cpu().numpy()
        predicted_output = ablated_predictions[idx, LIST_LEN+1:].cpu().numpy()
        
        print(f"\nExample {i+1} (Index: {idx}):")
        print(f"  Input Digits:     {input_digits}")
        print(f"  Correct Output:   {correct_output}")
        print(f"  Predicted Output: {predicted_output} <--- ERROR")

    # TEST
    bad_preds = ablated_predictions[error_indices, LIST_LEN+1:].cpu().numpy()
    bad_preds_2 = []
    c = 0
    for p in bad_preds:
        if p[0] == p[1]:
            c+=1
        else:
            bad_preds_2.append(p)
    print(f'{c} duped')
else:
    print("No incorrect predictions found after ablation.")


--- Analysis of 9 Failure Cases ---
Showing the first 9 incorrect predictions:

Example 1 (Index: 76):
  Input Digits:     [39 95]
  Correct Output:   [39 95]
  Predicted Output: [93 95] <--- ERROR

Example 2 (Index: 383):
  Input Digits:     [11 26]
  Correct Output:   [11 26]
  Predicted Output: [74 26] <--- ERROR

Example 3 (Index: 424):
  Input Digits:     [88 26]
  Correct Output:   [88 26]
  Predicted Output: [88 87] <--- ERROR

Example 4 (Index: 585):
  Input Digits:     [78 95]
  Correct Output:   [78 95]
  Predicted Output: [32 95] <--- ERROR

Example 5 (Index: 638):
  Input Digits:     [63 95]
  Correct Output:   [63 95]
  Predicted Output: [ 2 95] <--- ERROR

Example 6 (Index: 1031):
  Input Digits:     [87 49]
  Correct Output:   [87 49]
  Predicted Output: [87 87] <--- ERROR

Example 7 (Index: 1168):
  Input Digits:     [95 30]
  Correct Output:   [95 30]
  Predicted Output: [30 30] <--- ERROR

Example 8 (Index: 1475):
  Input Digits:     [10 55]
  Correct Output:   [10 5

In [None]:
len(bad_preds_2)

7