We define the following:
- $G$, the computational graph of the model, that is, a set of nodes $n_G$ (heads in attention an MLP fx) and a set of the connections, ie computations, between them.
- $D$, the dataset. This contains all data we want to test the hypothesis on, ie. the relevant sentences.
- $h$, the hypothesis. This is a tuple $(G, I, c)$, where $I$ is the hypothesized important graph and $c:I\to G$ is an injective correspondence function.

In [2]:
import random
import torch
from typing import Union

from transformer_lens import HookedTransformer
from functools import partial
import transformer_lens.utils as utils

In [35]:
class Node:
    '''
    A class to represent the idea of a node in the neural network. 
    '''
    def __init__(self, layer: int, head: int, model):
        self.layer = layer
        self.head = head
        self.model = model

    def parents(self):
        parent_list = []
        for i in range(self.layer):
            for j in range(self.head):
                parent_list.append(Node(i, j, self.model))
        return parent_list

    def value_on(self, x: str) -> torch.tensor:
        logits, _ = self.model.run_with_cache(x)
        return logits[self.layer, self.head]
    
    def __call__(self):
        return self.layer, self.head
    
    def value_from_inputs(self, inputs, text):
        self.model.reset_hooks()
        self.model.cfg.use_attn_result = True

        def exchange_activations_attn(result, hook, inputs, layer, pos):
            print(result.shape)
            result[hook.name][...]
            for position in range(self.head):
                print(layer, "---", position)
                result[layer, position, :] = inputs[layer, position]
            return result
        
        tokens = self.model.to_tokens(text)
        print(inputs)
        for i in range(self.layer):
            for j in range(self.head):
                temp_hook = partial(exchange_activations_attn, inputs=inputs, layer=i, pos=j)
                self.model.add_hook(f"blocks.{i}.attn.hook_result", temp_hook)
        logits = self.model(tokens)
        self.model.reset_hooks()
        return [logits[0, -1, :]]

In [4]:
class Hypothesis:
    def __init__(self, G: list[tuple[int]], I: list[tuple[int]], c: callable, model):
        self.G = G
        self.I = I 
        self.c = c
        self.model = model

    def set_domain(self, D: list[str]):
        self.D = D

    #def set_model(self, model):
        #self.model = model

    def c_image(self, set: list[Node]) -> list[Node]:
        if isinstance(set, Node):
            return self.c(set)
    
        image = list(map(self.c, set))
        return image
    
    def c_preimage(self, set: Union[list[Node], Node]) -> list[Node]:
        preimage = []
        for node in self.I:
            if isinstance(set, Node):
                if self.c(node) == set:
                    preimage.append(node)
            else:
                if self.c(node) in set:
                    preimage.append(node)
        return preimage


In [5]:
# TODO What if no matches? approx?
def sample_agreeing_x(D: list[str], n_I: Node, ref_x: str) -> str:
    '''Returns random sample input that agrees on the specified node.'''
    D_agree = [x for x in D if n_I.value_on(ref_x) == n_I.value_on(x)]
    return random.choice(D_agree)

In [None]:
# TODO: In a hypothesis class?
def hypothesis_correspondence(n_I: Node) -> Node:
    '''
    Return the corresponding node in the main graph.

    Param:
    n_I (int): A node in the interpretation graph I

    Returns:
    n_G (int): The correspondin node in the main graph G, according to the hypothesis
    '''
    pass


def run_scrub(h: Hypothesis, D: list[str], n_I: Union[Node, tuple[int]], ref_x: str):
    '''
    Return the output after a scrub changing all activation of unimportant nodes.

    Param:
    c (callable): The hypothesis correspondence
    D (list[list]): A list of all possible inputs in the domain.
    n_I (tuple[int]): The node in question
    ref_x (list): The reference input
    '''
    model = h.model
    h.set_domain(D)

    if isinstance(n_I, tuple):
        n_I = Node(*n_I, model)

    # The corresponding main node
    n_G = h.c(n_I)

    if (n_G.layer, n_G.head) == (0,0):
        return model(ref_x)[0, 0, -1]
    
    inputs_G = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_vocab)
    
    # We pick a random sample to use as the exchange values for the unimportant nodes
    random_x = random.choice(D)

    # Get scrubbed activations for the inputs to n_G
    for parent_G in n_G.parents():
        # "important" parents
        if parent_G in h.c_image(n_I.parents()):
            parent_I = h.c_preimage(parent_G)

            # Sample new input that agrees on the interpretation node
            new_x = sample_agreeing_x(D, parent_I, ref_x)

            # Get the scrubbed activations
            inputs_G[*parent_G(), :] = run_scrub(h, D, parent_I, new_x)
        
        # "unimportant" parents
        else:
            # get activations on the random input value chosen
            inputs_G[*parent_G(), :] = parent_G.value_on(random_x)
    print(inputs_G, "hei")
    # run n_G given the computed input activation
    return n_G.value_from_inputs(inputs_G, ref_x)

In [7]:
# A function estimating the output from a scrubbed model
def estimate(h: Hypothesis, D: list[str]):
    _G, I, c = h.G, h.I, h.c
    outs = []
    for i in range(len(I)):
        x = random.choice(D)
        outs.append(run_scrub(h, D, I[i], x))
    return torch.mean(torch.tensor(outs), 1)

In [37]:
model = HookedTransformer.from_pretrained("gpt2-small")
print(model.run_with_cache("hello"))
print(len((utils.get_act_name("attn_out"))))

Loaded pretrained model gpt2-small into HookedTransformer
(tensor([[[ 7.5261, 11.1214,  7.8919,  ..., -3.1299, -3.3873,  8.5934],
         [12.4429,  7.8121,  3.6721,  ..., -0.9013,  1.8067,  8.4964]]],
       grad_fn=<ViewBackward0>), ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z

In [36]:
text = ["hello", "hei", "hallo", "hi"]
G = []
for i in range(model.cfg.n_layers):
    for j in range(model.cfg.n_heads):
        G.append((i,j))
I = [(1,0), (1,1)]
h = Hypothesis(G, I, lambda x: x, model)

estimate(h, text)

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

IndexError: too many indices for tensor of dimension 4