In [None]:
############################################
# Team : RAGrats      
# Team Members : Ali Asgar Padaria, Param Patel, Meet Zalavadiya
#                    
# Code Description : This file contains the code for the Baseline 2 Model - Specifically the retriever part
#                    Here we create embeddings of all the contexts and store it in a VectorStore and extract top 5 relevant 
#                    embeddings for each of the question in the validation set. These are stored and used in the second 
#                    part of the baseline which is baseline_2_generation which generates explainatinos and classifies labels.
#                    
# NLP Concepts Usage: Tokenization, Embeddings
#                       
# System : GCP Server L4 GPU
#############################################

In [None]:
# Import Necessary Libraries
from transformers import AutoTokenizer, AutoModel
import torch
import faiss
import numpy as np
from collections import defaultdict
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import json
from datasets import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from datasets import load_from_disk
import os
import matplotlib.pyplot as plt
from datasets import concatenate_datasets
from sentence_transformers import SentenceTransformer

In [None]:
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # NLP Concept : SentenceTransformer, Embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
tokenizer = model.tokenizer

In [29]:
train_dataset = load_from_disk("../files/train_dataset")
val_dataset = load_from_disk("../files/val_dataset")

In [None]:
# Get Passages from Train and Val Datasets
train_contexts = list({context for item in train_dataset for context in item["context"]["contexts"]})
val_contexts = list({context for item in val_dataset for context in item["context"]["contexts"]})

# merge train and val contexts
all_contexts = list(set(train_contexts) | set(val_contexts))

In [31]:
context_embeddings = model.encode(all_contexts, batch_size=16)

In [None]:
index = faiss.IndexFlatL2(context_embeddings.shape[1])
index.add(context_embeddings)
faiss.write_index(index, "../files/baseline_2_all_faiss.index") # store all embeddings into a faiss index file

### Evaluate Retriever
Check Retriever Accuracy on this model on the complete dataset

In [33]:
# Concatenate Datasets
full_dataset = concatenate_datasets([train_dataset, val_dataset])

In [35]:
def retrieval_evaluation(dataset, index, contexts, k=5):
    total_matched = 0
    total_expected = 0

    for item in tqdm(dataset):
        query = item["question"]
        query_embedding = model.encode([query])[0]
        gold_contexts = set(item["context"]["contexts"])

        _, I = index.search(query_embedding.reshape(1, -1), len(gold_contexts))

        retrieved_contexts = {contexts[i] for i in I[0]}
        
        # Check if any of the retrieved contexts match the answers
        matched = len(gold_contexts & retrieved_contexts)
        total_matched += matched
        total_expected += len(gold_contexts)

    return total_matched / total_expected

In [36]:
retrieval_accuracy = retrieval_evaluation(full_dataset, index, all_contexts)
print(f"Retrieval Accuracy: {retrieval_accuracy:.4f}")

100%|██████████| 20000/20000 [13:50<00:00, 24.08it/s]

Retrieval Accuracy: 0.5950





Store Retrieved Context Pairs for Validation Set

In [None]:
val_dataset

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 2000
})

In [None]:
def retrieve_for_validation_set(dataset, index, contexts, k=5):
    retreived_pairas = []
    for item in tqdm(dataset):
        query = item["question"]
        query_embedding = model.encode([query])[0]

        _, I = index.search(query_embedding.reshape(1, -1), 5)

        retrieved_contexts = [contexts[i] for i in I[0]]
        
        # store question and retireved contexts in a dict object
        retreived_pairas.append({
            "question": query,
            "retrieved_contexts": retrieved_contexts
        })
    return retreived_pairas


In [None]:
# Retreive relevant pairs for the validation set

retreived_pairs = retrieve_for_validation_set(val_dataset, index, val_contexts)

100%|██████████| 2000/2000 [00:21<00:00, 94.87it/s]


In [None]:
# store the retireved contexts for validation set along with questions
with open("../files/val_retrieved_pairs_base_2.json", "w") as f:
    json.dump(retreived_pairs, f)
