In [1]:
from transformer_lens.cautils.notebook import *

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)
# model.set_use_split_qkv_input(True)

clear_output()

# full_data = get_webtext()
# TOTAL_OWT_SAMPLES = 100
# SEQ_LEN = 20
# data = full_data[:TOTAL_OWT_SAMPLES]

from transformer_lens import FactoredMatrix

In [7]:
def get_effective_embedding(model: HookedTransformer) -> Float[Tensor, "d_vocab d_model"]:

    W_E = model.W_E.clone()
    W_U = model.W_U.clone()
    # t.testing.assert_close(W_E[:10, :10], W_U[:10, :10].T)  NOT TRUE, because of the center unembed part!

    embeds = W_E.unsqueeze(0)
    pre_attention = model.blocks[0].ln1(embeds)
    post_attention = einops.einsum(
        pre_attention, 
        model.W_V[0],
        model.W_O[0],
        "b s d_model, num_heads d_model d_head, num_heads d_head d_model_out -> b s d_model_out",
    )
    resid_mid = post_attention + embeds
    normalized_resid_mid = model.blocks[0].ln2(resid_mid)
    mlp_out = model.blocks[0].mlp(normalized_resid_mid)
    
    W_EE = mlp_out.squeeze()
    W_EE_full = resid_mid.squeeze() + mlp_out.squeeze()

    return {
        "W_U (or W_E raw, no MLPs)": W_U.T,
        # "W_E (raw, no MLPs)": W_E,
        "W_E (including MLPs)": W_EE_full,
        "W_E (only MLPs)": W_EE
    }

embeddings_dict = get_effective_embedding(model)

$$
W_E^Q W_Q W_K^T W_E^K
$$

In [10]:
def plot_random_sample(
    embeddings_dict: Dict[str, Float[Tensor, "d_vocab d_model"]],
    model: HookedTransformer = model,
    sample_size: int = 50,
    num_batches: int = 1,
    head: Tuple[int, int] = (10, 7)
):
    results_for_each_batch = []

    sorted_keys = sorted(embeddings_dict.keys())

    W_Q = model.W_Q[head[0], head[1]]
    W_K = model.W_K[head[0], head[1]]

    embeddings_dict_normalized = {k: v / v.norm(dim=-1, keepdim=True) for k, v in embeddings_dict.items()}

    q_and_k_labels = [(q_name, k_name) for q_name in sorted_keys for k_name in sorted_keys]
    q_and_k_matrices = [(embeddings_dict_normalized[q_name], embeddings_dict_normalized[k_name]) for (q_name, k_name) in q_and_k_labels]

    for batch_idx in range(num_batches):
        results = []
        sample_indices = t.randint(0, model.cfg.d_vocab, (sample_size,))
        for q_matrix, k_matrix in q_and_k_matrices:
            full_matrix = FactoredMatrix(q_matrix @ W_Q, W_K.T @ k_matrix.T)
            full_matrix_sample = full_matrix.A[sample_indices, :] @ full_matrix.B[:, sample_indices]
            # full_matrix_sample = full_matrix_sample - full_matrix_sample.mean(dim=-1, keepdim=True)
            # full_matrix_sample = full_matrix_sample / full_matrix_sample.std()
            full_matrix_sample = full_matrix_sample.softmax(dim=-1)
            results.append(full_matrix_sample)

        results_for_each_batch.append(t.stack(results, dim=0))

    results = sum(results_for_each_batch) / len(results_for_each_batch)

    imshow(
        results - (1 / sample_size),
        facet_col=0,
        facet_col_wrap=len(embeddings_dict),
        facet_labels=[f"Q = {q_name}<br>K = {k_name}" for (q_name, k_name) in q_and_k_labels],
        title=f"Sample of diagonal patterns for differnet matrices: head {head}",
        labels={"x": "Key", "y": "Query"},
        height=1200, width=1200
    )

plot_random_sample(embeddings_dict, head = (3, 0), sample_size = 50, num_batches = 20)