In [None]:
############################################
# Team : RAGrats
# Team Members : Ali Asgar Padaria, Param Patel, Meet Zalavadiya
# 
# Code Description : This file holds code for generating the vectorStore and also prestoring the retrieved contexts via the retriever for easy access for the models.
#                    It uses SeneteceTransformer model for embedding the text and FAISS for vector storage.
#                    
# NLP Concepts Usage: Tokenization, Embeddings
#
# System : GCP Server L4 GPU
#############################################

In [20]:
# Imports
from transformers import AutoTokenizer, AutoModel
import torch
import faiss
import numpy as np
from collections import defaultdict
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM
import json
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from datasets import concatenate_datasets


In [None]:
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # NLP Concept: Embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
tokenizer = model.tokenizer
core_model = model._first_module().auto_model

In [5]:
train_dataset = load_from_disk("../files/train_dataset")
val_dataset = load_from_disk("../files/val_dataset")

In [38]:
# Mean Pooling Function
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, dim=1) / torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)

# Encode Function
def encode_texts(texts, batch_size=8):
    all_embeddings = []

    for i in (range(0, len(texts), batch_size)):
        batch_texts = texts[i:i+batch_size]
        encoded_input = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt', max_length=512).to(device)

        with torch.no_grad():
            with torch.amp.autocast(device_type=device.type, dtype=torch.float16):
                model_output = core_model(**encoded_input)
                embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

        # Normalize embeddings
        embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True)
        all_embeddings.append(embeddings.cpu().numpy())

        torch.cuda.empty_cache()

    return np.vstack(all_embeddings)

In [None]:
# Get Passages from Train and Val Datasets
train_contexts = list({context for item in train_dataset for context in item["context"]["contexts"]})
val_contexts = list({context for item in val_dataset for context in item["context"]["contexts"]})

# merge train and val contexts
all_contexts = list(set(train_contexts) | set(val_contexts))

In [None]:
# eoncode all contexts
context_embeddings = encode_texts(all_contexts, batch_size=16)

100%|██████████| 3950/3950 [01:23<00:00, 47.37it/s]


In [19]:
# Build FAISS Index
index = faiss.IndexFlatL2(context_embeddings.shape[1])
index.add(context_embeddings)
faiss.write_index(index, "../files/faiss.index")

# Save Contexts
with open("../files/contexts.json", "w") as f:
    json.dump(all_contexts, f)

### Evaluate Retreiver

In [23]:
# Concatenate Datasets
full_dataset = concatenate_datasets([train_dataset, val_dataset])

In [53]:
def retrieval_evaluation(dataset, index, contexts, k=5):
    total_matched = 0
    total_expected = 0

    for item in tqdm(dataset.select(range(20000))):
        query = item["question"]
        query_embedding = encode_texts([query])[0]
        gold_contexts = set(item["context"]["contexts"])

        _, I = index.search(query_embedding.reshape(1, -1), len(gold_contexts))

        retrieved_contexts = {contexts[i] for i in I[0]}
        
        # Check if any of the retrieved contexts match the answers
        matched = len(gold_contexts & retrieved_contexts)
        total_matched += matched
        total_expected += len(gold_contexts)

    return total_matched / total_expected

In [54]:
retrieval_accuracy = retrieval_evaluation(full_dataset, index, all_contexts)
print(f"Retrieval Accuracy: {retrieval_accuracy:.4f}")

100%|██████████| 20000/20000 [19:29<00:00, 17.10it/s]

Retrieval Accuracy: 0.5941





The retriever has accuracy of 59% shows on an average around 3 of 5 contexts retreived for each question were in the original context set.

### Store Retreived Context Pairs for Validation Set

In [55]:
val_dataset

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 2000
})

In [None]:
# use the FAISS index to retrieve the contexts for the validation set and store them in a dict object

def retrieve_for_validation_set(dataset, index, contexts, k=5):
    retreived_pairas = []
    for item in tqdm(dataset):
        query = item["question"]
        query_embedding = encode_texts([query])[0]

        _, I = index.search(query_embedding.reshape(1, -1), 5)

        retrieved_contexts = [contexts[i] for i in I[0]]
        
        # store question and retireved contexts in a dict object
        retreived_pairas.append({
            "question": query,
            "retrieved_contexts": retrieved_contexts
        })
    return retreived_pairas


In [59]:
retreived_pairs = retrieve_for_validation_set(val_dataset, index, all_contexts)

100%|██████████| 2000/2000 [01:55<00:00, 17.25it/s]


In [None]:
# save this file
with open("../files/val_retrieved_pairs_base_1.json", "w") as f:
    json.dump(retreived_pairs, f)

# saved this data for future direct extraction on validation set.
