In [1]:
from transformer_lens.cautils.notebook import *

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)
# model.set_use_split_qkv_input(True)

clear_output()

In [4]:
from transformer_lens import FactoredMatrix

W_E = model.W_E
W_U = model.W_U

In [5]:
# Calculate W_{EE} edit
batch_size = 1000
nrows = model.cfg.d_vocab
W_EE = t.zeros((nrows, model.cfg.d_model)).to(device)

for i in tqdm(range(0, nrows + batch_size, batch_size)):
    cur_range = t.tensor(range(i, min(i + batch_size, nrows)))
    if len(cur_range)>0:
        embeds = W_E[cur_range].unsqueeze(0)
        pre_attention = model.blocks[0].ln1(embeds)
        post_attention = einops.einsum(
            pre_attention, 
            model.W_V[0],
            model.W_O[0],
            "b s d_model, num_heads d_model d_head, num_heads d_head d_model_out -> b s d_model_out",
        )
        normalized_resid_mid = model.blocks[0].ln2(post_attention + embeds)
        resid_post = model.blocks[0].mlp(normalized_resid_mid)
        W_EE[cur_range.to(device)] = resid_post

  0%|          | 0/52 [00:00<?, ?it/s]

In [16]:
W_QK = FactoredMatrix(model.W_Q[10, 7], model.W_K[10, 7].T)

W_QK_full = W_E @ W_QK @ W_E.T
W_QK_full_eff = W_U.T @ W_QK @ W_EE.T

In [21]:
line(W_QK.S, height=400, width=600, title="W_QK")

In [24]:
line(W_QK_full.S, height=400, width=600, title="W_QK full (W_E on both sides)")

In [25]:
line(W_QK_full_eff.S, height=400, width=600, title="W_QK effective (W_U on query, W_EE on key)")

In [90]:
model.to_str_tokens(W_QK_full_eff.U[:, 0].abs().argmax())

[' Leilan']

In [54]:
model.to_str_tokens(W_QK_full_eff.Vh[:, 0].abs().argmax())

[' protocols']

# High cosine sim between names and the principal directions of $W_{QK}$ ?

In [65]:
W_U.T.shape

torch.Size([50257, 768])

In [66]:
W_QK.U.shape

torch.Size([768, 64])

In [67]:
squared_cos_sim.shape

torch.Size([50257])

In [89]:
W_U_normed = W_U / W_U.norm(dim=0)

squared_cos_sim = (W_U.T @ W_QK.U[:, 0]) ** 2
max_cos_sim_words = squared_cos_sim.topk(10).indices

model.to_str_tokens(max_cos_sim_words)

['ón', 'igh', 'í', 'ump', 'atican', 'agn', 'ivia', 'ndum', 'ocese', 'lev']

In [74]:
W_EE.shape

torch.Size([50257, 768])

In [76]:
W_EE_normed = W_EE / W_EE.norm(dim=-1, keepdim=True)

squared_cos_sim = (W_QK.Vh[:, 0] @ W_EE_normed.T) ** 2
max_cos_sim_words = squared_cos_sim.topk(10).indices

model.to_str_tokens(max_cos_sim_words)

[' Walker',
 ' Wade',
 ' House',
 ' trespass',
 ' Koh',
 ' Domestic',
 ' Simpson',
 'fare',
 'll',
 'iss']

In [77]:
W_E_normed = W_E / W_E.norm(dim=-1, keepdim=True)

squared_cos_sim = (W_QK.Vh[:, 0] @ W_E_normed.T) ** 2
max_cos_sim_words = squared_cos_sim.topk(10).indices

model.to_str_tokens(max_cos_sim_words)

['�',
 'ategory',
 ' constitu',
 'emale',
 'uilt',
 'lopp',
 'ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ',
 '�',
 'ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ',
 'ゴン']

In [93]:
W_QK.U.shape

torch.Size([768, 64])

In [106]:
name_cos_sims = []

for name in NAMES:

    unit_unembedding_vector = W_U[:, model.to_single_token(name)]
    unit_unembedding_vector /= unit_unembedding_vector.norm()

    variance_of_unembedding_explained_by_left_singular_space = einops.einsum(
        unit_unembedding_vector,
        W_QK.U,
        "d_model, d_model d_vocab -> d_vocab"
    ).pow(2).sum()

    name_cos_sims.append(variance_of_unembedding_explained_by_left_singular_space.item())

hist(np.array(name_cos_sims) * 12, template="simple_white", width=800)

In [105]:
generic_words = [
    "the", "be", "to", "of", "and", "a", "in", "that", "have", "I", "it", "for", "not", "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from", "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", 
    "all", "would", "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which", "go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know", "take", "people", "into", "year", "your", "good", "some", 
    "could", "them", "see", "other", "than", "then", "now", "look", "only", "come", "its", "over", "think", "also", "back", "after", "use", "two", "how", "our", "work", "first", "well", "way", "even", "new", "want", "because", "any", 
    "these", "give", "day", "most", "us"
]

generic_word_cos_sims = []

for word in generic_words:

    unit_unembedding_vector = W_U[:, model.to_single_token(word)]
    unit_unembedding_vector /= unit_unembedding_vector.norm()

    variance_of_unembedding_explained_by_left_singular_space = einops.einsum(
        unit_unembedding_vector,
        W_QK.U,
        "d_model, d_model d_vocab -> d_vocab"
    ).pow(2).sum()

    generic_word_cos_sims.append(variance_of_unembedding_explained_by_left_singular_space.item())

hist(np.array(generic_word_cos_sims) * 12, template="simple_white", width=800)