We define the following:
- $G$, the computational graph of the model, that is, a set of nodes $n_G$ (heads in attention an MLP fx) and a set of the connections, ie computations, between them.
- $D$, the dataset. This contains all data we want to test the hypothesis on, ie. the relevant sentences.
- $h$, the hypothesis. This is a tuple $(G, I, c)$, where $I$ is the hypothesized important graph and $c:I\to G$ is an injective correspondence function.

In [17]:
import random
import torch
from typing import Union

from transformer_lens import HookedTransformer
from functools import partial

In [None]:
class Node:
    def __init__(self, layer: int, head: int):
        self.layer = layer
        self.head = head

    def parents(self):
        parent_list = [(i, j) for i, j in zip(range(self.layer), range(self.head))]
        return [Node(*parent) for parent in parent_list]
    
        
    def value_on(self, x: str) -> torch.tensor:
        logits, _ = self.model.run_with_cache(x)
        return logits[self.layer, self.head]
    
    def __call__(self):
        return self.layer, self.head
    
    def value_from_inputs(self, inputs):
        

In [None]:
class Hypothesis:
    def __init__(self, G: list[tuple[int]], I: list[tuple[int]], c: callable):
        self.G = G
        self.I = I 
        self.c = c


    def set_domain(self, D: list[str]):
        self.D = D

    def set_model(self, model):
        self.model = model

    def c_image(self, set: list[Node]) -> list[Node]:
        if isinstance(set, Node):
            return self.c(set)
    
        image = list(map(self.c, set))
        return image
    
    def c_preimage(self, set: Union[list[Node], Node]) -> list[Node]:
        preimage = []
        for node in self.I:
            if isinstance(set, Node):
                if self.c(node) == set:
                    preimage.append(node)
            else:
                if self.c(node) in set:
                    preimage.append(node)
        return preimage


In [7]:
# TODO What if no matches? approx?
def sample_agreeing_x(D: list[str], n_I: Node, ref_x: str) -> str:
    '''Returns random sample input that agrees on the specified node.'''
    D_agree = [x for x in D if n_I.value_on(ref_x) == n_I.value_on(x)]
    return random.choice(D_agree)

In [16]:
# TODO: In a hypothesis class?
def hypothesis_correspondence(n_I: Node) -> Node:
    '''
    Return the corresponding node in the main graph.

    Param:
    n_I (int): A node in the interpretation graph I

    Returns:
    n_G (int): The correspondin node in the main graph G, according to the hypothesis
    '''
    pass


def run_scrub(h: Hypothesis, D: list[str], n_I: Node, ref_x: str, model):
    '''
    Return the output after a scrub changing all activation of unimportant nodes.

    Param:
    c (callable): The hypothesis correspondence
    D (list[list]): A list of all possible inputs in the domain.
    n_I (tuple[int]): The node in question
    ref_x (list): The reference input
    '''
    h.set_domain(D)
    h.set_model(model)

    # The corresponding main node
    n_G = h.c(n_I)

    if n_G in ref_x:
        return ref_x
    
    inputs_G = torch.zeros(model.cfg.n_layers, model.cfg.heads)
    
    # We pick a ranom sample to use as the exchange values for the unimportant nodes
    random_x = random.choice(D)

    # Get scrubbed activations for the inputs to n_G
    for parent_G in n_G.parents():
        # "important" parents
        if parent_G in h.c_image(n_I.parents):
            parent_I = h.c_preimage(parent_G)

            # Sample new input that agrees on the interpretation node
            new_x = sample_agreeing_x(D, parent_I, ref_x)

            # Get the scrubbed activations
            inputs_G[*parent_G(), :] = run_scrub(h, D, parent_I, new_x)
        
        # "unimportant" parents
        else:
            # get activations on the random input value chosen
            inputs_G[*parent_G(), :] = parent_G.value_on(random_x)
    
    # run n_G given the computed input activations
    return n_G.value_from_inputs(inputs_G)

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [13]:
text = "hello"
logits, cache = model.run_with_cache(text)
print(cache)
print(logits.shape)

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

In [None]:
def exchange_activations_attn(out, hook, inputs, node):
    for layer in range(model.cfg.n_layer):
        if not node:
            for position in range(model.cfg.n_heads):
                out[layer, position, :] = inputs[layer, position]
        else:
            for position in range(node.head):
                out[layer, position, :] = inputs[layer, position]
            return 

def run_with_changes(model, text, inputs, end_node=False):
    tokens = model.to_tokens(text)
    temp_hook = partial(exchange_activations_attn, inputs, end_node)
    logits
