# Max activating examples for 10.7 (by norm projection to logits)

Want to see where head 10.7 is most useful!

In [1]:
from transformer_lens.cautils.notebook import *

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [17]:
from transformer_lens.rs.callum.max_activating_exploration import print_best_outputs

In [2]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
model.set_use_attn_result(True)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
data = get_webtext(seed=5)

Found cached dataset openwebtext-10k (/home/ubuntu/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
LAYER_IDX, HEAD_IDX = (10, 7)
W_U = model.W_U.clone()
HEAD_HOOK_NAME = utils.get_act_name("result", LAYER_IDX)

NUM_PROMPTS = 100
BATCH_SIZE = 10

In [5]:
def hook_to_ablate_head(head_output: Float[Tensor, "batch seq_len head_idx d_head"], hook: HookPoint, head = (LAYER_IDX, HEAD_IDX)):
    assert head[0] == hook.layer()
    assert "result" in hook.name
    head_output[:, :, head[1], :] = 0
    return head_output

## How does 10.7 affect the logits when it's ablated?

We can see cross-entropy loss increases by 0.01 on average when this head is ablated. That might seem like not a lot, but it's actually not far off distribution to other late-stage heads.

In [6]:
str_token_list = []
loss_list = []
ablated_loss_list = []

for i in tqdm(range(NUM_PROMPTS)):
    # new_str = data[BATCH_SIZE * i: BATCH_SIZE * (i + 1)]
    new_str = data[i]
    new_str_tokens = model.to_str_tokens(new_str)
    tokens = model.to_tokens(new_str)
    # tokens = t.stack(tokens).to(device)
    loss = model(tokens, return_type="loss", loss_per_token=True)
    ablated_loss = model.run_with_hooks(tokens, return_type="loss", loss_per_token=True, fwd_hooks=[(HEAD_HOOK_NAME, hook_to_ablate_head)])
    loss_list.append(loss)
    ablated_loss_list.append(ablated_loss)
    str_token_list.append(new_str_tokens)


all_loss = t.cat(loss_list, dim=-1).squeeze()
all_ablated_loss = t.cat(ablated_loss_list, dim=-1).squeeze()

hist(
    all_ablated_loss - all_loss,
    title="Difference in loss after ablating (positive ⇒ loss increases)",
    labels={"x": "Difference in cross-entropy loss"},
    template="simple_white",
    add_mean_line=True,
    width=1000,
    nbins=200
)

  0%|          | 0/100 [00:00<?, ?it/s]

In [32]:
total_num_tokens = sum(len(i) for i in str_token_list)
top_half_pct = int(total_num_tokens * 0.005)


def find_best_improvements(str_token_list, loss_list, ablated_loss_list, k = 15, print_table = False):

    best_loss_decrease = []
    best_text = []
    best_indices = []

    for i, (stl, ll, all) in enumerate(zip(str_token_list, loss_list, ablated_loss_list)):

        loss_diff = (all - ll).squeeze()
        k_actual = min(k, loss_diff.shape[0])
        max_loss_decrease = loss_diff.topk(k_actual, largest=True)
        
        for value, index in zip(max_loss_decrease.values, max_loss_decrease.indices):
            text = stl[max(0, index - 15): index + 2]
            # ! Why `:idx+2` ? Because loss_diff[idx] is large, meaning we failed to predict the `idx+1`-th element, so this should be the last one in our list. We're highlighting the thing we predicted wrong.
            if text:
                text[-1] = f"[bold dark_orange u]{repr(text[-1])}[/]"
                text = "".join(text)
                if "�" not in text:
                    best_loss_decrease.append(value)
                    best_text.append(text + "\n\n")
                    best_indices.append((i, index.item()))

    table = Table("CE-Loss Decrease", "Prompt", title="Prompts & Answers:")

    best_k_indices = []
    best_k_loss_decrease = []
    for loss, text, idx in sorted(zip(best_loss_decrease, best_text, best_indices), key=lambda x: x[0], reverse=True)[:k]:
        table.add_row(f"{loss:.3f}", text)
        best_k_indices.append(idx)
        best_k_loss_decrease.append(loss.item())

    if print_table: rprint(table)

    return best_k_indices, best_k_loss_decrease

best_k_indices, best_k_loss_decrease = find_best_improvements(str_token_list, loss_list, ablated_loss_list, k=top_half_pct)

In [36]:
print_best_outputs(
    best_k_indices,
    best_k_loss_decrease,
    hook = (HEAD_HOOK_NAME, hook_to_ablate_head),
    model = model,
    data = data,
    n = 10,
    random = True
)

In [20]:
k

print_best_outputs(
    best_k_indices,
    best_k_loss_decrease,
    hook = (HEAD_HOOK_NAME, hook_to_ablate_head),
    model = model,
    data = data,
    n = 10,
    random = True
)

NameError: name 'k' is not defined