# Import necessary libraries

In [None]:
"""
Imports libraries for data handling, embedding generation, FAISS-based similarity search,
and neural network computations.
"""
import pandas as pd
import json
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np


# Load BioBERT model


In [None]:
"""
Loads the BioBERT model and tokenizer for generating embeddings from biomedical texts.
The model used is 'dmis-lab/biobert-base-cased-v1.1'.
"""
print("Loading BioBERT model...")
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
biobert_model = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1").to("cuda")  # Move model to GPU
print("BioBERT loaded successfully!")


# Generate embeddings and initialize FAISS


In [None]:
"""
Generates embeddings for a list of contexts and initializes a FAISS index for similarity search.

Args:
    contexts (list): List of context strings to generate embeddings for.
    batch_size (int, optional): Number of contexts to process in each batch. Defaults to 16.

Returns:
    tuple: FAISS index and a list of embeddings.
"""
def initialize_faiss_with_embeddings(contexts, batch_size=16):
    print("Generating embeddings and initializing FAISS index...")
    dimension = 768  # BioBERT CLS token dimension
    index = faiss.IndexFlatIP(dimension)  # Inner-product for similarity

    embeddings = []
    for i in range(0, len(contexts), batch_size):
        batch_texts = contexts[i:i+batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")
        with torch.no_grad():
            outputs = biobert_model(**inputs)
        batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()  # CLS token embeddings
        index.add(batch_embeddings)
        embeddings.append(batch_embeddings)
        torch.cuda.empty_cache()

    return index, embeddings


# Load datasets


In [None]:
"""
Loads the MedQA-USMLE, Medical Meadow, and PubMedQA datasets, and combines their contexts.

Returns:
    tuple: DataFrame for MedQA and a combined list of contexts.
"""
def load_datasets():
    medqa_path = "NLP_Project/Preprocessed/medqa_usmle_preprocessed.csv"
    medical_meadow_path = "NLP_Project/Preprocessed/medical_meadow_preprocessed.csv"
    pubmedqa_path = "NLP_Project/Preprocessed/pubmedqa_preprocessed.json"

    medqa_df = pd.read_csv(medqa_path).head(100)  # Load a subset for testing
    medical_meadow_df = pd.read_csv(medical_meadow_path)
    with open(pubmedqa_path, "r") as f:
        pubmedqa_data = json.load(f)

    medical_meadow_contexts = medical_meadow_df['output'].tolist()
    pubmedqa_contexts = [context for entry in pubmedqa_data.values() for context in entry.get('CONTEXTS', [])]

    combined_contexts = medical_meadow_contexts + pubmedqa_contexts
    print(f"Total contexts: {len(combined_contexts)}")
    return medqa_df, combined_contexts


# Hybrid retrieval function


In [None]:
"""
Retrieves the top-k most relevant contexts for each option using FAISS.

Args:
    question (str): Question text.
    options (dict): Dictionary of options (e.g., {"A": "Option A text", ...}).
    faiss_index (faiss.Index): Prebuilt FAISS index for similarity search.
    combined_contexts (list): List of context strings.
    top_k (int, optional): Number of top contexts to retrieve. Defaults to 3.
    token_limit (int, optional): Maximum token length for retrieved contexts. Defaults to 700.

Returns:
    dict: Retrieved contexts with similarity scores for each option.
"""
# used copilot for refernce
def retrieve_contexts_with_faiss(question, options, faiss_index, combined_contexts, top_k=3, token_limit=700):
    retrieved_contexts = {}

    for option_key, option_text in options.items():
        query = f"{question} {option_text}"

        # Generate query embedding
        inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")
        with torch.no_grad():
            query_embedding = biobert_model(**inputs).last_hidden_state[:, 0, :].cpu().numpy()

        # FAISS search
        distances, indices = faiss_index.search(query_embedding, top_k)

        # Retrieve contexts with token limit and similarity score
        retrieved_contexts[option_key] = []
        for idx, dist in zip(indices[0], distances[0]):
            context = " ".join(combined_contexts[idx].split()[:token_limit])
            retrieved_contexts[option_key].append({"context": context, "similarity_score": dist})

    return retrieved_contexts


# Process questions


In [None]:
"""
Processes each question in the MedQA dataset and retrieves contexts for each option.

Args:
    medqa_df (pd.DataFrame): DataFrame containing MedQA questions and options.
    faiss_index (faiss.Index): Prebuilt FAISS index for similarity search.
    combined_contexts (list): List of context strings.

Returns:
    list: Results containing questions, options, retrieved contexts, and similarity scores.
"""
def process_questions(medqa_df, faiss_index, combined_contexts):
    results = []
    for idx, row in medqa_df.iterrows():
        question = row['question']
        options = eval(row['options'])
        print(f"Processing Question {idx+1}: {question}")
        retrieved_contexts = retrieve_contexts_with_faiss(question, options, faiss_index, combined_contexts)
        result_row = {
            "question": question,
            "option_a": options.get("A", ""),
            "context_a": retrieved_contexts.get("A", [{"context": "No relevant context found.", "similarity_score": 0}])[0]["context"],
            "similarity_a": retrieved_contexts.get("A", [{"context": "No relevant context found.", "similarity_score": 0}])[0]["similarity_score"],
            "option_b": options.get("B", ""),
            "context_b": retrieved_contexts.get("B", [{"context": "No relevant context found.", "similarity_score": 0}])[0]["context"],
            "similarity_b": retrieved_contexts.get("B", [{"context": "No relevant context found.", "similarity_score": 0}])[0]["similarity_score"],
            "option_c": options.get("C", ""),
            "context_c": retrieved_contexts.get("C", [{"context": "No relevant context found.", "similarity_score": 0}])[0]["context"],
            "similarity_c": retrieved_contexts.get("C", [{"context": "No relevant context found.", "similarity_score": 0}])[0]["similarity_score"],
            "option_d": options.get("D", ""),
            "context_d": retrieved_contexts.get("D", [{"context": "No relevant context found.", "similarity_score": 0}])[0]["context"],
            "similarity_d": retrieved_contexts.get("D", [{"context": "No relevant context found.", "similarity_score": 0}])[0]["similarity_score"],
        }
        results.append(result_row)
    return results


# Save results


In [None]:
"""
Saves the retrieved results to a CSV file.

Args:
    results (list): List of results containing questions, options, contexts, and similarity scores.
    output_path (str, optional): File path to save the results. Defaults to 'NLP_Project/Preprocessed/HY_bio_faiss1.csv'.
"""
def save_results(results, output_path="NLP_Project/Preprocessed/HY_bio_faiss1.csv"):
    pd.DataFrame(results).to_csv(output_path, index=False)
    print(f"Results saved to: {output_path}")


# Main workflow

In [None]:
"""
Main script to load datasets, initialize FAISS index with BioBERT embeddings,
process questions, and save results.
"""
if __name__ == "__main__":
    # Load datasets
    medqa_df, combined_contexts = load_datasets()

    # Initialize FAISS and generate BioBERT embeddings
    faiss_index, _ = initialize_faiss_with_embeddings(combined_contexts)

    # Process questions and retrieve contexts
    results = process_questions(medqa_df, faiss_index, combined_contexts)

    # Save the results
    save_results(results)
