<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Grokking_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Grokking Demo Notebook

<b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

# Setup
(No need to read)

In [None]:
TRAIN_MODEL = True

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = True
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install circuitsvis

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

# Model Training

## Config

In [None]:
frac_train = 0.72

# Optimizer config
lr = 1e-4
wd = 0.1
betas = (0.9, 0.98)

num_epochs = 10000
checkpoint_every = 100

data_seed = 598

## Define Task
We want the model to learn to compute $sin(k x)$ given $k$. Hence, the input is $k$ and the desired output is a discretized version of $sin(k x)$ with $x$ ranging from $[0, 2\pi]$ with 100 data points.

We use 1000 samples for $k$, drawn uniformly at random from the $[0.1, 10.]$ interval.

In [None]:
num_samples = 1000
num_x_points = 100
k_min = 0.1
k_max = 10.0

# Create the k values that will be the input to the model
k_vector = (k_max - k_min) * torch.rand(num_samples, generator=torch.manual_seed(data_seed)) + k_min
dataset = k_vector.unsqueeze(1).to(device) # Shape: [num_samples, 1]

# Create the corresponding sin(kx) curves which are the labels
x_points = torch.linspace(0, 2 * torch.pi, num_x_points, device=device) # Shape: [num_x_points]
# Use broadcasting to compute all sine curves at once
# (num_samples, 1) * (num_x_points,) -> (num_samples, num_x_points)
labels = torch.sin(dataset * x_points)

# --- Split into Training and Test Sets ---
torch.manual_seed(data_seed)
indices = torch.randperm(num_samples)
cutoff = int(num_samples * frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]

## Define Model

The standard hooked transformer config from TransformerLens is set up for discrete tokens. We are however considering continuous inputs and outputs in this regression task. Hence, in order to keep all the useful features of HookedTransformer, we will make a new class that inherits from it and adjust the input and output layers to allow for continuous inputs and outputs.

In [None]:
class SineTransformer(HookedTransformer):
    def __init__(self, cfg):
        super().__init__(cfg)

        # The parent __init__ created self.embed and self.unembed, which are for
        # discrete tokens. We don't need them. It's good practice to delete them
        # to avoid confusion.
        del self.embed
        del self.unembed

        # --- Create New Layers for our Regression Task ---

        # 1. An input layer to project our continuous k value into d_model
        self.k_input_projection = nn.Linear(1, self.cfg.d_model)

        # 2. An output layer to project from d_model back to a single float value
        # at each position in the output sequence.
        self.output_projection = nn.Linear(self.cfg.d_model, 1)

    def forward(self, k_values):
        # k_values shape: [batch_size, 1]

        # 1. Project k into the model's dimension using our new layer
        k_embedding = self.k_input_projection(k_values) # [batch_size, d_model]

        # 2. Get the positional embeddings for our 100 x-points
        # The 'n_ctx' in our config now represents the number of output points.
        # W_pos is the learned positional embedding matrix of shape [n_ctx, d_model]
        positional_embeddings = self.pos_embed.W_pos[:self.cfg.n_ctx, :] # [n_ctx, d_model]

        residual_stream = k_embedding.unsqueeze(1) + positional_embeddings

        # --- THE FIX IS HERE ---
        # Instead of calling self.blocks directly, we loop through the modules inside it.
        for block in self.blocks:
            residual_stream = block(residual_stream)

        # Now we use the final state of the residual stream
        output_values = self.output_projection(residual_stream)

        return output_values.squeeze(-1)

# This is the updated configuration for the Sine task
cfg = HookedTransformerConfig(
    n_layers=1,
    n_heads=4,
    d_model=128,
    d_head=32,
    d_mlp=512,
    act_fn="relu",
    normalization_type="LN", # LayerNorm is generally helpful. You can set to None.
    n_ctx=num_x_points,
    # d_vocab and d_vocab_out are no longer used by our new model,
    # but the config object requires them. We can set them to a dummy value.
    d_vocab=10,
    d_vocab_out=10,
    init_weights=True,
    device=device,
    seed=999,
)

# Instantiate our NEW model class with the updated config
model = SineTransformer(cfg).to(device)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [None]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [None]:
loss_fn = nn.MSELoss()

train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(f"Initial Training Loss: {train_loss.item()}")

test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(f"Initial Test Loss: {test_loss.item()}")

## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent. We do this so to make training smoother and reduce the number of slingshots.

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR

# --- New, more powerful optimizer and scheduler config ---
lr = 1e-3           # Start high again
wd = 0.01           # A much more reasonable weight decay
betas = (0.9, 0.98)

# Re-initialize the optimizer with the new parameters
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

# Initialize the scheduler
# T_max is the total number of epochs you plan to run this new training session for.
# It will decay the LR from 1e-3 down to 0 over this period.
new_num_epochs = 20000
scheduler = CosineAnnealingLR(optimizer, T_max=new_num_epochs)

# --- New Training Loop (in a new cell) ---
# We can start new lists for this run to see the effect clearly
new_train_losses = []
new_test_losses = []

for epoch in tqdm.tqdm(range(new_num_epochs)):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)

    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()

    # Crucially, step the scheduler after each epoch
    scheduler.step()

    new_train_losses.append(train_loss.item())

    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits, test_labels)
        new_test_losses.append(test_loss.item())

    if ((epoch+1)%checkpoint_every)==0:
        # You can get the current learning rate from the scheduler to see it decrease
        current_lr = scheduler.get_last_lr()[0]
        print(f"Epoch {epoch} | Test Loss {test_loss.item():.6f} | LR {current_lr:.6f}")

Save trained model to google drive so we don't have to re-train everytime google colab reconnects

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os

# Create a folder in your Drive if it doesn't exist
DRIVE_SAVE_DIR = "/content/drive/My Drive/Colab Notebooks/SineModel"
os.makedirs(DRIVE_SAVE_DIR, exist_ok=True)

# Define the full path for your model file
PERSISTENT_PTH_LOCATION = os.path.join(DRIVE_SAVE_DIR, "sine_model_v1.pth")

# Now, use this path when you save
# (Assuming all the variables you want to save are defined)
torch.save(
    {
        "model": model.state_dict(),
        "config": model.cfg, # Or your config dictionary
        "test_losses": new_test_losses,
        "train_losses": new_train_losses,
        "train_indices": train_indices,
        "test_indices": test_indices,
    },
    PERSISTENT_PTH_LOCATION
)

print(f"Model saved successfully to: {PERSISTENT_PTH_LOCATION}")

If already trained, load the model from google drive:

In [None]:
import torch

# --- Re-define your model's architecture ---
# You must use the same parameters as the saved model!
D_MODEL = 128
N_HEADS = 4
D_MLP = D_MODEL * 4
N_LAYERS = 2
N_X_POINTS = 100

# (Import your model class definition, e.g., from a .py file or a cell)
# from your_model_file import SineModel

# Instantiate a new, untrained model with the correct architecture
cfg = HookedTransformerConfig(
    n_layers=1,
    n_heads=4,
    d_model=128,
    d_head=32,
    d_mlp=512,
    act_fn="relu",
    normalization_type="LN", # LayerNorm is generally helpful. You can set to None.
    n_ctx=num_x_points,
    # d_vocab and d_vocab_out are no longer used by our new model,
    # but the config object requires them. We can set them to a dummy value.
    d_vocab=10,
    d_vocab_out=10,
    init_weights=True,
    device=device,
    seed=999,
)

# Instantiate our NEW model class with the updated config
model = SineTransformer(cfg).to(device)

# Define the path to your saved model file
PERSISTENT_PTH_LOCATION = "/content/drive/My Drive/Colab Notebooks/SineModel/sine_model_v1.pth"

# Load the entire dictionary from the file
saved_data = torch.load(PERSISTENT_PTH_LOCATION, weights_only=False)

# Load the model's weights from the dictionary
model.load_state_dict(saved_data['model'])

# It's good practice to put the model in evaluation mode after loading
# This disables things like dropout, which you don't want for inference
model.eval()

print("Model weights loaded successfully!")

# You can also load the other data you saved
test_losses = saved_data['test_losses']
train_losses = saved_data['train_losses']
print(f"Final saved test loss: {test_losses[-1]}")

## Show Model Training Statistics, Check that it groks!

In [None]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line

In [None]:
line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Fibonacci", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)

# Analysing the Model

Get key weight matrices:

In [None]:
# Assume 'model' is your trained instance
# and 'device' is set correctly

# Choose some k values to inspect
k_to_inspect = torch.tensor([[1.0], [5.0], [10.0]], device=device)

# 1. Compute the k-embedding
k_embedding = model.k_input_projection(k_to_inspect)
print("Shape of k_embedding for 3 sample k's:", k_embedding.shape)

# 2. Get the positional embedding
positional_embedding = model.pos_embed.W_pos[:num_x_points, :]
print("\nShape of positional_embedding:", positional_embedding.shape) # Expected: [100, 128]

# 3. Compute the full initial residual stream
k_broadcast = k_embedding.unsqueeze(1).expand(-1, num_x_points, -1)

# Now the addition will work because PyTorch can broadcast [100, 128] to match [3, 100, 128]
initial_residual_stream = k_broadcast + positional_embedding

print("\nShape of the full initial embedding:", initial_residual_stream.shape)
print("This is the 'effective embedding' that the first transformer block sees.")


In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt


# --- 1. Generate a set of k values for analysis ---
k_for_pca = torch.linspace(k_min, k_max, 1000, device=device).unsqueeze(-1)

# --- 2. Get the embeddings for these k values ---
model.eval() # Put the model in evaluation mode
with torch.no_grad(): # We don't need to compute gradients
    k_embeddings = model.k_input_projection(k_for_pca)

# Move embeddings to CPU and convert to numpy for sklearn
k_embeddings_np = k_embeddings.cpu().numpy()

# --- 3. Perform PCA ---
pca = PCA() # By default, finds all components
pca.fit(k_embeddings_np)

# --- 4. Plot the explained variance ---
explained_variance = pca.explained_variance_ratio_
cumulative_variance = np.cumsum(explained_variance)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.bar(range(1, len(explained_variance) + 1), explained_variance)
plt.xlabel("Principal Component")
plt.ylabel("Explained Variance Ratio")
plt.title("Scree Plot")

plt.subplot(1, 2, 2)
plt.plot(range(1, len(cumulative_variance) + 1), cumulative_variance, marker='o', linestyle='--')
plt.xlabel("Number of Components")
plt.ylabel("Cumulative Explained Variance")
plt.title("Cumulative Variance Plot")
plt.grid(True)
plt.ylim(0, 1.1)
plt.axhline(y=0.99, color='r', linestyle=':', label='99% Variance')
plt.legend()

plt.tight_layout()
plt.show()

# --- Interpretation ---
num_components_for_99_variance = np.argmax(cumulative_variance >= 0.99) + 1
print(f"Number of components to explain 99% of variance: {num_components_for_99_variance}")


All of the variance is explained by a single component of the 128-dimensional vector. We now investigate what this component looks like.

In [None]:
# --- 1. Project the embeddings onto the first principal component ---
# The pca.transform() method projects the data onto the components.
# We only want the first column, which corresponds to PC1.
projected_data_pc1 = pca.transform(k_embeddings_np)[:, 0]

# --- 2. Prepare the original k values for plotting ---
# We need them as a flat numpy array on the CPU.
k_values_for_plot = k_for_pca.cpu().numpy().flatten()

# --- 3. Create the plot ---
plt.figure(figsize=(8, 6))
plt.plot(k_values_for_plot, projected_data_pc1)
plt.xlabel("Input k value")
plt.ylabel("Projection onto Principal Component 1")
plt.title("Learned Representation of k (Projected onto PC1)")
plt.grid(True)
plt.show()


Now, we want to look at the positional embedding: this gives us an idea of what the model thinks 'space' looks like. We do PCA again:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torch.nn.functional as F

# --- 1. Get the positional embedding matrix ---
# W_pos is the parameter we want to analyze.
# It has shape [num_x_points, d_model], e.g., [100, 128].
w_pos = model.pos_embed.W_pos.detach().cpu().numpy()

# --- 2. Perform and Plot PCA ---
pca = PCA()
pca.fit(w_pos)

explained_variance = pca.explained_variance_ratio_

plt.figure(figsize=(6, 4))
plt.bar(range(1, 101), explained_variance[:100]) # Plot first 20 components
plt.xlabel("Principal Component")
plt.ylabel("Explained Variance Ratio")
plt.title("PCA of Positional Embeddings (W_pos)")
plt.show()

# --- 3. Compute and Plot the Cosine Similarity Matrix ---
# This shows the similarity between the vector for position i and position j
w_pos_tensor = torch.from_numpy(w_pos)
cosine_sim = F.cosine_similarity(w_pos_tensor.unsqueeze(1), w_pos_tensor.unsqueeze(0), dim=-1)

plt.figure(figsize=(7, 6))
plt.imshow(cosine_sim.numpy(), cmap='viridis')
plt.colorbar(label="Cosine Similarity")
plt.xlabel("Position j")
plt.ylabel("Position i")
plt.title("Cosine Similarity of Positional Embeddings")
plt.show()

Observations from the plots above:

1.   The positional embedding is high-dimensional. There's roughly 30 components which explain >1% of the variance. The one that explains the most explains only roughly ~12%. The decay in explained variance seems to be roughly exponential.
2.   There is some brightness along the diagonal in the cosine similarity heat map. This indicates that point $i$ is more similar to point $i + 1$ than it is to point $i + 20$. Could be interpreted as the model understanding that this is a continuous curve. However, this effect does not appear super strong.

Let's now look into what the top few principal components actually are.



In [None]:
projected_data = pca.transform(w_pos)

# --- 2. Create the x-axis for our plot (the position index) ---
positions = np.arange(w_pos.shape[0])

# --- 3. Plot the first few components ---
num_components_to_plot = 3
fig, axes = plt.subplots(num_components_to_plot, 1, figsize=(10, 8), sharex=True)
fig.suptitle("Visualization of the First 3 Principal Components of Positional Embeddings", fontsize=16)

for i in range(num_components_to_plot):
    # Get the projection values for the current component
    projection_values = projected_data[:, i]

    ax = axes[i]
    ax.plot(positions, projection_values)
    ax.set_ylabel("Projection Value")
    ax.set_title(f"Principal Component {i + 1}")
    ax.grid(True)

axes[-1].set_xlabel("Position Index (x)")
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to make room for suptitle
plt.show()


The model does not rely on a simple, human-interpretable geometric basis (like Fourier or Polynomial) for its positional embeddings. Instead, it learns a high-dimensional, apparently unstructured mapping. This implies the core computation for synthesizing the sine wave must occur in the transformer's processing blocks.

Let's now investigate one of the transformer processing blocks, namely the attention heads. We'll start by visualizing the attention patterns.

In [None]:
# --- Hook function to cache the attention pattern ---
def cache_attention_pattern(pattern, hook):
    hook.ctx['pattern'] = pattern

# --- Visualization code ---
k_to_visualize = torch.tensor([[2.5]], device=device) # Example k
layer_to_inspect = 0 # First layer - There's only one layer
head_to_inspect = 0  # First head

# Run the model with the hook
model.eval()
with torch.no_grad():
    # The hook will run during the forward pass and store the pattern in hook.ctx
    logits, cache = model.run_with_cache(
        k_to_visualize,
        names_filter=f"blocks.{layer_to_inspect}.attn.hook_pattern"
    )

# Retrieve the cached pattern
attention_pattern = cache[f"blocks.{layer_to_inspect}.attn.hook_pattern"][0, head_to_inspect].cpu()

# Plot the heatmap
plt.figure(figsize=(7,6))
plt.imshow(attention_pattern, cmap='viridis')
plt.title(f"Attention Pattern for L{layer_to_inspect}H{head_to_inspect} with k={k_to_visualize.item():.1f}")
plt.xlabel("Key Position (j)")
plt.ylabel("Query Position (i)")
plt.show()


One observation we make is that at least for k=3.0, none of the heads pay any attention to keys above 50. This could be because the heads have correctly figured out that all information required to approximate the function is contained in the first half of the domain.

In [None]:
layer_to_patch = 0
hook_name = f"blocks.{layer_to_patch}.hook_attn_out"

# Define our k values
k_source = 3.0  # The k for the source of the activation (odd)
k_dest = 4.0    # The k for the destination run (even)

# --- Step 2: Run the source forward pass and cache the activation ---
model.eval()
with torch.no_grad():
    # We don't need the output, just the cache
    _, source_cache = model.run_with_cache(
        torch.tensor([[k_source]], device=device),
        names_filter=hook_name
    )
    # Store the activation we want to patch in
    activation_to_patch = source_cache[hook_name]

print(f"Cached activation from k={k_source} with shape: {activation_to_patch.shape}")


# --- Step 3: Define a patching hook ---
# This hook will overwrite the activation at a specific layer
def patch_activation_hook(activation, hook):
    # Overwrite the current activation with the one we saved
    # The slicing [0] is to remove the batch dimension if it exists
    return activation_to_patch[0]


# --- Step 4: Run the destination pass with the patching hook ---
model.eval()
with torch.no_grad():
    # Fwd_hooks will apply our patching hook for this run only
    patched_output = model.run_with_hooks(
        torch.tensor([[k_dest]], device=device),
        fwd_hooks=[(hook_name, patch_activation_hook)]
    )

print(f"Ran model with k={k_dest} but patched in attention output from k={k_source}")


# --- Step 5: Visualize the results to test our hypothesis ---
# For comparison, let's also get the clean outputs for k=3 and k=4
with torch.no_grad():
    clean_output_k3 = model(torch.tensor([[k_source]], device=device))
    clean_output_k4 = model(torch.tensor([[k_dest]], device=device))

plt.figure(figsize=(12, 8))

# Plot 1: Clean sin(3x) - should be anti-symmetric
plt.subplot(3, 1, 1)
plt.plot(clean_output_k3[0].cpu().numpy(), label="Clean output for k=3.0 (Ground Truth)")
plt.axvline(x=50, color='r', linestyle='--', label='x = pi')
plt.title("Expected Anti-Symmetry")
plt.legend()
plt.grid(True)

# Plot 2: Clean sin(4x) - should be symmetric
plt.subplot(3, 1, 2)
plt.plot(clean_output_k4[0].cpu().numpy(), label="Clean output for k=4.0", color='green')
plt.axvline(x=50, color='r', linestyle='--', label='x = pi')
plt.title("Expected Symmetry")
plt.legend()
plt.grid(True)

# Plot 3: The Patched Output - our prediction
plt.subplot(3, 1, 3)
plt.plot(patched_output[0].cpu().numpy(), label="Patched Output", color='purple')
plt.axvline(x=50, color='r', linestyle='--', label='x = pi')
plt.title("Patched Result: sin(3x) structure + sin(4x) symmetry rule")
plt.xlabel("Position (x)")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


Let's look at only the output of the MLP, ignoring the skip connection. Do we lose much accuracy?

In [None]:
# --- 2. Define Hook Points and Shared Dictionary (from your code) ---
layer_to_ablate = 0
hook_point_mlp_out = f"blocks.{layer_to_ablate}.hook_mlp_out"
hook_point_resid_post = f"blocks.{layer_to_ablate}.hook_resid_post"

# This dictionary will be accessible by both hooks during the run
captured_activations_for_run = {}

# --- 3. Define Hook Functions (from your code) ---

def capture_mlp_out_hook(activation, hook):
    """Captures the mlp_out activation into the shared dictionary."""
    captured_activations_for_run['mlp_out'] = activation
    # Must return the original activation to not affect the forward pass here
    return activation

def replace_resid_post_hook(activation, hook):
    """Replaces the resid_post activation with the captured mlp_out."""
    if 'mlp_out' not in captured_activations_for_run:
         print(f"Warning: Capture hook did not run before modify hook '{hook.name}'")
         return activation # Safety fallback

    # This is the captured output from the MLP layer
    mlp_out_value = captured_activations_for_run['mlp_out']

    # --- IMPORTANT: Return the captured mlp_out value ---
    # This replaces the original resid_post value ((resid_pre + attn_out) + mlp_out)
    # with just mlp_out for the rest of the forward pass.
    return mlp_out_value

# --- 4. Run the Ablation (from your code) ---
print("Running model, ablating to only the MLP output...")

# Clear the shared dictionary before the run
captured_activations_for_run.clear()

# Run with BOTH hooks active simultaneously
ablated_predictions = model.run_with_hooks(
    test_data,
    fwd_hooks=[
        (hook_point_mlp_out, capture_mlp_out_hook),
        (hook_point_resid_post, replace_resid_post_hook)
    ]
)
print("Model run with hooks completed.")


# --- 5. Calculate and Compare Loss (adapted for this task) ---
loss_fn = nn.MSELoss()

with torch.no_grad():
    original_predictions = model(test_data)
    original_loss = loss_fn(original_predictions, test_labels)
    ablated_loss = loss_fn(ablated_predictions, test_labels)

print("\n--- RESULTS ---")
print(f"Original Model Test Loss:          {original_loss.item():.6f}")
print(f"Ablated Model Loss (MLP Out Only): {ablated_loss.item():.6f}")


# --- 6. Visualize the Difference (adapted for this task) ---
idx_to_plot = 0
original_curve = original_predictions[idx_to_plot].cpu().numpy()
ablated_curve = ablated_predictions[idx_to_plot].cpu().numpy()
ground_truth_curve = test_labels[idx_to_plot].cpu().numpy()

plt.figure(figsize=(12, 6))
plt.plot(ground_truth_curve, label="Ground Truth", color='black', linestyle='--')
plt.plot(original_curve, label=f"Original Model Output (Loss: {original_loss:.4f})", color='blue')
plt.plot(ablated_curve, label=f"Ablated Model Output (Loss: {ablated_loss:.4f})", color='red', alpha=0.7)
plt.title("Effect of Ablating to MLP Output Only (Your Hooking Method)")
plt.xlabel("Position (x)")
plt.ylabel("Model Output")
plt.legend()
plt.grid(True)
plt.show()

# Clear the dictionary just in case
captured_activations_for_run.clear()
