In [3]:
# Install dependencies
!pip install transformers accelerate langgraph faiss-cpu sentence-transformers

# Imports
from langgraph.graph import StateGraph
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, CrossEncoder
import pandas as pd
import numpy as np
import faiss
import torch
from typing import TypedDict

# Load Data
print("Loading MedMCQA dataset and FAISS index...")
df = pd.read_csv("medmcqa_dataframe.csv")
index = faiss.read_index("medmcqa_index.faiss")
embeddings = np.load("medmcqa_embeddings.npy")

# For FAISS vector search
bi_encoder = SentenceTransformer("all-MiniLM-L6-v2")
# For re-ranking
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

# LLM for response generation
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Retrieval with FAISS + CrossEncoder
def retrieve_with_reranking(query, k=10):
    query_vec = bi_encoder.encode([query])
    distances, indices = index.search(np.array(query_vec), k)

    # Get candidate questions
    candidate_rows = df.iloc[indices[0]]
    questions = candidate_rows["question"].tolist()
    pairs = [[query, q] for q in questions]

    # CrossEncoder reranking
    scores = cross_encoder.predict(pairs)
    best_idx = int(np.argmax(scores))
    best_confidence = float(scores[best_idx])  # CrossEncoder scores are already normalized [0-1]

    return indices[0][best_idx], best_confidence

# Response Formatter
def format_response(state):
    row = df.iloc[state["index"]]
    options = [row["opa"], row["opb"], row["opc"], row["opd"]]
    correct_idx = row["cop"]
    correct_option = ['A', 'B', 'C', 'D'][correct_idx]
    correct_text = options[correct_idx]

    explanation = str(row.get("exp", "")).strip()
    subject = str(row.get("subject_name", "")).strip()
    topic = str(row.get("topic_name", "")).strip()

    response = f"""Question:
{row['question']}

Options:
A. {options[0]}
B. {options[1]}
C. {options[2]}
D. {options[3]}

Correct Answer: {correct_option}. {correct_text}
"""
    if explanation:
        response += f"\nExplanation: {explanation}"
    if subject:
        response += f"\nSubject: {subject}"
    if topic:
        response += f"\nTopic: {topic}"

    return {"answer": response.strip()}

# LangGraph Nodes
def receive_query(state):
    return {"query": state["query"]}

def search_knowledge(state):
    query = state["query"]
    index, confidence = retrieve_with_reranking(query, k=10)
    return {"index": index, "confidence": confidence, "query": query}

def reject_if_not_confident(state):
    if state["confidence"] < 0.75:
        return {"answer": "I'm sorry, I couldn't find a confident answer in my dataset."}
    return state

# LangGraph Schema
class ChatState(TypedDict):
    query: str
    index: int
    confidence: float
    answer: str

chat_graph = StateGraph(ChatState)
chat_graph.add_node("Input", receive_query)
chat_graph.add_node("Search", search_knowledge)
chat_graph.add_node("Reject", reject_if_not_confident)
chat_graph.add_node("Respond", format_response)

chat_graph.set_entry_point("Input")
chat_graph.add_edge("Input", "Search")
chat_graph.add_edge("Search", "Reject")
chat_graph.add_edge("Reject", "Respond")

chatbot = chat_graph.compile()

# Interactive Mode
while True:
    user_input = input("Ask a medical question in the openlifescienceai/medmcqa dataset (or type 'exit' to quit): ")
    if user_input.lower() in ["exit", "quit"]:
        print("Chat Ended")
        break
    result = chatbot.invoke({"query": user_input})
    print("\nAnswer:\n")
    print(result["answer"])
    print("-" * 50)

Loading MedMCQA dataset and FAISS index...
Ask a medical question in the openlifescienceai/medmcqa dataset (or type 'exit' to quit): 	 The radiograph of a 32 year old patient is shown below. The patient is asymptomatic and the lesion revealed in the radiograph is an accidental finding. The most likely diagnosis is:

Answer:

Question:
The radiograph of a 32 year old patient is shown below. The patient is asymptomatic and the lesion revealed in the radiograph is an accidental finding. The most likely diagnosis is:

Options:
A. Stafne’s bone cavity
B. Radicular Cyst
C. Dentigerous cyst
D. Lateral periodontal cyst

Correct Answer: A. Stafne’s bone cavity

Explanation: Radiological signs:
The lesion presents as a chance radiographic finding. It is a round or an ovoid (<3 cm) uniform radiolucency with a well-defined, usually corticated, margin. Stafne’s bone cavity is non-expansile and is found below the mandibular canal just anterior to the angle of the mandible. The location of Stafne’s b