In [14]:
#! /usr/bin/env python3
# https://github.com/huggingface/transformers/blob/main/examples/research_projects/rag-end2end-retriever/finetune_rag.py
import numpy as np
from sys import getsizeof
import os.path
from transformers import T5Tokenizer
from icecream import ic
import time
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizerFast
import pandas as pd
import random
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from pytorch_lightning import Trainer
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers

In [4]:
class QueryDataset(Dataset):
    def __init__(self, path_to_file: str, text_column: str, query_encoder):
        self.df = self.get_df(path_to_file, text_column, query_encoder)

    def get_df(self, path, text_column=None, query_encoder=None):
        def parse_ids(ids_str):
            ids_list = [id[1:-1] for id in ids_str[1:-1].split(", ")]
            return ids_list

        df = pd.read_csv(path)
        columns = ["query",
                   "outline",
                   "text",
                   "paragraph_span",
                   "paragraph_id"]
        df = df[["query",
                 "outline",
                 "text_" + text_column,
                 "paragraph_span_" + text_column,
                 "paragraph_id_" + text_column]]
        df = df.loc[df["outline"].map(len) > 0]
        df["text"] = df["text_" + text_column]
        df["paragraph_span"] = df["paragraph_span_" + text_column]
        df["paragraph_id"] = df["paragraph_id_" + text_column].apply(parse_ids)
        df = df.loc[df["text"] != "\n"]
        df = df[columns]
        df = query_encoder(df)
        return df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        row = self.df.iloc[item]
        query = row["query"]
        ids = row["paragraph_id"]
        input_ids = row["input_ids"]
        attention_mask = row["attention_mask"]
        return {"query": query, "ids": ids, "input_ids": input_ids, "attention_mask": attention_mask}


class CorpusDataset(pd.DataFrame):
    def __init__(self, path_to_file: str):
        df = self.get_df(path_to_file)
        super().__init__(df)

    def get_df(self, path):
        df = pd.read_csv(path)
        df = df.loc[df["text"] != "\n"]
        return df

In [29]:
class DPR(nn.Module):
    """
    Implementation of the DPR module :
    Encode all documents (contexts), and query with different BERT encoders.
    Similarity measure with dot product.
    """

    def __init__(self,
                 query_model_name: str,
                 context_model_name: str,
                 contexts,
                 dense_size=64):
        super().__init__()
        self.query_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained(query_model_name)
        self.context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(context_model_name)
        self.query_model = DPRQuestionEncoder.from_pretrained(query_model_name)
        self.context_model = DPRContextEncoder.from_pretrained(context_model_name)
        self.contexts = contexts
        self.encode_contexts()

        self.log_softmax = nn.LogSoftmax(dim=1)
        self.contexts_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
                                               nn.ReLU(),
                                               nn.Linear(dense_size * 2, dense_size),
                                               nn.GELU())
        self.query_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
                                            nn.ReLU(),
                                            nn.Linear(dense_size * 2, dense_size),
                                            nn.GELU())

    def encode_contexts(self, contexts: pd.DataFrame = None):
        ctx_enc = lambda ctx: self.context_tokenizer(ctx, truncation=True, padding=True, return_tensors='pt')
        if contexts is not None:
            contexts_encoding = pd.DataFrame(contexts["text"].apply(ctx_enc).tolist())
            contexts["input_ids"] = contexts_encoding["input_ids"]
            contexts["attention_mask"] = contexts_encoding["attention_mask"]
            return contexts
        else:
            contexts_encoding = pd.DataFrame(self.contexts["text"].apply(ctx_enc).tolist())
            self.contexts["input_ids"] = contexts_encoding["input_ids"]
            self.contexts["attention_mask"] = contexts_encoding["attention_mask"]
            return self.contexts

    def decode_contexts(self, contexts_encodings: list):
        contexts = [self.context_tokenizer.decode(c) for c in contexts_encodings]
        return contexts

    def get_dense_contexts(self, contexts=None, return_tensor=False):
        dense_emb = lambda ids, mask: self.context_model(input_ids=ids, attention_mask=mask)
        dense = contexts[["input_ids", "attention_mask"]].apply(lambda x: dense_emb(x[0], x[1]), axis=1)
        dense = dense.apply(lambda x: self.contexts_to_dense(x["pooler_output"]).squeeze())
        return dense

    def encode_queries(self, queries):
        qry_enc = lambda qry: self.query_tokenizer(qry, truncation=True, padding=True, return_tensors="pt")

        queries_encoding = pd.DataFrame(queries["text"].apply(qry_enc).tolist())
        queries["input_ids"] = queries_encoding["input_ids"]
        queries["attention_mask"] = queries_encoding["attention_mask"]
        return queries

    def get_dense_query(self, query):
        dense_query = self.query_model(input_ids=query["input_ids"],
                                       attention_mask=query["attention_mask"])
        dense_query = self.query_to_dense(dense_query["pooler_output"]).squeeze()
        return dense_query

    def dot_product(self, q_vector, p_vector):
        q_vector = q_vector.unsqueeze(1)
        sim = torch.matmul(q_vector, torch.transpose(p_vector, -2, -1))
        return sim

    def context_to_tensor(self, contextx_dense):
        tensor = []
        for row in contextx_dense:
            tensor.append(row.detach().numpy())
        return torch.tensor(tensor)

    def forward_step(self, query_dense, contexts_dense_tensor, k=10):
        """
        :param query:
        :param return_contexts:
        :param k:
        :return:
        """
        similarity = (contexts_dense_tensor @ query_dense.unsqueeze(1)).squeeze()
        top_k = similarity.argsort() < k
        return top_k

    def forward(self, query_batch):
        self.contexts["dense"] = self.get_dense_contexts(self.contexts)
        contexts_dense_tensor = self.context_to_tensor(self.contexts["dense"])
        for query in query_batch:
            query_dense = self.get_dense_query(query)
            pred = self.forward_step(query_dense, contexts_dense_tensor, k=len(query["ids"]))
            ground_truth = torch.tensor(self.contexts["id"].apply(lambda x: x in query["ids"]).tolist())


In [25]:
corpus_train = CorpusDataset(path_to_file="../../../data_pre_processed/fold-1/corpus_train.csv")

corpus_train = corpus_train.sample(60).reset_index(drop=True)
# corpus_train = pd.read_csv("../../../data_pre_processed/fold-1/corpus_train.csv").sample(25).reset_index(drop=True)

# df_val = df_get("../../data_pre_processed/fold-2/articles_train.csv",
#                 text_column=text_column)
# corpus_val = pd.read_csv("../../data_pre_processed/fold-2/corpus_train.csv").sample(25)
#
# df_test = df_get("../../data_pre_processed/fold-3/articles_train.csv",
#                  text_column=text_column)
# corpus_test = pd.read_csv("../../data_pre_processed/fold-3/corpus_train.csv").sample(25)

#---------------------------------------------------------------)

In [26]:
dpr = DPR(context_model_name="facebook/dpr-ctx_encoder-single-nq-base",
          query_model_name="facebook/dpr-question_encoder-single-nq-base",
          contexts=corpus_train)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizerFast'.
Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.weight', 'question_encoder.bert_model.pooler.dense.bias']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequ

In [27]:
text_column = "w/o_heading_first_sentence_by_paragraph"

df_train = QueryDataset(path_to_file="../../../data_pre_processed/fold-1/articles_train.csv",
                        text_column=text_column,
                        query_encoder=dpr.encode_queries)

In [28]:
train_dataset = DataLoader(df_train, batch_size=2, shuffle=True, collate_fn=lambda x: x)

In [23]:
for query_batch in train_dataset:
    dpr.forward(query_batch)
    break

tensor(0.6500)
tensor(0.7000)


In [12]:
trainer = Trainer(gpus=1)

  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(model=dpr,train_dataloaders=train_dataset)

TypeError: Unwrapping the module did not yield a `LightningModule`, got <class '__main__.DPR'> instead.