In [1]:
#! /usr/bin/env python3
# https://github.com/huggingface/transformers/blob/main/examples/research_projects/rag-end2end-retriever/finetune_rag.py
import numpy as np
import psutil
from sys import getsizeof
import os
from transformers import T5Tokenizer
from icecream import ic
import time
from transformers import (DPRContextEncoder,
                          DPRQuestionEncoder,
                          DPRContextEncoderTokenizerFast,
                          DPRQuestionEncoderTokenizerFast,
                          AutoTokenizer,
                          AutoModel)
import pandas as pd
import logging
import random
import copy
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from pytorch_lightning import Trainer
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers


In [2]:

logging.basicConfig(level=logging.INFO)

global device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:

class CorpusDataset(pd.DataFrame):
    def __init__(self, path_to_file: str):
        super().__init__(self.get_df(path_to_file))

    def get_df(self, path):
        df = pd.read_csv(path)
        df = df.loc[df["text"] != "\n"]
        return df


class QueryDataset(Dataset):
    def __init__(self, path_to_file: str, text_column: str, corpus, nb_irrelevant=1):
        self.df = self.get_df(path_to_file, text_column)
        self.corpus = corpus
        self.nb_irrelevant = nb_irrelevant
        self.count_doc = [0 for _ in range(len(self.df))]

    def get_df(self, path, text_column=None):
        def parse_ids(ids_str):
            ids_list = [id[1:-1] for id in ids_str[1:-1].split(", ")]
            return ids_list

        df = pd.read_csv(path)
        columns = ["query",
                   "outline",
                   "text",
                   "paragraphs_id"]
        df = df[["query",
                 "outline",
                 "text_" + text_column,
                 "paragraphs_id"]]
        df = df.loc[df["outline"].map(len) > 0]
        df["text"] = df["text_" + text_column]
        df["paragraphs_id"] = df["paragraphs_id"].apply(parse_ids)
        df = df.loc[df["text"] != "\n"]
        df = df[columns]
        return df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        row = self.df.iloc[item]
        docs_id = row["paragraphs_id"]
        query = {"query": row["query"],
                 "text": row["text"]}
        positive = self.get_positive(docs_id[self.count_doc[item]])
        negative = self.get_negative(docs_id)
        element = {"query": query,
                   "positive": positive,
                   "negative": negative}
        self.count_doc[item] = (self.count_doc[item] + 1) % len(docs_id)
        return element

    def get_positive(self, doc_id):
        positive = {"doc_id": doc_id}
        row = self.corpus[self.corpus["id"] == doc_id]
        positive["text"] = row["text"].item()
        # positive["input_ids"] = row["input_ids"].values[0]
        # positive["attention_mask"] = row["attention_mask"].values[0]
        return positive

    def get_negative(self, docs_id_positive):
        negatives_documents = []
        ids_selected = []
        for _ in range(self.nb_irrelevant):
            item_random = np.random.randint(0, len(self.corpus))
            row_negative = self.corpus.iloc[item_random]
            while row_negative["id"] in docs_id_positive + ids_selected:
                item_random = np.random.randint(0, len(self.corpus))
                row_negative = self.corpus.iloc[item_random]
            ids_selected.append(row_negative["id"])
            negative = {"doc_id": row_negative["id"],
                        "text": row_negative["text"]}
            # "input_ids": row_negative["input_ids"],
            # "attention_mask": row_negative["attention_mask"]}
            negatives_documents.append(negative)
        return negatives_documents



In [None]:

class DPR(pl.LightningModule):
    """
    Implementation of the DPR module :
    Encode all documents (contexts), and query with different BERT encoders.
    Similarity measure with dot product.
    """

    def __init__(self,
                 query_model_name: str,
                 context_model_name: str,
                 train_val_test : tuple,
                 batch_size=2,
                 num_workers=5,
                 learning_rate=1e-6):
        super().__init__()
        logging.info("\n\nWARNING about [query_tokenizer] :")
        # self.query_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained(query_model_name)
        self.query_tokenizer = AutoTokenizer.from_pretrained(query_model_name)
        logging.info("\nWARNING about [context_tokenizer]")
        # self.context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(context_model_name)
        self.context_tokenizer = AutoTokenizer.from_pretrained(context_model_name)
        logging.info("\nWARNING about [query_model]")
        # self.query_model = DPRQuestionEncoder.from_pretrained(query_model_name)
        self.query_model = DPRQuestionEncoder.from_pretrained(query_model_name)
        logging.info("\nWARNING about [context_model]")
        # self.context_model = DPRContextEncoder.from_pretrained(context_model_name)
        self.context_model = DPRContextEncoder.from_pretrained(context_model_name)
        logging.info("\n\n")

        self.learning_rate = learning_rate
        self.batch_size = batch_size

        self.train_ds, self.val_ds, self.test_ds = train_val_test

        self.loss_fn = self.get_loss_fn()

        if num_workers == -1:
            logging.info(f"Number of CPUs available : {psutil.cpu_count()}.")
            self.num_workers = psutil.cpu_count()
        else:
            self.num_workers = num_workers

    def encode_queries(self, query: dict):
        query_encoding = self.query_tokenizer(query["text"],
                                              truncation=True,
                                              max_length=512,
                                              padding="max_length",
                                              return_tensors="pt")
        query["input_ids"] = query_encoding["input_ids"]
        query["attention_mask"] = query_encoding["attention_mask"]
        return query

    def encode_contexts(self, contexts):
        def process(context):
            contexts_encoding = self.context_tokenizer(context["text"],
                                                       truncation=True,
                                                       max_length=512,
                                                       padding="max_length",
                                                       return_tensors='pt')
            context["input_ids"] = contexts_encoding["input_ids"]
            context["attention_mask"] = contexts_encoding["attention_mask"]
            return context

        if type(contexts) == list:
            for i in range(len(contexts)):
                contexts[i] = process(contexts[i])
        elif type(contexts) == dict:
            contexts = process(contexts)
        return contexts

    def decode_contexts(self, contexts_encodings: list):
        contexts = [self.context_tokenizer.decode(c) for c in contexts_encodings]
        return contexts

    def get_dense_query(self, query):
        query = self.encode_queries(query)
        dense_query = self.query_model(input_ids=query["input_ids"].to(self.device),
                                       attention_mask=query["attention_mask"].to(self.device))["pooler_output"]
        return dense_query

    def get_dense_contexts(self, contexts):
        contexts = self.encode_contexts(contexts)
        dense_embeddings = []
        for context in contexts:
            embedding = self.context_model(input_ids=context["input_ids"].to(self.device),
                                           attention_mask=context["attention_mask"].to(self.device))
            dense_embeddings.append(embedding["pooler_output"])
        return torch.cat(dense_embeddings)

    def dot_product(self, query, contexts):
        sim = query.squeeze().matmul(contexts.T)
        return sim.squeeze()

    def train_dataloader(self):
        return DataLoader(copy.deepcopy(self.train_ds),
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=True,
                          collate_fn=lambda x: x)

    def val_dataloader(self):
        return DataLoader(copy.deepcopy(self.val_ds),
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False,
                          collate_fn=lambda x: x)

    def test_dataloader(self):
        return DataLoader(copy.deepcopy(self.test_ds),
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False,
                          collate_fn=lambda x: x)

    def get_loss_fn(self, own=True):
        """negative log likelihood from DPR paper."""
        if own:
            loss_fn = lambda similarity: -torch.log(similarity[0].exp() / similarity.exp().sum())
            return loss_fn
        else:
            nllloss = nn.NLLLoss()
            loss_fn = lambda prediction: nllloss(prediction, torch.tensor(0).to(self.device))
            return loss_fn

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
                                     lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        """
        :param query:
        :param return_contexts:
        :param k:
        :return:
        """
        self.context_model.train()
        self.query_model.train()
        loss = 0
        for data in batch:
            query_dense = self.get_dense_query(data["query"])
            contexts_dense = self.get_dense_contexts([data["positive"], *data["negative"]])
            similarity = self.dot_product(query_dense, contexts_dense)
            loss += self.loss_fn(similarity)

        self.log('train/loss_step', loss.item(), on_step=True, batch_size=self.batch_size)
        self.log('train/loss_epoch', loss.item(), on_step=False, on_epoch=True, batch_size=self.batch_size)
        self.train_loss = loss
        return loss

    def training_epoch_end(self, outputs):
        logging.info(f'Finishing  epoch {str(self.current_epoch).rjust(5)} - loss : {str(self.train_loss).rjust(15)}')

    def validation_step(self, batch, batch_idx):
        self.context_model.eval()
        self.query_model.eval()
        loss = 0
        with torch.no_grad():
            for data in batch:
                query_dense = self.get_dense_query(data["query"])
                contexts_dense = self.get_dense_contexts([data["positive"], *data["negative"]])
                similarity = self.dot_product(query_dense, contexts_dense)
                loss += self.loss_fn(similarity)

            self.log('Val/loss_step', loss.item(), on_step=True, batch_size=self.batch_size)
            self.log('Val/loss_epoch', loss.item(), on_step=False, on_epoch=True, batch_size=self.batch_size)
            self.val_loss = loss

    def validation_epoch_end(self, outputs):
        logging.info(f'Validation epoch {str(self.current_epoch).rjust(5)} - loss : {str(self.val_loss).rjust(15)}')

    def test_step(self, batch, batch_idx):
        self.context_model.eval()
        self.query_model.eval()
        for data in batch:
            query_dense = self.get_dense_query(data["query"])
            contexts_dense = self.get_dense_contexts([data["positive"], *data["negative"]])
            similarity = self.dot_product(query_dense, contexts_dense)
            print(similarity)
        return "0"+str(batch_idx),"1"+str(batch_idx),"2"+str(batch_idx)

    def test_epoch_end(self, outputs):
        self.test_predictions = sum([output[0] for output in outputs], [])
        self.test_actuals = sum([output[1] for output in outputs], [])
        self.test_outlines = sum([output[2] for output in outputs], [])

    def predict(self, trainer, corpus):
        trainer.test(self)
        self.contexts_encoded = self.encode_contexts()
        return self.test_predictions, self.test_actuals, self.test_outlines

In [None]:
start_time = time.time()
logging.info("make model : dpr()")
text_column = "w/o_heading_first_sentence_by_paragraph"

logging.info(" " * 35 + "↪ elapsed time : "
                        f"{int((time.time() - start_time) // 60)}min "
                        f"{(time.time() - start_time) % 60:.2f}s.")
logging.info(f"Get corpus")

In [None]:
corpus_train = CorpusDataset(
    # path_to_file="/users/iris/rserrano/data-set_pre_processed/fold-1/corpus_train.csv",
    path_to_file="../../../data-subset_pre_processed/fold-1/corpus_train.csv",
    # context_encoder=dpr.encode_contexts
)
corpus_val = CorpusDataset(
    # path_to_file="/users/iris/rserrano/data-set_pre_processed/fold-2/corpus_train.csv",
    path_to_file="../../../data-subset_pre_processed/fold-2/corpus_train.csv",
    # context_encoder=dpr.encode_contexts
)
corpus_test = CorpusDataset(
    # path_to_file="/users/iris/rserrano/data-set_pre_processed/fold-3/corpus_train.csv",
    path_to_file="../../../data-subset_pre_processed/fold-3/corpus_train.csv",
    # context_encoder=dpr.encode_contexts
)

In [None]:
ds_train = QueryDataset(
    # path_to_file="/users/iris/rserrano/data-set_pre_processed/fold-1/articles_train_all_ids.csv",
    path_to_file="../../../data-subset_pre_processed/fold-1/articles_train_all_ids.csv",
    text_column=text_column,
    # query_encoder=dpr.encode_queries,
    corpus=corpus_train,
    nb_irrelevant=2
)
ds_val = QueryDataset(
    # path_to_file="/users/iris/rserrano/data-set_pre_processed/fold-2/articles_train_all_ids.csv",
    path_to_file="../../../data-subset_pre_processed/fold-2/articles_train_all_ids.csv",
    text_column=text_column,
    # query_encoder=dpr.encode_queries,
    corpus=corpus_val
)
ds_test = QueryDataset(
    # path_to_file="/users/iris/rserrano/data-set_pre_processed/fold-3/articles_train_all_ids.csv",
    path_to_file="../../../data-subset_pre_processed/fold-3/articles_train_all_ids.csv",
    text_column=text_column,
    # query_encoder=dpr.encode_queries,
    corpus=corpus_test
)

In [None]:
logger = pl_loggers.TensorBoardLogger("./dpr/checkpoints", name="dpr_retriever")
checkpoint_callback = ModelCheckpoint(monitor="Val/loss_epoch", mode="min", save_top_k=2, every_n_epochs=2)

In [None]:
trainer = Trainer(logger=logger,
                  precision=32,
                  accelerator="gpu",
                  gpus=-1,
                  strategy='dp',
                  max_epochs=100,
                  callbacks=[checkpoint_callback],
                  log_every_n_steps=1)

In [None]:
dpr = DPR(context_model_name="facebook/dpr-ctx_encoder-single-nq-base",
          query_model_name="facebook/dpr-question_encoder-single-nq-base",
          train_val_test=(ds_train, ds_val, ds_test))
trainer.test(model=dpr)