In [6]:
import random
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import transformers
from itertools import chain
import pandas as pd
from torch.optim import lr_scheduler, Optimizer
from pymongo import MongoClient
import numpy as np
from transformers import AutoModel, AutoTokenizer, BatchEncoding
from transformers.modeling_outputs import BaseModelOutput
from torch.optim.lr_scheduler import _LRScheduler
from typing import Optional

from typing import Tuple, List, Union
import matplotlib.pyplot as plt


In [8]:
class QADataset(Dataset):
    def __init__(self, qa_data: pd.DataFrame):
        assert 'QuestionTitle' in qa_data.columns, "DataFrame must contain 'QuestionTitle' column"
        assert 'QuestionBody' in qa_data.columns, "DataFrame must contain 'QuestionBody' column"
        assert 'Answer' in qa_data.columns, "DataFrame must contain 'Answer' column"
        
        self.qa_data = qa_data
    
    def __len__(self) -> int:
        return len(self.qa_data['QuestionTitle'])
    
    def __getitem__(self, index: int) -> dict:
        # Access the data directly using iloc, which is more memory-efficient
        return {
            'title': self.qa_data.iloc[index]['QuestionTitle'],
            'body': self.qa_data.iloc[index]['QuestionBody'],
            'answers': self.qa_data.iloc[index]['Answer']
        }
        
class TrainValidatePipeline:
    def __init__(self, q_model: AutoModel, a_model: AutoModel, tokenizer: AutoTokenizer, optimizer: Optimizer, scheduler: Optional[_LRScheduler], device: torch.device ='cpu'):
        self.q_model = q_model
        self.a_model = a_model
        self.tokenizer = tokenizer
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device

    def tokenize_qa_batch(self, q_titles: List[str], q_bodies: List[str], answers: List[str], max_length: int = 64) -> Tuple[BatchEncoding, BatchEncoding]:

        q_batch = self.tokenizer(text = q_titles, text_pair = q_bodies, padding ="longest", max_length = max_length, truncation = True, return_tensors ="pt")
        a_batch = self.tokenizer(text = answers, padding= "longest", max_length = max_length, truncation = True, return_tensors = "pt")

        q_batch = {k: v.to(self.device) for k, v in q_batch.items()}
        a_batch = {k: v.to(self.device) for k, v in a_batch.items()}

        return q_batch, a_batch

    def get_class_output(self, model: AutoModel, batch: BatchEncoding) -> BaseModelOutput:
        output = model(**batch)
        output = output.last_hidden_state
        return output[:,0,:]

    def inbatch_negative_sampling(self, Q: Tensor, P: Tensor) -> Tensor:
        # Q: Tensor of shape: N question titles + bodies, E embedding dimension
        # P: Tensor of shape: N passages, E embedding dimension

        S = (Q @ P.transpose(0,1)).to(self.device)

        return S


    def contrastive_loss_criterion(self, S: Tensor, labels: Tensor = None) -> Tensor:
        # First Calculate the log softmax as per the paper's definition
        softmax_scores = F.log_softmax(S, dim = 1)
        if labels == None:
            labels = torch.arange(len(S)).to(self.device)

        loss = F.nll_loss(softmax_scores, labels.to(self.device))

        return loss

    def get_topk_indices(self, Q: Tensor, P: Tensor, k: int = None) -> Tuple[Tensor, Tensor]:
        S = self.inbatch_negative_sampling(Q, P)
        if k == None:
            k = len(S)
        
        scores, indices = torch.topk(S, k)

        return indices, scores

    def select_by_indices(self, indices: Tensor, passages: List[str]) -> List[List[str]]:
        return [[passages[idx] for idx in index] for index in indices]

    def embed_passages(self, passages: List[str], max_length: int = 512) -> BaseModelOutput:
        return self.__embed_text(passages, self.a_model, self.tokenizer, max_length, as_pair=False)

    def embed_questions(self, titles: List[str], bodies: List[str], max_length: int = 512) -> BaseModelOutput:
        return self.__embed_text((titles, bodies), self.q_model, self.tokenizer, max_length, as_pair=True)

    def __embed_text(self, texts: Union[List[str], Tuple[List[str], List[str]]], model : AutoModel, tokenizer : AutoTokenizer, max_length : int = 512, as_pair : bool = False) -> BaseModelOutput:
        model.eval()
        with torch.no_grad():
            if as_pair:
                encoded_batch = tokenizer(
                    text=texts[0], text_pair=texts[1],
                    max_length=max_length, truncation=True,
                    padding='max_length', return_tensors='pt'
                )
            else:
                encoded_batch = tokenizer(
                    text=texts,
                    max_length=max_length, truncation=True,
                    padding='max_length', return_tensors='pt'
                )
            encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()}
            outputs = model(**encoded_batch)
            return outputs.last_hidden_state[:, 0, :]

    def train(self, train_loader: DataLoader, valid_loader : DataLoader, epochs : int) ->  Tuple[List[float], List[float], List[float], List[float]]:
        training_loss = []
        validation_loss = []
        validation_recall = []
        validation_mrr = []

        for epoch in range(epochs):
            self.q_model.train()
            self.a_model.train()
            total_train_loss = 0

            for train_batch in tqdm(train_loader):
                q_titles = train_batch['title']
                q_bodies = train_batch['body']
                answers = train_batch['answers']

                # Tokenize and embed the batch data
                q_batch, a_batch = self.tokenize_qa_batch(q_titles, q_bodies, answers)
                q_out = self.get_class_output(self.q_model, q_batch)
                a_out = self.get_class_output(self.a_model, a_batch)

                S = self.inbatch_negative_sampling(q_out, a_out)
                loss = self.contrastive_loss_criterion(S)
                total_train_loss += loss.item()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            
            self.scheduler.step()
            average_train_loss = total_train_loss / len(train_loader)
            print(f"Epoch: {epoch+1} | Average training loss per sample: {average_train_loss}")
            training_loss.append(average_train_loss)

            # Perform validation
            avg_valid_loss, recall_k, mrr = self.validate(valid_loader)
            validation_loss.append(avg_valid_loss)
            validation_recall.append(recall_k)
            validation_mrr.append(mrr)

        return training_loss, validation_loss, validation_recall, validation_mrr

    def validate(self, valid_loader : DataLoader) -> Tuple[float, float, float]:
        self.q_model.eval()
        self.a_model.eval()
        total_valid_loss = 0
        all_retrieved_indices = []
        all_true_indices = []
        
        with torch.no_grad():
            for valid_batch in tqdm(valid_loader):
                q_titles = valid_batch['title']
                q_bodies = valid_batch['body']
                answers = valid_batch['answers']

                # Embed questions and answers
                Q = self.embed_questions(titles=q_titles, bodies=q_bodies, max_length=512)
                P = self.embed_passages(passages=answers, max_length=512)

                S = self.inbatch_negative_sampling(Q, P)
                loss = self.contrastive_loss_criterion(S)
                total_valid_loss += loss.item()

                indices, _ = self.get_topk_indices(Q, P, k=5)
                true_indices = list(range(len(Q)))
                all_retrieved_indices.extend(indices.cpu().tolist())
                all_true_indices.extend(true_indices)

        average_valid_loss = total_valid_loss / len(valid_loader)
        recall_k = self.recall_at_k(all_retrieved_indices, all_true_indices, k=5)
        mrr = self.mean_reciprocal_rank(all_retrieved_indices, all_true_indices)

        print(f"Validation | Average loss per sample: {average_valid_loss}")
        print(f"Validation | Recall@k: {recall_k}")
        print(f"Validation | MRR: {mrr}")

        return average_valid_loss, recall_k, mrr
    
    def recall_at_k(self, retrieved_indices: List[List[int]], true_indices: List[int], k : int) -> float:
        hit = 0
        for true,retrieved in zip(true_indices, retrieved_indices):
            top_k_set = set(retrieved[:k])
            if true in top_k_set:
                hit += 1
        total = len(true_indices)

        return hit / total

    def mean_reciprocal_rank(self, retrieved_indices: List[List[int]], true_indices: List[int]) -> float:
        hit = 0
        for true,retrieved in zip(true_indices, retrieved_indices):
            try:
                rank = 1 + retrieved.index(true)
                hit += 1 / rank
            except ValueError:
                continue
        total = len(true_indices)
        return hit / total
    
    
    
def load_models_and_tokenizer(q_name: str, a_name: str, t_name: str, device: torch.device ='cpu') -> Tuple[AutoModel, AutoModel, AutoTokenizer]:
    q_enc = AutoModel.from_pretrained(q_name).to(device)
    a_enc = AutoModel.from_pretrained(a_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(t_name)
    
    return q_enc, a_enc, tokenizer

def enableMultiGPU(model: AutoModel, multi_gpu: bool):
    if multi_gpu:
        model = nn.DataParallel(model)

        import os
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        
    return model


In [None]:
# For Reproducibility
random.seed(2024)
torch.manual_seed(2024)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
bsize = 128
n_epoch = 20
lr = 5e-5
name = 'google/electra-small-discriminator'
step_size = 8
gamma = 0.8
multi_gpu = False

In [None]:
# Load File
qa_data = dict(
    train = pd.read_csv('qa/train.csv'),
    valid = pd.read_csv('qa/validation.csv'),
    answers = pd.read_csv('qa/answers.csv'),
    test = pd.read_csv('qa/test.csv'),
)

In [None]:
# Create DataLoader
train_dataset = QADataset(qa_data['train'])
valid_dataset = QADataset(qa_data['valid'])                      

train_loader = DataLoader(train_dataset, batch_size = bsize, shuffle = True)
valid_loader = DataLoader(valid_dataset, batch_size = bsize, shuffle = False)

In [None]:
# Initialize Training and Validation Pipeline
q_enc, a_enc, tokenizer = load_models_and_tokenizer(q_name = name, a_name = name, t_name = name, device = device)
q_enc = enableMultiGPU(q_enc, multi_gpu)
a_enc = enableMultiGPU(a_enc, multi_gpu)


optimizer = torch.optim.Adam(chain(q_enc.parameters(), a_enc.parameters()), lr= lr)
scheduler = lr_scheduler.StepLR(optimizer,step_size = step_size, gamma = gamma)

pipeline = TrainValidatePipeline(q_enc, a_enc, tokenizer, optimizer, scheduler, device)

In [None]:
# Perform Training and Validation at the same time
t_l, v_l, v_r, v_mrr = pipeline.train(train_loader, valid_loader, n_epoch)

In [None]:
# Plot the graph:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (10, 4))
ax1.plot(t_l, label = "Training Loss")  
ax1.plot(v_l, label = "Validation Loss")
ax1.set_title('Loss vs. Epoch')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()

ax2.plot(v_r, label = "Validation Recall")
ax2.plot(v_mrr, label = "Validaion Mean Reciprocal Rank")
ax2.set_title('Validation Results')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Result')
ax2.legend()

plt.show()

In [None]:
# Save Model:
q_enc = q_enc.to('cpu')
a_enc = a_enc.to('cpu')
torch.save(q_enc.state_dict(), "model/q_encoder_model.bin")
torch.save(a_enc.state_dict(), "model/a_encoder_model.bin")