# Install and Load Packages

In [5]:
!pip install transformers
!pip install datasets
!pip install faiss-gpu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
import torch
import torch.nn as nn
import tqdm
import faiss
import faiss.contrib.torch_utils
import pandas as pd
import json

from transformers import DPRQuestionEncoderTokenizerFast, DPRQuestionEncoder
from datasets import load_dataset, load_from_disk, Dataset
from google.colab import auth, drive
from google.cloud import bigquery

In [7]:
auth.authenticate_user()
print('Authenticated')

drive.mount('/content/drive')
data_path = '/content/drive/MyDrive/nlp/data/wiki_nq_train_passage_encodings/'

Authenticated
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Load Data

In [8]:
with open(data_path + 'easy_wiki_nq_train_3_no_qg_greedy_passage_idx2docid.json') as f:
    idx2docid = json.loads(f.read())

In [9]:
len(idx2docid)

276634

In [10]:
passage_index = faiss.read_index(data_path + 'easy_wiki_nq_train_3_no_qg_greedy_passage_index')
res = faiss.StandardGpuResources()
gpu_index = faiss.index_cpu_to_gpu(res, 0, passage_index)

In [11]:
project_id = 'calcium-vial-368801'
client = bigquery.Client(project=project_id)

In [12]:
dt_train_queries = client.query('''
SELECT DISTINCT query_id, doc_id, title, text, questions
FROM `calcium-vial-368801.staging.nq_train_documents_3_qg_25_beam`
''').to_dataframe()

In [13]:
query2docid = dt_train_queries[['query_id', 'doc_id']].drop_duplicates().set_index('query_id').to_dict('index')

In [14]:
# Get golden passages
query_text = client.query('''
SELECT DISTINCT query_id, text
FROM `calcium-vial-368801.beir_nq_train.train_query_lookup`
''').to_dataframe()

In [15]:
query_text_filter = query_text[query_text['query_id'].isin([i for i in query2docid.keys()])]

In [16]:
query2text = query_text_filter.set_index('query_id').to_dict('index')

In [17]:
question_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/493 [00:00<?, ?B/s]

In [18]:
class MyDataset(Dataset):
    def __init__(self, query_dict, q_tokenizer):
        self.query_dict = query_dict
        self.q_tokenizer = q_tokenizer
        self.ids = [i for i in query_dict.keys()]

    def __len__(self):
        return len(self.query_dict)

    def __getitem__(self, index):
        q_id = self.ids[index]
        
        q_embed = self.q_tokenizer(
            self.query_dict[q_id]['text'],
            return_tensors='pt',
            truncation=True,
            max_length=20,
            padding='max_length'
            )

        return q_id, q_embed


def collate_fn(batch):
    batchsize = len(batch)
    # Batch_size x 3 x num_query_tokens
    input_ids = torch.vstack([sample[1]['input_ids'] for sample in batch]).unsqueeze(2)
    token_type_ids = torch.vstack([sample[1]['token_type_ids'] for sample in batch]).unsqueeze(2)
    attention_mask = torch.vstack([sample[1]['attention_mask'] for sample in batch]).unsqueeze(2)
    token_tensors = torch.cat((input_ids, attention_mask, token_type_ids), dim=2)
    
    query_ids = [sample[0] for sample in batch]

    return query_ids, token_tensors

BATCH_SIZE = 64

dataloader_train = torch.utils.data.DataLoader(
    MyDataset(query2text, question_tokenizer),
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
    )

In [19]:
query_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base').to("cuda")

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [20]:
class QueryEncoder(nn.Module):
    def __init__(self, q_encoder, passage_index, k):
        super().__init__()
        self.q_encoder = q_encoder
        self.passage_index = passage_index
        self.k = k

    def forward(self, queries):
        """
        args
            - queries: batchsize x 3 x q_embed matrix of question tokens, token_ids, and attention_mask
        
        returns:
            - dict("query"): batchsize x passage_embed_dim tensor of dense question embeddings
        """

        queries_enc = self.q_encoder(queries[:, :, 0], queries[:, :, 1], queries[:, :, 2])
        top_k = self.passage_index.search(queries_enc.pooler_output.contiguous(), self.k)
        _, passage_idx = top_k[0], top_k[1]


        return passage_idx

In [21]:
def check_retrieval(query_id, top_k_ids):
    
    true_passage = query2docid[query_id]['doc_id']
    recovered_passages = [idx2docid[str(p)] for p in top_k_ids]

    #if true_passage in recovered_passages:
    #    print(query_id)
    #    print(true_passage)
    #    print(recovered_passages)
    #    print(recovered_passages.index(true_passage))

    return true_passage in recovered_passages

In [22]:
qEncoder = QueryEncoder(q_encoder=query_model, passage_index=gpu_index, k=10)

successes = 0
total = 0

for ids, tokens in tqdm.notebook.tqdm(dataloader_train, total=len(dataloader_train)):
    bsz = len(ids)
    top_k_passages = qEncoder(tokens.to("cuda"))
    
    for n in range(bsz):
        retrieve = check_retrieval(ids[n], top_k_passages[n, :].tolist())
        successes = successes + retrieve
        total = total + 1


  0%|          | 0/417 [00:00<?, ?it/s]

In [23]:
print(f"k = 10\nAccuracy {successes/total * 100}")

k = 10
Accuracy 76.73274761582938


In [24]:
qEncoder = QueryEncoder(q_encoder=query_model, passage_index=gpu_index, k=20)

successes = 0
total = 0

for ids, tokens in tqdm.notebook.tqdm(dataloader_train, total=len(dataloader_train)):
    bsz = len(ids)
    top_k_passages = qEncoder(tokens.to("cuda"))
    
    for n in range(bsz):
        retrieve = check_retrieval(ids[n], top_k_passages[n, :].tolist())
        successes = successes + retrieve
        total = total + 1

  0%|          | 0/417 [00:00<?, ?it/s]

In [25]:
print(f"k = 20\nAccuracy {successes/total * 100}")

k = 20
Accuracy 81.90658556731996


In [26]:
qEncoder = QueryEncoder(q_encoder=query_model, passage_index=gpu_index, k=50)

successes = 0
total = 0

for ids, tokens in tqdm.notebook.tqdm(dataloader_train, total=len(dataloader_train)):
    bsz = len(ids)
    top_k_passages = qEncoder(tokens.to("cuda"))
    
    for n in range(bsz):
        retrieve = check_retrieval(ids[n], top_k_passages[n, :].tolist())
        successes = successes + retrieve
        total = total + 1

  0%|          | 0/417 [00:00<?, ?it/s]

In [27]:
print(f"k = 50\nAccuracy {successes/total * 100}")

k = 50
Accuracy 86.9940677329729


In [28]:
qEncoder = QueryEncoder(q_encoder=query_model, passage_index=gpu_index, k=100)

successes = 0
total = 0

for ids, tokens in tqdm.notebook.tqdm(dataloader_train, total=len(dataloader_train)):
    bsz = len(ids)
    top_k_passages = qEncoder(tokens.to("cuda"))
    
    for n in range(bsz):
        retrieve = check_retrieval(ids[n], top_k_passages[n, :].tolist())
        successes = successes + retrieve
        total = total + 1

  0%|          | 0/417 [00:00<?, ?it/s]

In [29]:
print(f"k = 100\nAccuracy {successes/total * 100}")

k = 100
Accuracy 89.82503566869416
