In [1]:
import os
import sys
import plotly.express as px
import plotly.subplots as sp
import torch
from pathlib import Path
import numpy as np
import einops
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from matplotlib import pyplot as plt
import plotly_utils

from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv
MAIN = __name__ == "__main__"



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ['KMP_DUPLICATE_LIB_OK']='True'
torch.set_grad_enabled(False)
#device_name = "mps" if torch.backends.mps.is_available() else "cpu"
device_name = "cpu"
device = torch.device(device_name)
print(f"{device=}")


device=device(type='cpu')


In [3]:
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b", 
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)

In [4]:
weights_dir = "attn_only_2L_half.pth"
if not os.path.abspath(weights_dir).exists():
    url = "https://drive.google.com/uc?id=1vcZLJnJoYKQs-2KOjkd6LvHZrkSdoxhu"
    output = str(weights_dir)
    gdown.download(url, output)

AttributeError: 'str' object has no attribute 'exists'

In [None]:
model = HookedTransformer(cfg)
pretrained_weights = torch.load(weights_dir, map_location=device)
model.load_state_dict(pretrained_weights)

In [None]:
input_text = "We think that powerful HookedTransformer, significantly superhuman machine intelligence is more likely than not to be created this century by HookedTransformer. If current HookedTransformer machine learning techniques were scaled up to this level with HookedTransformer, we think they would by default produce HookedTransformer systems that are deceptive or manipulative HookedTransformer, and that no solid plans are known for HookedTransformer on how to avoid this. HookedTransformer."
#input_text = "Try inputting different text, and see how stable your results are. Do you always get the same classifications for heads"
#input_text = "Again, you are strongly recommended to read the corresponding section of the glossary, before continuing (or this LessWrong post). In brief, however, the induction circuit consists of a previous token head in layer 0 and an induction head in layer 1, where the induction head learns to attend to the token immediately after copies of the current token via K-Composition with the previous token"


In [None]:
logits, cache = model.run_with_cache(input_text, remove_batch_dim=True)

In [None]:
layer_n = 0
attention_pattern_0 = cache["pattern", layer_n]
tokens = model.to_str_tokens(input_text)
print(f"Layer {layer_n} Head Attention Patterns:")
display(cv.attention.attention_heads(
    tokens=tokens,
    attention=attention_pattern_0
))

In [None]:
layer_n = 1
attention_pattern_0 = cache["pattern", layer_n]
tokens = model.to_str_tokens(input_text)
print(f"Layer {layer_n} Head Attention Patterns:")
display(cv.attention.attention_heads(
    tokens=tokens,
    attention=attention_pattern_0
))

In [None]:
def current_attn_detector(cache: ActivationCache, batch_n = 0, threshold=3) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be current-token heads
    '''
    current_attn_heads = []
    for layer in range(model.cfg.n_layers):
        attention_pattern = cache["pattern", layer]
        if cache.has_batch_dim:
            if batch_n is not None:
                attention_pattern = attention_pattern[batch_n].squeeze()
            else:
                # batch_n = None === average activations across batch.
                attention_pattern = attention_pattern.mean(dim=0)
        for head in range(model.cfg.n_heads):
            curr_attention_pattern = attention_pattern[head].squeeze()
            # Get the diagonal values (self-attention values)
            diag = torch.diagonal(curr_attention_pattern)
            
            norm_attention_pattern = (curr_attention_pattern - curr_attention_pattern.mean()) / curr_attention_pattern.std()
            
            norm_diag = (diag - curr_attention_pattern.mean()) / curr_attention_pattern.std()
            # Check if the mean diagonal value is significantly larger than the mean off-diagonal value
            if norm_diag.mean() > norm_attention_pattern.mean() + threshold * norm_attention_pattern.std():
                current_attn_heads.append(f"{layer}.{head}")
    return current_attn_heads

def prev_attn_detector(cache: ActivationCache, batch_n = 0, threshold=3) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    '''
    current_attn_heads = []
    for layer in range(model.cfg.n_layers):
        attention_pattern = cache["pattern", layer]
        if cache.has_batch_dim:
            if batch_n is not None:
                attention_pattern = attention_pattern[batch_n].squeeze()
            else:
                # batch_n = None === average activations across batch.
                attention_pattern = attention_pattern.mean(dim=0)

        for head in range(model.cfg.n_heads):
            curr_attention_pattern = attention_pattern[head].squeeze()
            # Get the first subdiagonal values (prev-token-attention values)
            diag = torch.diag(curr_attention_pattern, diagonal=-1)
            
            norm_attention_pattern = (curr_attention_pattern - curr_attention_pattern.mean()) / curr_attention_pattern.std()
            
            norm_diag = (diag - curr_attention_pattern.mean()) / curr_attention_pattern.std()
            # Check if the mean first token attn value is significantly larger than the mean off-diagonal value
            if norm_diag.mean() > norm_attention_pattern.mean() + threshold * norm_attention_pattern.std():
                current_attn_heads.append(f"{layer}.{head}")
    return current_attn_heads


def first_attn_detector(cache: ActivationCache, batch_n = 0, threshold=3) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be first-token heads
    '''
    current_attn_heads = []
    for layer in range(model.cfg.n_layers):
        attention_pattern = cache["pattern", layer]
        if cache.has_batch_dim:
            if batch_n is not None:
                attention_pattern = attention_pattern[batch_n].squeeze()
            else:
                # batch_n = None === average activations across batch.
                attention_pattern = attention_pattern.mean(dim=0)

        for head in range(model.cfg.n_heads):
            curr_attention_pattern = attention_pattern[head].squeeze()
            # Get the first column (first token attention values)
            first_token_attn = curr_attention_pattern[:, 0]
            
            norm_attention_pattern = (curr_attention_pattern - curr_attention_pattern.mean()) / curr_attention_pattern.std()
            
            norm_diag = (first_token_attn - curr_attention_pattern.mean()) / curr_attention_pattern.std()
            # Check if the mean diagonal value is significantly larger than the mean off-diagonal value
            if norm_diag.mean() > norm_attention_pattern.mean() + threshold * norm_attention_pattern.std():
                current_attn_heads.append(f"{layer}.{head}")
    return current_attn_heads

def create_induction_mask(tokens):
    """
    Create a mask where a value 1 at position (i, j) where i > j and 
    the value at index i is exactly equal to the value at index j.

    Args:
    tokens: A 1D tensor of tokens.

    Returns:
    A 2D mask tensor.
    """
    # Get the length of the sequence
    seq_len = tokens.size()[0]

    # Initialize an empty mask
    mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)

    # Compare each token to all previous tokens - 1, for induction, skip the first token
    for i in range(1, seq_len):
        mask[i, 1:i] = (tokens[i] == tokens[:i-1])

    return mask

def induction_head_detector(cache: ActivationCache, tokens, batch_n = 0, threshold=3) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    '''
    # except for the BOS token at index 0, 
    # for the second half of the sequence, index repeated_seq_len + i is exactly equal to index i
    # Induction heads will have high activation at position (repeated_seq_len + 1, i + 1) as it would be
    # attending to the next token of the last time this token was seen!
    # now, alternatively, we can get the activations of the last time this token was seen, without asserting
    # equality, maybe that can work for the previous non-repeated examples too?

    induction_mask = create_induction_mask(tokens.squeeze())
    current_attn_heads = []
    for layer in range(model.cfg.n_layers):
        attention_pattern = cache["pattern", layer]
        if cache.has_batch_dim:
            if batch_n is not None:
                attention_pattern = attention_pattern[batch_n].squeeze()
            else:
                # batch_n = None === average activations across batch.
                attention_pattern = attention_pattern.mean(dim=0)

        for head in range(model.cfg.n_heads):
            curr_attention_pattern = attention_pattern[head].squeeze()
            induction_activations = curr_attention_pattern[induction_mask]
            
            norm_attention_pattern = (curr_attention_pattern - curr_attention_pattern.mean()) / curr_attention_pattern.std()
            norm_induction = (induction_activations - curr_attention_pattern.mean()) / curr_attention_pattern.std()
            # Check if the mean induction token attn value is significantly larger than the mean activations
            if norm_induction.mean() > norm_attention_pattern.mean() + threshold * norm_attention_pattern.std():
                current_attn_heads.append(f"{layer}.{head}")
    return current_attn_heads

In [None]:
print("Heads attending to current token  = ", ", ".join(current_attn_detector(cache, threshold=3)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(cache, threshold=3)))
print("Heads attending to first token    = ", ", ".join(first_attn_detector(cache, threshold=4)))
print("Induction heads: ", ", ".join(induction_head_detector(cache, model.to_tokens(input_text), threshold=2)))

In [None]:
def generate_repeated_tokens(
    model: HookedTransformer, seq_len: int, batch: int = 1
) -> torch.Tensor:
    '''
    Generates a sequence of repeated random tokens

    Int[torch.Tensor, "batch full_seq_len"]
    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
    '''
    prefix = (torch.ones(batch, 1) * model.tokenizer.bos_token_id).long()
    first_half = torch.randint(low = 0, high = model.cfg.d_vocab, size = torch.Size([batch, seq_len]), dtype=torch.long)
    repeated_tokens = torch.concat([prefix, first_half, first_half], dim=1)
    return repeated_tokens

def run_and_cache_model_repeated_tokens(model: HookedTransformer, seq_len: int, batch: int = 1) -> Tuple[torch.Tensor, torch.Tensor, ActivationCache]:
    '''
    Generates a sequence of repeated random tokens, and runs the model on it, returning logits, tokens and cache

    Should use the `generate_repeated_tokens` function above

    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
        rep_logits: [batch, 1+2*seq_len, d_vocab]
        rep_cache: The cache of the model run on rep_tokens
    '''
    rep_tokens = generate_repeated_tokens(model, seq_len, batch=batch)
    rep_logits, rep_cache = model.run_with_cache(rep_tokens)
    print(rep_tokens.size(), rep_logits.size(), type(rep_cache))
    return rep_tokens, rep_logits, rep_cache


In [None]:
def get_log_probs(logits, tokens):
    """
    """
    print(logits.size(), tokens.size())
    log_probs = torch.log_softmax(logits, dim=-1)    
    log_probs = log_probs[:,:-1].gather(-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    print(log_probs.size())
    return log_probs

def plot_losses(log_probs, seq_len):
    """
    """
    # Calculate the mean log probabilities for the first and second halves of the sequence
    first_half = log_probs[:seq_len]
    second_half = log_probs[seq_len:]

    # Plot the difference
    plt.plot(first_half, label="first half")
    plt.plot(second_half, label="second_half")
    plt.ylabel('Log Probability of correct prediction')
    plt.xlabel('Token position')
    plt.legend()
    # plt.title(f'Loss Difference for "{rep_str}"')
    plt.grid()
    plt.show()

In [None]:
seq_len = 20
batch = 1
(rep_tokens, rep_logits, rep_cache) = run_and_cache_model_repeated_tokens(model, seq_len, batch)
rep_str = model.to_str_tokens(rep_tokens)
rep_cache.remove_batch_dim()
model.reset_hooks()
log_probs = get_log_probs(rep_logits, rep_tokens).squeeze()

print(f"Performance on the first half: {log_probs[:seq_len].mean():.3f}")
print(f"Performance on the second half: {log_probs[seq_len:].mean():.3f}")
plotly_utils.plot_loss_difference(log_probs, rep_str, seq_len)
plot_losses(log_probs, seq_len)
print(rep_str)

In [None]:
layer_n = 0
print(f"Layer {layer_n} attention heads")
display(cv.attention.attention_heads(
    tokens=rep_str,
    attention=rep_cache["pattern", layer_n]
))

In [None]:
layer_n = 1
print(f"Layer {layer_n} attention heads")
display(cv.attention.attention_heads(
    tokens=rep_str,
    attention=rep_cache["pattern", layer_n]
))

In [None]:
print("Heads attending to current token  = ", ", ".join(current_attn_detector(rep_cache, threshold=3)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(rep_cache, threshold=3)))
print("Heads attending to first token    = ", ", ".join(first_attn_detector(rep_cache, threshold=3)))
print("Induction heads: ", ", ".join(induction_head_detector(rep_cache, rep_tokens, threshold=2)))

Heads attending to current token  =  0.1, 0.9, 0.11, 1.6, 1.7
Heads attending to previous token =  0.0, 0.4, 0.7, 0.9
Heads attending to first token    =  0.2, 0.3, 0.6, 1.0, 1.1, 1.2, 1.3, 1.4, 1.6, 1.8, 1.9, 1.10, 1.11
Induction heads:  1.4, 1.10

In [None]:
# create 1000 examples and find common heads!
batch = 1000
seq_len = 51

prefix = (torch.ones(batch, 1) * model.tokenizer.bos_token_id).long()
random_tokens = torch.randint(low = 0, high = model.cfg.d_vocab, size = torch.Size([batch, seq_len-1]), dtype=torch.long)
batch_tokens = torch.concat([prefix, random_tokens, random_tokens[:, 15:20], random_tokens[:, 5:10]], dim=1)
batch_logits, batch_cache = model.run_with_cache(batch_tokens)
print(batch_tokens.size(), batch_logits.size(), type(batch_cache))


In [None]:
curr_token_heads = []
prev_token_heads = []
first_token_heads = []
induction_heads = []
for batch_n in range(batch):
    curr_token_heads.append(current_attn_detector(batch_cache, batch_n=batch_n, threshold=3))
    prev_token_heads.append(prev_attn_detector(batch_cache, batch_n=batch_n, threshold=3))
    first_token_heads.append(first_attn_detector(batch_cache, batch_n=batch_n, threshold=3))
    induction_heads.append(induction_head_detector(batch_cache, batch_tokens[batch_n].squeeze(), batch_n=batch_n, threshold=2))

In [None]:
import itertools
import collections
import pandas as pd

def plot_histogram(data, title):
    """
    Plot a histogram of counts across a 2D array of string values.

    Args:
    data: A 2D list of strings.

    Returns:
    None
    """
    # Flatten the 2D list into a 1D list
    flattened_data = list(itertools.chain.from_iterable(data))

    # Count the occurrence of each string
    counter = collections.Counter(flattened_data)

    # Convert the counter to a DataFrame for plotting
    df = pd.DataFrame.from_dict(counter, orient='index').reset_index()
    df = df.rename(columns={'index':'Head', 0:'Count'})

    # Create the plot
    fig = px.histogram(df, x='Head', y='Count', title=title)
    fig.show()


In [None]:
plot_histogram(curr_token_heads, "Current Token Attention Heads")
plot_histogram(prev_token_heads, "Previous Token Attention Heads")
plot_histogram(first_token_heads, "First Token Attention Heads")
plot_histogram(induction_heads, "Induction Heads")

In [None]:
def plot_attention_heads(attention_scores, layer_n):
    """
    Plot the attention scores of all heads.

    Args:
    attention_scores: A tensor of shape [n_heads, seq_len, seq_len]
    representing the attention scores for each head.
    """
    n_heads, seq_len, _ = attention_scores.shape

    # Create a subplot with 4 rows and 3 columns (for 12 heads)
    fig = sp.make_subplots(rows=4, cols=3, subplot_titles=[f'Head {i}' for i in range(n_heads)])

    for i in range(n_heads):
        # Compute the row and column indices for the subplot
        row = i // 3 + 1
        col = i % 3 + 1

        # Plot the attention scores for this head
        img = px.imshow(attention_scores[i], color_continuous_scale='viridis', binary_string=True)

        fig.add_trace(
            img.data[0],
            row=row,
            col=col
        )

    fig.update_layout(height=800, width=800, title_text=f"Attention Scores for Each Head at layer {layer_n}")
    fig.show()

In [None]:
layer_n = 0
plot_attention_heads(batch_cache["pattern", layer_n].mean(dim=0), layer_n)

In [None]:
layer_n = 1
plot_attention_heads(batch_cache["pattern", layer_n].mean(dim=0), layer_n)

In [None]:
prefix_idx = (torch.ones(1, 1) * model.tokenizer.bos_token_id).long()
random_token_idx = torch.randint(low = 1, high = model.cfg.d_vocab, size = torch.Size([1, seq_len-1]), dtype=torch.long)
batch_token_idx = torch.concat([prefix_idx, random_token_idx, random_token_idx[:, 15:20], random_token_idx[:, 5:10]], dim=1)
print("Heads attending to current token  = ", ", ".join(current_attn_detector(batch_cache, batch_n = None, threshold=3)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(batch_cache, batch_n = None, threshold=3)))
print("Heads attending to first token    = ", ", ".join(first_attn_detector(batch_cache, batch_n = None, threshold=4)))
print("Induction heads: ", ", ".join(induction_head_detector(batch_cache, batch_token_idx, batch_n = None, threshold=2)))
