# Setup

In [1]:
from transformer_lens.cautils.notebook import *

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)
clear_output()

In [3]:
def parse_str(s: str):
    doubles = "“”"
    singles = "‘’"
    for char in doubles: s = s.replace(char, '"')
    for char in singles: s = s.replace(char, "'")
    return s

def parse_str_tok_for_printing(s: str):
    s = s.replace("\n", "\\n")
    return s

In [4]:
BATCH_SIZE = 50
SEQ_LEN = 200 # 1024

DATA_STR = get_webtext(seed=6)[:BATCH_SIZE]
DATA_STR = [parse_str(s) for s in DATA_STR]

DATA_TOKS = model.to_tokens(DATA_STR)
DATA_STR_TOKS = model.to_str_tokens(DATA_STR)

if SEQ_LEN < 1024:
    DATA_TOKS = DATA_TOKS[:, :SEQ_LEN]
    DATA_STR_TOKS = [str_toks[:SEQ_LEN] for str_toks in DATA_STR_TOKS]

DATA_STR_TOKS_PARSED = [[parse_str_tok_for_printing(tok) for tok in toks] for toks in DATA_STR_TOKS]

clear_output()

print(DATA_TOKS.shape, "\n")

print(DATA_STR_TOKS[0])

torch.Size([50, 200]) 

['<|endoftext|>', 'Oh', ' boy', ' was', ' this', ' damn', ' hard', ' to', ' crack', '.', '\n', '\n', 'Ok', ',', ' I', ' believe', ' before', ' it', ' was', ' established', ' before', ' that', ' A', 'perture', ' Science', ' headquarters', ' are', ' in', ' Cleveland', ',', ' OH', '.', '\n', '\n', 'Source', ':', ' HL', '2', 'EP', '2', '\n', '\n', 'Though', ',', ' this', ' has', ' been', ' found', '.', '\n', '\n', 'Source', ':', ' Portal', ' 2', '\n', '\n', 'It', ' can', ' be', ' assumed', ' that', ' the', ' En', 'rich', 'ment', ' Center', ' is', ' there', ',', ' as', ' there', ' is', ' an', ' underground', ' mine', ' in', ' the', ' game', '.', '\n', '\n', 'Now', ' here', "'s", ' the', ' catch', ':', ' Not', ' only', ' are', ' there', ' no', ' salt', ' mines', ' in', ' the', ' Upper', ' Peninsula', ',', ' the', ' surface', ' as', ' seen', ' at', ' the', ' end', ' of', ' Portal', ' 2', ' is', ' flat', ' with', ' wheat', ' growing', '.', ' The', ' UP', ' is', ' very',

# Data

Here's where I gather data for the other visualisations.

In [5]:
from dataclasses import dataclass, field
from typing import Any
Head = Tuple[int, int]

class HeadResults:
    data: Dict[Head, Tensor]
    def __init__(self, data=None):
        if data is None: # ! bad practice to have default arguments be dicts
            data = {}
        self.data = data

    def __getitem__(self, layer_and_head) -> Tensor:
        return self.data[layer_and_head].clone()
    
    def __setitem__(self, layer_and_head, value):
        self.data[layer_and_head] = value.clone()

@dataclass(frozen=False)
class LogitResults:
    zero_patched: HeadResults = HeadResults()
    mean_patched: HeadResults = HeadResults()
    zero_direct: HeadResults = HeadResults()
    mean_direct: HeadResults = HeadResults()

@dataclass(frozen=False)
class ModelResults:
    logits_orig: Tensor = t.empty(0)
    loss_orig: Tensor = t.empty(0)
    result: HeadResults = HeadResults()
    result_mean: HeadResults = HeadResults()
    pattern: HeadResults = HeadResults()
    direct_effect: HeadResults = HeadResults()
    direct_effect_mean: HeadResults = HeadResults()
    scale: Tensor = t.empty(0)
    logits: LogitResults = LogitResults()
    loss: LogitResults = LogitResults()

    def clear(self):
        # Empties all intermediate results which we don't need
        self.result = HeadResults()
        self.result_mean = HeadResults()

In [6]:
def get_data_dict(
    model: HookedTransformer,
    toks: Int[Tensor, "batch seq"],
    negative_heads: List[Tuple[int, int]],
    use_cuda: bool = False,
):
    model.reset_hooks(including_permanent=True)
    t.cuda.empty_cache()

    device = str(model.cfg.device)
    if use_cuda: model = model.cuda()
    else: model = model.cpu()

    model_results = ModelResults()

    # Cache the head results and attention patterns, and final ln scale

    def cache_head_result(result: Float[Tensor, "batch seq n_heads d_model"], hook: HookPoint, head: int):
        model_results.result[hook.layer(), head] = result[:, :, head]
    
    def cache_head_pattern(pattern: Float[Tensor, "batch n_heads seq_Q seq_K"], hook: HookPoint, head: int):
        model_results.pattern[hook.layer(), head] = pattern[:, head]
    
    def cache_scale(scale: Float[Tensor, "batch seq 1"], hook: HookPoint):
        model_results.scale = scale

    for layer, head in negative_heads:
        model.add_hook(utils.get_act_name("result", layer), partial(cache_head_result, head=head))
        model.add_hook(utils.get_act_name("pattern", layer), partial(cache_head_pattern, head=head))
    model.add_hook(utils.get_act_name("scale"), cache_scale)

    # Run the forward pass, to cache all values (and get logits)

    model_results.logits_orig, model_results.loss_orig = model(toks, return_type="both", loss_per_token=True)

    # Calculate the thing we'll be subbing in for mean ablation

    for layer, head in negative_heads:
        model_results.result_mean[layer, head] = einops.reduce(
            model_results.result[layer, head], 
            "batch seq d_model -> d_model", "mean"
        )

    # Now, use "result" to get the thing we'll eventually be adding to logits (i.e. scale it and map it through W_U)

    for layer, head in negative_heads:

        # TODO - is it more reasonable to patch in at the final value of residual stream instead of directly changing logits?
        model_results.direct_effect[layer, head] = einops.einsum(
            model_results.result[layer, head] / model_results.scale,
            model.W_U,
            "batch seq d_model, d_model d_vocab -> batch seq d_vocab"
        )
        model_results.direct_effect_mean[layer, head] = einops.reduce(
            model_results.direct_effect[layer, head],
            "batch seq d_vocab -> d_vocab",
            "mean"
        )

    # Two new forward passes: one with mean ablation, one with zero ablation. We only store logits from these

    def patch_head_result(
        result: Float[Tensor, "batch seq n_heads d_model"],
        hook: HookPoint,
        head: int,
        ablation_values: Optional[HeadResults] = None,
    ):
        if ablation_values is None:
            result[:, :, head] = t.zeros_like(result[:, :, head])
        else:
            result[:, :, head] = ablation_values[hook.layer(), head]
        return result

    for layer, head in negative_heads:
        model.add_hook(utils.get_act_name("result", layer), partial(patch_head_result, head=head))
        model_results.logits.zero_patched[layer, head] = model(toks, return_type="logits")
        model.add_hook(utils.get_act_name("result", layer), partial(patch_head_result, head=head, ablation_values=model_results.result_mean))
        model_results.logits.mean_patched[layer, head] = model(toks, return_type="logits")
    
    model_results.clear()

    # Now, the direct effects

    for layer, head in negative_heads:
        # Get the change in logits from removing the direct effect of the head
        model_results.logits.zero_direct[layer, head] = model_results.logits_orig - model_results.direct_effect[layer, head]
        # Get the change in logits from removing the direct effect of the head, and replacing with the mean effect
        model_results.logits.mean_direct[layer, head] = model_results.logits.zero_direct[layer, head] + model_results.direct_effect_mean[layer, head]

    # Calculate the loss for all of these
    for k in ["zero_patched", "mean_patched", "zero_direct", "mean_direct"]:
        setattr(model_results.loss, k, HeadResults({
            (layer, head): model.loss_fn(getattr(model_results.logits, k)[layer, head], toks, per_token=True)
            for layer, head in negative_heads
        }))

    model = model.to(device)
    return model_results

In [7]:
MODEL_RESULTS = get_data_dict(model, DATA_TOKS, negative_heads = [(10, 7), (11, 10)])

Moving model to device:  cpu
Moving model to device:  cuda


# Activations

Here's where I get the activation plots, where each value actually shows the effect on logits of ablating.

We show `(ablated loss) - (original loss)`, so blue (positivity) indicates that this head is helpful, because loss goes up when it gets ablated. 

In [8]:
assert MODEL_RESULTS.loss_orig.shape == MODEL_RESULTS.loss.mean_patched[(10, 7)].shape == (BATCH_SIZE, SEQ_LEN - 1)

loss_diffs = t.stack([
    t.stack(list(MODEL_RESULTS.loss.mean_patched.data.values())),
    t.stack(list(MODEL_RESULTS.loss.zero_patched.data.values())),
    t.stack(list(MODEL_RESULTS.loss.mean_direct.data.values())),
    t.stack(list(MODEL_RESULTS.loss.zero_direct.data.values())),
]) - MODEL_RESULTS.loss_orig
loss_diffs_padded = t.concat([loss_diffs, t.zeros((4, 2, BATCH_SIZE, 1))], dim=-1)
loss_diffs_padded = list(einops.rearrange(
    loss_diffs_padded, "loss_type head batch seq -> batch seq loss_type head"
).unbind(0))

html = cv.activations.text_neuron_activations(
    tokens = DATA_STR_TOKS_PARSED,
    activations = loss_diffs_padded,
    first_dimension_name = "loss_type",
    first_dimension_labels = ["mean, patched", "zero, patched", "mean, direct", "zero, direct"],
    second_dimension_name = "head",
    second_dimension_labels = ["10.7", "11.10"],
)

with open("test.html", "w") as file:
    file.write(str(html))

I also want to be able to print out what the biggest ones are.

Do the top 5 results here hold up to sanity checks, i.e. do they look like copy suppression?

In [9]:
batch_idx = 36

def to_string(toks):
    s = model.to_string(toks)
    s = s.replace("\n", "\\n")
    return s

cv.logits.token_log_probs(
    DATA_TOKS[batch_idx].cpu(),
    MODEL_RESULTS.logits_orig[batch_idx].log_softmax(-1),
    to_string = to_string
)

In [13]:
MODEL_RESULTS.logits_orig[batch_idx].log_softmax(-1).shape

torch.Size([200, 50257])

In [10]:
batch_idx = 36

def to_string(toks):
    s = model.to_string(toks)
    s = s.replace("\n", "\\n")
    return s

cv.logits.token_log_probs(
    DATA_TOKS[batch_idx].cpu(),
    MODEL_RESULTS.direct_effect[10, 7][batch_idx].log_softmax(-1),
    to_string = to_string,
    top_k = 10,
    negative = True,
)

TypeError: token_log_probs() got an unexpected keyword argument 'negative'

In [None]:
def topk_of_Nd_tensor(tensor: Float[Tensor, "rows cols"], k: int):
    '''
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    '''
    i = t.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()


loss_diffs_mean_direct_107 = loss_diffs[2, 0]

most_useful_positions = topk_of_Nd_tensor(loss_diffs_mean_direct_107, 5)

for batch_idx, seq_pos in most_useful_positions:
    print("\n".join([
        f"Batch = {batch_idx}",
        f"Seq pos = {seq_pos}",
        f"Loss increase from ablation = {loss_diffs_mean_direct_107[batch_idx, seq_pos]}",
        f"Text = {''.join(DATA_STR_TOKS_PARSED[batch_idx][seq_pos-10: seq_pos+1])}",
        ""
    ]))

# Attention Patterns

In [None]:
batch_idx = 36

cv.attention.attention_heads(
    attention = MODEL_RESULTS.pattern[10, 7][[batch_idx]], # (batch=1, seqQ, seqK)
    tokens = DATA_STR_TOKS_PARSED[batch_idx], # list of length seqQ
    attention_head_names = ["10.7"],
)