In [1]:
#%pip install transformer_lens

In [2]:
from transformer_lens import HookedEncoderDecoder
import transformer_lens.utils as utils
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES

import torch

torch.set_grad_enabled(False)




<torch.autograd.grad_mode.set_grad_enabled at 0x7fd5365aeb50>

## Loading the Model in TransformerLens

Please download the model first: https://cloud.anja.re/s/Qpo8CZ6yRzDH7ZF

In [3]:
# !wget "https://cloud.anja.re/s/qckH8GQPyN6YK8w/download?path=%2F&files=DSI-large-TriviaQA.zip"
# !unzip "download?path=%2F&files=DSI-large-TriviaQA.zip"
checkpoint = "../DSI-large-TriviaQA"

OFFICIAL_MODEL_NAMES.append(checkpoint)

hf_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
device = utils.get_device()
model = HookedEncoderDecoder.from_pretrained(checkpoint, hf_model=hf_model, device=device)

tokenizer_t5 = AutoTokenizer.from_pretrained('google-t5/t5-large')


# Our model has a new token for each document id that we trained it on.

# token id of first document that was added
first_added_doc_id = len(tokenizer_t5)
# token id of the last document that was added
last_added_doc_id = len(tokenizer_t5) + (len(tokenizer) - len(tokenizer_t5))
del tokenizer_t5


If using T5 for interpretability research, keep in mind that T5 has some significant architectural differences to GPT. The major one is that T5 is an Encoder-Decoder modelAlso, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm


Loaded pretrained model ../DSI-large-TriviaQA into HookedTransformer


## Running a sample query

In [4]:
#wget "https://cloud.anja.re/s/qckH8GQPyN6YK8w/download?path=%2FGenIR-Data&files=TriviaQAData.zip"
#unzip "download?path=%2FGenIR-Data&files=TriviaQAData.zip"

import json
from torch.utils.data import Dataset, DataLoader
with open("../TriviaQAData/train_queries_trivia_qa.json", mode='r') as f:
  train_data = json.load(f)
with open("../TriviaQAData/val_queries_trivia_qa.json", mode='r') as f:
  val_data = json.load(f)
data = train_data + val_data

class QuestionsDataset(Dataset):
    def __init__(self, inputs, targets, ids):
        self.inputs = inputs
        self.targets = targets
        self.ids = ids

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx], self.ids[idx]

    def collate_fn(self, batch):
        input_texts, target_texts, ids = zip(*batch)
        return input_texts, target_texts, ids


queries = [entry['query'] for entry in data]
ground_truths = [entry['relevant_docs'] for entry in data]
query_ids = [entry['id'] for entry in data]
q_ids_to_qs = {id : query for id, query in zip(query_ids, queries)}
dataset = QuestionsDataset(queries, ground_truths, query_ids)
data_loader = DataLoader(dataset, batch_size=16, shuffle = False, collate_fn=dataset.collate_fn)

In [5]:
def extract_activated_neurons(hooks, layer_indices):
    result_dict = {}
    for layer_id in layer_indices:
        hook_post = hooks[f'decoder.{layer_id}.mlp.hook_post']
        activated_neuron_indices = (hook_post > 0).nonzero(as_tuple=False)
        result_dict[f'layer_{layer_id}'] = activated_neuron_indices
    return result_dict

def pad_relevant_docs(relevant_docs):
    relevant_docs = [[int(item) for item in sublist] for sublist in relevant_docs]
    max_len = 0
    for sublist in relevant_docs:
        if len(sublist) > max_len:
            max_len = len(sublist)
    padded_relevant_docs = []
    for sublist in relevant_docs:
        # Calculate how many padding elements are needed
        num_padding = max_len - len(sublist)
        # Create the padded sublist
        padded_sublist = sublist + [-1] * num_padding
        padded_relevant_docs.append(padded_sublist)
    return padded_relevant_docs

def extract_doc_ids_from_output(decoder_output):
    doc_out_ids = []
    for out in doc_out:
        doc_id = re.findall(r"@DOC_ID_([0-9]+)@", decoder_output)
        assert len(doc_id) <= 1
        doc_out_ids.append(doc_id[0] if len(doc_id) else '-1')

In [23]:
def make_mlp_hook_function(target_token_pos, target_neuron_index, new_activation_value):
    def modify_mlp_neuron_hook(
        activation_tensor: torch.Tensor, 
        hook
    ) -> torch.Tensor:
        """
        A hook function to modify a specific MLP neuron's activation.
        activation_tensor shape: [batch, position, n_mlp_neurons]
        """
        # print(f"Hook fired at {hook.name}. Original activation value at "
        #       f"pos {target_token_pos}, neuron {target_neuron_index}: "
        #       f"{activation_tensor[:, target_token_pos, target_neuron_index]}")
              # f"{activation_tensor[0, target_token_pos, target_neuron_index].item():.4f}")
    
        # Modify the specific neuron's activation in-place
        # We use [0] for batch dimension assuming a single prompt
        activation_tensor[:, target_token_pos, target_neuron_index] = new_activation_value
        # activation_tensor[0, target_token_pos, target_neuron_index] = new_activation_value
    
        # print(f"Modified activation to: "
        #       f"{activation_tensor[:, target_token_pos, target_neuron_index]}")
              # f"{activation_tensor[0, target_token_pos, target_neuron_index].item():.4f}")
    
        return activation_tensor # Always return the modified tensor

    return modify_mlp_neuron_hook

def run_model_with_activation_hook(model, prompt, mlp_hook_name, neuron_index, neuron_new_value):
    # mlp_hook_name = f"blocks.{target_layer}.mlp.hook_post"
    # Now, run with the hook
    modified_logits = model.run_with_hooks(
        prompt,
        fwd_hooks=[(mlp_hook_name, make_mlp_hook_function(0, neuron_index, neuron_new_value))]
    )
    hook_result = tokenizer.batch_decode(torch.argmax(modified_logits, dim=-1).squeeze(-1))

    logits = model(prompt)
    orig_result = tokenizer.batch_decode(torch.argmax(logits, dim=-1).squeeze(-1))
    
    print(f'original result:{orig_result}, and after using the hook:{hook_result}')
    correct_count = 0
    for i in range(len(orig_result)):
        if orig_result[i] == hook_result[i]:
            correct_count += 1
    print(f'Total correct answered:{correct_count}/ {len(orig_result)}')

def get_affected_prompts(model, queries_dict, mlp_hook_name, layer_index, neuron_index, neuron_new_value):
    layer_activated_neuron_inputs, layer_activated_neurons_correct_doc_ids = [], []
    for key in queries_dict:
        if neuron_index in queries_dict[key]['activated_neurons'][f'layer_{layer_id}']:
            layer_activated_neuron_inputs.append(results_copy[key]['input'])
            layer_activated_neurons_correct_doc_ids.append(results_copy[key]['correct_doc_id'])
    return run_model_with_activation_hook(model, layer_activated_neuron_inputs, mlp_hook_name, neuron_index, neuron_new_value)

## Examining the activations

The activations of each component in the transformer are stored in the `cache` object. It's basically a dict from which you choose which component to look at.

Here, we print all possible component keys for layer 0 in the decoder:

In [7]:
# for key in cache.keys():
#   if key.startswith('decoder.0.'):
#     print(key)

We choose to look at the output of the MLP in layer 19 of the decoder:

In [8]:
# cache['decoder.19.hook_mlp_out'], cache['decoder.19.hook_mlp_out'].shape

Take a look at where the MLP hooks are computed: https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/components/mlps/mlp.py

`hook_pre`: Before activation,
`hook_post`: After applying activation

### Shir's Part

In [9]:
def get_logits(model, prompt, mlp_hook_name, neuron_index):
    modified_logits = model.run_with_hooks(
        prompt,
        fwd_hooks=[(mlp_hook_name, make_mlp_hook_function(0, neuron_index, 0.0))]
    )
    logits = model(prompt)
    return logits, modified_logits

Checking change of logits after neuron zeroing

In [10]:
def change_by_neuron_zeroing(logits, modified_logits):
    l2_change = torch.linalg.vector_norm(logits - modified_logits)
    cosine_similarity = torch.nn.functional.cosine_similarity(logits, modified_logits)
    return l2_change, cosine_similarity
    

Checking the ratio between the max and second max logits (first and second prediction of doc ids)

In [21]:
def logit_top2_difference(logits, modified_logits):
    
    ## return the difference caused to the original (unmodified) top 2 logits by modifying the neurons
    
    top_2_logits, top_2_logits_indices = torch.topk(logits, k=2)
    top_2_logits = top_2_logits.squeeze(0)
    top_2_logits_indices = top_2_logits_indices.squeeze(0)
    modified_logits = modified_logits.squeeze(0)
    difference = top_2_logits - modified_logits[top_2_logits_indices]
    return difference

check layer output similarity to final logits

In [12]:
def similarity_between_layer_output_and_logits(logits, layer_output_hook): 
    # I don't think is makes sense to run the model again as we have a saved dictionary of all hooks per layer, 
    # so this function is pretty simple right now.
    return torch.nn.functional.cosine_similarity(logits, layer_output_hook)
    # project layer output into logits space
    # check if layer output is normalized
    # talk to Anja about what is the best way to measure the contribution to the final logits

Project into vocabulary dimension - copying Anja's code into here. I don't really understand how to use it though.

In [13]:
def process_resid(resid, model):
    decoder_resid = model.decoder_final_ln(resid)

    if model.cfg.tie_word_embeddings:
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        decoder_resid *= model.cfg.d_model**-0.5

    return decoder_resid

def logit_lens_decoder(resid, model):
    decoder_resid = process_resid(resid, model=model)
    logits = model.unembed(decoder_resid)

    return logits

    # check what decoder_final_ln is and if we have it

In [14]:
with open('../activated_neurons_train_val_copy.json', 'r') as f:
        activated_neurons = json.load(f)

In [15]:
with open('../neuron_activation_stats_with_entries.json', 'r') as f:
    neuron_stats = json.load(f)

In [31]:
from collections import defaultdict

logit_changes = defaultdict(lambda: {})
for layer in neuron_stats['by_layer'].keys():
    layer_id = int(''.join(c for c in layer if c.isdigit()))
    layer_dict = neuron_stats['by_layer'][layer]
    for neuron in layer_dict.keys():
        l2_sum = 0
        cos_sim_sum = 0
        top_2_diff_sum = [0,0]
        print(f"modifying neuron {neuron}")
        ids = [id for id in layer_dict[neuron]['entries']]
        queries = [q_ids_to_qs[id] for id in ids]
        n = len(queries)
        batches = [(queries[i : i+32], ids[i : i+32]) for i in range(0, len(queries), 32)]
        for q_batch, ids_batch in batches:
            b_logits, b_modified_logits = get_logits(model, q_batch, f'decoder.{layer_id}.mlp.hook_post', int(neuron))
            for logits, modified_logits, id in zip(b_logits, b_modified_logits, ids_batch):
                l2, cos_sim = change_by_neuron_zeroing(logits, modified_logits)
                top2_diff = logit_top2_difference(logits, modified_logits)
                l2_sum += l2
                cos_sim_sum += cos_sim
                top_2_diff_sum[0] += top2_diff[0]
                top_2_diff_sum[1] += top2_diff[1]
                
        logit_changes[layer][neuron] = {"avg l2": l2_sum / n, "avg cosine similarity": cos_sim_sum / n, "avg top2_difference": [i/n for i in top_2_diff_sum]}    
                
with open("logit_changes_by_zeroing.json", "w") as f:
    json.dump(logit_changes, f)
                
        

modifying neuron 1
modifying neuron 513


KeyboardInterrupt: 

In [32]:
logit_changes["layer_17"]["1"]

{'avg l2': tensor(17.3075, device='cuda:0'),
 'avg cosine similarity': tensor([1.], device='cuda:0'),
 'avg top2_difference': [tensor(-0.0008, device='cuda:0'),
  tensor(-1.8029e-05, device='cuda:0')]}