<a href="https://colab.research.google.com/github/NanoPrompt/AI-Chatobots/blob/main/Agentic_Clinical_Trial_RAG_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
from typing import TypedDict, List, Annotated
from langgraph.graph import StateGraph, END, START
from langgraph.prebuilt import ToolNode
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI # Using OpenAI for demonstration, easily replaceable with Gemini
from langchain_core.prompts import ChatPromptTemplate

# --- 1. Define the Graph State ---
# The state is the central object passed between all nodes in the graph.
# 'messages' holds the conversation history and tool outputs.
# 'analysis_report' is where structured, tool-generated analysis will be stored.
class ClinicalTrialState(TypedDict):
    """Represents the state of our graph for the clinical trial analysis."""
    messages: Annotated[List[dict], lambda x, y: x + y] # The main message history
    analysis_report: str # Dedicated field for structured data analysis results
    question: str # The initial user question

# --- 2. Define Specialized Tools (The "Tool Plane") ---

# Placeholder for a Vector Store Retriever (e.g., Chroma, Qdrant, or GCP Vector Search)
# In a real scenario, this would connect to a vector DB containing PolyNovo trial documents
# (protocols, patient data summaries, adverse event reports).
def get_retriever():
    # Placeholder function for RAG retrieval
    # In a production environment, this would:
    # 1. Initialize an embedding model (e.g., text-embedding-004)
    # 2. Connect to the clinical vector database
    # 3. Perform a similarity search

    # We simulate retrieval of relevant PolyNovo data
    mock_clinical_data = {
        "PolyNovo_PN1_Protocol": "The primary endpoint is 75% complete wound closure by Day 60. The trial is phase IIb.",
        "Patient_Cohort_Summary": "Initial patient cohort size was 50. Mean age 62. Primary cause of wound was burn injury (70%).",
        "Adverse_Events": "Three Grade 2 adverse events were recorded: localized infection (2), device migration (1). All resolved."
    }

    def retrieve(query: str) -> str:
        """Retrieves specific clinical trial data from the PolyNovo document corpus."""
        if "endpoint" in query.lower() or "closure" in query.lower():
            return f"Retrieved Document: {mock_clinical_data['PolyNovo_PN1_Protocol']}"
        if "patient" in query.lower() or "demographics" in query.lower():
            return f"Retrieved Document: {mock_clinical_data['Patient_Cohort_Summary']}"
        if "adverse" in query.lower() or "safety" in query.lower():
            return f"Retrieved Document: {mock_clinical_data['Adverse_Events']}"
        return "No specific, high-relevance document was found. The agent should rely on general knowledge or ask to refine the query."
    return retrieve

@tool
def retrieve_clinical_data(query: str) -> str:
    """
    Retrieves relevant clinical trial documents (e.g., protocols, patient summaries,
    safety data) from the PolyNovo vector database based on the query.
    """
    return get_retriever()(query)

@tool
def analyze_metrics(data_summary: str) -> str:
    """
    Performs specialized statistical analysis on structured clinical data summaries,
    such as calculating efficacy rates, confidence intervals, or performing survival analysis.

    Input must be a concise summary of the data extracted from documents.
    """
    # This simulates a complex statistical routine (e.g., using pandas/scikit-learn/R)
    if "75% complete wound closure" in data_summary and "50" in data_summary:
        # Placeholder for a complex statistical model running on extracted data
        mock_analysis_result = {
            "primary_endpoint_success_rate": "68%",
            "confidence_interval_95": "[54%, 82%]",
            "survival_median_days": "180 days (placeholder)",
            "key_finding": "The observed primary endpoint success rate (68%) was slightly below the 75% target."
        }
        return f"Specialized Analysis Results (JSON format):\n{json.dumps(mock_analysis_result, indent=2)}"
    return "Analysis tool executed, but insufficient structured data was provided for a comprehensive statistical report."

tools = [retrieve_clinical_data, analyze_metrics]
tool_node = ToolNode(tools)


# --- 3. Define the LLM/Agent Node ---

# Initialize the LLM (Using a placeholder for the Gemini API call)
# NOTE: In the live environment, you would use a dedicated API client for Gemini.
# For LangChain/LangGraph, we can simulate the call structure with OpenAI's client.
# Replace with the actual Gemini API integration when deployed:
# model = ChatGoogleGenerativeAI(model="gemini-2.5-flash-preview-09-2025")
model = ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(tools)

def run_agent(state: ClinicalTrialState):
    """
    The Agent (LLM) decides whether to use a tool or generate the final response.
    It's the core reasoning engine of the pipeline.
    """
    # The agent receives the whole message history, including previous tool outputs
    messages = state["messages"]

    # Add a system message to guide the LLM's role and purpose
    system_message = SystemMessage(content=(
        "You are an expert Clinical Data Analyst for PolyNovo. Your goal is to accurately "
        "analyze and report on clinical trial data. You have access to specialized tools: "
        "'retrieve_clinical_data' for finding documents and 'analyze_metrics' for running "
        "complex statistical routines on the data you retrieve. "
        "Use the tools iteratively until you have enough information to generate a comprehensive, "
        "grounded final report. If a tool returns results, call the 'analyze_metrics' tool next "
        "if appropriate data is retrieved."
    ))

    # Combine system message and chat history for the prompt
    full_prompt = [system_message] + messages

    # Invoke the model (the agent decides on the next action: tool call or final response)
    response = model.invoke(full_prompt)

    # Store the response (which might contain a tool call) back in the state
    return {"messages": [response]}


# --- 4. Define Conditional Edges (Routing Logic) ---

def should_continue(state: ClinicalTrialState) -> str:
    """
    Conditional logic to determine the next step in the graph.
    - If the LLM output includes a tool call, execute the tool.
    - Otherwise, the LLM has generated the final answer, so end the graph.
    """
    last_message = state["messages"][-1]

    # Check if the last message contains a tool call (meaning the agent wants to use a tool)
    if not isinstance(last_message, AIMessage):
        # This shouldn't happen, but good practice to check
        return END

    if last_message.tool_calls:
        print("--- ROUTER: Tool Call Detected. Proceeding to Tool Execution. ---")
        return "call_tool"
    else:
        print("--- ROUTER: Final Answer Generated. Ending Pipeline. ---")
        return END


# --- 5. Build the LangGraph Workflow ---

# Instantiate the graph with our defined state
workflow = StateGraph(ClinicalTrialState)

# Add Nodes
workflow.add_node("agent", run_agent)
workflow.add_node("call_tool", tool_node)

# Set the Entry Point (The flow always starts with the Agent)
workflow.add_edge(START, "agent")

# Set Conditional Edges
# After the agent runs, check if it wants to use a tool or end.
workflow.add_conditional_edges(
    "agent",
    should_continue,
    {
        "call_tool": "call_tool", # If tool is called, go to the tool node
        END: END                # If not, end the graph
    }
)

# Set Normal Edges
# After a tool runs, the output is fed back to the agent for further reasoning.
workflow.add_edge("call_tool", "agent")

# Compile the graph
app = workflow.compile()


# --- 6. Example Execution ---

# Define the initial user question (This is the input to the pipeline)
initial_question = "What was the primary endpoint success rate for the PolyNovo PN1 clinical trial, considering the safety profile?"
initial_state = {
    "question": initial_question,
    "messages": [HumanMessage(content=initial_question)],
    "analysis_report": ""
}

print(f"\n--- STARTING PIPELINE for Query: {initial_question} ---\n")

# Run the pipeline
final_state = None
for step in app.stream(initial_state):
    print(f"Current Node Output: {list(step.keys())[0]}")
    final_state = step

# Extract the final response
if final_state and final_state.get('agent'):
    final_message = final_state['agent']['messages'][-1].content
    print("\n--- FINAL REPORT ---\n")
    print(final_message)
else:
    print("\n--- PIPELINE COMPLETED WITHOUT FINAL AGENT MESSAGE ---")

# You would also typically persist the conversation state using a checkpointer
# For example: app = workflow.compile(checkpointer=MemorySaver())