# Collection of functions for decoders

This is a collections I have written as part of my work with mechanical interpretability with decoders (primarily gpt2-small) and TransformerLens. Some are implementations of methods, others are visualizing results more flexibly. 

In [3]:
# Some necessary imports
import circuitsvis as cv
import torch
import plotly.express as px
from functools import partial
import tqdm.auto as tqdm
import transformer_lens.utils as utils
import numpy as np

Visualizations:

In [6]:
def vis_attn_patterns(model, text, layers, compact=True):
    ''' 
    Visualize attention patterns for a chosen number of layers.
    '''
    str_tokens = model.to_str_tokens(text)
    logits, cache = model.run_with_cache(text, remove_batch_dim=True)

    if compact:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
    
    else:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            
            display(cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern))

def imshow_patching_result(model, patching_results, corrupted_prompt, corrupted_answer):
    '''
    Visualizes the logit differences caused by activation patching in a heat map. If the answer has more than one token, "patching_results" must be a list of results.
    '''
    if isinstance(patching_results, list):
        len_ans = len(patching_results)
        for i in range(len_ans):
            tokens = model.to_str_tokens(corrupted_prompt+corrupted_answer)
            labels = [f'{token}_{index}' for index, token in enumerate(tokens)][:-len(patching_results)+i]
            if not torch.all(patching_results[i] == 0):
                #print(patching_results[i])
                px.imshow(patching_results[i].detach(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", x=labels, labels={"x": "Position", "y": "Layer"}, title="Patching Results").show()
    else:
        tokens = model.to_str_tokens(corrupted_prompt)
        labels = [f'{token}_{index}' for index, token in enumerate(tokens)]
        px.imshow(patching_results.detach(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", x=labels, labels={"x": "Position", "y": "Layer"}, title="Patching Results").show()
    

## Activation patching

We now want to demonstrate the activation patching technique on a small problem. We will first run a forward pass on the clean prompt and cache the activations. Then, we will run a forward pass on the corrupted run and at each layer and position exchange the corrupted activations with the clean activations from the cache. To do this, we use the run_with_hooks method from TransformerLens and make a suitable hook function, activation_patching_hook. We make an activationg_patching function that returns the results, as well as an activation_patching_mult function to use if the answer has more than a single token. In the latter case, we need to run through the multiple separate times, one for each token in the answer.

In [None]:
def activation_patching_hook(resid_pre, hook, position, clean_cache):
    clean_activation = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_activation[:, position, :]
    return resid_pre

def activation_patching(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer):
    '''
    Performs activation patching of the clean prompt onto the corrupted prompt. The prompts must have the same number of tokens.

    Parameters:
    model: The transformer lens model
    clean_prompt (str): The clean prompt we will patch from
    corrupted_prompt (str): The corrupted prompt we will patch onto
    clean_answer (str): The answer (or next prediction) of the clean prompt
    corrupted_answer (str): The answer (or next prediction) of the corrupted prompt

    Returns: 
    patching_results (list[tensor[layers, positions]]): The logit difference after patching
    patched_logits (list[tensor[num_tokens, logits]]): The logits of the tokens after patching
    '''
    
    clean_logits, clean_cache = model.run_with_cache(clean_prompt)
    corrupted_logits = model(corrupted_prompt)
    print("Clean answer:",clean_answer)
    print("Corrupted answer:", corrupted_answer)
    clean_index = model.to_single_token(clean_answer)
    corrupted_index = model.to_single_token(corrupted_answer)

    clean_diff = clean_logits[0, -1, clean_index] - clean_logits[0, -1, corrupted_index]
    corrupted_diff = corrupted_logits[0, -1, clean_index] - corrupted_logits[0, -1, corrupted_index]

    clean_tokens = model.to_tokens(clean_prompt)
    corrupted_tokens= model.to_tokens(corrupted_prompt)
    num_positions = len(model.to_tokens(clean_prompt)[0])

    assert len(clean_tokens[0]) == len(corrupted_tokens[0]), "The prompts must have the same number of tokens."

   
    patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)
    for layer in tqdm.tqdm(range(model.cfg.n_layers)):
        for position in range(num_positions):
            # We use a temporary hook with functool.partial to patch at each position
            temp_hook = partial(activation_patching_hook, position=position, clean_cache=clean_cache)
            # We then run the model with hooks as usual
            patched_logits = model.run_with_hooks(corrupted_tokens, 
                                                  fwd_hooks=[(utils.get_act_name("resid_pre", layer), temp_hook)])
            
            # We then calculate the logit difference
            patched_diff = (patched_logits[0, -1, clean_index] - patched_logits[0, -1, corrupted_index]).detach()
            # We then store the result in the patching_result tensor, normalizing it
            if abs(clean_diff-corrupted_diff) < 1e-16:
                patching_result[layer, position] = 0
            else:
                patching_result[layer, position] = abs((patched_diff - corrupted_diff) / (clean_diff - corrupted_diff))
    print(patched_logits.shape)
    return patching_result, patched_logits

def activation_patching_mult(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer):
    ''' 
    Performs activation patching on prompts with multi-word answers by using separate run-throughs.
    The answers must have the same number of tokens
    '''
    patching_result = []
    patched_logits = []
    clean_answers_tokens = model.to_str_tokens(clean_answer)[1:]
    corrupted_answers_tokens = model.to_str_tokens(corrupted_answer)[1:]
    print("Number of run throughs:", len(clean_answers_tokens))
    for i in range(len(clean_answers_tokens)):
        p_result, p_logits = activation_patching(model, clean_prompt, corrupted_prompt, 
                                                 clean_answers_tokens[0], corrupted_answers_tokens[0])
        patching_result.append(p_result)
        patched_logits.append(p_logits)
        clean_prompt += clean_answers_tokens[0]
        clean_answers_tokens = clean_answers_tokens[1:]
        corrupted_prompt += corrupted_answers_tokens[0]
        corrupted_answers_tokens = corrupted_answers_tokens[1:]

    return patching_result, patched_logits

As an example, we choose the prompts

Miscellaneous:

In [7]:
def get_prediction(model, logits, num_top=5):
    ''' 
    Returns the most likely words and their corresponding probabilities.
    '''
    if isinstance(logits, list):
        for i in range(len(logits)):
            logits_i = logits[i][0, -1, :]
            probs = logits_i.softmax(dim=-1)
            top_probs, top_indices = probs.topk(num_top)
            top_tokens = [model.to_string(index.item()) for index in top_indices]

            for token, prob in zip(top_tokens, top_probs):
                print(f'{token:>15}: {prob.item():.4f}, ', end='')
            print()
    
    else:
        logits_i = logits[0, -1, :]
        probs = logits_i.softmax(dim=-1)
        top_probs, top_indices = probs.topk(num_top)
        top_tokens = [model.to_string(index.item()) for index in top_indices]

        for token, prob in zip(top_tokens, top_probs):
            print(f'{token}: {prob.item():.4f}, ', end='')