In [None]:
#install required libraries

In [None]:
import pandas as pd
import numpy as np
import json
import re
from typing import List, Dict, Any, Tuple
from sentence_transformers import SentenceTransformer
from openai import OpenAI
import time
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import seaborn as sns
from dotenv import load_dotenv
import openai
import os
from langchain_community.retrievers import TavilySearchAPIRetriever

In [None]:
load_dotenv()

groq_api_key = os.getenv("GROQ_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")

In [None]:
df_qa = pd.read_csv('qna_dataset.csv')
df_qa = df_qa.sample(500, random_state=0).reset_index(drop=True)
df_qa.head()

In [None]:
#combine question and answer into one column
df_qa['combined_column']=(
    "Question : "+df_qa['Question'].astype('str') + '. ' +
    " Answer : "+df_qa['Answer'].astype('str') + '. ' +
    "Type : " + df_qa['qtype'].astype('str') + '. '
)
df_qa.head()

In [None]:
df_md = pd.read_csv('medical_device_manuals_dataset.csv')
df_md = df_md.sample(500, random_state=0).reset_index(drop=True)
df_md.head()

In [None]:
df_md['combined_column']= (
    "Device Name: " + df_md['Device_Name'].astype('str') + '. ' +
    "Model: " + df_md['Model_Number'].astype('str') + '. ' +
    "Manufacturer: " + df_md['Manufacturer'].astype('str') + '. ' +
    "Indications: "+ df_md['Indications_for_Use'].astype('str') + '. ' +
    "Contraindications: "+ df_md['Contraindications'].fillna('None').astype('str')
)
df_md.head()

In [None]:
df_md['combined_column'] = (
    df_md['combined_column']
    .fillna("")
    .astype(str)
    .str.strip()
    )

In [None]:
import chromadb

client = chromadb.PersistentClient(path="./chroma_db")

In [None]:
collection1 = client.get_or_create_collection(name="medical_qna")

In [None]:
collection1.add(
    documents = df_qa['combined_column'].tolist(),
    metadatas = df_qa.to_dict(orient='records'),
    ids=df_qa.index.astype(str).tolist()
)
print("\n Medical Q&A collection created and data added.")

In [None]:
collection2 = client.get_or_create_collection(name="medical_device_manual")

In [None]:
collection2.add(
    documents = df_md['combined_column'].tolist(),
    metadatas = df_md.to_dict(orient='records'),
    ids=df_md.index.astype(str).tolist()
)
print("\n Medical Device Manual collection created and data added.")

In [None]:
query = "What are devices used in surgery"
results = collection2.query(query_texts=[query], n_results=5)
print("\nQuery Results:", results)

In [None]:
from langchain_community.retrievers import TavilySearchAPIRetriever
# Initialize retriever
retriever = TavilySearchAPIRetriever(
    api_key=os.getenv("TAVILY_API_KEY"),
    k=4 )
query = "What is the speciality of Momento"
# Use invoke() correctly
results = retriever.invoke(query)
results

In [None]:
from groq import Groq
import os
from dotenv import load_dotenv

load_dotenv()

client = Groq(api_key=os.getenv("GROQ_API_KEY"))

response = client.chat.completions.create(
    model="llama-3.3-70b-versatile",
    messages=[
        {"role": "user", "content": "What is the speciality of Dunkirk?"}
    ]
)

print(response.choices[0].message.content)

def call_llm(state):
    prompt = state["promt"]
    response = get_llm_response(prompt)
    state["response"] = response
    return state


In [None]:
def get_llm_response(prompt: str) -> str:
    response = client.chat.completions.create(
        model="llama-3.3-70b-versatile",
        messages=[{"role": "user", "content": prompt}]
    )
    return response.choices[0].message.content


In [None]:
from typing import TypedDict, Dict
from langgraph.graph import StateGraph, START, END

# 1. STATE SCHEMA (must be defined BEFORE node functions)

class GraphState(TypedDict):
    query: str
    context: str
    promt: str
    response: str
    source: str
    is_relevant: str
    iteration_count: int


# 2. NODE FUNCTIONS

def retrieve_context_qna(state: Dict) -> Dict:
    print("\nRetrieving context for Q&A...")
    query = state["query"]

    results = collection1.query(query_texts=[query], n_results=3)
    raw_docs = results.get("documents", [[]])[0]
    safe_docs = [str(x) for x in raw_docs if x is not None]
    context = "\n".join(safe_docs)

    state["context"] = context
    return state


def retrieve_context_md(state: Dict) -> Dict:
    print("\nRetrieving context for Medical Device Manual...")
    query = state["query"]

    results = collection2.query(query_texts=[query], n_results=3)
    raw_docs = results.get("documents", [[]])[0]
    safe_docs = [str(x) for x in raw_docs if x is not None]
    context = "\n".join(safe_docs)

    state["context"] = context
    return state


def tavily_web_search(state: Dict) -> Dict:
    print("\nPerforming Tavily web search...")
    query = state["query"]

    results = retriever.invoke(query)
    context = "\n".join(str(doc.page_content) for doc in results)

    state["context"] = context
    return state


def router(state: Dict) -> Dict:
    query = state["query"]

    decision_prompt = f"""
You are a routing agent. Based on the user query, choose exactly one:
- retrieve qna
- retrieve device
- web search

Query: {query}

Respond with exactly one of:
retrieve qna
retrieve device
web search
"""

    decision = get_llm_response(decision_prompt).strip().lower().replace(".", "")
    print("\nRouter decision:", decision)

    state["source"] = decision
    return state


def route_decision(state: Dict) -> str:
    return state["source"]


def check_relevance(state: Dict) -> Dict:
    print("\nChecking relevance of retrieved context...")
    query = state["query"]
    context = state["context"]

    relevance_prompt = f"""
Check if the context is relevant to the query. Respond with only 'Yes' or 'No'.

Context:
{context}

Query:
{query}
"""

    decision = get_llm_response(relevance_prompt).strip()
    print("Relevance decision:", decision)

    state["is_relevant"] = decision
    return state


def relevance_decision(state: Dict) -> str:
    count = state.get("iteration_count", 0) + 1
    state["iteration_count"] = count

    if count >= 3:
        print("\nMax iterations reached. Forcing Yes.")
        state["is_relevant"] = "Yes"

    return state["is_relevant"]


def build_prompt(state: Dict) -> Dict:
    query = state["query"]
    context = state["context"]

    prompt = f"""
You are a medical assistant AI. Use ONLY the context below to answer the user's question.

Context:
{context}

User Query:
{query}

Provide a clear, concise answer.
"""

    state["promt"] = prompt
    return state


def call_llm(state: Dict) -> Dict:
    response = get_llm_response(state["promt"])
    state["response"] = response
    return state


# 3. BUILD GRAPH

workflow = StateGraph(GraphState)

workflow.add_node("router", router)
workflow.add_node("Retrieve_QnA", retrieve_context_qna)
workflow.add_node("Retrieve_Device", retrieve_context_md)
workflow.add_node("Web_Search", tavily_web_search)
workflow.add_node("Check_Relevance", check_relevance)
workflow.add_node("Relevance_Decision", relevance_decision)
workflow.add_node("Augment", build_prompt)
workflow.add_node("Generate", call_llm)

workflow.add_edge(START, "router")

workflow.add_conditional_edges(
    "router",
    route_decision,
    {
        "retrieve qna": "Retrieve_QnA",
        "retrieve device": "Retrieve_Device",
        "web search": "Web_Search",
    }
)

workflow.add_edge("Retrieve_QnA", "Check_Relevance")
workflow.add_edge("Retrieve_Device", "Check_Relevance")
workflow.add_edge("Web_Search", "Check_Relevance")

workflow.add_conditional_edges(
    "Check_Relevance",
    relevance_decision,
    {
        "Yes": "Augment",
        "No": "Web_Search",
    }
)

workflow.add_edge("Augment", "Generate")
workflow.add_edge("Generate", END)

agentic_rag = workflow.compile()


In [None]:
def build_prompt(state):
    """Combine query + retrieved context into a final LLM prompt."""
    query = state["query"]
    context = state["context"]

    prompt = f"""
You are a medical assistant AI. Use ONLY the context below to answer the user's question.

Context:
{context}

User Query:
{query}

Provide a clear, concise answer.
"""
    state["promt"] = prompt
    return state


In [None]:
from typing import TypedDict, Literal
from langgraph.graph import StateGraph, START, END

class GraphState(TypedDict):
    query: str
    context: str
    promt: str
    response: str
    source: str
    is_relevant: str
    iteration_count: int

# Define input_state BEFORE using it
input_state = {
    "query": "What is the treatment for cancer?",
    "context": "",
    "promt": "",
    "response": "",
    "source": "",
    "is_relevant": "",
    "iteration_count": 0
}

workflow = StateGraph(GraphState)

# Register nodes
workflow.add_node("router", router)
workflow.add_node("Retrieve_QnA", retrieve_context_qna)
workflow.add_node("Retrieve_Device", retrieve_context_md)
workflow.add_node("Web_Search", tavily_web_search)
workflow.add_node("Check_Relevance", check_relevance)
workflow.add_node("Relevance_Decision", relevance_decision)
workflow.add_node("Augment", build_prompt)
workflow.add_node("Generate", call_llm)

# Start → Router
workflow.add_edge(START, "router")

# Router → Retrieval Nodes
workflow.add_conditional_edges(
    "router",
    route_decision,
    {
        "retrieve qna": "Retrieve_QnA",
        "retrieve device": "Retrieve_Device",
        "web search": "Web_Search",
    }
)

# Retrieval → Relevance Check
workflow.add_edge("Retrieve_QnA", "Check_Relevance")
workflow.add_edge("Retrieve_Device", "Check_Relevance")
workflow.add_edge("Web_Search", "Check_Relevance")

# Relevance Check → Augment or Retry
workflow.add_conditional_edges(
    "Check_Relevance",
    relevance_decision,
    {
        "Yes": "Augment",
        "No": "Web_Search",
    }
)

# Augment → Generate → END
workflow.add_edge("Augment", "Generate")
workflow.add_edge("Generate", END)

print("Sample state shape before run:")
print(input_state.keys())

# Compile the graph
agentic_rag = workflow.compile()


In [None]:
from IPython.display import Image, display

# Mermaid PNG visualization
png_bytes = agentic_rag.get_graph().draw_mermaid_png()
display(Image(png_bytes))


In [None]:
input_state = {
    "query": "What is the treatment for cancer?",
    "context": "",
    "promt": "",
    "response": "",
    "source": "",
    "is_relevant": "",
    "iteration_count": 0
}

In [None]:
from pprint import pprint

for step in agentic_rag.stream(input_state):
    for node_name, state_value in step.items():
        print(f"Finished running: {node_name}")

pprint(state_value["response"])
