In [7]:
import random
from typing import Union

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import transformers
from itertools import chain
import pandas as pd
from torch.optim import lr_scheduler


In [8]:
class QADataset(Dataset):
    def __init__(self, qa_data: pd.DataFrame):
        assert 'QuestionTitle' in qa_data.columns, "DataFrame must contain 'QuestionTitle' column"
        assert 'QuestionBody' in qa_data.columns, "DataFrame must contain 'QuestionBody' column"
        assert 'Answer' in qa_data.columns, "DataFrame must contain 'Answer' column"
        
        self.qa_data = qa_data
    
    def __len__(self) -> int:
        return len(self.qa_data['QuestionTitle'])
    
    def __getitem__(self, index: int) -> dict:
        # Access the data directly using iloc, which is more memory-efficient
        return {
            'title': self.qa_data.iloc[index]['QuestionTitle'],
            'body': self.qa_data.iloc[index]['QuestionBody'],
            'answers': self.qa_data.iloc[index]['Answer']
        }
        
class TrainValidatePipeline:
    def __init__(self, q_model, a_model, tokenizer, optimizer, scheduler, device='cpu'):
        self.q_model = q_model
        self.a_model = a_model
        self.tokenizer = tokenizer
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device

    def tokenize_qa_batch(self, q_titles, q_bodies, answers, max_length=64):
        q_batch = self.tokenizer(text = q_titles, text_pair = q_bodies, padding ="longest", max_length = max_length, truncation = True, return_tensors ="pt")
        a_batch = self.tokenizer(text = answers, padding= "longest", max_length = max_length, truncation = True, return_tensors = "pt")

        q_batch = {k: v.to(self.device) for k, v in q_batch.items()}
        a_batch = {k: v.to(self.device) for k, v in a_batch.items()}

        return q_batch, a_batch

    def get_class_output(self, model, batch):
        output = model(**batch)
        output = output.last_hidden_state
        return output[:,0,:]

    def inbatch_negative_sampling(self, Q, P):
        S = (Q @ P.transpose(0,1)).to(self.device)

        return S


    def contrastive_loss_criterion(self, S, labels=None):
        # First Calculate the log softmax as per the paper's definition
        softmax_scores = F.log_softmax(S, dim = 1)
        if labels == None:
            labels = torch.arange(len(S)).to(self.device)

        loss = F.nll_loss(softmax_scores, labels.to(self.device))

        return loss

    def get_topk_indices(self, Q, P, k=None):
        S = self.inbatch_negative_sampling(Q, P)
        if k == None:
            k = len(S)
        
        scores, indices = torch.topk(S, k)

        return indices, scores

    def select_by_indices(self, indices, passages):
        return [[passages[idx] for idx in index] for index in indices]

    def embed_passages(self, passages: 'list[str]', max_length=512):
        return self.__embed_text(passages, self.a_model, self.tokenizer, max_length, as_pair=False)

    def embed_questions(self, titles, bodies, max_length=512):
        return self.__embed_text((titles, bodies), self.q_model, self.tokenizer, max_length, as_pair=True)

    def __embed_text(self, texts, model, tokenizer, max_length=512, as_pair=False):
        model.eval()
        with torch.no_grad():
            if as_pair:
                encoded_batch = tokenizer(
                    text=texts[0], text_pair=texts[1],
                    max_length=max_length, truncation=True,
                    padding='max_length', return_tensors='pt'
                )
            else:
                encoded_batch = tokenizer(
                    text=texts,
                    max_length=max_length, truncation=True,
                    padding='max_length', return_tensors='pt'
                )
            encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()}
            outputs = model(**encoded_batch)
            return outputs.last_hidden_state[:, 0, :]

    def train(self, train_loader, valid_loader, epochs):
        training_loss = []
        validation_loss = []
        validation_recall = []
        validation_mrr = []

        for epoch in range(epochs):
            self.q_model.train()
            self.a_model.train()
            total_train_loss = 0

            for train_batch in tqdm(train_loader):
                q_titles = train_batch['title']
                q_bodies = train_batch['body']
                answers = train_batch['answers']

                # Tokenize and embed the batch data
                q_batch, a_batch = self.tokenize_qa_batch(q_titles, q_bodies, answers)
                q_out = self.get_class_output(self.q_model, q_batch)
                a_out = self.get_class_output(self.a_model, a_batch)

                S = self.inbatch_negative_sampling(q_out, a_out)
                loss = self.contrastive_loss_criterion(S)
                total_train_loss += loss.item()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            
            self.scheduler.step()
            average_train_loss = total_train_loss / len(train_loader)
            print(f"Epoch: {epoch+1} | Average training loss per sample: {average_train_loss}")
            training_loss.append(average_train_loss)

            # Perform validation
            avg_valid_loss, recall_k, mrr = self.validate(valid_loader)
            validation_loss.append(avg_valid_loss)
            validation_recall.append(recall_k)
            validation_mrr.append(mrr)

        return training_loss, validation_loss, validation_recall, validation_mrr

    def validate(self, valid_loader):
        self.q_model.eval()
        self.a_model.eval()
        total_valid_loss = 0
        all_retrieved_indices = []
        all_true_indices = []
        
        with torch.no_grad():
            for valid_batch in tqdm(valid_loader):
                q_titles = valid_batch['title']
                q_bodies = valid_batch['body']
                answers = valid_batch['answers']

                # Embed questions and answers
                Q = self.embed_questions(titles=q_titles, bodies=q_bodies, max_length=512)
                P = self.embed_passages(passages=answers, max_length=512)

                S = self.inbatch_negative_sampling(Q, P)
                loss = self.contrastive_loss_criterion(S)
                total_valid_loss += loss.item()

                indices, _ = self.get_topk_indices(Q, P, k=5)
                true_indices = list(range(len(Q)))
                all_retrieved_indices.extend(indices.cpu().tolist())
                all_true_indices.extend(true_indices)

        average_valid_loss = total_valid_loss / len(valid_loader)
        recall_k = self.recall_at_k(all_retrieved_indices, all_true_indices, k=5)
        mrr = self.mean_reciprocal_rank(all_retrieved_indices, all_true_indices)

        print(f"Validation | Average loss per sample: {average_valid_loss}")
        print(f"Validation | Recall@k: {recall_k}")
        print(f"Validation | MRR: {mrr}")

        return average_valid_loss, recall_k, mrr
    
    def recall_at_k(self, retrieved_indices, true_indices, k):
        hit = 0
        for true,retrieved in zip(true_indices, retrieved_indices):
            top_k_set = set(retrieved[:k])
            if true in top_k_set:
                hit += 1
        total = len(true_indices)

        return hit / total

    def mean_reciprocal_rank(self, retrieved_indices, true_indices):
        hit = 0
        for true,retrieved in zip(true_indices, retrieved_indices):
            try:
                rank = 1 + retrieved.index(true)
                hit += 1 / rank
            except ValueError:
                continue
        total = len(true_indices)
        return hit / total
    
    
    
def load_models_and_tokenizer(q_name, a_name, t_name, device='cpu'):
    q_enc = transformers.AutoModel.from_pretrained(q_name).to(device)
    a_enc = transformers.AutoModel.from_pretrained(a_name).to(device)
    tokenizer = transformers.AutoTokenizer.from_pretrained(t_name)
    
    return q_enc, a_enc, tokenizer


In [None]:
# For Reproducibility
random.seed(2024)
torch.manual_seed(2024)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
# Hyperparameters
bsize = 64
n_epoch = 10
lr = 5e-5
name = 'google/electra-small-discriminator'
step_size = 8
gamma = 0.8

# Load File
qa_data = dict(
    train = pd.read_csv('qa/train.csv'),
    valid = pd.read_csv('qa/validation.csv'),
    answers = pd.read_csv('qa/answers.csv'),
    test = pd.read_csv('qa/test.csv'),
)

train_dataset = QADataset(qa_data['train'])
valid_dataset = QADataset(qa_data['valid'])                      

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size = bsize, shuffle = True)
valid_loader = DataLoader(valid_dataset, batch_size = bsize, shuffle = False)

q_enc, a_enc, tokenizer = load_models_and_tokenizer(q_name=name, a_name=name, t_name=name, device = device)
optimizer = torch.optim.Adam(chain(q_enc.parameters(), a_enc.parameters()), lr= lr)
scheduler = lr_scheduler.StepLR(optimizer,step_size = step_size, gamma = gamma)

pipeline = TrainValidatePipeline(q_enc, a_enc, tokenizer, optimizer, scheduler, device)

t_l, v_l, v_r, v_mrr = pipeline.train(train_loader, valid_loader, n_epoch)

  return self.fget.__get__(instance, owner)()
  0%|          | 0/169 [00:00<?, ?it/s]

In [None]:
answers = pd.read_csv('/kaggle/input/assignment4/data/qa/answers.csv')
def load_model(model_path, model_name, device = 'cpu'):
    model =  transformers.AutoModel.from_pretrained(model_name).to(device)
    model.load_state_dict(torch.load(model_path, map_location = device))
    return model

def load_tokenizer(tokenizer_name):
    return transformers.AutoTokenizer.from_pretrained(tokenizer_name)

load_a = load_model("/kaggle/input/bert-model-odqa/a_encoder_model.bin", 'google/electra-small-discriminator', device )
load_q = load_model("/kaggle/input/bert-model-odqa/q_encoder_model.bin", 'google/electra-small-discriminator', device )
tokenizer = load_tokenizer('google/electra-small-discriminator')


In [None]:
from torch.quantization import get_default_qconfig
from torch.quantization import prepare

qconfig = get_default_qconfig('fbgemm') 
load_a.eval()
load_q.eval()
load_a.qconfig = qconfig
load_q.qconfig = qconfig

prepare(load_q, inplace=True)
prepare(load_a, inplace=True)

load_a = nn.DataParallel(load_a)
load_q = nn.DataParallel(load_q)

In [None]:
class InferencePipeline:
    def __init__(self, q_model, a_model, tokenizer,answer_loader, answer, device = "cpu"):
        self.q_model = q_model
        self.a_model = a_model
        self.device = device 
        self.tokenizer = tokenizer 
        self.answer_loader = answer_loader
        self.answer = answer
        self.output_embedding = None
    
    def embed_passage(self, batch_size = 64, max_length = 256):
        self.a_model.eval()
        # Process answers in batches
        output_embedding = torch.zeros((350, 256), device=self.device)
        position = 0
        for answers in self.answer_loader:
            with torch.no_grad():
                encoded_batch = self.tokenizer(
                    text = answers,
                    max_length = max_length,
                    truncation = True,
                    padding="max_length",
                    return_tensors = 'pt'
                )
            encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()}
            outputs = self.a_model(**encoded_batch)
            batch_embedding = outputs.last_hidden_state[:,0,:]
            
            batch_size = batch_embedding.size(0)
            output_embedding[position : position + batch_size, :] = batch_embedding
            
            position += batch_size  
        self.output_embedding = output_embedding
        return output_embedding
    
    def embed_question(self, title, body, max_length = 100):
        self.q_model.eval()
        with torch.no_grad():
            encoded_batch = self.tokenizer(
                text=title, text_pair=body,
                max_length=max_length, truncation=True,
                padding='max_length', return_tensors='pt' 
            )
        encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()}
        outputs = self.q_model(**encoded_batch)
        batch_embedding = outputs.last_hidden_state[:,0,:]
        return batch_embedding
    
    def inbatch_negative_sampling(self, Q, P):
        S = (Q @ P.transpose(0,1)).to(self.device)
        return S


    def get_topk_indices(self, Q, P, k=None):
        S = self.inbatch_negative_sampling(Q, P)
        if k == None:
            k = len(S)
        scores, indices = torch.topk(S, k)

        return indices, scores

    def inference(self, title, body):
        Q = self.embed_question(title, body)
        if self.output_embedding is None:
            P = self.embed_passage()
        else:
            P = self.output_embedding
        idx, scores = self.get_topk_indices(Q, P, k = 1)
        return self.answer[idx]
    
    

In [None]:
class AnswerDataset(Dataset):
    def __init__(self, answer):
        self.answer = answer[:350]
    
    def __len__(self):
        return len(self.answer)
    
    def __getitem__(self, index):
        return self.answer[index]

In [None]:
batch_size = 64
answers['Answer'] = answers['Answer'].fillna('').str.replace('[^a-zA-Z0-9.!,]', ' ', regex=True).replace('\s+', ' ', regex=True)
answer_full = answers['Answer'].tolist()
answers['Answer'] = answers['Answer'].str[:250]
answer = answers['Answer'].tolist()

answerDataset = AnswerDataset(answer)
answer_loader = DataLoader(answerDataset, batch_size = batch_size, shuffle = False, num_workers = 4, pin_memory = True)
testpipeline = InferencePipeline(load_q, load_a, tokenizer,answer_loader, answer_full, device)
title = ["Season Fried Chicken"]
body = ["I've been trying to season fried chicken but I don't really know where to start."]
testpipeline.inference(title, body)

In [None]:
from pymongo import MongoClient
from dotenv import load_dotenv

def get_database():
    pass