# Data Preparation

This notebook prepares data for subsequent analysis and modeling.

**Inputs (loaded):**
- `TriviaQAData/train_queries_trivia_qa.json`
- `TriviaQAData/val_queries_trivia_qa.json`

**Outputs (saved):**
- Activation maps separated by MLP layers for each query correctly classified by the model.
- Subsequent useful dictionaries filtering queries or doc-ids by desired properties, to be used later.

**High-level steps:**
1. Prepare queries data set.
2. Collect correctly classified queries.
3. Run model only on correct queries, while saving activations.


In [None]:
import json
import numpy as np
import torch
import transformer_lens.utils as utils
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformer_lens import HookedEncoderDecoder
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer

In [None]:
torch.set_grad_enabled(False)

## Loading the Model in TransformerLens

Download the model first: https://cloud.anja.re/s/qckH8GQPyN6YK8w#

In [None]:
checkpoint = 'DSI-large-TriviaQA'
OFFICIAL_MODEL_NAMES.append(checkpoint)
hf_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
device = utils.get_device()
model = HookedEncoderDecoder.from_pretrained(checkpoint, hf_model=hf_model, device=device)
tokenizer_t5 = AutoTokenizer.from_pretrained('google-t5/t5-large')
first_added_doc_id = len(tokenizer_t5)
last_added_doc_id = len(tokenizer_t5) + (len(tokenizer) - len(tokenizer_t5))
del tokenizer_t5

## Data Prep

In [None]:
THRESHOLD = 4
with open('TriviaQAData/val_queries_trivia_qa.json', mode='r') as f:
    val_data = json.load(f)
with open('TriviaQAData/train_queries_trivia_qa.json', mode='r') as f:
    train_data = json.load(f)
data = val_data + train_data
q_id_to_entry = {entry['id']: entry for entry in data}
doc_id_to_entries = {}
for entry in data:
    for doc_id in entry['relevant_docs']:
        doc_id_to_entries.setdefault(doc_id, []).append(entry)
doc_id_to_entries = {doc_id: entries for doc_id, entries in doc_id_to_entries.items() if len(entries) >= THRESHOLD}
filtered_entries_ids = set([entry['id'] for ent_list in doc_id_to_entries.values() for entry in ent_list])
filtered_entries = [q_id_to_entry[id] for id in filtered_entries_ids]
queries = [entry['query'] for entry in filtered_entries]
ground_truths = [entry['relevant_docs'] for entry in filtered_entries]
query_ids = [entry['id'] for entry in filtered_entries]

### Dataset Class

In [None]:
class QuestionsDataset(Dataset):

    def __init__(self, inputs, targets, ids):
        self.inputs = inputs
        self.targets = targets
        self.ids = ids

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return (self.inputs[idx], self.targets[idx], self.ids[idx])

    def collate_fn(self, batch):
        input_texts, target_texts, ids = zip(*batch)
        return (input_texts, target_texts, ids)

In [None]:
batch_size = 32
dataset = QuestionsDataset(queries, ground_truths, query_ids)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=dataset.collate_fn)

### Finding Correct Queries

In [None]:
K = 5
decoder_input = torch.tensor([[0]])
doc_id_to_valid_queries = {}
queries_to_correct_predicted_ids = {}
for input_texts, target_texts, ids in data_loader:
    input_tokens = tokenizer(input_texts, return_tensors='pt', padding=True)['input_ids']
    logits = model.forward(input_tokens, decoder_input)
    res = tokenizer.decode(torch.topk(logits, K, dim=-1)[1].flatten())
    doc_ids = [s for s in res.replace('@', '_').split(sep='_') if s.isdigit()]
    doc_ids = np.array_split(doc_ids, batch_size)
    for q_id, relevant_docs, predicted_list in zip(ids, target_texts, doc_ids):
        for i, doc_id in enumerate(predicted_list):
            if doc_id in relevant_docs:
                doc_id_to_valid_queries.setdefault(doc_id, [])
                doc_id_to_valid_queries[doc_id].append(q_id)
                queries_to_correct_predicted_ids.setdefault(q_id, {})
                queries_to_correct_predicted_ids[q_id][i + 1] = doc_id
ids_with_more_than_threshold_correct_queries = {key: value for key, value in doc_id_to_valid_queries.items() if len(value) >= THRESHOLD}
correct_queries_ids = set([q_id for key, query_list in ids_with_more_than_threshold_correct_queries.items() for q_id in query_list])


In [None]:
torch.save(doc_id_to_valid_queries, 'doc_id_to_valid_queries.json')
torch.save(queries_to_correct_predicted_ids, 'queries_to_correct_predicted_ids.json')
torch.save(ids_with_more_than_threshold_correct_queries, 'ids_with_more_than_threshold_correct_queries.json')
torch.save(correct_queries_ids, 'correct_queries_ids.json')

In [None]:
correct_queries_at_first_place = set([q_id for q_id, q_dict in queries_to_correct_predicted_ids.items() if 1 in q_dict and q_id in correct_queries_ids])
print(len(correct_queries_at_first_place))
torch.save(correct_queries_at_first_place, 'correct_queries_at_first_place.json')

### Generating Activation Dictionaries by Layer

In [None]:
model.cuda()
for layer in range(17, 24):
    ids = list(correct_queries_at_first_place)
    queries = [q_id_to_entry[q_id]['query'] for q_id in ids]
    truths = [q_id_to_entry[q_id]['relevant_docs'] for q_id in ids]
    print(f'layer {layer}')
    correct_dataset = QuestionsDataset(queries, truths, ids)
    dl = DataLoader(correct_dataset, batch_size=batch_size, shuffle=False, collate_fn=dataset.collate_fn)
    hook_names = [f'decoder.{layer}.mlp.hook_pre', f'decoder.{layer}.mlp.hook_post', f'decoder.{layer}.hook_mlp_out']
    decoder_input = torch.tensor([[0]])
    layer_query_to_activations = {}
    for input_texts, target_texts, ids in dl:
        input_tokens = tokenizer(input_texts, return_tensors='pt', padding=True)['input_ids'].cuda()
        _, cache = model.run_with_cache(input_tokens, decoder_input, names_filter=hook_names)
        for pre_act_row, post_act_row, out_act_row, q_id in zip(cache[f'decoder.{layer}.mlp.hook_pre'], cache[f'decoder.{layer}.mlp.hook_post'], cache[f'decoder.{layer}.hook_mlp_out'], ids):
            layer_query_to_activations[q_id] = {'pre': pre_act_row.squeeze(0).cpu(), 'post': post_act_row.squeeze(0).cpu(), 'out': out_act_row.squeeze(0).cpu()}
        del cache
    torch.save(layer_query_to_activations, f'layer_activations_no_test/layer_{layer}.json')
    layer_query_to_activations.clear()