# Install and Load Packages

In [1]:
!pip install transformers
!pip install datasets
!pip install faiss-gpu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 4.1 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 16.6 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 82.1 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 transformers-4.25.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.7.1-py3-none-any.whl (451 kB)
[K     |████████████████████████████████| 451 kB 4.

In [2]:
import torch
import torch.nn as nn
import tqdm
import pandas as pd
import faiss
import faiss.contrib.torch_utils

from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast
from datasets import Dataset
from google.colab import auth, drive
from google.cloud import bigquery

In [3]:
auth.authenticate_user()
print('Authenticated')

drive.mount('/content/drive')
data_path = '/content/drive/MyDrive/nlp/data/'

Authenticated
Mounted at /content/drive


# Load Data

In [4]:
project_id = 'calcium-vial-368801'
client = bigquery.Client(project=project_id)

In [5]:
wiki_passages = client.query('''
SELECT doc_id, title, text, questions
FROM `calcium-vial-368801.prod_datasets.test_corpus_3_rand_sample_reduced`
''').to_dataframe()

In [6]:
wiki_passages.questions.fillna(value="", inplace=True)
wiki_passages['passage_append'] = wiki_passages['title'] + ' [SEP] ' + wiki_passages['text'] + ' [SEP] ' + wiki_passages['questions']

In [7]:
wiki_passages = wiki_passages.drop_duplicates()

# Data Loader

In [8]:
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/492 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizerFast'.


In [9]:
class MyDataset(Dataset):
    def __init__(self, dataframe, p_tokenizer):
        self.dataframe = dataframe
        self.p_tokenizer = p_tokenizer

        self.p_embed = p_tokenizer(
            self.dataframe['passage_append'].tolist(),
            return_tensors='pt',
            truncation=True,
            max_length=512,
            padding='max_length'
        )

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        return self.p_embed[index]


def collate_fn(batch):
    batchsize = len(batch)

    ctx_tensor = torch.LongTensor(
        [[sample.ids, sample.attention_mask, sample.type_ids] for sample in batch]
        )

    return ctx_tensor

BATCH_SIZE = 20

dataloader_train = torch.utils.data.DataLoader(
    MyDataset(wiki_passages, ctx_tokenizer),
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
    )

In [10]:
import gc

del ctx_tokenizer

gc.collect()
torch.cuda.empty_cache()

# Model

In [11]:
# https://discuss.huggingface.co/t/finetuning-dpr-on-custom-dataset/4170
ctx_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to("cuda")

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [12]:
gc.collect()
torch.cuda.empty_cache()

In [13]:
res = faiss.StandardGpuResources()
no_qg_index = faiss.read_index(data_path + 'nq_train_passage_encodings/nq_train_3_no_qg_beam_passage_index')
gpu_index_no_qg = faiss.index_cpu_to_gpu(res, 0, no_qg_index)

In [14]:
qg_25_index = faiss.read_index(data_path + 'nq_train_passage_encodings/nq_train_3_qg_25_beam_passage_index')
gpu_index_qg_25 = faiss.index_cpu_to_gpu(res, 0, qg_25_index)

In [15]:
class PassageEncoder(nn.Module):
    def __init__(self, p_encoder, index1, index2):
        super().__init__()
        self.p_encoder = p_encoder
        self.index1 = index1
        self.index2 = index2

    def forward(self, passage):
        encoded = self.p_encoder(passage[:, 0, :], passage[:, 1, :], passage[:, 2, :]).pooler_output.contiguous()

        self.index1.add(encoded)
        self.index2.add(encoded)


In [16]:
pEncoder = PassageEncoder(ctx_model, gpu_index_no_qg, gpu_index_qg_25)

for i in tqdm.notebook.tqdm(dataloader_train, total=len(dataloader_train)):
    pEncoder(i.to("cuda"))
    torch.cuda.empty_cache()

  0%|          | 0/9597 [00:00<?, ?it/s]

In [17]:
faiss.write_index(faiss.index_gpu_to_cpu(gpu_index_no_qg), data_path + 'wiki_nq_train_passage_encodings/wiki_nq_train_3_no_qg_beam_passage_index')
print(gpu_index_no_qg.ntotal)

218565


In [18]:
faiss.write_index(faiss.index_gpu_to_cpu(gpu_index_qg_25), data_path + 'wiki_nq_train_passage_encodings/wiki_nq_train_3_qg_25_beam_passage_index')
print(gpu_index_qg_25.ntotal)

218565


In [19]:
# Make an index with the passages and queries
import json
import numpy as np

with open(data_path + 'nq_train_passage_encodings/nq_train_3_no_qg_beam_passage_idx2docid.json') as f:
    old_index = json.loads(f.read())

for i in np.arange(len(old_index), gpu_index_no_qg.ntotal):
    old_index[str(i)] = wiki_passages.doc_id[i - len(old_index)]

with open(data_path + 'wiki_nq_train_passage_encodings/wiki_nq_train_3_no_qg_beam_passage_idx2docid.json', 'w') as f:
    f.write(json.dumps(old_index))

In [20]:
# Make an index with the passages and queries
with open(data_path + 'nq_train_passage_encodings/nq_train_3_qg_25_beam_passage_idx2docid.json') as f:
    old_index = json.loads(f.read())

for i in np.arange(len(old_index), gpu_index_qg_25.ntotal):
    old_index[str(i)] = wiki_passages.doc_id[i - len(old_index)]

with open(data_path + 'wiki_nq_train_passage_encodings/wiki_nq_train_3_qg_25_beam_passage_idx2docid.json', 'w') as f:
    f.write(json.dumps(old_index))