<a href="https://colab.research.google.com/github/Yshen-11/DUKE_XAI/blob/main/Assignment9_No_Position_Experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# No Position Experiment


---
# Introduction

The accompanying notebook to my [real-time research](https://www.youtube.com/watch?v=yo4QvDn-vsU) video. Trains a model with no positional embeddings to predict the previous token, and makes a start at analysing what's going on there!

EDIT: The loss spikes were due to the learning rate being max(step/100, 1.0) not min! Thanks to MadHatter for catching that.


# Step 1: Setup
Here, we set up the environment and import the necessary libraries for building and analyzing the Transformer model.


In [1]:
# NBVAL_IGNORE_OUTPUT
import os

# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab

    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")

if IN_COLAB or IN_GITHUB:
    %pip install einops
    %pip install transformer_lens==1.15.0

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git

from transformer_lens import HookedTransformer, HookedTransformerConfig
import torch
import numpy as np
import plotly.express as px
import plotly.io as pio

pio.renderers.default = "colab"
import tqdm.auto as tqdm
import einops
from transformer_lens.utils import to_numpy

device = "cuda" if torch.cuda.is_available() else "cpu"

Running as a Colab notebook
Collecting transformer_lens==1.15.0
  Downloading transformer_lens-1.15.0-py3-none-any.whl.metadata (11 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens==1.15.0)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens==1.15.0)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer_lens==1.15.0)
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens==1.15.0)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer_lens==1.15.0)
  Downloading jaxtyping-0.2.36-py3-none-any.whl.metadata (6.5 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.7.1->transformer_lens==1.15.0)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets>=2.7.1->transformer_lens==1.15.0)
  Downloading

Some plotting code. Wrappers around Plotly, not important to understand.

In [2]:
def line(tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    labels = {"y": yaxis, "x": xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()


def imshow(tensor, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    plot_kwargs = {
        "color_continuous_scale": "RdBu",
        "color_continuous_midpoint": 0.0,
        "labels": {"x": xaxis, "y": yaxis},
    }
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

# Step 2: Model Initialization
Next, we define the Transformer model architecture and explicitly deactivate positional embeddings.

## Defining the Transformer Model
We initialize a simple Transformer model with two layers and a single attention head. The model uses the following configuration:

- **Number of Layers**: 2
- **Model Dimension**: 64
- **Vocabulary Size**: 300
- **Context Length**: 50 (maximum sequence length)
- **Activation Function**: ReLU
- **Normalization**: Layer Normalization (LN)

We then deactivate the positional embeddings to remove any position-specific information.


In [3]:
cfg = HookedTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
    normalization_type="LN",
    device=device,
)
model = HookedTransformer(cfg)

def deactivate_position(model):
    # Zero out the positional embeddings and disable gradient updates
    model.pos_embed.W_pos.data[:] = 0.0
    model.pos_embed.W_pos.requires_grad = False


deactivate_position(model)

### Validation:
- The model summary will confirm the architecture matches our intended configuration.
- The positional embedding matrix (`W_pos`) should contain only zeros, confirming it has been successfully deactivated.


In [4]:
# Print the model to verify configuration
print(model)

# Check positional embeddings
print("Positional embeddings:", model.pos_embed.W_pos)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

# Step 3: Data Generation

##  Data Generator
The data generator creates batches of random token sequences based on the following specifications:

- **Vocabulary Size**: 300 tokens.
- **Sequence Length**: 50 tokens (as defined in the model configuration).
- **Batch Size**: Number of sequences generated per iteration.
- **Special Token**: The first token in each sequence is set to `0`, representing a beginning-of-sequence (BOS) token.

The generator uses a fixed random seed to ensure reproducibility.


- **Random Token Sequences**:
  - Tokens are sampled uniformly between 1 and `d_vocab` (300 in this case).
  - The `BOS` token (value `0`) is used to represent the start of a sequence.
- **Generator**:
  - The generator is an infinite loop (`while True`), producing new batches of data each time it is called.
- **Example Output**:
  - Prints a batch of random token sequences to validate the generator's functionality.

In [5]:
def make_data_generator(cfg, batch_size, seed=123, incl_bos_token=True):
    torch.manual_seed(seed)
    while True:
        # Generate random tokens between 1 and vocabulary size
        x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))
        if incl_bos_token:
            # Set the first token to the BOS token
            x[:, 0] = 0
        yield x

# Create a data generator with batch size of 2
data_generator = make_data_generator(cfg, batch_size=2)
# Generate a single batch and print it
print(next(data_generator))


tensor([[  0,  93,  34, 155, 274, 116, 114, 248,  68,   3, 298,  83, 194,  20,
           8, 133,  32,  66,  62,  73, 210, 273,  46, 243, 104, 232, 161, 125,
         123, 251,   7,   4, 115, 127,  21,   1,  89, 142,   6,  15, 298, 251,
          88, 229, 108, 114,  23,  88,   3, 265],
        [  0, 118,  46, 274, 105, 268, 131,  35,  19,  58, 226, 278,  27,  25,
         276, 180, 164,   4,  95,  27,  74, 201, 105,  65,  80, 185,  44, 258,
         105,  60,  58,  47, 126,  60, 294, 253, 258, 136,  29, 101, 258,  77,
          80, 180, 159, 169, 122, 117,  27, 194]])



## Defining the Loss Function
To train the model, we need a loss function that measures how well it predicts the next token in the sequence.

The loss function computes the negative log probability of the correct next token, given the model's predictions. It operates as follows:
1. **Shift Tokens**:
   - For predictions (`logits`), the first token is ignored because it has no preceding context.
   - For true tokens, the last token is ignored as it cannot be predicted.

2. **Log Probabilities**:
   - Converts model predictions to log probabilities using `log_softmax`.

3. **Correct Token Probabilities**:
   - Gathers the log probabilities corresponding to the true tokens.

4. **Output**:
   - If `per_token=True`, returns the loss for each token.
   - Otherwise, computes the mean loss across all tokens.


In [6]:
def loss_fn(logits, tokens, per_token=False):
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]
    logits = logits[:, 1:]
    tokens = tokens[:, :-1]

    # Compute log probabilities
    log_probs = logits.log_softmax(-1)

    # Gather log probabilities of correct tokens
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]

    # Return token-level or batch-level loss
    if per_token:
        return -correct_log_probs
    else:
        return -correct_log_probs.mean()

# Test the loss function with example logits and tokens
test_tokens = torch.arange(5)[None, :]
test_logits = torch.randn(1, 5, 10)
test_logits[:, 1, 0] = 10.0
test_logits[:, 2, 1] = 10.0
test_logits[:, 3, 2] = 10.0
test_logits[:, 4, 3] = 10.0
print("Token-level Loss:", loss_fn(test_logits, test_tokens, per_token=True))
print("Batch-level Loss:", loss_fn(test_logits, test_tokens, per_token=False))

Token-level Loss: tensor([[0.0004, 0.0003, 0.0031, 0.0005]])
Batch-level Loss: tensor(0.0011)


# Step 4: Training the Model
In this section, we train the Transformer model without positional embeddings. The key steps are as follows:

1. **Optimizer Setup**:
   - We use the AdamW optimizer with weight decay for better generalization.
   - A learning rate scheduler gradually increases the learning rate at the start of training.

2. **Training Loop**:
   - For each batch of token sequences:
     - Compute the model's predictions (`logits`).
     - Calculate the loss using the defined loss function.
     - Backpropagate the loss to compute gradients.
     - Clip gradients to avoid exploding gradients.
     - Update model parameters using the optimizer.
   - Track the loss for analysis and visualization.

3. **Loss Visualization**:
   - Plot the loss curve to monitor training progress.


## Optimizer and Training Loop
- **Batch Size**: 256 sequences per batch, matching the training data generator.
- **Epochs**: 4000 iterations of training.
- **Optimizer**:
  - **AdamW**: Combines Adam optimizer with weight decay for better regularization.
  - **Learning Rate Scheduler**: Gradually increases the learning rate during the first 100 iterations.

- **Training Loop**:
  - Each iteration involves:
    - **Forward Pass**: Compute logits (predictions) for the input tokens.
    - **Loss Computation**: Measure how well the predictions match the target tokens.
    - **Backward Pass**: Compute gradients of the loss with respect to model parameters.
    - **Gradient Clipping**: Prevents gradients from becoming too large.
    - **Parameter Update**: Adjust model parameters using the optimizer.

- **Loss Tracking**:
  - The loss for each epoch is stored in the `losses` list.
  - The loss curve is plotted to visualize training progress.

In [7]:
# Training parameters
batch_size = 256
num_epochs = 4000
lr = 1e-4
betas = (0.9, 0.95)
max_grad_norm = 1.0
wd = 0.1

# Define optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))


## Model Training

In [8]:
# Training loop
losses = []
data_loader = make_data_generator(cfg, batch_size)
for epoch in tqdm.tqdm(range(num_epochs)):
    # Generate a batch of tokens
    tokens = next(data_loader).to(device)

    # Forward pass
    logits = model(tokens)

    # Compute loss
    loss = loss_fn(logits, tokens)

    # Backpropagation
    loss.backward()

    # Gradient clipping
    if max_grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

    # Optimizer step and scheduler step
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()

    # Track loss
    losses.append(loss.item())

    # Print loss every 100 epochs
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item()}")

# Plot the loss curve
px.line(losses, labels={"x": "Epoch", "y": "Loss"}).show()

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch 0: Loss = 6.007701873779297
Epoch 100: Loss = 5.755710124969482
Epoch 200: Loss = 5.565812110900879
Epoch 300: Loss = 5.4445576667785645
Epoch 400: Loss = 5.315097332000732
Epoch 500: Loss = 5.179347991943359
Epoch 600: Loss = 4.983320713043213
Epoch 700: Loss = 4.717402458190918
Epoch 800: Loss = 4.423213958740234
Epoch 900: Loss = 4.063094615936279
Epoch 1000: Loss = 3.670921802520752
Epoch 1100: Loss = 3.2450222969055176
Epoch 1200: Loss = 2.8303120136260986
Epoch 1300: Loss = 2.4301490783691406
Epoch 1400: Loss = 2.0606563091278076
Epoch 1500: Loss = 1.7411576509475708
Epoch 1600: Loss = 1.432838797569275
Epoch 1700: Loss = 1.166140079498291
Epoch 1800: Loss = 0.9399619698524475
Epoch 1900: Loss = 0.7393752932548523
Epoch 2000: Loss = 0.5784555673599243
Epoch 2100: Loss = 0.4442668557167053
Epoch 2200: Loss = 0.33904004096984863
Epoch 2300: Loss = 0.26044589281082153
Epoch 2400: Loss = 0.1964339017868042
Epoch 2500: Loss = 0.14744994044303894
Epoch 2600: Loss = 0.108612053096

In [9]:
# torch.save(model.state_dict(), "no_pos_experiment_state_dict_v0.pth")

# Step 5: Analyzing Results
This section focuses on understanding the trained model's internal behavior. We aim to explore the following:

1. **Attention Patterns**:
   - Visualize the attention weights across different layers to understand how the model processes input sequences.

2. **Component Contributions**:
   - Evaluate how specific components of the model (e.g., embeddings, attention layers) contribute to the final logits (predictions).

3. **Interpretability Tools**:
   - Use cached activations to analyze and manipulate model outputs.


### Verifying Positional Embeddings
In this experiment, we explicitly deactivate the positional embeddings in the model. To confirm this, we compute the norm (`L2` norm) of the positional embedding matrix `W_pos`.

- **Expected Output**: The norm should be zero (`tensor(0.)`), indicating that all positional embeddings have been set to zero and are not being updated during training.

This step ensures the experiment is correctly focused on a model without positional encoding.


In [10]:
# Verify positional embeddings
model.pos_embed.W_pos.norm()

tensor(0.)

## Look at attention patterns

- **Attention Patterns**:
  - **Query and Key**: Each attention head computes attention weights between tokens based on their queries and keys.
  - **Mean Attention**: By averaging over batches and attention heads, we visualize the overall attention pattern for each layer.

- **Visualization**:
  - `Layer 0` and `Layer 1` attention patterns are displayed as heatmaps.
  - X-axis represents the "Key" tokens, and Y-axis represents the "Query" tokens.

In [11]:
# Generate a large batch of tokens for analysis
big_data_loader = make_data_generator(cfg, 4000)
big_tokens = next(big_data_loader).to(device)

# Forward pass with caching of activations
logits, cache = model.run_with_cache(big_tokens)

print("Loss:", loss_fn(logits, big_tokens).item())

Loss: 0.0050029149278998375


In [12]:
print(cache)

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_sc

In [13]:
cache["blocks.0.attn.hook_pattern"].shape

torch.Size([4000, 1, 50, 50])

In [14]:
batch_index = 0
tokens = big_tokens[batch_index]

# Visualize attention patterns for Layer 0
imshow(
    to_numpy(cache["attn", 0].mean([0, 1])),
    title="Layer 0 Attention Pattern",
    yaxis="Query",
    xaxis="Key",
    height=500,
    width=500,
)
imshow(
    to_numpy(cache["attn", 1].mean([0, 1])),
    title="Layer 1 Attention Pattern",
    yaxis="Query",
    xaxis="Key",
    height=500,
    width=500,
)

## Look at how different bits of the model directly contribute to the logits
- **Residual Components**:
  - Tracks how different parts of the model (e.g., embeddings, attention, and MLP layers) contribute to the logits (predictions).

- **Normalization**:
  - Residual components are centered by subtracting the mean across the batch.

- **Logit Contributions**:
  - Uses a folded weight matrix (`fold_W_U`) to compute contributions from each component to the model's final output.

- **Visualization**:
  - Plots the contributions of each component across the sequence, providing insight into their relative importance.

In [15]:
# Extract key components from the cache
resid_components = [
    cache["embed"],         # Embedding layer
    cache["attn_out", 0],   # Attention output from Layer 0
    cache["mlp_out", 0],    # MLP output from Layer 0
    cache["attn_out", 1],   # Attention output from Layer 1
    cache["mlp_out", 1],    # MLP output from Layer 1
]
# Component labels for visualization
labels = ["Embedding", "Attention Layer 0", "MLP Layer 0", "Attention Layer 1", "MLP Layer 1"]

# Stack residual components for analysis
resid_stack = torch.stack(resid_components, 0)

# Normalize residual components
resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)

print(resid_stack.shape)

torch.Size([5, 4000, 50, 64])


In [16]:
# Compute contributions to logits
fold_W_U = model.ln_final.w[:, None] * model.unembed.W_U
logit_components = resid_stack[:, batch_index] @ fold_W_U / cache["scale"][batch_index]
print(logit_components.shape)

torch.Size([5, 50, 300])


In [17]:
# Visualize contributions
logit_components = logit_components - logit_components.mean(-1, keepdim=True)
line(
    logit_components[:, torch.arange(1, model.cfg.n_ctx).to(device), tokens[:-1]].T,
    line_labels=labels,
)

## Folding in LayerNorm

LayerNorm Folding
This section explains how to fold LayerNorm into the model's weights to simplify subsequent analysis and improve interpretability.

**Key Steps:**
1. Pre-Normalization Configuration:
* A new Transformer model configuration is created with pre-normalization enabled (LNPre), which simplifies the model by folding LayerNorm into the weights.
2. Weight Processing:
* The state dictionary of the original model is used to initialize the new analysis model.
* LayerNorm weights and biases are folded into the attention and MLP layers to center the weights and improve interpretability.
3. Positional Embeddings Deactivation:
* Positional embeddings are disabled in the new analysis model for consistency with the original model.

In [18]:
# Define a new configuration for the analysis model
analysis_cfg = HookedTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
    normalization_type="LNPre",# Pre-normalization is used for folding
    init_weights=False,# Do not initialize new weights
)
# Initialize a new model for analysis
analysis_model = HookedTransformer(analysis_cfg)

# Load the state dictionary from the original model and fold LayerNorm
state_dict = model.state_dict()
analysis_model.load_and_process_state_dict(
    state_dict,
    fold_ln=True, # Fold LayerNorm into weights
    center_writing_weights=True, # Center LayerNorm weights
    center_unembed=True # Center the unembedding weights
)

# Deactivate positional embeddings for consistency
deactivate_position(analysis_model)

In [19]:
# Output the folded model structure (for validation)
print(analysis_model)

# Check if positional embeddings remain deactivated
print("Positional embeddings norm:", analysis_model.pos_embed.W_pos.norm())

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hoo

In [20]:
# analysis_model()

# Step 6: Understanding Model Behavior

## Understand Attention in Layer 0

Attention in Layer 0 focuses on the initial step of the Transformer where the model learns relationships between tokens using the **Query-Key-Value** mechanism.

**Key Points**:
1. **Query-Key Scores**:
   - The model calculates a similarity score between the current token (query) and other tokens (keys).
   - Higher scores indicate stronger relevance.

2. **Attention Weights**:
   - These scores are converted into weights via a softmax operation, emphasizing the most relevant tokens.

3. **Value Aggregation**:
   - Attention weights are used to combine information from the value vectors of other tokens.

**Objective**:
Visualize and interpret the attention patterns in Layer 0 to understand how the model distributes focus among tokens.



In [21]:
# Visualize Query-Key interactions for Layer 0
QK = model.W_E @ model.W_Q[0, 0] @ model.W_K[0, 0].T @ model.W_E.T
imshow(QK, yaxis="Query", xaxis="Key",title="Layer 0: Query-Key Interactions")

In [22]:
# Visualize Output-Value mapping for Layer 0
OV = model.W_E @ model.W_V[0, 0] @ model.W_O[0, 0] @ model.W_in[0]
imshow(OV, yaxis="Input Vocab", xaxis="Neuron",title="Layer 0: Output-Value Mapping")

In [23]:
line(OV[:, torch.randint(0, 256, (5,))])

## Understand MLP in Layer 0

The MLP (Multi-Layer Perceptron) in Layer 0 processes token representations after attention. It consists of:
1. **Feedforward Layers**:
   - Applies linear transformations followed by an activation function (e.g., ReLU).
2. **Neuron Activations**:
   - Each neuron in the MLP represents a specific feature learned by the model.

**Objective**:
Analyze the behavior of the MLP layer by visualizing:
1. Post-activation states (outputs of the MLP).
2. Activation rates (percentage of neurons that are active).


In [24]:
# Visualize post-activation values for Layer 0
batch_index = 0
imshow(cache["post", 0][batch_index], yaxis="Pos", xaxis="Neuron", title="Layer 0: MLP Activations (Batch)")

# Visualize mean activations across tokens
imshow(cache["post", 0].mean(0), yaxis="Pos", xaxis="Neuron", title="Layer 0: Mean MLP Activations")

# Visualize activation states (active/inactive neurons)
imshow((cache["post", 0] > 0).float().mean(0), yaxis="Pos", xaxis="Neuron", title="Layer 0: MLP Activation States")


## Understand Attention in Layer 1

Attention in Layer 1 builds upon the representations from Layer 0. It refines the relationships between tokens, often focusing on broader context or deeper semantic patterns.

**Objective**:
1. Visualize attention patterns to understand how Layer 1 complements Layer 0.
2. Observe differences in Query-Key interactions and Output-Value mappings between layers.


In [25]:
# Visualize Query-Key interactions for Layer 1
QK_layer1 = analysis_model.W_E @ analysis_model.W_Q[1, 0] @ analysis_model.W_K[1, 0].T @ analysis_model.W_E.T
imshow(QK_layer1, yaxis="Query", xaxis="Key", title="Attention Layer 1: Query-Key Matrix")


In [26]:
# Visualize Output-Value mapping for Layer 1
OV_layer1 = analysis_model.W_E @ analysis_model.W_V[1, 0] @ analysis_model.W_O[1, 0] @ analysis_model.W_in[0]
imshow(OV_layer1, yaxis="Input Vocab", xaxis="Neuron", title="Attention Layer 1: Output-Value Matrix")



In [27]:
line(OV_layer1[:, torch.randint(0, 256, (5,))], title="Attention Layer 1: Random Neuron Outputs")


## Understand MLP in Layer 1

The MLP in Layer 1 works similarly to Layer 0 but may capture more abstract patterns or broader contexts due to its position in the deeper layer.

**Objective**:
1. Analyze post-activation states for neurons in Layer 1.
2. Compare activation behaviors between Layer 0 and Layer 1.


In [28]:
# Visualize post-activation values for Layer 1
imshow(cache["post", 1][batch_index], yaxis="Pos", xaxis="Neuron", title="Layer 1: MLP Activations (Batch)")

# Visualize mean activations across tokens
imshow(cache["post", 1].mean(0), yaxis="Pos", xaxis="Neuron", title="Layer 1: Mean MLP Activations")

# Visualize activation states (active/inactive neurons)
imshow((cache["post", 1] > 0).float().mean(0), yaxis="Pos", xaxis="Neuron", title="Layer 1: MLP Activation States")


# Step 7: Experimentation: Activation Replacement

In this experiment, we replace specific activations in the model with averaged values. This allows us to:
1. Identify critical components (layers, neurons) by observing their impact on model performance.
2. Measure the sensitivity of the model to specific activations.

**Steps**:
1. Replace activations in different layers/components.
2. Measure the change in model loss.
3. Visualize the results to identify the most important components.


In [29]:
# Baseline loss
new_token_batch = next(big_data_loader).to(device)
baseline_loss = loss_fn(model(new_token_batch), new_token_batch).item()
print("Baseline loss:", baseline_loss)

Baseline loss: 0.005062513053417206


In [30]:
# Replace activations with their averaged values and compute loss
hook_list = list(model.hook_dict.keys())
losses = []
loss_labels = []
for hook_name in hook_list:
    if (
        hook_name in cache
        and hook_name != "hook_pos_embed"
        and "result" not in hook_name
    ):
        average_act = cache[hook_name].mean(0)

        def replacing_with_average_act(activation, hook):
            activation[:] = einops.repeat(
                average_act, "... -> batch ...", batch=new_token_batch.size(0)
            )
            return activation

        logits = model.run_with_hooks(
            new_token_batch, fwd_hooks=[(hook_name, replacing_with_average_act)]
        )
        loss = loss_fn(logits, new_token_batch)
        print(hook_name, loss.item())
        losses.append(loss.item())
        loss_labels.append(hook_name)

hook_embed 11.68921947479248
blocks.0.ln1.hook_scale 0.6906909346580505
blocks.0.ln1.hook_normalized 2.4798977375030518
blocks.0.ln2.hook_scale 0.011786145158112049
blocks.0.ln2.hook_normalized 10.283241271972656
blocks.0.attn.hook_k 0.018330011516809464
blocks.0.attn.hook_q 1.8176791667938232
blocks.0.attn.hook_v 0.30361440777778625
blocks.0.attn.hook_z 2.50703501701355
blocks.0.attn.hook_attn_scores 1.722840666770935
blocks.0.attn.hook_pattern 1.7351224422454834
blocks.0.mlp.hook_pre 10.283241271972656
blocks.0.mlp.hook_post 10.302469253540039
blocks.0.hook_attn_out 2.507035255432129
blocks.0.hook_mlp_out 10.302469253540039
blocks.0.hook_resid_pre 11.68921947479248
blocks.0.hook_resid_mid 10.975273132324219
blocks.0.hook_resid_post 10.1483154296875
blocks.1.ln1.hook_scale 0.01404147781431675
blocks.1.ln1.hook_normalized 9.083661079406738
blocks.1.ln2.hook_scale 0.005110857542604208
blocks.1.ln2.hook_normalized 4.403139114379883
blocks.1.attn.hook_k 0.002950200578197837
blocks.1.attn.

In [31]:
# Visualize impact of averaged activations
line(losses, hover_name=loss_labels)

In [32]:
cache.cache_dict.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'ln_final.hook_scale', 'ln_final.