In [1]:
from transformer_lens.cautils.notebook import *

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)
# model.set_use_split_qkv_input(True)

clear_output()

In [2]:
BATCH_SIZE = 20

In [3]:
def get_webtext(seed: int = 420) -> List[str]:
    """Get 10,000 sentences from the OpenWebText dataset"""

    # Let's see some WEBTEXT
    raw_dataset = load_dataset("stas/openwebtext-10k", split="train")

    tokenized_dataset = tokenize_and_concatenate(raw_dataset, model.tokenizer, streaming=False, max_length=model.cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)
    
    dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    
    return dataloader


my_dataloader = get_webtext(0)

Found cached dataset openwebtext-10k (/home/ubuntu/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)
Loading cached processed dataset at /home/ubuntu/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b/cache-92b08acc5409395d_*_of_00004.arrow


In [10]:
first_batch = next(iter(my_dataloader))['tokens']

print(first_batch.shape)

torch.Size([20, 1024])


In [11]:
ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising = generate_data_and_caches(N=BATCH_SIZE, model=model, verbose=False, seed=420, prepend_bos=True)

In [21]:
# vague idea - cut down on computation by only looking at the (eventual) most likely 20 toks? But probably this doesn't cut down on the time taken by run_with_cache which is still the main one

def entropy_measure(
    model: HookedTransformer,
    batch: Int[Tensor, "batch_size seq_len"], # same datatype as the batches we get when we iterate through my dataloader
):
    batch_size, seq_len = batch.shape

    logits, cache = model.run_with_cache(
        batch,
        names_filter = lambda name: any([name.endswith(x) for x in ["resid_pre", "resid_mid", "z", "scale", "mlp_out"]])
    )

    W_U = model.W_U

    resid_entropies = t.zeros(2 * model.cfg.n_layers + 1, batch_size, seq_len)
    entropy_diffs = t.zeros(model.cfg.n_layers, model.cfg.n_heads + 1, batch_size, seq_len)

    progress_bar = tqdm(total = model.cfg.n_layers * (model.cfg.n_heads + 1))
    display(progress_bar)

    scale = cache["scale"]

    for layer in range(model.cfg.n_layers):

        resid_pre = cache["resid_pre", layer] / scale
        resid_mid = cache["resid_pre", layer] / scale

        resid_pre_logits = einops.einsum(resid_pre, W_U, "batch seq d_model, d_model d_vocab -> batch seq d_vocab")
        resid_mid_logits = einops.einsum(resid_mid, W_U, "batch seq d_model, d_model d_vocab -> batch seq d_vocab")
        
        resid_pre_entropy = Categorical(logits = resid_pre_logits).entropy()
        resid_mid_entropy = Categorical(logits = resid_mid_logits).entropy()

        resid_entropies[2 * layer, :, :] = resid_pre_entropy
        resid_entropies[2 * layer + 1, :, :] = resid_mid_entropy

        for head in range(model.cfg.n_heads):
            
            head_contribution = einops.einsum(cache["z", layer][:, :, head], model.W_O[layer, head], "batch seq d_head, d_head d_model -> batch seq d_model") / scale
            new_resid_pre = resid_pre + head_contribution
            new_logits = einops.einsum(new_resid_pre, W_U, "batch seq d_model, d_model d_vocab -> batch seq d_vocab")

            new_entropy = Categorical(logits = new_logits).entropy()
            entropy_diff = new_entropy - resid_pre_entropy

            entropy_diffs[layer, head, :, :] = entropy_diff

            progress_bar.update(1)
            t.cuda.empty_cache()

        mlp_contribution = cache["mlp_out", layer] / scale
        new_resid_mid = resid_mid + mlp_contribution
        new_logits = einops.einsum(new_resid_mid, W_U, "batch seq d_model, d_model d_vocab -> batch seq d_vocab")
        new_entropy = Categorical(logits = new_logits).entropy()
        entropy_diff = new_entropy - resid_mid_entropy
        
        entropy_diffs[layer, -1, :, :] = entropy_diff
        
        progress_bar.update(1)
        t.cuda.empty_cache()

    progress_bar.close()

    entropy_final = Categorical(logits = logits).entropy()
    resid_entropies[-1] = entropy_final

    return resid_entropies, entropy_diffs

In [22]:
layer_entropies, entropy_diffs = entropy_measure(model, first_batch)

  0%|          | 0/156 [00:00<?, ?it/s]

<tqdm.auto.tqdm at 0x7f8c6810b460>

In [46]:
entropy_diffs_mean = einops.reduce(entropy_diffs, "layers heads batch seq -> layers heads", reduction = "mean")
layer_entropies_mean = einops.reduce(layer_entropies, "layers batch seq -> layers", reduction = "mean")
layer_entropies_diffs_mean = layer_entropies_mean[1:] - layer_entropies_mean[:-1]
layer_entropies_attn_diffs_mean, layer_entropies_mlp_diffs_mean = layer_entropies_diffs_mean[::2].tolist(), layer_entropies_diffs_mean[1::2].tolist()

# labels = concat_lists([[f"Attn {i}", f"MLP {i}"] for i in range(model.cfg.n_layers)])
line([layer_entropies_attn_diffs_mean, layer_entropies_mlp_diffs_mean], width=600, title="Entropy of distribution at each layer", labels={"x": "Layer", "y": "Entropy diff"})
imshow(entropy_diffs_mean, width=600, title="Reduction in entropy as a consequence of each head", border=True, labels={"x": "Head", "y": "Layer"})

zmax = entropy_diffs_mean[:, :-1].abs().max().item()
imshow(entropy_diffs_mean[:, :-1], width=600, title="Remove MLPs", border=True, zmin=-zmax, zmax=zmax, labels={"x": "Head", "y": "Layer"})

entropy_increases = entropy_diffs_mean[:, :-1] * (entropy_diffs_mean[:, :-1] > 0)
zmax = entropy_increases.max().item()
imshow(entropy_increases, width=600, title="Only showing entropy increases", border=True, zmin=-zmax, zmax=zmax, labels={"x": "Head", "y": "Layer"})

In [35]:
layer_entropies_ioi, entropy_diffs_ioi = entropy_measure(model, ioi_dataset.toks)

  0%|          | 0/156 [00:00<?, ?it/s]

<tqdm.auto.tqdm at 0x7f8b9c6d6760>

In [36]:
entropy_diffs_mean = einops.reduce(entropy_diffs_ioi, "layers heads batch seq -> layers heads", reduction = "mean")
layer_entropies_mean = einops.reduce(layer_entropies_ioi, "layers batch seq -> layers", reduction = "mean")
layer_entropies_diffs_mean = layer_entropies_mean[1:] - layer_entropies_mean[:-1]
layer_entropies_attn_diffs_mean, layer_entropies_mlp_diffs_mean = layer_entropies_diffs_mean[::2].tolist(), layer_entropies_diffs_mean[1::2].tolist()

# labels = concat_lists([[f"Attn {i}", f"MLP {i}"] for i in range(model.cfg.n_layers)])
line([layer_entropies_attn_diffs_mean, layer_entropies_mlp_diffs_mean], width=600, title="Entropy of distribution at each layer", labels={"x": "Layer", "y": "Entropy diff"})
imshow(entropy_diffs_mean, width=600, title="Reduction in entropy as a consequence of each head", border=True)

# Analysis so far

Entropy drops a lot from layer 0 to 1, because the MLPs act as an extended embedding. At the very start it's slightly better than random (because the tied embeddings still do something sorta bigram-ish), which is why entropy is below 10. But it drops a lot after the first attention layer.

Why does it drop by so much, and why do the attention heads play such a large role in making this drop (rather than the MLPs playing this role)?