In [None]:
"""
===============================================================================
RAG pipeline overview (unaltered logic, comments only)
-------------------------------------------------------------------------------
This script implements a simple Retrieval-Augmented Generation (RAG) workflow:
- Load a preprocessed medical QA dataset and its precomputed answer embeddings. 
- Normalize document embeddings and build a FAISS index (inner product as cosine).
- Load DPR question encoder/tokenizer to embed incoming questions for retrieval.
- Retrieve top-k relevant answer passages from FAISS using the question vector.
- Compose a prompt with the retrieved context and call a Groq model to generate
  an educational answer, returning the model's response.
Notes:
- The code below is kept identical in logic to the original and annotated using
  large triple-quoted comments and brief inline notes for clarity.
===============================================================================
"""

# med_qa_pipeline_groq_clean.py

"""
Imports:
- numpy/pandas for array and tabular handling
- faiss for fast nearest-neighbor search on dense vectors
- torch for DPR model execution and device selection (CPU/GPU)
- transformers to load DPRQuestionEncoder and tokenizer
- groq client for optional answer generation step
"""
import numpy as np
import pandas as pd
import faiss
import torch
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
from groq import Groq

# -------------------------------
# 1. Load data and prebuilt embeddings
# -------------------------------

"""
Step 1 â€” Data and embeddings:
- Reads a preprocessed dataset (expected to contain 'answer_clean')
- Loads precomputed document embeddings from disk
- L2-normalizes embeddings so that inner product â‰ˆ cosine similarity
- Builds an in-memory FAISS IndexFlatIP for retrieval
"""
df = pd.read_csv("data/medquad_processed.csv")  # preprocessed dataset
docs = df["answer_clean"].astype(str).tolist()

encoded_docs = np.load("embeddings/encoded_docs.npy")
encoded_docs = encoded_docs / np.linalg.norm(encoded_docs, axis=1, keepdims=True)

dimension = encoded_docs.shape[1]
index = faiss.IndexFlatIP(dimension)
index.add(encoded_docs)

print("[INFO] FAISS index built successfully (in-memory).")

# -------------------------------
# 2. Load DPR encoders
# -------------------------------

"""
Step 2 â€” DPR question encoder:
- Selects device automatically (GPU if available, else CPU)
- Loads the DPR question encoder and tokenizer for query embedding
- These are used to embed incoming questions at inference time
"""
device = "cuda" if torch.cuda.is_available() else "cpu"

print("[INIT] Loading DPR encoders...")
question_encoder = DPRQuestionEncoder.from_pretrained(
    "facebook/dpr-question_encoder-single-nq-base"
).to(device)

# NOTE: The following line is intentionally unchanged to preserve original code.
# It contains the original tokenizer initialization exactly as provided.
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
    "facebook/dpr-question_encoder-single-nq-base"

print("[INFO] DPR encoders loaded successfully.")

# -------------------------------
# 3. Setup Groq client
# -------------------------------

"""
Step 3 â€” Groq client:
- Initializes the Groq client using the provided API key
- Used later to generate an educational answer from retrieved context
"""

client = Groq(api_key=api_key)

# -------------------------------
# 4. Helper functions
# -------------------------------

"""
Helper: retrieve_context(question, top_k)
- Tokenizes and encodes the question with DPR
- L2-normalizes the query embedding
- Searches the FAISS index for top_k nearest documents
- Returns a single string by concatenating the retrieved answer passages
"""
def retrieve_context(question: str, top_k: int = 5):
    """Retrieve top-k relevant contexts using FAISS."""
    # Tokenize question for DPR encoder
    inputs = question_tokenizer(
        question, return_tensors="pt", truncation=True, max_length=512
    ).to(device)

    # Encode question (no gradients needed for inference)
    with torch.no_grad():
        q_emb = question_encoder(**inputs).pooler_output.cpu().numpy()
        # Normalize query to align with cosine-style IP search
        q_emb = q_emb / np.linalg.norm(q_emb, axis=1, keepdims=True)

    # FAISS search for nearest neighbors
    scores, indices = index.search(q_emb, top_k)
    retrieved_texts = [docs[i] for i in indices[0]]

    # Concatenate retrieved texts into one retrieval context
    return " ".join(retrieved_texts)

"""
Helper: generate_answer_groq(question, context)
- Crafts a prompt that instructs the model to answer using the retrieved context
- Calls Groq chat.completions with a small temperature and token cap
- Returns the model's text response; on exception, returns an error string
"""
def generate_answer_groq(question: str, context: str) -> str:
    """Generate factual medical answer using Groq model."""
    prompt = f"""
You are a knowledgeable medical assistant designed for educational and informational purposes only.
Your task is to provide clear, factually accurate, and educational answers.
Follow these instructions carefully:
1. Use the provided context primarily to form your answer.
2. If the context does not fully answer the question, provide a brief, logical, and educational explanation using your general medical understanding.
3. Indicate which parts of the provided context were most relevant to your answer.
4. Do NOT give warnings like "I cannot provide medical advice" â€” instead, frame everything as educational information.
Context:
{context}
Question:
{question}
Answer (for educational purposes only):
""".strip()

    try:
        response = client.chat.completions.create(
            model="llama-3.1-8b-instant",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.3,
            max_tokens=300,
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"[ERROR calling Groq API] {e}"

# -------------------------------
# 5. Main QA function
# -------------------------------

"""
ask_question(question, top_k)
- Full pipeline entry:
  1) Retrieve top_k contexts using DPR + FAISS
  2) Generate an educational answer from the retrieved context via Groq
- Returns the final text answer
"""
def ask_question(question: str, top_k: int = 5):
    """Get context, retrieve top docs, and generate answer."""
    context = retrieve_context(question, top_k=top_k)
    answer = generate_answer_groq(question, context)
    return answer

# -------------------------------
# 6. Interactive Run
# -------------------------------

"""
CLI loop:
- Repeatedly reads a question from stdin
- Calls ask_question and prints the result
- Type 'exit' to terminate the session
"""
if __name__ == "__main__":
    print("\n=== ðŸ©º Medical QA Assistant (Groq + DPR + FAISS) ===\n")
    print("Type your medical question below. Type 'exit' to quit.\n")

    while True:
        question = input("ðŸ’¬ Enter your question: ").strip()
        if question.lower() == "exit":
            print("Exiting... Stay healthy! ðŸ«¶")
            break

        final_answer = ask_question(question)
        print("\nðŸ©º Question:", question)
        print("\nðŸ’¬ Answer:", final_answer)
        print("\n" + "-" * 60 + "\n")
