In [None]:
import os
import sys
import plotly.express as px
import plotly.subplots as sp
import torch
from pathlib import Path
import numpy as np
import einops
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from matplotlib import pyplot as plt
import plotly_utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv
import functools
MAIN = __name__ == "__main__"

In [None]:
os.environ['KMP_DUPLICATE_LIB_OK']='True'
torch.set_grad_enabled(False)
#device_name = "mps" if torch.backends.mps.is_available() else "cpu"
device_name = "cpu"
device = torch.device(device_name)
print(f"{device=}")


In [None]:
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b", 
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)
weights_dir = "attn_only_2L_half.pth"
if not Path(weights_dir).exists():
    url = "https://drive.google.com/uc?id=1vcZLJnJoYKQs-2KOjkd6LvHZrkSdoxhu"
    output = str(weights_dir)
    gdown.download(url, output)
model = HookedTransformer(cfg)
pretrained_weights = torch.load(weights_dir, map_location=device)
model.load_state_dict(pretrained_weights)

In [None]:
# We make a tensor to store the induction score for each head.
# We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
global induction_score_store, determinant_store, rank_store
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
rank_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)


In [None]:
def generate_repeated_tokens(
    model: HookedTransformer, seq_len: int, batch: int = 1
) -> torch.Tensor:
    '''
    Generates a sequence of repeated random tokens

    Int[torch.Tensor, "batch full_seq_len"]
    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
    '''
    prefix = (torch.ones(batch, 1) * model.tokenizer.bos_token_id).long()
    first_half = torch.randint(low = 0, high = model.cfg.d_vocab, size = torch.Size([batch, seq_len]), dtype=torch.long)
    repeated_tokens = torch.concat([prefix, first_half, first_half], dim=1)
    return repeated_tokens

def run_and_cache_model_repeated_tokens(model: HookedTransformer, seq_len: int, batch: int = 1) -> Tuple[torch.Tensor, torch.Tensor, ActivationCache]:
    '''
    Generates a sequence of repeated random tokens, and runs the model on it, returning logits, tokens and cache

    Should use the `generate_repeated_tokens` function above

    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
        rep_logits: [batch, 1+2*seq_len, d_vocab]
        rep_cache: The cache of the model run on rep_tokens
    '''
    rep_tokens = generate_repeated_tokens(model, seq_len, batch=batch)
    rep_logits, rep_cache = model.run_with_cache(rep_tokens)
    print(rep_tokens.size(), rep_logits.size(), type(rep_cache))
    return rep_tokens, rep_logits, rep_cache


In [None]:
def create_induction_mask_rep_tokens(seq_len):
    """
    Create a mask where a value 1 at position (i, j) where i > j and 
    the value at index i is exactly equal to the value at index j.
    Assumes the tokens are generated with generate_repeated_tokens, i.e.
    position 0 is the BOS token, 1, ..., seq_len is repeated twice.

    Args:
    tokens: A 1D tensor of tokens.

    Returns:
    A 2D mask tensor.
    """
    # Get the length of the sequence
    indices = torch.Tensor(range(seq_len+1))
    repeats = indices[1:]
    repeated_indices = torch.concat([indices, repeats])
    seq_len = 2 * seq_len + 1

    # Initialize an empty mask
    mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)
    # Compare each token to all previous tokens - 1, for induction, skip the first token
    for i in range(1, seq_len):
        mask[i, 1:i] = (repeated_indices[i] == repeated_indices[:i-1])

    return mask


def induction_score_hook_inefficient(
    pattern: torch.Tensor,
    hook: HookPoint,
    seq_len: int = 0,
    threshold: float = 3
):
    '''
    Calculates the induction score, and stores it in the [layer, head] position of the `induction_score_store` tensor.

    pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"]
    '''
    induction_mask = create_induction_mask_rep_tokens(seq_len)
    attention_pattern = pattern.mean(dim = 0).squeeze()

    for head in range(attention_pattern.size()[0]):
        curr_attention_pattern = attention_pattern[head].squeeze()
        induction_activations = curr_attention_pattern[induction_mask]
        
        norm_attention_pattern = (curr_attention_pattern - curr_attention_pattern.mean()) / curr_attention_pattern.std()
        norm_induction = (induction_activations - curr_attention_pattern.mean()) / curr_attention_pattern.std()
        # Check if the mean induction token attn value is significantly larger than the mean activations
        induction_score_store[hook.layer(), head] = norm_induction.mean() - (norm_attention_pattern.mean() + threshold * norm_attention_pattern.std())
        rank_store[hook.layer(), head] = torch.linalg.matrix_rank(curr_attention_pattern)

def induction_score_hook(
    pattern: torch.Tensor,
    hook: HookPoint,
    seq_len: int = 0,
    threshold: float = 3
):
    '''
    Calculates the induction score, and stores it in the [layer, head] position of the `induction_score_store` tensor.

    pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"]
    '''
    attention_pattern = pattern.mean(dim = 0).squeeze() # avg across batch
    induction_mask = create_induction_mask_rep_tokens(seq_len).unsqueeze(0).expand_as(attention_pattern)

    mean_induction_activations = attention_pattern[induction_mask].view(attention_pattern.size()[0], -1).mean(dim=1)
    mean_activations = attention_pattern.mean(dim=[1,2])
    std_activations = attention_pattern.std(dim=[1,2])
    induction_scores = mean_induction_activations - (mean_activations + threshold * std_activations)
    induction_score_store[hook.layer(), :] = induction_scores

    ranks = torch.zeros(attention_pattern.size()[0])
    for i in range(attention_pattern.size()[0]):
        ranks[i] = torch.linalg.matrix_rank(attention_pattern[i])
    rank_store[hook.layer(), :] = ranks

In [None]:
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
rank_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

seq_len = 50
batch = 10
rep_tokens_10 = generate_repeated_tokens(model, seq_len, batch)

pattern_hook_names_filter = lambda name: name.endswith("pattern")

induction_score_hook_function = functools.partial(
    induction_score_hook_inefficient,
    seq_len = seq_len,
    threshold = 3
)

# Run with hooks (this is where we write to the `induction_score_store` tensor`)
model.run_with_hooks(
    rep_tokens_10, 
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook_function
    )]
)

# Plot the induction scores for each head in each layer
plotly_utils.imshow(
    induction_score_store, 
    labels={"x": "Head", "y": "Layer"}, 
    title="Induction Score by Head", 
    text_auto=".2f",
    width=900, height=400
)

plotly_utils.imshow(
    rank_store, 
    labels={"x": "Head", "y": "Layer"}, 
    title="Rank of average activation by Head", 
    text_auto=".2f",
    width=900, height=400
)

In [None]:
def visualize_pattern_hook(
    pattern: torch.Tensor,
    hook: HookPoint,
    model = None,
    tokens = None
):
    """
    pattern: Float[Tensor, "batch head_index dest_pos source_pos"]
    """
    print("Layer: ", hook.layer())
    display(
        cv.attention.attention_patterns(
            tokens=model.to_str_tokens(tokens[0]), 
            attention=pattern.mean(0)
        )
    )

In [None]:
# on gpt2-small
gpt2_small = HookedTransformer.from_pretrained("gpt2-small")
print(type(gpt2_small))
print(f"{gpt2_small.cfg=}")

In [None]:
induction_score_store = torch.zeros((gpt2_small.cfg.n_layers, gpt2_small.cfg.n_heads), device=gpt2_small.cfg.device)
rank_store = torch.zeros((gpt2_small.cfg.n_layers, gpt2_small.cfg.n_heads), device=gpt2_small.cfg.device)
seq_len = 50
batch = 10
rep_tokens_10 = generate_repeated_tokens(gpt2_small, seq_len, batch)

pattern_hook_names_filter = lambda name: name.endswith("pattern")

induction_score_hook_function = functools.partial(
    induction_score_hook,
    seq_len = seq_len,
    threshold = 3
)

# Run with hooks (this is where we write to the `induction_score_store` tensor`)
gpt2_small.run_with_hooks(
    rep_tokens_10, 
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook_function
    )]
)

# Plot the induction scores for each head in each layer
plotly_utils.imshow(
    induction_score_store, 
    labels={"x": "Head", "y": "Layer"}, 
    title="Induction Score by Head", 
    text_auto=".2f",
    width=900, height=400
)

plotly_utils.imshow(
    rank_store, 
    labels={"x": "Head", "y": "Layer"}, 
    title="Rank of average activation by Head", 
    text_auto=".2f",
    width=900, height=400
)

In [None]:
# visualize only layers 5, 6, and 7
visualize_pattern_hook_function = functools.partial(
    visualize_pattern_hook,
    model = gpt2_small,
    tokens = rep_tokens_10
)
pattern_hook_names_filter_of_interest = lambda name: name in [utils.get_act_name("pattern", l) for l in (5, 6, 7)]

# Run with hooks (this is where we write to the `induction_score_store` tensor`)
gpt2_small.run_with_hooks(
    rep_tokens_10, 
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter_of_interest,
        visualize_pattern_hook_function
    )]
)

In [None]:
# Logit Attribution


In [None]:
def logit_attribution(
    embed: torch.Tensor,
    l1_results: torch.Tensor,
    l2_results: torch.Tensor,
    w_u: torch.Tensor,
    tokens: torch.Tensor 
):
    '''
    Inputs:
        embed: the embeddings of the tokens (i.e. token + position embeddings) Float[Tensor, "seq d_model"]
        l1_results: the outputs of the attention heads at layer 1 (with head as one of the dimensions) Float[Tensor, "seq nheads d_model"]
        l2_results: the outputs of the attention heads at layer 2 (with head as one of the dimensions) Float[Tensor, "seq nheads d_model"]
        W_U: the unembedding matrix Float[Tensor, "d_model d_vocab"]
        tokens: the token ids of the sequence Int[Tensor, "seq"]

    Returns:
        Tensor of shape (seq_len-1, n_components) -> Float[Tensor, "seq-1 n_components"]
        represents the concatenation (along dim=-1) of logit attributions from:
            the direct path (seq-1,1)
            layer 0 logits (seq-1, n_heads)
            layer 1 logits (seq-1, n_heads)
        so n_components = 1 + 2*n_heads
    '''
    w_u_correct_tokens = w_u[:, tokens[1:]]

    # direct path
    direct_path_logits = torch.diagonal(embed @ w_u_correct_tokens).unsqueeze(dim=1)
    
    # layer 1 path
    layer_1_logits = torch.diagonal(
        torch.einsum("imh,mo->ioh", l1_results.permute(0, 2, 1), w_u_correct_tokens),
        offset = 0, dim1 = 0, dim2 = 1
    ).T

    # layer 2 path
    layer_2_logits = torch.diagonal(
        torch.einsum("imh,mo->ioh", l2_results.permute(0, 2, 1), w_u_correct_tokens),
        offset = 0, dim1 = 0, dim2 = 1
    ).T

    # # inefficient (but more readable?) way to do it, by indexing into the logits to get the right token after unembedding!
    # direct_path_logits = (embed @ W_U)[torch.arange(tokens.size()[0] - 1), tokens[1:]].unsqueeze(dim=1)

    # # layer 1 path
    # layer_1_logits = torch.einsum("smh,mv->svh", l1_results.permute(0, 2, 1), W_U)[
    #     torch.arange(tokens.size()[0] - 1), tokens[1:]
    # ]
    # layer_1_logits = layer_1_logits.view(tokens.size()[0] - 1, -1)

    # # layer 2 path
    # layer_2_logits = torch.einsum("smh,mv->svh", l2_results.permute(0, 2, 1), W_U)[
    #     torch.arange(tokens.size()[0] - 1), tokens[1:]
    # ]
    # layer_2_logits = layer_2_logits.view(tokens.size()[0] - 1, -1)
    
    print(direct_path_logits.size(), layer_1_logits.size(), layer_2_logits.size())

    return torch.concat([direct_path_logits, layer_1_logits, layer_2_logits], dim=1)

In [None]:
input_text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."
logits, cache = model.run_with_cache(input_text, remove_batch_dim=True)
str_tokens = model.to_str_tokens(input_text)
tokens = model.to_tokens(input_text)

In [None]:
with torch.inference_mode():
    embed = cache["embed"]
    l1_results = cache["result", 0]
    l2_results = cache["result", 1]
    logit_attr = logit_attribution(embed, l1_results, l2_results, model.W_U, tokens[0])
    # Uses fancy indexing to get a len(tokens[0])-1 length tensor, where the kth entry is the predicted logit for the correct k+1th token
    correct_token_logits = logits[0, torch.arange(len(tokens[0]) - 1), tokens[0, 1:]]
    torch.testing.assert_close(logit_attr.sum(1), correct_token_logits, atol=1e-3, rtol=0)
    print("Tests passed!")

    plotly_utils.plot_logit_attribution(model, logit_attr, tokens)

In [None]:
# run it through gpt-2 (layers 5 and 6) -> Actually the computations are for attention only, the non-linearities in GPT-2 make
# it tricky, or rather we would need to re-write the logit attribution for gpt-2!
# gpt2_logits, gpt2_cache = gpt2_small.run_with_cache(input_text, remove_batch_dim=True)
# gpt2_str_tokens = gpt2_small.to_str_tokens(input_text)
# gpt2_tokens = gpt2_small.to_tokens(input_text)

# with torch.inference_mode():
#     gpt2_embed = cache["embed"]
#     gpt2_l5_results = cache["result", 5]
#     gpt2_l6_results = cache["result", 6]
#     gpt2_logit_attr = logit_attribution(gpt2_embed, gpt2_l5_results, gpt2_l6_results, gpt2_small.W_U, gpt2_tokens[0])
#     # Uses fancy indexing to get a len(tokens[0])-1 length tensor, where the kth entry is the predicted logit for the correct k+1th token
#     gpt2_correct_token_logits = gpt2_logits[0, torch.arange(len(tokens[0]) - 1), tokens[0, 1:]]
#     torch.testing.assert_close(gpt2_logit_attr.sum(1), gpt2_correct_token_logits, atol=1e-3, rtol=0)
#     print("Tests passed!")

#     plotly_utils.plot_logit_attribution(gpt2_small, gpt2_logit_attr, tokens)

In [None]:
seq_len = 50
rep_tokens = generate_repeated_tokens(model, seq_len, batch=1)
rep_logits, rep_cache = model.run_with_cache(rep_tokens, remove_batch_dim=True)

with torch.inference_mode():
    embed = rep_cache["embed"]
    l1_results = rep_cache["result", 0]
    l2_results = rep_cache["result", 1]
    first_half_tokens = rep_tokens[0, :seq_len+1]
    second_half_tokens = rep_tokens[0, seq_len:]
    first_half_logit_attr = logit_attribution(embed[:seq_len+1], l1_results[:seq_len+1], l2_results[:seq_len+1], model.W_U, first_half_tokens)
    second_half_logit_attr = logit_attribution(embed[seq_len:], l1_results[seq_len:], l2_results[seq_len:], model.W_U, second_half_tokens)

assert first_half_logit_attr.shape == (seq_len, 2*model.cfg.n_heads + 1)
assert second_half_logit_attr.shape == (seq_len, 2*model.cfg.n_heads + 1)

plotly_utils.plot_logit_attribution(model, first_half_logit_attr, first_half_tokens, "Logit attribution (first half of repeated sequence)")
plotly_utils.plot_logit_attribution(model, second_half_logit_attr, second_half_tokens, "Logit attribution (second half of repeated sequence)")

In [None]:
# Ablations
def head_ablation_hook(
    v: torch.Tensor,
    hook: HookPoint,
    head_index_to_ablate: int
):
    """
    v: Float[Tensor, "batch seq n_heads d_head"]
    out -> Float[Tensor, "batch seq n_heads d_head"]
    """
    v[:, :, head_index_to_ablate, :] = 0.0
    return v

def cross_entropy_loss(logits, tokens):
    '''
    Computes the mean cross entropy between logits (the model's prediction) and tokens (the true values).

    (optional, you can just use return_type="loss" instead.)
    '''
    log_probs = torch.log_softmax(logits, dim=-1)
    pred_log_probs = torch.gather(log_probs[:, :-1], -1, tokens[:, 1:, None])[..., 0]
    return -pred_log_probs.mean()


def get_ablation_scores(
    model: HookedTransformer, 
    tokens: torch.Tensor,
):
    '''
    Returns a tensor of shape (n_layers, n_heads) containing the increase in cross entropy loss from ablating the output of each head.

    tokens: Int[Tensor, "batch seq"]
    out: -> Float[Tensor, "n_layers n_heads"]
    '''
    # Initialize an object to store the ablation scores
    ablation_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

    # Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    logits = model(tokens, return_type="logits")
    loss_no_ablation = cross_entropy_loss(logits, tokens)

    for layer in tqdm(range(model.cfg.n_layers)):
        for head in range(model.cfg.n_heads):
            # Use functools.partial to create a temporary hook function with the head number fixed
            temp_hook_fn = functools.partial(head_ablation_hook, head_index_to_ablate=head)
            # Run the model with the ablation hook
            ablated_logits = model.run_with_hooks(tokens, fwd_hooks=[
                (utils.get_act_name("v", layer), temp_hook_fn)
            ])
            # Calculate the logit difference
            loss = cross_entropy_loss(ablated_logits, tokens)
            # Store the result, subtracting the clean loss so that a value of zero means no change in loss
            ablation_scores[layer, head] = loss - loss_no_ablation

    return ablation_scores

In [None]:
ablation_scores = get_ablation_scores(model, rep_tokens)
plotly_utils.imshow(
    ablation_scores, 
    labels={"x": "Head", "y": "Layer"}, 
    title="Ablation score by Head", 
    text_auto=".2f",
    width=900, height=400
)

In [None]:
# Solution from Callum McDougall's files
# from callummcdougall_chapters.chapter1_transformers.exercises.part2_intro_to_mech_interp import tests
# tests.test_get_ablation_scores(ablation_scores, model, rep_tokens)

In [None]:
# bonus: got bored to implement, 
# but performance should improve when you ablate everything except the previous token head & two induction heads
# as they seem to influence the output the most!