<a href="https://colab.research.google.com/github/JasonGross/guarantees-based-mechanistic-interpretability/blob/main/notebooks_jason/max_of_2_grokking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exploring Grokking in a Very Simple Max of 2 Model

## Introduction

## Setup

In [75]:
! pip install git+https://github.com/JasonGross/guarantees-based-mechanistic-interpretability.git@e489d823133de86822ddd5634c7364507c9ac998

Collecting git+https://github.com/JasonGross/guarantees-based-mechanistic-interpretability.git@e489d823133de86822ddd5634c7364507c9ac998
  Cloning https://github.com/JasonGross/guarantees-based-mechanistic-interpretability.git (to revision e489d823133de86822ddd5634c7364507c9ac998) to /tmp/pip-req-build-_bhvzl4c
  Running command git clone --filter=blob:none --quiet https://github.com/JasonGross/guarantees-based-mechanistic-interpretability.git /tmp/pip-req-build-_bhvzl4c
  Running command git rev-parse -q --verify 'sha^e489d823133de86822ddd5634c7364507c9ac998'
  Running command git fetch -q https://github.com/JasonGross/guarantees-based-mechanistic-interpretability.git e489d823133de86822ddd5634c7364507c9ac998
  Running command git checkout -q e489d823133de86822ddd5634c7364507c9ac998
  Resolved https://github.com/JasonGross/guarantees-based-mechanistic-interpretability.git to commit e489d823133de86822ddd5634c7364507c9ac998
  Installing build dependencies ... [?25l[?25hdone
  Getting re

In [76]:
! wandb login --anonymously

[34m[1mwandb[0m: Currently logged in as: [33manony-moose-397685823686906646[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [77]:
from tqdm import tqdm
import math
import os
import imageio
from gbmi.exp_max_of_n.train import (
    FullDatasetCfg,
    MaxOfN,
    train_or_load_model,
)
from gbmi.model import Config, RunData
from transformer_lens import HookedTransformerConfig, HookedTransformer
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
import wandb
from jaxtyping import Float
from torch import Tensor
from typing import (
    Tuple,
    Dict,
    Optional,
    Any,
    List,
)

api = wandb.Api()

## Introduction

Consider a 1 layer attention-only transformer with no normalization trained on inputs of the form $(a, b, =)$ (for $0 \le a, b < 64$) to predict $\max(a, b)$.  Inputs are one-hot encoded.

The training dataset is all sequences of the form $(a, a\pm i, =)$ for $i\in \{0, 1, 2, 17\}$.  17 is chosen to be a medium-size number coprime to 2 and 17.

`=` is encoded as token $-1$

### Model configuration

In [78]:
seq_len = 2 # training data setup code only works for sequence length 2
vocab = 64 #@param {type:"number"}
d_head = 32 #@param {type:"number"}
d_model = 32 #@param {type:"number"}
model_seed = 613947648 #@param {type:"number"}
seed = 123 #@param {type:"number"}
force_adjacent = (0, 1, 2, 17) #@param
lr = 0.001 #@param {type:"number"}
betas = (0.9, 0.98) #@param
weight_decay = 1.0 #@param {type:"number"}
optimizer = "AdamW" #@param ["AdamW", "Adam"]
deterministic = True #@param {type:"boolean"}
# list out the number here explicitly so that it matches with what is saved in wandb
training_ratio = 0.099609375 #@param {type:"number"}
expected_training_ratio = ((vocab if 0 in force_adjacent else 0) + 2 * sum(vocab - i for i in force_adjacent if i)) / vocab**seq_len
if abs(training_ratio - expected_training_ratio) > 1e-5:
    f"training_ratio should probably be float.from_hex('{expected_training_ratio.hex()}') ({expected_training_ratio})"
batch_size = int(round(training_ratio * vocab ** seq_len))
epochs_to_train_for = 3000 #@param {type:"number"}
include_biases = False #@param {type:"boolean"}
cfg = Config(
    experiment=MaxOfN(
        model_config=HookedTransformerConfig(
            act_fn=None,
            attn_only=True,
            d_head=d_head,
            d_mlp=None,
            d_model=d_model,
            d_vocab=vocab + 1,
            d_vocab_out=vocab,
            default_prepend_bos=True,
            device="cpu" if deterministic else None,
            dtype=torch.float32,
            n_ctx=seq_len + 1,
            n_heads=1,
            n_layers=1,
            normalization_type=None,
            seed=model_seed,
        ),
        zero_biases=not include_biases,
        use_log1p=True,
        use_end_of_sequence=True,
        seq_len=2,
        train_dataset_cfg=FullDatasetCfg(
            force_adjacent=force_adjacent,
            training_ratio=training_ratio,
        ),
        test_dataset_cfg=FullDatasetCfg(
            force_adjacent=force_adjacent,
            training_ratio=training_ratio,
        ),
        optimizer_kwargs=dict(lr=lr, betas=betas, weight_decay=weight_decay),
        optimizer=optimizer,
    ),
    deterministic=deterministic,
    seed=seed,
    batch_size=batch_size,
    train_for=(epochs_to_train_for, "epochs"),
    log_every_n_steps=10,
    validate_every=(10, "epochs"),
    checkpoint_every=(10, "epochs"),
)


### Model Training / Loading

In [79]:
# Load (or train) the model
force = "load" #@param ["load", "train", "allow either"]
if force == "allow either": force = None
runtime, model = train_or_load_model(cfg, force=force)

INFO: Seed set to 123
INFO:lightning.fabric.utilities.seed:Seed set to 123
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [80]:
# load all model versions
models = runtime.model_versions(cfg, max_count=3000, step=1)
assert models is not None
models = list(models)

  0%|          | 0/301 [00:00<?, ?it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  0%|          | 1/301 [00:00<02:34,  1.94it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  1%|          | 2/301 [00:01<02:34,  1.94it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  1%|          | 3/301 [00:01<02:45,  1.80it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  1%|▏         | 4/301 [00:02<02:49,  1.75it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  2%|▏         | 5/301 [00:02<02:50,  1.73it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  2%|▏         | 6/301 [00:03<02:43,  1.80it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  2%|▏         | 7/301 [00:03<02:40,  1.84it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  3%|▎         | 8/301 [00:04<02:36,  1.87it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  3%|▎         | 9/301 [00:04<02:40,  1.82it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  3%|▎         | 10/301 [00:05<02:57,  1.64it/s]

## Basic Interpretation

The model works by attending to the largest element and copying that elment.  Let's validate this with some basic plots.

In [85]:
#@title interpretation functions
@torch.no_grad()
def compute_QK(model: HookedTransformer = model) -> dict:
    W_E, W_pos, W_Q, W_K = (
        model.W_E,
        model.W_pos,
        model.W_Q,
        model.W_K,
    )
    QK = (
        (W_E[-1] + W_pos[-1])
        @ W_Q[0, 0]
        @ W_K[0, 0].T
        @ (W_E[:-1] + W_pos[:-1].mean(dim=0, keepdim=True)).T
    )
    QK_last = (W_E[-1] + W_pos[-1]) @ W_Q[0, 0] @ W_K[0, 0].T @ (W_E[-1] + W_pos[-1]).T
    return {
        "data": QK - QK_last,
        "title": "Attention Score<br>QK[p] := (W<sub>E</sub>[-1] + W<sub>pos</sub>[-1]) @ W<sub>Q</sub> @ W<sub>K</sub><sup>T</sup> @ (W<sub>E</sub> + W<sub>pos</sub>[p])<sup>T</sup><br>QK[:-1,:-1].mean(dim=0) - QK[-1, -1]",
        "xaxis": "input token",
        "yaxis": "attention score pre-softmax",
    }


@torch.no_grad()
def compute_OV(model: HookedTransformer = model, centered: bool = True) -> dict:
    W_E, W_pos, W_V, W_O, W_U = (
        model.W_E,
        model.W_pos,
        model.W_V,
        model.W_O,
        model.W_U,
    )
    OV = (W_E[:-1] + W_pos[:-1].mean(dim=0)) @ W_V[0, 0] @ W_O[0, 0] @ W_U
    result: dict = {"xaxis": "output logit token", "yaxis": "input token"}
    if not centered:
        result.update(
            {
                "data": OV,
                "title": "Attention Computation: (W<sub>E</sub>[:-1] + W<sub>pos</sub>[:-1].mean(dim=0)) @ W<sub>V</sub> @ W<sub>O</sub> @ W<sub>U</sub>",
            }
        )
        return result
    result.update(
        {
            "data": OV - OV.diag()[:, None],
            "title": "Attention Computation (centered)<br>OV := (W<sub>E</sub>[:-1] + W<sub>pos</sub>[:-1].mean(dim=0)) @ W<sub>V</sub> @ W<sub>O</sub> @ W<sub>U</sub><br>OV - OV.diag()[:, None]",
        }
    )
    return result

@torch.no_grad()
def compute_QK_by_position(model: HookedTransformer = model) -> dict:
    W_E, W_pos, W_Q, W_K = (
        model.W_E,
        model.W_pos,
        model.W_Q,
        model.W_K,
    )
    QK = (
        (W_E[-1] + W_pos[-1])
        @ W_Q[0, 0]
        @ W_K[0, 0].T
        @ (W_pos[:-1] - W_pos[:-1].mean(dim=0)).T
    )
    return {
        "data": {"QK": QK},
        "title": "Positional Contribution to Attention Score<br>(W<sub>E</sub>[-1] + W<sub>pos</sub>[-1]) @ W<sub>Q</sub> @ W<sub>K</sub><sup>T</sup> @ (W<sub>pos</sub>[:-1] - W<sub>pos</sub>[:-1].mean(dim=0))<sup>T</sup>",
        "xaxis": "position",
        "yaxis": "attention score pre-softmax",
    }

@torch.no_grad()
def compute_irrelevant(
    model: HookedTransformer = model, include_equals_OV: bool = False
) -> dict:
    W_E, W_pos, W_V, W_O, W_U = (
        model.W_E,
        model.W_pos,
        model.W_V,
        model.W_O,
        model.W_U,
    )
    data = {
        "(W<sub>E</sub>[-1]+W<sub>pos</sub>[-1]) @ W<sub>U</sub>": (
            (W_E[-1] + W_pos[-1]) @ W_U
        ),
    }
    if include_equals_OV:
        data.update(
            {
                "(W<sub>E</sub>[-1]+W<sub>pos</sub>[-1]) @ W<sub>V</sub> @ W<sub>O</sub> @ W<sub>U</sub>": (
                    (W_E[-1] + W_pos[-1]) @ W_V[0, 0] @ W_O[0, 0] @ W_U
                ),
            }
        )
    data.update(
        {
            f"(W<sub>pos</sub>[{i}] - W<sub>pos</sub>[:-1].mean(dim=0)) @ W<sub>V</sub> @ W<sub>O</sub> @ W<sub>U</sub>": (
                (W_pos[i] - W_pos[:-1].mean(dim=0))
                @ W_V[0, 0, :, :]
                @ W_O[0, 0, :, :]
                @ W_U
            )
            for i in range(W_pos.shape[0] - 1)
        }
    )

    return {
        "data": data,
        "title": "Irrelevant Contributions to logits",
        "xaxis": "output logit token",
        "yaxis": "logit value",
    }

In [86]:
#@title display basic interpretation

@torch.no_grad()
def display_basic_interpretation(
    model: HookedTransformer = model,
    include_uncentered: bool = False,
    legend_at_bottom: bool = False,
    include_equals_OV: bool = False,
):
    QK = compute_QK(model)
    px.line(
        {"QK": QK["data"]},
        title=QK["title"],
        labels={
            "index": QK["xaxis"],
            "variable": "",
            "value": QK["yaxis"],
        },
    ).show()

    if include_uncentered:
        OV = compute_OV(model, centered=False)
        px.imshow(
            OV["data"],
            title=OV["title"],
            color_continuous_scale="Picnic_r",
            color_continuous_midpoint=0,
            labels={"x": OV["xaxis"], "y": OV["yaxis"]},
        ).show()
    OV = compute_OV(model, centered=True)
    px.imshow(
        OV["data"],
        title=OV["title"],
        color_continuous_scale="Picnic_r",
        labels={"x": OV["xaxis"], "y": OV["yaxis"]},
    ).show()

    pos_QK = compute_QK_by_position(model)
    px.scatter(
        pos_QK["data"],
        title=pos_QK["title"],
        labels={"index": pos_QK["xaxis"], "variable": "", "value": pos_QK["yaxis"]},
    ).show()

    irrelevant = compute_irrelevant(model, include_equals_OV=include_equals_OV)
    fig = px.scatter(
        irrelevant["data"],
        title=irrelevant["title"],
        labels={
            "index": irrelevant["xaxis"],
            "variable": "",
            "value": irrelevant["yaxis"],
        },
    )
    if legend_at_bottom:
        fig.update_layout(
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=-0.5,
                xanchor="center",
                x=0.5,
            )
        )
    fig.show()

In [87]:
display_basic_interpretation(model)

## Plotting the Training

Let's plot this analysis, along with the loss and accuracy, across training.

In [84]:
#@title precompute loss and accuracy lists
def group_metrics_by_epoch(runtime: RunData) -> Dict[str, Dict[int, Any]]:
    result = {}
    max_epoch = 0
    for metric in runtime.train_metrics or []:
        epoch = metric["epoch"]
        max_epoch = max(max_epoch, epoch)
        for k, v in metric.items():
            if k not in ("epoch", "step"):
                result.setdefault(k, {})[epoch] = (
                    v.item() if isinstance(v, torch.Tensor) else v
                )
    return result

metrics = group_metrics_by_epoch(runtime)

def get_epochs_and_metric(
    metric_name: str,
    epoch: Optional[int],
    metrics: Dict[str, Dict[int, float]] = metrics,
) -> Tuple[List[int], List[Any]]:
    values = metrics[metric_name]
    epochs = [i for i in sorted(values.keys()) if epoch is None or i <= epoch]
    return epochs, [values[i] for i in epochs]


In [74]:
#@title plot
# Create a subplot with 3 rows (1 for attention, 2 for losses, 3 for accuracies)
fig = make_subplots(
    rows=3,
    cols=1,
    subplot_titles=("Attention Plot", "Loss Plot", "Accuracy Plot"),
    # vertical_spacing=0.15,
)

# Lists to hold frames and slider steps
frames = []
slider_steps = []

# Variable to track the maximum values for each plot
max_abs_value_attention = 0
max_value_losses = 0
max_value_accuracies = 0
all_max_abs_value_attention = []
all_max_value_losses = []
all_max_value_accuracies = []

with torch.no_grad():
    for i, (_version, old_data, _artifact) in enumerate(models):
        assert old_data is not None
        old_runtime, old_model = old_data
        epoch = old_runtime.epoch
        W_E, W_pos, W_Q, W_K = (
            old_model.W_E,
            old_model.W_pos,
            old_model.W_Q[0, 0, :, :],
            old_model.W_K[0, 0, :, :],
        )
        overlap = (
            (W_E[-1] + W_pos[-1])
            @ W_Q
            @ W_K.T
            @ (W_E[:-1, :] + W_pos[:-1, 0].mean(dim=0)).T
        )

        # Update the max_abs_value for the attention plot
        current_max_attention = torch.max(torch.abs(overlap)).item()
        max_abs_value_attention = max(max_abs_value_attention, current_max_attention)

        training_losses_epochs, training_losses = get_epochs_and_metric("loss", epoch)
        training_accuracies_epochs, training_accuracies = get_epochs_and_metric(
            "acc", epoch
        )
        test_losses_epochs, test_losses = get_epochs_and_metric(
            "periodic_test_loss", epoch
        )
        test_accuracies_epochs, test_accuracies = get_epochs_and_metric(
            "periodic_test_acc", epoch
        )

        # Update the max_value for the loss and accuracy plots
        max_value_losses = max(max(training_losses), max(test_losses))
        max_value_accuracies = max(max(training_accuracies), max(test_accuracies))

        # Update the max values for all plots
        all_max_abs_value_attention.append(max_abs_value_attention)
        all_max_value_losses.append(max_value_losses)
        all_max_value_accuracies.append(max_value_accuracies)

        # Add a trace for the initial plot (first data point) in all subplots
        if i == 0:
            # Attention plot trace
            fig.add_trace(
                go.Scatter(
                    x=list(range(len(overlap))),
                    y=overlap,
                    mode="lines",
                    name="(E+P)<sub>-1</sub>QK<sup>T</sup>(E+P)<sup>T</sup>",
                ),
                row=1,
                col=1,
            )
            # Loss plot traces
            fig.add_trace(
                go.Scatter(
                    x=training_losses_epochs,
                    y=training_losses,
                    mode="lines",
                    name="Training Loss",
                ),
                row=2,
                col=1,
            )
            fig.add_trace(
                go.Scatter(
                    x=test_losses_epochs, y=test_losses, mode="lines", name="Test Loss"
                ),
                row=2,
                col=1,
            )
            # Accuracy plot traces
            fig.add_trace(
                go.Scatter(
                    x=training_accuracies_epochs,
                    y=training_accuracies,
                    mode="lines",
                    name="Training Accuracy",
                ),
                row=3,
                col=1,
            )
            fig.add_trace(
                go.Scatter(
                    x=test_accuracies_epochs,
                    y=test_accuracies,
                    mode="lines",
                    name="Test Accuracy",
                ),
                row=3,
                col=1,
            )

        # Frame data for the attention plot
        frame_data_attention = go.Scatter(
            x=list(range(len(overlap))), y=overlap, mode="lines"
        )

        # Frame data for the loss and accuracy plots
        frame_data_losses = [
            go.Scatter(
                x=training_losses_epochs,
                y=training_losses,
                mode="lines",
                name="Training Loss",
            ),
            go.Scatter(
                x=test_losses_epochs,
                y=test_losses,
                mode="lines",
                name="Test Loss",
            ),
        ]
        frame_data_accuracies = [
            go.Scatter(
                x=training_accuracies_epochs,
                y=training_accuracies,
                mode="lines",
                name="Training Accuracy",
            ),
            go.Scatter(
                x=test_accuracies_epochs,
                y=test_accuracies,
                mode="lines",
                name="Test Accuracy",
            ),
        ]

        # Create a frame combining all plots
        frame = go.Frame(
            data=[frame_data_attention] + frame_data_losses + frame_data_accuracies,
            name=str(epoch),
            traces=[0, 1, 2, 3, 4, 5, 6],  # Indices of the traces in this frame
            layout=go.Layout(
                yaxis={
                    "range": [-max_abs_value_attention, max_abs_value_attention]
                },  # Attention plot
                yaxis2={"range": [0, max_value_losses]},  # Loss plot
                yaxis3={"range": [0, max_value_accuracies]},  # Accuracy plot
            ),
        )
        frames.append(frame)

        # Add a step to the slider
        slider_step = dict(
            method="animate",
            args=[
                [str(epoch)],
                {
                    "frame": {"duration": 0, "redraw": True},
                    "mode": "immediate",
                    "transition": {"duration": 0},
                },
            ],
            label=str(epoch),
        )
        slider_steps.append(slider_step)

# Add frames to the figure
fig.frames = frames

# Update layout for the figure
fig.update_layout(
    xaxis_title="Input Token",
    xaxis2_title="Epoch",
    xaxis3_title="Epoch",
    title="Model Analysis: Attention, Losses, and Accuracies Over Epochs",
    updatemenus=[
        {
            "type": "buttons",
            "showactive": False,
            "buttons": [
                {
                    "label": "Play",
                    "method": "animate",
                    "args": [
                        None,
                        {
                            "frame": {"duration": 500, "redraw": True},
                            "fromcurrent": True,
                            "transition": {"duration": 300, "easing": "linear"},
                            "mode": "immediate",
                            "repeat": True,
                        },
                    ],
                },
                {
                    "label": "Pause",
                    "method": "animate",
                    "args": [
                        [None],
                        {
                            "frame": {"duration": 0, "redraw": False},
                            "mode": "immediate",
                            "transition": {"duration": 0},
                        },
                    ],
                },
            ],
        }
    ],
    sliders=[{"steps": slider_steps, "active": 0}],
)

# Adjust the height of the figure (e.g., if the original height was 600, now set it to 1200)
# fig.update_layout(width=500)
# fig.update_layout(height=600)  # Double the original height

# Show the figure
fig.show()