## Prepare Embeddings Before Training
This notebook will pre-calculate the embedding of our queries and their expansion terms.
Result is saved as a `pickled` dictionary which maps an expanded query to a tensor of its embeddings.

Queries and their candidate is generated using a fine-tuned BERT, so for all the metrics these queries is the same. Using only one file to calculate the embedding is enough. The final pickled file can be used for all of them

In [1]:
from transformers import AutoTokenizer, pipeline
from transformers import BertForMaskedLM
from utils import set_seed
import torch
import pandas as pd
from tqdm import tqdm
from utils import save_pickle, get_query_and_candidates_embeddings
import os

set_seed()

In [2]:
model_checkpoint = "./models/bert/Saved_model_epochs_5/"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model: BertForMaskedLM = BertForMaskedLM.from_pretrained(model_checkpoint)
bert_model = model.bert
del model

In [3]:
device = torch.device('cpu')
bert_model = bert_model.to(device)


In [4]:
df = pd.read_csv('./data/precision/expanded_and_labeled_queries_with_pubmed_precision_n1000_200candidates.csv')
df.head()

Unnamed: 0,index,Q_name,candidate,new_q,base_reward,new_reward,label
0,0,coronavirus origin,asian,coronavirus origin asian,0.007,0.019231,1
1,0,coronavirus origin,natural,coronavirus origin natural,0.007,0.131579,1
2,0,coronavirus origin,disease,coronavirus origin disease,0.007,0.086,1
3,0,coronavirus origin,unknown,coronavirus origin unknown,0.007,0.136612,1
4,0,coronavirus origin,common,coronavirus origin common,0.007,0.105263,1


In [5]:

query_to_torch = dict()
for _, row in tqdm(df.iterrows(), total=len(df), desc="create embeddings"):
        query = row['Q_name']
        candidate = row['candidate']
        key = query + ' ' + candidate
        if key not in query_to_torch:
            query_to_torch[key] = get_query_and_candidates_embeddings([query], [candidate], tokenizer=tokenizer, bert_model = bert_model, device=device).squeeze()


create embeddings: 100%|██████████| 17941/17941 [01:26<00:00, 208.60it/s]


In [6]:
save_pickle(query_to_torch,"./data/expanded_queries_embeddings.pkl")