In [1]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

# from plotly_utils import plot_comp_scores, plot_logit_attribution, plot_loss_difference

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

device = t.device("cuda" if t.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


# Toy Attention-Only 2 Layer Model

- No layer norm, no MLP, no biases.
- Positional embeddings are added to residual stream when calculating Query and Key vectors as opposed to adding them to the token embeddings

In [2]:
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b", 
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)

Download weights for this model

In [3]:
from huggingface_hub import hf_hub_download

REPO_ID = "callummcdougall/attn_only_2L_half"
FILENAME = "attn_only_2L_half.pth"

weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Create model and load weights

In [4]:
model = HookedTransformer(cfg)
pretrained_weights = t.load(weights_path, map_location=device)
model.load_state_dict(pretrained_weights)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


<All keys matched successfully>

# Visualize attention patterns

Visualize attention patterns of both attention layers on the following text

In [5]:
text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."

logits, cache = model.run_with_cache(text, remove_batch_dim=True)

## Layer 0 attention

In [7]:
layer0_attention_patterns = cache["pattern", 0]
print(layer0_attention_patterns.shape) # n_heads x seqQ x seqK

display(cv.attention.attention_heads(
    attention=layer0_attention_patterns,
    tokens=model.to_str_tokens(text),
    attention_head_names=[f"L0H{i}" for i in range(cfg.n_heads)]
))

torch.Size([12, 62, 62])


## Layer 1 Attention

In [8]:
layer1_attention_patterns = cache["pattern", 1]
print(layer1_attention_patterns.shape) # n_heads x seqQ x seqK

display(cv.attention.attention_heads(
    attention=layer1_attention_patterns,
    tokens=model.to_str_tokens(text),
    attention_head_names=[f"L1H{i}" for i in range(cfg.n_heads)]
))

torch.Size([12, 62, 62])


We see three kinds of attention heads
1. Current token heads - Those that attend to the tokens at the same position. Eg: L0H7
2. Previous token heads - Those that attend to the tokens at the previous position. Eg: L1H6
3. First token heads - Those that attend to the tokens at the first position. Eg: L1H4

# Detect different types of attention heads 