### [Blog Post](https://storm-jacket-4ec.notion.site/1dbd34e9912680ed826ccd1d5a991792?pvs=74)

In [1]:
# !pip install -qU "google-genai==1.7.0" "chromadb==0.6.3"
!pip uninstall -qqy jupyterlab kfp
!pip install -qU \
    google-genai==1.7.0 \
    chromadb==0.6.3 \
    langchain \
    tenacity \
    rapidfuzz \
    rouge-score \
    scikit-learn

[0m  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.7/144.7 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m611.1/611.1 kB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m45.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m70.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━

In [2]:
!pip install -U -q "google-genai==1.7.0"
!pip install rouge-score -q
!pip install langgraph
!pip install gradio==5.25.2

Collecting langgraph
  Downloading langgraph-0.3.31-py3-none-any.whl.metadata (7.9 kB)
Collecting langgraph-checkpoint<3.0.0,>=2.0.10 (from langgraph)
  Downloading langgraph_checkpoint-2.0.24-py3-none-any.whl.metadata (4.6 kB)
Collecting langgraph-prebuilt<0.2,>=0.1.8 (from langgraph)
  Downloading langgraph_prebuilt-0.1.8-py3-none-any.whl.metadata (5.0 kB)
Collecting langgraph-sdk<0.2.0,>=0.1.42 (from langgraph)
  Downloading langgraph_sdk-0.1.61-py3-none-any.whl.metadata (1.8 kB)
Collecting ormsgpack<2.0.0,>=1.8.0 (from langgraph-checkpoint<3.0.0,>=2.0.10->langgraph)
  Downloading ormsgpack-1.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.5/43.5 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Downloading langgraph-0.3.31-py3-none-any.whl (145 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.2/145.2 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lang

In [4]:
CHROMA_DB_PATH = '/kaggle/working/chromadb_2_sample_v11'

In [5]:

import pandas as pd
from google import genai
from google.genai import types
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import time
from tqdm import tqdm
import warnings
from langgraph.graph import StateGraph, END
from typing import TypedDict, List
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_exception_type
from google.api_core.exceptions import ResourceExhausted, InternalServerError, ServiceUnavailable, DeadlineExceeded
from kaggle_secrets import UserSecretsClient
import sqlite3
import uuid
import os
import glob
import json
from langchain.text_splitter import RecursiveCharacterTextSplitter
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.errors import InvalidCollectionException
from rouge_score import rouge_scorer
from rapidfuzz import fuzz
import gradio as gr
import shutil
import io
import sys
import logging
from contextlib import redirect_stdout

# Suppress warnings
warnings.filterwarnings("ignore")
tqdm.pandas()

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)
log_buffer = io.StringIO()
stream_handler = logging.StreamHandler(log_buffer)
stream_handler.setLevel(logging.INFO)
logger.addHandler(stream_handler)
log_file_handler = logging.FileHandler("/kaggle/working/pipeline.log")
log_file_handler.setLevel(logging.INFO)
logger.addHandler(log_file_handler)

# Configuration: API Key Setup
try:
    GOOGLE_API_KEY = UserSecretsClient().get_secret("Google_API_Key")
except ImportError:
    GOOGLE_API_KEY = "YOUR_API_KEY_HERE"  # Replace with your API key
try:
    client = genai.Client(api_key=GOOGLE_API_KEY)
    client.models.list()
    print("✅ API client initialized successfully.")
    logger.info("✅ API client initialized successfully.")
except Exception as e:
    print(f"Failed to initialize Google API client: {e}")
    logger.error(f"Failed to initialize Google API client: {e}")
    raise

# Initialize ROUGE scorer
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

# LangGraph State
class PipelineState(TypedDict):
    raw_text: List[str]
    summaries: List[str]
    summary_ids: List[str]
    criteria_prompts: List[dict]
    chunk_ids: List[str]
    eligibility_results: List[dict]
    embedding_function: EmbeddingFunction
    evaluation_results: List[dict]
    trial_names: List[str]

# Summarization Function
def summarize_text(state: PipelineState) -> PipelineState:
    print("Starting summarization step...")
    logger.info("Starting summarization step...")
    sys.stdout.flush()
    
    print("Loading dataset...")
    logger.info("Loading dataset...")
    try:
        df = pd.read_csv('/kaggle/input/mimic-iv-2/mimic_iv_summarization_test_dataset_shortened_edited.csv')
        df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)
        print(f"Dataset loaded: {len(df_train)} training, {len(df_test)} test samples.")
        logger.info(f"Dataset loaded: {len(df_train)} training, {len(df_test)} test samples.")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        logger.error(f"Error loading dataset: {e}")
        raise
    
    for i, text in enumerate(state['raw_text']):
        print(f"Input {i+1} (length: {len(text)}): {text[:100]}...")
        logger.info(f"Input {i+1} (length: {len(text)}): {text[:100]}...")
        if not text.strip() or text.count("___") > 10:
            print(f"Warning: Input {i+1} may contain excessive placeholders or be empty.")
            logger.warning(f"Input {i+1} may contain excessive placeholders or be empty.")
    
    input_data = {
        'examples': 
            df_train[['text', 'summary']]
            .rename(columns={'text': 'textInput', 'summary': 'output'})
            .to_dict(orient='records')
    }

    print(f"input data length: {len(input_data)}")
    print("Prepared input data for fine-tuning.")
    logger.info("Prepared input data for fine-tuning.")
    
    model_id = None
    queued_model = None
    try:
        for m in reversed(client.tunings.list()):
            if m.name.startswith('tunedModels/summarization-model'):
                if m.state.name == 'JOB_STATE_SUCCEEDED':
                    # model_id = m.name
                    # print(f'Found existing tuned model to reuse: {model_id}')
                    # logger.info(f'Found existing tuned model to reuse: {model_id}')
                    model_id = 'tunedModels/summarization-model-9zeddx1lwavq'
                    break
                elif m.state.name == 'JOB_STATE_RUNNING' and not queued_model:
                    queued_model = m.name
    except Exception as e:
        print(f"Error checking existing tuned models: {e}")
        logger.error(f"Error checking existing tuned models: {e}")
    
    if queued_model and not model_id:
        model_id = queued_model
        print(f'Found queued model, still waiting: {model_id}')
        logger.info(f'Found queued model, still waiting: {model_id}')
        try:
            while not (tuned_model := client.tunings.get(name=model_id)).has_ended:
                print(f'Tuning state: {tuned_model.state}')
                logger.info(f'Tuning state: {tuned_model.state}')
                time.sleep(60)
            if not tuned_model.has_succeeded and tuned_model.error:
                print(f"Error during tuning: {tuned_model.error}")
                logger.error(f"Error during tuning: {tuned_model.error}")
                model_id = None
        except Exception as e:
            print(f"Error waiting for queued model: {e}")
            logger.error(f"Error waiting for queued model: {e}")
            model_id = None
    
    if not model_id:
        try:
            print("Starting fine-tuning...")
            logger.info("Starting fine-tuning...")
            tuning_op = client.tunings.tune(
                base_model="models/gemini-1.5-flash-001-tuning",
                training_dataset=input_data,
                config=types.CreateTuningJobConfig(
                    tuned_model_display_name="Summarization model",
                    batch_size=16,
                    epoch_count=2,
                ),
            )
            model_id = tuning_op.name
            print(f'Started tuning job: {model_id}')
            logger.info(f'Started tuning job: {model_id}')
            while not (tuned_model := client.tunings.get(name=model_id)).has_ended:
                print(f'Tuning state: {tuned_model.state}')
                logger.info(f'Tuning state: {tuned_model.state}')
                time.sleep(60)
            if not tuned_model.has_succeeded and tuned_model.error:
                print(f"Error during tuning: {tuned_model.error}")
                logger.error(f"Error during tuning: {tuned_model.error}")
                raise Exception(f"Tuning failed: {tuned_model.error}")
            print("✅ Fine-tuning completed successfully.")
            logger.info("✅ Fine-tuning completed successfully.")
        except Exception as e:
            print(f"Failed to fine-tune model: {e}")
            logger.error(f"Failed to fine-tune model: {e}")
            model_id = "gemini-1.5-flash-001"
            print(f"Using pre-trained model: {model_id}")
            logger.info(f"Using pre-trained model: {model_id}")
    
    @retry(
        wait=wait_random_exponential(min=1, max=20),
        stop=stop_after_attempt(15),
        retry=retry_if_exception_type((ResourceExhausted, InternalServerError, ServiceUnavailable, DeadlineExceeded))
    )
    def generate_summary(text: str) -> str:
        print(f"Generating summary for text (length: {len(text)})...")
        logger.info(f"Generating summary for text (length: {len(text)})...")
        prompt = f"""
        Summarize the following clinical discharge note in a concise and detailed manner. Include key information such as primary diagnoses, major treatments, procedures, and outcomes. Exclude placeholders (e.g., '___') and focus on clinical content. The summary should be no longer than 500 words and written in a professional medical tone.

        Clinical Note:
        {text}
        """

        print(f'model_id: {model_id}')
        try:
            response = client.models.generate_content(
                model=model_id,
                contents=prompt,
                config=types.GenerateContentConfig()
            )
            rc = response.candidates[0]
            print(f"API Finish Reason: {rc.finish_reason.name}")
            logger.info(f"API Finish Reason: {rc.finish_reason.name}")
            if rc.finish_reason.name != "STOP":
                print(f"Warning: Incomplete response (reason: {rc.finish_reason.name})")
                logger.warning(f"Incomplete response (reason: {rc.finish_reason.name})")
            summary = rc.content.parts[0].text
            if not summary.strip() or len(summary) < 50:
                print(f"Warning: Summary is empty or too short (length: {len(summary)})")
                logger.warning(f"Summary is empty or too short (length: {len(summary)})")
                return "(error: empty or invalid summary)"
            print(f"Summary generated (length: {len(summary)}): {summary[:100]}...")
            logger.info(f"Summary generated (length: {len(summary)}): {summary[:100]}...")
            return summary
        except Exception as e:
            print(f"Error generating summary: {e}")
            logger.error(f"Error generating summary: {e}")
            return f"(error: {str(e)})"
    
    state['summaries'] = []
    for text in state['raw_text']:
        summary = generate_summary(text)
        if "(error" in summary.lower():
            print(f"Skipping invalid summary: {summary[:100]}...")
            logger.warning(f"Skipping invalid summary: {summary[:100]}...")
        else:
            state['summaries'].append(summary)
    
    print(f"✅ Generated {len(state['summaries'])} valid summaries.")
    logger.info(f"✅ Generated {len(state['summaries'])} valid summaries.")
    print("Sample summaries:", [s[:100] for s in state['summaries'][:2]])
    logger.info("Sample summaries: %s", [s[:100] for s in state['summaries'][:2]])
    if not state['summaries']:
        print("❌ No valid summaries generated. Pipeline will continue with empty summaries.")
        logger.error("❌ No valid summaries generated. Pipeline will continue with empty summaries.")
    
    return state

# Database Storage Function
def store_in_database(state: PipelineState) -> PipelineState:
    print("Starting database storage step...")
    logger.info("Starting database storage step...")
    sys.stdout.flush()
    
    db_path = "/kaggle/working/clinical_data.db"
    
    try:
        with sqlite3.connect(db_path) as conn:
            cursor = conn.cursor()
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS summaries (
                    id TEXT PRIMARY KEY,
                    original_text TEXT,
                    summary TEXT,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """)
            conn.commit()
        print(f"✅ Database initialized at {db_path}.")
        logger.info(f"✅ Database initialized at {db_path}.")
    except Exception as e:
        print(f"Error initializing database: {e}")
        logger.error(f"Error initializing database: {e}")
        raise
    
    summary_ids = []
    for text, summary in zip(state['raw_text'], state['summaries'] + [""] * (len(state['raw_text']) - len(state['summaries']))):
        if "(error" in summary.lower() or not summary.strip():
            print(f"Skipping storage for invalid summary: {summary[:100]}...")
            logger.warning(f"Skipping storage for invalid summary: {summary[:100]}...")
            continue
        record_id = str(uuid.uuid4())
        try:
            with sqlite3.connect(db_path) as conn:
                cursor = conn.cursor()
                cursor.execute(
                    "INSERT INTO summaries (id, original_text, summary) VALUES (?, ?, ?)",
                    (record_id, text, summary)
                )
                conn.commit()
            print(f"Stored summary (ID: {record_id}, summary length: {len(summary)}).")
            logger.info(f"Stored summary (ID: {record_id}, summary length: {len(summary)}).")
            summary_ids.append(record_id)
        except Exception as e:
            print(f"Error storing summary (ID: {record_id}): {e}")
            logger.error(f"Error storing summary (ID: {record_id}): {e}")
            raise
    
    state['summary_ids'] = summary_ids
    print(f"✅ Stored {len(summary_ids)} summaries in database.")
    logger.info(f"✅ Stored {len(summary_ids)} summaries in database.")
    return state

# Criteria Processing Function
def process_criteria(state: PipelineState) -> PipelineState:
    print("Starting criteria processing step...")
    logger.info("Starting criteria processing step...")
    sys.stdout.flush()
    
    csv_folder = "/kaggle/input/clinical-trial-sample"
    output_dir = "/kaggle/working/Criteria_based_Prompts"
    os.makedirs(output_dir, exist_ok=True)
    
    csv_files = glob.glob(os.path.join(csv_folder, "*.csv"))
    print(f"Found {len(csv_files)} CSV files in {csv_folder}.")
    logger.info(f"Found {len(csv_files)} CSV files in {csv_folder}.")
    
    criteria_prompts = []
    trial_names = []
    
    for file_path in csv_files:
        try:
            df = pd.read_csv(file_path)
            if 'Criteria_name' not in df.columns or 'Content' not in df.columns:
                print(f"❌ Skipping {file_path}: Missing 'Criteria_name' or 'Content' columns.")
                logger.error(f"❌ Skipping {file_path}: Missing 'Criteria_name' or 'Content' columns.")
                continue
            df = df[['Criteria_name', 'Content']]
            base_name = os.path.splitext(os.path.basename(file_path))[0]
            trial_names.append(base_name)
            json_filename = os.path.join("/kaggle/working", f"{base_name}.json")
            
            df.to_json(json_filename, orient='records', lines=False)
            print(f"Converted {file_path} to {json_filename}.")
            logger.info(f"Converted {file_path} to {json_filename}.")
            
            with open(json_filename, "r") as f:
                data = json.load(f)
            
            cot_prompt = '''{
  "instructions": "You are a medical data analyst specialized in creating prompts for LLMs to process clinical trial criteria. Given a document containing two sections, 'Criteria' and 'Content', review the document and create an LLM prompt for each criterion by following the steps below.",
  "global_instructions": {
    "processing_steps": [
      "1. Think through each step wisely.",
      "2. Go through each criterion one-by-one.",
      "3. Return the criterion name as 'Criteria_name'.",
      "4. Return the context of the criterion specified in the associated 'Content' as 'Text' and add clarification/interpretation.",
      "5. The 'Text' should not be more than 2 sentences.",
      "6. Specify less than 10 steps for chain of thoughts (step-by-step reasoning) for an LLM to help process each criterion.",
      "7. The created prompt must emphasize on the fact that the model **is not allowed to infer** a condition based on the **symptoms or medication**, and all the evidences extracted **must have an actual match** in the **clinical note**.",
      "8. The created prompt must emphasize on the fact that the model must only extract the conditions pertinent to the patients only (not their family).",
      "9. If the criterion is condition or medication related, bring some examples for more clarity as 'extractableLexicons'. Return an empty list '[]' if the criteria is not condition or medication related.",
      "10. Ensure the number of arrays match the total number of criteria.",
      "11. Use the global output format for all criteria evaluation results."
    ],
    "output_format": {
      "format_instructions": [
        "1. You must always return **a valid JSON array**.",
        "2. You must return the **eligibility** of the patient as 'Yes' or 'No'.",
        "3. You must return the **evidences** of the criterion met if available. If the patient doesn't have any evidences to be eligible, return None.",
        "4. You should return an array for each matched evidence.",
        "5. Your output must look like the following: {\"summary_id\": \"ID\", \"Eligibility\": \"Yes\" or \"No\", \"Evidence\": \"The relevant evidence or None\"}"
      ]
    },
    "final_output_structure": [
      {
        "Criteria_name": "name of the criteria",
        "Text": "the context of the criterion and interpretation/clarification",
        "chainOfThought": [
          "Step 1: the relevant step",
          "Step 2: the relevant step",
          "Step 3: the relevant step",
          "..."
        ],
        "extractableLexicons": "2-4 examples, if the criterion is medication or condition related",
        "Output": [
          "1. You must always return **a valid JSON array**.",
          "2. You must return the **eligibility** of the patient as 'Yes' or 'No'.",
          "3. You must return the **evidences** of the criterion met if available. If the patient doesn't have any evidences to be eligible, return None.",
          "4. You should return an array for each matched evidence.",
          "5. Your output must look like the following: {\"summary_id\": \"ID\", \"Eligibility\": \"Yes\" or \"No\", \"Evidence\": \"The relevant evidence or None\"}"
        ]
      }
    ]
  },
  "Example_Input": [
    {
      "Criteria_name": "Inclusion_1", 
      "Content": "CKD (eGFR of less than 60 mL/min/1.73 m²)"
    },
    {
      "Criteria_name": "Exclusion_1", 
      "Content": "Known HF hospitalization history"
    }
  ],
  "Example_Output": [
    {
      "Criteria_name": "Inclusion_1",
      "Text": "Clinical diagnosis of CKD or an estimated glomerular filtration rate (eGFR) of less than 60 mL/min/1.73 m².",
      "chainOfThought": [
        "Step 1: Check if the patient has a documented clinical diagnosis of CKD in their medical records.",
        "Step 2: If any related lexicon is specified in 'extractableLexicons', use as example. But you should search for other synonyms as well.",
        "Step 3: You are **not allowed to infer** a condition based on the **symptoms or medication**.",
        "Step 4: You must only extract the conditions **pertinent to the patients** only (not their family).",
        "Step 5: Only extract the evidences that have an actual match in the text.",
        "Step 6: Verify if the patient has any eGFR value recorded.",
        "Step 7: Check if the patient's eGFR is less than 60 mL/min/1.73 m².",
        "Step 8: If any of the conditions in steps 1 or 6 are met, set 'eligibility' as 'Yes', set 'No' otherwise.",
        "Step 9: Return the eligibility evidences as 'Evidence'."
      ],
      "extractableLexicons": [
        "estimated glomerular filtration rate < 60 mL/min/1.73 m²",
        "eGFR < 60 mL/min/1.73 m²",
        "eGFR = 55 mL/min/1.73 m²",
        "chronic kidney disease",
        "CKD"
      ],
      "Output": [
        "1. You must always return **a valid JSON array**.",
        "2. You must return the **eligibility** of the patient as 'Yes' or 'No'.",
        "3. You must return the **evidences** of the criterion met if available. If the patient doesn't have any evidences to be eligible, return None.",
        "4. You should return an array for each matched evidence.",
        "5. Your output must look like the following: {\"summary_id\": \"ID\", \"Eligibility\": \"Yes\" or \"No\", \"Evidence\": \"The relevant evidence or None\"}"
      ]
    },
    { 
      "Criteria_name": "Exclusion_1",
      "Text": "History of known Heart Failure hospitalization",
      "chainOfThought": [
        "Step 1: Check the patient's hospitalization history for any admissions related to heart failure.",
        "Step 2: If any related lexicon is specified in 'extractableLexicons', use as example. But you should search for other synonyms as well.",
        "Step 3: Look for discharge diagnoses that include heart failure or related terms.",
        "Step 4: You are **not allowed to infer** a condition based on the **symptoms or medications**.",
        "Step 5: You must only extract the conditions **pertinent to the patients** only (not their family).",
        "Step 6: Only extract the evidences that have an actual match in text.",
        "Step 7: If any hospitalization for HF is documented, this exclusion criterion is met and you should return 'Yes' as 'eligibility'. Return 'No', otherwise.",
        "Step 8: If any hospitalization for HF is documented return the evidence as 'Evidence'. Return None, otherwise."
      ],
      "extractableLexicons": [
        "heart failure hospitalization",
        "admitted for CHF",
        "hospitalized due to heart failure",
        "inpatient stay for cardiac failure"
      ],
      "Output": [
        "1. You must always return **a valid JSON array**.",
        "2. You must return the **eligibility** of the patient as 'Yes' or 'No'.",
        "3. You must return the **evidences** of the criterion met if available. If the patient doesn't have any evidences to be eligible, return None.",
        "4. You should return an array for each matched evidence.",
        "5. Your output must look like the following: {\"summary_id\": \"ID\", \"Eligibility\": \"Yes\" or \"No\", \"Evidence\": \"The relevant evidence or None\"}"
      ]
    }
  ],
  "response_format_instructions": "When generating LLM prompts for clinical trial criteria, use the following structure for each criterion. Each prompt should follow the global instructions while providing specific guidance for evaluating that particular criterion. Output should consistently reference the global output format and avoid redundancy.",
  "expected_model_response": {
    "format": {
      "summary_id": "Patient summary ID",
      "Eligibility": "Yes or No",
      "Evidence": "Relevant text from clinical note or None"
    }
  }
}'''
            
            prompt = cot_prompt + f"\nYou are given a JSON-formatted clinical trial criteria description below. PASSAGE:\n{json.dumps(data)}\n"
            
            try:
                response = client.models.generate_content(
                    model="gemini-1.5-flash",
                    contents=prompt,
                    config=types.GenerateContentConfig()
                )
                response_text = response.candidates[0].content.parts[0].text
                response_text = response_text.replace("```json", "").replace("```", "").strip()
                parsed = json.loads(response_text)
                
                subfolder = os.path.join(output_dir, base_name)
                os.makedirs(subfolder, exist_ok=True)
                
                for item in parsed:
                    item['Trial_Name'] = base_name
                    filename = f"{item['Criteria_name']}.json"
                    filepath = os.path.join(subfolder, filename)
                    with open(filepath, "w", encoding="utf-8") as f:
                        json.dump(item, f, ensure_ascii=False, indent=2)
                    criteria_prompts.append(item)
                
                print(f"✅ Processed {base_name}: {len(parsed)} prompts written to {subfolder}.")
                logger.info(f"✅ Processed {base_name}: {len(parsed)} prompts written to {subfolder}.")
            except Exception as e:
                print(f"❌ Error generating prompts for {json_filename}: {e}")
                logger.error(f"❌ Error generating prompts for {json_filename}: {e}")
                continue
        
        except Exception as e:
            print(f"❌ Error processing {file_path}: {e}")
            logger.error(f"❌ Error processing {file_path}: {e}")
            continue
    
    state['criteria_prompts'] = criteria_prompts
    state['trial_names'] = list(set(trial_names))
    print(f"✅ Generated {len(criteria_prompts)} criteria prompts.")
    logger.info(f"✅ Generated {len(criteria_prompts)} criteria prompts.")
    print(f"✅ Found {len(state['trial_names'])} trial names: {state['trial_names']}")
    logger.info(f"✅ Found {len(state['trial_names'])} trial names: {state['trial_names']}")
    return state

# Chunking and Embedding Function
def chunk_and_embed(state: PipelineState) -> PipelineState:
    print("Starting chunking and embedding step...")
    logger.info("Starting chunking and embedding step...")
    sys.stdout.flush()
    
    # Initialize ChromaDB client
    chromadb_path = CHROMA_DB_PATH
    # import os
    # import shutil
    
    # if os.path.exists(chromadb_path):
    #     try:
    #         shutil.rmtree(chromadb_path)
    #         print(f"✅ Successfully removed directory: {chromadb_path}")
    #     except OSError as e:
    #         print(f"❌ Error removing directory {chromadb_path}: {e}")
    # else:
    #     print(f"⚠️ Directory does not exist: {chromadb_path}")

    
    os.makedirs(chromadb_path, exist_ok=True)
    try:
        chroma_client = chromadb.PersistentClient(path=chromadb_path)
        print(f"✅ Persistent ChromaDB client initialized at {chromadb_path}.")
        logger.info(f"✅ Persistent ChromaDB client initialized at {chromadb_path}.")
    except Exception as e:
        print(f"❌ Error initializing persistent ChromaDB client: {e}")
        logger.error(f"❌ Error initializing persistent ChromaDB client: {e}")
        raise
    
    class GeminiEmbeddingFunction(EmbeddingFunction):
        def __init__(self):
            self.document_mode = True
        
        @retry(
            wait=wait_random_exponential(min=1, max=10),
            stop=stop_after_attempt(5),
            retry=retry_if_exception_type((ResourceExhausted, InternalServerError, ServiceUnavailable, DeadlineExceeded))
        )
        def __call__(self, input: Documents) -> Embeddings:
            task = "retrieval_document" if self.document_mode else "retrieval_query"
            try:
                response = client.models.embed_content(
                    model="models/text-embedding-004",
                    contents=input,
                    config=types.EmbedContentConfig(task_type=task),
                )
                embeddings = [e.values for e in response.embeddings]
                print(f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0])}")
                logger.info(f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0])}")
                return embeddings
            except Exception as e:
                print(f"Retrying due to error: {e}")
                logger.error(f"Retrying due to error: {e}")
                raise
    
    embed_fn = GeminiEmbeddingFunction()
    embed_fn.document_mode = True
    state['embedding_function'] = embed_fn
    
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=450,
        chunk_overlap=50,
        separators=["\n\n", "\n**", ".", " "]
    )
    
    all_chunks = []
    chunk_ids = []
    for idx, summary in enumerate(state['summaries']):
        if "(error" in summary.lower() or not summary.strip():
            print(f"⚠️ Skipping summary {idx+1} due to error or empty: {summary[:100]}...")
            logger.warning(f"⚠️ Skipping summary {idx+1} due to error or empty: {summary[:100]}...")
            continue
        chunks = splitter.split_text(summary)
        if not chunks:
            print(f"⚠️ No chunks generated for summary {idx+1}")
            logger.warning(f"⚠️ No chunks generated for summary {idx+1}")
            continue
        all_chunks.extend(chunks)
        chunk_ids.extend([f"{state['summary_ids'][idx]}_{i}" for i in range(len(chunks))])
    
    print(f"Generated {len(all_chunks)} chunks from {len(state['summaries'])} summaries.")
    logger.info(f"Generated {len(all_chunks)} chunks from {len(state['summaries'])} summaries.")
    print(f"Sample chunks: {all_chunks[:2]}")
    logger.info(f"Sample chunks: {all_chunks[:2]}")
    print(f"Sample chunk IDs: {chunk_ids[:2]}")
    logger.info(f"Sample chunk IDs: {chunk_ids[:2]}")

    
    # Ensure collection exists
    collection = None
    try:
        collection = chroma_client.get_collection(name="mimic_iv_summary_chunks_trial")
        print("✅ Collection 'mimic_iv_summary_chunks_trial' already exists.")
        logger.info("✅ Collection 'mimic_iv_summary_chunks_trial' already exists.")
    except (ValueError, InvalidCollectionException) as e:
        print(f"Collection does not exist: {e}. Creating new collection 'mimic_iv_summary_chunks_trial'...")
        logger.info(f"Collection does not exist: {e}. Creating new collection 'mimic_iv_summary_chunks_trial'...")
        try:
            collection = chroma_client.create_collection(
                name="mimic_iv_summary_chunks_trial",
                embedding_function=embed_fn,
                metadata={"hnsw:space": "cosine"}
            )
            print("✅ Collection created successfully.")
            logger.info("✅ Collection created successfully.")
        except Exception as create_e:
            print(f"❌ Failed to create collection: {create_e}")
            logger.error(f"❌ Failed to create collection: {create_e}")
            raise
    
    # Store chunks if available
    if all_chunks:
        try:
            collection.add(documents=all_chunks, ids=chunk_ids)
            print(f"✅ Stored {collection.count()} chunks in ChromaDB collection 'mimic_iv_summary_chunks_trial'.")
            logger.info(f"✅ Stored {collection.count()} chunks in ChromaDB collection 'mimic_iv_summary_chunks_trial'.")
        except Exception as e:
            print(f"❌ Error adding chunks to ChromaDB: {e}")
            logger.error(f"❌ Error adding chunks to ChromaDB: {e}")
            raise
    else:
        print("⚠️ No chunks to store. Collection created but empty.")
        logger.warning("⚠️ No chunks to store. Collection created but empty.")
    
    # Verify collection exists
    try:
        collection = chroma_client.get_collection(name="mimic_iv_summary_chunks_trial")
        print(f"✅ Verified collection exists with {collection.count()} chunks.")
        logger.info(f"✅ Verified collection exists with {collection.count()} chunks.")
        
    except (ValueError, InvalidCollectionException) as e:
        print(f"❌ Collection verification failed: {e}")
        logger.error(f"❌ Collection verification failed: {e}")
        raise
    
    state['chunk_ids'] = chunk_ids
    print(f"✅ Stored {len(chunk_ids)} chunk IDs.")
    logger.info(f"✅ Stored {len(chunk_ids)} chunk IDs.")
    return state

# Eligibility Checking Function
def evaluate_eligibility(state: PipelineState) -> PipelineState:
    print("Starting eligibility checking step...")
    logger.info("Starting eligibility checking step...")
    sys.stdout.flush()
    
    try:
        chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
        try:
            collection = chroma_client.get_collection(
                name="mimic_iv_summary_chunks_trial",
                embedding_function=state['embedding_function']
            )
            print(f"✅ ChromaDB collection loaded with {collection.count()} chunks.")
            logger.info(f"✅ ChromaDB collection loaded with {collection.count()} chunks.")
        except (ValueError, InvalidCollectionException) as e:
            print(f"❌ Collection 'mimic_iv_summary_chunks_trial' does not exist: {e}")
            logger.error(f"❌ Collection 'mimic_iv_summary_chunks_trial' does not exist: {e}")
            raise
    except Exception as e:
        print(f"❌ Error initializing persistent ChromaDB client: {e}")
        logger.error(f"❌ Error initializing persistent ChromaDB client: {e}")
        raise
    
    eligibility_results = []
    
    # Skip if no chunks or summaries
    if not state['chunk_ids'] or not state['summaries']:
        print("⚠️ No chunks or valid summaries available. Skipping eligibility check.")
        logger.warning("⚠️ No chunks or valid summaries available. Skipping eligibility check.")
        for sid in state['summary_ids']:
            for criterion in state['criteria_prompts']:
                eligibility_results.append({
                    "summary_id": sid,
                    "criteria_name": criterion['Criteria_name'],
                    "trial_name": criterion.get('Trial_Name', 'Unknown'),
                    "eligibility": "No",
                    "evidence": "No chunks available"
                })
        state['eligibility_results'] = eligibility_results
        print(f"✅ Generated {len(eligibility_results)} placeholder eligibility results.")
        logger.info(f"✅ Generated {len(eligibility_results)} placeholder eligibility results.")
        return state
    
    for criterion in state['criteria_prompts']:
        criteria_name = criterion['Criteria_name']
        criteria_text = criterion['Text']
        cot_steps = criterion['chainOfThought']
        lexicons = criterion['extractableLexicons']
        trial_name = criterion.get('Trial_Name', 'Unknown')
        
        state['embedding_function'].document_mode = False
        try:
            results = collection.query(
                query_texts=[criteria_text.replace("\n", " ")],
                n_results=5,
            )
            relevant_chunks = results['documents'][0]
            chunk_ids = results['ids'][0]
            print(f"Retrieved {len(relevant_chunks)} chunks for criterion '{criteria_name}'.")
            logger.info(f"Retrieved {len(relevant_chunks)} chunks for criterion '{criteria_name}'.")
        except Exception as e:
            print(f"❌ Error querying ChromaDB for '{criteria_name}': {e}")
            logger.error(f"❌ Error querying ChromaDB for '{criteria_name}': {e}")
            for sid in state['summary_ids']:
                eligibility_results.append({
                    "summary_id": sid,
                    "criteria_name": criteria_name,
                    "trial_name": trial_name,
                    "eligibility": "No",
                    "evidence": f"Error: {str(e)}"
                })
            continue
        
        context = "\n".join([f"Chunk {i+1} (ID: {cid}): {chunk}" for i, (chunk, cid) in enumerate(zip(relevant_chunks, chunk_ids))])
        if not context:
            context = "No relevant chunks found."
        
        prompt = f"""
        You are a medical data analyst evaluating patient eligibility for a clinical trial based on the following criterion and retrieved clinical note chunks. Follow the chain-of-thought steps precisely and adhere to the output format.

        Criterion:
        - Name: {criteria_name}
        - Description: {criteria_text}
        - Extractable Lexicons: {json.dumps(lexicons)}
        - Chain of Thought: {json.dumps(cot_steps)}

        Retrieved Chunks:
        {context}

        Instructions:
        - Do NOT infer conditions based on symptoms or medications.
        - Only extract conditions pertinent to the patient (not their family).
        - Evidence must have an exact match in the clinical note.
        - Follow the chain-of-thought steps to determine eligibility.
        - Return a JSON array with one object per patient summary, matching the summary IDs: {json.dumps(state['summary_ids'])}.
        - Output format: {{"summary_id": "ID", "Eligibility": "Yes" or "No", "Evidence": "Relevant text or None"}}

        Output a JSON array of eligibility results for each patient summary.
        """
        
        @retry(
            wait=wait_random_exponential(min=1, max=10),
            stop=stop_after_attempt(5),
            retry=retry_if_exception_type((ResourceExhausted, InternalServerError, ServiceUnavailable, DeadlineExceeded))
        )
        def generate_eligibility():
            try:
                response = client.models.generate_content(
                    model="gemini-1.5-flash",
                    contents=prompt,
                    config=types.GenerateContentConfig()
                )
                response_text = response.candidates[0].content.parts[0].text
                response_text = response_text.replace("```json", "").replace("```", "").strip()
                results = json.loads(response_text)
                return results
            except Exception as e:
                print(f"Retrying due to error for '{criteria_name}': {e}")
                logger.error(f"Retrying due to error for '{criteria_name}': {e}")
                raise
        
        try:
            results = generate_eligibility()
            result_map = {r["summary_id"]: r for r in results if "summary_id" in r}
            for sid in state['summary_ids']:
                if sid in result_map:
                    eligibility_results.append({
                        "summary_id": sid,
                        "criteria_name": criteria_name,
                        "trial_name": trial_name,
                        "eligibility": result_map[sid]["Eligibility"],
                        "evidence": result_map[sid]["Evidence"]
                    })
                else:
                    print(f"⚠️ Missing result for summary ID {sid} in criterion '{criteria_name}'.")
                    logger.warning(f"⚠️ Missing result for summary ID {sid} in criterion '{criteria_name}'.")
                    eligibility_results.append({
                        "summary_id": sid,
                        "criteria_name": criteria_name,
                        "trial_name": trial_name,
                        "eligibility": "No",
                        "evidence": "Missing API response"
                    })
            print(f"✅ Evaluated eligibility for '{criteria_name}': {len(result_map)} results.")
            logger.info(f"✅ Evaluated eligibility for '{criteria_name}': {len(result_map)} results.")
            time.sleep(5)
        except Exception as e:
            print(f"❌ Failed to evaluate eligibility for '{criteria_name}' after retries: {e}")
            logger.error(f"❌ Failed to evaluate eligibility for '{criteria_name}' after retries: {e}")
            for sid in state['summary_ids']:
                eligibility_results.append({
                    "summary_id": sid,
                    "criteria_name": criteria_name,
                    "trial_name": trial_name,
                    "eligibility": "No",
                    "evidence": f"Error: {str(e)}"
                })
        finally:
            state['embedding_function'].document_mode = True
    
    state['eligibility_results'] = eligibility_results
    print(f"✅ Generated {len(eligibility_results)} eligibility results.")
    logger.info(f"✅ Generated {len(eligibility_results)} eligibility results.")
    return state

# Evaluation Function
def evaluate_results(state: PipelineState) -> PipelineState:
    print("Starting evaluation step...")
    logger.info("Starting evaluation step...")
    sys.stdout.flush()
    
    try:
        chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
        collection = chroma_client.get_collection(
            name="mimic_iv_summary_chunks_trial",
            embedding_function=state['embedding_function']
        )
        print(f"✅ ChromaDB collection loaded with {collection.count()} chunks.")
        logger.info(f"✅ ChromaDB collection loaded with {collection.count()} chunks.")
    except (ValueError, InvalidCollectionException) as e:
        print(f"❌ Collection 'mimic_iv_summary_chunks_trial' does not exist: {e}")
        logger.error(f"❌ Collection 'mimic_iv_summary_chunks_trial' does not exist: {e}")
        raise
    except Exception as e:
        print(f"❌ Error loading ChromaDB collection: {e}")
        logger.error(f"❌ Error loading ChromaDB collection: {e}")
        raise
    
    evaluation_results = []
    summary_map = dict(zip(state['summary_ids'], state['summaries']))
    
    criteria_groups = {}
    for result in state['eligibility_results']:
        crit_name = result['criteria_name']
        if crit_name not in criteria_groups:
            criteria_groups[crit_name] = []
        criteria_groups[crit_name].append(result)
    
    for crit_name, results in criteria_groups.items():
        print(f"Processing criterion '{crit_name}'...")
        logger.info(f"Processing criterion '{crit_name}'...")
        eval_results = []
        trial_name = results[0]['trial_name']
        
        criterion = next((c for c in state['criteria_prompts'] if c['Criteria_name'] == crit_name), None)
        if not criterion:
            print(f"❌ Criterion '{crit_name}' not found in criteria_prompts.")
            logger.error(f"❌ Criterion '{crit_name}' not found in criteria_prompts.")
            continue
        
        query = criterion['Text']
        query_oneline = query.replace("\n", " ")
        
        for result in results:
            patient_id = result['summary_id']
            eligibility = result['eligibility']
            raw_evidence = result['evidence']
            patient_summary = summary_map.get(patient_id, "Summary not found")
            
            if isinstance(raw_evidence, list):
                evidence_list = [e.strip() for e in raw_evidence if isinstance(e, str)]
            elif isinstance(raw_evidence, str) and raw_evidence != "None":
                evidence_list = [raw_evidence.strip()]
            else:
                evidence_list = []
            
            if eligibility == "No" and not evidence_list:
                eval_results.append({
                    "Trial_Name": trial_name,
                    "Criteria_Name": crit_name,
                    "Patient_id": patient_id,
                    "Eligibility": eligibility,
                    "Evidence": None,
                    "MatchedChunk": None,
                    "Patient_Summary": patient_summary,
                    "ROUGE-1": None,
                    "ROUGE-2": None,
                    "ROUGE-L": None,
                    "fuzz_MatchScore": None,
                    "CosineSimilarity": None
                })
                continue
            
            state['embedding_function'].document_mode = False
            matched_chunks = []
            for ev in evidence_list:
                try:
                    query_result = collection.query(query_texts=[ev], n_results=1)
                    top_chunk = query_result["documents"][0][0] if query_result["documents"] else ""
                    matched_chunks.append(top_chunk)
                except Exception as e:
                    print(f"⚠️ Reverse RAG failed for evidence '{ev}' in '{crit_name}': {e}")
                    logger.warning(f"⚠️ Reverse RAG failed for evidence '{ev}' in '{crit_name}': {e}")
                    matched_chunks.append("")
            state['embedding_function'].document_mode = True
            
            for ev, chunk in zip(evidence_list, matched_chunks):
                try:
                    state['embedding_function'].document_mode = False
                    evidence_embedding = state['embedding_function']([ev])[0]
                    state['embedding_function'].document_mode = True
                    chunk_embedding = state['embedding_function']([chunk])[0]
                    cos_sim = cosine_similarity([evidence_embedding], [chunk_embedding])[0][0]
                except Exception as e:
                    print(f"⚠️ Cosine similarity failed for '{crit_name}': {e}")
                    logger.warning(f"⚠️ Cosine similarity failed for '{crit_name}': {e}")
                    cos_sim = np.nan
                
                scores = scorer.score(ev, chunk)
                eval_results.append({
                    "Trial_Name": trial_name,
                    "Criteria_Name": crit_name,
                    "Patient_id": patient_id,
                    "Eligibility": eligibility,
                    "Evidence": ev,
                    "MatchedChunk": chunk,
                    "Patient_Summary": patient_summary,
                    "ROUGE-1": round(scores["rouge1"].fmeasure * 100, 2),
                    "ROUGE-2": round(scores["rouge2"].fmeasure * 100, 2),
                    "ROUGE-L": round(scores["rougeL"].fmeasure * 100, 2),
                    "fuzz_MatchScore": round(fuzz.partial_ratio(ev, chunk), 2),
                    "CosineSimilarity": round(cos_sim * 100, 2)
                })
        
        output_dir = "/kaggle/working/evaluation_results"
        os.makedirs(output_dir, exist_ok=True)
        df = pd.DataFrame(eval_results)
        csv_path = os.path.join(output_dir, f"{crit_name}_output_reverse_RAG_test.csv")
        df.to_csv(csv_path, index=False)
        print(f"✅ Saved: {csv_path}")
        logger.info(f"✅ Saved: {csv_path}")
        
        evaluation_results.extend(eval_results)
    
    consolidated_csv_path = "/kaggle/working/evaluation_results/all_results.csv"
    consolidated_df = pd.DataFrame(evaluation_results)
    consolidated_df.to_csv(consolidated_csv_path, index=False)
    print(f"✅ Saved consolidated results: {consolidated_csv_path}")
    logger.info(f"✅ Saved consolidated results: {consolidated_csv_path}")
    
    state['evaluation_results'] = evaluation_results
    print(f"✅ Generated {len(evaluation_results)} evaluation results.")
    logger.info(f"✅ Generated {len(evaluation_results)} evaluation results.")
    return state

# Build LangGraph Workflow
def build_workflow() -> StateGraph:
    print("Building LangGraph workflow...")
    logger.info("Building LangGraph workflow...")
    workflow = StateGraph(PipelineState)
    workflow.add_node("summarization", summarize_text)
    workflow.add_node("database", store_in_database)
    workflow.add_node("criteria", process_criteria)
    workflow.add_node("chunking_embedding", chunk_and_embed)
    workflow.add_node("eligibility", evaluate_eligibility)
    workflow.add_node("evaluation", evaluate_results)
    workflow.set_entry_point("summarization")
    workflow.add_edge("summarization", "database")
    workflow.add_edge("database", "criteria")
    workflow.add_edge("criteria", "chunking_embedding")
    workflow.add_edge("chunking_embedding", "eligibility")
    workflow.add_edge("eligibility", "evaluation")
    workflow.add_edge("evaluation", END)
    print("✅ Workflow built with all nodes.")
    logger.info("✅ Workflow built with all nodes.")
    return workflow.compile()


def run_pipeline(clinical_note, csv_file):
    try:
        raw_texts = []
        if clinical_note and clinical_note.strip():
            raw_texts.append(clinical_note.strip())
        if csv_file is not None:
            try:
                df = pd.read_csv(csv_file.name)
                if 'text' in df.columns:
                    raw_texts.extend(df['text'].dropna().tolist())
                else:
                    yield {
                        "status": "Error: CSV must contain a 'text' column.",
                        "log_output": ""
                    }
                    return
            except Exception as e:
                yield {
                    "status": f"Error reading CSV: {str(e)}",
                    "log_output": ""
                }
                return
        
        if not raw_texts:
            yield {
                "status": "Error: Please provide a clinical note or upload a valid CSV.",
                "log_output": ""
            }
            return
        
        log_buffer.truncate(0)
        log_buffer.seek(0)
        with redirect_stdout(log_buffer):
            workflow = build_workflow()
            initial_state = {"raw_text": raw_texts}
            
            steps = [
                ("Summarizing...", summarize_text),
                ("Storing in database...", store_in_database),
                ("Processing criteria...", process_criteria),
                ("Chunking and embedding...", chunk_and_embed),
                ("Evaluating eligibility...", evaluate_eligibility),
                ("Evaluating results...", evaluate_results)
            ]
            
            state = initial_state
            for step_name, step_func in steps:
                yield {
                    "status": step_name,
                    "log_output": log_buffer.getvalue(),
                    "summary_output": None,
                    "eligibility_output": None,
                    "evaluation_output": None,
                    "db_output": None,
                    "chroma_output": None,
                    "consolidated_csv": None
                }
                state = step_func(state)
                sys.stdout.flush()
            
            result = state
            
            summary_df = pd.DataFrame({
                "Summary_ID": result['summary_ids'],
                "Original_Text": [t[:200] + "..." for t in result['raw_text'][:len(result['summary_ids'])]],
                "Summary": [s[:200] + "..." for s in result['summaries'] + [""] * (len(result['summary_ids']) - len(result['summaries']))]
            })
            
            eligibility_df = pd.DataFrame(result['eligibility_results'])
            evaluation_df = pd.DataFrame(result['evaluation_results'])
            
            db_output = []
            try:
                with sqlite3.connect("/kaggle/working/clinical_data.db") as conn:
                    cursor = conn.cursor()
                    cursor.execute("SELECT id, original_text, summary FROM summaries")
                    rows = cursor.fetchall()
                    db_output = [{"ID": r[0], "Original_Text": r[1][:100] + "...", "Summary": r[2][:100] + "..."} for r in rows]
            except Exception as e:
                db_output = [{"ID": "Error", "Original_Text": str(e), "Summary": ""}]
            
            chroma_output = []
            try:
                chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
                collection = chroma_client.get_collection(name="mimic_iv_summary_chunks_trial")
                count = collection.count()
                results = collection.peek(limit=5)
                chroma_output = [{"Chunk_ID": cid, "Document": doc[:100] + "..."} for cid, doc in zip(results['ids'], results['documents'])]
                chroma_output.insert(0, {"Chunk_ID": "Total Count", "Document": str(count)})
            except Exception as e:
                chroma_output = [{"Chunk_ID": "Error", "Document": str(e)}]
            
            consolidated_csv = "/kaggle/working/evaluation_results/all_results.csv"
            status = f"Pipeline completed: {len(result['summaries'])} summaries, {len(result['chunk_ids'])} chunks, {len(result['eligibility_results'])} eligibility results, {len(result['evaluation_results'])} evaluation results."
            
            yield {
                "status": status,
                "summary_output": summary_df,
                "eligibility_output": eligibility_df,
                "evaluation_output": evaluation_df,
                "db_output": pd.DataFrame(db_output),
                "chroma_output": pd.DataFrame(chroma_output),
                "consolidated_csv": consolidated_csv,
                "log_output": log_buffer.getvalue()
            }
    
    except Exception as e:
        error_msg = f"Error running pipeline: {str(e)}"
        yield {
            "status": error_msg,
            "log_output": log_buffer.getvalue() + "\n" + error_msg,
            "summary_output": None,
            "eligibility_output": None,
            "evaluation_output": None,
            "db_output": None,
            "chroma_output": None,
            "consolidated_csv": None
        }

# # Gradio UI Functions
# def run_pipeline(clinical_note, csv_file):
#     try:
#         raw_texts = []
        # if clinical_note and clinical_note.strip():
        #     raw_texts.append(clinical_note.strip())
        # if csv_file is not None:
        #     try:
        #         df = pd.read_csv(csv_file.name)
        #         if 'text' in df.columns:
        #             raw_texts.extend(df['text'].dropna().tolist())
        #         else:
        #             yield {"status": "Error: CSV must contain a 'text' column.", "logs": ""}
        #             return
        #     except Exception as e:
        #         yield {"status": f"Error reading CSV: {str(e)}", "logs": ""}
        #         return
        
        # if not raw_texts:
        #     yield {"status": "Error: Please provide a clinical note or upload a valid CSV.", "logs": ""}
        #     return
        
        # log_buffer.truncate(0)
        # log_buffer.seek(0)
        # with redirect_stdout(log_buffer):
        #     workflow = build_workflow()
        #     initial_state = {"raw_text": raw_texts}
            
        #     # Simulate pipeline steps with progress updates
        #     steps = [
        #         ("Summarizing...", summarize_text),
        #         ("Storing in database...", store_in_database),
        #         ("Processing criteria...", process_criteria),
        #         ("Chunking and embedding...", chunk_and_embed),
        #         ("Evaluating eligibility...", evaluate_eligibility),
        #         ("Evaluating results...", evaluate_results)
        #     ]
            
        #     state = initial_state
        #     for step_name, step_func in steps:
        #         yield {
            #         "status": step_name,
            #         "logs": log_buffer.getvalue()
            #     }
            #     if step_func != summarize_text:  # Apply step_func directly for non-summarization steps
            #         state = step_func(state)
            #     else:
            #         state = step_func(state)  # summarization step
            #     sys.stdout.flush()
            
            # result = state  # Final state after all steps
            
            # # Format results for UI
            # summary_df = pd.DataFrame({
            #     "Summary_ID": result['summary_ids'],
            #     "Original_Text": [t[:200] + "..." for t in result['raw_text'][:len(result['summary_ids'])]],
            #     "Summary": [s[:200] + "..." for s in result['summaries'] + [""] * (len(result['summary_ids']) - len(result['summaries']))]
            # })
            
            # eligibility_df = pd.DataFrame(result['eligibility_results'])
            # evaluation_df = pd.DataFrame(result['evaluation_results'])
            
            # # Database contents
            # db_output = []
            # try:
            #     with sqlite3.connect("/kaggle/working/clinical_data.db") as conn:
            #         cursor = conn.cursor()
            #         cursor.execute("SELECT id, original_text, summary FROM summaries")
            #         rows = cursor.fetchall()
            #         db_output = [{"ID": r[0], "Original_Text": r[1][:100] + "...", "Summary": r[2][:100] + "..."} for r in rows]
            # except Exception as e:
            #     db_output = [{"ID": "Error", "Original_Text": str(e), "Summary": ""}]
            
            # # ChromaDB contents
            # chroma_output = []
            # try:
            #     chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
            #     collection = chroma_client.get_collection(name="mimic_iv_summary_chunks_trial")
            #     count = collection.count()
            #     results = collection.peek(limit=5)
            #     chroma_output = [{"Chunk_ID": cid, "Document": doc[:100] + "..."} for cid, doc in zip(results['ids'], results['documents'])]
            #     chroma_output.insert(0, {"Chunk_ID": "Total Count", "Document": str(count)})
            # except Exception as e:
            #     chroma_output = [{"Chunk_ID": "Error", "Document": str(e)}]
            
            # consolidated_csv = "/kaggle/working/evaluation_results/all_results.csv"
            # status = f"Pipeline completed: {len(result['summaries'])} summaries, {len(result['chunk_ids'])} chunks, {len(result['eligibility_results'])} eligibility results, {len(result['evaluation_results'])} evaluation results."
            
    #         yield {
    #             "status": status,
    #             "summary_df": summary_df,
    #             "eligibility_df": eligibility_df,
    #             "evaluation_df": evaluation_df,
    #             "db_df": pd.DataFrame(db_output),
    #             "chroma_df": pd.DataFrame(chroma_output),
    #             "consolidated_csv": consolidated_csv,
    #             "logs": log_buffer.getvalue()
    #         }
    
    # except Exception as e:
    #     error_msg = f"Error running pipeline: {str(e)}"
    #     yield {
    #         "status": error_msg,
    #         "logs": log_buffer.getvalue() + "\n" + error_msg
    #     }

def clear_outputs():
    try:
        shutil.rmtree("/kaggle/working/chromadb", ignore_errors=True)
        shutil.rmtree("/kaggle/working/evaluation_results", ignore_errors=True)
        shutil.rmtree("/kaggle/working/Criteria_based_Prompts", ignore_errors=True)
        if os.path.exists("/kaggle/working/clinical_data.db"):
            os.remove("/kaggle/working/clinical_data.db")
        if os.path.exists("/kaggle/working/pipeline.log"):
            os.remove("/kaggle/working/pipeline.log")
        return "Output directories, database, and log file cleared successfully!"
    except Exception as e:
        return f"Error clearing outputs: {str(e)}"

# Main execution (preserved as original)
if __name__ == "__main__":
    print("Loading test data...")
    logger.info("Loading test data...")
    try:
        df = pd.read_csv('/kaggle/input/mimic-iv-2/mimic_iv_summarization_test_dataset_shortened_edited.csv', nrows=10)
        df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)
        print(f"Test data loaded: {len(df_test)} samples.")
        logger.info(f"Test data loaded: {len(df_test)} samples.")
    except Exception as e:
        print(f"Error loading test data: {e}")
        logger.error(f"Error loading test data: {e}")
        raise
    
    print("Running LangGraph pipeline...")
    logger.info("Running LangGraph pipeline...")
    workflow = build_workflow()
    initial_state = {"raw_text": df_test['text'].tolist()}
    result = workflow.invoke(initial_state)
    
    print("\nPipeline Result:")
    logger.info("\nPipeline Result:")
    print(f"Input Texts: {len(result['raw_text'])}")
    logger.info(f"Input Texts: {len(result['raw_text'])}")
    print(f"Summaries: {len(result['summaries'])}")
    logger.info(f"Summaries: {len(result['summaries'])}")
    print(f"Summary IDs: {len(result['summary_ids'])}")
    logger.info(f"Summary IDs: {len(result['summary_ids'])}")
    print(f"Criteria Prompts: {len(result['criteria_prompts'])}")
    logger.info(f"Criteria Prompts: {len(result['criteria_prompts'])}")
    print(f"Chunk IDs: {len(result['chunk_ids'])}")
    logger.info(f"Chunk IDs: {len(result['chunk_ids'])}")
    print(f"Eligibility Results: {len(result['eligibility_results'])}")
    logger.info(f"Eligibility Results: {len(result['eligibility_results'])}")
    print(f"Evaluation Results: {len(result['evaluation_results'])}")
    logger.info(f"Evaluation Results: {len(result['evaluation_results'])}")

    
    for i, (text, summary, sid) in enumerate(zip(result['raw_text'], result['summaries'] + [""] * (len(result['raw_text']) - len(result['summaries'])), result['summary_ids'])):
        print(f"\nSample {i+1}:")
        logger.info(f"\nSample {i+1}:")
        print(f"Input (first 100 chars): {text[:100]}...")
        logger.info(f"Input (first 100 chars): {text[:100]}...")
        print(f"Summary (first 200 chars): {summary[:200]}...")
        logger.info(f"Summary (first 200 chars): {summary[:200]}...")
        print(f"Summary ID: {sid}")
        logger.info(f"Summary ID: {sid}")
        if "(error" in summary.lower():
            print(f"Error detected in summary: {summary}")
            logger.error(f"Error detected in summary: {summary}")
    
    print("\nSample Criteria Prompts:")
    logger.info("\nSample Criteria Prompts:")
    for i, prompt in enumerate(result['criteria_prompts'][:2]):
        print(f"Prompt {i+1}:")
        logger.info(f"Prompt {i+1}:")
        print(f"Trial Name: {prompt.get('Trial_Name', 'Unknown')}")
        logger.info(f"Trial Name: {prompt.get('Trial_Name', 'Unknown')}")
        print(f"Criteria Name: {prompt['Criteria_name']}")
        logger.info(f"Criteria Name: {prompt['Criteria_name']}")
        print(f"Text: {prompt['Text']}")
        logger.info(f"Text: {prompt['Text']}")
        print(f"Extractable Lexicons: {prompt['extractableLexicons']}")
        logger.info(f"Extractable Lexicons: {prompt['extractableLexicons']}")
    
    print("\nSample Chunks:")
    logger.info("\nSample Chunks:")
    for i, cid in enumerate(result['chunk_ids'][:2]):
        print(f"Chunk ID: {cid}")
        logger.info(f"Chunk ID: {cid}")
    
    print("\nSample Eligibility Results:")
    logger.info("\nSample Eligibility Results:")
    for i, res in enumerate(result['eligibility_results'][:4]):
        print(f"Result {i+1}:")
        logger.info(f"Result {i+1}:")
        print(f"Trial Name: {res['trial_name']}")
        logger.info(f"Trial Name: {res['trial_name']}")
        print(f"Summary ID: {res['summary_id']}")
        logger.info(f"Summary ID: {res['summary_id']}")
        print(f"Criteria Name: {res['criteria_name']}")
        logger.info(f"Criteria Name: {res['criteria_name']}")
        print(f"Eligibility: {res['eligibility']}")
        logger.info(f"Eligibility: {res['eligibility']}")
        print(f"Evidence: {res['evidence']}")
        logger.info(f"Evidence: {res['evidence']}")
    
    print("\nSample Evaluation Results:")
    logger.info("\nSample Evaluation Results:")
    for i, res in enumerate(result['evaluation_results'][:4]):
        print(f"Result {i+1}:")
        logger.info(f"Result {i+1}:")
        print(f"Trial Name: {res['Trial_Name']}")
        logger.info(f"Trial Name: {res['Trial_Name']}")
        print(f"Patient ID: {res['Patient_id']}")
        logger.info(f"Patient ID: {res['Patient_id']}")
        print(f"Criteria Name: {res['Criteria_Name']}")
        logger.info(f"Criteria Name: {res['Criteria_Name']}")
        print(f"Eligibility: {res['Eligibility']}")
        logger.info(f"Eligibility: {res['Eligibility']}")
        print(f"Evidence: {res['Evidence']}")
        logger.info(f"Evidence: {res['Evidence']}")
        print(f"ROUGE-1: {res['ROUGE-1']}")
        logger.info(f"ROUGE-1: {res['ROUGE-1']}")
        print(f"fuzz_MatchScore: {res['fuzz_MatchScore']}")
        logger.info(f"fuzz_MatchScore: {res['fuzz_MatchScore']}")
    
    print("\nVerifying database contents...")
    logger.info("\nVerifying database contents...")
    try:
        with sqlite3.connect("/kaggle/working/clinical_data.db") as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT id, original_text, summary FROM summaries")
            rows = cursor.fetchall()
            print(f"Found {len(rows)} records in database.")
            logger.info(f"Found {len(rows)} records in database.")
            for row in rows[:2]:
                print(f"Record ID: {row[0]}, Original Text (first 50 chars): {row[1][:50]}..., Summary (first 50 chars): {row[2][:50]}...")
                logger.info(f"Record ID: {row[0]}, Original Text (first 50 chars): {row[1][:50]}..., Summary (first 50 chars): {row[2][:50]}...")
    except Exception as e:
        print(f"Error verifying database: {e}")
        logger.error(f"Error verifying database: {e}")
    
    print("\nVerifying ChromaDB contents...")
    logger.info("\nVerifying ChromaDB contents...")
    try:
        chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
        collection = chroma_client.get_collection(name="mimic_iv_summary_chunks_trial")
        count = collection.count()
        print(f"Found {count} chunks in ChromaDB collection 'mimic_iv_summary_chunks_trial'.")
        logger.info(f"Found {count} chunks in ChromaDB collection 'mimic_iv_summary_chunks_trial'.")
        results = collection.peek(limit=2)
        for i, (doc, cid) in enumerate(zip(results['documents'], results['ids'])):
            print(f"Chunk {i+1} (ID: {cid}): {doc[:100]}...")
            logger.info(f"Chunk {i+1} (ID: {cid}): {doc[:100]}...")
    except Exception as e:
        print(f"Error verifying ChromaDB: {e}")
        logger.error(f"Error verifying ChromaDB: {e}")
    
    print("\nVerifying CSV outputs...")
    logger.info("\nVerifying CSV outputs...")
    try:
        csv_files = glob.glob("/kaggle/working/evaluation_results/*_output_reverse_RAG_test.csv")
        print(f"Found {len(csv_files)} CSV files in /kaggle/working/evaluation_results/")
        logger.info(f"Found {len(csv_files)} CSV files in /kaggle/working/evaluation_results/")
        for csv_file in csv_files[:2]:
            df = pd.read_csv(csv_file)
            print(f"\nCSV: {csv_file}")
            logger.info(f"\nCSV: {csv_file}")
            print(df.head(2).to_string())
            logger.info(df.head(2).to_string())
    except Exception as e:
        print(f"Error verifying CSV outputs: {e}")
        logger.error(f"Error verifying CSV outputs: {e}")

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Clinical Trial Pipeline")
    with gr.Row():
        with gr.Column():
            clinical_note = gr.Textbox(label="Clinical Note", placeholder="Enter a clinical discharge note...")
            csv_file = gr.File(label="Upload CSV (with 'text' column)")
            with gr.Row():
                run_button = gr.Button("Run Pipeline")
                clear_button = gr.Button("Clear Outputs")
        with gr.Column():
            status = gr.Textbox(label="Status", value="Ready to run pipeline.")
    
    with gr.Tabs():
        with gr.Tab("Summaries"):
            summary_output = gr.DataFrame(label="Summaries")
        with gr.Tab("Eligibility Results"):
            eligibility_output = gr.DataFrame(label="Eligibility Results")
        with gr.Tab("Evaluation Results"):
            evaluation_output = gr.DataFrame(label="Evaluation Results")
        with gr.Tab("Database Contents"):
            db_output = gr.DataFrame(label="Database Contents")
        with gr.Tab("ChromaDB Contents"):
            chroma_output = gr.DataFrame(label="ChromaDB Contents")
        with gr.Tab("Logs"):
            log_output = gr.Textbox(label="Logs", lines=20)
    
    consolidated_csv = gr.File(label="Download Consolidated Results")


    def update_outputs(clinical_note, csv_file):
        # Default values for all output components
        default_output = {
            "status": "Initializing...",
            "summary_output": None,
            "eligibility_output": None,
            "evaluation_output": None,
            "db_output": None,
            "chroma_output": None,
            "consolidated_csv": None,
            "log_output": ""
        }
        
        for output in run_pipeline(clinical_note, csv_file):
            # Ensure the output dictionary only contains expected keys
            current_output = default_output.copy()
            current_output.update({k: v for k, v in output.items() if k in default_output})
            
            # Yield a tuple with exactly 8 values, one for each output component
            yield (
                current_output["status"],
                current_output["summary_output"],
                current_output["eligibility_output"],
                current_output["evaluation_output"],
                current_output["db_output"],
                current_output["chroma_output"],
                current_output["consolidated_csv"],
                current_output["log_output"]
            )

    # def update_outputs(clinical_note, csv_file):
    #     # Default output dictionary with all required keys
    #     default_output = {
    #         "status": "Initializing...",
    #         "summary_output": None,
    #         "eligibility_output": None,
    #         "evaluation_output": None,
    #         "db_output": None,
    #         "chroma_output": None,
    #         "consolidated_csv": None,
    #         "log_output": ""
    #     }
        
    #     for output in run_pipeline(clinical_note, csv_file):
    #         # Merge the yielded output with the default output to ensure all keys are present
    #         current_output = default_output.copy()
    #         current_output.update(output)
    #         yield current_output
    
    # def update_outputs(clinical_note, csv_file):
    #     for output in run_pipeline(clinical_note, csv_file):
    #         yield {
    #             "status": output.get("status", ""),
    #             "summary_output": output.get("summary_df", None),
    #             "eligibility_output": output.get("eligibility_df", None),
    #             "evaluation_output": output.get("evaluation_df", None),
    #             "db_output": output.get("db_df", None),
    #             "chroma_output": output.get("chroma_df", None),
    #             "consolidated_csv": output.get("consolidated_csv", None),
    #             "log_output": output.get("logs", "")
    #         }
    
    run_button.click(
        fn=update_outputs,
        inputs=[clinical_note, csv_file],
        outputs=[
            status,
            summary_output,
            eligibility_output,
            evaluation_output,
            db_output,
            chroma_output,
            consolidated_csv,
            log_output
        ]
    )
    clear_button.click(
        fn=clear_outputs,
        inputs=None,
        outputs=status
    )

# Launch Gradio in Kaggle
try:
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        inline=True,
        quiet=False
    )
except Exception as e:
    print(f"Error launching Gradio: {e}")
    print("Please ensure the notebook is public and try running the cell again.")
    print("Alternatively, check the Kaggle output cell below for the Gradio URL.")
    logger.error(f"Error launching Gradio: {e}")

✅ API client initialized successfully.
Loading test data...
Test data loaded: 2 samples.
Running LangGraph pipeline...
Building LangGraph workflow...
✅ Workflow built with all nodes.
Starting summarization step...
Loading dataset...
Dataset loaded: 799 training, 200 test samples.
Input 1 (length: 7898):  
Name:  ___                   Unit No:   ___
 
Admission Date:  ___              Discharge Date:   ...
Input 2 (length: 9293):  
Name:  ___                    Unit No:   ___
 
Admission Date:  ___              Discharge Date:  ...
input data length: 1
Prepared input data for fine-tuning.
Generating summary for text (length: 7898)...
model_id: tunedModels/summarization-model-9zeddx1lwavq
API Finish Reason: STOP
Summary generated (length: 3345): ## Discharge Summary for [Patient Name]

**Admission:** [Date]
**Discharge:** [Date]
**DOB:** [Date]...
Generating summary for text (length: 9293)...
model_id: tunedModels/summarization-model-9zeddx1lwavq
API Finish Reason: STOP
Summary generated

/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:00<00:00, 103MiB/s] 


In [9]:
CHROMA_DB_PATH = '/kaggle/working/chromadb_2_sample_v15'