Qdrant

In [1]:
# Install Qdrant client
!pip install qdrant-client -q

from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
import getpass

# üîë Secure API Key Input (invisible)
print("üîë Enter your Qdrant API Key (input will be hidden):")
qdrant_api_key = getpass.getpass("Qdrant API Key: ")

# Verify key format
if qdrant_api_key and len(qdrant_api_key) > 10:
    print("‚úÖ API Key captured securely")
else:
    print("‚ö†Ô∏è API Key seems invalid")

# Connect to your cluster
client = QdrantClient(
    url="https://215ec69e-fa22-4f38-bcf3-941e73901a68.us-east4-0.gcp.cloud.qdrant.io",
    api_key=qdrant_api_key
)

# Create collection
client.create_collection(
    collection_name="clinical_trials",
    vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)

print("‚úÖ Collection 'clinical_trials' created successfully!")

# Verify
collections = client.get_collections()
print(f"\nüìä Collections: {collections}")


üîë Enter your Qdrant API Key (input will be hidden):
Qdrant API Key: ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
‚úÖ API Key captured securely


UnexpectedResponse: Unexpected Response: 409 (Conflict)
Raw response content:
b'{"status":{"error":"Wrong input: Collection `clinical_trials` already exists!"},"time":0.038894767}'

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os

BASE = "/content/drive/.shortcut-targets-by-id/1-SiVJhXHTHtDYSrPmW_0VfuP7gSTePcj/data"

# Check if files exist
files_to_check = [
    "clinical_trials_all_full_embeddings.npy",
    "clinical_trials_all_full_chunk_map.json",
    "clinical_trials_all_full_faiss.index"
]

print("üìÅ Checking if files exist in Drive:\n")
for filename in files_to_check:
    filepath = f"{BASE}/{filename}"
    exists = os.path.exists(filepath)
    if exists:
        size_mb = os.path.getsize(filepath) / (1024*1024)
        print(f"‚úÖ {filename}: EXISTS ({size_mb:.1f} MB)")
    else:
        print(f"‚ùå {filename}: NOT FOUND")


üìÅ Checking if files exist in Drive:

‚úÖ clinical_trials_all_full_embeddings.npy: EXISTS (384.8 MB)
‚úÖ clinical_trials_all_full_chunk_map.json: EXISTS (230.6 MB)
‚úÖ clinical_trials_all_full_faiss.index: EXISTS (384.8 MB)


Load Data and Upload to Qdrant

In [6]:
import pandas as pd
import numpy as np
import json
from tqdm import tqdm

BASE = "/content/drive/.shortcut-targets-by-id/1-SiVJhXHTHtDYSrPmW_0VfuP7gSTePcj/data"

print("‚è≥ Loading embeddings and chunk map from Drive...")

# Load embeddings (384.8 MB)
embeddings = np.load(f"{BASE}/clinical_trials_all_full_embeddings.npy")
print(f"‚úÖ Loaded {len(embeddings)} embeddings (shape: {embeddings.shape})")

# Load chunk map (metadata - 230.6 MB)
with open(f"{BASE}/clinical_trials_all_full_chunk_map.json", "r") as f:
    chunk_map = json.load(f)
print(f"‚úÖ Loaded {len(chunk_map)} chunks of metadata")

# Verify sizes match
if len(embeddings) == len(chunk_map):
    print(f"‚úÖ Data verified: {len(embeddings)} vectors ready to upload")
else:
    print(f"‚ö†Ô∏è WARNING: Mismatch! {len(embeddings)} embeddings vs {len(chunk_map)} chunks")


‚è≥ Loading embeddings and chunk map from Drive...
‚úÖ Loaded 262660 embeddings (shape: (262660, 384))
‚úÖ Loaded 262660 chunks of metadata
‚úÖ Data verified: 262660 vectors ready to upload


In [7]:
from qdrant_client.models import PointStruct
import getpass

# Reconnect to Qdrant (in case session expired)
print("\nüîë Enter your Qdrant API Key again:")
qdrant_api_key = getpass.getpass("Qdrant API Key: ")

from qdrant_client import QdrantClient

client = QdrantClient(
    url="https://215ec69e-fa22-4f38-bcf3-941e73901a68.us-east4-0.gcp.cloud.qdrant.io",
    api_key=qdrant_api_key
)
print("‚úÖ Connected to Qdrant")

print("\n‚è≥ Uploading vectors to Qdrant...")
print("‚ö†Ô∏è This will take 5-10 minutes for 262K vectors. Please wait...\n")

# Batch upload (100 vectors at a time)
batch_size = 100
total_batches = (len(embeddings) + batch_size - 1) // batch_size

for i in tqdm(range(0, len(embeddings), batch_size), desc="Uploading", total=total_batches):
    batch_end = min(i + batch_size, len(embeddings))

    points = []
    for idx in range(i, batch_end):
        points.append(
            PointStruct(
                id=idx,
                vector=embeddings[idx].tolist(),
                payload={
                    "nct_id": chunk_map[idx]["nct_id"],
                    "title": chunk_map[idx]["title"],
                    "text": chunk_map[idx]["text"],
                    "status": chunk_map[idx]["status"]
                }
            )
        )

    # Upload batch
    client.upsert(
        collection_name="clinical_trials",
        points=points
    )

print(f"\n‚úÖ Successfully uploaded {len(embeddings)} vectors to Qdrant!")

# Verify upload
collection_info = client.get_collection("clinical_trials")
print(f"\nüìä Final Collection Stats:")
print(f"   ‚úÖ Total vectors: {collection_info.points_count:,}")
print(f"   ‚úÖ Vector dimension: {collection_info.config.params.vectors.size}")
print(f"   ‚úÖ Distance metric: {collection_info.config.params.vectors.distance}")



üîë Enter your Qdrant API Key again:
Qdrant API Key: ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
‚úÖ Connected to Qdrant

‚è≥ Uploading vectors to Qdrant...
‚ö†Ô∏è This will take 5-10 minutes for 262K vectors. Please wait...



Uploading: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2627/2627 [13:20<00:00,  3.28it/s]


‚úÖ Successfully uploaded 262660 vectors to Qdrant!

üìä Final Collection Stats:
   ‚úÖ Total vectors: 262,660
   ‚úÖ Vector dimension: 384
   ‚úÖ Distance metric: Cosine





In [None]:
%%writefile update_qdrant_from_drive.py
"""
Read CSV files from Google Drive ‚Üí Generate embeddings ‚Üí Upload to Qdrant
No intermediate files needed!
"""

import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams
from tqdm import tqdm
import os

class QdrantDataPipeline:
    def __init__(self, qdrant_url, qdrant_api_key):
        self.client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
        self.embed_model = SentenceTransformer("all-MiniLM-L6-v2")
        self.collection_name = "clinical_trials"

    def load_and_filter_csvs(self, drive_base_path, csv_names):
        """Load multiple CSVs from Drive and filter"""
        print("üìÇ Loading CSV files from Drive...")

        dfs = []
        for csv_name in csv_names:
            csv_path = f"{drive_base_path}/{csv_name}"
            print(f"   Loading {csv_name}...")
            df = pd.read_csv(csv_path)
            dfs.append(df)

        # Concatenate all
        df_all = pd.concat(dfs, ignore_index=True)
        print(f"‚úÖ Loaded {len(df_all)} total trials")

        # Filter bad statuses
        df_all["status"] = df_all["status"].astype(str).str.strip().str.title()
        bad_status = ["Terminated", "Withdrawn", "Suspended", "No Longer Available", "Unknown"]
        df_clean = df_all[~df_all["status"].isin(bad_status)].copy()

        print(f"‚úÖ After filtering: {len(df_clean)} trials")
        return df_clean

    def create_chunks(self, df_clean):
        """Create text chunks from DataFrame"""
        print("üìù Creating chunks...")

        chunks = []
        for idx, row in df_clean.iterrows():
            title = str(row.get("brief_title", "")).strip()
            summary = str(row.get("brief_summary", "")).strip()

            if len(summary) < 20:
                continue

            text = f"Title: {title}\nSummary: {summary}"

            chunks.append({
                "nct_id": row["nct_id"],
                "title": title,
                "text": text,
                "status": row["status"]
            })

        print(f"‚úÖ Created {len(chunks)} chunks")
        return chunks

    def generate_embeddings(self, chunks):
        """Generate embeddings for all chunks"""
        print("üß† Generating embeddings...")

        texts = [c["text"] for c in chunks]
        embeddings = self.embed_model.encode(
            texts,
            batch_size=64,
            show_progress_bar=True,
            convert_to_numpy=True
        )

        print(f"‚úÖ Generated {len(embeddings)} embeddings")
        return embeddings

    def upload_to_qdrant(self, embeddings, chunks, mode="refresh"):
        """Upload data to Qdrant"""

        if mode == "refresh":
            print("üóëÔ∏è Deleting old collection...")
            try:
                self.client.delete_collection(self.collection_name)
            except:
                pass

            print("üì¶ Creating fresh collection...")
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(size=384, distance=Distance.COSINE)
            )
            start_id = 0
        else:  # mode == "add"
            collection_info = self.client.get_collection(self.collection_name)
            start_id = collection_info.points_count
            print(f"üìä Adding to existing data, starting from ID: {start_id}")

        print(f"‚è≥ Uploading {len(embeddings)} vectors to Qdrant...")

        batch_size = 100
        for i in tqdm(range(0, len(embeddings), batch_size), desc="Uploading"):
            batch_end = min(i + batch_size, len(embeddings))

            points = []
            for idx in range(i, batch_end):
                points.append(PointStruct(
                    id=start_id + idx,
                    vector=embeddings[idx].tolist(),
                    payload=chunks[idx]
                ))

            self.client.upsert(
                collection_name=self.collection_name,
                points=points
            )

        # Verify
        final_count = self.client.get_collection(self.collection_name).points_count
        print(f"‚úÖ Upload complete! Total vectors in Qdrant: {final_count:,}")

    def run_pipeline_from_drive(self, drive_base_path, csv_names, mode="refresh"):
        """Complete pipeline: Drive CSVs ‚Üí Qdrant"""
        print("\nüöÄ Starting Qdrant Update Pipeline from Drive\n")

        # Step 1: Load and filter CSVs
        df_clean = self.load_and_filter_csvs(drive_base_path, csv_names)

        # Step 2: Create chunks
        chunks = self.create_chunks(df_clean)

        # Step 3: Generate embeddings
        embeddings = self.generate_embeddings(chunks)

        # Step 4: Upload to Qdrant
        self.upload_to_qdrant(embeddings, chunks, mode=mode)

        print("\n‚úÖ Pipeline complete! Your app will now use the updated data.")


# Usage Example
if __name__ == "__main__":
    import getpass

    # Configuration
    DRIVE_BASE = "/content/drive/.shortcut-targets-by-id/1-SiVJhXHTHtDYSrPmW_0VfuP7gSTePcj/data"

    CSV_FILES = [
        "clinical_trials_diabetes_full.csv",
        "clinical_trials_cancer_full.csv",
        "clinical_trials_alzheimer_full.csv",
        "clinical_trials_asthma_full.csv",
        "clinical_trials_cardiovascular_full.csv"
    ]

    QDRANT_URL = "https://215ec69e-fa22-4f38-bcf3-941e73901a68.us-east4-0.gcp.cloud.qdrant.io"

    # Get API key
    qdrant_key = getpass.getpass("üîë Enter Qdrant API Key: ")

    # Run pipeline
    pipeline = QdrantDataPipeline(QDRANT_URL, qdrant_key)
    pipeline.run_pipeline_from_drive(
        drive_base_path=DRIVE_BASE,
        csv_names=CSV_FILES,
        mode="refresh"  # Change to "add" to append instead of replace
    )


Update Code to Use Qdrant Instead of FAISS

In [2]:
%%writefile utils_qdrant.py
import json
import hashlib
from datetime import datetime
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer

# --- Confidence score from distance ---

def calculate_confidence_score(distance: float, normalization_factor: float = 1.0) -> float:
    """Inverse L2 distance score in (0,1]; closer = higher confidence."""
    return normalization_factor / (normalization_factor + float(distance))


# --- Load Qdrant client + embedding model ---

def load_qdrant_and_model(qdrant_url: str, qdrant_api_key: str):
    """Loads Qdrant client and embedding model."""
    print("‚è≥ Connecting to Qdrant...")

    client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)

    # Verify connection
    collection_info = client.get_collection("clinical_trials")
    print(f"‚úÖ Connected to Qdrant: {collection_info.points_count:,} vectors ready")

    # Load embedding model (same as before)
    embed_model = SentenceTransformer("all-MiniLM-L6-v2")
    print("‚úÖ Embedding model loaded")

    return client, embed_model


# --- Provenance logging (unchanged) ---

def log_provenance_step(agent_name: str, input_data, output_data, detail=None):
    """Creates a detailed log entry for a single agent step."""
    log_entry = {
        "timestamp": datetime.now().isoformat(),
        "agent": agent_name,
        "input": input_data,
        "output": output_data,
        "detail": detail or {},
        "model_version": "gemini-2.0-flash",
    }
    return log_entry


# --- Reproducibility hash (unchanged) ---

def generate_reproducibility_hash(conversation_history, corpus_version: str = "v1.0"):
    """Generates a deterministic session hash based on conversation history."""
    queries = [turn.get("query", "") for turn in conversation_history]
    raw = f"{corpus_version}|{'|'.join(queries)}"
    return hashlib.md5(raw.encode("utf-8")).hexdigest()


Overwriting utils_qdrant.py


In [3]:
import getpass
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer

# Get API key
qdrant_api_key = getpass.getpass("üîë Qdrant API Key: ")

# Connect
qdrant_client = QdrantClient(
    url="https://215ec69e-fa22-4f38-bcf3-941e73901a68.us-east4-0.gcp.cloud.qdrant.io",
    api_key=qdrant_api_key
)

# Check what methods are available
print("Available search methods:")
print([m for m in dir(qdrant_client) if 'search' in m.lower() or 'query' in m.lower()])


üîë Qdrant API Key: ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
Available search methods:
['_resolve_query', '_resolve_query_batch_request', '_resolve_query_request', '_scored_points_to_query_responses', 'query', 'query_batch', 'query_batch_points', 'query_points', 'query_points_groups', 'search_matrix_offsets', 'search_matrix_pairs']


In [4]:
import getpass
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer

# Get API key
qdrant_api_key = getpass.getpass("üîë Qdrant API Key: ")

# Connect
qdrant_client = QdrantClient(
    url="https://215ec69e-fa22-4f38-bcf3-941e73901a68.us-east4-0.gcp.cloud.qdrant.io",
    api_key=qdrant_api_key
)

# Load model
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# Test search
test_query = "diabetes insulin therapy trials"
q_emb = embed_model.encode([test_query])[0]

# Use query_points (correct method)
results = qdrant_client.query_points(
    collection_name="clinical_trials",
    query=q_emb.tolist(),
    limit=3
)

print(f"\nüîç Test Query: '{test_query}'")
print(f"\nüìä Top 3 Results:\n")

for i, point in enumerate(results.points, 1):
    print(f"{i}. NCT ID: {point.payload['nct_id']}")
    print(f"   Score: {point.score:.3f}")
    print(f"   Title: {point.payload['title'][:80]}...")
    print()

print("‚úÖ Qdrant search working!")


üîë Qdrant API Key: ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.



üîç Test Query: 'diabetes insulin therapy trials'

üìä Top 3 Results:

1. NCT ID: NCT00151697
   Score: 0.752
   Title: LANN-study: Lantus, Amaryl, Novorapid, Novomix Study...

2. NCT ID: NCT00151697
   Score: 0.752
   Title: LANN-study: Lantus, Amaryl, Novorapid, Novomix Study...

3. NCT ID: NCT02192424
   Score: 0.697
   Title: Early Intermittent Intensive Insulin Therapy as an Effective Treatment of Type 2...

‚úÖ Qdrant search working!


Fix RetrievalAgent to Use query_points

In [5]:
%%writefile retrieval_agent_qdrant.py
import numpy as np
from typing import Dict, Any
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer

# Import scoring from your existing code
from utils_qdrant import calculate_confidence_score, log_provenance_step

# Try to import reranker
CrossEncoder = None
try:
    from sentence_transformers import CrossEncoder
except ImportError:
    pass


class RetrievalAgentQdrant:
    """Retrieval agent using Qdrant instead of FAISS."""

    def __init__(
        self,
        qdrant_client: QdrantClient,
        embed_model: SentenceTransformer,
        evidence_scorer,
        profile_agent=None
    ):
        self.client = qdrant_client
        self.embed_model = embed_model
        self.evidence_scorer = evidence_scorer
        self.profile_agent = profile_agent

        # Optional: Load reranker
        self.reranker = None
        if CrossEncoder:
            try:
                print("‚è≥ Loading Cross-Encoder reranker...")
                self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
                print("‚úÖ Reranker loaded")
            except Exception as e:
                print(f"‚ö†Ô∏è Reranker failed to load: {e}")

    def retrieve(self, parsed: Dict[str, Any], top_k: int = 5):
        """Retrieve trials from Qdrant."""
        FETCH_K = top_k * 3

        symptoms = parsed.get("symptoms") or []
        context = parsed.get("context") or ""
        query = parsed.get("user_question") or (" ".join(symptoms) + " " + context).strip()

        if not query:
            retrieval = {"query": "", "trials": [], "avg_confidence": 0.0}
            log = log_provenance_step("RetrievalAgentQdrant", parsed, retrieval, {"reason": "empty_query"})
            return retrieval, log

        # Query expansion (same as before)
        EXPANSIONS = {
            "insulin": "insulin OR insulin therapy OR insulin treatment OR insulin pump",
            "medication": "medication OR drug OR pharmaceutical OR pharmacological OR treatment",
            "diet": "diet OR dietary OR nutrition OR nutritional OR eating",
            "exercise": "exercise OR physical activity OR fitness OR activity",
            "chemo": "chemotherapy OR antineoplastic OR oncology",
            "cancer": "cancer OR tumor OR tumour OR malignancy OR oncology",
            "alzheim": "alzheimer OR dementia OR cognitive decline OR memory loss",
        }

        query_lower = query.lower()
        for term, expansion in EXPANSIONS.items():
            if term in query_lower:
                query = f"{query} {expansion}"
                break

        # 1. Generate query embedding
        q_emb = self.embed_model.encode([query])[0]

        # 2. Search Qdrant (FIXED: use query_points)
        search_results = self.client.query_points(
            collection_name="clinical_trials",
            query=q_emb.tolist(),
            limit=FETCH_K
        )

        # 3. Convert to candidate format
        initial_candidates = []
        for point in search_results.points:
            initial_candidates.append({
                "nct_id": point.payload["nct_id"],
                "title": point.payload.get("title", ""),
                "text": point.payload["text"],
                "status": point.payload.get("status", "Unknown Status"),
                "qdrant_score": point.score,  # Cosine similarity (higher = better)
            })

        final_trials = []

        # 4. Optional CrossEncoder reranking
        if self.reranker and initial_candidates:
            pairs = [[query, cand["text"]] for cand in initial_candidates]
            scores = self.reranker.predict(pairs)

            for i, cand in enumerate(initial_candidates):
                cand["rerank_score"] = float(scores[i])

            initial_candidates.sort(key=lambda x: x["rerank_score"], reverse=True)
            top_hits = initial_candidates[:top_k]

            for rank, item in enumerate(top_hits):
                logit = item["rerank_score"]
                base_conf = 1 / (1 + np.exp(-logit))

                scoring_result = self.evidence_scorer.calculate_weighted_score(
                    trial=item,
                    base_confidence=base_conf,
                    query=query,
                )

                final_trials.append({
                    "nct_id": item["nct_id"],
                    "title": item["title"],
                    "text": item["text"],
                    "status": item["status"],
                    "confidence": base_conf,
                    "weighted_score": scoring_result["weighted_score"],
                    "score_breakdown": scoring_result["breakdown"],
                    "rank": rank + 1,
                    "method": "qdrant_evidence_weighted",
                })
        else:
            # Qdrant-only path (no reranking)
            top_hits = initial_candidates[:top_k]
            for rank, item in enumerate(top_hits):
                # Qdrant uses cosine similarity (0-1, higher = better)
                base_conf = item["qdrant_score"]

                scoring_result = self.evidence_scorer.calculate_weighted_score(
                    trial=item,
                    base_confidence=base_conf,
                    query=query,
                )

                final_trials.append({
                    "nct_id": item["nct_id"],
                    "title": item["title"],
                    "text": item["text"],
                    "status": item["status"],
                    "confidence": base_conf,
                    "weighted_score": scoring_result["weighted_score"],
                    "score_breakdown": scoring_result["breakdown"],
                    "rank": rank + 1,
                    "method": "qdrant_evidence_weighted",
                })

        # Sort by weighted score
        final_trials.sort(key=lambda x: x["weighted_score"], reverse=True)
        for i, trial in enumerate(final_trials):
            trial["rank"] = i + 1

        confs = [t["weighted_score"] for t in final_trials]
        avg_conf = float(np.mean(confs)) if confs else 0.0

        retrieval = {
            "query": query,
            "trials": final_trials,
            "avg_confidence": avg_conf,
        }

        detail = {
            "top_k": top_k,
            "avg_confidence": avg_conf,
            "num_trials": len(final_trials),
            "method": "qdrant_reranked" if self.reranker else "qdrant_only",
        }

        log = log_provenance_step("RetrievalAgentQdrant", parsed, retrieval, detail)
        return retrieval, log


Overwriting retrieval_agent_qdrant.py


Update Main Bot Code to Use Qdrant

In [6]:
%%writefile run_bot_qdrant.py
"""
Updated HealthcareBot using Qdrant instead of FAISS
"""

import json
import re
import os
import sys
from typing import List, Dict, Any
import numpy as np
import requests
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold

# Import utilities
from utils_qdrant import (
    load_qdrant_and_model,
    log_provenance_step,
    generate_reproducibility_hash,
)

from retrieval_agent_qdrant import RetrievalAgentQdrant

# CrossEncoder
CrossEncoder = None
try:
    from sentence_transformers import CrossEncoder
except:
    pass


# ============================================================
# PARSER
# ============================================================
class SymptomParser:
    def __init__(self, model):
        self.model = model

    def parse(self, text: str):
        """
        Enhanced parser for clinical trial search queries.
        Decides:
        - Are they searching for trials or just asking a question?
        - Which disease (diabetes, cancer, Alzheimer‚Äôs, asthma, cardiovascular) is implied?
        """
        prompt = (
            "You are a clinical trial search classifier for medical research.\n"
            "You support conditions including diabetes, cancer, Alzheimer's disease, asthma, and cardiovascular disease.\n\n"
            f"User Input: \"{text}\"\n\n"
            "Your tasks:\n"
            "1) Decide if the user is searching for clinical trials or just asking a general question.\n"
            "2) Detect which disease(s) they are talking about.\n"
            "3) Detect if the query is not about health or clinical trials (off_topic).\n\n"
            "Classification Rules:\n"
            "- If the query mentions or implies trials, studies, research, clinical experiments, etc. ‚Üí intent='trial_search'\n"
            "- If the user is mainly describing themselves (age, diagnosis, comorbidities, meds) ‚Üí intent='profile_info'\n"
            "- If they ask 'what is X', 'how does Y work', etc. without asking about trials ‚Üí intent='general_question'\n"
            "- Simple greetings (hi, hello, hey) ‚Üí intent='greeting'\n"
            "- If clearly not about health or clinical research ‚Üí intent='off_topic', is_disease_related=false\n\n"
            "You must detect disease_focus whenever possible:\n"
            "- diabetes: diabetes, blood sugar, glucose, insulin, HbA1c, metformin, GLP-1, SGLT2\n"
            "- cancer: cancer, tumor/tumour, chemotherapy, oncology, breast cancer, lung cancer, leukemia, lymphoma\n"
            "- alzheimers: Alzheimer's, dementia, memory loss, cognitive decline\n"
            "- asthma: asthma, wheezing, bronchodilator, inhaler\n"
            "- cardiovascular: heart failure, cardiovascular disease, hypertension, high blood pressure, angina,\n"
            "  myocardial infarction, coronary artery disease, stroke\n\n"
            "Return ONLY valid JSON with this exact format:\n"
            "{\n"
            "  \"intent\": \"trial_search\" | \"profile_info\" | \"general_question\" | \"greeting\" | \"off_topic\",\n"
            "  \"query_type\": \"trial_query\" | \"profile_statement\" | \"knowledge_seeking\" | \"greeting\",\n"
            "  \"search_keywords\": [\"keyword1\", \"keyword2\"],\n"
            "  \"is_disease_related\": true or false,\n"
            "  \"disease_focus\": [\"diabetes\", \"cancer\", \"alzheimers\", \"asthma\", \"cardiovascular\"],\n"
            "  \"user_question\": \"the question in plain English\",\n"
            "  \"trial_interest\": \"what type of trial they want (diet, medication, technology, surgery, etc.)\"\n"
            "}\n\n"
            "Examples:\n"
            "- 'What trials study liraglutide in diabetes?' ‚Üí intent='trial_search', query_type='trial_query',\n"
            "  is_disease_related=true, disease_focus=['diabetes'], search_keywords=['liraglutide']\n"
            "- 'My mom has breast cancer, are there trials?' ‚Üí intent='trial_search', disease_focus=['cancer']\n"
            "- 'I am 70 with memory loss and Alzheimer's' ‚Üí intent='profile_info', disease_focus=['alzheimers']\n"
            "- 'What is HbA1c?' ‚Üí intent='general_question', disease_focus=['diabetes']\n"
            "- 'What is the weather in Paris?' ‚Üí intent='off_topic', is_disease_related=false, disease_focus=[]\n"
        )

        try:
            res = self.model.generate_content(prompt)
            raw = (res.text or "").strip()
            match = re.search(r"\{.*\}", raw, re.DOTALL)
            if match:
                parsed = json.loads(match.group(0))
            else:
                parsed = json.loads(raw)
        except Exception:
            # Fallback: simple heuristic if model fails
            text_lower = text.lower()
            disease_focus = []
            if any(x in text_lower for x in ["diabetes", "insulin", "glucose", "hba1c", "metformin", "glp-1", "sglt2"]):
                disease_focus.append("diabetes")
            if any(x in text_lower for x in ["cancer", "tumor", "tumour", "chemo", "chemotherapy", "oncology"]):
                disease_focus.append("cancer")
            if any(x in text_lower for x in ["alzheimer", "dementia", "memory loss", "cognitive decline"]):
                disease_focus.append("alzheimers")
            if "asthma" in text_lower or "wheezing" in text_lower:
                disease_focus.append("asthma")
            if any(x in text_lower for x in ["heart failure", "cardiovascular", "hypertension",
                                             "high blood pressure", "angina", "myocardial", "coronary", "stroke"]):
                disease_focus.append("cardiovascular")

            if any(kw in text_lower for kw in ["trial", "study", "research", "clinical"]):
                intent = "trial_search"
                query_type = "trial_query"
            elif any(kw in text_lower for kw in ["hi", "hello", "hey"]):
                intent = "greeting"
                query_type = "greeting"
            else:
                intent = "general_question"
                query_type = "knowledge_seeking"

            parsed = {
                "intent": intent,
                "query_type": query_type,
                "search_keywords": [text] if intent == "trial_search" else [],
                "is_disease_related": bool(disease_focus),
                "disease_focus": disease_focus,
                "user_question": text,
                "trial_interest": "general",
            }

        # --- Heuristic correction layer on top of model output ---
        text_lower = text.lower()
        diseases = set(parsed.get("disease_focus") or [])

        if any(x in text_lower for x in ["diabetes", "insulin", "glucose", "hba1c", "metformin", "glp-1", "sglt2"]):
            diseases.add("diabetes")
        if any(x in text_lower for x in ["cancer", "tumor", "tumour", "chemo", "chemotherapy", "oncology"]):
            diseases.add("cancer")
        if any(x in text_lower for x in ["alzheimer", "alzheimers", "dementia", "memory loss", "cognitive decline"]):
            diseases.add("alzheimers")
        if "asthma" in text_lower or "wheezing" in text_lower or "inhaler" in text_lower:
            diseases.add("asthma")
        if any(x in text_lower for x in ["heart failure", "cardiovascular", "hypertension",
                                         "high blood pressure", "angina", "myocardial", "coronary", "stroke"]):
            diseases.add("cardiovascular")

        parsed["disease_focus"] = list(diseases)

        # Force trial_search if obvious trial keywords
        trial_keywords = [
            "trial", "study", "studies", "research",
            "clinical", "show me", "are there", "what trials"
        ]
        if any(kw in text_lower for kw in trial_keywords):
            parsed["intent"] = "trial_search"
            parsed["query_type"] = "trial_query"

        # If we detected diseases, ensure is_disease_related = True
        if diseases and parsed.get("intent") != "off_topic":
            parsed["is_disease_related"] = True
        elif "is_disease_related" not in parsed:
            parsed["is_disease_related"] = bool(diseases)

        log = log_provenance_step("SymptomParser", text, parsed)
        return parsed, log


# ============================================================
# PROFILE AGENT
# ============================================================
class ProfileAgent:
    def __init__(self, initial_profile: Dict[str, Any] = None):
        if initial_profile is None:
            initial_profile = {
                "user_id": "Patient",
                "conditions": [],          # could be filled later
                "extracted_conditions": [],  # dynamic memory
                "history": [],
            }
        self.profile = initial_profile

    def update_profile(self, turn_data: Dict[str, Any]):
        """
        Updates history and extracts persistent medical entities.
        """
        self.profile.setdefault("history", []).append(turn_data)
        self.profile.setdefault("extracted_conditions", [])

        parsed = turn_data.get("parsed", {})
        # optional: keep disease_focus as conditions
        diseases = parsed.get("disease_focus") or []
        if diseases:
            current = set(self.profile["extracted_conditions"])
            for d in diseases:
                current.add(d)
            self.profile["extracted_conditions"] = list(current)

        snapshot = {
            "user_id": self.profile.get("user_id", "Patient"),
            "known_conditions": self.profile.get("extracted_conditions", []),
            "num_turns": len(self.profile["history"]),
        }
        log = log_provenance_step("ProfileAgent", turn_data, {"profile_snapshot": snapshot})
        return log


# ============================================================
# EVIDENCE-WEIGHTED SCORER
# ============================================================
class EvidenceWeightedScorer:
    """
    Implements evidence-weighted scoring for clinical trials.
    Ranks trials based on multiple quality factors beyond semantic similarity.
    """

    def __init__(self):
        self.status_weights = {
            "Completed": 1.0,
            "Active, Not Recruiting": 0.9,
            "Recruiting": 0.85,
            "Enrolling By Invitation": 0.8,
            "Not Yet Recruiting": 0.6,
            "Terminated": 0.4,
            "Withdrawn": 0.3,
            "Suspended": 0.3,
            "Unknown Status": 0.5,
        }

        self.design_keywords = {
            "randomized controlled": 1.0,
            "double-blind": 0.95,
            "randomized": 0.9,
            "controlled": 0.85,
            "interventional": 0.8,
            "prospective": 0.75,
            "observational": 0.6,
            "retrospective": 0.5,
        }

    def calculate_weighted_score(
        self,
        trial: Dict[str, Any],
        base_confidence: float,
        query: str,
    ) -> Dict[str, Any]:
        """
        Calculate evidence-weighted score for a trial.
        """

        # Factor 1: Base semantic match (35%)
        match_score = base_confidence * 0.35

        # Factor 2: Trial status quality (25%)
        status = str(trial.get("status", "Unknown Status")).strip().title()
        status_score = self.status_weights.get(status, 0.5) * 0.25

        # Factor 3: Study design quality (20%)
        design_score = self._calculate_design_quality(trial) * 0.20

        # Factor 4: Keyword density (10%)
        keyword_score = self._calculate_keyword_density(trial, query) * 0.10

        # Factor 5: Metadata completeness (10%)
        completeness_score = self._calculate_completeness(trial) * 0.10

        weighted_score = (
            match_score +
            status_score +
            design_score +
            keyword_score +
            completeness_score
        )

        breakdown = {
            "base_confidence": base_confidence,
            "weighted_score": weighted_score,
            "factors": {
                "semantic_match": match_score,
                "trial_status": status_score,
                "study_design": design_score,
                "keyword_density": keyword_score,
                "completeness": completeness_score,
            },
        }

        return {
            "weighted_score": min(weighted_score, 1.0),
            "breakdown": breakdown,
        }

    def _calculate_design_quality(self, trial: Dict[str, Any]) -> float:
        text = f"{trial.get('title', '')} {trial.get('text', '')}".lower()
        max_score = 0.0
        for keyword, weight in self.design_keywords.items():
            if keyword in text:
                max_score = max(max_score, weight)
        return max_score if max_score > 0 else 0.6

    def _calculate_keyword_density(self, trial: Dict[str, Any], query: str) -> float:
        text = f"{trial.get('title', '')} {trial.get('text', '')}".lower()
        stopwords = {
            "the", "a", "an", "and", "or", "for", "with", "in", "on", "at", "to",
            "of", "is", "are", "what", "trials", "trial", "study", "studies", "clinical"
        }
        query_terms = [
            term for term in query.lower().split()
            if term not in stopwords and len(term) > 2
        ]
        if not query_terms:
            return 0.5
        matches = sum(1 for term in query_terms if term in text)
        density = matches / len(query_terms)
        return min(density, 1.0)

    def _calculate_completeness(self, trial: Dict[str, Any]) -> float:
        # Our chunk_map has "title" and "text"; treat longer text as more complete
        text = trial.get("text", "") or ""
        title = trial.get("title", "") or ""
        score = 0.0
        if len(title) > 10:
            score += 0.3
        if len(text) > 200:
            score += 0.7
        return min(score, 1.0)


# ============================================================
# PubMed Helper (NCT ‚Üí PubMed abstract)
# ============================================================
def fetch_pubmed_abstract_for_nct(nct_id: str):
    """
    Try to find a PubMed article linked to this NCT ID and return its abstract.
    Returns: {"pmid": str, "abstract": str} or None
    """
    try:
        esearch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
        params = {
            "db": "pubmed",
            "term": f"{nct_id}[si]",
            "retmode": "json",
            "retmax": 1,
        }
        r = requests.get(esearch_url, params=params, timeout=10)
        r.raise_for_status()
        data = r.json()
        idlist = data.get("esearchresult", {}).get("idlist", [])
        if not idlist:
            return None

        pmid = idlist[0]

        efetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
        params = {
            "db": "pubmed",
            "id": pmid,
            "rettype": "abstract",
            "retmode": "text",
        }
        r2 = requests.get(efetch_url, params=params, timeout=10)
        r2.raise_for_status()
        abstract_text = r2.text.strip()
        if not abstract_text:
            return None

        return {"pmid": pmid, "abstract": abstract_text}
    except Exception:
        return None


# ============================================================
# RETRIEVAL AGENT
# ============================================================
class RetrievalAgent:
    def __init__(self, embed_model, faiss_index, chunk_map, profile_agent: ProfileAgent = None):
        self.embed_model = embed_model
        self.index = faiss_index
        self.chunk_map = chunk_map
        self.profile_agent = profile_agent
        self.evidence_scorer = EvidenceWeightedScorer()

    def retrieve(self, parsed: Dict[str, Any], top_k: int = 5):
        FETCH_K = top_k * 3

        symptoms = parsed.get("symptoms") or []
        context = parsed.get("context") or ""
        query = parsed.get("user_question") or (" ".join(symptoms) + " " + context).strip()

        if not query:
            retrieval = {"query": "", "trials": [], "avg_confidence": 0.0}
            log = log_provenance_step("RetrievalAgent", parsed, retrieval, {"reason": "empty_query"})
            return retrieval, log

        # Simple expansions (still useful across diseases)
        EXPANSIONS = {
            "insulin": "insulin OR insulin therapy OR insulin treatment OR insulin pump",
            "medication": "medication OR drug OR pharmaceutical OR pharmacological OR treatment",
            "diet": "diet OR dietary OR nutrition OR nutritional OR eating",
            "exercise": "exercise OR physical activity OR fitness OR activity",
            "chemo": "chemotherapy OR antineoplastic OR oncology",
            "cancer": "cancer OR tumor OR tumour OR malignancy OR oncology",
            "alzheim": "alzheimer OR dementia OR cognitive decline OR memory loss",
        }

        query_lower = query.lower()
        for term, expansion in EXPANSIONS.items():
            if term in query_lower:
                query = f"{query} {expansion}"
                break

        # 1. FAISS retrieval
        q_emb = self.embed_model.encode([query])
        distances, indices = self.index.search(q_emb.astype("float32"), FETCH_K)

        initial_candidates = []
        for rank, idx in enumerate(indices[0]):
            if idx == -1:
                continue
            item = self.chunk_map[idx]
            dist = float(distances[0][rank])
            initial_candidates.append({
                "nct_id": item["nct_id"],
                "title": item.get("title", ""),
                "text": item["text"],
                "status": item.get("status", "Unknown Status"),
                "faiss_dist": dist,
            })

        final_trials = []
        confs = []

        # 2. Optional CrossEncoder reranking
        if reranker and initial_candidates:
            pairs = [[query, cand["text"]] for cand in initial_candidates]
            scores = reranker.predict(pairs)

            for i, cand in enumerate(initial_candidates):
                cand["rerank_score"] = float(scores[i])

            initial_candidates.sort(key=lambda x: x["rerank_score"], reverse=True)
            top_hits = initial_candidates[:top_k]

            for rank, item in enumerate(top_hits):
                logit = item["rerank_score"]
                base_conf = 1 / (1 + np.exp(-logit))

                scoring_result = self.evidence_scorer.calculate_weighted_score(
                    trial=item,
                    base_confidence=base_conf,
                    query=query,
                )

                final_trials.append({
                    "nct_id": item["nct_id"],
                    "title": item["title"],
                    "text": item["text"],
                    "status": item["status"],
                    "confidence": base_conf,
                    "weighted_score": scoring_result["weighted_score"],
                    "score_breakdown": scoring_result["breakdown"],
                    "rank": rank + 1,
                    "method": "evidence_weighted",
                })

            final_trials.sort(key=lambda x: x["weighted_score"], reverse=True)
            for i, trial in enumerate(final_trials):
                trial["rank"] = i + 1

            confs = [t["weighted_score"] for t in final_trials]

        else:
            # FAISS-only path
            top_hits = initial_candidates[:top_k]
            for rank, item in enumerate(top_hits):
                base_conf = calculate_confidence_score(item["faiss_dist"])
                scoring_result = self.evidence_scorer.calculate_weighted_score(
                    trial=item,
                    base_confidence=base_conf,
                    query=query,
                )
                final_trials.append({
                    "nct_id": item["nct_id"],
                    "title": item["title"],
                    "text": item["text"],
                    "status": item["status"],
                    "confidence": base_conf,
                    "weighted_score": scoring_result["weighted_score"],
                    "score_breakdown": scoring_result["breakdown"],
                    "rank": rank + 1,
                    "method": "evidence_weighted_faiss",
                })

            final_trials.sort(key=lambda x: x["weighted_score"], reverse=True)
            for i, trial in enumerate(final_trials):
                trial["rank"] = i + 1

            confs = [t["weighted_score"] for t in final_trials]

        avg_conf = float(np.mean(confs)) if confs else 0.0

        retrieval = {
            "query": query,
            "trials": final_trials,
            "avg_confidence": avg_conf,
        }

        detail = {
            "top_k": top_k,
            "avg_confidence": avg_conf,
            "num_trials": len(final_trials),
            "method": "reranked" if reranker else "faiss_only",
        }

        log = log_provenance_step("RetrievalAgent", parsed, retrieval, detail)
        return retrieval, log


# ============================================================
# DIAGNOSIS / ADVISOR
# ============================================================
class DiagnosisAdvisor:
    def __init__(self, model):
        self.model = model

    def _handle_general_question(self, parsed: Dict[str, Any], retrieved: Dict[str, Any]):
        """Handle general medical knowledge questions."""
        trials = retrieved.get("trials", [])
        user_question = parsed.get("user_question") or " ".join(parsed.get("symptoms", []))

        evidence_parts = []
        for t in trials[:3]:
            evidence_parts.append(f"Trial {t['nct_id']}: {t['text'][:400]}")
        evidence = "\n\n".join(evidence_parts) if evidence_parts else "No specific trials available."

        prompt = (
            "You are a medical research educator. Answer the user's question clearly using reliable medical knowledge.\n"
            "The clinical trial evidence below provides real-world context - mention it if helpful.\n\n"
            f"USER'S QUESTION: {user_question}\n\n"
            "CLINICAL TRIAL CONTEXT (for reference only):\n"
            f"{evidence}\n\n"
            "Instructions:\n"
            "- Answer the question directly in 3‚Äì5 sentences.\n"
            "- Be specific and educational.\n"
            "- Do NOT give diagnoses or treatment instructions.\n"
            "- End with: 'For personalized advice, please consult your healthcare provider.'\n"
        )

        try:
            res = self.model.generate_content(prompt)
            text = (res.text or "").strip()
            if not text or len(text) < 50:
                text = (
                    "I don't have enough information to answer this question accurately. "
                    "For personalized guidance, please consult your healthcare provider."
                )
            return text
        except Exception:
            return (
                "I'm unable to generate a detailed answer right now. "
                "For personalized guidance, please consult your healthcare provider."
            )

    def _handle_symptom_query(
        self,
        parsed: Dict[str, Any],
        retrieved: Dict[str, Any],
        profile: Dict[str, Any],
    ):
        """
        Generate response for clinical trial search queries with
        readable paragraph summaries and PubMed abstracts when available.
        """
        trials = retrieved.get("trials", [])
        if not trials:
            return "No relevant trials were found. Please try refining your query."

        formatted_trials = []
        for t in trials[:5]:
            title = t.get("title", "") or t["text"].split("\n")[0].replace("Title: ", "")
            status = t.get("status", "Unknown")
            weighted_score = t.get("weighted_score", 0.0)

            # Extract the ClinicalTrials.gov summary text
            raw_text = t.get("text", "")
            brief_summary = raw_text.split("Summary:", 1)[-1].strip() if "Summary:" in raw_text else raw_text

            if brief_summary:
                # Ask Gemini to turn the CT.gov summary into a short paragraph
                prompt = (
                    "Rewrite the following clinical trial description as a short, clear paragraph "
                    "about what the study is testing:\n\n"
                    f"{brief_summary}\n\n"
                    "Guidelines:\n"
                    "- Use 2‚Äì4 sentences.\n"
                    "- Plain English, minimal jargon.\n"
                    "- Include the purpose and the main type of participant.\n"
                )
                try:
                    res = self.model.generate_content(prompt)
                    brief_summary = res.text.strip() if res.text else brief_summary
                except Exception:
                    if len(brief_summary) > 600:
                        brief_summary = brief_summary[:600] + "..."
            else:
                brief_summary = "No summary available."

            # PubMed abstract lookup
            pubmed_block = ""
            pub = fetch_pubmed_abstract_for_nct(t["nct_id"])
            if pub:
                abs_text = pub["abstract"]
                max_len = 2000
                if len(abs_text) > max_len:
                    abs_text = abs_text[:max_len] + "..."
                pubmed_block = (
                    f"  PubMed abstract (PMID {pub['pmid']}):\n"
                    f"  {abs_text}\n\n"
                    f"  PubMed link: https://pubmed.ncbi.nlm.nih.gov/{pub['pmid']}/\n\n"
                )

            formatted_trials.append(
                f"**{t['nct_id']}** (Relevance: {weighted_score:.0%})\n"
                f"‚Ä¢ {title}\n"
                f"  Status: {status}\n\n"
                f"  {brief_summary}\n\n"
                f"{pubmed_block}"
            )

        trials_text = "\n\n".join(formatted_trials)
        num_trials = len(formatted_trials)

        response = (
            f"I found {num_trials} clinical trial{'s' if num_trials != 1 else ''} relevant to your request:\n\n"
            f"{trials_text}\n\n"
            "Summary: These trials explore potential treatments or management strategies for the condition you asked about. "
            "More details are available using the listed NCT IDs.\n\n"
            "To learn more or consider participation, visit clinicaltrials.gov and search by NCT ID. "
            "Always discuss clinical trial options with your healthcare provider."
        )

        return response

    def advise(self, parsed: Dict[str, Any], retrieved: Dict[str, Any], profile: Dict[str, Any]):
        trials = retrieved.get("trials", [])
        avg_conf = retrieved.get("avg_confidence", 0.0)
        query_type = parsed.get("query_type", "trial_query")
        is_disease_related = parsed.get("is_disease_related", True)

        draft = {
            "recommendation": "",
            "avg_confidence": avg_conf,
            "query_type": query_type,
        }

        if not is_disease_related:
            draft["recommendation"] = (
                "I‚Äôm specialized in clinical trials for medical conditions (for example diabetes, cancer, "
                "Alzheimer‚Äôs disease, asthma, and cardiovascular diseases). "
                "Your question does not appear to be about a health condition or clinical research. "
                "If you‚Äôd like, you can ask me about trials for a specific condition."
            )
            draft["confidence_veto"] = True
            log = log_provenance_step(
                "DiagnosisAdvisor",
                parsed,
                draft,
                {"veto": True, "reason": "off_topic"},
            )
            return draft, log

        if not trials or avg_conf < 0.05:
            draft["recommendation"] = (
                "Based on the trials I retrieved, I don‚Äôt have strong enough evidence to answer this question directly. "
                "Please consult your healthcare provider for personalized advice."
            )
            draft["confidence_veto"] = True
            log = log_provenance_step(
                "DiagnosisAdvisor",
                parsed,
                draft,
                {"veto": True, "reason": "low_confidence"},
            )
            return draft, log

        if query_type == "knowledge_seeking":
            draft["recommendation"] = self._handle_general_question(parsed, retrieved)
        else:
            draft["recommendation"] = self._handle_symptom_query(parsed, retrieved, profile)

        draft["confidence_veto"] = False
        log = log_provenance_step("DiagnosisAdvisor", parsed, draft)
        return draft, log


# ============================================================
# SAFETY FILTER
# ============================================================
class ActiveSafetyFilter:
    def __init__(self, model):
        self.model = model
        self.safety_cfg = {
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
        }

    def verify(self, advice_text: str, trials: List[Dict[str, Any]]):
        # Skip safety check for list-type responses about trials
        if any(marker in advice_text for marker in ["NCT", "clinical trial", "clinicaltrials.gov"]):
            log = log_provenance_step(
                "ActiveSafetyFilter",
                {"advice": advice_text},
                {"final_text": advice_text, "status": "Pass (Trial Listing)"},
            )
            return advice_text, "Pass (Trial Listing)", log

        evidence_text = "\n".join(t["text"][:500] for t in trials[:3])

        audit_prompt = (
            "You are a Medical Safety Officer reviewing AI-generated advice.\n\n"
            "ADVICE:\n"
            f"{advice_text}\n\n"
            "EVIDENCE FROM CLINICAL TRIALS (for context):\n"
            f"{evidence_text}\n\n"
            "Check for safety issues:\n"
            "- If the advice suggests starting/stopping/changing medication without a doctor ‚Üí UNSAFE.\n"
            "- If it gives a diagnosis ‚Üí UNSAFE.\n"
            "- If it makes strong clinical claims not supported by evidence ‚Üí UNSAFE.\n"
            "- If it simply lists clinical trials with neutral wording and a recommendation to talk to a doctor ‚Üí SAFE.\n\n"
            "If the advice is acceptable, respond with exactly: SAFE\n"
            "If it is not acceptable, respond starting with: CORRECTED: <safer version>\n"
        )

        try:
            res = self.model.generate_content(audit_prompt, safety_settings=self.safety_cfg)
            txt = (res.text or "").strip()
            if txt.startswith("SAFE") or "SAFE" in txt:
                final_text = advice_text
                status = "Pass"
            else:
                final_text = f"‚ö†Ô∏è SAFETY REVISION:\n{txt}"
                status = "Revised"
        except Exception:
            if "NCT" in advice_text or "clinical trial" in advice_text.lower():
                final_text = advice_text
                status = "Pass (API Fallback)"
            else:
                final_text = "‚ö†Ô∏è Safety filter triggered. Please consult a doctor."
                status = "Revised (API Error)"

        log = log_provenance_step(
            "ActiveSafetyFilter",
            {"advice": advice_text},
            {"final_text": final_text, "status": status},
        )
        return final_text, status, log


# ============================================================
# HEALTHCAREBOT - Updated to use Qdrant
# ============================================================

class HealthcareBot:
    def __init__(self, qdrant_client, embed_model, gemini_model, initial_profile=None):
        self.parser = SymptomParser(gemini_model)
        self.profile_agent = ProfileAgent(initial_profile)
        self.evidence_scorer = EvidenceWeightedScorer()

        # NEW: Use Qdrant retrieval agent
        self.retrieval = RetrievalAgentQdrant(
            qdrant_client=qdrant_client,
            embed_model=embed_model,
            evidence_scorer=self.evidence_scorer,
            profile_agent=self.profile_agent
        )

        self.advisor = DiagnosisAdvisor(gemini_model)
        self.safety_filter = ActiveSafetyFilter(gemini_model)
        self.conversation_history = []
        self.provenance_log = []


    def chat(self, user_input: str) -> Dict[str, Any]:
        """Process user input through the pipeline."""

        # Parse intent
        parsed, parse_log = self.parser.parse(user_input)
        self.provenance_log.append(parse_log)

        # Update profile
        turn_data = {"query": user_input, "parsed": parsed}
        profile_log = self.profile_agent.update_profile(turn_data)
        self.provenance_log.append(profile_log)

        # Retrieve trials (now from Qdrant!)
        retrieved, retrieval_log = self.retrieval.retrieve(parsed, top_k=5)
        self.provenance_log.append(retrieval_log)

        # Generate response
        profile_snapshot = {
            "user_id": self.profile_agent.profile.get("user_id", "Patient"),
            "known_conditions": self.profile_agent.profile.get("extracted_conditions", []),
        }


        draft, advisor_log = self.advisor.advise(parsed, retrieved, profile_snapshot)
        self.provenance_log.append(advisor_log)

        # Safety filter - FIXED: verify() returns (text, status, log)
        advice_text = draft.get("recommendation", "") if isinstance(draft, dict) else str(draft)
        trials = retrieved.get("trials", [])

        final_response, safety_status, safety_log = self.safety_filter.verify(advice_text, trials)
        self.provenance_log.append(safety_log)

        # Save turn
        full_turn = {
            "query": user_input,
            "parsed": parsed,
            "retrieved": retrieved,
            "response": final_response,
            "timestamp": parse_log["timestamp"],
        }
        self.conversation_history.append(full_turn)

        return {
            "response": final_response,
            "avg_confidence": retrieved.get("avg_confidence", 0.0),
            "num_trials": len(retrieved.get("trials", [])),
            "provenance": self.provenance_log[-5:],
            "session_hash": generate_reproducibility_hash(self.conversation_history),
        }



def run_bot(user_input: str, qdrant_client, embed_model, gemini_model) -> Dict[str, Any]:
    """Convenience wrapper for single queries."""
    bot = HealthcareBot(qdrant_client, embed_model, gemini_model)
    return bot.chat(user_input)

Overwriting run_bot_qdrant.py


In [14]:
# # Quick check - what does verify expect?
# import inspect
# from run_bot_qdrant import ActiveSafetyFilter

# print(inspect.signature(ActiveSafetyFilter.verify))


(self, advice_text: str, trials: List[Dict[str, Any]])


In [23]:
# !sed -n '805,812p' /content/run_bot_qdrant.py


        # Safety filter - FIXED: verify() returns (text, status, log)
        advice_text = draft.get("recommendation", "") if isinstance(draft, dict) else str(draft)
        trials = retrieved.get("trials", [])

        final_response, safety_status, safety_log = self.safety_filter.verify(advice_text, trials)
        self.provenance_log.append(safety_log)

        # Save turn


In [24]:
# # Force reload the module
# import sys
# if 'run_bot_qdrant' in sys.modules:
#     del sys.modules['run_bot_qdrant']

# # Now import fresh
# from run_bot_qdrant import run_bot

# # Test again
# result = run_bot(
#     "What trials are studying insulin therapy for diabetes?",
#     qdrant_client,
#     embed_model,
#     gemini_model
# )


‚è≥ Loading Cross-Encoder reranker...
‚úÖ Reranker loaded


In [25]:
# import inspect
# from run_bot_qdrant import ActiveSafetyFilter

# # Show the FULL source code of verify
# print(inspect.getsource(ActiveSafetyFilter.verify))


    def verify(self, advice_text: str, trials: List[Dict[str, Any]]):
        # Skip safety check for list-type responses about trials
        if any(marker in advice_text for marker in ["NCT", "clinical trial", "clinicaltrials.gov"]):
            log = log_provenance_step(
                "ActiveSafetyFilter",
                {"advice": advice_text},
                {"final_text": advice_text, "status": "Pass (Trial Listing)"},
            )
            return advice_text, "Pass (Trial Listing)", log

        evidence_text = "\n".join(t["text"][:500] for t in trials[:3])

        audit_prompt = (
            "You are a Medical Safety Officer reviewing AI-generated advice.\n\n"
            "ADVICE:\n"
            f"{advice_text}\n\n"
            "EVIDENCE FROM CLINICAL TRIALS (for context):\n"
            f"{evidence_text}\n\n"
            "Check for safety issues:\n"
            "- If the advice suggests starting/stopping/changing medication without a doctor ‚Üí UNSAFE.\n"
           

In [26]:
# # Kill all cached modules
# import sys
# for key in list(sys.modules.keys()):
#     if any(x in key for x in ['run_bot', 'utils_qdrant', 'retrieval_agent']):
#         del sys.modules[key]

# print("‚úÖ Modules cleared")


‚úÖ Modules cleared


In [7]:
import os
import getpass
import google.generativeai as genai

# Setup keys
gemini_key = getpass.getpass("üîë Gemini API Key: ")
os.environ["GEMINI_API_KEY"] = gemini_key

qdrant_api_key = getpass.getpass("üîë Qdrant API Key: ")

# Setup clients
genai.configure(api_key=gemini_key)
gemini_model = genai.GenerativeModel("models/gemini-2.0-flash")

from utils_qdrant import load_qdrant_and_model
qdrant_url = "https://215ec69e-fa22-4f38-bcf3-941e73901a68.us-east4-0.gcp.cloud.qdrant.io"
qdrant_client, embed_model = load_qdrant_and_model(qdrant_url, qdrant_api_key)

# Import and test
from run_bot_qdrant import run_bot

print("\nü§ñ Testing bot...\n")
result = run_bot(
    "What trials are studying insulin therapy for diabetes?",
    qdrant_client,
    embed_model,
    gemini_model
)

print(result["response"])
print(f"\nüìä Found {result['num_trials']} trials")


üîë Gemini API Key: ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
üîë Qdrant API Key: ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
‚è≥ Connecting to Qdrant...
‚úÖ Connected to Qdrant: 262,660 vectors ready
‚úÖ Embedding model loaded

ü§ñ Testing bot...

‚è≥ Loading Cross-Encoder reranker...
‚úÖ Reranker loaded
I found 5 clinical trials relevant to your request:

**NCT00151697** (Relevance: 91%)
‚Ä¢ LANN-study: Lantus, Amaryl, Novorapid, Novomix Study
  Status: Completed

  This clinical trial aims to determine if a new combination of glimepiride and short-acting insulin can better control blood sugar and weight gain compared to standard insulin injections in people with diabetes whose current oral medications are not enough. The study will follow 150 participants with poorly controlled diabetes on oral medications for one year, comparing their glucose levels and weight changes under the new combination treatment, twice-daily mixed insulin injections, or once-daily basal insulin injections.



**NCT00922649** (Relevance: 88%)
‚Ä¢ P

Update Streamlit App

In [8]:
%%writefile app.py
"""
Streamlit UI for HealthcareBot with Qdrant backend
"""

import streamlit as st
import os
from typing import Dict, Any
import google.generativeai as genai

# Import Qdrant utilities and bot
from utils_qdrant import load_qdrant_and_model
from run_bot_qdrant import HealthcareBot

# Page config
st.set_page_config(
    page_title="Clinical Trials Search Assistant",
    page_icon="üè•",
    layout="wide"
)

# Title
st.title("üè• Clinical Trials Search Assistant")
st.markdown("**Powered by Qdrant + Gemini 2.0 Flash**")
st.markdown("Search across 262,000+ clinical trials for diabetes, cancer, Alzheimer's, asthma, and cardiovascular disease.")

# Sidebar for API keys
with st.sidebar:
    st.header("‚öôÔ∏è Configuration")

    # Gemini API Key
    gemini_key = st.text_input(
        "Gemini API Key",
        type="password",
        help="Enter your Gemini API key"
    )

    # Qdrant API Key
    qdrant_key = st.text_input(
        "Qdrant API Key",
        type="password",
        help="Enter your Qdrant API key"
    )

    # Qdrant URL (pre-filled)
    qdrant_url = st.text_input(
        "Qdrant Cluster URL",
        value="https://215ec69e-fa22-4f38-bcf3-941e73901a68.us-east4-0.gcp.cloud.qdrant.io",
        help="Your Qdrant cluster URL"
    )

    st.divider()

    # Info
    st.markdown("### üìä System Status")
    if gemini_key and qdrant_key:
        st.success("‚úÖ Keys configured")
    else:
        st.warning("‚ö†Ô∏è Enter API keys to start")

# Initialize session state
if "messages" not in st.session_state:
    st.session_state.messages = []

if "bot" not in st.session_state:
    st.session_state.bot = None

if "qdrant_client" not in st.session_state:
    st.session_state.qdrant_client = None

if "embed_model" not in st.session_state:
    st.session_state.embed_model = None

# Initialize bot when keys are provided
if gemini_key and qdrant_key and st.session_state.bot is None:
    with st.spinner("üîÑ Initializing system..."):
        try:
            # Setup Gemini
            os.environ["GEMINI_API_KEY"] = gemini_key
            genai.configure(api_key=gemini_key)
            gemini_model = genai.GenerativeModel("models/gemini-2.0-flash")

            # Setup Qdrant
            qdrant_client, embed_model = load_qdrant_and_model(qdrant_url, qdrant_key)

            # Store in session
            st.session_state.qdrant_client = qdrant_client
            st.session_state.embed_model = embed_model

            # Initialize bot
            st.session_state.bot = HealthcareBot(
                qdrant_client=qdrant_client,
                embed_model=embed_model,
                gemini_model=gemini_model
            )

            st.success("‚úÖ System ready!")

        except Exception as e:
            st.error(f"‚ùå Initialization failed: {str(e)}")

# Display chat messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

        # Show metadata for assistant messages
        if message["role"] == "assistant" and "metadata" in message:
            with st.expander("üìä Details"):
                col1, col2 = st.columns(2)
                with col1:
                    st.metric("Trials Found", message["metadata"]["num_trials"])
                with col2:
                    st.metric("Confidence", f"{message['metadata']['avg_confidence']:.0%}")

# Chat input
if prompt := st.chat_input("Ask about clinical trials..."):

    # Check if bot is initialized
    if st.session_state.bot is None:
        st.error("‚ö†Ô∏è Please enter your API keys in the sidebar first!")
    else:
        # Add user message
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        # Get bot response
        with st.chat_message("assistant"):
            with st.spinner("üîç Searching clinical trials..."):
                try:
                    result = st.session_state.bot.chat(prompt)

                    response = result["response"]

                    # Display response
                    st.markdown(response)

                    # Show metadata
                    with st.expander("üìä Details"):
                        col1, col2, col3 = st.columns(3)
                        with col1:
                            st.metric("Trials Found", result["num_trials"])
                        with col2:
                            st.metric("Avg Confidence", f"{result['avg_confidence']:.0%}")
                        with col3:
                            st.metric("Session Hash", result["session_hash"][:8])

                    # Add to messages
                    st.session_state.messages.append({
                        "role": "assistant",
                        "content": response,
                        "metadata": {
                            "num_trials": result["num_trials"],
                            "avg_confidence": result["avg_confidence"]
                        }
                    })

                except Exception as e:
                    st.error(f"‚ùå Error: {str(e)}")
                    st.exception(e)

# Sidebar examples
with st.sidebar:
    st.divider()
    st.markdown("### üí° Example Queries")

    examples = [
        "What trials study insulin therapy for diabetes?",
        "Show me cancer immunotherapy trials",
        "Are there trials for Alzheimer's disease?",
        "What trials are recruiting for asthma?",
        "Find cardiovascular disease trials"
    ]

    for example in examples:
        if st.button(example, key=example):
            st.session_state.messages.append({"role": "user", "content": example})
            st.rerun()

# Footer
st.divider()
st.markdown("""
<div style='text-align: center; color: gray; font-size: 0.9em;'>
    üî¨ Powered by Qdrant Vector Database + Gemini 2.0 Flash<br>
    üìä Searching 262,660+ clinical trials across 5 disease areas
</div>
""", unsafe_allow_html=True)


Writing app.py


In [10]:
# Install Streamlit and pyngrok
!pip install -q streamlit pyngrok

print("‚úÖ Packages installed!")


[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m10.2/10.2 MB[0m [31m49.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m6.9/6.9 MB[0m [31m97.1 MB/s[0m eta [36m0:00:00[0m
[?25h‚úÖ Packages installed!


In [13]:
!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64
!mv cloudflared-linux-amd64 cloudflared
!chmod +x cloudflared

In [14]:
#AI LLM
!streamlit run app.py &>/dev/null&
!./cloudflared tunnel --url http://localhost:8501 --no-autoupdate

[90m2025-11-30T15:38:11Z[0m [32mINF[0m Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee, are subject to the Cloudflare Online Services Terms of Use (https://www.cloudflare.com/website-terms/), and Cloudflare reserves the right to investigate your use of Tunnels for violations of such terms. If you intend to use Tunnels in production you should use a pre-created named tunnel by following: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps
[90m2025-11-30T15:38:11Z[0m [32mINF[0m Requesting new quick Tunnel on trycloudflare.com...
[90m2025-11-30T15:38:16Z[0m [32mINF[0m +--------------------------------------------------------------------------------------------+
[90m2025-11-30T15:38:16Z[0m [32mINF[0m |  Your quick Tunnel has been created! Visit it at (it may take some time to be reachable):  |
[90m2025

Step 1: Import Libraries


In [None]:
!pip install -q requests pandas streamlit pyngrok faiss-cpu sentence-transformers numpy

import requests
import pandas as pd
import json
import hashlib
from datetime import datetime
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m10.2/10.2 MB[0m [31m48.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m23.6/23.6 MB[0m [31m55.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m6.9/6.9 MB[0m [31m76.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Secure KEY INPUT
import os
import getpass

# Securely Capture Key
# Input will be invisible. Paste key and press Enter.
key_input = getpass.getpass("üîë Enter Gemini API Key (Invisible Input): ")

if not key_input.startswith("AIza"):
    print("‚ö†Ô∏è Warning: Key might be invalid (usually starts with 'AIza').")
else:
    print("‚úÖ API Key captured securely in Environment Variable.")

# 2. Set as Environment Variable for the Session
os.environ["GEMINI_API_KEY"] = key_input

üîë Enter Gemini API Key (Invisible Input): ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
‚úÖ API Key captured securely in Environment Variable.


In [None]:
%%writefile build_embeddings.py
import pandas as pd
import numpy as np
import faiss
import json
from sentence_transformers import SentenceTransformer

# === REAL PATH (from readlink) ===
BASE = "/content/drive/.shortcut-targets-by-id/1-SiVJhXHTHtDYSrPmW_0VfuP7gSTePcj/data"

# ---------------------------------------------
# Load Data
# ---------------------------------------------

df1 = pd.read_csv(f"{BASE}/clinical_trials_diabetes_full.csv")
df2 = pd.read_csv(f"{BASE}/clinical_trials_master_full.csv")
df3 = pd.read_csv(f"{BASE}/clinical_trials_alzheimer_full.csv")
df4 = pd.read_csv(f"{BASE}/clinical_trials_cancer_full.csv")
df5 = pd.read_csv(f"{BASE}/clinical_trials_asthma_full.csv")
df6 = pd.read_csv(f"{BASE}/clinical_trials_cardiovascular_full.csv")

df = pd.concat([df1, df2, df3, df4, df5], ignore_index=True)

df["status"] = df["status"].astype(str).str.strip().str.title()
bad_status = ["Terminated", "Withdrawn", "Suspended", "No Longer Available", "Unknown"]
df_clean = df[~df["status"].isin(bad_status)].copy()

# ---------------------------------------------
# Chunking
# ---------------------------------------------
chunks = []
chunk_map = []

for idx, row in df_clean.iterrows():
    title = str(row.get("brief_title", "")).strip()
    summary = str(row.get("brief_summary", "")).strip()

    if len(summary) < 20:
        continue

    text = f"Title: {title}\nSummary: {summary}"
    chunks.append(text)

    chunk_map.append({
        "nct_id": row["nct_id"],
        "title": title,
        "text": text,
        "status": row["status"]
    })

print(f"Created {len(chunks)} chunks.")

# ---------------------------------------------
# Embeddings
# ---------------------------------------------
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embed_model.encode(chunks, batch_size=64, show_progress_bar=True)

np.save(f"{BASE}/clinical_trials_all_full_embeddings.npy", embeddings)
print("Saved clinical_trials_all_full_embeddings.npy")

# ---------------------------------------------
# Save chunk map
# ---------------------------------------------
with open(f"{BASE}/clinical_trials_all_full_chunk_map.json", "w") as f:
    json.dump(chunk_map, f)

print("Saved clinical_trials_all_full_chunk_map.json")

# ---------------------------------------------
# Build & Save FAISS
# ---------------------------------------------
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings).astype("float32"))
faiss.write_index(index, f"{BASE}/clinical_trials_all_full_faiss.index")

print("Saved clinical_trials_all_full_faiss.index")
print("‚úÖ Embedding build COMPLETE.")


Writing build_embeddings.py


In [None]:
!python build_embeddings.py

2025-11-28 00:13:26.004412: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764288806.025774    2186 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764288806.031981    2186 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1764288806.047591    2186 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1764288806.047617    2186 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1764288806.047621    2186 computation_placer.cc:177] computation placer alr

In [None]:
%%writefile utils.py
import json
import hashlib
from datetime import datetime

import faiss
from sentence_transformers import SentenceTransformer

# --- Confidence score from distance ---

def calculate_confidence_score(distance: float, normalization_factor: float = 1.0) -> float:
    """Inverse L2 distance score in (0,1]; closer = higher confidence."""
    return normalization_factor / (normalization_factor + float(distance))


# --- Load pre-built index + chunk map ---

def load_data_and_index(chunk_map_path: str, faiss_path: str):
    """Loads pre-built chunks and FAISS index for quick startup."""
    print("‚è≥ Loading pre-built RAG index...")

    with open(chunk_map_path, "r") as f:
        chunk_map = json.load(f)

    index = faiss.read_index(faiss_path)

    embed_model = SentenceTransformer("all-MiniLM-L6-v2")

    print(f"‚úÖ RAG Index Ready: {index.ntotal} vectors loaded.")
    return embed_model, index, chunk_map


# --- Provenance logging ---

def log_provenance_step(agent_name: str, input_data, output_data, detail=None):
    """
    Creates a detailed log entry for a single agent step.
    """
    log_entry = {
        "timestamp": datetime.now().isoformat(),
        "agent": agent_name,
        "input": input_data,
        "output": output_data,
        "detail": detail or {},
        "model_version": "gemini-2.0-flash",
    }
    return log_entry


# --- Reproducibility hash ---

def generate_reproducibility_hash(conversation_history, corpus_version: str = "v1.0"):
    """
    Generates a deterministic session hash based on the conversation history.
    """
    queries = [turn.get("query", "") for turn in conversation_history]
    raw = f"{corpus_version}|{'|'.join(queries)}"
    return hashlib.md5(raw.encode("utf-8")).hexdigest()


Writing utils.py


In [None]:
import json
import re
import os
import sys
from typing import List, Dict, Any

import numpy as np
import requests
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold

# --- Updated Import: Robust Cross-Encoder Initialization ---
CrossEncoder = None
try:
    from sentence_transformers import CrossEncoder
    print("‚úÖ sentence_transformers imported successfully.")
except ImportError:
    print("‚ö†Ô∏è sentence_transformers not found. Reranking will be disabled.")
except Exception as e:
    print(f"‚ö†Ô∏è Error importing CrossEncoder: {e}. Reranking disabled.")

from utils import (
    load_data_and_index,
    log_provenance_step,
    generate_reproducibility_hash,
    calculate_confidence_score,
)

# --- CONFIG (Gemini 2.0 Flash) ---
API_KEY = os.environ.get("GEMINI_API_KEY")

if not API_KEY:
    print("‚ùå ERROR: API Key not found. Please run the 'Secure Input' cell first.")
    sys.exit(1)

genai.configure(api_key=API_KEY)

gemini_model = genai.GenerativeModel("models/gemini-2.0-flash")

# --- ALL-DISEASE INDEX (diabetes + cancer + Alzheimer‚Äôs + asthma + cardiovascular) ---
BASE = "/content/drive/.shortcut-targets-by-id/1-SiVJhXHTHtDYSrPmW_0VfuP7gSTePcj/data"
CHUNK_PATH = f"{BASE}/clinical_trials_all_full_chunk_map.json"
FAISS_PATH = f"{BASE}/clinical_trials_all_full_faiss.index"

# Load embedding model, FAISS index, and chunk metadata
embed_model, faiss_index, chunk_map = load_data_and_index(CHUNK_PATH, FAISS_PATH)

# --- Reranker Initialization ---
reranker = None
if CrossEncoder:
    try:
        print("‚è≥ Loading Reranker Model (Cross-Encoder)...")
        reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
        print("‚úÖ Reranker Loaded.")
    except Exception as e:
        print(f"‚ö†Ô∏è Reranker model download failed (using pure FAISS): {e}")


# ============================================================
# PARSER
# ============================================================
class SymptomParser:
    def __init__(self, model):
        self.model = model

    def parse(self, text: str):
        """
        Enhanced parser for clinical trial search queries.
        Decides:
        - Are they searching for trials or just asking a question?
        - Which disease (diabetes, cancer, Alzheimer‚Äôs, asthma, cardiovascular) is implied?
        """
        prompt = (
            "You are a clinical trial search classifier for medical research.\n"
            "You support conditions including diabetes, cancer, Alzheimer's disease, asthma, and cardiovascular disease.\n\n"
            f"User Input: \"{text}\"\n\n"
            "Your tasks:\n"
            "1) Decide if the user is searching for clinical trials or just asking a general question.\n"
            "2) Detect which disease(s) they are talking about.\n"
            "3) Detect if the query is not about health or clinical trials (off_topic).\n\n"
            "Classification Rules:\n"
            "- If the query mentions or implies trials, studies, research, clinical experiments, etc. ‚Üí intent='trial_search'\n"
            "- If the user is mainly describing themselves (age, diagnosis, comorbidities, meds) ‚Üí intent='profile_info'\n"
            "- If they ask 'what is X', 'how does Y work', etc. without asking about trials ‚Üí intent='general_question'\n"
            "- Simple greetings (hi, hello, hey) ‚Üí intent='greeting'\n"
            "- If clearly not about health or clinical research ‚Üí intent='off_topic', is_disease_related=false\n\n"
            "You must detect disease_focus whenever possible:\n"
            "- diabetes: diabetes, blood sugar, glucose, insulin, HbA1c, metformin, GLP-1, SGLT2\n"
            "- cancer: cancer, tumor/tumour, chemotherapy, oncology, breast cancer, lung cancer, leukemia, lymphoma\n"
            "- alzheimers: Alzheimer's, dementia, memory loss, cognitive decline\n"
            "- asthma: asthma, wheezing, bronchodilator, inhaler\n"
            "- cardiovascular: heart failure, cardiovascular disease, hypertension, high blood pressure, angina,\n"
            "  myocardial infarction, coronary artery disease, stroke\n\n"
            "Return ONLY valid JSON with this exact format:\n"
            "{\n"
            "  \"intent\": \"trial_search\" | \"profile_info\" | \"general_question\" | \"greeting\" | \"off_topic\",\n"
            "  \"query_type\": \"trial_query\" | \"profile_statement\" | \"knowledge_seeking\" | \"greeting\",\n"
            "  \"search_keywords\": [\"keyword1\", \"keyword2\"],\n"
            "  \"is_disease_related\": true or false,\n"
            "  \"disease_focus\": [\"diabetes\", \"cancer\", \"alzheimers\", \"asthma\", \"cardiovascular\"],\n"
            "  \"user_question\": \"the question in plain English\",\n"
            "  \"trial_interest\": \"what type of trial they want (diet, medication, technology, surgery, etc.)\"\n"
            "}\n\n"
            "Examples:\n"
            "- 'What trials study liraglutide in diabetes?' ‚Üí intent='trial_search', query_type='trial_query',\n"
            "  is_disease_related=true, disease_focus=['diabetes'], search_keywords=['liraglutide']\n"
            "- 'My mom has breast cancer, are there trials?' ‚Üí intent='trial_search', disease_focus=['cancer']\n"
            "- 'I am 70 with memory loss and Alzheimer's' ‚Üí intent='profile_info', disease_focus=['alzheimers']\n"
            "- 'What is HbA1c?' ‚Üí intent='general_question', disease_focus=['diabetes']\n"
            "- 'What is the weather in Paris?' ‚Üí intent='off_topic', is_disease_related=false, disease_focus=[]\n"
        )

        try:
            res = self.model.generate_content(prompt)
            raw = (res.text or "").strip()
            match = re.search(r"\{.*\}", raw, re.DOTALL)
            if match:
                parsed = json.loads(match.group(0))
            else:
                parsed = json.loads(raw)
        except Exception:
            # Fallback: simple heuristic if model fails
            text_lower = text.lower()
            disease_focus = []
            if any(x in text_lower for x in ["diabetes", "insulin", "glucose", "hba1c", "metformin", "glp-1", "sglt2"]):
                disease_focus.append("diabetes")
            if any(x in text_lower for x in ["cancer", "tumor", "tumour", "chemo", "chemotherapy", "oncology"]):
                disease_focus.append("cancer")
            if any(x in text_lower for x in ["alzheimer", "dementia", "memory loss", "cognitive decline"]):
                disease_focus.append("alzheimers")
            if "asthma" in text_lower or "wheezing" in text_lower:
                disease_focus.append("asthma")
            if any(x in text_lower for x in ["heart failure", "cardiovascular", "hypertension",
                                             "high blood pressure", "angina", "myocardial", "coronary", "stroke"]):
                disease_focus.append("cardiovascular")

            if any(kw in text_lower for kw in ["trial", "study", "research", "clinical"]):
                intent = "trial_search"
                query_type = "trial_query"
            elif any(kw in text_lower for kw in ["hi", "hello", "hey"]):
                intent = "greeting"
                query_type = "greeting"
            else:
                intent = "general_question"
                query_type = "knowledge_seeking"

            parsed = {
                "intent": intent,
                "query_type": query_type,
                "search_keywords": [text] if intent == "trial_search" else [],
                "is_disease_related": bool(disease_focus),
                "disease_focus": disease_focus,
                "user_question": text,
                "trial_interest": "general",
            }

        # --- Heuristic correction layer on top of model output ---
        text_lower = text.lower()
        diseases = set(parsed.get("disease_focus") or [])

        if any(x in text_lower for x in ["diabetes", "insulin", "glucose", "hba1c", "metformin", "glp-1", "sglt2"]):
            diseases.add("diabetes")
        if any(x in text_lower for x in ["cancer", "tumor", "tumour", "chemo", "chemotherapy", "oncology"]):
            diseases.add("cancer")
        if any(x in text_lower for x in ["alzheimer", "alzheimers", "dementia", "memory loss", "cognitive decline"]):
            diseases.add("alzheimers")
        if "asthma" in text_lower or "wheezing" in text_lower or "inhaler" in text_lower:
            diseases.add("asthma")
        if any(x in text_lower for x in ["heart failure", "cardiovascular", "hypertension",
                                         "high blood pressure", "angina", "myocardial", "coronary", "stroke"]):
            diseases.add("cardiovascular")

        parsed["disease_focus"] = list(diseases)

        # Force trial_search if obvious trial keywords
        trial_keywords = [
            "trial", "study", "studies", "research",
            "clinical", "show me", "are there", "what trials"
        ]
        if any(kw in text_lower for kw in trial_keywords):
            parsed["intent"] = "trial_search"
            parsed["query_type"] = "trial_query"

        # If we detected diseases, ensure is_disease_related = True
        if diseases and parsed.get("intent") != "off_topic":
            parsed["is_disease_related"] = True
        elif "is_disease_related" not in parsed:
            parsed["is_disease_related"] = bool(diseases)

        log = log_provenance_step("SymptomParser", text, parsed)
        return parsed, log


# ============================================================
# PROFILE AGENT
# ============================================================
class ProfileAgent:
    def __init__(self, initial_profile: Dict[str, Any] = None):
        if initial_profile is None:
            initial_profile = {
                "user_id": "Patient",
                "conditions": [],          # could be filled later
                "extracted_conditions": [],  # dynamic memory
                "history": [],
            }
        self.profile = initial_profile

    def update_profile(self, turn_data: Dict[str, Any]):
        """
        Updates history and extracts persistent medical entities.
        """
        self.profile.setdefault("history", []).append(turn_data)
        self.profile.setdefault("extracted_conditions", [])

        parsed = turn_data.get("parsed", {})
        # optional: keep disease_focus as conditions
        diseases = parsed.get("disease_focus") or []
        if diseases:
            current = set(self.profile["extracted_conditions"])
            for d in diseases:
                current.add(d)
            self.profile["extracted_conditions"] = list(current)

        snapshot = {
            "user_id": self.profile.get("user_id", "Patient"),
            "known_conditions": self.profile.get("extracted_conditions", []),
            "num_turns": len(self.profile["history"]),
        }
        log = log_provenance_step("ProfileAgent", turn_data, {"profile_snapshot": snapshot})
        return log


# ============================================================
# EVIDENCE-WEIGHTED SCORER
# ============================================================
class EvidenceWeightedScorer:
    """
    Implements evidence-weighted scoring for clinical trials.
    Ranks trials based on multiple quality factors beyond semantic similarity.
    """

    def __init__(self):
        self.status_weights = {
            "Completed": 1.0,
            "Active, Not Recruiting": 0.9,
            "Recruiting": 0.85,
            "Enrolling By Invitation": 0.8,
            "Not Yet Recruiting": 0.6,
            "Terminated": 0.4,
            "Withdrawn": 0.3,
            "Suspended": 0.3,
            "Unknown Status": 0.5,
        }

        self.design_keywords = {
            "randomized controlled": 1.0,
            "double-blind": 0.95,
            "randomized": 0.9,
            "controlled": 0.85,
            "interventional": 0.8,
            "prospective": 0.75,
            "observational": 0.6,
            "retrospective": 0.5,
        }

    def calculate_weighted_score(
        self,
        trial: Dict[str, Any],
        base_confidence: float,
        query: str,
    ) -> Dict[str, Any]:
        """
        Calculate evidence-weighted score for a trial.
        """

        # Factor 1: Base semantic match (35%)
        match_score = base_confidence * 0.35

        # Factor 2: Trial status quality (25%)
        status = str(trial.get("status", "Unknown Status")).strip().title()
        status_score = self.status_weights.get(status, 0.5) * 0.25

        # Factor 3: Study design quality (20%)
        design_score = self._calculate_design_quality(trial) * 0.20

        # Factor 4: Keyword density (10%)
        keyword_score = self._calculate_keyword_density(trial, query) * 0.10

        # Factor 5: Metadata completeness (10%)
        completeness_score = self._calculate_completeness(trial) * 0.10

        weighted_score = (
            match_score +
            status_score +
            design_score +
            keyword_score +
            completeness_score
        )

        breakdown = {
            "base_confidence": base_confidence,
            "weighted_score": weighted_score,
            "factors": {
                "semantic_match": match_score,
                "trial_status": status_score,
                "study_design": design_score,
                "keyword_density": keyword_score,
                "completeness": completeness_score,
            },
        }

        return {
            "weighted_score": min(weighted_score, 1.0),
            "breakdown": breakdown,
        }

    def _calculate_design_quality(self, trial: Dict[str, Any]) -> float:
        text = f"{trial.get('title', '')} {trial.get('text', '')}".lower()
        max_score = 0.0
        for keyword, weight in self.design_keywords.items():
            if keyword in text:
                max_score = max(max_score, weight)
        return max_score if max_score > 0 else 0.6

    def _calculate_keyword_density(self, trial: Dict[str, Any], query: str) -> float:
        text = f"{trial.get('title', '')} {trial.get('text', '')}".lower()
        stopwords = {
            "the", "a", "an", "and", "or", "for", "with", "in", "on", "at", "to",
            "of", "is", "are", "what", "trials", "trial", "study", "studies", "clinical"
        }
        query_terms = [
            term for term in query.lower().split()
            if term not in stopwords and len(term) > 2
        ]
        if not query_terms:
            return 0.5
        matches = sum(1 for term in query_terms if term in text)
        density = matches / len(query_terms)
        return min(density, 1.0)

    def _calculate_completeness(self, trial: Dict[str, Any]) -> float:
        # Our chunk_map has "title" and "text"; treat longer text as more complete
        text = trial.get("text", "") or ""
        title = trial.get("title", "") or ""
        score = 0.0
        if len(title) > 10:
            score += 0.3
        if len(text) > 200:
            score += 0.7
        return min(score, 1.0)


# ============================================================
# PubMed Helper (NCT ‚Üí PubMed abstract)
# ============================================================
def fetch_pubmed_abstract_for_nct(nct_id: str):
    """
    Try to find a PubMed article linked to this NCT ID and return its abstract.
    Returns: {"pmid": str, "abstract": str} or None
    """
    try:
        esearch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
        params = {
            "db": "pubmed",
            "term": f"{nct_id}[si]",
            "retmode": "json",
            "retmax": 1,
        }
        r = requests.get(esearch_url, params=params, timeout=10)
        r.raise_for_status()
        data = r.json()
        idlist = data.get("esearchresult", {}).get("idlist", [])
        if not idlist:
            return None

        pmid = idlist[0]

        efetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
        params = {
            "db": "pubmed",
            "id": pmid,
            "rettype": "abstract",
            "retmode": "text",
        }
        r2 = requests.get(efetch_url, params=params, timeout=10)
        r2.raise_for_status()
        abstract_text = r2.text.strip()
        if not abstract_text:
            return None

        return {"pmid": pmid, "abstract": abstract_text}
    except Exception:
        return None


# ============================================================
# RETRIEVAL AGENT
# ============================================================
class RetrievalAgent:
    def __init__(self, embed_model, faiss_index, chunk_map, profile_agent: ProfileAgent = None):
        self.embed_model = embed_model
        self.index = faiss_index
        self.chunk_map = chunk_map
        self.profile_agent = profile_agent
        self.evidence_scorer = EvidenceWeightedScorer()

    def retrieve(self, parsed: Dict[str, Any], top_k: int = 5):
        FETCH_K = top_k * 3

        symptoms = parsed.get("symptoms") or []
        context = parsed.get("context") or ""
        query = parsed.get("user_question") or (" ".join(symptoms) + " " + context).strip()

        if not query:
            retrieval = {"query": "", "trials": [], "avg_confidence": 0.0}
            log = log_provenance_step("RetrievalAgent", parsed, retrieval, {"reason": "empty_query"})
            return retrieval, log

        # Simple expansions (still useful across diseases)
        EXPANSIONS = {
            "insulin": "insulin OR insulin therapy OR insulin treatment OR insulin pump",
            "medication": "medication OR drug OR pharmaceutical OR pharmacological OR treatment",
            "diet": "diet OR dietary OR nutrition OR nutritional OR eating",
            "exercise": "exercise OR physical activity OR fitness OR activity",
            "chemo": "chemotherapy OR antineoplastic OR oncology",
            "cancer": "cancer OR tumor OR tumour OR malignancy OR oncology",
            "alzheim": "alzheimer OR dementia OR cognitive decline OR memory loss",
        }

        query_lower = query.lower()
        for term, expansion in EXPANSIONS.items():
            if term in query_lower:
                query = f"{query} {expansion}"
                break

        # 1. FAISS retrieval
        q_emb = self.embed_model.encode([query])
        distances, indices = self.index.search(q_emb.astype("float32"), FETCH_K)

        initial_candidates = []
        for rank, idx in enumerate(indices[0]):
            if idx == -1:
                continue
            item = self.chunk_map[idx]
            dist = float(distances[0][rank])
            initial_candidates.append({
                "nct_id": item["nct_id"],
                "title": item.get("title", ""),
                "text": item["text"],
                "status": item.get("status", "Unknown Status"),
                "faiss_dist": dist,
            })

        final_trials = []
        confs = []

        # 2. Optional CrossEncoder reranking
        if reranker and initial_candidates:
            pairs = [[query, cand["text"]] for cand in initial_candidates]
            scores = reranker.predict(pairs)

            for i, cand in enumerate(initial_candidates):
                cand["rerank_score"] = float(scores[i])

            initial_candidates.sort(key=lambda x: x["rerank_score"], reverse=True)
            top_hits = initial_candidates[:top_k]

            for rank, item in enumerate(top_hits):
                logit = item["rerank_score"]
                base_conf = 1 / (1 + np.exp(-logit))

                scoring_result = self.evidence_scorer.calculate_weighted_score(
                    trial=item,
                    base_confidence=base_conf,
                    query=query,
                )

                final_trials.append({
                    "nct_id": item["nct_id"],
                    "title": item["title"],
                    "text": item["text"],
                    "status": item["status"],
                    "confidence": base_conf,
                    "weighted_score": scoring_result["weighted_score"],
                    "score_breakdown": scoring_result["breakdown"],
                    "rank": rank + 1,
                    "method": "evidence_weighted",
                })

            final_trials.sort(key=lambda x: x["weighted_score"], reverse=True)
            for i, trial in enumerate(final_trials):
                trial["rank"] = i + 1

            confs = [t["weighted_score"] for t in final_trials]

        else:
            # FAISS-only path
            top_hits = initial_candidates[:top_k]
            for rank, item in enumerate(top_hits):
                base_conf = calculate_confidence_score(item["faiss_dist"])
                scoring_result = self.evidence_scorer.calculate_weighted_score(
                    trial=item,
                    base_confidence=base_conf,
                    query=query,
                )
                final_trials.append({
                    "nct_id": item["nct_id"],
                    "title": item["title"],
                    "text": item["text"],
                    "status": item["status"],
                    "confidence": base_conf,
                    "weighted_score": scoring_result["weighted_score"],
                    "score_breakdown": scoring_result["breakdown"],
                    "rank": rank + 1,
                    "method": "evidence_weighted_faiss",
                })

            final_trials.sort(key=lambda x: x["weighted_score"], reverse=True)
            for i, trial in enumerate(final_trials):
                trial["rank"] = i + 1

            confs = [t["weighted_score"] for t in final_trials]

        avg_conf = float(np.mean(confs)) if confs else 0.0

        retrieval = {
            "query": query,
            "trials": final_trials,
            "avg_confidence": avg_conf,
        }

        detail = {
            "top_k": top_k,
            "avg_confidence": avg_conf,
            "num_trials": len(final_trials),
            "method": "reranked" if reranker else "faiss_only",
        }

        log = log_provenance_step("RetrievalAgent", parsed, retrieval, detail)
        return retrieval, log


# ============================================================
# DIAGNOSIS / ADVISOR
# ============================================================
class DiagnosisAdvisor:
    def __init__(self, model):
        self.model = model

    def _handle_general_question(self, parsed: Dict[str, Any], retrieved: Dict[str, Any]):
        """Handle general medical knowledge questions."""
        trials = retrieved.get("trials", [])
        user_question = parsed.get("user_question") or " ".join(parsed.get("symptoms", []))

        evidence_parts = []
        for t in trials[:3]:
            evidence_parts.append(f"Trial {t['nct_id']}: {t['text'][:400]}")
        evidence = "\n\n".join(evidence_parts) if evidence_parts else "No specific trials available."

        prompt = (
            "You are a medical research educator. Answer the user's question clearly using reliable medical knowledge.\n"
            "The clinical trial evidence below provides real-world context - mention it if helpful.\n\n"
            f"USER'S QUESTION: {user_question}\n\n"
            "CLINICAL TRIAL CONTEXT (for reference only):\n"
            f"{evidence}\n\n"
            "Instructions:\n"
            "- Answer the question directly in 3‚Äì5 sentences.\n"
            "- Be specific and educational.\n"
            "- Do NOT give diagnoses or treatment instructions.\n"
            "- End with: 'For personalized advice, please consult your healthcare provider.'\n"
        )

        try:
            res = self.model.generate_content(prompt)
            text = (res.text or "").strip()
            if not text or len(text) < 50:
                text = (
                    "I don't have enough information to answer this question accurately. "
                    "For personalized guidance, please consult your healthcare provider."
                )
            return text
        except Exception:
            return (
                "I'm unable to generate a detailed answer right now. "
                "For personalized guidance, please consult your healthcare provider."
            )

    def _handle_symptom_query(
        self,
        parsed: Dict[str, Any],
        retrieved: Dict[str, Any],
        profile: Dict[str, Any],
    ):
        """
        Generate response for clinical trial search queries with
        readable paragraph summaries and PubMed abstracts when available.
        """
        trials = retrieved.get("trials", [])
        if not trials:
            return "No relevant trials were found. Please try refining your query."

        formatted_trials = []
        for t in trials[:5]:
            title = t.get("title", "") or t["text"].split("\n")[0].replace("Title: ", "")
            status = t.get("status", "Unknown")
            weighted_score = t.get("weighted_score", 0.0)

            # Extract the ClinicalTrials.gov summary text
            raw_text = t.get("text", "")
            brief_summary = raw_text.split("Summary:", 1)[-1].strip() if "Summary:" in raw_text else raw_text

            if brief_summary:
                # Ask Gemini to turn the CT.gov summary into a short paragraph
                prompt = (
                    "Rewrite the following clinical trial description as a short, clear paragraph "
                    "about what the study is testing:\n\n"
                    f"{brief_summary}\n\n"
                    "Guidelines:\n"
                    "- Use 2‚Äì4 sentences.\n"
                    "- Plain English, minimal jargon.\n"
                    "- Include the purpose and the main type of participant.\n"
                )
                try:
                    res = self.model.generate_content(prompt)
                    brief_summary = res.text.strip() if res.text else brief_summary
                except Exception:
                    if len(brief_summary) > 600:
                        brief_summary = brief_summary[:600] + "..."
            else:
                brief_summary = "No summary available."

            # PubMed abstract lookup
            pubmed_block = ""
            pub = fetch_pubmed_abstract_for_nct(t["nct_id"])
            if pub:
                abs_text = pub["abstract"]
                max_len = 2000
                if len(abs_text) > max_len:
                    abs_text = abs_text[:max_len] + "..."
                pubmed_block = (
                    f"  PubMed abstract (PMID {pub['pmid']}):\n"
                    f"  {abs_text}\n\n"
                    f"  PubMed link: https://pubmed.ncbi.nlm.nih.gov/{pub['pmid']}/\n\n"
                )

            formatted_trials.append(
                f"**{t['nct_id']}** (Relevance: {weighted_score:.0%})\n"
                f"‚Ä¢ {title}\n"
                f"  Status: {status}\n\n"
                f"  {brief_summary}\n\n"
                f"{pubmed_block}"
            )

        trials_text = "\n\n".join(formatted_trials)
        num_trials = len(formatted_trials)

        response = (
            f"I found {num_trials} clinical trial{'s' if num_trials != 1 else ''} relevant to your request:\n\n"
            f"{trials_text}\n\n"
            "Summary: These trials explore potential treatments or management strategies for the condition you asked about. "
            "More details are available using the listed NCT IDs.\n\n"
            "To learn more or consider participation, visit clinicaltrials.gov and search by NCT ID. "
            "Always discuss clinical trial options with your healthcare provider."
        )

        return response

    def advise(self, parsed: Dict[str, Any], retrieved: Dict[str, Any], profile: Dict[str, Any]):
        trials = retrieved.get("trials", [])
        avg_conf = retrieved.get("avg_confidence", 0.0)
        query_type = parsed.get("query_type", "trial_query")
        is_disease_related = parsed.get("is_disease_related", True)

        draft = {
            "recommendation": "",
            "avg_confidence": avg_conf,
            "query_type": query_type,
        }

        if not is_disease_related:
            draft["recommendation"] = (
                "I‚Äôm specialized in clinical trials for medical conditions (for example diabetes, cancer, "
                "Alzheimer‚Äôs disease, asthma, and cardiovascular diseases). "
                "Your question does not appear to be about a health condition or clinical research. "
                "If you‚Äôd like, you can ask me about trials for a specific condition."
            )
            draft["confidence_veto"] = True
            log = log_provenance_step(
                "DiagnosisAdvisor",
                parsed,
                draft,
                {"veto": True, "reason": "off_topic"},
            )
            return draft, log

        if not trials or avg_conf < 0.05:
            draft["recommendation"] = (
                "Based on the trials I retrieved, I don‚Äôt have strong enough evidence to answer this question directly. "
                "Please consult your healthcare provider for personalized advice."
            )
            draft["confidence_veto"] = True
            log = log_provenance_step(
                "DiagnosisAdvisor",
                parsed,
                draft,
                {"veto": True, "reason": "low_confidence"},
            )
            return draft, log

        if query_type == "knowledge_seeking":
            draft["recommendation"] = self._handle_general_question(parsed, retrieved)
        else:
            draft["recommendation"] = self._handle_symptom_query(parsed, retrieved, profile)

        draft["confidence_veto"] = False
        log = log_provenance_step("DiagnosisAdvisor", parsed, draft)
        return draft, log


# ============================================================
# SAFETY FILTER
# ============================================================
class ActiveSafetyFilter:
    def __init__(self, model):
        self.model = model
        self.safety_cfg = {
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
        }

    def verify(self, advice_text: str, trials: List[Dict[str, Any]]):
        # Skip safety check for list-type responses about trials
        if any(marker in advice_text for marker in ["NCT", "clinical trial", "clinicaltrials.gov"]):
            log = log_provenance_step(
                "ActiveSafetyFilter",
                {"advice": advice_text},
                {"final_text": advice_text, "status": "Pass (Trial Listing)"},
            )
            return advice_text, "Pass (Trial Listing)", log

        evidence_text = "\n".join(t["text"][:500] for t in trials[:3])

        audit_prompt = (
            "You are a Medical Safety Officer reviewing AI-generated advice.\n\n"
            "ADVICE:\n"
            f"{advice_text}\n\n"
            "EVIDENCE FROM CLINICAL TRIALS (for context):\n"
            f"{evidence_text}\n\n"
            "Check for safety issues:\n"
            "- If the advice suggests starting/stopping/changing medication without a doctor ‚Üí UNSAFE.\n"
            "- If it gives a diagnosis ‚Üí UNSAFE.\n"
            "- If it makes strong clinical claims not supported by evidence ‚Üí UNSAFE.\n"
            "- If it simply lists clinical trials with neutral wording and a recommendation to talk to a doctor ‚Üí SAFE.\n\n"
            "If the advice is acceptable, respond with exactly: SAFE\n"
            "If it is not acceptable, respond starting with: CORRECTED: <safer version>\n"
        )

        try:
            res = self.model.generate_content(audit_prompt, safety_settings=self.safety_cfg)
            txt = (res.text or "").strip()
            if txt.startswith("SAFE") or "SAFE" in txt:
                final_text = advice_text
                status = "Pass"
            else:
                final_text = f"‚ö†Ô∏è SAFETY REVISION:\n{txt}"
                status = "Revised"
        except Exception:
            if "NCT" in advice_text or "clinical trial" in advice_text.lower():
                final_text = advice_text
                status = "Pass (API Fallback)"
            else:
                final_text = "‚ö†Ô∏è Safety filter triggered. Please consult a doctor."
                status = "Revised (API Error)"

        log = log_provenance_step(
            "ActiveSafetyFilter",
            {"advice": advice_text},
            {"final_text": final_text, "status": status},
        )
        return final_text, status, log


# ============================================================
# HEALTHCARE BOT (Orchestrator)
# ============================================================
class HealthcareBot:
    def __init__(self, gemini_model, embed_model, faiss_index, chunk_map, initial_profile=None):
        self.parser = SymptomParser(gemini_model)
        self.profile_agent = ProfileAgent(initial_profile)
        self.retriever = RetrievalAgent(embed_model, faiss_index, chunk_map, self.profile_agent)
        self.advisor = DiagnosisAdvisor(gemini_model)
        self.safety = ActiveSafetyFilter(gemini_model)

        self.history: List[Dict[str, Any]] = []
        self.provenance_chain: List[Dict[str, Any]] = []

    def _handle_simple_greeting(self, user_input: str):
        user_id = self.profile_agent.profile.get("user_id", "there")
        msg = (
            f"Hello {user_id}! I'm your **Clinical Trial Research Assistant**. üî¨\n\n"
            "I can help you explore clinical trials for conditions such as:\n"
            "- Diabetes\n"
            "- Cancer\n"
            "- Alzheimer's disease\n"
            "- Asthma\n"
            "- Cardiovascular disease\n\n"
            "I search a database of tens of thousands of real trials (e.g., from ClinicalTrials.gov).\n\n"
            "**You can ask things like:**\n"
            "- 'What trials are studying insulin therapy in diabetes?'\n"
            "- 'Are there breast cancer trials targeting HER2?'\n"
            "- 'Trials for memory loss and Alzheimer's?'\n"
            "- 'I'm 55 with type 2 diabetes, what trials can I join?'\n\n"
            "What condition and question would you like to explore?"
        )

        log = log_provenance_step("GreetingAgent", user_input, msg, {"type": "greeting"})
        self.provenance_chain.append(log)

        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])
        self.history.append({"query": user_input, "response_hash": session_hash})

        return {
            "recommendation": msg,
            "cited_trials": [],
            "safety_status": "Non-RAG",
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }

    def _handle_off_topic(self, user_input: str, parsed: Dict[str, Any]):
        msg = (
            "I‚Äôm specialized in clinical trials for medical conditions (for example diabetes, cancer, "
            "Alzheimer‚Äôs disease, asthma, and cardiovascular disease). "
            "Your question doesn‚Äôt seem to be about a health condition or clinical research. "
            "If you‚Äôd like, you can ask me to find trials for a specific condition."
        )
        log = log_provenance_step("OffTopicHandler", user_input, msg, {"type": "off_topic"})
        self.provenance_chain.append(log)
        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])

        return {
            "recommendation": msg,
            "cited_trials": [],
            "safety_status": "Off-topic",
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }

    def _handle_knowledge_question(self, user_input: str, parsed: Dict[str, Any]):
        user_question = parsed.get("user_question", user_input)
        prompt = (
            "You are a medical research educator. Answer this question clearly and accurately in 3‚Äì6 sentences.\n"
            "Do NOT give diagnoses or individual treatment instructions.\n"
            f"QUESTION: {user_question}\n"
            "End with: 'For personalized advice, please consult your healthcare provider.'\n"
        )
        try:
            res = self.advisor.model.generate_content(prompt)
            answer = (res.text or "").strip()
        except Exception:
            answer = "I'm unable to answer this right now. For personalized advice, please consult your healthcare provider."

        log = log_provenance_step("KnowledgeAgent", user_input, answer, {"type": "general_knowledge"})
        self.provenance_chain.append(log)
        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])

        return {
            "recommendation": answer,
            "cited_trials": [],
            "safety_status": "Knowledge-Based",
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }

    def _handle_generic_trial_query(self, user_input: str, parsed: Dict[str, Any]):
        """Handle very generic queries that need more specificity."""
        msg = (
            "That question is a bit broad. I have a large database of clinical trials across conditions like "
            "diabetes, cancer, Alzheimer's, asthma, and cardiovascular disease.\n\n"
            "To help you better, you can specify:\n\n"
            "**Example trial searches by condition**\n"
            "- Diabetes: 'trials testing new insulin pumps', 'GLP-1 diabetes trials'\n"
            "- Cancer: 'HER2-positive breast cancer trials', 'lung cancer immunotherapy trials'\n"
            "- Alzheimer‚Äôs: 'trials for early Alzheimer‚Äôs disease', 'memory loss drug trials'\n"
            "- Asthma: 'pediatric asthma trials', 'new inhaler trials'\n"
            "- Cardiovascular: 'heart failure trials', 'hypertension drug trials'\n\n"
            "**Or describe your situation:**\n"
            "- 'I have type 2 diabetes and obesity, what trials might fit me?'\n"
            "- 'My father has metastatic lung cancer, any trials?'\n\n"
            "What condition and type of trial are you most interested in?"
        )

        log = log_provenance_step("GenericQueryHandler", user_input, msg, {"type": "needs_refinement"})
        self.provenance_chain.append(log)

        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])
        self.history.append({"query": user_input, "response_hash": session_hash})

        return {
            "recommendation": msg,
            "cited_trials": [],
            "safety_status": "Refinement Needed",
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }

    def _handle_missing_disease(self, user_input: str, parsed: Dict[str, Any]):
        """
        Fallback A (your choice): If we can't detect any disease,
        ask the user to specify the condition explicitly.
        """
        msg = (
            "I can search clinical trials for conditions such as diabetes, cancer, Alzheimer's disease, "
            "asthma, and cardiovascular disease.\n\n"
            "I couldn‚Äôt clearly tell which condition you meant from your last message.\n\n"
            "Please tell me which condition you‚Äôre interested in and, if you‚Äôd like, what type of trial.\n"
            "For example:\n"
            "- 'Diabetes ‚Äì trials for new insulin therapies'\n"
            "- 'Breast cancer ‚Äì HER2 targeted trials'\n"
            "- 'Alzheimer‚Äôs ‚Äì early-stage drug trials'\n"
            "- 'Asthma ‚Äì trials for severe asthma in adults'\n"
        )
        log = log_provenance_step("MissingDiseaseHandler", user_input, msg, {"type": "missing_disease"})
        self.provenance_chain.append(log)

        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])
        self.history.append({"query": user_input, "response_hash": session_hash})

        return {
            "recommendation": msg,
            "cited_trials": [],
            "safety_status": "Clarification Needed",
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }

    def process_query(self, user_input: str):
        self.provenance_chain = []

        # 1. Parse
        parsed, parse_log = self.parser.parse(user_input)
        self.provenance_chain.append(parse_log)

        intent = (parsed.get("intent") or "trial_search").lower()
        query_type = parsed.get("query_type", "trial_query")
        is_disease_related = parsed.get("is_disease_related", True)
        disease_focus = parsed.get("disease_focus") or []

        # Greetings
        if intent == "greeting":
            return self._handle_simple_greeting(user_input)

        # Off-topic
        if intent == "off_topic" or not is_disease_related:
            return self._handle_off_topic(user_input, parsed)

        # Profile info
        if intent == "profile_info":
            msg = (
                "Thank you for sharing your information. I've noted your details conceptually. "
                "What type of clinical trials would you like to explore? "
                "For example: 'diabetes trials for new medications' or 'breast cancer trials for HER2-positive disease'."
            )
            log = log_provenance_step("ProfileAgent", user_input, msg, {"action": "profile_stored"})
            self.provenance_chain.append(log)

            session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])
            return {
                "recommendation": msg,
                "cited_trials": [],
                "safety_status": "Profile Update",
                "session_hash": session_hash,
                "provenance_chain": self.provenance_chain,
            }

        # Pure education (no trial search)
        if intent == "general_question" and query_type == "knowledge_seeking":
            if "trial" not in user_input.lower() and "study" not in user_input.lower():
                return self._handle_knowledge_question(user_input, parsed)

        # Fallback A: no disease detected ‚Üí ask user to specify condition
        if not disease_focus:
            return self._handle_missing_disease(user_input, parsed)

        # DEFAULT: trial search
        retrieved, retrieve_log = self.retriever.retrieve(parsed)
        self.provenance_chain.append(retrieve_log)

        trials = retrieved.get("trials", [])
        avg_conf = retrieved.get("avg_confidence", 0.0)
        top_score = trials[0]["weighted_score"] if trials else 0.0

        # Generic query detection (very broad wording)
        generic_terms = ["new", "any", "some", "recent", "latest", "medications", "drugs", "treatments", "trials", "studies"]
        is_generic = sum(1 for term in generic_terms if term in user_input.lower()) >= 2

        if is_generic and (avg_conf < 0.50 or top_score < 0.55):
            return self._handle_generic_trial_query(user_input, parsed)

        # 3. Advisor
        draft_advice, advise_log = self.advisor.advise(parsed, retrieved, self.profile_agent.profile)
        self.provenance_chain.append(advise_log)

        trials = retrieved.get("trials", [])
        if draft_advice.get("confidence_veto", False) or not trials:
            final_text = draft_advice["recommendation"]
            safety_status = "Vetoed (Low Confidence)"
            evidence_list = []
        else:
            final_text, safety_status, safety_log = self.safety.verify(draft_advice["recommendation"], trials)
            self.provenance_chain.append(safety_log)
            evidence_list = trials

        nct_ids = [t["nct_id"] for t in evidence_list]
        session_hash = generate_reproducibility_hash(self.history + [{"query": user_input}])

        # Update profile/history
        turn_data = {
            "query": user_input,
            "parsed": parsed,
            "nct_ids": nct_ids,
            "safety_status": safety_status,
            "session_hash": session_hash,
        }
        profile_log = self.profile_agent.update_profile(turn_data)
        self.provenance_chain.append(profile_log)
        self.history.append({"query": user_input, "response_hash": session_hash})

        return {
            "recommendation": final_text,
            "cited_trials": nct_ids,
            "safety_status": safety_status,
            "session_hash": session_hash,
            "provenance_chain": self.provenance_chain,
        }


# ============================================================
# GLOBAL BOT INSTANCE + ENTRYPOINT
# ============================================================
default_profile = {
    "user_id": "Patient",
    "conditions": [],
    "extracted_conditions": [],
}

_bot = HealthcareBot(gemini_model, embed_model, faiss_index, chunk_map, initial_profile=default_profile)

def run_bot(user_input: str) -> Dict[str, Any]:
    return _bot.process_query(user_input)

‚úÖ sentence_transformers imported successfully.
‚è≥ Loading pre-built RAG index...
‚úÖ RAG Index Ready: 262660 vectors loaded.
‚è≥ Loading Reranker Model (Cross-Encoder)...
‚úÖ Reranker Loaded.


UI frontend application simple web interface

https://docs.streamlit.io/develop/tutorials/chat-and-llm-apps/build-conversational-apps

In [None]:
%%writefile app.py
import streamlit as st
import os
import importlib
import run_bot  # import module not function

# Force reload of run_bot.py so Streamlit uses updated code
importlib.reload(run_bot)

st.title("Clinical Trial Health Advisor ü§ñ")
st.caption("AI for Healthcare - Clinical Trials RAG")

# API Key validation
if "GEMINI_API_KEY" not in os.environ:
    st.error("‚ö†Ô∏è API Key missing! Please run the 'Secure Input' cell in the notebook first.")
    st.stop()

if "messages" not in st.session_state:
    st.session_state.messages = []

# Load previous chat history
for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])

# Chat input
if user_input := st.chat_input("Describe your symptoms or ask about clinical trials..."):
    st.session_state.messages.append({"role": "user", "content": user_input})
    with st.chat_message("user"):
        st.markdown(user_input)

    with st.spinner("Searching clinical trials..."):
        # Call updated run_bot function
        result = run_bot.run_bot(user_input)
        reply = result.get("recommendation", "No response available.")

    with st.chat_message("assistant"):
        st.markdown(reply)

    st.session_state.messages.append({"role": "assistant", "content": reply})


Overwriting app.py


In [None]:
!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64
!mv cloudflared-linux-amd64 cloudflared
!chmod +x cloudflared

In [None]:
#AI LLM
!streamlit run app.py &>/dev/null&
!./cloudflared tunnel --url http://localhost:8501 --no-autoupdate

[90m2025-11-28T00:40:08Z[0m [32mINF[0m Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee, are subject to the Cloudflare Online Services Terms of Use (https://www.cloudflare.com/website-terms/), and Cloudflare reserves the right to investigate your use of Tunnels for violations of such terms. If you intend to use Tunnels in production you should use a pre-created named tunnel by following: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps
[90m2025-11-28T00:40:08Z[0m [32mINF[0m Requesting new quick Tunnel on trycloudflare.com...
[90m2025-11-28T00:40:11Z[0m [32mINF[0m +--------------------------------------------------------------------------------------------+
[90m2025-11-28T00:40:11Z[0m [32mINF[0m |  Your quick Tunnel has been created! Visit it at (it may take some time to be reachable):  |
[90m2025

In [None]:
response = run_bot("What trials are studying insulin therapy?")
print(response["recommendation"])



I found 5 clinical trials relevant to your request:

**NCT04981808** (Relevance: 95%)
‚Ä¢ Diabetes teleMonitoring of Patients in Insulin Therapy
  Status: Completed

  The trial is an open-label randomized controlled trial. Patients with T2D on insulin therapy will be randomized to a telemonitoring group (intervention) and a usual care group (control). The telemonitoring group will use various devices at home. Hospital staff will monitor their data for a period of three months.

  PubMed abstract (PMID 36476605):
  1. Trials. 2022 Dec 7;23(1):985. doi: 10.1186/s13063-022-06921-6.

The Diabetes teleMonitoring of patients in insulin Therapy (DiaMonT) trial: 
study protocol for a randomized controlled trial.

Hangaard S(1)(2), Kronborg T(3)(4), Hejlesen O(4), Arad√≥ttir TB(5), Kaas A(5), 
Bengtsson H(5), Vestergaard P(4)(6)(7), Jensen MH(3)(4).

Author information:
(1)Steno Diabetes Center North Denmark, M√∏lleparkvej 4, 9000, Aalborg, Denmark. 
svh@hst.aau.dk.
(2)Department of Health Sci

In [None]:
response = run_bot("What trials are studying asthma prevention?")
print(response["recommendation"])



I found 5 clinical trials relevant to your request:

**NCT00214526** (Relevance: 93%)
‚Ä¢ Asthma Intervention Research (AIR) Trial
  Status: Completed

  The purpose of this study is to demonstrate the effectiveness and safety of the Alair System for the treatment of asthma.

This will be a multicenter, randomized controlled study comparing the effects of treatment with the Alair System to standard drug therapy. One-hundred and ten subjects will be randomized 1:1 to either the Alair Group (Medical management + Alair treatment),or Control Group (Medical management only).

  PubMed abstract (PMID 17392302):
  1. N Engl J Med. 2007 Mar 29;356(13):1327-37. doi: 10.1056/NEJMoa064707.

Asthma control during the year after bronchial thermoplasty.

Cox G(1), Thomson NC, Rubin AS, Niven RM, Corris PA, Siersted HC, Olivenstein R, 
Pavord ID, McCormack D, Chaudhuri R, Miller JD, Laviolette M; AIR Trial Study 
Group.

Author information:
(1)St. Joseph's Healthcare, McMaster University, Hamilton, O

In [None]:
response = run_bot("hi?")
print(response["recommendation"])



Hello Patient! I'm your **Clinical Trial Research Assistant**. üî¨

I can help you explore clinical trials for conditions such as:
- Diabetes
- Cancer
- Alzheimer's disease
- Asthma
- Cardiovascular disease

I search a database of tens of thousands of real trials (e.g., from ClinicalTrials.gov).

**You can ask things like:**
- 'What trials are studying insulin therapy in diabetes?'
- 'Are there breast cancer trials targeting HER2?'
- 'Trials for memory loss and Alzheimer's?'
- 'I'm 55 with type 2 diabetes, what trials can I join?'

What condition and question would you like to explore?
