In [None]:
# # Comparing Neuron Activations Across Prompts with TransformerLens

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from transformer_lens import HookedTransformer

In [None]:
# ## Config

In [None]:
model_name = "gpt2-small"
layer_idx = 6
act_name = f"blocks.{layer_idx}.mlp.hook_post"  # MLP post-layer hook
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
prompt_1 = "The cat sat on the mat."
prompt_2 = "The kitten sat on the mat."

In [None]:
# ## Load Model

In [None]:
model = HookedTransformer.from_pretrained(model_name).to(device).eval()

In [None]:
# ## Activation Extraction

In [None]:
def get_activations(prompt):
    tokens = model.to_tokens(prompt).to(device)
    _, cache = model.run_with_cache(tokens)
    return cache[act_name].squeeze(0), tokens

In [None]:
act_1, tokens_1 = get_activations(prompt_1)
act_2, tokens_2 = get_activations(prompt_2)

In [None]:
# Align lengths
min_len = min(act_1.shape[0], act_2.shape[0])
act_1, act_2 = act_1[:min_len], act_2[:min_len]
tokens_1 = tokens_1[0, :min_len]

In [None]:
# ## MLP Transcoder

In [None]:
class MLPTranscoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
        )

In [None]:
def forward(self, x):
        return self.net(x)

In [None]:
transcoder = MLPTranscoder(act_1.size(-1)).to(device)

In [None]:
# ## Training

In [None]:
optimizer = torch.optim.Adam(transcoder.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

In [None]:
for step in range(1000):
    optimizer.zero_grad()
    output = transcoder(act_1)
    loss = loss_fn(output, act_2)
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(f"Step {step}: Loss = {loss.item():.4f}")

In [None]:
# ## Evaluation

In [None]:
with torch.no_grad():
    aligned = transcoder(act_1)
    cosine = F.cosine_similarity(aligned, act_2, dim=-1).mean().item()
    mse = F.mse_loss(aligned, act_2).item()

In [None]:
print(f"\nFinal Cosine Similarity: {cosine:.4f}")
print(f"Final MSE: {mse:.4f}")

In [None]:
# ## Neuron-Wise Difference (Bar Plot)

In [None]:
with torch.no_grad():
    mean_act_1 = act_1.mean(dim=0)
    mean_act_2 = act_2.mean(dim=0)
    neuron_diff = (mean_act_1 - mean_act_2).abs().cpu()

In [None]:
topk = 20
topk_vals, topk_idx = torch.topk(neuron_diff, topk)

In [None]:
plt.figure(figsize=(12, 4))
plt.bar(range(topk), topk_vals.numpy())
plt.xticks(range(topk), topk_idx.numpy(), rotation=45)
plt.title(f"Top {topk} Differing Neurons in MLP Layer {layer_idx}")
plt.xlabel("Neuron Index")
plt.ylabel("Activation Difference (abs)")
plt.tight_layout()
plt.show()

In [None]:
# ## Token-wise Neuron Difference (Heatmap)

In [None]:
diff_matrix = (act_1 - act_2).abs().cpu().numpy()
token_labels = model.to_str_tokens(tokens_1)

In [None]:
plt.figure(figsize=(14, 6))
sns.heatmap(diff_matrix.T, cmap="viridis", cbar=True, xticklabels=token_labels, yticklabels=False)
plt.title(f"Neuron-wise Differences over Tokens (Layer {layer_idx})")
plt.xlabel("Token")
plt.ylabel("Neuron Index")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()