In [1]:
import torch as t
import torch.nn.functional as F
import numpy as np
import math
from transformer_lens import HookedTransformer as ht
import transformer_lens.utils as utils

In [2]:
from einops import rearrange, reduce, einsum
import circuitsvis as cv

In [3]:
t.cuda.is_available()

True

In [4]:
t.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x749f03ece690>

In [5]:
device = t.device("cuda")

In [6]:
MAIN = __name__ == "__main__"

# Model

In [7]:
model = ht.from_pretrained("gpt2-small", device=device)

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model gpt2-small into HookedTransformer


In [8]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

163049041

In [7]:
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPT2LMHeadModel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_adjacent_pairs': False,
 'rotary_base': 10000,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': 

In [8]:
print(f"W_E: {model.W_E.shape}")
print(f"W_K: {model.W_K.shape}")
print(f"W_Q: {model.W_Q.shape}")
print(f"W_V: {model.W_V.shape}")
print(f"W_O: {model.W_O.shape}")

W_E: torch.Size([50257, 768])
W_K: torch.Size([12, 12, 768, 64])
W_Q: torch.Size([12, 12, 768, 64])
W_V: torch.Size([12, 12, 768, 64])
W_O: torch.Size([12, 12, 64, 768])


In [5]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

# Tokens

In [43]:
vocab_list = np.array(sorted(list(model.tokenizer.vocab.items()), key=lambda n: n[1]))
vocab_list[:10]

array([['!', '0'],
       ['"', '1'],
       ['#', '2'],
       ['$', '3'],
       ['%', '4'],
       ['&', '5'],
       ["'", '6'],
       ['(', '7'],
       [')', '8'],
       ['*', '9']], dtype='<U128')

In [51]:
t_seq = model.to_tokens("Exercise - how many words does your model guess correctly?")
vocab_list[t_seq.cpu().numpy()]

array([[['<|endoftext|>', '50256'],
        ['Ex', '3109'],
        ['ercise', '23697'],
        ['Ġ-', '532'],
        ['Ġhow', '703'],
        ['Ġmany', '867'],
        ['Ġwords', '2456'],
        ['Ġdoes', '857'],
        ['Ġyour', '534'],
        ['Ġmodel', '2746'],
        ['Ġguess', '4724'],
        ['Ġcorrectly', '9380'],
        ['?', '30']]], dtype='<U128')

In [48]:
model.to_str_tokens("Exercise - how many words does your model guess correctly?", prepend_bos=False)

['Ex',
 'ercise',
 ' -',
 ' how',
 ' many',
 ' words',
 ' does',
 ' your',
 ' model',
 ' guess',
 ' correctly',
 '?']

In [9]:
text = '''There are many types of senses. One of them is the sense of sight while the other is the sense of sound. These two senses are quite common. When you look at the tree, you are perceiving sense of sight.'''

logits: t.Tensor = model(text, return_type="logits")
print(logits.shape)
t_pred = logits.argmax(dim=-1).squeeze()
print(t_pred.shape)

torch.Size([1, 47, 50257])
torch.Size([47])


In [10]:
print(f"{sum(model.to_tokens(text).squeeze()[1:]==t_pred[:-1]).item()} out of {len(t_pred)} tokens are correctly predicted")

24 out of 47 tokens are correctly predicted


# Caching activation

In [11]:
model.to_tokens(text).shape

torch.Size([1, 47])

In [12]:
tok_text = model.to_tokens(text)

logits, cache = model.run_with_cache(tok_text)

In [13]:
cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

In [14]:
print(cache['blocks.0.attn.hook_q'].shape)
print(cache['blocks.0.attn.hook_k'].shape)
print(cache['blocks.0.attn.hook_pattern'].shape)

torch.Size([1, 47, 12, 64])
torch.Size([1, 47, 12, 64])
torch.Size([1, 12, 47, 47])


In [15]:
layer0_pattern_from_cache = cache['blocks.0.attn.hook_pattern']

#dim: (batch, n_seq, n_heads, d_head)
q_from_cache = cache['blocks.0.attn.hook_q']
#dim: (batch, n_seq, n_heads, d_head)
k_from_cache = cache['blocks.0.attn.hook_k']

d_head = q_from_cache.shape[-1]
d_seq = q_from_cache.shape[1]

#layer0_pattern_from_q_and_k = t.matmul(rearrange(q_from_cache, 'b s h d -> b h s d'),
#                                       rearrange(k_from_cache, 'b s h d -> b h d s'))
layer0_pattern_from_q_and_k = einsum(q_from_cache, k_from_cache, 
                                    'b qs h d, b ks h d -> b h qs ks')/math.sqrt(d_head)

#same as t.arange(0, d_seq, device=device).reshape(-1,1)<t.arange(0, d_seq, device=device).reshape(1,-1)
mask = ~t.tril(t.ones(d_seq, d_seq, device=device)).bool()
layer0_pattern_from_q_and_k = layer0_pattern_from_q_and_k.masked_fill(mask, -1e6)
layer0_pattern_from_q_and_k = layer0_pattern_from_q_and_k.softmax(dim=-1)

In [16]:
(layer0_pattern_from_cache - layer0_pattern_from_q_and_k).abs().sum()

tensor(0., device='cuda:0')

In [18]:
cv.attention.attention_patterns(attention=cache['blocks.3.attn.hook_pattern'].squeeze(),
                               tokens=model.to_str_tokens(text))

Note the head 0 of layer 3 attend to repeated or synonym words.

In [22]:
layer0_pattern_from_cache.squeeze().shape

torch.Size([12, 47, 47])