Project Phase 1: Stepwise API Exploration

Step 1: Import Libraries


In [None]:
!pip install -q requests pandas streamlit pyngrok faiss-cpu sentence-transformers numpy

import requests
import pandas as pd
import json
import hashlib
from datetime import datetime
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Secure KEY INPUT
import os
import getpass

# Securely Capture Key
# Input will be invisible. Paste key and press Enter.
key_input = getpass.getpass("üîë Enter Gemini API Key (Invisible Input): ")

if not key_input.startswith("AIza"):
    print("‚ö†Ô∏è Warning: Key might be invalid (usually starts with 'AIza').")
else:
    print("‚úÖ API Key captured securely in Environment Variable.")

# 2. Set as Environment Variable for the Session
os.environ["GEMINI_API_KEY"] = key_input

In [None]:
%%writefile build_embeddings.py
import pandas as pd
import numpy as np
import faiss
import json
from sentence_transformers import SentenceTransformer

# === REAL PATH (from readlink) ===
BASE = "/content/drive/MyDrive/LLM_Based_GenAI_Sem1/data/"

# ---------------------------------------------
# Load Data
# ---------------------------------------------
df = pd.read_csv(f"{BASE}/clinical_trials_diabetes_full.csv")

df["status"] = df["status"].astype(str).str.strip().str.title()
bad_status = ["Terminated", "Withdrawn", "Suspended", "No Longer Available", "Unknown"]
df_clean = df[~df["status"].isin(bad_status)].copy()

# ---------------------------------------------
# Chunking
# ---------------------------------------------
chunks = []
chunk_map = []

for idx, row in df_clean.iterrows():
    title = str(row.get("brief_title", "")).strip()
    summary = str(row.get("brief_summary", "")).strip()

    if len(summary) < 20:
        continue

    text = f"Title: {title}\nSummary: {summary}"
    chunks.append(text)

    chunk_map.append({
        "nct_id": row["nct_id"],
        "title": title,
        "text": text,
        "status": row["status"]
    })

print(f"Created {len(chunks)} chunks.")

# ---------------------------------------------
# Embeddings
# ---------------------------------------------
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embed_model.encode(chunks, batch_size=64, show_progress_bar=True)

np.save(f"{BASE}/clinical_trials_diabetes_full_embeddings.npy", embeddings)
print("Saved clinical_trials_diabetes_full_embeddings.npy")

# ---------------------------------------------
# Save chunk map
# ---------------------------------------------
with open(f"{BASE}/clinical_trials_diabetes_full_chunk_map.json", "w") as f:
    json.dump(chunk_map, f)

print("Saved clinical_trials_diabetes_full_chunk_map.json")

# ---------------------------------------------
# Build & Save FAISS
# ---------------------------------------------
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings).astype("float32"))
faiss.write_index(index, f"{BASE}/clinical_trials_diabetes_full_faiss.index")

print("Saved clinical_trials_diabetes_full_faiss.index")
print("‚úÖ Embedding build COMPLETE.")


In [None]:
!python build_embeddings.py

In [None]:
%%writefile utils.py
import json
import hashlib
from datetime import datetime

import faiss
from sentence_transformers import SentenceTransformer

# --- Confidence score from distance ---

def calculate_confidence_score(distance: float, normalization_factor: float = 1.0) -> float:
    """Inverse L2 distance score in (0,1]; closer = higher confidence."""
    return normalization_factor / (normalization_factor + float(distance))


# --- Load pre-built index + chunk map ---

def load_data_and_index(chunk_map_path: str, faiss_path: str):
    """Loads pre-built chunks and FAISS index for quick startup."""
    print("‚è≥ Loading pre-built RAG index...")

    with open(chunk_map_path, "r") as f:
        chunk_map = json.load(f)

    index = faiss.read_index(faiss_path)

    embed_model = SentenceTransformer("all-MiniLM-L6-v2")

    print(f"‚úÖ RAG Index Ready: {index.ntotal} vectors loaded.")
    return embed_model, index, chunk_map


# --- Provenance logging ---

def log_provenance_step(agent_name: str, input_data, output_data, detail=None):
    """
    Creates a detailed log entry for a single agent step.
    """
    log_entry = {
        "timestamp": datetime.now().isoformat(),
        "agent": agent_name,
        "input": input_data,
        "output": output_data,
        "detail": detail or {},
        "model_version": "gemini-2.0-flash",
    }
    return log_entry


# --- Reproducibility hash ---

def generate_reproducibility_hash(conversation_history, corpus_version: str = "v1.0"):
    """
    Generates a deterministic session hash based on the conversation history.
    """
    queries = [turn.get("query", "") for turn in conversation_history]
    raw = f"{corpus_version}|{'|'.join(queries)}"
    return hashlib.md5(raw.encode("utf-8")).hexdigest()


In [None]:
%%writefile run_bot.py
import json
import re
import os
import sys
from typing import List, Dict, Any

import numpy as np
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold

# --- Updated Import: Robust Cross-Encoder Initialization ---
CrossEncoder = None
try:
    from sentence_transformers import CrossEncoder
    print("‚úÖ sentence_transformers imported successfully.")
except ImportError:
    print("‚ö†Ô∏è sentence_transformers not found. Reranking will be disabled.")
except Exception as e:
    print(f"‚ö†Ô∏è Error importing CrossEncoder: {e}. Reranking disabled.")

from utils import (
    load_data_and_index,
    log_provenance_step,
    generate_reproducibility_hash,
    calculate_confidence_score,
)


# --- NEW CONFIG (SECURE & 2.0 MODEL) ---
API_KEY = os.environ.get("GEMINI_API_KEY")

if not API_KEY:
    print("‚ùå ERROR: API Key not found. Please run the 'Secure Input' cell first.")
    sys.exit(1)

genai.configure(api_key=API_KEY)

# Using the Experimental 2.0 Flash endpoint
gemini_model = genai.GenerativeModel("models/gemini-2.0-flash")

CHUNK_PATH = "/content/drive/MyDrive/LLM_Based_GenAI_Sem1/data/clinical_trials_diabetes_full_chunk_map.json"
FAISS_PATH = "/content/drive/MyDrive/LLM_Based_GenAI_Sem1/data/clinical_trials_diabetes_full_faiss.index"

# Load embedding model, FAISS index, and chunk metadata
embed_model, faiss_index, chunk_map = load_data_and_index(CHUNK_PATH, FAISS_PATH)

# --- NEW: Reranker Initialization ---
reranker = None
if CrossEncoder:
    try:
        print("‚è≥ Loading Reranker Model (Cross-Encoder)...")
        # High precision reranker
        reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        print("‚úÖ Reranker Loaded.")
    except Exception as e:
        print(f"‚ö†Ô∏è Reranker model download failed (using pure FAISS): {e}")


# --- NEW PARSER (UPDATED) ---

class SymptomParser:
    def __init__(self, model):
        self.model = model

    def parse(self, text: str):
        """
        Enhanced parser for clinical trial search queries
        """
        prompt = (
            "You are a clinical trial search classifier for diabetes research.\n"
            "Your job is to determine if the user wants to SEARCH for trials or just learn about diabetes.\n\n"

            f'User Input: "{text}"\n\n'

            "Classification Rules:\n"
            "1. If query contains 'trial', 'study', 'research', 'clinical', or asks 'what trials', 'show me trials', 'are there trials' ‚Üí intent='trial_search'\n"
            "2. If user states personal info like age, conditions, medications ‚Üí intent='profile_info'\n"
            "3. If asking general education questions like 'what is X?', 'how does Y work?' (WITHOUT asking about trials) ‚Üí intent='general_question'\n"
            "4. Simple greetings ‚Üí intent='greeting'\n"
            "5. Not about diabetes ‚Üí intent='off_topic'\n\n"

            "Return ONLY valid JSON:\n"
            "{\n"
            '  "intent": "trial_search" | "profile_info" | "general_question" | "greeting" | "off_topic",\n'
            '  "query_type": "trial_query" | "profile_statement" | "knowledge_seeking" | "greeting",\n'
            '  "search_keywords": ["keyword1", "keyword2"],  // Extract main search terms\n'
            '  "is_diabetes_related": true/false,\n'
            '  "user_question": "the question in plain English",\n'
            '  "trial_interest": "what type of trial they want (diet, medication, technology, etc.)"\n'
            "}\n\n"

            "Examples:\n"
            '- "What trials study liraglutide?" ‚Üí intent="trial_search", search_keywords=["liraglutide"]\n'
            '- "I\'m 55 with diabetes" ‚Üí intent="profile_info"\n'
            '- "What is HbA1c?" ‚Üí intent="general_question"\n'
        )

        try:
            res = self.model.generate_content(prompt)
            raw = (res.text or "").strip()
            match = re.search(r"\{.*\}", raw, re.DOTALL)
            if match:
                parsed = json.loads(match.group(0))
            else:
                parsed = json.loads(raw)

            # Force trial_search if keywords present
            text_lower = text.lower()
            trial_keywords = ['trial', 'study', 'studies', 'research', 'clinical', 'show me', 'are there', 'what trials']
            if any(kw in text_lower for kw in trial_keywords):
                parsed["intent"] = "trial_search"
                parsed["query_type"] = "trial_query"

        except Exception as e:
            # Fallback with keyword detection
            text_lower = text.lower()
            if any(kw in text_lower for kw in ['trial', 'study', 'research']):
                parsed = {
                    "intent": "trial_search",
                    "query_type": "trial_query",
                    "search_keywords": [text],
                    "is_diabetes_related": True,
                    "user_question": text,
                    "trial_interest": "general"
                }
            else:
                parsed = {
                    "intent": "general_question",
                    "query_type": "knowledge_seeking",
                    "search_keywords": [],
                    "is_diabetes_related": True,
                    "user_question": text,
                    "trial_interest": None
                }

        log = log_provenance_step("SymptomParser", text, parsed)
        return parsed, log





# --- NEW PROFILE AGENT (STATEFUL) ---
class ProfileAgent:
    def __init__(self, initial_profile: Dict[str, Any] = None):
        if initial_profile is None:
            initial_profile = {
                "user_id": "Patient",
                "conditions": ["diabetes"], # Default context
                "extracted_conditions": [], # Dynamic memory
                "history": [],
            }
        self.profile = initial_profile

    def update_profile(self, turn_data: Dict[str, Any]):
        """
        Updates history and extracts persistent medical entities.
        """
        self.profile.setdefault("history", []).append(turn_data)
        self.profile.setdefault("extracted_conditions", [])

        # Heuristic: Add new symptoms found in this turn to the persistent profile
        parsed = turn_data.get("parsed", {})
        new_symptoms = parsed.get("symptoms", [])

        if new_symptoms:
            current_conditions = set(self.profile["extracted_conditions"])
            for sym in new_symptoms:
                if sym and len(sym) > 3: # Avoid noise
                    current_conditions.add(sym.lower())
            self.profile["extracted_conditions"] = list(current_conditions)

        snapshot = {
            "user_id": self.profile.get("user_id", "Patient"),
            "known_conditions": self.profile.get("extracted_conditions", []),
            "num_turns": len(self.profile["history"]),
        }
        log = log_provenance_step("ProfileAgent", turn_data, {"profile_snapshot": snapshot})
        return log


# --- NEW RETRIEVAL AGENT (RERANKING) ---
class RetrievalAgent:
    def __init__(self, embed_model, faiss_index, chunk_map, profile_agent: ProfileAgent = None):
        self.embed_model = embed_model
        self.index = faiss_index
        self.chunk_map = chunk_map
        self.profile_agent = profile_agent

    def retrieve(self, parsed: Dict[str, Any], top_k: int = 5):
        # Fetch 3x candidates for reranking
        FETCH_K = top_k * 3


        symptoms = parsed.get("symptoms") or []
        context = parsed.get("context") or ""
        # user_question usually captures the intent best
        query = parsed.get("user_question") or (" ".join(symptoms) + " " + context).strip()

        if not query:
            retrieval = {"query": "", "trials": [], "avg_confidence": 0.0}
            log = log_provenance_step("RetrievalAgent", parsed, retrieval, {"reason": "empty_query"})
            return retrieval, log


        EXPANSIONS = {
            "insulin": "insulin OR insulin therapy OR insulin treatment OR insulin pump",
            "medication": "medication OR drug OR pharmaceutical OR pharmacological OR treatment",
            "diet": "diet OR dietary OR nutrition OR nutritional OR eating",
            "exercise": "exercise OR physical activity OR fitness OR activity",
            "new": "medication OR drug OR pharmacological OR treatment OR therapy OR intervention",  # Changed!
        }


        query_lower = query.lower()
        for term, expansion in EXPANSIONS.items():
            if term in query_lower:
                query = f"{query} {expansion}"
                break  # Only expand once

        # 1. FAISS Retrieval (Fast/Dense)
        q_emb = self.embed_model.encode([query])
        distances, indices = self.index.search(q_emb.astype("float32"), FETCH_K)

        initial_candidates = []
        for rank, idx in enumerate(indices[0]):
            if idx == -1: continue
            item = self.chunk_map[idx]
            dist = float(distances[0][rank])
            initial_candidates.append({
                "nct_id": item["nct_id"],
                "text": item["text"],
                "status": item["status"],
                "faiss_dist": dist,
            })

        # 2. Reranking (Cross-Encoder)
        final_trials = []
        confs = []

        if reranker and initial_candidates:
            # Score (Query, Doc) pairs
            pairs = [[query, cand["text"]] for cand in initial_candidates]
            scores = reranker.predict(pairs)

            # Attach scores
            for i, cand in enumerate(initial_candidates):
                cand["rerank_score"] = float(scores[i])

            # Sort by rerank score (descending)
            initial_candidates.sort(key=lambda x: x["rerank_score"], reverse=True)

            # Take top_k
            top_hits = initial_candidates[:top_k]

            for rank, item in enumerate(top_hits):
                # Sigmoid normalization for confidence
                logit = item["rerank_score"]
                conf = 1 / (1 + np.exp(-logit))
                confs.append(conf)

                final_trials.append({
                    "nct_id": item["nct_id"],
                    "title": item["text"].split("\n")[0].replace("Title: ", ""),
                    "text": item["text"],
                    "status": item["status"],
                    "confidence": conf,
                    "rank": rank + 1,
                    "method": "reranked"
                })
        else:
            # Fallback if reranker is not loaded
            top_hits = initial_candidates[:top_k]
            for rank, item in enumerate(top_hits):
                conf = calculate_confidence_score(item["faiss_dist"])
                confs.append(conf)
                final_trials.append({
                    "nct_id": item["nct_id"],
                    "title": item["text"].split("\n")[0].replace("Title: ", ""),
                    "text": item["text"],
                    "status": item["status"],
                    "confidence": conf,
                    "rank": rank + 1,
                    "method": "faiss_only"
                })

        avg_conf = float(np.mean(confs)) if confs else 0.0

        retrieval = {
            "query": query,
            "trials": final_trials,
            "avg_confidence": avg_conf,
        }

        detail = {
            "top_k": top_k,
            "avg_confidence": avg_conf,
            "num_trials": len(final_trials),
            "method": "reranked" if reranker else "faiss_only"
        }

        log = log_provenance_step("RetrievalAgent", parsed, retrieval, detail)
        return retrieval, log


# --- NEW DIAGNOSIS ADVISOR (CONTEXT AWARE) ---
class DiagnosisAdvisor:
    def __init__(self, model):
        self.model = model

    def _handle_general_question(self, parsed: Dict[str, Any], retrieved: Dict[str, Any]):
        """Handle general knowledge questions about diabetes"""
        trials = retrieved.get("trials", [])
        user_question = parsed.get("user_question") or " ".join(parsed.get("symptoms", []))

        # Build evidence context (top 3 trials)
        evidence_parts = []
        for t in trials[:3]:
            evidence_parts.append(f"Trial {t['nct_id']}: {t['text'][:400]}")
        evidence = "\n\n".join(evidence_parts) if evidence_parts else "No specific trials available."

        prompt = (
            "You are a diabetes health educator. Answer the user's question clearly using your medical knowledge.\n"
            "The clinical trial evidence below provides real-world context - mention it if relevant.\n\n"
            f"USER'S QUESTION: {user_question}\n\n"
            "CLINICAL TRIAL CONTEXT (for reference):\n"
            f"{evidence}\n\n"
            "Instructions:\n"
            "- Answer the question directly in 3-5 sentences\n"
            "- Be specific and educational\n"
            "- If trials mention relevant findings, cite them briefly\n"
            "- End with: 'For personalized advice, please consult your healthcare provider.'\n"
        )

        try:
            res = self.model.generate_content(prompt)
            text = (res.text or "").strip()
            if not text or len(text) < 50:
                text = "I don't have enough information to answer this question accurately. Please consult your healthcare provider."
            return text
        except Exception:
            return "Unable to generate an answer at this time. Please try rephrasing your question."


    def _handle_symptom_query(self, parsed: Dict[str, Any], retrieved: Dict[str, Any], profile: Dict[str, Any]):
        trials = retrieved.get("trials", [])
        user_input = parsed.get("user_question", "")

        # Build trial listings with more details
        trial_listings = []
        for t in trials[:5]:
            # Extract just the title from the text
            title_line = t['text'].split('\n')[0].replace('Title: ', '')

            trial_listings.append(
                f"**{t['nct_id']}** (Confidence: {t['confidence']:.0%})\n"
                f"   {title_line}\n"
                f"   Status: {t.get('status', 'Unknown')}"
            )

        trials_text = "\n\n".join(trial_listings)

        prompt = (
            "You are a clinical trial research assistant.\n"
            "The user is searching for diabetes clinical trials.\n\n"
            f"USER'S SEARCH: {user_input}\n\n"
            "RELEVANT CLINICAL TRIALS FROM DATABASE:\n"
            f"{trials_text}\n\n"
            "Instructions:\n"
            "1. Start with: 'I found [N] relevant diabetes clinical trials:'\n"
            "2. Briefly describe what each trial studies (1 sentence per trial)\n"
            "3. Use the NCT ID in your descriptions\n"
            "4. End with: 'To learn more about any trial, visit clinicaltrials.gov and search for the NCT ID. Discuss with your healthcare provider before participating.'\n"
        )

        try:
            res = self.model.generate_content(prompt)
            text = (res.text or "").strip()
            if not text or len(text) < 50:
                text = f"I found {len(trials)} trials in our database. Here are the details:\n\n{trials_text}\n\nPlease consult your healthcare provider."
            return text
        except Exception:
            return f"I found {len(trials)} trials:\n\n{trials_text}\n\nPlease consult your healthcare provider."




    def advise(self, parsed: Dict[str, Any], retrieved: Dict[str, Any], profile: Dict[str, Any]):
        trials = retrieved.get("trials", [])
        avg_conf = retrieved.get("avg_confidence", 0.0)
        query_type = parsed.get("query_type", "symptom_matching")
        is_diabetes_related = parsed.get("is_diabetes_related", True)

        draft = {
            "recommendation": "",
            "avg_confidence": avg_conf,
            "query_type": query_type,
        }

        # Handle off-topic queries
        if not is_diabetes_related:
            draft["recommendation"] = (
                "I'm specialized in diabetes-related clinical trials. Your query appears to be "
                "about symptoms or conditions not directly related to diabetes. "
                "If you have diabetes-related questions or symptoms (like high blood sugar, "
                "insulin management, complications, etc.), I'd be happy to help! "
                "Otherwise, please consult your healthcare provider for your current symptoms."
            )
            draft["confidence_veto"] = True
            log = log_provenance_step("DiagnosisAdvisor", parsed, draft, {"veto": True, "reason": "off_topic"})
            return draft, log

        # Handle low confidence
        if not trials or avg_conf < 0.05:
            draft["recommendation"] = (
                "EVIDENCE IS INSUFFICIENT TO ANSWER THIS QUESTION DIRECTLY based on the "
                "retrieved clinical trials. Please consult your healthcare provider."
            )
            draft["confidence_veto"] = True
            log = log_provenance_step("DiagnosisAdvisor", parsed, draft, {"veto": True, "reason": "low_confidence"})
            return draft, log

        # Route to appropriate handler
        if query_type == "knowledge_seeking":
            draft["recommendation"] = self._handle_general_question(parsed, retrieved)
        else:
            draft["recommendation"] = self._handle_symptom_query(parsed, retrieved, profile)

        draft["confidence_veto"] = False

        log = log_provenance_step("DiagnosisAdvisor", parsed, draft)
        return draft, log


# --- NEW SAFETY FILTER (UNCHANGED BUT RE-DECLARED) ---

class ActiveSafetyFilter:
    def __init__(self, model):
        self.model = model
        self.safety_cfg = {
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
        }

    def verify(self, advice_text: str, trials: List[Dict[str, Any]]):
        # NEW: Skip safety check for trial listing responses
        # Safety filter should only check MEDICAL ADVICE, not trial summaries

        # Detect if this is just listing trials (safe by design)
        if any(marker in advice_text for marker in ["NCT", "clinical trial", "Please discuss these findings"]):
            # This is a trial listing, not medical advice - safe by default
            log = log_provenance_step(
                "ActiveSafetyFilter",
                {"advice": advice_text},
                {"final_text": advice_text, "status": "Pass (Trial Listing)"},
            )
            return advice_text, "Pass (Trial Listing)", log

        # Otherwise, run full safety check for actual medical advice
        evidence_text = "\n".join(t["text"][:500] for t in trials[:3])  # Limit length

        audit_prompt = (
            "You are a Medical Safety Officer reviewing AI-generated advice.\n\n"
            "ADVICE:\n"
            f"{advice_text}\n\n"
            "EVIDENCE FROM CLINICAL TRIALS:\n"
            f"{evidence_text}\n\n"
            "Check for safety issues:\n"
            "- If the advice suggests stopping or changing medication without a doctor ‚Üí UNSAFE.\n"
            "- If it gives a diagnosis ‚Üí UNSAFE.\n"
            "- If it makes claims not supported by the evidence ‚Üí UNSAFE.\n"
            "- If it just lists clinical trials with disclaimers ‚Üí SAFE.\n\n"
            'If the advice is acceptable, respond with exactly: SAFE\n'
            'If it is not acceptable, respond starting with: CORRECTED: <safer version>\n'
        )

        try:
            res = self.model.generate_content(audit_prompt, safety_settings=self.safety_cfg)
            txt = (res.text or "").strip()
            if txt.startswith("SAFE") or "SAFE" in txt:
                final_text = advice_text
                status = "Pass"
            else:
                final_text = f"‚ö†Ô∏è SAFETY REVISION:\n{txt}"
                status = "Revised"
        except Exception as e:
            # Fallback: If safety API fails, check if it's trial listing
            if "NCT" in advice_text or "clinical trial" in advice_text.lower():
                final_text = advice_text  # Trial listings are safe
                status = "Pass (API Fallback)"
            else:
                final_text = "‚ö†Ô∏è Safety filter triggered. Please consult a doctor."
                status = "Revised (API Error)"

        log = log_provenance_step(
            "ActiveSafetyFilter",
            {"advice": advice_text},
            {"final_text": final_text, "status": status},
        )
        return final_text, status, log




# ============================================================
# HEALTHCARE BOT (Orchestrator)
# ============================================================

# --- NEW BOT (ORCHESTRATOR) ---
class HealthcareBot:
    def __init__(self, gemini_model, embed_model, faiss_index, chunk_map, initial_profile=None):
        self.parser = SymptomParser(gemini_model)
        self.profile_agent = ProfileAgent(initial_profile)
        self.retriever = RetrievalAgent(embed_model, faiss_index, chunk_map, self.profile_agent)
        self.advisor = DiagnosisAdvisor(gemini_model)
        self.safety = ActiveSafetyFilter(gemini_model)

        self.history: List[Dict[str, Any]] = []
        self.provenance_chain: List[Dict[str, Any]] = []


    def _handle_simple_greeting(self, user_input: str):
        user_id = self.profile_agent.profile.get("user_id", "there")
        msg = (
            f"Hello {user_id}! I'm your **Clinical Trial Research Assistant** for diabetes. üî¨\n\n"
            "I can help you find relevant diabetes clinical trials from a database of **22,000+ studies**.\n\n"
            "**Try asking:**\n"
            "- 'What trials are studying insulin therapy?'\n"
            "- 'Show me trials about low-carb diets'\n"
            "- 'Are there trials testing new medications?'\n"
            "- 'I'm 55 with type 2 diabetes, what trials can I join?'\n\n"
            "I search real trial data from ClinicalTrials.gov. How can I help you explore diabetes research today?"
        )

        log = log_provenance_step("GreetingAgent", user_input, msg, {"type": "greeting"})
        self.provenance_chain.append(log)

        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])
        self.history.append({"query": user_input, "response_hash": session_hash})

        return {
            "recommendation": msg,
            "cited_trials": [],
            "safety_status": "Non-RAG",
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }




    def _handle_off_topic(self, user_input: str, parsed: Dict[str, Any]):
        msg = (
            "I'm specialized in diabetes-related clinical trials. Your query appears to be "
            "about symptoms or conditions not directly related to diabetes. "
            "If you have diabetes-related questions, I'd be happy to help!"
        )
        log = log_provenance_step("OffTopicHandler", user_input, msg, {"type": "off_topic"})
        self.provenance_chain.append(log)
        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])

        return {
            "recommendation": msg,
            "cited_trials": [],
            "safety_status": "Off-topic",
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }

    def _handle_knowledge_question(self, user_input: str, parsed: Dict[str, Any]):
        user_question = parsed.get("user_question", user_input)
        prompt = (
            "You are a certified diabetes educator. Answer this question clearly and accurately.\n"
            f"QUESTION: {user_question}\n"
        )
        try:
            res = self.advisor.model.generate_content(prompt)
            answer = (res.text or "").strip()
        except:
            answer = "Unable to answer at this time."

        log = log_provenance_step("KnowledgeAgent", user_input, answer, {"type": "general_knowledge"})
        self.provenance_chain.append(log)
        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])

        return {
            "recommendation": answer,
            "cited_trials": [],
            "safety_status": "Knowledge-Based",
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }


    def process_query(self, user_input: str):
        self.provenance_chain = []

        # 1. Parse
        parsed, parse_log = self.parser.parse(user_input)
        self.provenance_chain.append(parse_log)

        intent = (parsed.get("intent") or "trial_search").lower()
        query_type = parsed.get("query_type", "trial_query")
        is_diabetes_related = parsed.get("is_diabetes_related", True)

        # Handle greetings
        if intent == "greeting":
            return self._handle_simple_greeting(user_input)

        # Handle off-topic
        if intent == "off_topic" or not is_diabetes_related:
            return self._handle_off_topic(user_input, parsed)

        # Handle profile info (store but don't search yet)
        if intent == "profile_info":
            # Extract profile
            # TODO: Implement profile extraction
            msg = (
                "Thank you for sharing your information. I've noted your details. "
                "What type of clinical trials would you like to explore? "
                "For example: 'Show me trials about diet management' or 'What trials test new medications?'"
            )
            log = log_provenance_step("ProfileAgent", user_input, msg, {"action": "profile_stored"})
            self.provenance_chain.append(log)

            session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])
            return {
                "recommendation": msg,
                "cited_trials": [],
                "safety_status": "Profile Update",
                "session_hash": session_hash,
                "provenance_chain": self.provenance_chain,
            }

        # Handle ONLY pure education questions (NO retrieval)
        # Must be general_question AND not asking about trials
        if intent == "general_question" and query_type == "knowledge_seeking":
            # Double-check not asking about trials
            if "trial" not in user_input.lower() and "study" not in user_input.lower():
                return self._handle_knowledge_question(user_input, parsed)

        # üî¥ DEFAULT: TRIAL SEARCH (WITH RETRIEVAL)
        # This catches:
        # - intent="trial_search"
        # - Anything diabetes-related we're unsure about
        # - Better to search and find nothing than miss relevant trials

        retrieved, retrieve_log = self.retriever.retrieve(parsed)


        # # TEMP DEBUG - Remove after testing
        # print(f"\n=== DEBUG INFO ===")
        # print(f"Query: {user_input}")
        # print(f"Retrieved {len(retrieved.get('trials', []))} trials")
        # print(f"Avg confidence: {retrieved.get('avg_confidence', 0):.3f}")
        # if retrieved.get('trials'):
        #     print(f"Top trial: {retrieved['trials'][0]['nct_id']}")
        #     print(f"Top confidence: {retrieved['trials'][0]['confidence']:.3f}")
        # print(f"===================\n")


        # # TEMP DEBUG
        # if user_input.lower() == "are there trials testing new medications?":
        #     print(f"\n=== DEBUG: New Medications Query ===")
        #     print(f"Retrieved {len(retrieved.get('trials', []))} trials")
        #     print(f"Avg confidence: {retrieved.get('avg_confidence', 0):.4f}")
        #     if retrieved.get('trials'):
        #         for i, t in enumerate(retrieved['trials'][:3]):
        #             print(f"Trial {i+1}: {t['nct_id']} | Conf: {t['confidence']:.4f}")
        #     print("====================================\n")


        self.provenance_chain.append(retrieve_log)

        # Check if query is too generic (low confidence + generic keywords)
        generic_terms = ["new", "any", "some", "recent", "latest"]
        is_generic = any(term in user_input.lower() for term in generic_terms)
        avg_conf = retrieved.get("avg_confidence", 0.0)

        if is_generic and avg_conf < 0.15:
            return self._handle_generic_trial_query(user_input, parsed)


        # 3. Advisor
        draft_advice, advise_log = self.advisor.advise(parsed, retrieved, self.profile_agent.profile)
        self.provenance_chain.append(advise_log)

        trials = retrieved.get("trials", [])
        if draft_advice.get("confidence_veto", False) or not trials:
            final_text = draft_advice["recommendation"]
            safety_status = "Vetoed (Low Confidence)"
            evidence_list = []
        else:
            # 4. Safety
            final_text, safety_status, safety_log = self.safety.verify(draft_advice["recommendation"], trials)
            self.provenance_chain.append(safety_log)
            evidence_list = trials

        nct_ids = [t["nct_id"] for t in evidence_list]
        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])

        # 5. Update profile/history
        turn_data = {
            "query": user_input,
            "parsed": parsed,
            "nct_ids": nct_ids,
            "safety_status": safety_status,
            "session_hash": session_hash,
        }
        profile_log = self.profile_agent.update_profile(turn_data)
        self.provenance_chain.append(profile_log)
        self.history.append({"query": user_input, "response_hash": session_hash})

        return {
            "recommendation": final_text,
            "cited_trials": nct_ids,
            "safety_status": safety_status,
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }



    def _handle_generic_trial_query(self, user_input: str, parsed: Dict[str, Any]):
        """Handle generic queries that need more specificity"""

        msg = (
            "I found that question a bit broad. I have 22,000+ diabetes trials in my database. "
            "To help you better, could you specify:\n\n"
            "**Drug/Medication Trials:**\n"
            "- Specific drugs: 'trials testing metformin', 'liraglutide trials'\n"
            "- Drug classes: 'GLP-1 trials', 'SGLT2 inhibitor trials'\n"
            "- Insulin: 'insulin pump trials', 'insulin therapy trials'\n\n"
            "**Lifestyle Trials:**\n"
            "- Diet: 'low-carb diet trials', 'Mediterranean diet trials'\n"
            "- Exercise: 'physical activity trials', 'exercise trials'\n\n"
            "**Technology Trials:**\n"
            "- Monitoring: 'CGM trials', 'glucose monitoring trials'\n"
            "- Apps: 'diabetes app trials', 'digital health trials'\n\n"
            "**Or describe your situation:**\n"
            "- 'I'm 55 with type 2 diabetes, what trials can I join?'\n"
            "- 'Trials for managing high blood sugar'\n\n"
            "What would you like to explore?"
        )

        log = log_provenance_step("GenericQueryHandler", user_input, msg, {"type": "needs_refinement"})
        self.provenance_chain.append(log)

        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])
        self.history.append({"query": user_input, "response_hash": session_hash})

        return {
            "recommendation": msg,
            "cited_trials": [],
            "safety_status": "Refinement Needed",
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }




# ============================================================
# GLOBAL BOT INSTANCE + ENTRYPOINT
# ============================================================

default_profile = {
    "user_id": "Patient",
    "conditions": ["diabetes"],
    "extracted_conditions": []
}

_bot = HealthcareBot(gemini_model, embed_model, faiss_index, chunk_map, initial_profile=default_profile)

def run_bot(user_input: str) -> Dict[str, Any]:
    return _bot.process_query(user_input)

UI frontend application simple web interface

https://docs.streamlit.io/develop/tutorials/chat-and-llm-apps/build-conversational-apps

In [None]:
%%writefile app.py
import streamlit as st
import os

if "GEMINI_API_KEY" not in os.environ:
    st.error("‚ö†Ô∏è API Key missing! Please run the 'Secure Input' cell in the notebook first.")

from run_bot import run_bot

st.title("Clinical Trial Health Advisor ü§ñ")
st.caption("AI for Healthcare - Clinical Trials RAG")

if "messages" not in st.session_state:
    st.session_state.messages = []

for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])

if user_input := st.chat_input("Describe your symptoms..."):
    st.session_state.messages.append({"role": "user", "content": user_input})
    with st.chat_message("user"):
        st.markdown(user_input)

    with st.spinner("Searching clinical trials..."):
        result = run_bot(user_input)
        reply = result["recommendation"]

    with st.chat_message("assistant"):
        st.markdown(reply)

    st.session_state.messages.append({"role": "assistant", "content": reply})

In [None]:
!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64
!mv cloudflared-linux-amd64 cloudflared
!chmod +x cloudflared

In [None]:
#AI LLM
!streamlit run app.py &>/dev/null&
!./cloudflared tunnel --url http://localhost:8501 --no-autoupdate

In [None]:
# Run this and share output
# === REAL PATH (from readlink) ===
BASE = "/content/drive/MyDrive/LLM_Based_GenAI_Sem1/data/"

import pandas as pd
df = pd.read_csv(f"{BASE}/clinical_trials_diabetes_full.csv")

print("=== DATASET INSPECTION ===")
print(f"Total trials: {len(df)}")
print(f"\nColumns: {list(df.columns)}")
print(f"\nSample titles:")
print(df['brief_title'].head(10))
print(f"\nSample summaries:")
print(df['brief_summary'].iloc[0][:500])
print(df['brief_summary'].iloc[1][:500])


In [None]:
import pandas as pd

BASE = "/content/drive/MyDrive/LLM_Based_GenAI_Sem1/data/"
df = pd.read_csv(f"{BASE}/clinical_trials_diabetes_full.csv")

# NCT IDs from your chatbot results
test_nct_ids = [
    "NCT00115973",  # Insulin pump
    "NCT01489644",  # Metformin
    "NCT05136287",  # GLP-1 semaglutide
    "NCT02478190",  # High blood sugar management
]

print("=== VERIFICATION: Are these NCT IDs in your database? ===\n")

for nct in test_nct_ids:
    match = df[df['nct_id'] == nct]

    if len(match) > 0:
        print(f"‚úÖ {nct} FOUND in database")
        print(f"   Title: {match.iloc[0]['brief_title']}")
        print(f"   Summary (first 100 chars): {match.iloc[0]['brief_summary'][:100]}...")
        print()
    else:
        print(f"‚ùå {nct} NOT FOUND in database (HALLUCINATION!)")
        print()


In [None]:
nct = "NCT00115973"
match = df[df['nct_id'] == nct]

if len(match) > 0:
    actual_title = match.iloc[0]['brief_title']
    actual_summary = match.iloc[0]['brief_summary']

    print("=== CHATBOT vs REALITY ===")
    print(f"\nChatbot said:")
    print("'studied the treatment of type 2 diabetes with an insulin infusion pump'")

    print(f"\nActual trial title:")
    print(actual_title)

    print(f"\nActual summary:")
    print(actual_summary[:300])

    print(f"\n=== IS CHATBOT DESCRIPTION ACCURATE? ===")
    summary_lower = actual_summary.lower()
    print(f"Mentions 'type 2 diabetes': {'type 2' in summary_lower or 't2d' in summary_lower}")
    print(f"Mentions 'insulin': {'insulin' in summary_lower}")
    print(f"Mentions 'pump' or 'infusion': {'pump' in summary_lower or 'infusion' in summary_lower}")


In [None]:
# Quick check
import pandas as pd

df = pd.read_csv("/content/drive/MyDrive/LLM_Based_GenAI_Sem1/data/clinical_trials_diabetes_full.csv")

# Your chatbot's results
test_ids = ["NCT00115973", "NCT01489644", "NCT05136287", "NCT02478190"]

print("Quick Verification:")
for nct in test_ids:
    exists = nct in df['nct_id'].values
    print(f"{nct}: {'‚úÖ FOUND' if exists else '‚ùå NOT FOUND (HALLUCINATION!)'}")
