In [None]:
%pip install transformer_lens



In [None]:
from transformer_lens import HookedEncoderDecoder
import transformer_lens.utils as utils
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES

import torch

torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7c4ee2b94b10>

## Loading the Model in TransformerLens

Please download the model first: https://cloud.anja.re/s/Qpo8CZ6yRzDH7ZF

In [None]:
#!wget "https://cloud.anja.re/s/qckH8GQPyN6YK8w/download?path=%2F&files=DSI-large-TriviaQA.zip"
#!unzip "download?path=%2F&files=DSI-large-TriviaQA.zip"
checkpoint = "DSI-large-TriviaQA"

OFFICIAL_MODEL_NAMES.append(checkpoint)

hf_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
device = utils.get_device()
model = HookedEncoderDecoder.from_pretrained(checkpoint, hf_model=hf_model, device=device)

tokenizer_t5 = AutoTokenizer.from_pretrained('google-t5/t5-large')


# Our model has a new token for each document id that we trained it on.

# token id of first document that was added
first_added_doc_id = len(tokenizer_t5)
# token id of the last document that was added
last_added_doc_id = len(tokenizer_t5) + (len(tokenizer) - len(tokenizer_t5))
del tokenizer_t5


If using T5 for interpretability research, keep in mind that T5 has some significant architectural differences to GPT. The major one is that T5 is an Encoder-Decoder modelAlso, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm


Loaded pretrained model DSI-large-TriviaQA into HookedTransformer


## Data Prep

In [None]:
#!wget "https://cloud.anja.re/s/qckH8GQPyN6YK8w/download?path=%2FGenIR-Data&files=TriviaQAData.zip"
#!unzip "download?path=%2FGenIR-Data&files=TriviaQAData.zip"

THRESHOLD = 4

import json
from torch.utils.data import Dataset, DataLoader
with open("TriviaQAData/test_queries_trivia_qa.json", mode='r') as f:
  test_data = json.load(f)

with open("TriviaQAData/val_queries_trivia_qa.json", mode='r') as f:
  val_data = json.load(f)

with open("TriviaQAData/train_queries_trivia_qa.json", mode='r') as f:
  train_data = json.load(f)

data = test_data + val_data + train_data
q_id_to_entry = {entry['id'] : entry for entry in data}

# there seem to be no duplicated queries so no need to filter them

doc_id_to_entries = {}
for entry in data:
    for doc_id in entry['relevant_docs']:
        doc_id_to_entries.setdefault(doc_id,[]).append(entry)

# filter docs with less than 4 queries
doc_id_to_entries = {doc_id : entries for doc_id, entries in doc_id_to_entries.items() if len(entries) >= THRESHOLD}

# get the entries after the filtering
filtered_entries_ids = set([entry["id"] for ent_list in doc_id_to_entries.values() for entry in ent_list])
filtered_entries = [q_id_to_entry[id] for id in filtered_entries_ids]
print(len(filtered_entries))

queries = [entry['query'] for entry in filtered_entries]
ground_truths = [entry['relevant_docs'] for entry in filtered_entries]
query_ids = [entry['id'] for entry in filtered_entries]


54801


In [None]:
print(len(data))

77582


### Dataset Class

In [None]:
class QuestionsDataset(Dataset):
    def __init__(self, inputs, targets, ids):
        self.inputs = inputs
        self.targets = targets
        self.ids = ids

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx], self.ids[idx]

    def collate_fn(self, batch):
        input_texts, target_texts, ids = zip(*batch)
        return input_texts, target_texts, ids


In [None]:
batch_size = 32
dataset = QuestionsDataset(queries, ground_truths, query_ids)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle = False, collate_fn=dataset.collate_fn)

In [None]:
# import numpy as np

# K = 5

# decoder_input = torch.tensor([[0]])
# doc_id_to_valid_queries = {}
# queries_to_correct_predicted_ids = {}
# for input_texts, target_texts, ids in data_loader:
#   input_tokens = tokenizer(input_texts, return_tensors='pt', padding=True)['input_ids']
#   logits = model.forward(input_tokens, decoder_input)
#   res = tokenizer.decode(torch.topk(logits, K, dim=-1)[1].flatten())
#   doc_ids = [s for s in res.replace('@','_').split(sep='_') if s.isdigit()]
#   doc_ids = np.array_split(doc_ids, batch_size)
#   # correct_queries += [(id, query, predicted, truth) for id,query,predicted,truth in zip(ids, input_texts, doc_ids, target_texts) if truth in predicted]
#   for q_id, relevant_docs, predicted_list in zip(ids, target_texts, doc_ids):
#       for i,doc_id in enumerate(predicted_list):
#           if doc_id in relevant_docs:
#               doc_id_to_valid_queries.setdefault(doc_id, [])
#               doc_id_to_valid_queries[doc_id].append(q_id)
#               queries_to_correct_predicted_ids.setdefault(q_id, {})
#               queries_to_correct_predicted_ids[q_id][i+1] = doc_id

# ids_with_more_than_threshold_correct_queries = {key : value for key, value in doc_id_to_valid_queries.items() if len(value) >= THRESHOLD}
# correct_queries_ids = set([q_id for key, query_list in ids_with_more_than_threshold_correct_queries.items() for q_id in query_list])

# print(len(correct_queries_ids))






# for entry in training_data:
#   id, query, relevant_docs = entry
#   input_tokens = tokenizer(query, return_tensors='pt')['input_ids']
#   decoder_input = torch.tensor([[0]])

#   logits, cache = model.run_with_cache(input_tokens, decoder_input, remove_batch_dim=True)
#   res = tokenizer.decode(torch.argmax(logits, dim=-1)[0][0])
#   if res in relevant_docs:
#     correct_queries.append(entry)

In [None]:
correct_queries_at_first_place = torch.load("correct_queries_at_first_place.json")
print(len(correct_queries_at_first_place))

44004


In [None]:
ids = list(correct_queries_at_first_place)
queries = [q_id_to_entry[q_id]['query'] for q_id in ids]
truths = [q_id_to_entry[q_id]['relevant_docs'] for q_id in ids]

for layer in range(18, 24):
  print(f"layer {layer}")
  correct_dataset = QuestionsDataset(queries, truths, ids)
  dl = DataLoader(correct_dataset, batch_size=batch_size, shuffle = False, collate_fn=dataset.collate_fn)

  hook_names = [f"decoder.{layer}.mlp.hook_pre", f"decoder.{layer}.mlp.hook_post", f"decoder.{layer}.hook_mlp_out"]

  decoder_input = torch.tensor([[0]])
  layer_query_to_activations = {}
  for input_texts, target_texts, ids in dl:
    input_tokens = tokenizer(input_texts, return_tensors='pt', padding=True)['input_ids']
    _, cache = model.run_with_cache(input_tokens, decoder_input, names_filter=hook_names)
    for pre_act_row, post_act_row, out_act_row, q_id in  \
            zip(cache[f"decoder.{layer}.mlp.hook_pre"], cache[f"decoder.{layer}.mlp.hook_post"], cache[f"decoder.{layer}.hook_mlp_out"], ids):
        layer_query_to_activations[q_id] = {"pre" : pre_act_row.squeeze(0).cpu(), "post" : post_act_row.squeeze(0).cpu(), "out" : out_act_row.squeeze(0).cpu()}

    del cache
  torch.save(layer_query_to_activations, f"layer_activations/layer_{layer}.json")
  layer_query_to_activations.clear()




layer 18




layer 19


IndexError: tuple index out of range

In [None]:
# saver cell
torch.save(doc_id_to_valid_queries, "doc_id_to_valid_queries.json")
torch.save(queries_to_correct_predicted_ids, "queries_to_correct_predicted_ids.json")
torch.save(ids_with_more_than_threshold_correct_queries, "ids_with_more_than_threshold_correct_queries.json")
torch.save(correct_queries_ids, "correct_queries_ids.json")
torch.save(correct_queries_at_first_place, "correct_queries_at_first_place.json")
torch.save(correct_doc_ids_at_first_place,"correct_doc_ids_at_first_place")


In [None]:
# cached = torch.load("cached_mlp_from_correct_queries.json")
# cached_with_query_id = {q_id : {key : cached_layer.squeeze(1)[index] for key,cached_layer in cached.items()} for index, q_id in enumerate(ids)}
# queries_predicted = {l[0] : l[2] for l in correct_queries}
queries_predicted
# TODO: what do we need to save for each doc id and query
# dict: doc-id -> num-of-valid-queries
# doc-id -> (activations, valid-queries)
# TODO: make sure all docs were indexed



In [None]:
query = "test query"

input_tokens = tokenizer(query, return_tensors='pt')['input_ids']
decoder_input = torch.tensor([[0]])

logits, cache = model.run_with_cache(input_tokens, decoder_input, remove_batch_dim=True)

In [None]:
logits

In [None]:
# Prediction from the logits
torch.argmax(logits, dim=-1), tokenizer.decode(torch.argmax(logits, dim=-1)[0][0])

## Examining the activations

The activations of each component in the transformer are stored in the `cache` object. It's basically a dict from which you choose which component to look at.

Here, we print all possible component keys for layer 0 in the decoder:

In [None]:
for key in cache.keys():
  if key.startswith('decoder.0.'):
    print(key)

We choose to look at the output of the MLP in layer 19 of the decoder:

In [None]:
cache['decoder.19.hook_mlp_out'], cache['decoder.19.hook_mlp_out'].shape

Take a look at where the MLP hooks are computed: https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/components/mlps/mlp.py

`hook_pre`: Before activation,
`hook_post`: After applying activation