# Computational essay - Decoders and TransformerLens 

## Table of contents

1. [Introduction](#introduction)
    - [Imports and setup](#imports-and-setup)
    - [Helper functions](#helper-function)
2. [Applications](#applications)
    - [Preliminary exploration](#preliminary-exploration)
    - [Attention patterns](#attention-patterns)
    - [Induction heads](#induction-heads)
    - [Activation patching](#activation-patching)
    - [Logit Lens](#logit-lens)
    - [Causal scrubbing](#causal-scrubbing)

## Introduction

This computational essay on decoders and the TransformerLens library corresponds to the similarly titled sections of the report. In general, this run-through uses simple examples and a small model to clearly demonstrate the methods used and what type of results one may expect. We note that since the model is small, it does not perform the same as larger models, so the predictions are less accurate and thus, the effects of the interventions smaller. Moreover, it is according to the universality claim that these methods generalize to large scale models. As this is a postulate, we can not be sure of the validity of this assumption. This notebook does, however, show simple applications of some core methods for decoder interpretability research, which proves to be useful at smaller scales.

The structure of this section of the essay is as follows: 
- First, we do all necessary setup, including imports, loading our model (in our case GPT2-small) and defining some helper functions.
- Then, we go through the methods outlined in the corresponding paper. We explain our methods and choices, illustrate these methods with simple examples and suggest some relevant uses for these in later work. 

### Imports and setup

In [15]:
# Some necessary imports
import circuitsvis as cv
import torch
import plotly.express as px
import matplotlib.pyplot as plt
from functools import partial
import tqdm.auto as tqdm
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
import numpy as np
import einops

In [16]:
# Loading the model used in this notebook
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


We also disable automatic gradient computations as they are irrelevant to our work and quite time expensive.

In [17]:
torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

Disabled automatic differentiation


### Helper function

We define some helper functions for the visualizations. In the first function we use the CircuitsVis library with the "attention_patterns" and "attention_heads" methods to define a function which can visualize the attention patterns of chosen layers in the model.

For the second function, we visualize the results of an activation patching in terms of the logit difference to the corrupted prompt (see section on Activation Patching).

In [18]:
def vis_attn_patterns(model, text, layers, compact=True):
    ''' 
    Visualize attention patterns for a chosen number of layers.
    '''
    str_tokens = model.to_str_tokens(text)
    logits, cache = model.run_with_cache(text, remove_batch_dim=True)

    if compact:
        for layer in layers:
            print("Attention pattern for layer", layer)
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
    
    else:
        for layer in layers:
            print("Attention pattern for layer", layer)
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern))

def imshow_patching_result(model, patching_results, corrupted_prompt, corrupted_answer):
    '''
    Visualizes the logit differences caused by activation patching in a heat map. If the answer has more than one token, "patching_results" must be a list of results.
    '''
    if isinstance(patching_results, list):
        len_ans = len(patching_results)
        for i in range(len_ans):
            tokens = model.to_str_tokens(corrupted_prompt+corrupted_answer)
            labels = [f'{token}_{index}' for index, token in enumerate(tokens)][:-len(patching_results)+i]
            if not torch.all(patching_results[i] == 0):
                px.imshow(patching_results[i].detach(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", x=labels, labels={"x": "Position", "y": "Layer"}, title="Patching Results").show()
    else:
        tokens = model.to_str_tokens(corrupted_prompt)
        labels = [f'{token}_{index}' for index, token in enumerate(tokens)]
        px.imshow(patching_results[0].detach(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", x=labels, labels={"x": "Position", "y": "Layer"}, title="Patching Results").show()
    

## Applications

### Preliminary exploration

We first want to check if the chosen model is suitable for the behavior we are interested in, that is, if it can perform the task at hand. The TransformerLens library has some convenient methods for this exact use. We choose a simple promt that we will use in the following sections as well: "The capital of France is called" with corresponding answer " Paris". We note the space in front of "Paris". That is to adhere to the tokenization of the name as being a single token. Inconsistency of spaces and tokenization is a common source of errors in implementations of interpretability research. 

In [19]:
prompt = "The capital city of France is called"
answer = " Paris"

# We print the string tokens to examine the tokenization
print(model.to_str_tokens(prompt))
print(model.to_str_tokens(answer))

['<|endoftext|>', 'The', ' capital', ' city', ' of', ' France', ' is', ' called']
['<|endoftext|>', ' Paris']


We now see what the model predicts as its top tokens. We note that the following function does not return anything, but is primarily intended for exploratory analysis.

In [20]:
utils.test_prompt(prompt, answer, model)

Tokenized prompt: ['<|endoftext|>', 'The', ' capital', ' city', ' of', ' France', ' is', ' called']
Tokenized answer: [' Paris']


Top 0th token. Logit: 14.32 Prob:  8.26% Token: | Paris|
Top 1th token. Logit: 13.79 Prob:  4.86% Token: | Marse|
Top 2th token. Logit: 13.78 Prob:  4.81% Token: | the|
Top 3th token. Logit: 13.71 Prob:  4.49% Token: | "|
Top 4th token. Logit: 12.99 Prob:  2.17% Token: | La|
Top 5th token. Logit: 12.68 Prob:  1.61% Token: | '|
Top 6th token. Logit: 12.66 Prob:  1.57% Token: | Mont|
Top 7th token. Logit: 12.60 Prob:  1.48% Token: | St|
Top 8th token. Logit: 12.59 Prob:  1.46% Token: | Saint|
Top 9th token. Logit: 12.30 Prob:  1.09% Token: | V|


We see that even though our chosen model is small, it still predicts the correct next token. However, the probabilities are low, indicating that the model is not sure in its prediction. To solve this issue, one could choose a larger model, but for effiency's sake, we continue with the smaller model.

### Attention patterns

We now want to visualize the attention patterns of our model. We use the same prompt as above, but this time use the visualization tools in the CircuitsVis library. The number of layers (and heads) of a model can be found as attributes for the defined model.

We first perform a single forward pass through the model, caching the intermediate results. Then we find all attention patterns from the key "attn". Lastly, we visualize the results.

In [21]:
# We decide to show the first layer (index 0)

model_tokens = model.to_tokens(prompt)
model_logits, model_cache = model.run_with_cache(model_tokens, remove_batch_dim=True)

attention_patterns = model_cache["pattern", 0, "attn"]
str_tokens = model.to_str_tokens(prompt)

print("Layer 0 - Head Attention Patterns")
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_patterns)
#plt.savefig("attention_L0.png")
#plt.show()

Layer 0 - Head Attention Patterns


From these patterns, we can see a few things. First, we see that for many of the heads (i.e. 0, 2 and 9) the first column is the most strongly colored. This indicates that the tokens all attend to the first token. This token is a "Beginning of Sentence" (BOS) token and thus does not inherently contain any information. This then acts as an available place to store information. 

Secondly, we see that some heads (i.e. 1 and mostly 5) does not act. The tokens only attend to themselves, and we see this clearly with the diagonal in the figure. 

Lastly, some of the head (i.e. 8 and 10) have more interesting patterns. While there is still some column 1 dominance, they also seem to act significantly on other token pairs. This could provide an interesting starting point for further investingation.

### Induction heads

We now turn to induction heads, one of the simplest circuits found in smaller models. These are one-off attention heads, responsible for detecting repetition in sequences. By visualizing the attention patterns, we can clearly see these induction heads. 

The following approach is based on the one found in Neel Nanda's colab notebook "Transformer Lens Main Demo Notebook" found [here](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Main_Demo.ipynb). We first choose a repeating sequence. To compare to the previous example, we use "The capital of France is called Paris. The capital of France is called".

In [22]:
repeated_prompt = (prompt+answer+".")*2
repeated_tokens = model.to_tokens(repeated_prompt)

sequence_len = len(model.to_str_tokens(repeated_prompt))//2

Then, we define an induction score which we will use to measure how much each attention head acts as an induction head. We do this by averaging all attention scores starting from the beginning of the second repetition. To access the intermediate values of the model, we must use the hook functionaly provided by TransformerLens. By defining a hook function, and later calling it with the "run_with_hooks" method, we can calculate the average attention scores during the forward pass and save them to an induction_score_store tensor.

In [23]:
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

def induction_score_hook(activation_pattern, hook):
    # We take the diagonal for all tokens with index>=sequence_len
    induction_stripe = activation_pattern.diagonal(dim1=-2, dim2=-1, offset=1-sequence_len)
    # We get the average score for each head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # We store the results
    induction_score_store[hook.layer(), :] = induction_score

We now write to the "induction_score_store" tensor by running a forward pass of the model with the given prompt with the hook function. The way this works is that the model has preset hooks at all important (for instance attention head or tokens in the residual stream) points in the network. These all have corresponding keywords, and we want to include all hook points regarding attention patterns. Then, when using the "run_with_hooks" method, the model will stop at all given hook points and call the hook function we have defined. 

In [24]:
# We make a boolean filter on activation names
pattern_hook_names_filter = lambda name: name.endswith("pattern")

# Run with hooks (this is where we write to the `induction_score_store` tensor`)
model.run_with_hooks(repeated_tokens,
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook
    )]
)

Having run a forward pass on the model with the hook function, we can now plot the induction scores for each head in each layer.

In [25]:
fig = px.imshow(induction_score_store, color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":"Head", "y":"Layer"}, title="Induction Score by Head", text_auto=".2f")
fig.show()

We see that some of the heads are more strongly activated. To see how this shows on the attention patterns, we visualize the attention patterns for the 5th and 7th layer.

In [26]:
vis_attn_patterns(model, repeated_prompt, [5, 7])

Attention pattern for layer 5


Attention pattern for layer 7


We can see that the attention heads that correspond to the stronger induction scores have distinct attention patterns. They all have a clearly indicated diagonal, offset from the center by half sequence length. Hence, these attention heads are active when a sequence repeats itself. We would also see this if we were to rerun the code with another repeated sequence. 

### Activation patching

We now want to demonstrate the activation patching technique on a small problem. We will first run a forward pass on the clean prompt and cache the activations. Then, we will run a forward pass on the corrupted run and at each layer and position exchange the corrupted activations with the clean activations from the cache. To do this, we use the run_with_hooks method from TransformerLens and make a suitable hook function, activation_patching_hook. We make an activation_patching function that returns the results, as well as an activation_patching_mult function to use if the answer has more than a single token. In the latter case, we need to run through the model multiple separate times, one for each token in the answer.

In [27]:
def activation_patching_hook(resid_pre, hook, position, clean_cache):
    clean_activation = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_activation[:, position, :]
    return resid_pre

def activation_patching(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer):
    '''
    Performs activation patching of the clean prompt onto the corrupted prompt. The prompts must have the same number of tokens.

    Parameters:
    model: The transformer lens model
    clean_prompt (str): The clean prompt we will patch from
    corrupted_prompt (str): The corrupted prompt we will patch onto
    clean_answer (str): The answer (or next prediction) of the clean prompt
    corrupted_answer (str): The answer (or next prediction) of the corrupted prompt

    Returns: 
    patching_results (list[tensor[layers, positions]]): The logit difference after patching
    patched_logits (list[tensor[num_tokens, logits]]): The logits of the tokens after patching
    '''
    
    clean_logits, clean_cache = model.run_with_cache(clean_prompt)
    corrupted_logits = model(corrupted_prompt)
    print("Clean answer:",clean_answer)
    print("Corrupted answer:", corrupted_answer)
    clean_index = model.to_single_token(clean_answer)
    corrupted_index = model.to_single_token(corrupted_answer)

    clean_diff = clean_logits[0, -1, clean_index] - clean_logits[0, -1, corrupted_index]
    corrupted_diff = corrupted_logits[0, -1, clean_index] - corrupted_logits[0, -1, corrupted_index]

    clean_tokens = model.to_tokens(clean_prompt)
    corrupted_tokens= model.to_tokens(corrupted_prompt)
    num_positions = len(model.to_tokens(clean_prompt)[0])

    assert len(clean_tokens[0]) == len(corrupted_tokens[0]), "The prompts must have the same number of tokens."

   
    patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)
    for layer in tqdm.tqdm(range(model.cfg.n_layers)):
        for position in range(num_positions):
            # We use a temporary hook with functool.partial to patch at each position
            temp_hook = partial(activation_patching_hook, position=position, clean_cache=clean_cache)
            # We then run the model with hooks as usual
            patched_logits = model.run_with_hooks(corrupted_tokens, 
                                                  fwd_hooks=[(utils.get_act_name("resid_pre", layer), temp_hook)])
            
            # We then calculate the logit difference
            patched_diff = (patched_logits[0, -1, clean_index] - patched_logits[0, -1, corrupted_index]).detach()
            # We then store the result in the patching_result tensor, normalizing it
            if abs(clean_diff-corrupted_diff) < 1e-16:
                patching_result[layer, position] = 0
            else:
                patching_result[layer, position] = abs((patched_diff - corrupted_diff) / (clean_diff - corrupted_diff))
    print(patched_logits.shape)
    return patching_result, patched_logits

def activation_patching_mult(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer):
    ''' 
    Performs activation patching on prompts with multi-word answers by using separate run-throughs.
    The answers must have the same number of tokens
    '''
    patching_result = []
    patched_logits = []
    clean_answers_tokens = model.to_str_tokens(clean_answer)[1:]
    corrupted_answers_tokens = model.to_str_tokens(corrupted_answer)[1:]
    print("Number of run throughs:", len(clean_answers_tokens))
    for i in range(len(clean_answers_tokens)):
        p_result, p_logits = activation_patching(model, clean_prompt, corrupted_prompt, 
                                                 clean_answers_tokens[0], corrupted_answers_tokens[0])
        patching_result.append(p_result)
        patched_logits.append(p_logits)
        clean_prompt += clean_answers_tokens[0]
        clean_answers_tokens = clean_answers_tokens[1:]
        corrupted_prompt += corrupted_answers_tokens[0]
        corrupted_answers_tokens = corrupted_answers_tokens[1:]

    return patching_result, patched_logits

Following from the previous examples, our clean prompt is the same as earlier. We choose the corrupted prompt very similarly, with only one key difference, and make sure that they have the same number of tokens. 

In [28]:
clean_prompt = "The capital city of France is called"
clean_answer = " Paris"
corrupted_prompt = "The capital city of Italy is called"
corrupted_answer = " Rome"

# To determine the function used, we print the tokens
print(model.to_str_tokens([clean_prompt, clean_answer]))
print(model.to_str_tokens([corrupted_prompt, corrupted_answer]))

[['<|endoftext|>', 'The', ' capital', ' city', ' of', ' France', ' is', ' called'], ['<|endoftext|>', ' Paris']]
[['<|endoftext|>', 'The', ' capital', ' city', ' of', ' Italy', ' is', ' called'], ['<|endoftext|>', ' Rome']]


We see that the answers to both prompts are a single token. Hence, we use the function corresponding to single-token answers.

We then check if the model can, in fact, predict the correct answers to the prompt. We use the same function as in the preliminary section.

In [29]:
utils.test_prompt(clean_prompt, clean_answer, model)
utils.test_prompt(corrupted_prompt, corrupted_answer, model)

Tokenized prompt: ['<|endoftext|>', 'The', ' capital', ' city', ' of', ' France', ' is', ' called']
Tokenized answer: [' Paris']


Top 0th token. Logit: 14.32 Prob:  8.26% Token: | Paris|
Top 1th token. Logit: 13.79 Prob:  4.86% Token: | Marse|
Top 2th token. Logit: 13.78 Prob:  4.81% Token: | the|
Top 3th token. Logit: 13.71 Prob:  4.49% Token: | "|
Top 4th token. Logit: 12.99 Prob:  2.17% Token: | La|
Top 5th token. Logit: 12.68 Prob:  1.61% Token: | '|
Top 6th token. Logit: 12.66 Prob:  1.57% Token: | Mont|
Top 7th token. Logit: 12.60 Prob:  1.48% Token: | St|
Top 8th token. Logit: 12.59 Prob:  1.46% Token: | Saint|
Top 9th token. Logit: 12.30 Prob:  1.09% Token: | V|


Tokenized prompt: ['<|endoftext|>', 'The', ' capital', ' city', ' of', ' Italy', ' is', ' called']
Tokenized answer: [' Rome']


Top 0th token. Logit: 14.02 Prob:  6.08% Token: | Rome|
Top 1th token. Logit: 13.98 Prob:  5.81% Token: | the|
Top 2th token. Logit: 13.42 Prob:  3.32% Token: | Milan|
Top 3th token. Logit: 13.28 Prob:  2.89% Token: | "|
Top 4th token. Logit: 12.80 Prob:  1.79% Token: | Florence|
Top 5th token. Logit: 12.59 Prob:  1.45% Token: | Naples|
Top 6th token. Logit: 12.56 Prob:  1.40% Token: | '|
Top 7th token. Logit: 12.47 Prob:  1.28% Token: | T|
Top 8th token. Logit: 12.45 Prob:  1.26% Token: | Pal|
Top 9th token. Logit: 12.43 Prob:  1.24% Token: | St|


As the predictions are correct for both prompts, we can continue in our example. We do, however, note that the probabilities are still quite low for both answers, which could influence the results somewhat.

Continuing, we then perform activation patching on the two prompts using the functions defined earlier.

In [30]:
patching_results = activation_patching(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer)
imshow_patching_result(model, patching_results, corrupted_prompt, corrupted_answer)

Clean answer:  Paris
Corrupted answer:  Rome


  0%|          | 0/12 [00:00<?, ?it/s]

torch.Size([1, 8, 50257])


We can clearly see a change in the logit difference depending on which layer and token we patch. We also see that at layer 9 and 10 (0-indexed), the information from the country token is communicated to the last token, not affecting the intermediate tokens.

### Logit Lens

In this section we want to demonstrate Logit Lens, a technique used to see how a prediction changes throughout the model. TransformerLens has functionality for accessing the accumulated residual stack at each layer, which we will use to retrieve the logits. We will then apply a layer norm to these values to account for fact that this usually happens after the last layer. Lastly, we will compute the logit difference and find the top predicted token for each position at each layer. 

We choose the same, repeating prompt as earlier, this time we let the last token be a separate object. We also run through the model again to make sure we cache the non-hooked activations. 

In [31]:
repeated_tokens = model.to_tokens(repeated_prompt)
rep_answer_token = repeated_tokens[0][-1]
repeated_tokens = repeated_tokens[0][:-1]
print(repeated_tokens.shape)
print(rep_answer_token)

logits, cache = model.run_with_cache(repeated_tokens)
print(logits.shape)

torch.Size([18])
tensor(13)
torch.Size([1, 18, 50257])


We will now access the accumulated residual stack and apply a layer norm. Also, we must unembed the vectors by multiplying by the unembedding matrix. 

In [32]:
acc_resid_stack = cache.accumulated_resid()
scaled_resid_stack = cache.apply_ln_to_stack(acc_resid_stack)
print(scaled_resid_stack.shape)

unembedding_matrix = model.W_U
logit_lens_final_logits = einops.einsum(scaled_resid_stack, unembedding_matrix,
                                        "n_layer ... pos n_dim, n_dim vocab -> n_layer pos vocab")
print(logit_lens_final_logits.shape)

torch.Size([13, 1, 18, 768])
torch.Size([13, 18, 50257])


We will now compute the logit difference between the intermediate logits and the final prediction logits, as well as convert them to their corresponding top token. Then, we will plot development of the prediction.

In [33]:
logits_final = []
inter_tokens = []
indices_all = []

for layer in range(model.cfg.n_layers):
    # Converting to probabilities
    probs = logit_lens_final_logits[layer].softmax(dim=-1)
    top_prob, indices = probs.topk(1)
    
    indices_all.append(indices)
    logits_final.append([logit_lens_final_logits[layer, i, index].item() 
                         for i, index in enumerate(indices)])
    top_tokens = [model.to_string(index.item()) for index in indices]
    inter_tokens.append(top_tokens)

logit_diff_results = torch.zeros(model.cfg.n_layers, len(repeated_tokens))
answer_tokens = repeated_tokens.tolist()[1:] + [rep_answer_token]
answers_str_tokens = [f'{model.to_string(answer)}_{ids}' for ids, answer in enumerate(answer_tokens)]

for layer in range(model.cfg.n_layers):
    for position in range(len(logit_lens_final_logits[0])):
        logit_diff_results[layer][position] = logits[0, position, answer_tokens[position]] 
        - logit_lens_final_logits[layer, position, indices_all[layer][position][0]]
logits_final = np.array(logits_final)
answers_str_tokens = np.array(answers_str_tokens)

fig = px.imshow(logits_final[:,1:], labels=dict(x='Tokens', y='Layers', color='Logits'), 
                x=answers_str_tokens[1:], aspect='auto')
fig.update_traces(text=[tokens[1:] for tokens in inter_tokens], texttemplate='%{text}')
fig.update_layout(title_text="Development of predicted next tokens", title_x=0.5)


We can see that although the model is quite small, it gives sensible predictions in most cases. We also see that for the second repeated sequence, it performs significantly better than the first half. The answers get much better after the fifth and seventh layer, consistent with the induction heads we found in an earlier section.

## Causal scrubbing

In this last section, we will briefly go through our incomplete attempt at implementing causal scrubbing. The implementation follows the pseudocode provided by Redwood Research, found [here](https://www.lesswrong.com/posts/JvZhhzycHu2Yd57RN/causal-scrubbing-a-method-for-rigorously-testing). 

We define the following:
- $G$, the computational graph of the model, that is, a set of nodes $n_G$ (heads in attention an MLP fx) and a set of the connections, ie computations, between them.
- $D$, the dataset. This contains all data we want to test the hypothesis on, ie. the relevant sentences.
- $h$, the hypothesis. This is a tuple $(G, I, c)$, where $I$ is the hypothesized important graph and $c:I\to G$ is an injective correspondence function.

First, we need some additional imports:

In [34]:
import random
from typing import Union

We then construct a Node class, representing the idea of a single node in our computational graph.

In [35]:
class Node:
    '''
    A class to represent the idea of a node in the neural network. 
    '''
    def __init__(self, layer: int, head: int, model):
        self.layer = layer
        self.head = head
        self.model = model

    def parents(self):
        parent_list = []
        for i in range(self.layer):
            for j in range(self.head):
                parent_list.append(Node(i, j, self.model))
        return parent_list

    def value_on(self, x: str) -> torch.tensor:
        logits, _ = self.model.run_with_cache(x)
        return logits[self.layer, self.head]
    
    def __call__(self):
        return self.layer, self.head
    
    def value_from_inputs(self, inputs, text):
        self.model.reset_hooks()
        self.model.cfg.use_attn_result = True

        def exchange_activations_attn(result, hook, inputs, layer, pos):
            print(result.shape)
            result[hook.name][...]
            for position in range(self.head):
                print(layer, "---", position)
                result[layer, position, :] = inputs[layer, position]
            return result
        
        tokens = self.model.to_tokens(text)
        print(inputs)
        for i in range(self.layer):
            for j in range(self.head):
                temp_hook = partial(exchange_activations_attn, inputs=inputs, layer=i, pos=j)
                self.model.add_hook(f"blocks.{i}.attn.hook_result", temp_hook)
        logits = self.model(tokens)
        self.model.reset_hooks()
        return [logits[0, -1, :]]

We also need a class for the hypothesis we want to test. We then have

In [36]:
class Hypothesis:
    def __init__(self, G: list[tuple[int]], I: list[tuple[int]], c: callable, model):
        self.G = G
        self.I = I 
        self.c = c
        self.model = model

    def set_domain(self, D: list[str]):
        self.D = D

    #def set_model(self, model):
        #self.model = model

    def c_image(self, set: list[Node]) -> list[Node]:
        if isinstance(set, Node):
            return self.c(set)
    
        image = list(map(self.c, set))
        return image
    
    def c_preimage(self, set: Union[list[Node], Node]) -> list[Node]:
        preimage = []
        for node in self.I:
            if isinstance(set, Node):
                if self.c(node) == set:
                    preimage.append(node)
            else:
                if self.c(node) in set:
                    preimage.append(node)
        return preimage


We then want a function giving us a random sample that agrees on the given output. Note that we have no way to be sure if this actually exists. An approximate answer could be a good solution, but this is not included in our work. 

In [37]:
def sample_agreeing_x(D: list[str], n_I: Node, ref_x: str) -> str:
    '''Returns random sample input that agrees on the specified node.'''
    D_agree = [x for x in D if n_I.value_on(ref_x) == n_I.value_on(x)]
    return random.choice(D_agree)

The main idea of this method is the sscrubbing of a graph, which follows. 

In [38]:
def run_scrub(h: Hypothesis, D: list[str], n_I: Union[Node, tuple[int]], ref_x: str):
    '''
    Return the output after a scrub changing all activation of unimportant nodes.

    Param:
    c (callable): The hypothesis correspondence
    D (list[list]): A list of all possible inputs in the domain.
    n_I (tuple[int]): The node in question
    ref_x (list): The reference input
    '''
    model = h.model
    h.set_domain(D)

    if isinstance(n_I, tuple):
        n_I = Node(*n_I, model)

    # The corresponding main node
    n_G = h.c(n_I)

    if (n_G.layer, n_G.head) == (0,0):
        return model(ref_x)[0, 0, -1]
    
    inputs_G = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_vocab)
    
    # We pick a random sample to use as the exchange values for the unimportant nodes
    random_x = random.choice(D)

    # Get scrubbed activations for the inputs to n_G
    for parent_G in n_G.parents():
        # "important" parents
        if parent_G in h.c_image(n_I.parents()):
            parent_I = h.c_preimage(parent_G)

            # Sample new input that agrees on the interpretation node
            new_x = sample_agreeing_x(D, parent_I, ref_x)

            # Get the scrubbed activations
            inputs_G[*parent_G(), :] = run_scrub(h, D, parent_I, new_x)
        
        # "unimportant" parents
        else:
            # get activations on the random input value chosen
            inputs_G[*parent_G(), :] = parent_G.value_on(random_x)
    print(inputs_G, "hei")
    # run n_G given the computed input activation
    return n_G.value_from_inputs(inputs_G, ref_x)

Lastly, we need a function to estimate the performance of the hypothesis. 

In [39]:
# A function estimating the output from a scrubbed model
def estimate(h: Hypothesis, D: list[str]):
    _G, I, c = h.G, h.I, h.c
    outs = []
    for i in range(len(I)):
        x = random.choice(D)
        outs.append(run_scrub(h, D, I[i], x))
    return torch.mean(torch.tensor(outs), 1)

What follows is an example of what a hypothesis and method run-through might look like, but we keep in mind that this causes an error message.

In [40]:
text = ["hello", "hei", "hallo", "hi"]
G = []
for i in range(model.cfg.n_layers):
    for j in range(model.cfg.n_heads):
        G.append((i,j))
I = [(1,0), (1,1)]
h = Hypothesis(G, I, lambda x: x, model)

estimate(h, text)

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

IndexError: too many indices for tensor of dimension 4