In [None]:
import os
os.environ["ACCELERATE_DISABLE_RICH"] = "1"
from typeguard import typechecked
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

import torch as t
import torch
import einops
import itertools
import plotly.express as px
import numpy as np
from datasets import load_dataset
from functools import partial
from tqdm import tqdm
from jaxtyping import Float, Int, jaxtyped
from typing import Union, List, Dict, Tuple, Callable, Optional
from torch import Tensor
import gc
import transformer_lens
from transformer_lens.ActivationCache import ActivationCache
from transformer_lens import utils
from transformer_lens.utils import to_numpy
t.set_grad_enabled(False)

In [None]:
from transformer_lens.cautils.notebook import *

In [None]:
def to_tensor(
    tensor,
):
    return t.from_numpy(to_numpy(tensor))

def imshow_old(
    tensor, 
    **kwargs,
):
    tensor = to_tensor(tensor)
    zmax = tensor.abs().max().item()

    if "zmin" not in kwargs:
        kwargs["zmin"] = -zmax
    if "zmax" not in kwargs:
        kwargs["zmax"] = zmax
    if "color_continuous_scale" not in kwargs:
        kwargs["color_continuous_scale"] = "RdBu"

    fig = px.imshow(
        to_numpy(tensor),
        **kwargs,
    )
    fig.show()

In [None]:
MODEL_NAME = "gpt2-small"
# MODEL_NAME = "solu-10l"
model = transformer_lens.HookedTransformer.from_pretrained(MODEL_NAME)
from transformer_lens.hackathon.ioi_dataset import IOIDataset, NAMES

In [None]:
model.set_use_attn_result(False)
model.set_use_split_qkv_input(True)

In [None]:
LAYER_IDX, HEAD_IDX = {
    "SoLU_10L1280W_C4_Code": (9, 18), # (9, 18) is somewhat cheaty
    "gpt2": (10, 7),
}[model.cfg.model_name]


W_U = model.W_U
W_Q_negative = model.W_Q[LAYER_IDX, HEAD_IDX]
W_K_negative = model.W_K[LAYER_IDX, HEAD_IDX]

W_E = model.W_E

# ! question - what's the approximation of GPT2-small's embedding?
# lock attn to 1 at current position
# lock attn to average
# don't include attention

In [None]:
from transformer_lens import FactoredMatrix

full_QK_circuit = FactoredMatrix(W_U.T @ W_Q_negative, W_K_negative.T @ W_E.T)

indices = t.randint(0, model.cfg.d_vocab, (250,))
full_QK_circuit_sample = full_QK_circuit.A[indices, :] @ full_QK_circuit.B[:, indices]

full_QK_circuit_sample_centered = full_QK_circuit_sample - full_QK_circuit_sample.mean(dim=1, keepdim=True)

imshow(
    full_QK_circuit_sample_centered,
    labels={"x": "Source / key token (embedding)", "y": "Destination / query token (unembedding)"},
    title="Full QK circuit for negative name mover head",
    width=700,
)

In [None]:
def lock_attn(
    attn_patterns: Float[t.Tensor, "batch head_idx dest_pos src_pos"],
    hook: HookPoint,
    ablate: bool = False,
) -> Float[t.Tensor, "batch head_idx dest_pos src_pos"]:
    
    assert isinstance(attn_patterns, Float[t.Tensor, "batch head_idx dest_pos src_pos"])
    assert hook.layer() == 0

    batch, n_heads, seq_len = attn_patterns.shape[:3]
    attn_new = einops.repeat(t.eye(seq_len), "dest src -> batch head_idx dest src", batch=batch, head_idx=n_heads).clone().to(attn_patterns.device)
    if ablate:
        attn_new = attn_new * 0
    return attn_new

def fwd_pass_lock_attn0_to_self(
    model: HookedTransformer,
    input: Union[List[str], Int[t.Tensor, "batch seq_pos"]],
    ablate: bool = False,
) -> Float[t.Tensor, "batch seq_pos d_vocab"]:

    model.reset_hooks()
    
    loss = model.run_with_hooks(
        input,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("pattern", 0), partial(lock_attn, ablate=ablate))],
    )

    return loss

In [None]:
raw_dataset = load_dataset("stas/openwebtext-10k")
train_dataset = raw_dataset["train"]
dataset = [train_dataset[i]["text"] for i in range(len(train_dataset))]

In [None]:
for i, s in enumerate(dataset):
    loss_hooked = fwd_pass_lock_attn0_to_self(model, s)
    print(f"Loss with attn locked to self: {loss_hooked:.2f}")
    loss_hooked_0 = fwd_pass_lock_attn0_to_self(model, s, ablate=True)
    print(f"Loss with attn locked to zero: {loss_hooked_0:.2f}")
    loss_orig = model(s, return_type="loss")
    print(f"Loss with attn free: {loss_orig:.2f}\n")

    # gc.collect()

    if i == 5:
        break

In [None]:
if "gpt" in model.cfg.model_name: # sigh, tied embeddings
    # sanity check this is the same 

    def remove_pos_embed(z, hook):
        return 0.0 * z

    # setup a forward pass that 
    model.reset_hooks()
    model.add_hook(
        name="hook_pos_embed",
        hook=remove_pos_embed,
        level=1, # ???
    ) 
    model.add_hook(
        name=utils.get_act_name("pattern", 0),
        hook=lock_attn,
    )
    logits, cache = model.run_with_cache(
        torch.arange(1000).to(device).unsqueeze(0),
        names_filter=lambda name: name=="blocks.1.hook_resid_pre",
        return_type="logits",
    )


    W_EE_test = cache["blocks.1.hook_resid_pre"].squeeze(0)
    W_EE_prefix = W_EE_test[:1000]

    assert torch.allclose(
        W_EE_prefix,
        W_EE_test,
        atol=1e-4,
        rtol=1e-4,
    )

In [None]:
def get_EE_QK_circuit(
    layer_idx,
    head_idx,
    random_seeds: Optional[int] = 5,
    num_samples: Optional[int] = 500,
    bags_of_words: Optional[List[List[int]]] = None, # each List is a List of unique tokens
    mean_version: bool = True,
    show_plot: bool = False,
    W_E_query_side: Optional[t.Tensor] = None,
    W_E_key_side: Optional[t.Tensor] = None,
):
    assert (random_seeds is None and num_samples is None) != (bags_of_words is None), (random_seeds is None, num_samples is None, bags_of_words is None, "Must specify either random_seeds and num_samples or bag_of_words_version")

    if bags_of_words is not None:
        random_seeds = len(bags_of_words) # eh not quite random seeds but whatever
        assert all([len(bag_of_words) == len(bags_of_words[0])] for bag_of_words in bags_of_words), "Must have same number of words in each bag of words"
        num_samples = len(bags_of_words[0])

    W_Q_head = model.W_Q[layer_idx, head_idx]
    W_K_head = model.W_K[layer_idx, head_idx]

    assert W_E_query_side is not None
    assert W_E_key_side is not None
    W_E_Q_normed = W_E_query_side / W_E_query_side.var(dim=-1, keepdim=True).pow(0.5)
    W_E_K_normed = W_E_key_side / W_E_key_side.var(dim=-1, keepdim=True).pow(0.5)

    EE_QK_circuit = FactoredMatrix(W_E_Q_normed @ W_Q_head, W_K_head.T @ W_E_K_normed.T)
    EE_QK_circuit_result = t.zeros((num_samples, num_samples))

    for random_seed in range(random_seeds):
        if bags_of_words is None:
            indices = t.randint(0, model.cfg.d_vocab, (num_samples,))
        else:
            indices = t.tensor(bags_of_words[random_seed])

        # assert False, "TODO: add Q and K and V biases???"
        EE_QK_circuit_sample = einops.einsum(
            EE_QK_circuit.A[indices, :],
            EE_QK_circuit.B[:, indices],
            "num_query_samples d_head, d_head num_key_samples -> num_query_samples num_key_samples"
        )

        if mean_version:
            # we're going to take a softmax so the constant factor is arbitrary 
            # and it's a good idea to centre all these results so adding them up is reasonable
            EE_QK_mean = EE_QK_circuit_sample.mean(dim=-1, keepdim=True)
            EE_QK_circuit_sample_centered = EE_QK_circuit_sample - EE_QK_mean 
            EE_QK_circuit_result += EE_QK_circuit_sample_centered.cpu()

        else:
            EE_QK_softmax = t.nn.functional.softmax(EE_QK_circuit_sample, dim=-1)
            EE_QK_circuit_result += EE_QK_softmax.cpu()

    EE_QK_circuit_result /= random_seeds

    if show_plot:
        imshow_old(
            EE_QK_circuit_result,
            labels={"x": "Source/Key Token (embedding)", "y": "Destination/Query Token (unembedding)"},
            title=f"EE QK circuit for head {layer_idx}.{head_idx}",
            width=700,
        )

    return EE_QK_circuit_result

In [None]:
NAME_MOVERS = {
    "gpt2": [(9, 9), (10, 0), (9, 6)],
    "SoLU_10L1280W_C4_Code": [(7, 12), (5, 4), (8, 3)],
}[model.cfg.model_name]

NEGATIVE_NAME_MOVERS = {
    "gpt2": [(LAYER_IDX, HEAD_IDX), (11, 10)],
    "SoLU_10L1280W_C4_Code": [(LAYER_IDX, HEAD_IDX), (9, 15)], # second one on this one IOI prompt only...
}[model.cfg.model_name]

In [None]:
# Prep some bags of words...
# OVERLY LONG because it really helps to have the bags of words the same length

bags_of_words = []

OUTER_LEN = 50
INNER_LEN = 100

idx = -1
while len(bags_of_words) < OUTER_LEN:
    idx += 1
    cur_tokens = model.tokenizer.encode(dataset[idx])
    cur_bag = []
    
    for i in range(len(cur_tokens)):
        if len(cur_bag) == INNER_LEN:
            break
        if cur_tokens[i] not in cur_bag:
            cur_bag.append(cur_tokens[i])

    if len(cur_bag) == INNER_LEN:
        bags_of_words.append(cur_bag)

In [None]:
def get_effective_embedding(model: HookedTransformer) -> Float[Tensor, "d_vocab d_model"]:

    W_E = model.W_E.clone()
    W_U = model.W_U.clone()
    # t.testing.assert_close(W_E[:10, :10], W_U[:10, :10].T)  NOT TRUE, because of the center unembed part!

    resid_pre = W_E.unsqueeze(0)
    pre_attention = model.blocks[0].ln1(resid_pre)
    attn_out = einops.einsum(
        pre_attention, 
        model.W_V[0],
        model.W_O[0],
        "b s d_model, num_heads d_model d_head, num_heads d_head d_model_out -> b s d_model_out",
    )
    resid_mid = attn_out + resid_pre
    normalized_resid_mid = model.blocks[0].ln2(resid_mid)
    mlp_out = model.blocks[0].mlp(normalized_resid_mid)
    
    W_EE = mlp_out.squeeze()
    W_EE_full = resid_mid.squeeze() + mlp_out.squeeze()

    return {
        "W_U (or W_E, no MLPs)": W_U.T,
        # "W_E (raw, no MLPs)": W_E,
        "W_E (including MLPs)": W_EE_full,
        "W_E (only MLPs)": W_EE
    }

embeddings_dict = get_effective_embedding(model)

In [None]:
# Getting just diag patterns for a single head

from transformer_lens import FactoredMatrix

LAYER = 3
HEAD = 0

all_results = []
embeddings_dict_keys = sorted(embeddings_dict.keys())
labels = []

for q_side_matrix, k_side_matrix in tqdm(list(itertools.product(embeddings_dict_keys, embeddings_dict_keys))):
    labels.append(f"Q = {q_side_matrix}<br>K = {k_side_matrix}")

    results = []
    for idx in range(OUTER_LEN):
        softmaxed_attn = get_EE_QK_circuit(
            LAYER,
            HEAD,
            show_plot=False,
            num_samples=None,
            random_seeds=None,
            bags_of_words=bags_of_words[idx: idx+1],
            mean_version=False,
            W_E_query_side=embeddings_dict[q_side_matrix],
            W_E_key_side=embeddings_dict[k_side_matrix],
        )
        results.append(softmaxed_attn)
    
    all_results.append(sum(results) / len(results))

    t.cuda.empty_cache()

all_results = t.stack(all_results) # .reshape((3, 3, INNER_LEN, INNER_LEN))

In [None]:
imshow(
    all_results,
    facet_col=0,
    facet_col_wrap=len(embeddings_dict),
    facet_labels=labels,
    title=f"Sample of diagonal patterns for different matrices: head 3.0 (duplicate token head)",
    labels={"x": "Key", "y": "Query"},
    height=900, width=900
)

In [None]:
scores = t.zeros(12, 12).float().to(device)

for layer, head in tqdm(list(itertools.product(range(12), range(12)))):
    results = []
    for idx in range(OUTER_LEN):
        softmaxed_attn = get_EE_QK_circuit(
            layer,
            head,
            show_plot=False,
            num_samples=None,
            random_seeds=None,
            bags_of_words=bags_of_words[idx:idx+1],
            mean_version=False,
            W_E_query_side=embeddings_dict["W_U (or W_E, no MLPs)"],
            W_E_key_side=embeddings_dict["W_E (including MLPs)"],  # "W_E (only MLPs)"
        )
        results.append(softmaxed_attn.diag().mean())

    results = sum(results) / len(results)

    scores[layer, head] = results

imshow(scores, width=750, labels={"x": "Head", "y": "Layer"}, title="Prediction-attn scores for bag of words (including MLPs in embedding)")

In [None]:
scores = t.zeros(12, 12).float().to(device)

for layer, head in tqdm(list(itertools.product(range(12), range(12)))):
    results = []
    for idx in range(OUTER_LEN):
        softmaxed_attn = get_EE_QK_circuit(
            layer,
            head,
            show_plot=False,
            num_samples=None,
            random_seeds=None,
            bags_of_words=bags_of_words[idx:idx+1],
            mean_version=False,
            W_E_query_side=embeddings_dict["W_U (or W_E, no MLPs)"],
            W_E_key_side=embeddings_dict["W_E (only MLPs)"],  # 
        )
        results.append(softmaxed_attn.diag().mean())

    results = sum(results) / len(results)

    scores[layer, head] = results

imshow(scores, width=750, labels={"x": "Head", "y": "Layer"}, title="Prediction-attn scores for bag of words (only MLPs in embedding)")