In [1]:
#%pip install transformer_lens

In [2]:
from transformer_lens import HookedEncoderDecoder
import transformer_lens.utils as utils
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES

import torch

torch.set_grad_enabled(False)




<torch.autograd.grad_mode.set_grad_enabled at 0x280e2362090>

## Loading the Model in TransformerLens

Please download the model first: https://cloud.anja.re/s/Qpo8CZ6yRzDH7ZF

In [3]:
# !wget "https://cloud.anja.re/s/qckH8GQPyN6YK8w/download?path=%2F&files=DSI-large-TriviaQA.zip"
# !unzip "download?path=%2F&files=DSI-large-TriviaQA.zip"
checkpoint = "DSI-large-TriviaQA"

OFFICIAL_MODEL_NAMES.append(checkpoint)

hf_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
device = utils.get_device()
model = HookedEncoderDecoder.from_pretrained(checkpoint, hf_model=hf_model, device=device)

tokenizer_t5 = AutoTokenizer.from_pretrained('google-t5/t5-large')


# Our model has a new token for each document id that we trained it on.

# token id of first document that was added
first_added_doc_id = len(tokenizer_t5)
# token id of the last document that was added
last_added_doc_id = len(tokenizer_t5) + (len(tokenizer) - len(tokenizer_t5))
del tokenizer_t5


If using T5 for interpretability research, keep in mind that T5 has some significant architectural differences to GPT. The major one is that T5 is an Encoder-Decoder modelAlso, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm


Loaded pretrained model DSI-large-TriviaQA into HookedTransformer


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

## Running a sample query

In [4]:
#wget "https://cloud.anja.re/s/qckH8GQPyN6YK8w/download?path=%2FGenIR-Data&files=TriviaQAData.zip"
#unzip "download?path=%2FGenIR-Data&files=TriviaQAData.zip"

import json
from torch.utils.data import Dataset, DataLoader
with open("TriviaQAData/test_queries_trivia_qa.json", mode='r') as f:
  test_data = json.load(f)

class QuestionsDataset(Dataset):
    def __init__(self, inputs, targets, ids):
        self.inputs = inputs
        self.targets = targets
        self.ids = ids

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx], self.ids[idx]

    def collate_fn(self, batch):
        input_texts, target_texts, ids = zip(*batch)
        return input_texts, target_texts, ids


queries = [entry['query'] for entry in test_data]
ground_truths = [entry['relevant_docs'] for entry in test_data]
query_ids = [entry['id'] for entry in test_data]
dataset = QuestionsDataset(queries, ground_truths, query_ids)
data_loader = DataLoader(dataset, batch_size=16, shuffle = False, collate_fn=dataset.collate_fn)

In [5]:
correct_queries = []
decoder_input = torch.tensor([[0]])
for input_texts, target_texts, ids in data_loader:
  input_tokens = tokenizer(input_texts, return_tensors='pt', padding=True)['input_ids']
  logits = model.forward(input_tokens, decoder_input)
  # print(torch.argmax(logits, dim=-1).squeeze(-1))
  res = tokenizer.decode(torch.argmax(logits, dim=-1).squeeze(-1))
  doc_ids = [s for s in res.replace('@','_').split(sep='_') if s.isdigit()]
  correct_queries += [(id, query, predicted, truth) for id,query,predicted,truth in zip(ids, input_texts, doc_ids, target_texts) if predicted in truth]


# for entry in training_data:
#   id, query, relevant_docs = entry
#   input_tokens = tokenizer(query, return_tensors='pt')['input_ids']
#   decoder_input = torch.tensor([[0]])

#   logits, cache = model.run_with_cache(input_tokens, decoder_input, remove_batch_dim=True)
#   res = tokenizer.decode(torch.argmax(logits, dim=-1)[0][0])
#   if res in relevant_docs:
#     correct_queries.append(entry)



In [17]:
ids = [entry[0] for entry in correct_queries]
queries = [entry[1] for entry in correct_queries]
truths = [entry[3] for entry in correct_queries]

correct_dataset = QuestionsDataset(queries, truths, ids)
dl = DataLoader(correct_dataset, batch_size=16, shuffle = False, collate_fn=dataset.collate_fn)

decoder_input = torch.tensor([[0]])
cached_mlps = {}
for input_texts, target_texts, ids in dl:
  input_tokens = tokenizer(input_texts, return_tensors='pt', padding=True)['input_ids']
  _, cache = model.run_with_cache(input_tokens, decoder_input)
  for layer in range(18, 24):
      cached_mlps[f"layer_{layer}"] = torch.cat((cached_mlps.setdefault(f"layer_{layer}", torch.Tensor()), cache[f"decoder.{layer}.mlp.hook_post"]), dim = 0)

torch.save(cached_mlps,"cached_mlp_from_correct_queries.json")




In [42]:
# cached = torch.load("cached_mlp_from_correct_queries.json")
# cached_with_query_id = {q_id : {key : cached_layer.squeeze(1)[index] for key,cached_layer in cached.items()} for index, q_id in enumerate(ids)} 
# queries_predicted = {l[0] : l[2] for l in correct_queries}
queries_predicted
# TODO: what do we need to save for each doc id and query
# dict: doc-id -> num-of-valid-queries
# doc-id -> (activations, valid-queries)
# TODO: make sure all docs were indexed



{'QTest0': '21871',
 'QTest2': '70062',
 'QTest4': '52288',
 'QTest5': '38019',
 'QTest6': '8466',
 'QTest7': '73330',
 'QTest8': '9181',
 'QTest9': '70053',
 'QTest10': '52421',
 'QTest11': '13600',
 'QTest12': '60198',
 'QTest14': '46780',
 'QTest15': '50855',
 'QTest17': '68189',
 'QTest18': '11358',
 'QTest20': '42034',
 'QTest21': '66108',
 'QTest22': '4579',
 'QTest23': '38155',
 'QTest24': '59353',
 'QTest25': '36612',
 'QTest27': '52886',
 'QTest28': '52588',
 'QTest29': '34722',
 'QTest30': '23458',
 'QTest31': '7944',
 'QTest32': '23003',
 'QTest34': '11482',
 'QTest35': '56292',
 'QTest38': '45635',
 'QTest39': '29055',
 'QTest41': '57923',
 'QTest42': '59708',
 'QTest45': '39093',
 'QTest46': '45883',
 'QTest47': '59708',
 'QTest48': '51159',
 'QTest49': '21815',
 'QTest50': '56658',
 'QTest51': '55325',
 'QTest52': '3839',
 'QTest55': '4181',
 'QTest56': '29055',
 'QTest57': '34767',
 'QTest58': '67193',
 'QTest59': '59878',
 'QTest60': '18097',
 'QTest61': '68811',
 'QTes

In [None]:
query = "test query"

input_tokens = tokenizer(query, return_tensors='pt')['input_ids']
decoder_input = torch.tensor([[0]])

logits, cache = model.run_with_cache(input_tokens, decoder_input, remove_batch_dim=True)

In [None]:
logits

In [None]:
# Prediction from the logits
torch.argmax(logits, dim=-1), tokenizer.decode(torch.argmax(logits, dim=-1)[0][0])

## Examining the activations

The activations of each component in the transformer are stored in the `cache` object. It's basically a dict from which you choose which component to look at.

Here, we print all possible component keys for layer 0 in the decoder:

In [None]:
for key in cache.keys():
  if key.startswith('decoder.0.'):
    print(key)

We choose to look at the output of the MLP in layer 19 of the decoder:

In [None]:
cache['decoder.19.hook_mlp_out'], cache['decoder.19.hook_mlp_out'].shape

Take a look at where the MLP hooks are computed: https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/components/mlps/mlp.py

`hook_pre`: Before activation,
`hook_post`: After applying activation