In [21]:
from transformers import AutoTokenizer, AutoModel
import torch
import os
import pandas as pd
import faiss


torch.backends.cudnn.benchmark = True  # Improves performance when input size is constant

# Function to load documents from a directory containing Parquet files
def load_documents_from_parquet(directory, title_column="title", content_column="content"):
    documents = {}
    doc_id = 0
    for filename in os.listdir(directory):
        if filename.endswith(".parquet"):
            filepath = os.path.join(directory, filename)
            try:
                df = pd.read_parquet(filepath)
                if title_column not in df.columns or content_column not in df.columns:
                    raise ValueError(f"Columns '{title_column}' or '{content_column}' not found in Parquet file: {filename}")
                for index, row in df.iterrows():
                    # Combine title and content for each document
                    combined_text = f"{row[title_column]} {row[content_column]}"
                    documents[doc_id] = combined_text
                    doc_id += 1
            except Exception as e:
                print(f"Error reading Parquet file {filename}: {e}")
    return documents

class MedRAG:
    def __init__(self, llm_name, rag, retriever_name, corpus_dir, title_column="title", content_column="content"):
        self.llm_name = llm_name
        self.rag = rag
        self.retriever_name = retriever_name
        self.corpus_dir = corpus_dir
        self.title_column = title_column
        self.content_column = content_column

        self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
        self.model = AutoModel.from_pretrained(llm_name)

        # Use GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        # Load the corpus from Parquet files
        self.corpus = load_documents_from_parquet(corpus_dir, title_column=self.title_column, content_column=self.content_column)
        if not self.corpus:
            raise ValueError(f"No documents found in directory: {corpus_dir} or columns '{self.title_column}'/'{self.content_column}' are missing")
        print(f"Loaded {len(self.corpus)} documents from {corpus_dir}")

        # Compute embeddings in batches and build FAISS index
        self.embedding_dim = self.model.config.hidden_size
        self.index = faiss.IndexFlatIP(self.embedding_dim)  # Index for cosine similarity (Inner Product)

        batch_size = 16  # Adjust batch size for optimal performance on RTX 4070
        self.corpus_embeddings = self.compute_embeddings_batched(list(self.corpus.values()), batch_size)
        self.index.add(self.corpus_embeddings.numpy())

    def compute_embeddings_batched(self, texts, batch_size):
        all_embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512)
            inputs = {key: value.to(self.device) for key, value in inputs.items()}
            with torch.no_grad():
                outputs = self.model(**inputs)
            embeddings = outputs.last_hidden_state[:, 0, :].cpu()
            all_embeddings.append(embeddings)
        return torch.cat(all_embeddings, dim=0)

    def answer(self, question, k):
        if self.rag:
            snippets, scores, doc_ids = self.retrieve(question, k)
        else:
            snippets, scores, doc_ids = [], [], []

        # Generate answer based on retrieved snippets
        selected_answer = self.select_answer(question, snippets)  # No options passed

        return selected_answer, snippets, scores, doc_ids

    def retrieve(self, question, k):
        # Compute the question embedding
        question_embedding = self.compute_embeddings_batched([question], batch_size=1)  # Use the batched function
        question_embedding = question_embedding[0].unsqueeze(0) #.numpy()

        # Search the FAISS index
        D, I = self.index.search(question_embedding.numpy(), k)

        # Retrieve snippets and scores
        doc_ids = I[0].tolist()
        snippets = [list(self.corpus.values())[i] for i in doc_ids]
        scores = D[0].tolist()

        return snippets, scores, doc_ids

    def select_answer(self, question, snippets):
        # Combine question and snippets for context
        context = f"Question: {question}\n"
        if snippets:
            context += "\nContext:\n" + "\n".join(snippets)
        context += "\nAnswer:"

        # Tokenize and encode the context
        inputs = self.tokenizer(context, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Generate the answer using the model
        with torch.no_grad():
            # Use model's generate method with appropriate parameters
            outputs = self.model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=100,  # Adjust as needed
                do_sample=True,
                top_k=50,  # Adjust as needed
                top_p=0.95,  # Adjust as needed
            )

        # Decode the generated answer
        answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        return answer

In [22]:
# --- Example Usage (Assuming data_RAG is set up) ---
question = "A lesion causing compression of the facial nerve at the stylomastoid foramen will cause ipsilateral"

# --- IMPORTANT:
# Ensure your 'data_RAG' directory exists and contains your Parquet files.
# ---

medrag = MedRAG(llm_name="facebook/bart-base",
                rag=True,
                retriever_name="FaissRetriever",  # Indicate we're using FAISS
                corpus_dir="data_RAG",
                title_column="title",
                content_column="content")

answer, snippets, scores, doc_ids = medrag.answer(question=question, k=3)

print(f"Selected Answer: {answer}")
print(f"Retrieved Snippets: {snippets}")
print(f"Retrieval Scores: {scores}")
print(f"Retrieved Doc IDs: {doc_ids}")

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Loaded 213527 documents from data_RAG


KeyboardInterrupt: 

In [8]:
import pickle
import faiss
import numpy as np

# ... (Your MedRAG class and other code) ...

# --- Saving Embeddings, FAISS Index, and Corpus ---

def save_medrag_data(medrag, save_dir):
    """Saves the MedRAG model's embeddings, FAISS index, and corpus mapping.

    Args:
        medrag: The MedRAG instance.
        save_dir: The directory to save the data to.
    """
    os.makedirs(save_dir, exist_ok=True)

    # 1. Save Embeddings (as a NumPy array)
    embeddings_file = os.path.join(save_dir, "embeddings.npy")
    np.save(embeddings_file, medrag.corpus_embeddings.numpy())
    print(f"Embeddings saved to {embeddings_file}")

    # 2. Save FAISS Index
    index_file = os.path.join(save_dir, "faiss.index")
    faiss.write_index(medrag.index, index_file)
    print(f"FAISS index saved to {index_file}")

    # 3. Save Corpus (using Pickle)
    corpus_file = os.path.join(save_dir, "corpus.pkl")
    with open(corpus_file, "wb") as f:
        pickle.dump(medrag.corpus, f)
    print(f"Corpus mapping saved to {corpus_file}")

    print(f"All MedRAG data saved to {save_dir}")

# --- Example of Saving ---

# Assuming you have initialized your MedRAG instance as 'medrag'
save_dir = "medrag_data_saved_embeddings_and_faiss_index"  # Choose a directory to save your data
save_medrag_data(medrag, save_dir)

Embeddings saved to medrag_data_saved_embeddings_and_faiss_index/embeddings.npy
FAISS index saved to medrag_data_saved_embeddings_and_faiss_index/faiss.index
Corpus mapping saved to medrag_data_saved_embeddings_and_faiss_index/corpus.pkl
All MedRAG data saved to medrag_data_saved_embeddings_and_faiss_index


In [9]:
# --- Example Usage (Assuming data_RAG is set up) ---
question = "I have a question about lung cancer. What are the risk factors?"

# --- IMPORTANT:
# Ensure your 'data_RAG' directory exists and contains your Parquet files.
# ---

answer, snippets, scores, doc_ids = medrag.answer(question=question, k=3)

print(f"Selected Answer: {answer}")
print(f"Retrieved Snippets: {snippets}")
print(f"Retrieval Scores: {scores}")
print(f"Retrieved Doc IDs: {doc_ids}")

Selected Answer: A
Retrieved Snippets: ['International predictive testing for NCLEX-RN. Rewards and risks. The United States still relies heavily on foreign nurses, and that means heavy reliance on them passing the state board examinations. Can we predict successful performance? Wells tells us the risks and rewards.', "Credit where credit's due. How to borrow money effectively. Neither a borrower or a lender be, says the old proverb. However, used effectively, credit can help you to make the most of your money - so long as you are careful!", 'Missouri leads the nation: and that needs to change. In Missouri, legislation to control tobacco use severely lags behind other states who have beaten us to the punch by passing such laws. It is time to make some changes, and you can help.']
Retrieval Scores: [217.67137145996094, 216.8662109375, 216.736083984375]
Retrieved Doc IDs: [91684, 130427, 174199]


In [None]:
# --- Example Usage (Assuming data_RAG is set up) ---
question = "My knees hurt when I run. What could be the cause?"

# --- IMPORTANT:
# Ensure your 'data_RAG' directory exists and contains your Parquet files.
# ---

answer, snippets, scores, doc_ids = medrag.answer(question=question, k=3)

print(f"Selected Answer: {answer}")
print(f"Retrieved Snippets: {snippets}")
print(f"Retrieval Scores: {scores}")
print(f"Retrieved Doc IDs: {doc_ids}")

Selected Answer: A
Retrieved Snippets: ['Missouri leads the nation: and that needs to change. In Missouri, legislation to control tobacco use severely lags behind other states who have beaten us to the punch by passing such laws. It is time to make some changes, and you can help.', "Boiling snow in north Romania. One day Andy Beazley was a busy dentist, the next he found himself setting off for a week in a Romanian children's hospital, loaded to the brim with blankets, bedding and toys. When he goes back in March, he hopes to take a couple of Portakabins and some drainage equipment. What did he discover while he was there and how healthy are Romanian teeth?", 'Res ipsa loquitur--putting the health care provider on the hot seat. Just when you thought you knew all of the fine points of malpractice, McMullen describes a new way of getting sued. Fortunately, she also suggests ways to avoid it.']
Retrieval Scores: [217.4948272705078, 217.46316528320312, 217.03416442871094]
Retrieved Doc IDs

In [20]:
# --- Example Usage (Assuming data_RAG is set up) ---
question = "I have a cold. What could be the cause?"

# --- IMPORTANT:
# Ensure your 'data_RAG' directory exists and contains your Parquet files.
# ---

answer, snippets, scores, doc_ids = medrag.answer(question=question, k=3)

print(f"Selected Answer: {answer}")
print(f"Retrieved Snippets: {snippets}")
print(f"Retrieval Scores: {scores}")
print(f"Retrieved Doc IDs: {doc_ids}")

Selected Answer: B
Retrieved Snippets: ["Boiling snow in north Romania. One day Andy Beazley was a busy dentist, the next he found himself setting off for a week in a Romanian children's hospital, loaded to the brim with blankets, bedding and toys. When he goes back in March, he hopes to take a couple of Portakabins and some drainage equipment. What did he discover while he was there and how healthy are Romanian teeth?", "Credit where credit's due. How to borrow money effectively. Neither a borrower or a lender be, says the old proverb. However, used effectively, credit can help you to make the most of your money - so long as you are careful!", 'Missouri leads the nation: and that needs to change. In Missouri, legislation to control tobacco use severely lags behind other states who have beaten us to the punch by passing such laws. It is time to make some changes, and you can help.']
Retrieval Scores: [220.4746551513672, 220.32826232910156, 219.60935974121094]
Retrieved Doc IDs: [134969