# Import necessary libraries

In [None]:

"""
Imports libraries for data processing, BM25 retrieval, BioBERT embedding generation,
and hybrid retrieval tasks.
"""
import pandas as pd
import json
from transformers import AutoTokenizer, AutoModel
from rank_bm25 import BM25Okapi
import torch


# Load BioBERT model


In [None]:
"""
Loads the BioBERT model for generating contextual embeddings from biomedical texts.
The model used is 'dmis-lab/biobert-base-cased-v1.1'.
"""
print("Loading BioBERT model...")
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
biobert_model = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1").to("cuda")  # Move model to GPU
print("BioBERT loaded successfully!")


# Initialize BM25 and BioBERT embeddings


In [None]:
"""
Initializes BM25 for term-based retrieval and generates BioBERT embeddings for all contexts.

Args:
    contexts (list): List of context strings to index with BM25 and embed with BioBERT.

Returns:
    tuple: BM25 instance and tensor of BioBERT embeddings.
"""
def initialize_bm25_and_embeddings(contexts):
    print("Initializing BM25...")
    bm25 = BM25Okapi([doc.split() for doc in contexts])
    print("BM25 initialized!")

    print("Generating BioBERT embeddings for contexts...")
    embeddings = []
    # used stackoverflow
    batch_size = 16  # Batch size for efficient memory usage
    for i in range(0, len(contexts), batch_size):
        batch_texts = contexts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")
        with torch.no_grad():
            outputs = biobert_model(**inputs)
        embeddings.append(outputs.last_hidden_state[:, 0, :].cpu())  # Use CLS token embeddings
        torch.cuda.empty_cache()  # Clear GPU memory
    print("BioBERT embeddings generated!")
    return bm25, torch.cat(embeddings, dim=0)


# Load datasets


In [None]:
"""
Loads MedQA-USMLE, Medical Meadow, and PubMedQA datasets, and combines their contexts.

Returns:
    tuple: DataFrame for MedQA and a combined list of contexts.
"""
def load_datasets():
    medqa_path = "NLP_Project/Preprocessed/medqa_usmle_preprocessed.csv"
    medical_meadow_path = "NLP_Project/Preprocessed/medical_meadow_preprocessed.csv"
    pubmedqa_path = "NLP_Project/Preprocessed/pubmedqa_preprocessed.json"

    medqa_df = pd.read_csv(medqa_path).head(100)  # Load subset for testing
    medical_meadow_df = pd.read_csv(medical_meadow_path)
    with open(pubmedqa_path, "r") as f:
        pubmedqa_data = json.load(f)

    medical_meadow_contexts = medical_meadow_df['output'].tolist()
    pubmedqa_contexts = [context for entry in pubmedqa_data.values() for context in entry.get('CONTEXTS', [])]

    combined_contexts = medical_meadow_contexts + pubmedqa_contexts
    print(f"Total contexts: {len(combined_contexts)}")
    return medqa_df, combined_contexts


# Hybrid retrieval function


In [None]:
"""
Retrieves top-k contexts for each option using BM25 and ranks them with BioBERT.

Args:
    question (str): Question text.
    options (dict): Dictionary of options (e.g., {"A": "Option A text", ...}).
    bm25 (BM25Okapi): BM25 instance for term-based retrieval.
    context_embeddings (torch.Tensor): Precomputed BioBERT embeddings for all contexts.
    combined_contexts (list): List of context strings.
    top_k (int, optional): Number of top contexts to retrieve. Defaults to 3.
    token_limit (int, optional): Maximum token length for retrieved contexts. Defaults to 700.

Returns:
    tuple: Retrieved contexts and their similarity scores.
"""
def hybrid_retrieve_contexts(question, options, bm25, context_embeddings, combined_contexts, top_k=3, token_limit=700):
    retrieved_contexts = {}
    similarity_scores = {}

    for option_key, option_text in options.items():
        query = f"{question} {option_text}"

        # BM25 retrieval
        bm25_scores = bm25.get_scores(query.split())
        bm25_top_indices = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:50]
        bm25_top_contexts = [combined_contexts[i] for i in bm25_top_indices]

        # BioBERT ranking
        bm25_top_embeddings = torch.stack([context_embeddings[i] for i in bm25_top_indices])
        inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")
        with torch.no_grad():
            query_embedding = biobert_model(**inputs).last_hidden_state[:, 0, :].cpu()
        cosine_similarities = torch.nn.functional.cosine_similarity(query_embedding, bm25_top_embeddings, dim=1)

        # Select top context and score
        top_index = torch.argmax(cosine_similarities).item()
        context = bm25_top_contexts[top_index]
        similarity_score = cosine_similarities[top_index].item()
        truncated_context = ' '.join(context.split()[:token_limit])

        retrieved_contexts[option_key] = truncated_context
        similarity_scores[option_key] = similarity_score

    return retrieved_contexts, similarity_scores


# Process questions


In [None]:
"""
Processes each question in the MedQA dataset and retrieves relevant contexts.

Args:
    medqa_df (pd.DataFrame): DataFrame containing MedQA questions and options.
    bm25 (BM25Okapi): BM25 instance for term-based retrieval.
    context_embeddings (torch.Tensor): Precomputed BioBERT embeddings for all contexts.
    combined_contexts (list): List of context strings.

Returns:
    list: Results containing questions, options, retrieved contexts, and similarity scores.
"""
def process_questions(medqa_df, bm25, context_embeddings, combined_contexts):
    results = []
    for idx, row in medqa_df.iterrows():
        question = row['question']
        options = eval(row['options'])
        print(f"Processing Question {idx+1}: {question}")
        retrieved_contexts, similarity_scores = hybrid_retrieve_contexts(question, options, bm25, context_embeddings, combined_contexts)

        result_row = {
            "question": question,
            "option_a": options.get("A", ""),
            "context_a": retrieved_contexts.get("A", "No relevant context found."),
            "similarity_a": similarity_scores.get("A", "N/A"),
            "option_b": options.get("B", ""),
            "context_b": retrieved_contexts.get("B", "No relevant context found."),
            "similarity_b": similarity_scores.get("B", "N/A"),
            "option_c": options.get("C", ""),
            "context_c": retrieved_contexts.get("C", "No relevant context found."),
            "similarity_c": similarity_scores.get("C", "N/A"),
            "option_d": options.get("D", ""),
            "context_d": retrieved_contexts.get("D", "No relevant context found."),
            "similarity_d": similarity_scores.get("D", "N/A"),
        }
        results.append(result_row)
    return results


# Save results


In [None]:
"""
Saves the retrieved results to a CSV file.

Args:
    results (list): List of results containing questions, options, contexts, and similarity scores.
    output_path (str, optional): File path to save the results. Defaults to 'NLP_Project/Retrieved/TRY/hy_bm_bio1.csv'.
"""
def save_results(results, output_path="NLP_Project/Retrieved/TRY/hy_bm_bio1.csv"):
    pd.DataFrame(results).to_csv(output_path, index=False)
    print(f"Results saved to: {output_path}")


# Main workflow

In [None]:

"""
Main script to load datasets, initialize retrieval models, process questions,
and save the results.
"""
if __name__ == "__main__":
    # Load datasets
    medqa_df, combined_contexts = load_datasets()

    # Initialize BM25 and generate BioBERT embeddings
    bm25, context_embeddings = initialize_bm25_and_embeddings(combined_contexts)

    # Process questions and retrieve contexts
    results = process_questions(medqa_df, bm25, context_embeddings, combined_contexts)

    # Save the results
    save_results(results)
