In [1]:
import sys
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    utils,
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
import circuitsvis as cv

from optim_hunter.plotly_utils import imshow, hist, plot_comp_scores, plot_logit_attribution, plot_loss_difference, line
from optim_hunter.utils import prepare_prompt, slice_dataset, prepare_prompt_from_tokens, pad_numeric_tokens
from optim_hunter.sklearn_regressors import linear_regression, knn_regression, random_forest, baseline_average, baseline_last, baseline_random
from optim_hunter.datasets import get_dataset_friedman_2
from optim_hunter.data_model import create_comparison_data
from optim_hunter.model_utils import get_numerical_tokens, generate_linreg_tokens, run_and_cache_model_linreg_tokens_batched, run_and_cache_model_linreg_tokens
from optim_hunter.llama_model import load_llama_model
import logging
from typing import List, Tuple


# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

#device = t.device("cuda:0,1" if t.cuda.is_available() else "cpu")
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
# device = t.device("cpu")

MAIN = __name__ == "__main__"

In [None]:
model = load_llama_model()

In [None]:
dataset = get_dataset_friedman_2
seq_len = 25
random_int = 1
def get_prompt(seq_len, random_int, print_prompt=True):
    x_train, y_train, x_test, y_test = dataset(random_int)
    x_train, y_train, x_test, y_test = slice_dataset(
        x_train, y_train, x_test, y_test, seq_len
    )
    x_train_tokens, y_train_tokens, x_test_tokens = pad_numeric_tokens(
                    model, x_train, y_train, x_test
                )
    tokenized_prompt = prepare_prompt_from_tokens(
        model, x_train_tokens, y_train_tokens, x_test_tokens, prepend_bos=True, prepend_inst=True
    )

    prompt = model.to_string(tokenized_prompt[0])
    if print_prompt: print(prompt)
    return tokenized_prompt
tokenized_prompt = get_prompt(seq_len, random_int)

In [4]:
import plotly.express as px

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

What is the model's loss for the first numerical token following "Output:" ?

In [None]:
def plot_linreg_output_loss(tokenized_prompt):
    input_tokens = tokenized_prompt.to(model.cfg.device)
    logits = model(input_tokens)
    log_probs = model.loss_fn(logits, input_tokens, per_token=True)
    loss_by_position = einops.reduce(log_probs, "batch position -> position", "mean")


    # Get string tokens for the full sequence
    str_tokens = model.to_str_tokens(input_tokens[0])

    # Find indices where "Output:" appears, starting from position 24
    output_indices = [i for i, token in enumerate(str_tokens[24:], start=24) if token == "Output"]

    # Get just the first number token after each "Output:"
    output_losses = []
    output_tokens = []
    for idx in output_indices:
        # Skip the ":" and " " tokens by starting 3 positions after "Output"
        current_pos = idx + 3
        if current_pos < len(str_tokens) - 1:
            output_losses.append(loss_by_position[current_pos].item())
            output_tokens.append(str_tokens[current_pos])

    return output_losses, output_tokens

output_losses, output_tokens = plot_linreg_output_loss(tokenized_prompt)

# Create figure with just the first number token
fig = px.line(
    output_losses,
    labels={"x": "Token Position", "y": "Loss"},
    title="Loss for First Output Number Token"
)

# Update x-axis to show the actual tokens
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        ticktext = output_tokens,
        tickvals = list(range(len(output_tokens))),
        tickangle = 45
    )
)

Do the same thing only using 25 samples.

In [None]:
seq_len = 25 
random_int = 1
tokenized_prompt = get_prompt(seq_len, random_int, print_prompt=False)
output_losses, output_tokens = plot_linreg_output_loss(tokenized_prompt)

# Create figure with just the first number token
fig = px.line(
    output_losses,
    labels={"x": "Token Position", "y": "Loss"},
    title="Loss for First Output Number Token"
)

# Update x-axis to show the actual tokens
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        ticktext = output_tokens,
        tickvals = list(range(len(output_tokens))),
        tickangle = 45
    )
)

Remove the outlilers to see if the loss is decreasing.

In [None]:
seq_len = 51
random_int = 1
tokenized_prompt = get_prompt(seq_len, random_int, print_prompt=False)
output_losses, output_tokens = plot_linreg_output_loss(tokenized_prompt)

# Remove outliers using more aggressive IQR method
import numpy as np

def remove_outliers(data, tokens, iqr_multiplier=0.5):
    data = np.array(data)
    Q1 = np.percentile(data, 25)
    Q3 = np.percentile(data, 75)
    IQR = Q3 - Q1
    lower_bound = Q1 - iqr_multiplier * IQR
    upper_bound = Q3 + iqr_multiplier * IQR
    
    mask = (data >= lower_bound) & (data <= upper_bound)
    return data[mask].tolist(), [t for i, t in enumerate(tokens) if mask[i]]

# Filter out outliers
filtered_losses, filtered_tokens = remove_outliers(output_losses, output_tokens)

# Create figure with filtered data
fig = px.line(
    filtered_losses,
    labels={"x": "Token Position", "y": "Loss"},
    title="Loss for First Output Number Token (Outliers Removed)"
)

# Update x-axis to show the actual tokens
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        ticktext = filtered_tokens,
        tickvals = list(range(len(filtered_tokens))),
        tickangle = 45
    )
)

Confirm that all first numerical tokens positions are the same for Output and Feature.

In [None]:
def check_output_token_positions(model, dataset, seq_len, num_seeds=5):
    positions = []
    for seed in range(num_seeds):
        tokenized_prompt = get_prompt(seq_len, seed, print_prompt=False)
        input_tokens = tokenized_prompt.to(model.cfg.device)
        
        # Get string tokens for the full sequence
        str_tokens = model.to_str_tokens(input_tokens[0])
        
        # Find indices where "Output:" appears, starting from position 24
        output_indices = [i for i, token in enumerate(str_tokens[24:], start=24) if token == "Output"]
        
        # Get positions of first number token after each "Output:"
        output_positions = []
        for idx in output_indices:
            # Skip the ":" and " " tokens by starting 3 positions after "Output"
            current_pos = idx + 3
            if current_pos < len(str_tokens) - 1:
                output_positions.append(current_pos)
        
        positions.append(output_positions)
    
    # Print results
    print("\nPositions of first number token after 'Output:' for each seed:")
    for seed, pos in enumerate(positions):
        print(f"Seed {seed}: {pos}")
    
    # Check if all position lists are identical
    all_same = all(pos == positions[0] for pos in positions[1:])
    print(f"\nAll positions consistent across seeds: {all_same}")
    
    return positions

# Test with different sequence lengths
seq_lens = [25, 50]
for seq_len in seq_lens:
    print(f"\nTesting with sequence length: {seq_len}")
    positions = check_output_token_positions(model, dataset, seq_len)

In [None]:

def check_token_positions_multi(model, dataset, seq_len, num_seeds=5, print="true"):
    output_positions = []
    feature_positions = []
    
    for seed in range(num_seeds):
        tokenized_prompt = get_prompt(seq_len, seed, print_prompt=False)
        input_tokens = tokenized_prompt.to(model.cfg.device)
        
        # Get string tokens for the full sequence
        str_tokens = model.to_str_tokens(input_tokens[0])
        
        # Find indices where "Output:" and "Features:" appear
        output_indices = [i for i, token in enumerate(str_tokens[24:], start=24) if token == "Output"]
        feature_indices = [i for i, token in enumerate(str_tokens[24:], start=24) if token == "Feature"]
        
        # Get positions of first number token after each marker
        seed_output_positions = []
        for idx in output_indices:
            current_pos = idx + 3  # Skip "Output: "
            if current_pos < len(str_tokens) - 1:
                seed_output_positions.append(current_pos)
        
        seed_feature_positions = []
        for idx in feature_indices:
            current_pos = idx + 5  # Skip "Features n: "
            if current_pos < len(str_tokens) - 1:
                seed_feature_positions.append(current_pos)
        
        output_positions.append(seed_output_positions)
        feature_positions.append(seed_feature_positions)
    
    # Print results
    print("\nPositions of first number token after 'Output:' for each seed:")
    for seed, pos in enumerate(output_positions):
        print(f"Seed {seed}: {pos}")
        
    print("\nPositions of first number token after 'Feature:' for each seed:")
    for seed, pos in enumerate(feature_positions):
        print(f"Seed {seed}: {pos}")
    
    # Check if all position lists are identical
    outputs_same = all(pos == output_positions[0] for pos in output_positions[1:])
    features_same = all(pos == feature_positions[0] for pos in feature_positions[1:])
    
    print(f"\nAll Output positions consistent across seeds: {outputs_same}")
    print(f"All Feature positions consistent across seeds: {features_same}")
    
    # Print example tokens at these positions for first seed
    print("\nExample tokens at these positions (first seed):")
    str_tokens = model.to_str_tokens(get_prompt(seq_len, 0, print_prompt=False)[0])
    print("After Output:", [str_tokens[pos] for pos in output_positions[0]])
    print("After Features:", [str_tokens[pos] for pos in feature_positions[0]])
    
    return output_positions, feature_positions

# Test with different sequence lengths
seq_lens = [25, 50]
for seq_len in seq_lens:
    print(f"\nTesting with sequence length: {seq_len}")
    output_pos, feature_pos = check_token_positions_multi(model, dataset, seq_len)

In [None]:
dataset = get_dataset_friedman_2

def check_token_positions(model, dataset, seq_len, seed=0, print_info=True):
    """
    Check token positions for a single sequence length and seed.
    
    Args:
        model: The transformer model
        dataset: The dataset being used
        seq_len: Sequence length to test
        seed: Random seed (default: 0)
    
    Returns:
        tuple: Lists of output and feature positions
    """
    # Get tokenized prompt for the specified seed
    tokenized_prompt = get_prompt(seq_len, seed, print_prompt=False)
    input_tokens = tokenized_prompt.to(model.cfg.device)
    
    # Get string tokens for the full sequence
    str_tokens = model.to_str_tokens(input_tokens[0])
    
    # Find indices where "Output:" and "Features:" appear
    output_indices = [i for i, token in enumerate(str_tokens[24:], start=24) if token == "Output"]
    feature_indices = [i for i, token in enumerate(str_tokens[24:], start=24) if token == "Feature"]
    
    # Get positions of first number token after each marker
    output_positions = []
    for idx in output_indices:
        current_pos = idx + 3  # Skip "Output: "
        if current_pos < len(str_tokens) - 1:
            output_positions.append(current_pos)
    
    feature_positions = []
    for idx in feature_indices:
        current_pos = idx + 5  # Skip "Features n: "
        if current_pos < len(str_tokens) - 1:
            feature_positions.append(current_pos)
    
    # Print results
    if print_info:
        print("\nPositions of first number token after 'Output:':")
        print(f"Positions: {output_positions}")
        
        print("\nPositions of first number token after 'Feature:':")
        print(f"Positions: {feature_positions}")
        
        # Print example tokens at these positions
        print("\nExample tokens at these positions:")
        print("After Output:", [str_tokens[pos] for pos in output_positions])
        print("After Features:", [str_tokens[pos] for pos in feature_positions])
    
    return output_positions, feature_positions

# Test with a single sequence length
seq_len = 25
output_pos, feature_pos = check_token_positions(model, dataset, seq_len)


Let's Look at the heads
- Check for induction heads.
- Check for heads which attend to the numbers after "Features:"
- Check for heads which attend to numbers after "Outputs:"
- Check for heads which attend to just the previous "Features:"

In [None]:
# We make a tensor to store the induction score for each head. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
from plotly.subplots import make_subplots
import plotly.express as px

seq_len = 25
dataset = get_dataset_friedman_2
random_int = 1
tokenized_prompt = get_prompt(seq_len, random_int, print_prompt=False)
output_pos, feature_pos = check_token_positions(model, dataset, seq_len, print_info=False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def imshow_multi(tensor, num_examples, xaxis="", yaxis="", title="Example Scores by Head", renderer=None):
    """
    Display multiple example scores by head in a grid of square plots, each with its own color scale.
    
    Args:
        tensor: The tensor containing scores for each example.
        num_examples: Number of examples to plot.
        xaxis: Label for the x-axis.
        yaxis: Label for the y-axis.
        title: Title for the entire plot.
        renderer: Optional renderer for displaying the plot.
    """
    # Determine the number of rows and columns for the grid
    num_cols = int(num_examples**0.5)
    num_rows = (num_examples + num_cols - 1) // num_cols  # Ceiling division

    # Create a subplot grid
    fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=[f"Example {i}" for i in range(num_examples)])
    
    # Add each example's score as a subplot
    for i in range(num_examples):
        row = i // num_cols + 1
        col = i % num_cols + 1
        example_tensor = utils.to_numpy(tensor[..., i])
        
        # Determine the min and max for the color scale of this example
        zmin = example_tensor.min()
        zmax = example_tensor.max()
        
        heatmap = px.imshow(
            example_tensor,
            color_continuous_midpoint=0.0,  # Use the same midpoint
            color_continuous_scale="RdBu",  # Use the same color scale
            labels={"x": xaxis, "y": yaxis},
            zmin=zmin,
            zmax=zmax
        )
        
        # Add the heatmap to the subplot
        for trace in heatmap.data:
            fig.add_trace(trace, row=row, col=col)
    
    # Update layout
    fig.update_layout(title_text=title, height=300 * num_rows, width=300 * num_cols)
    
    # Show the plot
    fig.show(renderer)

induction_score_store = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
per_example_score_store = t.zeros((model.cfg.n_layers, model.cfg.n_heads, len(output_pos)), device=model.cfg.device)
per_example_accumulated_score_store = t.zeros((model.cfg.n_layers, model.cfg.n_heads, len(output_pos)), device=model.cfg.device)
all_examples_score_store = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
all_example_accumulated_score_store = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

def induction_score_hook(
    pattern: Float[t.Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

def accumulated_attention_hook(
    pattern: Float[t.Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
    output_positions,
    feature_positions
):
    """
    Hook to measure accumulated attention from current output positions to all previous output and feature positions.
    
    Args:
        pattern: Attention pattern tensor with shape [batch, head_index, dest_pos, source_pos]
        hook: HookPoint object containing layer information
        output_positions: List of positions of first number after "Output:"
        feature_positions: List of positions of first numbers after "Features:"
    """
    batch_size = pattern.shape[0]
    n_heads = pattern.shape[1]
    scores = []
    
    # For each output position
    for i, output_pos in enumerate(output_positions):
        # Get all previous output and feature positions
        relevant_positions = [pos for pos in output_positions if pos < output_pos] + \
                             [pos for pos in feature_positions if pos < output_pos]
        
        # Get attention scores from current output position to all previous relevant positions
        # Shape: [head_index, 1, source_pos]
        output_attention = pattern[0, :, output_pos:output_pos+1, :]  # Using first batch element
        
        # Calculate mean attention to the relevant positions
        # Shape: [head_index]
        if relevant_positions:
            accumulated_attention = output_attention[:, 0, relevant_positions].mean(dim=-1)
            per_example_accumulated_score_store[hook.layer(), : , i] = accumulated_attention
            scores.append(accumulated_attention)
    
    # Average across outputs
    if scores:
        example_score = t.stack(scores).mean(dim=0)
        # Store the result in the global store
        all_example_accumulated_score_store[hook.layer(), :] = example_score

def all_example_hook(
    pattern: Float[t.Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
    output_positions,
    feature_positions
):
    """
    Hook to measure attention from output positions to previous feature positions.
    
    Args:
        pattern: Attention pattern tensor with shape [batch, head_index, dest_pos, source_pos]
        hook: HookPoint object containing layer information
        output_positions: List of positions of first number after "Output:"
        feature_positions: List of positions of first numbers after "Features:"
    """
    batch_size = pattern.shape[0]
    n_heads = pattern.shape[1]
    scores = []
    
    # For each output position
    for i, output_pos in enumerate(output_positions):
        # Get the 3 relevant feature positions that come before this output
        relevant_feature_pos = [pos for pos in feature_positions if pos < output_pos][-3:]
        
        
        # Get attention scores from output position to feature positions
        # Shape: [head_index, 1, source_pos]
        output_attention = pattern[0, :, output_pos:output_pos+1, :]  # Using first batch element
        
        # Calculate mean attention to the relevant feature positions
        # Shape: [head_index]
        feature_attention = output_attention[:, 0, relevant_feature_pos].mean(dim=-1)
        per_example_score_store[hook.layer(), : , i] = feature_attention
        scores.append(feature_attention)
    
    # Average across outputs
    example_score = t.stack(scores).mean(dim=0)
    
    # Store the result in the global store
    all_examples_score_store[hook.layer(), :] = example_score

# We make a boolean filter on activation names, that's true only on attention pattern names.
pattern_hook_names_filter = lambda name: name.endswith("pattern")

model.run_with_hooks(
    tokenized_prompt, 
    return_type=None,  # For efficiency, we don't need to calculate the logits
    fwd_hooks=[
        (
            pattern_hook_names_filter,
            induction_score_hook
        ),
        (
            pattern_hook_names_filter,
            functools.partial(
                all_example_hook,
                output_positions=output_pos,  # Use positions for the first seed
                feature_positions=feature_pos  # Use positions for the first seed
            )
        ),
        (
            pattern_hook_names_filter,
            functools.partial(
                accumulated_attention_hook,
                output_positions=output_pos,  # Use positions for the first seed
                feature_positions=feature_pos  # Use positions for the first seed
            )
        )
    ]
)

imshow(induction_score_store, xaxis="Head", yaxis="Layer", title="Induction Score by Head")
imshow(all_examples_score_store, xaxis="Head", yaxis="Layer", title="Per Example Score Avg by Head")
imshow(all_example_accumulated_score_store, xaxis="Head", yaxis="Layer", title="Accumulated Example Score Avg by Head")
imshow_multi(per_example_score_store, per_example_score_store.shape[-1], xaxis="Head", yaxis="Layer", title="Example Scores by Head")
imshow_multi(per_example_accumulated_score_store, per_example_accumulated_score_store.shape[-1], xaxis="Head", yaxis="Layer", title="Accumulated Example by Head")

for i in range(per_example_score_store.shape[-1]):
    imshow(per_example_score_store[..., i], xaxis="Head", yaxis="Layer", title=f"Example {i} Score by Head")

for i in range(per_example_accumulated_score_store.shape[-1]):
    imshow(per_example_accumulated_score_store[..., i], xaxis="Head", yaxis="Layer", title=f"Accumulated Example {i} Score by Head")


In [None]:
import torch as t
import numpy as np
import functools
from tqdm import tqdm

# Define the number of seeds
num_seeds = 100

output_pos, feature_pos = check_token_positions(model, dataset, seq_len, print_info=False)


# Initialize accumulators for scores
induction_score_accumulator = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
all_examples_score_accumulator = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
all_example_accumulated_score_accumulator = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
per_example_score_store_accumulator = t.zeros((model.cfg.n_layers, model.cfg.n_heads, len(output_pos)), device=model.cfg.device)
per_example_accumulated_score_store_accumulator = t.zeros((model.cfg.n_layers, model.cfg.n_heads, len(output_pos)), device=model.cfg.device)

# Loop over seeds
for seed in tqdm(range(num_seeds), desc="Running seeds"):
    # Get tokenized prompt for the current seed
    tokenized_prompt = get_prompt(seq_len, seed, print_prompt=False)
    output_pos, feature_pos = check_token_positions(model, dataset, seq_len, print_info=False)

    # Run the model with hooks
    model.run_with_hooks(
        tokenized_prompt, 
        return_type=None,  # For efficiency, we don't need to calculate the logits
        fwd_hooks=[
            (
                pattern_hook_names_filter,
                induction_score_hook
            ),
            (
                pattern_hook_names_filter,
                functools.partial(
                    all_example_hook,
                    output_positions=output_pos,  # Use positions for the current seed
                    feature_positions=feature_pos  # Use positions for the current seed
                )
            ),
            (
                pattern_hook_names_filter,
                functools.partial(
                    accumulated_attention_hook,
                    output_positions=output_pos,  # Use positions for the current seed
                    feature_positions=feature_pos  # Use positions for the current seed
                )
            )
        ]
    )

    # Accumulate scores
    induction_score_accumulator += induction_score_store
    all_examples_score_accumulator += all_examples_score_store
    all_example_accumulated_score_accumulator += all_example_accumulated_score_store
    per_example_accumulated_score_store_accumulator += per_example_accumulated_score_store
    per_example_score_store_accumulator += per_example_score_store

# Average the scores
induction_score_avg = induction_score_accumulator / num_seeds
all_examples_score_avg = all_examples_score_accumulator / num_seeds
all_example_accumulated_score_avg = all_example_accumulated_score_accumulator / num_seeds
per_example_accumulated_score_store_avg = per_example_accumulated_score_store_accumulator / num_seeds
per_example_score_store_avg = per_example_score_store_accumulator / num_seeds

# Plot the averaged scores
imshow(induction_score_avg, xaxis="Head", yaxis="Layer", title="Average Induction Score by Head")
imshow(all_examples_score_avg, xaxis="Head", yaxis="Layer", title="Average Per Example Score by Head")
imshow(all_example_accumulated_score_avg, xaxis="Head", yaxis="Layer", title="Average Accumulated Example Score by Head")
imshow_multi(per_example_score_store_avg, per_example_score_store_avg.shape[-1], xaxis="Head", yaxis="Layer", title="Example Scores by Head")
imshow_multi(per_example_accumulated_score_store_avg, per_example_accumulated_score_store_avg.shape[-1], xaxis="Head", yaxis="Layer", title="Accumulated Example by Head")

for i in range(per_example_score_store_avg.shape[-1]):
    imshow(per_example_accumulated_score_store_avg[..., i], xaxis="Head", yaxis="Layer", title=f"Accumulated Example {i} Score by Head")

for i in range(per_example_score_store_avg.shape[-1]):
    imshow(per_example_accumulated_score_store_avg[..., i], xaxis="Head", yaxis="Layer", title=f"Accumulated Example {i} Score by Head")

In [None]:
import plotly.graph_objects as go

def visualize_mlp_layers(model, input_tokens, num_last_layers=3):
    """
    Visualize MLP activations for the last few layers using plotly with zero values in white
    """
    mlp_activations = {}
    
    def mlp_hook(act, hook, layer_num):
        if layer_num >= model.cfg.n_layers - num_last_layers:
            mlp_activations[f'layer_{layer_num}'] = act.detach()
        return act

    # Create hooks for each layer
    hooks = []
    for layer_num in range(model.cfg.n_layers):
        hooks.append((
            f'blocks.{layer_num}.hook_mlp_in',
            lambda act, hook, ln=layer_num: mlp_hook(act, hook, ln)
        ))

    # Run model with hooks
    model.set_use_hook_mlp_in(True)
    _ = model.run_with_hooks(
        input_tokens,
        fwd_hooks=hooks
    )
    
    # Custom colorscale with white at zero
    custom_colorscale = [
        [0, 'blue'],        # Negative values
        [0.5, 'white'],     # Zero
        [1, 'red']          # Positive values
    ]
    
    # Create subplots
    fig = make_subplots(
        rows=num_last_layers, 
        cols=1,
        subplot_titles=[f'MLP Input Activations - Layer {layer_num}' 
                       for layer_num in range(model.cfg.n_layers - num_last_layers, model.cfg.n_layers)]
    )
    
    # Add heatmaps to subplots
    for idx, layer_num in enumerate(range(model.cfg.n_layers - num_last_layers, model.cfg.n_layers)):
        layer_key = f'layer_{layer_num}'
        if layer_key in mlp_activations:
            acts = mlp_activations[layer_key].squeeze(0)  # Remove batch dimension
            acts_np = acts.cpu().numpy()
            max_abs_val = float(max(abs(acts.min()), abs(acts.max())))
            
            heatmap = go.Heatmap(
                z=acts_np,
                colorscale=custom_colorscale,
                zmid=0,  # Set middle of color scale to zero
                zmin=-max_abs_val,  # Symmetric color scale
                zmax=max_abs_val,
                showscale=True,
                colorbar=dict(title='Activation Value')
            )
            
            fig.add_trace(heatmap, row=idx+1, col=1)
    
    # Update layout
    fig.update_layout(
        height=400 * num_last_layers,
        width=1000,
        title_text="MLP Layer Activations",
        showlegend=False
    )
    
    # Update axes labels
    for i in range(num_last_layers):
        fig.update_xaxes(title_text="Hidden Dimension", row=i+1, col=1)
        fig.update_yaxes(title_text="Token Position", row=i+1, col=1)
    
    fig.show()

# Create a custom prompt
custom_prompt = """Let's solve a math problem step by step:
Input: Calculate 2 + 2
Step 1: We start with the number 2
Step 2: We add another 2
Output: The result is 4"""

# Convert prompt to tokens
custom_tokens = model.to_tokens(custom_prompt)

# Run visualizations with both prompts
print("Visualizing original prompt:")
visualize_mlp_layers(model, tokenized_prompt, num_last_layers=3)
# analyze_mlp_statistics(model, tokenized_prompt, num_last_layers=3)

print("\nVisualizing custom prompt:")
visualize_mlp_layers(model, custom_tokens, num_last_layers=3)
# analyze_mlp_statistics(model, custom_tokens, num_last_layers=3)

In [None]:
import plotly.graph_objects as go

def analyze_significant_activations(model, input_tokens, num_last_layers=3, std_threshold=23):
    """
    Analyze and list activations that are significantly above the mean (default: > 2 standard deviations)
    
    Args:
        model: The transformer model
        input_tokens: Input tokens
        num_last_layers: Number of last layers to analyze
        std_threshold: Number of standard deviations above mean to consider significant
    """
    mlp_activations = {}
    
    def mlp_hook(act, hook, layer_num):
        if layer_num >= model.cfg.n_layers - num_last_layers:
            mlp_activations[f'layer_{layer_num}'] = act.detach()
        return act

    # Create hooks for each layer
    hooks = []
    for layer_num in range(model.cfg.n_layers):
        hooks.append((
            f'blocks.{layer_num}.hook_mlp_in',
            lambda act, hook, ln=layer_num: mlp_hook(act, hook, ln)
        ))

    # Run model with hooks
    model.set_use_hook_mlp_in(True)
    _ = model.run_with_hooks(
        input_tokens,
        fwd_hooks=hooks
    )
    
    # Analyze each layer
    for layer_num in range(model.cfg.n_layers - num_last_layers, model.cfg.n_layers):
        layer_key = f'layer_{layer_num}'
        if layer_key in mlp_activations:
            acts = mlp_activations[layer_key].squeeze(0)  # Remove batch dimension
            
            # Calculate statistics
            mean_act = acts.mean().item()
            std_act = acts.std().item()
            threshold = mean_act + (std_threshold * std_act)
            
            # Find significant activations
            significant_mask = acts > threshold
            if significant_mask.any():
                print(f"\nLayer {layer_num} Significant Activations:")
                print(f"Mean: {mean_act:.4f}, Std: {std_act:.4f}, Threshold: {threshold:.4f}")
                
                # Get positions and values of significant activations
                positions = t.nonzero(significant_mask)
                for pos in positions:
                    token_idx, hidden_idx = pos.tolist()
                    value = acts[token_idx, hidden_idx].item()
                    std_above_mean = (value - mean_act) / std_act
                    
                    # Get the actual token if possible
                    token = input_tokens[0, token_idx].item()
                    try:
                        token_str = model.tokenizer.decode([token])
                    except:
                        token_str = f"Token ID: {token}"
                    
                    print(f"Token: {token_str} (pos {token_idx}), "
                          f"Hidden dim: {hidden_idx}, "
                          f"Value: {value:.4f} "
                          f"({std_above_mean:.2f} σ above mean)")
                
                # Create visualization of significant activations
                fig = go.Figure()
                
                # Add heatmap of only significant activations
                significant_acts = acts.clone()
                significant_acts[~significant_mask] = float('nan')
                
                heatmap = go.Heatmap(
                    z=significant_acts.cpu().numpy(),
                    colorscale='Viridis',
                    showscale=True,
                    colorbar=dict(title='Activation Value'),
                    hoverongaps=False
                )
                
                fig.add_trace(heatmap)
                
                fig.update_layout(
                    title=f"Layer {layer_num} Significant Activations (>{std_threshold}σ above mean)",
                    xaxis_title="Hidden Dimension",
                    yaxis_title="Token Position",
                    width=1000,
                    height=400
                )
                
                fig.show()

# Example usage:
analyze_significant_activations(model, tokenized_prompt, num_last_layers=3, std_threshold=2)

In [None]:
def analyze_mlp_for_specific_tokens(model, input_tokens, output_pos, feature_pos, num_last_layers=10):
    """
    Analyze MLP activations specifically for output and feature number tokens
    with zero values shown in white
    """
    mlp_activations = {}
    
    def mlp_hook(act, hook, layer_num):
        if layer_num >= model.cfg.n_layers - num_last_layers:
            mlp_activations[f'layer_{layer_num}'] = act.detach()
        return act

    # Create hooks for each layer
    hooks = []
    for layer_num in range(model.cfg.n_layers):
        hooks.append((
            f'blocks.{layer_num}.hook_mlp_out',
            lambda act, hook, ln=layer_num: mlp_hook(act, hook, ln)
        ))

    # Run model with hooks
    model.set_use_hook_mlp_in(True)
    _ = model.run_with_hooks(
        input_tokens,
        fwd_hooks=hooks
    )
    
    # Custom colorscale with white at zero
    custom_colorscale = [
        [0, 'blue'],        # Negative values
        [0.5, 'white'],     # Zero
        [1, 'red']          # Positive values
    ]
    
    for layer_num in range(model.cfg.n_layers - num_last_layers, model.cfg.n_layers):
        layer_key = f'layer_{layer_num}'
        if layer_key in mlp_activations:
            acts = mlp_activations[layer_key].squeeze(0)
            
            fig = make_subplots(
                rows=2, cols=1,
                subplot_titles=(
                    f'Layer {layer_num} - Output Number Tokens',
                    f'Layer {layer_num} - Feature Number Tokens'
                )
            )
            
            # Plot output token activations
            output_acts = acts[output_pos, :]
            output_acts_np = output_acts.cpu().numpy()
            max_abs_val = float(max(abs(output_acts.min()), abs(output_acts.max())))  # Convert to float
            
            fig.add_trace(
                go.Heatmap(
                    z=output_acts_np,
                    colorscale=custom_colorscale,
                    zmid=0,  # Set middle of color scale to zero
                    zmin=-max_abs_val,  # Symmetric color scale
                    zmax=max_abs_val,
                    showscale=True,
                    colorbar=dict(title='Activation Value', y=0.85, len=0.35)
                ),
                row=1, col=1
            )
            
            # Plot feature token activations
            feature_acts = acts[feature_pos, :]
            feature_acts_np = feature_acts.cpu().numpy()
            max_abs_val = float(max(abs(feature_acts.min()), abs(feature_acts.max())))  # Convert to float
            
            fig.add_trace(
                go.Heatmap(
                    z=feature_acts_np,
                    colorscale=custom_colorscale,
                    zmid=0,  # Set middle of color scale to zero
                    zmin=-max_abs_val,  # Symmetric color scale
                    zmax=max_abs_val,
                    showscale=True,
                    colorbar=dict(title='Activation Value', y=0.35, len=0.35)
                ),
                row=2, col=1
            )
            
            # Update layout
            fig.update_layout(
                height=800,
                width=1000,
                title_text=f"MLP Activations for Output and Feature Tokens - Layer {layer_num}"
            )
            
            # Update axes labels
            fig.update_xaxes(title_text="Hidden Dimension", row=1, col=1)
            fig.update_xaxes(title_text="Hidden Dimension", row=2, col=1)
            fig.update_yaxes(title_text="Token Index", row=1, col=1)
            fig.update_yaxes(title_text="Token Index", row=2, col=1)
            
            fig.show()
            
            # Print statistics
            print(f"\nLayer {layer_num} Statistics:")
            print("\nOutput Token Statistics:")
            print(f"Mean activation: {output_acts.mean():.4f}")
            print(f"Std deviation: {output_acts.std():.4f}")
            print(f"Max activation: {output_acts.max():.4f}")
            print(f"Min activation: {output_acts.min():.4f}")
            print(f"Sparsity: {(output_acts == 0).float().mean() * 100:.2f}%")
            
            print("\nFeature Token Statistics:")
            print(f"Mean activation: {feature_acts.mean():.4f}")
            print(f"Std deviation: {feature_acts.std():.4f}")
            print(f"Max activation: {feature_acts.max():.4f}")
            print(f"Min activation: {feature_acts.min():.4f}")
            print(f"Sparsity: {(feature_acts == 0).float().mean() * 100:.2f}%")

# Get the positions and run the analysis
output_pos, feature_pos = check_token_positions(model, dataset, seq_len, print_info=False)
tok = model.to_tokens("HI MOM")
analyze_mlp_for_specific_tokens(model, tok, output_pos, feature_pos, num_last_layers=10)