## Description of this notebook
- This notebook assumes access to my WandB API key and Dashiell's AWS bucket.
- This is a 5-state (10 total), 3-layer, 4-heads-per-layer HookedTransformer transformer with layernorm and MLPs.
- WandB run this pertains to (private): https://wandb.ai/wu-cindyx/devinterp-automata/runs/89pbzt28?nw=nwuserwucindyx.

This contains:
- Activation distribution PCA code and plots (here we just tried resid_mid and mlp_post, but all others are testable)
- Attention patterns for all heads in all layers on random inputs
    - These are done for all forms and also at end of the training process: we have 4 forms and a final model. model_0 is a comparison for very early on in training.
- Attention patterns for all heads in all layers on hand-crafted examples for sequences of all 0s, all 1s, all 0s with a 1 injected at a single position, and the inverse of the latter (all 1s with a 0 injected at a single position)
- Positional embedding and embedding patching, including dot product of columns of W_pos with itself
- OV circuit analysis, including eigenvectors to search for copying circuits

In [None]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys

if IN_COLAB:
    # Code to download the necessary files (e.g. solutions, test funcs)
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")

         # Install packages
        %pip install einops
        %pip install jaxtyping
        %pip install transformer_lens
        %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
        %pip install s3fs
        %pip install omegaconf
        %pip install git+https://github.com/CindyXWu/devinterp-automata.git
        %pip install torch-ema

        !curl -o /content/main.zip https://codeload.github.com/CindyXWu/devinterp-automata/zip/refs/heads/main
        !unzip -o /content/main.zip -d /content/

        sys.path.append("/content/devinterp-automata/")
        os.remove("/content/main.zip")

        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

    CHAPTER = r"chapter1_transformers"
    CHAPTER_DIR = r"./" if CHAPTER in os.listdir() else os.getcwd().split(CHAPTER)[0]
    EXERCISES_DIR = CHAPTER_DIR + f"{CHAPTER}/exercises"
    sys.path.append(EXERCISES_DIR)

In [None]:
from dotenv import load_dotenv
import plotly.express as px
from typing import List, Union, Optional, Dict, Tuple
from jaxtyping import Int, Float

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
import numpy as np
import pandas as pd
import einops
import re
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.utils import to_numpy

import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# For Dashiell's groups code
from copy import deepcopy
from functools import reduce
from itertools import product
import math
import numpy as np
from operator import mul
import torch

torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAIN = __name__ == "__main__"

import wandb
from pathlib import Path
import os
import yaml
import s3fs
from omegaconf import OmegaConf

from di_automata.config_setup import *
from di_automata.constructors import (
    construct_model,
    create_dataloader_hf,
)
from di_automata.tasks.data_utils import take_n
import plotly.io as pio

# AWS
load_dotenv()
AWS_KEY, AWS_SECRET = os.getenv("AWS_KEY"), os.getenv("AWS_SECRET")
s3 = s3fs.S3FileSystem(key=AWS_KEY, secret=AWS_SECRET)

In [None]:
from di_automata.interp_utils import (
    imshow_attention,
    line,
    scatter,
    imshow,
    reorder_list_in_plotly_way,
    get_pca,
    get_vars,
    plot_tensor_heatmap,
    get_activations,
    LN_hook_names,
    get_ln_fit,
    cos_sim_with_MLP_weights,
    avg_squared_cos_sim,
    hook_fn_display_attn_patterns,
    hook_fn_patch_qk,
)

from di_automata.tasks.dashiell_groups import (
    DihedralElement,
    DihedralIrrep, 
    ProductDihedralIrrep,
    dihedral_conjugacy_classes, 
    generate_subgroup,
    actions_to_labels,
    get_all_bits,
    dihedral_fourier,
    get_fourier_spectrum,
    analyse_power,
)

In [None]:
group = DihedralElement.full_group(5)

In [None]:
translation = {
    (0,0):0,
    (1,0):1,
    (2,0):2,
    (3,0):3,
    (4,0):4,
    (0,1):5,
    (1,1):6,
    (2,1):7,
    (3,1):8,
    (4,1):9,
}

In [None]:
def imshow_attention(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)


In [None]:
DI_ROOT = Path("/content/devinterp-automata-main/") if IN_COLAB else Path("../")
config_file_path = DI_ROOT / f"scripts/configs/slt_config.yaml"
slt_config = OmegaConf.load(config_file_path)

with open(DI_ROOT / f"scripts/configs/task_config/{slt_config.dataset_type}.yaml", 'r') as file:
    task_config = yaml.safe_load(file)

In [None]:
OmegaConf.set_struct(slt_config, False) # Allow new configuration values to be added
# Because we are in Colab and not VSCode, here is where you want to edit your config values
slt_config["task_config"] = task_config
slt_config["lr"] = 0.0005
slt_config["num_training_iter"] = 100000
slt_config["n_layers"] = 3

# Convert OmegaConf object to MainConfig Pydantic model for dynamic type validation - NECESSARY DO NOT SKIP
pydantic_config = PostRunSLTConfig(**slt_config)
# Convert back to OmegaConf object for compatibility with existing code
slt_config = OmegaConf.create(pydantic_config.model_dump())

print(task_config["dataset_type"])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Run path and name for easy referral later
run_path = f"{slt_config.entity_name}/{slt_config.wandb_project_name}"
run_name = slt_config.run_name
print(run_name)

In [None]:
# Get run information
api = wandb.Api(timeout=3000)
run_list = api.runs(
    path=run_path,
    filters={
        "display_name": run_name,
        "state": "finished",
        },
    order="created_at", # Default descending order so backwards in time
)
assert run_list, f"Specified run {run_name} does not exist"
run_api = run_list[slt_config.run_idx]
try: history = run_api.history()
except: history = run_api.history
loss_history = history["Train Loss"]
accuracy_history = history["Train Acc"]
steps = history["_step"]
time = run_api.config["time"]

In [None]:
def get_config() -> MainConfig:
    """"
    Manually get config from run as artifact.
    WandB also logs automatically for each run, but it doesn't log enums correctly.
    """
    artifact = api.artifact(f"{run_path}/config:{run_name}_{time}")
    data_dir = artifact.download()
    config_path = Path(data_dir) / "config.yaml"
    return OmegaConf.load(config_path)

In [None]:
config = get_config()

# Set total number of unique samples seen (n). If this is not done it will break LLC estimator.
slt_config.rlct_config.sgld_kwargs.num_samples = slt_config.rlct_config.num_samples = config.rlct_config.sgld_kwargs.num_samples
slt_config.nano_gpt_config = config.nano_gpt_config

In [None]:
def restore_state_single_cp(cp_idx: int) -> dict:
    """Restore model state from a single checkpoint.
    Used in _load_logits_states() and _calculate_rlct().

    Args:
        idx_cp: index of checkpoint.

    Returns:
        model state dictionary.
    """
    idx = cp_idx * config.rlct_config.ed_config.eval_frequency * slt_config.skip_cps
    print(f"Getting checkpoint {idx}")
    print(config.model_save_method)
    match config.model_save_method:
        case "wandb":
            artifact = api.artifact(f"{run_path}/states:idx{idx}_{run_name}_{time}")
            data_dir = artifact.download()
            state_path = Path(data_dir) / f"states_{idx}.torch"
            states = torch.load(state_path)
        case "aws":
            with s3.open(f'{config.aws_bucket}/{run_name}_{time}/states_{idx}.pth', mode='rb') as file:
                states = torch.load(file, map_location=device)
    return states["model"]


def load_logits_single_cp(cp_idx: int) -> None:
    """Load just a single cp. 
    This function is designed to be called in multithreading and is called by the above function.
    """
    idx = cp_idx * config.rlct_config.ed_config.eval_frequency * slt_config.skip_cps

    try:
        match config.model_save_method:
            case "wandb":
                artifact = api.artifact(f"{run_path}/logits:logits_cp_{idx}_{run_name}_{time}")
                data_dir = artifact.download()
                logit_path = Path(data_dir) / f"logits_cp_{idx}.torch"
                return torch.load(logit_path)
            case "aws":
                with s3.open(f'{config.aws_bucket}/{run_name}_{time}/logits_cp_{idx}.pth', mode='rb') as file:
                    return torch.load(file)
                
    except Exception as e:
        print(f"Error fetching logits at step {idx}: {e}")

In [None]:
current_directory = Path().absolute()
logits_file_path = current_directory.parent / f"di_automata/logits_{run_name}_{time}"
print(logits_file_path)

In [None]:
ed_loader = create_dataloader_hf(config, deterministic=True) # Make sure deterministic to see same data

### Functions to display attention

In [None]:
def display_layer_heads(att, batch_idx=0):
    """For generic inputs, display attention for particular index in batch.
    """
    display(cv.attention.attention_patterns(
        tokens=list_of_strings(inputs[batch_idx,...]),
        attention=att[batch_idx,...],
        attention_head_names=[f"L0H{i}" for i in range(4)],
    ))
    # 0 is toggle action
    # 1 is drive action
    print(inputs[batch_idx,...])
    print(labels[batch_idx,...])


def list_of_strings(tensor):
    return tensor.numpy().astype(str).tolist()


def display_layer_heads_batch(att: torch.Tensor, cache: ActivationCache, toks: list[str]):
    """TODO: refactor"""
    cv.attention.from_cache(
      cache = cache,
      tokens = toks,
      batch_idx = list(range(10)),
      attention_type = "info-weighted",
      radioitems = True,
      return_mode = "view",
      batch_labels = lambda batch_idx, str_tok_list: format_sequence(str_tok_list, dataset.str_tok_labels[batch_idx]),
      mode = "small",
    )

# Get checkpoints

In [None]:
cp_idxs = [20, 220, 520, 855, 1150, 1500]

In [None]:
# Pre-form
cp_idx_0 = 20
state_0 = restore_state_single_cp(cp_idx_0)
model_0, _ = construct_model(config)
model_0.load_state_dict(state_0)

# Form 1
cp_idx_1 = 220
state_1 = restore_state_single_cp(cp_idx_1)
model_1, _ = construct_model(config)
model_1.load_state_dict(state_1)

# Form 2
cp_idx_2 = 520
state_2 = restore_state_single_cp(cp_idx_2)
model_2, _ = construct_model(config)
model_2.load_state_dict(state_2)

# Form 3
cp_idx_3 = 855
state_3 = restore_state_single_cp(cp_idx_3)
model_3, _ = construct_model(config)
model_3.load_state_dict(state_3)

# Form 4
cp_idx_4 = 1150
state_4 = restore_state_single_cp(cp_idx_4)
model_4, _ = construct_model(config)
model_4.load_state_dict(state_4)

# End
cp_idx_5 = 1500
state_5 = restore_state_single_cp(cp_idx_5)
model_5, _ = construct_model(config)
model_5.load_state_dict(state_5)

Inspect model architecture

In [None]:
model_5

# Activation distribution PCA

In [None]:
from plotnine import ggplot, aes, geom_histogram, geom_point, facet_wrap, scale_x_log10, ggtitle
import itertools
from IPython.display import display
import einops
import polars as pl

In [None]:
data_all_bits = torch.asarray(get_all_bits(16))

In [None]:
actions_to_labels(torch.tensor([0,1,1,0,1,0,0,0,1,1,1]), translation=translation)

In [None]:
all_labels = torch.stack([actions_to_labels(tensor, translation=translation) for tensor in data_all_bits], dim=1)
labels = all_labels[15, :]
label_df = pl.DataFrame(labels.detach().cpu().numpy(), schema=['label'])

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available")
elif torch.backends.mps.is_available():
    print("CUDA not available, but MPS is available.")
    device = torch.device("mps")
else:
    print("CUDA and MPS not available. Using CPU.")
    device = torch.device("cpu")

In [None]:
splits = torch.split(data_all_bits, 512, dim=0)
resid_mids = []
resid2_mids = []
mlp_posts = []
mlp2_posts = []

model_5.to(device)
for batch in tqdm(splits):
    with torch.no_grad():
        logits_all, cache_all = model_3.run_with_cache(batch.to(device))
        resid_mids.append(cache_all["resid_mid", 1].detach().cpu())
        resid2_mids.append(cache_all["resid_mid", 2].detach().cpu())
        mlp_posts.append(cache_all["post", 1, "mlp"].detach().cpu())
        mlp2_posts.append(cache_all["post", 2, "mlp"].detach().cpu())

Length of all the lists above is 128, since that's how many batches of unique samples we have (aka, 2^16/512)

In [None]:
len(resid_mids)

Dimensions of each element of list will be batch size, sequence length, resid dim/mlp dim

In [None]:
resid_mids[0].shape

In [None]:
resid_mids[0].shape
mlp_posts[0].shape

In [None]:
concat_resids = torch.concatenate(resid_mids, dim=0)
concat_resids2 = torch.concatenate(resid2_mids, dim=0)
concat_mlp_posts = torch.concatenate(mlp_posts, dim=0)
concat_mlp2_posts = torch.concatenate(mlp2_posts, dim=0)

Now that we have concatenated the list elements together we get rid of the batch dimension and end up with all 2^16 examples.

In [None]:
concat_mlp_posts.shape

We will focus on position 16 (for no reason at all other than it's the last position!).

In [None]:
# Get position 16 in sequence (final token)
resid_pos16 = concat_resids[:, 15, :]
resid2_pos16 = concat_resids2[:, 15, :]
mlp_pos16 = concat_mlp_posts[:, 15, :]
mlp2_pos16 = concat_mlp2_posts[:, 15, :]

We got rid of the second dimension, the sequence dimension, by indexing number 15. Nice.

In [None]:
resid_pos16.shape

Now we can also get rid of samples by averaging.

In [None]:
mlp_pos16.mean(dim=0).shape

Now for fun part: we want to get the PCA and explained variance of the components of each of these activations at a particular position. Let's start with the MLP post from layer 3.

In [None]:
U_mlp, S_mlp, V_mlp = get_pca(mlp2_pos16)
get_vars(S_mlp)

This is incredibly un-sparse! It basically wants all the available PCA directions for 100% explained variance. Now let's plot the components against each other.

In [None]:
mean = mlp2_pos16.mean(dim=0)
nonzero = torch.nonzero(mean).squeeze()
mlp2_reduced = mlp2_pos16[:, nonzero] @ V_mlp
mlp2_posts_df = pl.DataFrame(mlp2_reduced.detach().cpu().numpy(), schema=[str(i) for i in range(mlp2_reduced.shape[1])])
mlp2_posts_df = pl.concat([label_df, mlp2_posts_df], how='horizontal')

combs = itertools.combinations(range(5), r=2)
for c in combs:
    plot = ggplot(mlp2_posts_df, aes(x=str(c[0]), y=str(c[1]), color='factor(label)')) + geom_point() + facet_wrap('~label')
    display(plot)

In [None]:
#ggplot(df.filter(pl.col('variable').is_in(['0', '1', '2', '3'])), aes(x='value', fill='factor(label)')) + geom_histogram() + facet_wrap('~variable')
# mlp2_posts_df.filter(pl.col('0') < -100)

In [None]:
U_resid, S_resid, V_resid = get_pca(resid_pos16)
get_vars(S_resid)

In [None]:
mean = resid_pos16.mean(dim=0)
nonzero = torch.nonzero(mean).squeeze()
resid_reduced = resid_pos16[:, nonzero] @ V_resid
resid_df = pl.DataFrame(resid_reduced.detach().cpu().numpy(), schema=[str(i) for i in range(resid_reduced.shape[1])])
resid_df = pl.concat([label_df, resid_df], how='horizontal')

combs = itertools.combinations(range(5), r=2)
for c in combs:
    plot = ggplot(resid_df, aes(x=str(c[0]), y=str(c[1]), color='factor(label)')) + geom_point() + facet_wrap('~label')
    display(plot)

In [None]:
U_resid2, S_resid2, V_resid2 = get_pca(resid2_pos16)
get_vars(S_resid2)

In [None]:
mean = resid2_pos16.mean(dim=0)
nonzero = torch.nonzero(mean).squeeze()
resid2_reduced = resid2_pos16[:, nonzero] @ V_resid2
resid2_df = pl.DataFrame(resid2_reduced.detach().cpu().numpy(), schema=[str(i) for i in range(resid2_reduced.shape[1])])
resid2_df = pl.concat([label_df, resid2_df], how='horizontal')

combs = itertools.combinations(range(5), r=2)
for c in combs:
    plot = ggplot(resid2_df, aes(x=str(c[0]), y=str(c[1]), color='factor(label)')) + geom_point() + facet_wrap('~label')
    display(plot)

# Trial inputs

In [None]:
all1 = torch.ones((25), dtype=torch.int32)
all0 = torch.zeros((25), dtype=torch.int32)
all1_label = (torch.cumsum(all1, dim=0)) % 5
all0_label = torch.tensor([5,0]*12+[5], dtype=torch.int32)
print(all1_label)
print(all0_label)

# All zeros except a single one at one position
all_zero_except1 = deepcopy(all0)
all_zero_except1[8] = 1
all_zero_except1_label = actions_to_labels(all_zero_except1, dtype="int", translation=translation)
print(all_zero_except1_label)


# All ones except a single zero at one position
all_one_except1 = deepcopy(all1)
all_one_except1[0] = 0
all_one_except1[10] = 0
all_one_except1_label = actions_to_labels(all_one_except1, dtype="int", translation=translation)
print(all_one_except1_label)

In [None]:
alternating = torch.tensor([1,0]*12+[1], dtype=torch.int32)
alternating_label = actions_to_labels(alternating, dtype="int", translation=translation)
alternating_label

### Early in training

In [None]:
l_all1_0, cache_all1_0 = model_0.run_with_cache(all1)
l_all0_0, cache_all0_0 = model_0.run_with_cache(all0)
pred_all1_0 = torch.argmax(l_all1_0, dim=-1).squeeze().cpu()
pred_all0_0 = torch.argmax(l_all0_0, dim=-1).squeeze().cpu()
print(pred_all0_0.shape)
print("all one labels", all1_label)
print("predicted all ones", pred_all1_0)
print("all zero labels", all0_label)
print("predicted all zeros", pred_all0_0)

l_all_zero_except1_0, c_zero_except1_0 = model_0.run_with_cache(all_zero_except1)
l_all_one_except1_0 , c_all_one_except1_0 = model_0.run_with_cache(all_one_except1)
pred_all_zero_except1_0 = torch.argmax(l_all_zero_except1_0, dim=-1).squeeze().cpu()
pred_all_one_except1_0 = torch.argmax(l_all_one_except1_0, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_0)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_0)

l_alt_0, cache_alt_0 = model_0.run_with_cache(alternating)
pred_alt_0 = torch.argmax(l_alt_0, dim=-1).squeeze().cpu()
print("alternating labels", alternating_label)
print("predicted alternating", pred_alt_0)

### Form 1

In [None]:
l_all1_1, cache_all1_1 = model_1.run_with_cache(all1)
l_all0_1, cache_all0_1 = model_1.run_with_cache(all0)
pred_all1_1 = torch.argmax(l_all1_1, dim=-1).squeeze().cpu()
pred_all0_1 = torch.argmax(l_all0_1, dim=-1).squeeze().cpu()
print(pred_all0_1.shape)
print("all one labels", all1_label)
print("predicted all ones", pred_all1_1)
print("all zero labels", all0_label)
print("predicted all zeros", pred_all0_1)

l_all_zero_except1_1, c_zero_except1_1 = model_1.run_with_cache(all_zero_except1)
l_all_one_except1_1 , c_all_one_except1_1 = model_1.run_with_cache(all_one_except1)
pred_all_zero_except1_1 = torch.argmax(l_all_zero_except1_1, dim=-1).squeeze().cpu()
pred_all_one_except1_1 = torch.argmax(l_all_one_except1_1, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_1)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_1)

l_alt_1, cache_alt_1 = model_1.run_with_cache(alternating)
pred_alt_1 = torch.argmax(l_alt_1, dim=-1).squeeze().cpu()
print("alternating labels", alternating_label)
print("predicted alternating", pred_alt_1)

### Form 2

In [None]:
l_all1_2, cache_all1_2 = model_2.run_with_cache(all1)
l_all0_2, cache_all0_2 = model_2.run_with_cache(all0)
pred_all1_2 = torch.argmax(l_all1_2, dim=-1).squeeze().cpu()
pred_all0_2 = torch.argmax(l_all0_2, dim=-1).squeeze().cpu()
print(pred_all0_2.shape)
print("all one labels", all1_label)
print("predicted all ones", pred_all1_2)
print("all zero labels", all0_label)
print("predicted all zeros", pred_all0_2)

l_all_zero_except1_2, c_zero_except1_2 = model_2.run_with_cache(all_zero_except1)
l_all_one_except1_2 , c_all_one_except1_2 = model_2.run_with_cache(all_one_except1)
pred_all_zero_except1_2 = torch.argmax(l_all_zero_except1_2, dim=-1).squeeze().cpu()
pred_all_one_except1_2 = torch.argmax(l_all_one_except1_2, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_2)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_2)

l_alt_2, cache_alt_2 = model_1.run_with_cache(alternating)
pred_alt_2 = torch.argmax(l_alt_2, dim=-1).squeeze().cpu()
print("alternating labels", alternating_label)
print("predicted alternating", pred_alt_2)

### Form 3

In [None]:
l_all1_3, cache_all1_3 = model_3.run_with_cache(all1)
l_all0_3, cache_all0_3 = model_3.run_with_cache(all0)
pred_all1_3 = torch.argmax(l_all1_3, dim=-1).squeeze().cpu()
pred_all0_3 = torch.argmax(l_all0_3, dim=-1).squeeze().cpu()
print(pred_all0_3.shape)
print("all one labels", all1_label)
print("predicted all 1s", pred_all1_3)
print("all zero labels", all0_label)
print("predicted all 0s", pred_all0_3)

l_all_zero_except1_3, c_zero_except1_3 = model_3.run_with_cache(all_zero_except1)
l_all_one_except1_3 , c_all_one_except1_3 = model_3.run_with_cache(all_one_except1)
pred_all_zero_except1_3 = torch.argmax(l_all_zero_except1_3, dim=-1).squeeze().cpu()
pred_all_one_except1_3 = torch.argmax(l_all_one_except1_3, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_3)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_3)

### Form 4

In [None]:
l_all1_4, cache_all1_4 = model_4.run_with_cache(all1)
l_all0_4, cache_all0_4 = model_4.run_with_cache(all0)
pred_all1_4 = torch.argmax(l_all1_4, dim=-1).squeeze().cpu()
pred_all0_4 = torch.argmax(l_all0_4, dim=-1).squeeze().cpu()
print(pred_all0_4.shape)
print("all one labels", all1_label)
print("predicted all 1s", pred_all1_4)
print("all zero labels", all0_label)
print("predicted all 0s", pred_all0_4)

l_all_zero_except1_4, c_zero_except1_4 = model_4.run_with_cache(all_zero_except1)
l_all_one_except1_4 , c_all_one_except1_4 = model_4.run_with_cache(all_one_except1)
pred_all_zero_except1_4 = torch.argmax(l_all_zero_except1_4, dim=-1).squeeze().cpu()
pred_all_one_except1_4 = torch.argmax(l_all_one_except1_4, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_4)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_4)

### End of training

In [None]:
l_all1_5, cache_all1_5 = model_5.run_with_cache(all1)
l_all0_5, cache_all0_5 = model_5.run_with_cache(all0)
pred_all1_5 = torch.argmax(l_all1_5, dim=-1).squeeze().cpu()
pred_all0_5 = torch.argmax(l_all0_5, dim=-1).squeeze().cpu()
print(pred_all0_5.shape)
print("all one labels", all1_label)
print("predicted all 1s", pred_all1_5)
print("all zero labels", all0_label)
print("predicted all 0s", pred_all0_5)

l_all_zero_except1_5, c_all_zero_except1_5 = model_5.run_with_cache(all_zero_except1)
l_all_one_except1_5 , c_all_one_except1_5 = model_5.run_with_cache(all_one_except1)
pred_all_zero_except1_5 = torch.argmax(l_all_zero_except1_5, dim=-1).squeeze().cpu()
pred_all_one_except1_5 = torch.argmax(l_all_one_except1_5, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_5)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_5)

l_alt_5, cache_alt_5 = model_5.run_with_cache(alternating)
pred_alt_5 = torch.argmax(l_alt_5, dim=-1).squeeze().cpu()
print("alternating labels", alternating_label)
print("predicted alternating", pred_alt_5)

## First layer attention

This is a 5-state, 2-layer, 2-head transformer trained to 80% accuracy.

In [None]:
# Pass data through
for data in take_n(ed_loader, 1):
    inputs = data["input_ids"]
    labels = data["label_ids"]
    break

logits_0, cache_0 = model_0.run_with_cache(inputs)
logits_1, cache_1 = model_1.run_with_cache(inputs)
logits_2, cache_2 = model_2.run_with_cache(inputs)
logits_3, cache_3 = model_3.run_with_cache(inputs)
logits_4, cache_4 = model_4.run_with_cache(inputs)
logits_5, cache_5 = model_5.run_with_cache(inputs)

In [None]:
IDX = 8

In [None]:
att_0_0 = cache_0["pattern", 0, "attn"]
display_layer_heads(att_0_0, batch_idx=IDX+15)

In [None]:
att_1_0 = cache_1["pattern", 0, "attn"]
display_layer_heads(att_1_0, batch_idx=IDX)
display_layer_heads(att_1_0, batch_idx=IDX+1)

In [None]:
att_2_0 = cache_2["pattern", 0, "attn"]
display_layer_heads(att_2_0, batch_idx=IDX)
display_layer_heads(att_2_0, batch_idx=IDX+1)

In [None]:
att_3_0 = cache_3["pattern", 0, "attn"]
display_layer_heads(att_3_0, batch_idx=IDX)
display_layer_heads(att_3_0, batch_idx=IDX+1)

In [None]:
att_4_0 = cache_4["pattern", 0, "attn"]
display_layer_heads(att_4_0, batch_idx=IDX)
display_layer_heads(att_4_0, batch_idx=IDX+1)

In [None]:
att_5_0 = cache_5["pattern", 0, "attn"]
display_layer_heads(att_5_0, batch_idx=IDX)
display_layer_heads(att_5_0, batch_idx=IDX+1)

## Second layer attention

In [None]:
att_0_1 = cache_0["pattern", 1, "attn"]
display_layer_heads(att_0_1, batch_idx=IDX)

In [None]:
att_1_1 = cache_1["pattern", 1, "attn"]
display_layer_heads(att_1_1, batch_idx=IDX)

In [None]:
att_2_1 = cache_2["pattern", 1, "attn"]
display_layer_heads(att_2_1, batch_idx=IDX)
display_layer_heads(att_2_1, batch_idx=IDX+1)

In [None]:
att_3_1 = cache_3["pattern", 1, "attn"]
display_layer_heads(att_3_1, batch_idx=IDX)
display_layer_heads(att_3_1, batch_idx=IDX+1)

In [None]:
att_4_1 = cache_4["pattern", 1, "attn"]
display_layer_heads(att_4_1, batch_idx=IDX)
display_layer_heads(att_4_1, batch_idx=IDX+1)

In [None]:
att_5_1 = cache_5["pattern", 1, "attn"]
display_layer_heads(att_5_1, batch_idx=IDX)

## Third layer attention

In [None]:
att_0_3 = cache_0["pattern", 2, "attn"]
display_layer_heads(att_0_3, batch_idx=IDX)

In [None]:
att_1_3 = cache_1["pattern", 2, "attn"]
display_layer_heads(att_1_3, batch_idx=IDX)

In [None]:
att_2_3 = cache_2["pattern", 2, "attn"]
display_layer_heads(att_2_3, batch_idx=IDX)

In [None]:
att_3_3 = cache_3["pattern", 2, "attn"]
display_layer_heads(att_3_3, batch_idx=IDX)

In [None]:
att_4_3 = cache_4["pattern", 2, "attn"]
display_layer_heads(att_4_3, batch_idx=IDX)

In [None]:
att_5_3 = cache_5["pattern", 2, "attn"]
display_layer_heads(att_5_3, batch_idx=IDX)

## Trial input attention

### All 1s
Let's do all 1s as the input for the final model, and observe all 3 layers' attention.

In [None]:
att_5_0_all1 = cache_all1_5["pattern", 0, "attn"]
for head in range(4):
    imshow_attention(att_5_0_all1[0,head,...])

In [None]:
att_5_1_all1 = cache_all1_5["pattern", 1, "attn"]
for head in range(4):
    imshow_attention(att_5_1_all1[0,head,...])

In [None]:
att_5_2_all1 = cache_all1_5["pattern", 2, "attn"]
for head in range(4):
    imshow_attention(att_5_2_all1[0,head,...])

### All 0s

In [None]:
att_5_0_all0 = cache_all0_5["pattern", 0, "attn"]
for head in range(4):
    imshow_attention(att_5_0_all0[0,head,...])

In [None]:
att_5_1_all0 = cache_all0_5["pattern", 1, "attn"]
for head in range(4):
    imshow_attention(att_5_1_all0[0,head,...])

In [None]:
att_5_2_all0 = cache_all0_5["pattern", 2, "attn"]
for head in range(4):
    imshow_attention(att_5_2_all0[0,head,...])

### Alternating

In [None]:
att_5_0_alt = cache_alt_5["pattern", 0, "attn"]
for head in range(4):
    imshow_attention(att_5_0_alt[0,head,...])

In [None]:
att_5_1_alt = cache_alt_5["pattern", 1, "attn"]
for head in range(4):
    imshow_attention(att_5_1_alt[0,head,...])

In [None]:
att_5_2_alt = cache_alt_5["pattern", 2, "attn"]
for head in range(4):
    imshow_attention(att_5_2_alt[0,head,...])

## Form 1

In [None]:
att_1_0_all1 = cache_all1_1["pattern", 0, "attn"]
for head in range(4):
    imshow_attention(att_1_0_all1[0,head,...])

### All 1s except 1

In [None]:
att_5_0_all_one_except1 = c_all_one_except1_5["pattern", 0, "attn"]
for head in range(4):
    imshow_attention(att_5_0_all_one_except1[0,head,...])

In [None]:
att_5_1_all_one_except1 = c_all_one_except1_5["pattern", 1, "attn"]
for head in range(4):
    imshow_attention(att_5_1_all_one_except1[0,head,...])

In [None]:
att_5_2_all_one_except1 = c_all_one_except1_5["pattern", 2, "attn"]
for head in range(4):
    imshow_attention(att_5_2_all_one_except1[0,head,...])

# Activation patching/ablation

In [None]:
from jaxtyping import Float
import transformer_lens.utils as utils

def ablate_hook(
    value: Float[torch.Tensor, "input_vocab d_model"],
    hook: HookPoint,
) -> Float[torch.Tensor, "input_vocab d_model"]:
    value[:, :] = -0.
    return value

In [None]:
w_pos_hook_name = utils.get_act_name("resid_post", 2)

In [None]:
cache_1.keys()

In [None]:
logits = model_5.run_with_hooks(
    all1,
    fwd_hooks=[(
        "hook_pos_embed",
        ablate_hook,
    )]
)
print(torch.argmax(logits, dim=-1))
model_5.reset_hooks()

In [None]:
imshow_attention(model_5.W_pos)

In [None]:
imshow_attention(model_5.W_pos @ model_5.W_pos.T)

In [None]:
imshow_attention(model_0.W_pos @ model_0.W_pos.T)

In [None]:
imshow_attention(model_1.W_pos @ model_1.W_pos.T)

In [None]:
imshow_attention(model_2.W_pos @ model_2.W_pos.T)

In [None]:
imshow_attention(model_3.W_pos @ model_3.W_pos.T)

In [None]:
imshow_attention(model_4.W_pos @ model_4.W_pos.T)

In [None]:
logits = model_0.run_with_hooks(
    alternating,
    fwd_hooks=[(
    "hook_pos_embed",
    ablate_hook,
    )]
)
print(torch.argmax(logits, dim=-1))
model_0.reset_hooks()

print(torch.argmax(l_all0_0, dim=-1))

In [None]:
model_5.QK[0,...][0].shape

In [None]:
softmax = nn.Softmax(dim=-1)
for layer_num in range(3):
    for head_num in range(4):
        imshow_attention(model_5.W_pos @ model_5.QK.AB[layer_num, ...][head_num].T)
        imshow_attention(softmax(model_5.QK.AB[layer_num, ...][head_num]))

In [None]:
# model_5.W_V.shape
# for layer in range(3):
#     for head in range(4):
#         imshow_attention(model_5.W_V[layer, ...][head].T)

# OV circuit analysis

In [None]:
print(cache_1["scale"].shape)
# Layernorm scale, [batch, pos, 1]

In [None]:
# [nlayers nheads dmodel dhead] x [nlayers nheads dmodel dhead].T
W_OV = model_0.W_V @ model_0.W_O # [nlayers nheads dmodel dmodel]
W_E = model_0.W_E # [vocab_in dhead]
W_U = model_0.W_U # [vocab_out dhead]
print(W_E.shape)

In [None]:
scale_final = cache_1["scale"][:, :, 0].mean()
scale_0 = cache_1["scale", 0, "ln1"].mean()
scale_1 = cache_1["scale", 1, "ln1"].mean()

In [None]:
print(W_OV[1].shape)
print(W_OV[0].shape)
print(W_OV.shape)

In [None]:
# ! Get direct path
W_E_OV_direct = (W_E / scale_final) @ W_U
print(f"Direct {W_E_OV_direct.shape}") # [vocab_out vocab_out]

# ! Get full OV matrix for path through just layer 0
W_E_OV_0 = (W_E / scale_0) @ W_OV[0]
W_OV_0_full = (W_E_OV_0 / scale_final) @ W_U # [n_head vocab_in vocab_out]
print(f"Layer 0 {W_OV_0_full.shape}")

# ! Get full OV matrix for path through just layer 1
W_E_OV_1 = (W_E / scale_1) @ W_OV[1]
W_OV_1_full = (W_E_OV_1 / scale_final) @ W_U # [n_head vocab_in vocab_out]
print(f"Layer 1 {W_OV_1_full.shape}")

# ! Get full OV matrix for path through heads in layer 0 and 1
W_E_OV_01 = einops.einsum(
    (W_E_OV_0 / scale_1), W_OV[1],
    "head0 vocab_in d_model_in, head1 d_model_in d_model_out -> head0 head1 vocab_in d_model_out",
)
W_OV_01_full = (W_E_OV_01 / scale_final) @ W_U # [head0 head1 vocab_in vocab_out]
print(f"Layers 0, 1 {W_OV_01_full.shape}")

In [None]:
print(W_E_OV_direct[None, None].shape)
print(W_OV_0_full[:, None].shape)
print(W_OV_1_full[None].shape)
print(W_OV_01_full.shape)

cat_1 = torch.cat([W_E_OV_direct[None, None], W_OV_0_full[:, None]]) # [head0 1 vocab_in vocab_out]
cat_2 = torch.cat([W_OV_1_full[None], W_OV_01_full])  # [head0 head1 vocab_in vocab_out]
print(cat_1.shape, cat_2.shape)

W_OV_full_all = torch.cat([
    cat_1,
    cat_2,
], dim=1) # [head0 head1 vocab_in vocab_out]
print(W_OV_full_all.shape)
print(W_OV_full_all.transpose(0, 1).flatten(0, 1).shape)

In [None]:
tokens = [str(i) for i in range(10)]
components_0 = ["W<sub>E</sub>"] + [f"0.{i}" for i in range(4)]
components_1 = ["W<sub>U</sub>"] + [f"1.{i}" for i in range(4)]

# Using dict.fromkeys() prevents repeats
facet_labels = [" ➔ ".join(list(dict.fromkeys(["W<sub>E</sub>", c0, c1, "W<sub>U</sub>"]))) for c1 in components_1 for c0 in components_0]
imshow(
    W_OV_full_all.transpose(0, 1).flatten(0, 1), # .softmax(dim=-1),
    facet_col = 0,
    facet_col_wrap = 5,
    facet_labels = facet_labels,
    title = f"Full virtual OV circuits",
    x = tokens,
    y = tokens,
    labels = {"x": "Source", "y": "Dest"},
    height = 1200,
    width = 1200,
    # text = text,
)

In [None]:
model_5.to('cpu')
OV_circuit_all_heads = model_5.OV
OV_circuit_all_heads_eigenvalues = OV_circuit_all_heads.eigenvalues

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

OV_copying_score = OV_circuit_all_heads_eigenvalues.sum(dim=-1).real / OV_circuit_all_heads_eigenvalues.abs().sum(dim=-1)
imshow(utils.to_numpy(OV_copying_score), xaxis="Head", yaxis="Layer", title="OV Copying Score for each head in TfLens model", zmax=1.0, zmin=-1.0)