***Setup and Installation of Python Libraries for NLP, Retrieval***

In [6]:
# Run once at top of notebook
!pip install --upgrade pip
# core libs
!pip install pandas tqdm matplotlib seaborn
# embeddings + transformer models
!pip install sentence-transformers transformers accelerate
# FAISS (CPU). If you have GPU c
!pip install faiss-cpu
# BM25
!pip install rank_bm25
# streamlit + ngrok for serving from Colab
!pip install streamlit pyngrok
# evaluation helpers
!pip install rouge_score sacrebleu




In [8]:
from google.colab import files
uploaded = files.upload()  # choose mimic-iv-ext-direct-1.0.0.zip from your machine
zip_path = list(uploaded.keys())[0]
print("Uploaded:", zip_path)


Saving mimic-iv-ext-direct-1.0.0.zip to mimic-iv-ext-direct-1.0.0 (1).zip
Uploaded: mimic-iv-ext-direct-1.0.0 (1).zip


In [9]:
import zipfile, os

zip_path = list(uploaded.keys())[0]
extract_path = "/content/mimic_ext/"
os.makedirs(extract_path, exist_ok=True)

with zipfile.ZipFile(zip_path, "r") as z:
    z.extractall(extract_path)

print("Extracted to:", extract_path)


Extracted to: /content/mimic_ext/


In [10]:
import glob

# Recursively get all files (any extension) in the Finished folder
all_files = glob.glob(extract_path + "/mimic-iv-ext-direct-1.0.0/Finished/**/**", recursive=True)

# Filter only files (not directories)
all_files = [f for f in all_files if not os.path.isdir(f)]

print("Total files found:", len(all_files))
all_files[:5]  # Show first 5 files


Total files found: 1365


['/content/mimic_ext//mimic-iv-ext-direct-1.0.0/Finished/COPD/18591903-DS-16.json',
 '/content/mimic_ext//mimic-iv-ext-direct-1.0.0/Finished/COPD/15166831-DS-16.json',
 '/content/mimic_ext//mimic-iv-ext-direct-1.0.0/Finished/COPD/11655904-DS-23.json',
 '/content/mimic_ext//mimic-iv-ext-direct-1.0.0/Finished/COPD/14725771-DS-12.json',
 '/content/mimic_ext//mimic-iv-ext-direct-1.0.0/Finished/COPD/11482871-DS-15.json']

In [11]:
import os
import json

# Get all files recursively
all_files = glob.glob(extract_path + "/mimic-iv-ext-direct-1.0.0/Finished/**/**", recursive=True)
all_files = [f for f in all_files if os.path.isfile(f)]

documents = []

for path in all_files:
    try:
        ext = os.path.splitext(path)[1].lower()
        text = ""

        if ext == ".json":
            with open(path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            # extract ALL strings from JSON (deep)
            def extract_text(obj):
                if isinstance(obj, dict):
                    return " ".join(extract_text(v) for v in obj.values())
                elif isinstance(obj, list):
                    return " ".join(extract_text(v) for v in obj)
                elif isinstance(obj, str):
                    return obj
                else:
                    return ""

            text = extract_text(data)

        else:
            # For other file types (txt, csv, etc.)
            with open(path, 'r', encoding='utf-8', errors='ignore') as f:
                text = f.read()

        if len(text.strip()) > 20:  # ignore very short files
            documents.append({"source": path, "text": text})

    except Exception as e:
        # optional: print(e) to see problematic files
        pass

print("Total documents extracted:", len(documents))
documents[:5]  # preview first 5


Total documents extracted: 1365


[{'source': '/content/mimic_ext//mimic-iv-ext-direct-1.0.0/Finished/COPD/18591903-DS-16.json',
  'text': '           None\n The patient reports that over the past 1.5 months, she \nhas experienced worsening shortness of breath and chest pressure with exertion, which feels similar to her previous COPD flare. when she was hospitalized for an exacerbation while still living. She has never required intubation. She reports that she does not regularly take medications for her COPD, but the day prior to admission, tried an albuterol inhaler for the first time, with little improvement in her symptoms. She also reports that she has experienced an increase in her baseline cough over the past 1.5 months, and more acutely over the past 2 weeks. The cough is worse while supine at night and productive of a non-bloody thick mucus like sputum; she reports some improvement in her cough when sitting up. She also describes dyspnea on exertion and decreased exercise tolerance.\n\nOf note, the patient repo

In [12]:
!pip install sentence-transformers faiss-cpu




***Load Sentence Transformer Model and Initialize FAISS for Embedding-Based Search***

In [13]:
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss

model = SentenceTransformer('all-MiniLM-L6-v2')  # lightweight, works in Colab


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


***Encode Documents and Build FAISS Index for Similarity Search***

In [14]:
texts = [d["text"] for d in documents]
emb = model.encode(texts, convert_to_numpy=True, show_progress_bar=True)

dim = emb.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(emb)


Batches:   0%|          | 0/43 [00:00<?, ?it/s]

***Retrieve Top-K Relevant Documents Using FAISS and Sentence Embeddings***

In [15]:
def retrieve(query, top_k=5):
    q = model.encode([query], convert_to_numpy=True)
    scores, idx = index.search(q, top_k)

    out = []
    for score, i in zip(scores[0], idx[0]):
        out.append({
            "score": float(score),
            "source": documents[i]["source"],
            "text": documents[i]["text"]
        })
    return out


In [16]:
!pip install transformers accelerate




***Load Instruction-Tuned Causal Language Model for Text Generation***

In [17]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

gen_model_name = "Qwen/Qwen2.5-1.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
generator = AutoModelForCausalLM.from_pretrained(
    gen_model_name,
    device_map="auto",
    torch_dtype=torch.float16
)


`torch_dtype` is deprecated! Use `dtype` instead!


***Enhanced Clinical RAG System for Stroke Diagnosis and Assessment***

In [38]:
# =============================================================================
# ENHANCED GENERATION FUNCTION (CRITICAL FIX)
# =============================================================================

def generate_answer(query, retrieved_docs, max_tokens=400):
    """FIXED: Generate answers without prompt leakage"""

    # Extract context from retrieved documents
    context_parts = []
    for i, doc in enumerate(retrieved_docs):
        doc_text = doc["text"].strip()
        # Clean and truncate text properly
        if len(doc_text) > 800:
            # Try to truncate at sentence boundary
            trunc_point = doc_text[:800].rfind('.')
            if trunc_point > 400:  # Ensure meaningful content
                doc_text = doc_text[:trunc_point+1]
            else:
                doc_text = doc_text[:800] + "..."
        context_parts.append(f"[Document {i+1}]: {doc_text}")

    context = "\n\n".join(context_parts)

    # Improved prompt template
    prompt = f"""Based on the following clinical documentation, provide a concise medical assessment.

CLINICAL CONTEXT:
{context}

CLINICAL QUESTION: {query}

MEDICAL ASSESSMENT:"""

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(generator.device)

    # Generate with proper parameters
    output = generator.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=0.3,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.2
    )

    # Extract only the generated part
    full_output = tokenizer.decode(output[0], skip_special_tokens=True)

    if "MEDICAL ASSESSMENT:" in full_output:
        answer = full_output.split("MEDICAL ASSESSMENT:")[-1].strip()
    else:
        answer = full_output

    return answer

# =============================================================================
# ENHANCED RETRIEVAL FUNCTION
# =============================================================================

def retrieve(query, top_k=5):
    """FIXED: Remove duplicate documents"""
    q = model.encode([query], convert_to_numpy=True)
    scores, idx = index.search(q, top_k * 3)  # Get more to filter duplicates

    out = []
    seen_sources = set()

    for score, i in zip(scores[0], idx[0]):
        source = documents[i]["source"]

        # Skip duplicates
        if source in seen_sources:
            continue
        seen_sources.add(source)

        out.append({
            "score": float(score),
            "source": source,
            "text": documents[i]["text"],
            "filename": os.path.basename(source)
        })

        if len(out) >= top_k:
            break

    return out

# =============================================================================
# ENHANCED RAG QUERY WITH PROPER OUTPUT FORMAT
# =============================================================================

def rag_query(query, top_k=5):
    """Enhanced RAG query with better output formatting"""
    retrieved = retrieve(query, top_k=top_k)
    answer = generate_answer(query, retrieved)

    return {
        "query": query,
        "retrieved": retrieved,
        "answer": answer
    }

# =============================================================================
# ENHANCED DISPLAY FUNCTION
# =============================================================================

def display_stroke_results(results, example_num, stroke_type):
    """Enhanced display function for stroke domain results"""

    print(f"\n{'='*80}")
    print(f"üß† STROKE DOMAIN - EXAMPLE {example_num}: {stroke_type}")
    print(f"{'='*80}")

    # Query information
    print(f"üìã CLINICAL QUERY:")
    print(f"   {results['query']}")

    # Answer section
    print(f"\nüí° CLINICAL ASSESSMENT:")
    print(f"{'-'*60}")
    if results['answer'] and len(results['answer'].strip()) > 50:
        print(results['answer'])
    else:
        print("‚ö†Ô∏è  Answer generation issue detected - reviewing context processing...")
        print("Generated output:", results['answer'][:200] if results['answer'] else "Empty")
    print(f"{'-'*60}")

    # Retrieved documents with enhanced info
    print(f"\nüìö RETRIEVED CLINICAL DOCUMENTS ({len(results['retrieved'])}):")
    print(f"{'='*60}")

    for i, doc in enumerate(results['retrieved']):
        print(f"\nüìë DOCUMENT {i+1}:")
        print(f"   ‚≠ê Relevance Score: {doc['score']:.3f}")
        print(f"   üìÅ Source: {doc['filename']}")
        print(f"   üè• Category: {doc['source'].split('/')[-3] if len(doc['source'].split('/')) > 5 else 'Unknown'}")

        # Clean and display text preview
        doc_text = doc['text'].strip()
        preview = doc_text[:400] + "..." if len(doc_text) > 400 else doc_text
        print(f"   üìù Clinical Findings: {preview}")
        print(f"   {'‚îÄ'*50}")

# =============================================================================
# UPDATED STROKE DOMAIN TESTING
# =============================================================================

print("\n" + "üöÄ" * 20)
print("üß† ENHANCED STROKE DOMAIN CLINICAL RAG TESTING")
print("üöÄ" * 20)

# =============================================================================
# Example 6: Ischemic Stroke Query
# =============================================================================

print(f"\nüéØ TEST 6/10: ISCHEMIC STROKE IDENTIFICATION")
out6 = rag_query("Patient with sudden weakness on one side, facial droop, and slurred speech ‚Äî likely type of stroke?", top_k=5)
display_stroke_results(out6, 6, "ISCHEMIC STROKE")

# =============================================================================
# Example 7: Hemorrhagic Stroke Query
# =============================================================================

print(f"\nüéØ TEST 7/10: HEMORRHAGIC STROKE IDENTIFICATION")
out7 = rag_query("Patient with sudden severe headache, nausea, and vomiting ‚Äî possible hemorrhagic event?", top_k=5)
display_stroke_results(out7, 7, "HEMORRHAGIC STROKE")

# =============================================================================
# Example 8: Transient Ischemic Attack (TIA) Query
# =============================================================================

print(f"\nüéØ TEST 8/10: TRANSIENT ISCHEMIC ATTACK (TIA)")
out8 = rag_query("Patient reports brief episode of vision loss and numbness in the arm, symptoms resolve within minutes ‚Äî likely diagnosis?", top_k=5)
display_stroke_results(out8, 8, "TRANSIENT ISCHEMIC ATTACK (TIA)")

# =============================================================================
# Example 9: Stroke with Aphasia Query
# =============================================================================

print(f"\nüéØ TEST 9/10: STROKE WITH APHASIA")
out9 = rag_query("Patient with sudden difficulty speaking and understanding language, right-sided weakness ‚Äî what type of stroke?", top_k=5)
display_stroke_results(out9, 9, "STROKE WITH APHASIA")

# =============================================================================
# Example 10: Stroke with Visual Field Deficit Query
# =============================================================================

print(f"\nüéØ TEST 10/10: STROKE WITH VISUAL FIELD DEFICIT")
out10 = rag_query("Patient complains of sudden loss of vision in left visual field, left-sided weakness ‚Äî likely neurological condition?", top_k=5)
display_stroke_results(out10, 10, "STROKE WITH VISUAL FIELD DEFICIT")

# =============================================================================
# ADDITIONAL STROKE-RELATED QUERIES FOR COMPREHENSIVE TESTING
# =============================================================================

print(f"\n{'='*80}")
print("üîç ADDITIONAL STROKE-RELATED CLINICAL QUERIES")
print(f"{'='*80}")

# Additional test cases
additional_stroke_queries = [
    {
        "query": "What is the time window for thrombolytic therapy in acute ischemic stroke?",
        "type": "STROKE MANAGEMENT",
        "top_k": 4
    },
    {
        "query": "Differentiate between anterior and posterior circulation stroke symptoms",
        "type": "STROKE LOCALIZATION",
        "top_k": 4
    },
    {
        "query": "Risk factors for hemorrhagic transformation after ischemic stroke",
        "type": "STROKE COMPLICATIONS",
        "top_k": 4
    }
]

for i, stroke_query in enumerate(additional_stroke_queries, 1):
    print(f"\nüéØ ADDITIONAL TEST {i}: {stroke_query['type']}")
    result = rag_query(stroke_query["query"], top_k=stroke_query["top_k"])
    display_stroke_results(result, f"10+{i}", stroke_query["type"])

# =============================================================================
# STROKE DOMAIN PERFORMANCE SUMMARY
# =============================================================================

print(f"\n{'='*80}")
print("üìä STROKE DOMAIN PERFORMANCE SUMMARY")
print(f"{'='*80}")

# Collect all stroke results
stroke_results = [out6, out7, out8, out9, out10]

# Calculate performance metrics
total_queries = len(stroke_results)
successful_answers = sum(1 for result in stroke_results if result['answer'] and len(result['answer'].strip()) > 100)
avg_retrieval_score = np.mean([max(doc['score'] for doc in result['retrieved']) for result in stroke_results if result['retrieved']])
avg_docs_retrieved = np.mean([len(result['retrieved']) for result in stroke_results])

print(f"üìà PERFORMANCE METRICS:")
print(f"   ‚úÖ Total Queries: {total_queries}")
print(f"   ‚úÖ Successful Answers: {successful_answers}/{total_queries} ({successful_answers/total_queries*100:.1f}%)")
print(f"   üîç Average Retrieval Score: {avg_retrieval_score:.3f}")
print(f"   üìö Average Documents Retrieved: {avg_docs_retrieved:.1f}")

# Document source analysis
print(f"\nüìÅ DOCUMENT SOURCE ANALYSIS:")
source_categories = {}
for result in stroke_results:
    for doc in result['retrieved']:
        category = doc['source'].split('/')[-3] if len(doc['source'].split('/')) > 5 else 'Unknown'
        source_categories[category] = source_categories.get(category, 0) + 1

for category, count in source_categories.items():
    print(f"   üìÇ {category}: {count} documents")

# Retrieval quality assessment
print(f"\nüéØ RETRIEVAL QUALITY ASSESSMENT:")
high_relevance = sum(1 for result in stroke_results for doc in result['retrieved'] if doc['score'] > 0.6)
total_docs = sum(len(result['retrieved']) for result in stroke_results)
print(f"   ‚≠ê High-relevance documents (>0.6): {high_relevance}/{total_docs} ({high_relevance/total_docs*100:.1f}%)")

# Clinical relevance evaluation
print(f"\nüè• CLINICAL RELEVANCE EVALUATION:")
stroke_keywords = ['stroke', 'ischemic', 'hemorrhagic', 'TIA', 'aphasia', 'weakness', 'facial droop', 'slurred speech']
keyword_hits = 0
total_keywords = len(stroke_keywords) * total_queries

for result in stroke_results:
    answer_text = result['answer'].lower() if result['answer'] else ""
    for keyword in stroke_keywords:
        if keyword in answer_text:
            keyword_hits += 1

print(f"   üîë Clinical keyword coverage: {keyword_hits}/{total_keywords} ({keyword_hits/total_keywords*100:.1f}%)")

print(f"\nüéâ STROKE DOMAIN TESTING COMPLETED!")
print(f"   The Clinical RAG system has processed {total_queries} stroke-related queries")
print(f"   with an average retrieval relevance of {avg_retrieval_score:.3f}")


üöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄ
üß† ENHANCED STROKE DOMAIN CLINICAL RAG TESTING
üöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄüöÄ

üéØ TEST 6/10: ISCHEMIC STROKE IDENTIFICATION

üß† STROKE DOMAIN - EXAMPLE 6: ISCHEMIC STROKE
üìã CLINICAL QUERY:
   Patient with sudden weakness on one side, facial droop, and slurred speech ‚Äî likely type of stroke?

üí° CLINICAL ASSESSMENT:
------------------------------------------------------------
The patient described above exhibits symptoms consistent with ischemic stroke due to thrombosis within cerebral arteries supplying blood flow to brain tissue leading to neurological deficits such as slurring of speech, hemiplegia, and facial drooping. Given their age and pre-existing conditions including hypertension and prior cardiac surgery, there may be increased risk factors contributing to clot formation and subsequent infarction. Further investigation into underlying

***SIMPLE EVALUATION MODULE FOR CLINICAL RAG SYSTEM***

In [39]:
# =============================================================================
# SIMPLE EVALUATION MODULE (REQUIRED)
# =============================================================================

import numpy as np
from collections import defaultdict

def evaluate_retrieval_performance(test_queries, rag_interface):
    """Calculate precision, recall, F1 for retrieval"""
    results = []

    for query_info in test_queries:
        query = query_info["query"]
        expected_keywords = query_info["expected_keywords"]

        # Retrieve documents
        retrieved = rag_interface.enhanced_retrieve(query, top_k=5)

        # Calculate relevance (simple keyword matching)
        relevant_count = 0
        for doc in retrieved:
            doc_text = doc["text"].lower()
            if any(keyword in doc_text for keyword in expected_keywords):
                relevant_count += 1

        precision = relevant_count / len(retrieved) if retrieved else 0
        recall = relevant_count / len(expected_keywords) if expected_keywords else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        results.append({
            "query": query,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "retrieved": len(retrieved),
            "relevant": relevant_count
        })

    # Calculate averages
    avg_precision = np.mean([r["precision"] for r in results])
    avg_recall = np.mean([r["recall"] for r in results])
    avg_f1 = np.mean([r["f1"] for r in results])

    return {
        "detailed_results": results,
        "avg_precision": avg_precision,
        "avg_recall": avg_recall,
        "avg_f1": avg_f1
    }

def evaluate_generation_quality(test_queries, rag_interface):
    """Evaluate generation accuracy, coherence, relevance"""
    results = []

    for query_info in test_queries:
        query = query_info["query"]
        response_type = query_info.get("category", "assessment")

        # Generate response
        response, _, metrics, _ = rag_interface.process_query(query, response_type, 5)

        # Simple scoring (in real scenario, use human evaluation)
        accuracy_score = min(len(response.split()) / 100, 1.0)  # Length-based proxy
        coherence_score = 0.7  # Placeholder - would need actual evaluation

        # Relevance: check keyword presence
        expected_keywords = query_info["expected_keywords"]
        response_lower = response.lower()
        keyword_hits = sum(1 for kw in expected_keywords if kw in response_lower)
        relevance_score = keyword_hits / len(expected_keywords) if expected_keywords else 0.5

        results.append({
            "query": query,
            "accuracy": accuracy_score,
            "coherence": coherence_score,
            "relevance": relevance_score,
            "response_length": len(response)
        })

    # Calculate averages
    avg_accuracy = np.mean([r["accuracy"] for r in results])
    avg_coherence = np.mean([r["coherence"] for r in results])
    avg_relevance = np.mean([r["relevance"] for r in results])

    return {
        "detailed_results": results,
        "avg_accuracy": avg_accuracy,
        "avg_coherence": avg_coherence,
        "avg_relevance": avg_relevance
    }

# Test queries for evaluation
test_queries = [
    {
        "query": "What are symptoms of pneumonia?",
        "expected_keywords": ["cough", "fever", "shortness", "breath", "chest"],
        "category": "symptoms"
    },
    {
        "query": "Treatment for hypertension",
        "expected_keywords": ["medication", "blood", "pressure", "therapy", "treatment"],
        "category": "treatment"
    },
    {
        "query": "Diabetes medications",
        "expected_keywords": ["metformin", "insulin", "medication", "diabetes", "glucose"],
        "category": "medication"
    }
]

# Run evaluation
print("üîç Running System Evaluation...")
retrieval_results = evaluate_retrieval_performance(test_queries, rag_interface)
generation_results = evaluate_generation_quality(test_queries, rag_interface)

print("\nüìä RETRIEVAL METRICS:")
print(f"  ‚Ä¢ Average Precision: {retrieval_results['avg_precision']:.3f}")
print(f"  ‚Ä¢ Average Recall: {retrieval_results['avg_recall']:.3f}")
print(f"  ‚Ä¢ Average F1 Score: {retrieval_results['avg_f1']:.3f}")

print("\nüìä GENERATION METRICS:")
print(f"  ‚Ä¢ Average Accuracy: {generation_results['avg_accuracy']:.3f}")
print(f"  ‚Ä¢ Average Coherence: {generation_results['avg_coherence']:.3f}")
print(f"  ‚Ä¢ Average Relevance: {generation_results['avg_relevance']:.3f}")

üîç Running System Evaluation...

üìä RETRIEVAL METRICS:
  ‚Ä¢ Average Precision: 0.667
  ‚Ä¢ Average Recall: 0.667
  ‚Ä¢ Average F1 Score: 0.667

üìä GENERATION METRICS:
  ‚Ä¢ Average Accuracy: 0.093
  ‚Ä¢ Average Coherence: 0.700
  ‚Ä¢ Average Relevance: 0.200


***DOWNLOAD COMPLETE***

In [None]:
# Save your FAISS index to a file
faiss.write_index(index, "/content/clinical_faiss_index.index")

# Save documents metadata
import pickle
with open('/content/documents_metadata.pkl', 'wb') as f:
    pickle.dump(documents, f)

print("‚úÖ FAISS index and documents metadata saved!")

‚úÖ FAISS index and documents metadata saved!


In [None]:
from google.colab import files
import zipfile

# Download all model files in one command
faiss.write_index(index, "/content/faiss.index")
import pickle
with open('/content/documents.pkl', 'wb') as f:
    pickle.dump(documents, f)

with zipfile.ZipFile('/content/all_files.zip', 'w') as zipf:
    zipf.write('/content/faiss.index')
    zipf.write('/content/documents.pkl')

files.download('/content/all_files.zip')
print("‚úÖ All files downloaded!")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

‚úÖ All files downloaded!


In [None]:
from google.colab import files
import zipfile

# Download ALL files including model weights
faiss.write_index(index, "/content/faiss.index")
import pickle
with open('/content/documents.pkl', 'wb') as f:
    pickle.dump(documents, f)

# ADD THIS FOR MODEL WEIGHTS:
model.save('/content/sentence_model')
!cd /content && zip -r sentence_model.zip sentence_model/

with zipfile.ZipFile('/content/all_files_with_weights.zip', 'w') as zipf:
    zipf.write('/content/faiss.index')
    zipf.write('/content/documents.pkl')
    zipf.write('/content/sentence_model.zip')  # ADD THIS LINE

files.download('/content/all_files_with_weights.zip')
print("‚úÖ All files + model weights downloaded!")

  adding: sentence_model/ (stored 0%)
  adding: sentence_model/tokenizer_config.json (deflated 73%)
  adding: sentence_model/sentence_bert_config.json (deflated 9%)
  adding: sentence_model/config_sentence_transformers.json (deflated 41%)
  adding: sentence_model/config.json (deflated 47%)
  adding: sentence_model/model.safetensors (deflated 9%)
  adding: sentence_model/tokenizer.json (deflated 71%)
  adding: sentence_model/2_Normalize/ (stored 0%)
  adding: sentence_model/modules.json (deflated 62%)
  adding: sentence_model/special_tokens_map.json (deflated 80%)
  adding: sentence_model/vocab.txt (deflated 53%)
  adding: sentence_model/1_Pooling/ (stored 0%)
  adding: sentence_model/1_Pooling/config.json (deflated 59%)
  adding: sentence_model/README.md (deflated 64%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

‚úÖ All files + model weights downloaded!


In [None]:
from google.colab import files
import pickle
import numpy as np

print("üöÄ Starting individual file downloads...")

# 1. Download FAISS Index
print("üìä Downloading FAISS index...")
faiss.write_index(index, "/content/clinical_faiss_index.index")
files.download('/content/clinical_faiss_index.index')

# 2. Download Documents Metadata
print("üìù Downloading documents metadata...")
with open('/content/clinical_documents.pkl', 'wb') as f:
    pickle.dump(documents, f)
files.download('/content/clinical_documents.pkl')

# 3. Download Embeddings
print("üî¢ Downloading embeddings...")
np.save('/content/document_embeddings.npy', emb)
files.download('/content/document_embeddings.npy')

# 4. Download Model Weights (Sentence Transformer)
print("ü§ñ Downloading model weights...")
model.save('/content/sentence_transformer_model')
!cd /content && tar -czf sentence_model_weights.tar.gz sentence_transformer_model/
files.download('/content/sentence_model_weights.tar.gz')

# 5. Download Sample Test Results
print("üß™ Downloading test results...")
test_data = {
    'sample_queries': [
        "Patient with sudden weakness and facial droop",
        "Patient with severe headache and vomiting"
    ],
    'document_count': len(documents),
    'embedding_dimension': emb.shape[1]
}

import json
with open('/content/test_summary.json', 'w') as f:
    json.dump(test_data, f, indent=2)
files.download('/content/test_summary.json')

print("‚úÖ All individual files downloaded successfully!")
print("üìÅ Files downloaded:")
print("   ‚Ä¢ clinical_faiss_index.index")
print("   ‚Ä¢ clinical_documents.pkl")
print("   ‚Ä¢ document_embeddings.npy")
print("   ‚Ä¢ sentence_model_weights.tar.gz")
print("   ‚Ä¢ test_summary.json")

üöÄ Starting individual file downloads...
üìä Downloading FAISS index...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

üìù Downloading documents metadata...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

üî¢ Downloading embeddings...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

ü§ñ Downloading model weights...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

üß™ Downloading test results...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

‚úÖ All individual files downloaded successfully!
üìÅ Files downloaded:
   ‚Ä¢ clinical_faiss_index.index
   ‚Ä¢ clinical_documents.pkl
   ‚Ä¢ document_embeddings.npy
   ‚Ä¢ sentence_model_weights.tar.gz
   ‚Ä¢ test_summary.json


**UI**

In [22]:
# Install Gradio
!pip install gradio==4.21.0

import gradio as gr
import torch
import numpy as np
import os
import pickle
import json
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt

# =============================================================================
# GRADIO INTERFACE FOR CLINICAL RAG SYSTEM
# =============================================================================

class ClinicalRAGInterface:
    def __init__(self):
        self.model = model
        self.index = index
        self.documents = documents
        self.tokenizer = tokenizer
        self.generator = generator
        self.query_history = []

    def enhanced_retrieve(self, query, top_k=5, min_score=0.4):
        """Enhanced retrieval with better filtering"""
        q = self.model.encode([query], convert_to_numpy=True)
        scores, idx = self.index.search(q, top_k * 5)  # Get more to filter

        out = []
        seen_sources = set()

        for score, i in zip(scores[0], idx[0]):
            if i >= len(self.documents):
                continue

            source = self.documents[i]["source"]

            # Skip duplicates and low scores
            if source in seen_sources or score < min_score:
                continue
            seen_sources.add(source)

            # Extract category from path
            try:
                parts = source.split('/')
                if len(parts) > 5:
                    category = parts[-3]
                else:
                    category = "Unknown"
            except:
                category = "Unknown"

            # Clean text
            text = self.documents[i]["text"].strip()
            if len(text) > 1000:
                text = text[:1000] + "..."

            out.append({
                "score": float(score),
                "source": source,
                "filename": os.path.basename(source),
                "category": category,
                "text": text
            })

            if len(out) >= top_k:
                break

        return out

    def generate_specific_response(self, query, retrieved_docs, response_type="assessment"):
        """Generate different types of responses based on user needs"""

        # Prepare context
        context_parts = []
        for i, doc in enumerate(retrieved_docs):
            doc_text = doc["text"].strip()
            if len(doc_text) > 600:
                trunc_point = doc_text[:600].rfind('.')
                if trunc_point > 300:
                    doc_text = doc_text[:trunc_point+1]
                else:
                    doc_text = doc_text[:600] + "..."
            context_parts.append(f"[Document {i+1} - {doc['category']}]: {doc_text}")

        context = "\n\n".join(context_parts)

        # Define different prompts based on response type
        prompts = {
            "assessment": f"""Based on the following clinical documents, provide a concise medical assessment.

CLINICAL CONTEXT:
{context}

CLINICAL QUESTION: {query}

MEDICAL ASSESSMENT:""",

            "symptoms": f"""Extract ONLY symptoms mentioned in the clinical documents below.
- List each symptom as a bullet point
- Use only exact terms found in documents
- Do not include diagnosis or treatment

CLINICAL DOCUMENTS:
{context}

Query: {query}

SYMPTOMS (bullet points only):""",

            "treatment": f"""Extract ONLY treatments, medications, and procedures mentioned in the clinical documents below.
- List each treatment as a bullet point
- Include medication names and dosages if available
- Include procedures and interventions

CLINICAL DOCUMENTS:
{context}

Query: {query}

TREATMENTS (bullet points only):""",

            "diagnosis": f"""Extract potential diagnoses and clinical findings from the documents below.
- List each diagnosis possibility
- Include supporting clinical findings
- Mention confidence level if indicated

CLINICAL DOCUMENTS:
{context}

Query: {query}

POTENTIAL DIAGNOSES:""",

            "summary": f"""Create a comprehensive clinical summary from the documents below.
Include:
1. Key symptoms
2. Clinical findings
3. Potential diagnoses
4. Recommended treatments
5. Follow-up recommendations

CLINICAL DOCUMENTS:
{context}

Query: {query}

COMPREHENSIVE CLINICAL SUMMARY:"""
        }

        prompt = prompts.get(response_type, prompts["assessment"])

        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.generator.device)

        output = self.generator.generate(
            **inputs,
            max_new_tokens=500,
            temperature=0.3,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            repetition_penalty=1.2
        )

        full_output = self.tokenizer.decode(output[0], skip_special_tokens=True)

        # Extract the response based on prompt type
        markers = {
            "assessment": "MEDICAL ASSESSMENT:",
            "symptoms": "SYMPTOMS (bullet points only):",
            "treatment": "TREATMENTS (bullet points only):",
            "diagnosis": "POTENTIAL DIAGNOSES:",
            "summary": "COMPREHENSIVE CLINICAL SUMMARY:"
        }

        marker = markers.get(response_type, "MEDICAL ASSESSMENT:")
        if marker in full_output:
            response = full_output.split(marker)[-1].strip()
        else:
            response = full_output

        return response

    def process_query(self, query, response_type="assessment", top_k=5):
        """Main processing function for Gradio interface"""

        # Add to history
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.query_history.append({
            "timestamp": timestamp,
            "query": query,
            "response_type": response_type
        })

        # Retrieve documents
        retrieved_docs = self.enhanced_retrieve(query, top_k=top_k)

        # Generate response
        response = self.generate_specific_response(query, retrieved_docs, response_type)

        # Format retrieved documents for display
        docs_info = []
        for i, doc in enumerate(retrieved_docs):
            docs_info.append({
                "Rank": i+1,
                "Relevance": f"{doc['score']:.3f}",
                "Category": doc['category'],
                "Source": doc['filename'],
                "Preview": doc['text'][:200] + "..."
            })

        # Create metrics
        avg_score = np.mean([doc['score'] for doc in retrieved_docs]) if retrieved_docs else 0
        categories = {}
        for doc in retrieved_docs:
            cat = doc['category']
            categories[cat] = categories.get(cat, 0) + 1

        metrics = {
            "Query": query,
            "Response Type": response_type.upper(),
            "Documents Retrieved": len(retrieved_docs),
            "Average Relevance": f"{avg_score:.3f}",
            "Categories Retrieved": ", ".join([f"{k} ({v})" for k, v in categories.items()]),
            "Processing Time": timestamp
        }

        return response, docs_info, metrics

    def generate_report(self, query, response_type):
        """Generate a formatted clinical report"""
        response, docs_info, metrics = self.process_query(query, response_type)

        report = f"""
{'='*60}
üè• CLINICAL RAG SYSTEM REPORT
{'='*60}

üìã QUERY: {query}
üéØ RESPONSE TYPE: {response_type.upper()}
‚è∞ TIMESTAMP: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

üìä METRICS:
{'‚îÄ'*40}
‚Ä¢ Documents Retrieved: {metrics['Documents Retrieved']}
‚Ä¢ Average Relevance: {metrics['Average Relevance']}
‚Ä¢ Categories: {metrics['Categories Retrieved']}
{'‚îÄ'*40}

üí° CLINICAL RESPONSE:
{'‚îÄ'*40}
{response}
{'‚îÄ'*40}

üìö SOURCE DOCUMENTS:
{'‚îÄ'*40}
"""

        for doc in docs_info:
            report += f"\n[{doc['Rank']}] ‚≠ê{doc['Relevance']} | üè∑Ô∏è {doc['Category']}\n"
            report += f"   üìÅ {doc['Source']}\n"
            report += f"   üìù {doc['Preview']}\n"

        report += f"\n{'='*60}\n"

        return report

    def get_query_stats(self):
        """Get statistics about queries"""
        if not self.query_history:
            return "No queries processed yet."

        df = pd.DataFrame(self.query_history)
        stats = f"""
üìà QUERY STATISTICS
{'‚îÄ'*40}
‚Ä¢ Total Queries: {len(df)}
‚Ä¢ Last Query: {df.iloc[-1]['timestamp']}
‚Ä¢ Response Types: {df['response_type'].value_counts().to_dict()}
{'‚îÄ'*40}
Recent Queries:
"""
        for i, row in df.tail(5).iterrows():
            stats += f"\n[{row['timestamp']}] {row['response_type'].upper()}: {row['query'][:50]}..."

        return stats

# =============================================================================
# INITIALIZE THE INTERFACE
# =============================================================================

rag_interface = ClinicalRAGInterface()

# =============================================================================
# GRADIO UI COMPONENTS
# =============================================================================

def create_gradio_interface():
    """Create the complete Gradio interface"""

    # Custom CSS for better styling
    css = """
    .gradio-container {
        max-width: 1200px !important;
    }
    .clinical-input {
        font-size: 16px !important;
        padding: 12px !important;
    }
    .output-box {
        border-radius: 10px;
        padding: 15px;
        background: #f5f7fa;
        border: 1px solid #e0e0e0;
    }
    .metric-box {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        color: white;
        padding: 15px;
        border-radius: 10px;
        margin: 10px 0;
    }
    .doc-box {
        background: #ffffff;
        border: 1px solid #e0e0e0;
        border-radius: 8px;
        padding: 10px;
        margin: 5px 0;
    }
    """

    # Theme
    theme = gr.themes.Soft(
        primary_hue="blue",
        secondary_hue="gray",
    ).set(
        body_background_fill="#f9fafb",
        button_primary_background_fill="#3b82f6",
        button_primary_background_fill_hover="#2563eb",
    )

    # Examples
    examples = [
        ["Patient with sudden headache, nausea, and vomiting", "assessment"],
        ["What are common symptoms of migraine?", "symptoms"],
        ["Treatment options for hypertension", "treatment"],
        ["Differential diagnosis for chest pain", "diagnosis"],
        ["Summary of stroke management guidelines", "summary"],
        ["Patient with fever and cough for 3 days", "assessment"],
        ["Medications for diabetes management", "treatment"],
        ["Symptoms of pneumonia in elderly patients", "symptoms"]
    ]

    # Main processing function
    def process_clinical_query(query, response_type, top_k_slider, show_details):
        """Process user query and return results"""

        response, docs_info, metrics = rag_interface.process_query(
            query, response_type, top_k_slider
        )

        # Format output based on user preference
        if show_details:
            # Detailed output with metrics
            output_text = f"""
üìã **CLINICAL QUERY:** {query}
üéØ **RESPONSE TYPE:** {response_type.upper()}
üìä **METRICS:**
   ‚Ä¢ Documents Retrieved: {metrics['Documents Retrieved']}
   ‚Ä¢ Average Relevance: {metrics['Average Relevance']}
   ‚Ä¢ Categories: {metrics['Categories Retrieved']}

üí° **RESPONSE:**
{response}

üìö **SOURCE DOCUMENTS:**
"""
            for doc in docs_info:
                output_text += f"""
üîç **Document {doc['Rank']}** (‚≠ê{doc['Relevance']} | üè∑Ô∏è {doc['Category']})
üìÅ {doc['Source']}
üìù {doc['Preview']}
"""
        else:
            # Simple output
            output_text = response

        # Create metrics display
        metrics_display = {
            "Query": metrics["Query"],
            "Response Type": metrics["Response Type"],
            "Documents Retrieved": metrics["Documents Retrieved"],
            "Average Relevance": metrics["Average Relevance"],
            "Processing Time": metrics["Processing Time"]
        }

        # Create documents dataframe for display
        if docs_info:
            docs_df = pd.DataFrame(docs_info)
        else:
            docs_df = pd.DataFrame({"Message": ["No documents retrieved"]})

        return output_text, metrics_display, docs_df

    # Generate report function
    def generate_full_report(query, response_type):
        report = rag_interface.generate_report(query, response_type)
        return report

    # Get statistics function
    def show_statistics():
        return rag_interface.get_query_stats()

    # Clear history function
    def clear_history():
        rag_interface.query_history = []
        return "Query history cleared!"

    # Demo function
    def run_demo():
        """Run a demo query"""
        demo_response, _, _ = process_clinical_query(
            "Patient with chest pain and shortness of breath",
            "assessment",
            5,
            True
        )
        return demo_response

    # Create the Gradio interface
    with gr.Blocks(theme=theme, css=css, title="üè• Clinical RAG Assistant") as demo:
        gr.Markdown("""
        # üè• Clinical RAG Assistant
        ### Diagnostic Reasoning for Clinical Notes (DiReCT)
        *Powered by MIMIC-IV-Ext Direct Dataset & Qwen2.5-1.5B-Instruct*
        """)

        with gr.Row():
            with gr.Column(scale=2):
                # Query Input Section
                gr.Markdown("## üìù Clinical Query Input")

                query_input = gr.Textbox(
                    label="Enter your clinical query:",
                    placeholder="e.g., Patient with headache, nausea, and photophobia...",
                    lines=3,
                    elem_classes="clinical-input"
                )

                with gr.Row():
                    response_type = gr.Dropdown(
                        choices=["assessment", "symptoms", "treatment", "diagnosis", "summary"],
                        value="assessment",
                        label="Response Type",
                        info="What type of information do you need?"
                    )

                    top_k_slider = gr.Slider(
                        minimum=1,
                        maximum=10,
                        value=5,
                        step=1,
                        label="Number of documents to retrieve",
                        info="More documents = more context, but slower"
                    )

                show_details = gr.Checkbox(
                    label="Show detailed analysis",
                    value=True,
                    info="Include source documents and metrics"
                )

                with gr.Row():
                    submit_btn = gr.Button("üîç Analyze Query", variant="primary", size="lg")
                    demo_btn = gr.Button("üß™ Run Demo", variant="secondary")
                    clear_btn = gr.Button("üóëÔ∏è Clear", variant="secondary")

            with gr.Column(scale=1):
                # Quick Examples Section
                gr.Markdown("## üí° Quick Examples")
                gr.Examples(
                    examples=examples,
                    inputs=[query_input, response_type],
                    label="Click to try examples"
                )

                # Stats Section
                gr.Markdown("## üìä System Statistics")
                stats_output = gr.Textbox(label="Query Statistics", interactive=False)
                stats_btn = gr.Button("üìà Update Statistics")

        # Output Section
        gr.Markdown("## üìã Clinical Analysis Results")

        with gr.Tabs():
            with gr.TabItem("üí° Primary Response"):
                output_text = gr.Textbox(
                    label="Clinical Analysis",
                    lines=15,
                    elem_classes="output-box"
                )

            with gr.TabItem("üìä Metrics"):
                metrics_json = gr.JSON(
                    label="Retrieval Metrics",
                    elem_classes="output-box"
                )

            with gr.TabItem("üìö Source Documents"):
                docs_table = gr.Dataframe(
                    label="Retrieved Clinical Documents",
                    headers=["Rank", "Relevance", "Category", "Source", "Preview"],
                    elem_classes="output-box"
                )

            with gr.TabItem("üìÑ Full Report"):
                report_output = gr.Textbox(
                    label="Comprehensive Clinical Report",
                    lines=20,
                    elem_classes="output-box"
                )
                report_btn = gr.Button("üìã Generate Full Report", variant="primary")

        # Control Buttons
        with gr.Row():
            download_btn = gr.Button("üì• Download Report")
            reset_btn = gr.Button("üîÑ Reset All")

        # Event Handlers
        submit_btn.click(
            fn=process_clinical_query,
            inputs=[query_input, response_type, top_k_slider, show_details],
            outputs=[output_text, metrics_json, docs_table]
        )

        demo_btn.click(
            fn=run_demo,
            outputs=output_text
        )

        clear_btn.click(
            fn=lambda: ["", {}, pd.DataFrame()],
            outputs=[query_input, metrics_json, docs_table]
        )

        report_btn.click(
            fn=generate_full_report,
            inputs=[query_input, response_type],
            outputs=report_output
        )

        stats_btn.click(
            fn=show_statistics,
            outputs=stats_output
        )

        reset_btn.click(
            fn=lambda: ["", "assessment", 5, True, "", {}, pd.DataFrame(), "", ""],
            outputs=[query_input, response_type, top_k_slider, show_details,
                    output_text, metrics_json, docs_table, report_output, stats_output]
        )

        # Footer
        gr.Markdown("""
        ---
        ### üîí Ethical & Privacy Notice
        *This system uses de-identified clinical data from MIMIC-IV-Ext Direct dataset.*
        *Responses are AI-generated and should be verified by healthcare professionals.*
        *Not for actual clinical decision-making without proper validation.*

        ### üõ†Ô∏è System Information
        - **Embedding Model:** all-MiniLM-L6-v2
        - **Generator Model:** Qwen2.5-1.5B-Instruct
        - **Vector Database:** FAISS
        - **Documents Indexed:** {len(documents)}
        """)

        return demo

# =============================================================================
# LAUNCH THE GRADIO INTERFACE
# =============================================================================

print("üöÄ Creating Gradio interface...")
demo = create_gradio_interface()

# Launch in Colab
print("üéØ Launching Gradio interface...")
demo.launch(
    share=True,  # Creates public link
    debug=False,
    server_name="0.0.0.0",
    server_port=7860
)

# Alternative: Launch with localtunnel for better Colab access
print("\nüì± Alternative access methods:")
print("1. Click the gradio.app link above")
print("2. Or use this Colab localtunnel command:")
print("   !npx localtunnel --port 7860")

üöÄ Creating Gradio interface...
üéØ Launching Gradio interface...
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
IMPORTANT: You are using gradio version 4.21.0, however version 4.44.1 is available, please upgrade.
--------
Running on public URL: https://892956093ecf4307bf.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)



üì± Alternative access methods:
1. Click the gradio.app link above
2. Or use this Colab localtunnel command:
   !npx localtunnel --port 7860


In [24]:
# Install Gradio
!pip install gradio==4.21.0

import gradio as gr
import torch
import numpy as np
import os
import pickle
import json
from datetime import datetime
import pandas as pd
import re

# =============================================================================
# ENHANCED GRADIO INTERFACE FOR CLINICAL RAG SYSTEM
# =============================================================================

class EnhancedClinicalRAGInterface:
    def __init__(self):
        self.model = model
        self.index = index
        self.documents = documents
        self.tokenizer = tokenizer
        self.generator = generator
        self.query_history = []

    def enhanced_retrieve(self, query, top_k=5, min_score=0.4):
        """Enhanced retrieval with better filtering"""
        q = self.model.encode([query], convert_to_numpy=True)
        scores, idx = self.index.search(q, top_k * 5)  # Get more to filter

        out = []
        seen_sources = set()

        for score, i in zip(scores[0], idx[0]):
            if i >= len(self.documents):
                continue

            source = self.documents[i]["source"]

            # Skip duplicates and low scores
            if source in seen_sources or score < min_score:
                continue
            seen_sources.add(source)

            # Extract category from path
            try:
                parts = source.split('/')
                if len(parts) > 5:
                    category = parts[-3]
                else:
                    category = "Unknown"
            except:
                category = "Unknown"

            # Clean text
            text = self.documents[i]["text"].strip()
            if len(text) > 1000:
                text = text[:1000] + "..."

            out.append({
                "score": float(score),
                "source": source,
                "filename": os.path.basename(source),
                "category": category,
                "text": text
            })

            if len(out) >= top_k:
                break

        return out

    def format_as_bullets(self, text, response_type):
        """Format text as bullet points based on response type"""
        # Clean the text
        text = text.strip()

        # Split into sentences
        sentences = re.split(r'[.!?]+', text)

        bullet_points = []
        for sentence in sentences:
            sentence = sentence.strip()
            if len(sentence) > 10:  # Only meaningful sentences
                # Add bullet point
                bullet_points.append(f"‚Ä¢ {sentence}")

        # If no bullets created, use the original text with bullets
        if not bullet_points:
            lines = text.split('\n')
            for line in lines:
                line = line.strip()
                if line:
                    bullet_points.append(f"‚Ä¢ {line}")

        # Limit to reasonable number
        bullet_points = bullet_points[:15]

        # Add header based on response type
        headers = {
            "symptoms": "ü©∫ **SYMPTOMS FOUND:**",
            "treatment": "üíä **TREATMENTS & MEDICATIONS:**",
            "diagnosis": "üîç **POTENTIAL DIAGNOSES:**",
            "medication": "üíä **MEDICATIONS PRESCRIBED:**",
            "assessment": "üìã **CLINICAL ASSESSMENT:**",
            "summary": "üìÑ **COMPREHENSIVE SUMMARY:**"
        }

        header = headers.get(response_type, "üìã **CLINICAL FINDINGS:**")

        formatted_output = f"{header}\n\n"
        formatted_output += "\n".join(bullet_points)

        return formatted_output

    def generate_bullet_response(self, query, retrieved_docs, response_type="symptoms"):
        """Generate bullet point responses for different query types"""

        # Prepare context
        context_parts = []
        for i, doc in enumerate(retrieved_docs[:3]):  # Limit to 3 docs for context
            doc_text = doc["text"].strip()
            if len(doc_text) > 800:
                trunc_point = doc_text[:800].rfind('.')
                if trunc_point > 400:
                    doc_text = doc_text[:trunc_point+1]
                else:
                    doc_text = doc_text[:800] + "..."
            context_parts.append(f"[Document {i+1}]: {doc_text}")

        context = "\n\n".join(context_parts)

        # Enhanced prompts for bullet points
        prompts = {
            "symptoms": f"""EXTRACT ALL SYMPTOMS mentioned in these clinical documents.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. List ONLY symptoms found in the documents
2. Use bullet points (‚Ä¢) for each symptom
3. Be specific and include details like severity, location, duration
4. Do NOT include treatments or diagnoses
5. Group similar symptoms together

SYMPTOMS LIST:""",

            "treatment": f"""EXTRACT ALL TREATMENTS, MEDICATIONS, and PROCEDURES mentioned.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. List ALL treatments mentioned
2. Use bullet points (‚Ä¢) for each treatment
3. Include: medication names, dosages, frequencies, routes
4. Include procedures, therapies, interventions
5. Do NOT include symptoms or diagnoses

TREATMENTS LIST:""",

            "diagnosis": f"""EXTRACT ALL POTENTIAL DIAGNOSES and CLINICAL FINDINGS.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. List ALL possible diagnoses mentioned
2. Use bullet points (‚Ä¢) for each diagnosis
3. Include supporting clinical findings
4. Mention differential diagnoses if present
5. Do NOT include treatments or symptoms

DIAGNOSES LIST:""",

            "medication": f"""EXTRACT ONLY MEDICATIONS and PRESCRIPTIONS mentioned.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. List ALL medications mentioned
2. Use bullet points (‚Ä¢) for each medication
3. Format: Drug Name - Dosage - Frequency - Route
4. Include PRN (as needed) medications
5. Do NOT include symptoms or diagnoses

MEDICATIONS LIST:""",

            "assessment": f"""PROVIDE A CLINICAL ASSESSMENT based on the documents.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. Provide a structured assessment
2. Use bullet points (‚Ä¢) for each key point
3. Include: symptoms, findings, likely issues
4. Keep it concise and organized
5. Focus on the query

CLINICAL ASSESSMENT:"""
        }

        prompt = prompts.get(response_type, prompts["assessment"])

        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.generator.device)

        output = self.generator.generate(
            **inputs,
            max_new_tokens=600,
            temperature=0.2,  # Lower temperature for more focused responses
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            repetition_penalty=1.3,
            no_repeat_ngram_size=3
        )

        full_output = self.tokenizer.decode(output[0], skip_special_tokens=True)

        # Extract the response part
        response_markers = [
            "SYMPTOMS LIST:",
            "TREATMENTS LIST:",
            "DIAGNOSES LIST:",
            "MEDICATIONS LIST:",
            "CLINICAL ASSESSMENT:"
        ]

        response_text = full_output
        for marker in response_markers:
            if marker in full_output:
                response_text = full_output.split(marker)[-1].strip()
                break

        # Format as bullet points
        formatted_response = self.format_as_bullets(response_text, response_type)

        return formatted_response

    def process_query(self, query, response_type="symptoms", top_k=5):
        """Main processing function"""

        # Add to history
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.query_history.append({
            "timestamp": timestamp,
            "query": query,
            "response_type": response_type
        })

        # Retrieve documents
        retrieved_docs = self.enhanced_retrieve(query, top_k=top_k)

        # Generate bullet point response
        response = self.generate_bullet_response(query, retrieved_docs, response_type)

        # Format retrieved documents
        docs_info = []
        for i, doc in enumerate(retrieved_docs):
            docs_info.append({
                "Rank": i+1,
                "Relevance": f"{doc['score']:.3f}",
                "Category": doc['category'],
                "Source": doc['filename'],
                "Preview": doc['text'][:150] + "..."
            })

        # Create metrics
        avg_score = np.mean([doc['score'] for doc in retrieved_docs]) if retrieved_docs else 0
        categories = {}
        for doc in retrieved_docs:
            cat = doc['category']
            categories[cat] = categories.get(cat, 0) + 1

        metrics = {
            "Query": query,
            "Response Type": response_type.upper(),
            "Documents Retrieved": len(retrieved_docs),
            "Average Relevance": f"{avg_score:.3f}",
            "Categories": ", ".join([f"{k} ({v})" for k, v in categories.items()]),
            "Response Format": "BULLET POINTS",
            "Processing Time": timestamp
        }

        return response, docs_info, metrics

# =============================================================================
# INITIALIZE THE INTERFACE
# =============================================================================

rag_interface = EnhancedClinicalRAGInterface()

# =============================================================================
# SIMPLIFIED GRADIO INTERFACE WITHOUT EXAMPLES TABLE
# =============================================================================

def create_simple_gradio_interface():
    """Create simplified interface without examples table"""

    # Custom CSS
    css = """
    .gradio-container {
        max-width: 1000px !important;
        margin: auto !important;
    }
    .clinical-input {
        font-size: 16px !important;
        padding: 15px !important;
        border-radius: 10px !important;
    }
    .output-box {
        border-radius: 10px;
        padding: 20px;
        background: #f8f9fa;
        border: 2px solid #e0e0e0;
        font-family: 'Segoe UI', sans-serif;
        line-height: 1.8;
    }
    .bullet-points {
        margin-left: 20px;
    }
    .bullet-points li {
        margin-bottom: 10px;
        padding-left: 10px;
    }
    .type-badge {
        display: inline-block;
        padding: 5px 15px;
        border-radius: 20px;
        font-weight: bold;
        margin: 5px;
        font-size: 14px;
    }
    .symptom-badge { background: #ff6b6b; color: white; }
    .treatment-badge { background: #4ecdc4; color: white; }
    .diagnosis-badge { background: #45b7d1; color: white; }
    .medication-badge { background: #96ceb4; color: white; }
    """

    # Theme
    theme = gr.themes.Soft(
        primary_hue="blue",
        secondary_hue="gray",
    ).set(
        body_background_fill="#f0f2f6",
        button_primary_background_fill="#3b82f6",
        button_primary_background_fill_hover="#2563eb",
    )

    # Response type descriptions
    response_descriptions = {
        "symptoms": "üìã Get bullet-point list of symptoms",
        "treatment": "üíä Get bullet-point list of treatments & medications",
        "diagnosis": "üîç Get bullet-point list of possible diagnoses",
        "medication": "üíä Get bullet-point list of medications only",
        "assessment": "ü©∫ Get bullet-point clinical assessment"
    }

    def process_query_with_bullets(query, response_type, top_k):
        """Process query and return bullet-point results"""

        response, docs_info, metrics = rag_interface.process_query(query, response_type, top_k)

        # Create badge based on response type
        badges = {
            "symptoms": "ü©∫ SYMPTOMS",
            "treatment": "üíä TREATMENTS",
            "diagnosis": "üîç DIAGNOSES",
            "medication": "üíä MEDICATIONS",
            "assessment": "ü©∫ ASSESSMENT"
        }

        badge = badges.get(response_type, "üìã RESULTS")

        # Format the output
        formatted_output = f"""
<div style="font-family: 'Segoe UI', sans-serif; line-height: 1.8;">
    <div style="background: linear-gradient(135deg, #3b82f6, #1d4ed8); color: white; padding: 15px; border-radius: 10px 10px 0 0; margin-bottom: 20px;">
        <h3 style="margin: 0;">{badge} ANALYSIS</h3>
        <p style="margin: 5px 0 0 0; opacity: 0.9;">Query: <strong>{query}</strong></p>
    </div>

    <div style="background: white; padding: 20px; border-radius: 0 0 10px 10px; border: 1px solid #e0e0e0;">
        <div style="margin-bottom: 20px;">
            <strong>üìä System Metrics:</strong><br>
            ‚Ä¢ Documents Retrieved: {metrics['Documents Retrieved']}<br>
            ‚Ä¢ Average Relevance: {metrics['Average Relevance']}<br>
            ‚Ä¢ Response Type: {response_type.upper()}<br>
            ‚Ä¢ Processing Time: {metrics['Processing Time']}
        </div>

        <div style="background: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #3b82f6;">
            {response.replace('‚Ä¢', '‚Ä¢').replace('\n', '<br>')}
        </div>

        <div style="margin-top: 20px; padding-top: 15px; border-top: 1px solid #e0e0e0;">
            <strong>üìö Source Information:</strong><br>
            Retrieved from {len(docs_info)} clinical documents in categories: {metrics['Categories']}
        </div>
    </div>
</div>
"""

        # Create metrics for JSON
        metrics_display = {
            "query": query,
            "response_type": response_type,
            "documents_retrieved": metrics["Documents Retrieved"],
            "average_relevance": metrics["Average Relevance"],
            "categories": metrics["Categories"],
            "processing_time": metrics["Processing Time"]
        }

        # Create docs dataframe
        if docs_info:
            docs_df = pd.DataFrame(docs_info)
        else:
            docs_df = pd.DataFrame({"Message": ["No documents retrieved"]})

        return formatted_output, metrics_display, docs_df

    # Create the interface
    with gr.Blocks(theme=theme, css=css, title="üè• Clinical RAG Assistant") as demo:
        gr.Markdown("""
        # üè• Clinical RAG Assistant
        ### Get Organized Bullet-Point Medical Information
        *Ask about symptoms, treatments, diagnoses, medications - get clear bullet-point responses*
        """)

        with gr.Row():
            with gr.Column(scale=2):
                # Query Input
                gr.Markdown("### üìù Enter Medical Query")
                query_input = gr.Textbox(
                    label="",
                    placeholder="Examples: 'symptoms of pneumonia', 'treatments for hypertension', 'medications for diabetes'...",
                    lines=3,
                    elem_classes="clinical-input"
                )

                # Response Type Selection
                gr.Markdown("### üéØ Select Information Type")
                response_type = gr.Radio(
                    choices=list(response_descriptions.keys()),
                    value="symptoms",
                    label="",
                    info="Choose what type of information you need",
                    elem_id="response-type-radio"
                )

                # Display descriptions
                response_desc = gr.Markdown(
                    value=response_descriptions["symptoms"],
                    elem_id="response-desc"
                )

                # Update description when radio changes
                def update_desc(response_type):
                    return response_descriptions.get(response_type, "")

                response_type.change(
                    fn=update_desc,
                    inputs=response_type,
                    outputs=response_desc
                )

                # Settings
                with gr.Row():
                    top_k_slider = gr.Slider(
                        minimum=2,
                        maximum=10,
                        value=5,
                        step=1,
                        label="Number of clinical documents to analyze",
                        info="More documents = more comprehensive results"
                    )

                # Action Button
                submit_btn = gr.Button(
                    "üîç Generate Bullet-Point Analysis",
                    variant="primary",
                    size="lg",
                    scale=1
                )

                # Clear Button
                clear_btn = gr.Button("üóëÔ∏è Clear All", variant="secondary", scale=0)

            with gr.Column(scale=1):
                # Quick Query Suggestions
                gr.Markdown("### üí° Quick Suggestions")

                # Quick query buttons
                with gr.Column():
                    symptoms_btn = gr.Button("ü©∫ Symptoms of Pneumonia", size="sm")
                    treatment_btn = gr.Button("üíä Treatments for Hypertension", size="sm")
                    diagnosis_btn = gr.Button("üîç Diagnoses for Chest Pain", size="sm")
                    medication_btn = gr.Button("üíä Medications for Diabetes", size="sm")

                # System Info
                gr.Markdown("### üõ†Ô∏è System Information")
                gr.Markdown(f"""
                - **Documents Indexed:** {len(documents)}
                - **Embedding Model:** all-MiniLM-L6-v2
                - **LLM Model:** Qwen2.5-1.5B-Instruct
                - **Response Format:** Bullet Points
                """)

        # Output Section with Tabs
        gr.Markdown("### üìã Analysis Results")

        with gr.Tabs():
            with gr.TabItem("üí° Clinical Analysis"):
                output_html = gr.HTML(
                    label="",
                    value="<div style='text-align: center; padding: 40px; color: #666;'>Enter a query above and click 'Generate Bullet-Point Analysis'</div>"
                )

            with gr.TabItem("üìä System Metrics"):
                metrics_json = gr.JSON(
                    label="Retrieval Metrics",
                    value={}
                )

            with gr.TabItem("üìö Source Documents"):
                docs_table = gr.Dataframe(
                    label="Retrieved Clinical Documents",
                    headers=["Rank", "Relevance", "Category", "Source", "Preview"],
                    value=pd.DataFrame({"Message": ["No analysis performed yet"]})
                )

        # Quick query button actions
        def set_quick_query(query, r_type):
            return query, r_type, 5

        symptoms_btn.click(
            fn=lambda: set_quick_query("What are the symptoms of pneumonia?", "symptoms"),
            outputs=[query_input, response_type, top_k_slider]
        )

        treatment_btn.click(
            fn=lambda: set_quick_query("What treatments are available for hypertension?", "treatment"),
            outputs=[query_input, response_type, top_k_slider]
        )

        diagnosis_btn.click(
            fn=lambda: set_quick_query("What are possible diagnoses for chest pain?", "diagnosis"),
            outputs=[query_input, response_type, top_k_slider]
        )

        medication_btn.click(
            fn=lambda: set_quick_query("What medications are used for diabetes?", "medication"),
            outputs=[query_input, response_type, top_k_slider]
        )

        # Main submit action
        submit_btn.click(
            fn=process_query_with_bullets,
            inputs=[query_input, response_type, top_k_slider],
            outputs=[output_html, metrics_json, docs_table]
        )

        # Clear action
        clear_btn.click(
            fn=lambda: ["", "symptoms", 5,
                       "<div style='text-align: center; padding: 40px; color: #666;'>Enter a query above and click 'Generate Bullet-Point Analysis'</div>",
                       {}, pd.DataFrame({"Message": ["No analysis performed yet"]})],
            outputs=[query_input, response_type, top_k_slider, output_html, metrics_json, docs_table]
        )

        # Footer
        gr.Markdown("""
        ---
        ### üîí Important Notice
        *This system provides AI-generated bullet-point summaries from clinical documents.*
        *For actual medical decisions, consult healthcare professionals.*
        *All data is de-identified for privacy protection.*

        **Response Types:**
        - **ü©∫ Symptoms**: Bullet-point list of symptoms with details
        - **üíä Treatments**: Bullet-point list of treatments & medications
        - **üîç Diagnosis**: Bullet-point list of possible diagnoses
        - **üíä Medications**: Bullet-point list of medications only
        - **ü©∫ Assessment**: Bullet-point clinical assessment
        """)

    return demo

# =============================================================================
# LAUNCH THE INTERFACE
# =============================================================================

print("üöÄ Creating Clinical RAG Interface with Bullet-Point Results...")
demo = create_simple_gradio_interface()

# Kill any existing gradio processes
import subprocess
import time

print("üîÑ Checking for existing processes...")
subprocess.run(["pkill", "-f", "gradio"], capture_output=True)
subprocess.run(["pkill", "-f", "uvicorn"], capture_output=True)
time.sleep(2)

# Try different ports
ports_to_try = [7860, 7861, 7862, 7863]

for port in ports_to_try:
    print(f"\nüéØ Attempting to launch on port {port}...")
    try:
        demo.launch(
            server_name="0.0.0.0",
            server_port=port,
            share=True,
            quiet=False,
            debug=False
        )
        print(f"‚úÖ Success! Interface running on port {port}")
        print(f"üì± Check the gradio.app link above")
        break
    except Exception as e:
        print(f"‚ö†Ô∏è Port {port} failed: {str(e)[:100]}...")
        continue

# If all ports fail, try with ngrok
if 'demo' not in locals() or not hasattr(demo, 'server'):
    print("\nüîÑ Trying ngrok tunnel...")
    !pip install pyngrok -q

    from pyngrok import ngrok

    # Kill existing ngrok
    ngrok.kill()

    # Create tunnel
    public_url = ngrok.connect(addr="7865", proto="http")
    print(f"üåê Public URL: {public_url}")

    # Launch on local port
    demo.launch(
        server_name="0.0.0.0",
        server_port=7865,
        share=False,
        quiet=True
    )

    print(f"\n‚úÖ Interface ready!")
    print(f"üì± Open: {public_url}")

print("\nüéâ Ready to use! Enter medical queries and get organized bullet-point responses.")

üöÄ Creating Clinical RAG Interface with Bullet-Point Results...
üîÑ Checking for existing processes...
IMPORTANT: You are using gradio version 4.21.0, however version 4.44.1 is available, please upgrade.
--------

üéØ Attempting to launch on port 7860...
‚ö†Ô∏è Port 7860 failed: Cannot find empty port in range: 7860-7860. You can specify a different port by setting the GRADIO_S...

üéØ Attempting to launch on port 7861...
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://b03f27be833f84996f.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


‚úÖ Success! Interface running on port 7861
üì± Check the gradio.app link above

üéâ Ready to use! Enter medical queries and get organized bullet-point responses.


In [30]:
# Install Gradio
!pip install gradio==4.21.0

import gradio as gr
import torch
import numpy as np
import os
import pickle
import json
from datetime import datetime
import pandas as pd
import re

# =============================================================================
# ENHANCED GRADIO INTERFACE FOR CLINICAL RAG SYSTEM
# =============================================================================

class EnhancedClinicalRAGInterface:
    def __init__(self):
        self.model = model
        self.index = index
        self.documents = documents
        self.tokenizer = tokenizer
        self.generator = generator
        self.query_history = []

    def enhanced_retrieve(self, query, top_k=5, min_score=0.4):
        """Enhanced retrieval with better filtering"""
        q = self.model.encode([query], convert_to_numpy=True)
        scores, idx = self.index.search(q, top_k * 5)  # Get more to filter

        out = []
        seen_sources = set()

        for score, i in zip(scores[0], idx[0]):
            if i >= len(self.documents):
                continue

            source = self.documents[i]["source"]

            # Skip duplicates and low scores
            if source in seen_sources or score < min_score:
                continue
            seen_sources.add(source)

            # Extract category from path
            try:
                parts = source.split('/')
                if len(parts) > 5:
                    category = parts[-3]
                else:
                    category = "Unknown"
            except:
                category = "Unknown"

            # Clean text
            text = self.documents[i]["text"].strip()
            if len(text) > 1000:
                text = text[:1000] + "..."

            out.append({
                "score": float(score),
                "source": source,
                "filename": os.path.basename(source),
                "category": category,
                "text": text
            })

            if len(out) >= top_k:
                break

        return out

    def format_as_bullets(self, text, response_type):
        """Format text as bullet points based on response type"""
        # Clean the text
        text = text.strip()

        # Split into sentences
        sentences = re.split(r'[.!?]+', text)

        bullet_points = []
        for sentence in sentences:
            sentence = sentence.strip()
            if len(sentence) > 10:  # Only meaningful sentences
                # Add bullet point
                bullet_points.append(f"‚Ä¢ {sentence}")

        # If no bullets created, use the original text with bullets
        if not bullet_points:
            lines = text.split('\n')
            for line in lines:
                line = line.strip()
                if line:
                    bullet_points.append(f"‚Ä¢ {line}")

        # Limit to reasonable number
        bullet_points = bullet_points[:15]

        # Add header based on response type
        headers = {
            "symptoms": "ü©∫ **SYMPTOMS FOUND:**",
            "treatment": "üíä **TREATMENTS & MEDICATIONS:**",
            "diagnosis": "üîç **POTENTIAL DIAGNOSES:**",
            "medication": "üíä **MEDICATIONS PRESCRIBED:**",
            "assessment": "üìã **CLINICAL ASSESSMENT:**",
            "summary": "üìÑ **COMPREHENSIVE SUMMARY:**"
        }

        header = headers.get(response_type, "üìã **CLINICAL FINDINGS:**")

        formatted_output = f"{header}\n\n"
        formatted_output += "\n".join(bullet_points)

        return formatted_output, bullet_points

    def extract_numbered_summary(self, bullet_points, response_type):
        """Extract key points and create numbered summary"""
        if not bullet_points:
            return ""

        # Clean bullet points
        clean_points = []
        for point in bullet_points:
            # Remove bullet symbol and clean
            point = point.replace('‚Ä¢', '').strip()
            point = point.replace('-', '').strip()
            if point and len(point) > 15:
                clean_points.append(point)

        # Take top 5-7 points for summary
        summary_points = clean_points[:7]

        # Create numbered summary based on response type
        summary_titles = {
            "symptoms": "üìã **KEY SYMPTOMS SUMMARY:**",
            "treatment": "üíä **KEY TREATMENTS SUMMARY:**",
            "diagnosis": "üîç **KEY DIAGNOSES SUMMARY:**",
            "medication": "üíä **KEY MEDICATIONS SUMMARY:**",
            "assessment": "üìã **KEY FINDINGS SUMMARY:**"
        }

        title = summary_titles.get(response_type, "üìã **KEY POINTS SUMMARY:**")

        numbered_summary = f"\n\n{title}\n\n"
        for i, point in enumerate(summary_points, 1):
            numbered_summary += f"{i}) {point}\n"

        return numbered_summary

    def generate_bullet_response(self, query, retrieved_docs, response_type="symptoms"):
        """Generate bullet point responses for different query types"""

        # Prepare context
        context_parts = []
        for i, doc in enumerate(retrieved_docs[:3]):  # Limit to 3 docs for context
            doc_text = doc["text"].strip()
            if len(doc_text) > 800:
                trunc_point = doc_text[:800].rfind('.')
                if trunc_point > 400:
                    doc_text = doc_text[:trunc_point+1]
                else:
                    doc_text = doc_text[:800] + "..."
            context_parts.append(f"[Document {i+1}]: {doc_text}")

        context = "\n\n".join(context_parts)

        # Enhanced prompts for bullet points
        prompts = {
            "symptoms": f"""EXTRACT ALL SYMPTOMS mentioned in these clinical documents.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. List ONLY symptoms found in the documents
2. Use bullet points (‚Ä¢) for each symptom
3. Be specific and include details like severity, location, duration
4. Do NOT include treatments or diagnoses
5. Group similar symptoms together
6. At the end, provide a numbered summary of key symptoms

SYMPTOMS LIST:""",

            "treatment": f"""EXTRACT ALL TREATMENTS, MEDICATIONS, and PROCEDURES mentioned.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. List ALL treatments mentioned
2. Use bullet points (‚Ä¢) for each treatment
3. Include: medication names, dosages, frequencies, routes
4. Include procedures, therapies, interventions
5. Do NOT include symptoms or diagnoses
6. At the end, provide a numbered summary of key treatments

TREATMENTS LIST:""",

            "diagnosis": f"""EXTRACT ALL POTENTIAL DIAGNOSES and CLINICAL FINDINGS.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. List ALL possible diagnoses mentioned
2. Use bullet points (‚Ä¢) for each diagnosis
3. Include supporting clinical findings
4. Mention differential diagnoses if present
5. Do NOT include treatments or symptoms
6. At the end, provide a numbered summary of key diagnoses

DIAGNOSES LIST:""",

            "medication": f"""EXTRACT ONLY MEDICATIONS and PRESCRIPTIONS mentioned.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. List ALL medications mentioned
2. Use bullet points (‚Ä¢) for each medication
3. Format: Drug Name - Dosage - Frequency - Route
4. Include PRN (as needed) medications
5. Do NOT include symptoms or diagnoses
6. At the end, provide a numbered summary of key medications

MEDICATIONS LIST:""",

            "assessment": f"""PROVIDE A CLINICAL ASSESSMENT based on the documents.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. Provide a structured assessment
2. Use bullet points (‚Ä¢) for each key point
3. Include: symptoms, findings, likely issues
4. Keep it concise and organized
5. Focus on the query
6. At the end, provide a numbered summary of key findings

CLINICAL ASSESSMENT:"""
        }

        prompt = prompts.get(response_type, prompts["assessment"])

        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.generator.device)

        output = self.generator.generate(
            **inputs,
            max_new_tokens=700,  # Increased for numbered summary
            temperature=0.2,  # Lower temperature for more focused responses
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            repetition_penalty=1.3,
            no_repeat_ngram_size=3
        )

        full_output = self.tokenizer.decode(output[0], skip_special_tokens=True)

        # Extract the response part
        response_markers = [
            "SYMPTOMS LIST:",
            "TREATMENTS LIST:",
            "DIAGNOSES LIST:",
            "MEDICATIONS LIST:",
            "CLINICAL ASSESSMENT:"
        ]

        response_text = full_output
        for marker in response_markers:
            if marker in full_output:
                response_text = full_output.split(marker)[-1].strip()
                break

        # Format as bullet points and extract bullet list
        formatted_response, bullet_points = self.format_as_bullets(response_text, response_type)

        # Add numbered summary at the end
        numbered_summary = self.extract_numbered_summary(bullet_points, response_type)
        final_response = formatted_response + numbered_summary

        return final_response

    def process_query(self, query, response_type="symptoms", top_k=5):
        """Main processing function"""

        # Add to history
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.query_history.append({
            "timestamp": timestamp,
            "query": query,
            "response_type": response_type
        })

        # Retrieve documents
        retrieved_docs = self.enhanced_retrieve(query, top_k=top_k)

        # Generate bullet point response
        response = self.generate_bullet_response(query, retrieved_docs, response_type)

        # Format retrieved documents
        docs_info = []
        for i, doc in enumerate(retrieved_docs):
            docs_info.append({
                "Rank": i+1,
                "Relevance": f"{doc['score']:.3f}",
                "Category": doc['category'],
                "Source": doc['filename'],
                "Preview": doc['text'][:150] + "..."
            })

        # Create metrics
        avg_score = np.mean([doc['score'] for doc in retrieved_docs]) if retrieved_docs else 0
        categories = {}
        for doc in retrieved_docs:
            cat = doc['category']
            categories[cat] = categories.get(cat, 0) + 1

        metrics = {
            "Query": query,
            "Response Type": response_type.upper(),
            "Documents Retrieved": len(retrieved_docs),
            "Average Relevance": f"{avg_score:.3f}",
            "Categories": ", ".join([f"{k} ({v})" for k, v in categories.items()]),
            "Response Format": "BULLET POINTS + NUMBERED SUMMARY",
            "Processing Time": timestamp
        }

        return response, docs_info, metrics

# =============================================================================
# INITIALIZE THE INTERFACE
# =============================================================================

rag_interface = EnhancedClinicalRAGInterface()

# =============================================================================
# SIMPLIFIED GRADIO INTERFACE WITHOUT EXAMPLES TABLE
# =============================================================================

def create_simple_gradio_interface():
    """Create simplified interface without examples table"""

    # Custom CSS
    css = """
    .gradio-container {
        max-width: 1000px !important;
        margin: auto !important;
    }
    .clinical-input {
        font-size: 16px !important;
        padding: 15px !important;
        border-radius: 10px !important;
    }
    .output-box {
        border-radius: 10px;
        padding: 20px;
        background: #f8f9fa;
        border: 2px solid #e0e0e0;
        font-family: 'Segoe UI', sans-serif;
        line-height: 1.8;
    }
    .bullet-points {
        margin-left: 20px;
    }
    .bullet-points li {
        margin-bottom: 10px;
        padding-left: 10px;
    }
    .numbered-summary {
        margin-top: 20px;
        padding: 15px;
        background: #e8f4f8;
        border-radius: 8px;
        border-left: 4px solid #3498db;
    }
    .type-badge {
        display: inline-block;
        padding: 5px 15px;
        border-radius: 20px;
        font-weight: bold;
        margin: 5px;
        font-size: 14px;
    }
    .symptom-badge { background: #ff6b6b; color: white; }
    .treatment-badge { background: #4ecdc4; color: white; }
    .diagnosis-badge { background: #45b7d1; color: white; }
    .medication-badge { background: #96ceb4; color: white; }
    """

    # Theme
    theme = gr.themes.Soft(
        primary_hue="blue",
        secondary_hue="gray",
    ).set(
        body_background_fill="#f0f2f6",
        button_primary_background_fill="#3b82f6",
        button_primary_background_fill_hover="#2563eb",
    )

    # Response type descriptions
    response_descriptions = {
        "symptoms": "üìã Get bullet-point list of symptoms + numbered summary",
        "treatment": "üíä Get bullet-point list of treatments & medications + numbered summary",
        "diagnosis": "üîç Get bullet-point list of possible diagnoses + numbered summary",
        "medication": "üíä Get bullet-point list of medications only + numbered summary",
        "assessment": "ü©∫ Get bullet-point clinical assessment + numbered summary"
    }

    def process_query_with_bullets(query, response_type, top_k):
        """Process query and return bullet-point results"""

        response, docs_info, metrics = rag_interface.process_query(query, response_type, top_k)

        # Create badge based on response type
        badges = {
            "symptoms": "ü©∫ SYMPTOMS",
            "treatment": "üíä TREATMENTS",
            "diagnosis": "üîç DIAGNOSES",
            "medication": "üíä MEDICATIONS",
            "assessment": "ü©∫ ASSESSMENT"
        }

        badge = badges.get(response_type, "üìã RESULTS")

        # Format the output with numbered summary
        formatted_response = response.replace('\n', '<br>')

        # Add numbered summary styling
        if "KEY SUMMARY:" in formatted_response or "SUMMARY:" in formatted_response:
            # Add special styling for numbered summary
            formatted_response = formatted_response.replace("KEY SYMPTOMS SUMMARY:",
                "<div class='numbered-summary'><strong>üìã KEY SYMPTOMS SUMMARY:</strong>")
            formatted_response = formatted_response.replace("KEY TREATMENTS SUMMARY:",
                "<div class='numbered-summary'><strong>üíä KEY TREATMENTS SUMMARY:</strong>")
            formatted_response = formatted_response.replace("KEY DIAGNOSES SUMMARY:",
                "<div class='numbered-summary'><strong>üîç KEY DIAGNOSES SUMMARY:</strong>")
            formatted_response = formatted_response.replace("KEY MEDICATIONS SUMMARY:",
                "<div class='numbered-summary'><strong>üíä KEY MEDICATIONS SUMMARY:</strong>")
            formatted_response = formatted_response.replace("KEY FINDINGS SUMMARY:",
                "<div class='numbered-summary'><strong>üìã KEY FINDINGS SUMMARY:</strong>")
            formatted_response = formatted_response.replace("KEY POINTS SUMMARY:",
                "<div class='numbered-summary'><strong>üìã KEY POINTS SUMMARY:</strong>")

            # Close div tag
            formatted_response = formatted_response.replace('\n<br>', '</div><br>')

        # Create final HTML output
        formatted_output = f"""
<div style="font-family: 'Segoe UI', sans-serif; line-height: 1.8;">
    <div style="background: linear-gradient(135deg, #3b82f6, #1d4ed8); color: white; padding: 15px; border-radius: 10px 10px 0 0; margin-bottom: 20px;">
        <h3 style="margin: 0;">{badge} ANALYSIS</h3>
        <p style="margin: 5px 0 0 0; opacity: 0.9;">Query: <strong>{query}</strong></p>
    </div>

    <div style="background: white; padding: 20px; border-radius: 0 0 10px 10px; border: 1px solid #e0e0e0;">
        <div style="margin-bottom: 20px; padding: 15px; background: #f8f9fa; border-radius: 8px;">
            <strong>üìä System Metrics:</strong><br>
            ‚Ä¢ Documents Analyzed: {metrics['Documents Retrieved']}<br>
            ‚Ä¢ Average Relevance: {metrics['Average Relevance']}<br>
            ‚Ä¢ Response Type: {response_type.upper()}<br>
            ‚Ä¢ Format: {metrics['Response Format']}<br>
            ‚Ä¢ Processing Time: {metrics['Processing Time']}
        </div>

        <div style="margin-bottom: 20px;">
            {formatted_response}
        </div>

        <div style="margin-top: 20px; padding-top: 15px; border-top: 1px solid #e0e0e0; font-size: 14px; color: #666;">
            <strong>üìö Source Information:</strong><br>
            Retrieved from {len(docs_info)} clinical documents | Categories: {metrics['Categories']}
        </div>
    </div>
</div>
"""

        # Create metrics for JSON
        metrics_display = {
            "query": query,
            "response_type": response_type,
            "documents_retrieved": metrics["Documents Retrieved"],
            "average_relevance": metrics["Average Relevance"],
            "categories": metrics["Categories"],
            "processing_time": metrics["Processing Time"],
            "format": metrics["Response Format"]
        }

        # Create docs dataframe
        if docs_info:
            docs_df = pd.DataFrame(docs_info)
        else:
            docs_df = pd.DataFrame({"Message": ["No documents retrieved"]})

        return formatted_output, metrics_display, docs_df

    # Create the interface
    with gr.Blocks(theme=theme, css=css, title="üè• Clinical RAG Assistant") as demo:
        gr.Markdown("""
        # üè• Clinical RAG Assistant
        ### Get Organized Bullet-Point Medical Information with Numbered Summary
        *Ask about symptoms, treatments, diagnoses, medications - get detailed bullet points + numbered summary*
        """)

        with gr.Row():
            with gr.Column(scale=2):
                # Query Input
                gr.Markdown("### üìù Enter Medical Query")
                query_input = gr.Textbox(
                    label="",
                    placeholder="Examples: 'symptoms of pneumonia', 'treatments for hypertension', 'medications for diabetes'...",
                    lines=3,
                    elem_classes="clinical-input"
                )

                # Response Type Selection
                gr.Markdown("### üéØ Select Information Type")
                response_type = gr.Radio(
                    choices=list(response_descriptions.keys()),
                    value="symptoms",
                    label="",
                    info="Choose what type of information you need",
                    elem_id="response-type-radio"
                )

                # Display descriptions
                response_desc = gr.Markdown(
                    value=response_descriptions["symptoms"],
                    elem_id="response-desc"
                )

                # Update description when radio changes
                def update_desc(response_type):
                    return response_descriptions.get(response_type, "")

                response_type.change(
                    fn=update_desc,
                    inputs=response_type,
                    outputs=response_desc
                )

                # Settings
                with gr.Row():
                    top_k_slider = gr.Slider(
                        minimum=2,
                        maximum=10,
                        value=5,
                        step=1,
                        label="Number of clinical documents to analyze",
                        info="More documents = more comprehensive results"
                    )

                # Action Button
                submit_btn = gr.Button(
                    "üîç Generate Analysis with Numbered Summary",
                    variant="primary",
                    size="lg",
                    scale=1
                )

                # Clear Button
                clear_btn = gr.Button("üóëÔ∏è Clear All", variant="secondary", scale=0)

            with gr.Column(scale=1):
                # Quick Query Suggestions
                gr.Markdown("### üí° Quick Suggestions")

                # Quick query buttons
                with gr.Column():
                    symptoms_btn = gr.Button("ü©∫ Symptoms of Pneumonia", size="sm")
                    treatment_btn = gr.Button("üíä Treatments for Hypertension", size="sm")
                    diagnosis_btn = gr.Button("üîç Diagnoses for Chest Pain", size="sm")
                    medication_btn = gr.Button("üíä Medications for Diabetes", size="sm")
                    assessment_btn = gr.Button("ü©∫ Stroke Assessment", size="sm")

                # System Info
                gr.Markdown("### üõ†Ô∏è System Information")
                gr.Markdown(f"""
                - **Documents Indexed:** {len(documents)}
                - **Embedding Model:** all-MiniLM-L6-v2
                - **LLM Model:** Qwen2.5-1.5B-Instruct
                - **Response Format:** Bullet Points + Numbered Summary
                """)

        # Output Section with Tabs
        gr.Markdown("### üìã Analysis Results")

        with gr.Tabs():
            with gr.TabItem("üí° Clinical Analysis"):
                output_html = gr.HTML(
                    label="",
                    value="<div style='text-align: center; padding: 40px; color: #666;'>Enter a query above and click 'Generate Analysis with Numbered Summary'</div>"
                )

            with gr.TabItem("üìä System Metrics"):
                metrics_json = gr.JSON(
                    label="Retrieval Metrics",
                    value={}
                )

            with gr.TabItem("üìö Source Documents"):
                docs_table = gr.Dataframe(
                    label="Retrieved Clinical Documents",
                    headers=["Rank", "Relevance", "Category", "Source", "Preview"],
                    value=pd.DataFrame({"Message": ["No analysis performed yet"]})
                )

        # Quick query button actions
        def set_quick_query(query, r_type):
            return query, r_type, 5

        symptoms_btn.click(
            fn=lambda: set_quick_query("What are the symptoms of pneumonia?", "symptoms"),
            outputs=[query_input, response_type, top_k_slider]
        )

        treatment_btn.click(
            fn=lambda: set_quick_query("What treatments are available for hypertension?", "treatment"),
            outputs=[query_input, response_type, top_k_slider]
        )

        diagnosis_btn.click(
            fn=lambda: set_quick_query("What are possible diagnoses for chest pain?", "diagnosis"),
            outputs=[query_input, response_type, top_k_slider]
        )

        medication_btn.click(
            fn=lambda: set_quick_query("What medications are used for diabetes?", "medication"),
            outputs=[query_input, response_type, top_k_slider]
        )

        assessment_btn.click(
            fn=lambda: set_quick_query("Patient with stroke symptoms assessment", "assessment"),
            outputs=[query_input, response_type, top_k_slider]
        )

        # Main submit action
        submit_btn.click(
            fn=process_query_with_bullets,
            inputs=[query_input, response_type, top_k_slider],
            outputs=[output_html, metrics_json, docs_table]
        )

        # Clear action
        clear_btn.click(
            fn=lambda: ["", "symptoms", 5,
                       "<div style='text-align: center; padding: 40px; color: #666;'>Enter a query above and click 'Generate Analysis with Numbered Summary'</div>",
                       {}, pd.DataFrame({"Message": ["No analysis performed yet"]})],
            outputs=[query_input, response_type, top_k_slider, output_html, metrics_json, docs_table]
        )

        # Footer
        gr.Markdown("""
        ---
        ### üîí Important Notice
        *This system provides AI-generated bullet-point summaries from clinical documents.*
        *For actual medical decisions, consult healthcare professionals.*
        *All data is de-identified for privacy protection.*

        **Response Types (with Numbered Summary):**
        - **ü©∫ Symptoms**: Bullet-point list + numbered summary of key symptoms
        - **üíä Treatments**: Bullet-point list + numbered summary of key treatments
        - **üîç Diagnosis**: Bullet-point list + numbered summary of key diagnoses
        - **üíä Medications**: Bullet-point list + numbered summary of key medications
        - **ü©∫ Assessment**: Bullet-point assessment + numbered summary of key findings

        **üìù Every response includes:** Detailed bullet points + Numbered summary at the end
        """)

    return demo

# =============================================================================
# LAUNCH THE INTERFACE
# =============================================================================

print("üöÄ Creating Clinical RAG Interface with Bullet-Point + Numbered Summary...")
demo = create_simple_gradio_interface()

# Kill any existing gradio processes
import subprocess
import time

print("üîÑ Checking for existing processes...")
subprocess.run(["pkill", "-f", "gradio"], capture_output=True)
subprocess.run(["pkill", "-f", "uvicorn"], capture_output=True)
time.sleep(2)

# Try different ports
ports_to_try = [7860, 7861, 7862, 7863]

for port in ports_to_try:
    print(f"\nüéØ Attempting to launch on port {port}...")
    try:
        demo.launch(
            server_name="0.0.0.0",
            server_port=port,
            share=True,
            quiet=False,
            debug=False
        )
        print(f"‚úÖ Success! Interface running on port {port}")
        print(f"üì± Check the gradio.app link above")
        break
    except Exception as e:
        print(f"‚ö†Ô∏è Port {port} failed: {str(e)[:100]}...")
        continue

# If all ports fail, try with ngrok
if 'demo' not in locals() or not hasattr(demo, 'server'):
    print("\nüîÑ Trying ngrok tunnel...")
    !pip install pyngrok -q

    from pyngrok import ngrok

    # Kill existing ngrok
    ngrok.kill()

    # Create tunnel
    public_url = ngrok.connect(addr="7865", proto="http")
    print(f"üåê Public URL: {public_url}")

    # Launch on local port
    demo.launch(
        server_name="0.0.0.0",
        server_port=7865,
        share=False,
        quiet=True
    )

    print(f"\n‚úÖ Interface ready!")
    print(f"üì± Open: {public_url}")

print("\nüéâ Ready to use! Enter medical queries and get detailed bullet points + numbered summary.")

üöÄ Creating Clinical RAG Interface with Bullet-Point + Numbered Summary...
üîÑ Checking for existing processes...
IMPORTANT: You are using gradio version 4.21.0, however version 4.44.1 is available, please upgrade.
--------

üéØ Attempting to launch on port 7860...
‚ö†Ô∏è Port 7860 failed: Cannot find empty port in range: 7860-7860. You can specify a different port by setting the GRADIO_S...

üéØ Attempting to launch on port 7861...
‚ö†Ô∏è Port 7861 failed: Cannot find empty port in range: 7861-7861. You can specify a different port by setting the GRADIO_S...

üéØ Attempting to launch on port 7862...
‚ö†Ô∏è Port 7862 failed: Cannot find empty port in range: 7862-7862. You can specify a different port by setting the GRADIO_S...

üéØ Attempting to launch on port 7863...
‚ö†Ô∏è Port 7863 failed: Cannot find empty port in range: 7863-7863. You can specify a different port by setting the GRADIO_S...

üîÑ Trying ngrok tunnel...


ERROR:pyngrok.process.ngrok:t=2025-12-04T23:20:06+0000 lvl=eror msg="failed to reconnect session" obj=tunnels.session err="authentication failed: Usage of ngrok requires a verified account and authtoken.\n\nSign up for an account: https://dashboard.ngrok.com/signup\nInstall your authtoken: https://dashboard.ngrok.com/get-started/your-authtoken\r\n\r\nERR_NGROK_4018\r\n"
ERROR:pyngrok.process.ngrok:t=2025-12-04T23:20:06+0000 lvl=eror msg="session closing" obj=tunnels.session err="authentication failed: Usage of ngrok requires a verified account and authtoken.\n\nSign up for an account: https://dashboard.ngrok.com/signup\nInstall your authtoken: https://dashboard.ngrok.com/get-started/your-authtoken\r\n\r\nERR_NGROK_4018\r\n"
ERROR:pyngrok.process.ngrok:t=2025-12-04T23:20:06+0000 lvl=eror msg="terminating with error" obj=app err="authentication failed: Usage of ngrok requires a verified account and authtoken.\n\nSign up for an account: https://dashboard.ngrok.com/signup\nInstall your aut

PyngrokNgrokError: The ngrok process errored on start: authentication failed: Usage of ngrok requires a verified account and authtoken.\n\nSign up for an account: https://dashboard.ngrok.com/signup\nInstall your authtoken: https://dashboard.ngrok.com/get-started/your-authtoken\r\n\r\nERR_NGROK_4018\r\n.

In [26]:
# Install Gradio
!pip install gradio==4.21.0

import gradio as gr
import torch
import numpy as np
import os
import pickle
import json
from datetime import datetime
import pandas as pd
import re

# =============================================================================
# ENHANCED GRADIO INTERFACE FOR CLINICAL RAG SYSTEM
# =============================================================================

class EnhancedClinicalRAGInterface:
    def __init__(self):
        self.model = model
        self.index = index
        self.documents = documents
        self.tokenizer = tokenizer
        self.generator = generator
        self.query_history = []

    def enhanced_retrieve(self, query, top_k=5, min_score=0.4):
        """Enhanced retrieval with better filtering"""
        q = self.model.encode([query], convert_to_numpy=True)
        scores, idx = self.index.search(q, top_k * 5)  # Get more to filter

        out = []
        seen_sources = set()

        for score, i in zip(scores[0], idx[0]):
            if i >= len(self.documents):
                continue

            source = self.documents[i]["source"]

            # Skip duplicates and low scores
            if source in seen_sources or score < min_score:
                continue
            seen_sources.add(source)

            # Extract category from path
            try:
                parts = source.split('/')
                if len(parts) > 5:
                    category = parts[-3]
                else:
                    category = "Unknown"
            except:
                category = "Unknown"

            # Clean text
            text = self.documents[i]["text"].strip()
            if len(text) > 1000:
                text = text[:1000] + "..."

            out.append({
                "score": float(score),
                "source": source,
                "filename": os.path.basename(source),
                "category": category,
                "text": text
            })

            if len(out) >= top_k:
                break

        return out

    def extract_unique_items(self, text, response_type):
        """Extract unique symptoms, treatments, diagnoses, medications from text"""
        text = text.lower().strip()

        # Define patterns for different response types
        patterns = {
            "symptoms": [
                r'(?:symptoms?|signs?|presents? with|complains? of|experiencing|has|had|feeling)\s*(?:[a-z\s,]*?)([\w\s]+?(?:pain|ache|discomfort|nausea|vomiting|headache|fever|cough|shortness|weakness|fatigue|dizziness))',
                r'(?:including|such as|like|e\.g\.)\s*([\w\s,]+?(?:pain|headache|nausea|vomiting|fever|cough))',
                r'(\b[\w\s]+?(?:pain|ache|headache|nausea|vomiting|fever|cough|shortness|weakness|fatigue))\b'
            ],
            "treatment": [
                r'(?:treated with|prescribed|medication|drug|therapy|treatment)\s*(?:[a-z\s,]*?)([\w\s]+?(?:mg|g|ml|tablet|capsule|injection|iv|oral|cream|ointment))',
                r'(\b[\w\s]+?(?:mg\s|\g\s|ml\s|tablet|capsule|injection|iv|oral))\b',
                r'(?:including|such as|like)\s*([\w\s,]+?(?:mg|tablet|capsule|injection))'
            ],
            "medication": [
                r'(\b[\w\s]+?(?:\d+\s*mg|\d+\s*g|\d+\s*ml|tablet|capsule|injection|iv))\b',
                r'(?:prescribed|medication|drug)\s*([\w\s]+?(?:mg|g|ml|tablet|capsule))',
                r'(\b(?:aspirin|ibuprofen|paracetamol|lisinopril|amlodipine|metformin|insulin|warfarin|heparin)\b)'
            ],
            "diagnosis": [
                r'(?:diagnosis|diagnosed with|suspected|rule out|possible|likely)\s*([\w\s]+?(?:itis|osis|opathy|emia|oma|syndrome|disease|disorder))',
                r'(\b[\w\s]+?(?:pneumonia|hypertension|diabetes|stroke|migraine|asthma|copd|arthritis))\b',
                r'(?:including|such as|like|e\.g\.)\s*([\w\s,]+?(?:pneumonia|hypertension|diabetes|stroke))'
            ]
        }

        # Default extraction for assessment/summary
        if response_type not in patterns:
            # Extract key phrases
            sentences = re.split(r'[.!?]+', text)
            key_items = []
            for sentence in sentences[:10]:
                sentence = sentence.strip()
                if len(sentence) > 15:
                    # Extract important phrases (usually 2-4 words)
                    words = sentence.split()
                    if 2 <= len(words) <= 4:
                        phrase = ' '.join(words)
                        if phrase.lower() not in ['the patient', 'based on', 'clinical findings']:
                            key_items.append(phrase)
            return list(set(key_items))[:8]

        # Extract items based on response type
        all_items = []
        for pattern in patterns.get(response_type, []):
            matches = re.findall(pattern, text, re.IGNORECASE)
            for match in matches:
                if isinstance(match, tuple):
                    match = match[0]
                item = match.strip()
                if item and len(item) > 3 and item.lower() not in ['the', 'and', 'with', 'for', 'of']:
                    all_items.append(item.title())

        # Remove duplicates and clean
        unique_items = []
        seen = set()
        for item in all_items:
            clean_item = re.sub(r'\s+', ' ', item).strip()
            if clean_item and clean_item not in seen:
                seen.add(clean_item)
                unique_items.append(clean_item)

        return unique_items[:10]

    def format_unique_items_grid(self, items, response_type):
        """Format unique items in a clean grid/row format"""
        if not items:
            return ""

        # Create title based on response type
        titles = {
            "symptoms": "ü©∫ **UNIQUE SYMPTOMS IDENTIFIED:**",
            "treatment": "üíä **UNIQUE TREATMENTS IDENTIFIED:**",
            "diagnosis": "üîç **UNIQUE DIAGNOSES IDENTIFIED:**",
            "medication": "üíä **UNIQUE MEDICATIONS IDENTIFIED:**",
            "assessment": "üìã **KEY FINDINGS IDENTIFIED:**",
            "summary": "üìÑ **KEY POINTS IDENTIFIED:**"
        }

        title = titles.get(response_type, "üìã **KEY ITEMS IDENTIFIED:**")

        # Create HTML grid
        grid_html = f"""
<div style="margin: 15px 0;">
    <div style="font-weight: bold; margin-bottom: 10px; font-size: 16px;">
        {title}
    </div>
    <div style="display: flex; flex-wrap: wrap; gap: 10px; margin-top: 10px;">
"""

        for i, item in enumerate(items, 1):
            # Different colors for different types
            colors = {
                "symptoms": "#ff6b6b",
                "treatment": "#4ecdc4",
                "diagnosis": "#45b7d1",
                "medication": "#96ceb4",
                "assessment": "#feca57",
                "summary": "#a29bfe"
            }

            color = colors.get(response_type, "#3498db")

            grid_html += f"""
        <div style="
            background: {color};
            color: white;
            padding: 8px 15px;
            border-radius: 20px;
            font-size: 14px;
            font-weight: 500;
            display: flex;
            align-items: center;
            gap: 5px;
        ">
            <span style="opacity: 0.8;">{i}.</span> {item}
        </div>
"""

        grid_html += """
    </div>
</div>
"""

        return grid_html

    def format_detailed_bullets(self, text, response_type):
        """Format text as detailed bullet points"""
        # Clean the text
        text = text.strip()

        # Split into sentences
        sentences = re.split(r'[.!?]+', text)

        bullet_points = []
        for sentence in sentences:
            sentence = sentence.strip()
            if len(sentence) > 15:  # Only meaningful sentences
                # Clean up sentence
                sentence = re.sub(r'\s+', ' ', sentence)
                # Add bullet point
                bullet_points.append(f"‚Ä¢ {sentence}")

        # If no bullets created, use the original text
        if not bullet_points:
            lines = text.split('\n')
            for line in lines:
                line = line.strip()
                if line:
                    bullet_points.append(f"‚Ä¢ {line}")

        # Limit to reasonable number
        bullet_points = bullet_points[:10]

        # Add header based on response type
        headers = {
            "symptoms": "üìã **DETAILED SYMPTOM DESCRIPTION:**",
            "treatment": "üíä **DETAILED TREATMENT INFORMATION:**",
            "diagnosis": "üîç **DETAILED DIAGNOSIS ANALYSIS:**",
            "medication": "üíä **DETAILED MEDICATION INFORMATION:**",
            "assessment": "ü©∫ **DETAILED CLINICAL ASSESSMENT:**"
        }

        header = headers.get(response_type, "üìã **DETAILED ANALYSIS:**")

        formatted_output = f"{header}\n\n"
        formatted_output += "\n".join(bullet_points)

        return formatted_output

    def generate_response(self, query, retrieved_docs, response_type="symptoms"):
        """Generate comprehensive response with unique items grid"""

        # Prepare context
        context_parts = []
        for i, doc in enumerate(retrieved_docs[:3]):  # Limit to 3 docs
            doc_text = doc["text"].strip()
            if len(doc_text) > 800:
                trunc_point = doc_text[:800].rfind('.')
                if trunc_point > 400:
                    doc_text = doc_text[:trunc_point+1]
                else:
                    doc_text = doc_text[:800] + "..."
            context_parts.append(f"[Document {i+1}]: {doc_text}")

        context = "\n\n".join(context_parts)

        # Enhanced prompts
        prompts = {
            "symptoms": f"""ANALYZE these clinical documents and extract ALL symptoms.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. Extract ALL symptoms mentioned
2. Provide detailed description of each symptom
3. Include: severity, location, duration, characteristics
4. Group similar symptoms
5. Focus on accuracy and completeness

SYMPTOMS ANALYSIS:""",

            "treatment": f"""ANALYZE these clinical documents and extract ALL treatments.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. Extract ALL treatments, medications, procedures
2. Include: names, dosages, frequencies, routes
3. Specify medication details clearly
4. Include procedures and interventions
5. Focus on accuracy and completeness

TREATMENTS ANALYSIS:""",

            "diagnosis": f"""ANALYZE these clinical documents for diagnoses.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. Extract ALL possible diagnoses mentioned
2. Include supporting evidence for each
3. Mention confidence levels if indicated
4. Include differential diagnoses
5. Focus on accuracy and completeness

DIAGNOSES ANALYSIS:""",

            "medication": f"""ANALYZE these clinical documents for medications.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. Extract ALL medications mentioned
2. Format: Name - Dosage - Frequency - Route
3. Include PRN medications
4. Specify administration details
5. Focus on accuracy and completeness

MEDICATIONS ANALYSIS:""",

            "assessment": f"""PROVIDE comprehensive clinical assessment.

CLINICAL DOCUMENTS:
{context}

QUERY: {query}

INSTRUCTIONS:
1. Provide thorough clinical assessment
2. Include: symptoms, findings, assessments
3. Structure analysis clearly
4. Focus on the query
5. Be comprehensive yet concise

CLINICAL ASSESSMENT:"""
        }

        prompt = prompts.get(response_type, prompts["assessment"])

        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.generator.device)

        output = self.generator.generate(
            **inputs,
            max_new_tokens=800,
            temperature=0.2,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            repetition_penalty=1.3,
            no_repeat_ngram_size=3
        )

        full_output = self.tokenizer.decode(output[0], skip_special_tokens=True)

        # Extract the response
        response_markers = [
            "SYMPTOMS ANALYSIS:",
            "TREATMENTS ANALYSIS:",
            "DIAGNOSES ANALYSIS:",
            "MEDICATIONS ANALYSIS:",
            "CLINICAL ASSESSMENT:"
        ]

        response_text = full_output
        for marker in response_markers:
            if marker in full_output:
                response_text = full_output.split(marker)[-1].strip()
                break

        # Extract unique items
        unique_items = self.extract_unique_items(response_text, response_type)

        # Format detailed bullets
        detailed_bullets = self.format_detailed_bullets(response_text, response_type)

        # Format unique items grid
        items_grid = self.format_unique_items_grid(unique_items, response_type)

        # Combine both
        final_response = items_grid + "\n\n" + detailed_bullets

        return final_response, unique_items

    def process_query(self, query, response_type="symptoms", top_k=5):
        """Main processing function"""

        # Add to history
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.query_history.append({
            "timestamp": timestamp,
            "query": query,
            "response_type": response_type
        })

        # Retrieve documents
        retrieved_docs = self.enhanced_retrieve(query, top_k=top_k)

        # Generate response
        response, unique_items = self.generate_response(query, retrieved_docs, response_type)

        # Format retrieved documents
        docs_info = []
        for i, doc in enumerate(retrieved_docs):
            docs_info.append({
                "Rank": i+1,
                "Relevance": f"{doc['score']:.3f}",
                "Category": doc['category'],
                "Source": doc['filename'],
                "Preview": doc['text'][:150] + "..."
            })

        # Create metrics
        avg_score = np.mean([doc['score'] for doc in retrieved_docs]) if retrieved_docs else 0
        categories = {}
        for doc in retrieved_docs:
            cat = doc['category']
            categories[cat] = categories.get(cat, 0) + 1

        metrics = {
            "Query": query,
            "Response Type": response_type.upper(),
            "Documents Retrieved": len(retrieved_docs),
            "Average Relevance": f"{avg_score:.3f}",
            "Categories": ", ".join([f"{k} ({v})" for k, v in categories.items()]),
            "Unique Items Found": len(unique_items),
            "Response Format": "UNIQUE ITEMS GRID + DETAILED ANALYSIS",
            "Processing Time": timestamp
        }

        return response, docs_info, metrics, unique_items

# =============================================================================
# INITIALIZE THE INTERFACE
# =============================================================================

rag_interface = EnhancedClinicalRAGInterface()

# =============================================================================
# SIMPLIFIED GRADIO INTERFACE
# =============================================================================

def create_simple_gradio_interface():
    """Create interface with unique items display"""

    # Custom CSS
    css = """
    .gradio-container {
        max-width: 1100px !important;
        margin: auto !important;
    }
    .clinical-input {
        font-size: 16px !important;
        padding: 15px !important;
        border-radius: 10px !important;
    }
    .output-box {
        border-radius: 10px;
        padding: 20px;
        background: #f8f9fa;
        border: 2px solid #e0e0e0;
        font-family: 'Segoe UI', sans-serif;
        line-height: 1.8;
    }
    .items-grid {
        display: flex;
        flex-wrap: wrap;
        gap: 10px;
        margin: 15px 0;
        padding: 15px;
        background: #ffffff;
        border-radius: 10px;
        border: 1px solid #e0e0e0;
    }
    .grid-item {
        background: linear-gradient(135deg, #667eea, #764ba2);
        color: white;
        padding: 8px 16px;
        border-radius: 20px;
        font-size: 14px;
        font-weight: 500;
        display: flex;
        align-items: center;
        transition: transform 0.2s;
    }
    .grid-item:hover {
        transform: translateY(-2px);
        box-shadow: 0 4px 12px rgba(0,0,0,0.1);
    }
    .grid-item-number {
        background: rgba(255,255,255,0.2);
        border-radius: 50%;
        width: 24px;
        height: 24px;
        display: flex;
        align-items: center;
        justify-content: center;
        margin-right: 8px;
        font-size: 12px;
    }
    .detailed-section {
        margin-top: 25px;
        padding: 20px;
        background: #f0f7ff;
        border-radius: 10px;
        border-left: 4px solid #3498db;
    }
    """

    # Theme
    theme = gr.themes.Soft(
        primary_hue="blue",
        secondary_hue="gray",
    ).set(
        body_background_fill="#f5f7fa",
        button_primary_background_fill="#3b82f6",
        button_primary_background_fill_hover="#2563eb",
    )

    # Response type descriptions
    response_descriptions = {
        "symptoms": "üîç Get unique symptoms list + detailed analysis",
        "treatment": "üíä Get unique treatments list + detailed analysis",
        "diagnosis": "ü©∫ Get unique diagnoses list + detailed analysis",
        "medication": "üíä Get unique medications list + detailed analysis",
        "assessment": "üìã Get comprehensive clinical assessment"
    }

    def create_items_grid_html(items, response_type):
        """Create HTML for items grid"""
        if not items:
            return "<div style='color: #666; padding: 20px; text-align: center;'>No unique items identified</div>"

        # Color mapping
        colors = {
            "symptoms": "linear-gradient(135deg, #ff6b6b, #ff8e8e)",
            "treatment": "linear-gradient(135deg, #4ecdc4, #6dd3ca)",
            "diagnosis": "linear-gradient(135deg, #45b7d1, #6bc5dd)",
            "medication": "linear-gradient(135deg, #96ceb4, #b0d8c4)",
            "assessment": "linear-gradient(135deg, #feca57, #fed67a)"
        }

        color = colors.get(response_type, "linear-gradient(135deg, #667eea, #764ba2)")

        grid_html = f"""
<div class="items-grid">
"""

        for i, item in enumerate(items, 1):
            grid_html += f"""
    <div class="grid-item" style="background: {color};">
        <div class="grid-item-number">{i}</div>
        {item}
    </div>
"""

        grid_html += """
</div>
"""
        return grid_html

    def process_query_with_unique_items(query, response_type, top_k):
        """Process query and return results with unique items grid"""

        response, docs_info, metrics, unique_items = rag_interface.process_query(query, response_type, top_k)

        # Create badge
        badges = {
            "symptoms": "ü©∫ SYMPTOMS",
            "treatment": "üíä TREATMENTS",
            "diagnosis": "üîç DIAGNOSES",
            "medication": "üíä MEDICATIONS",
            "assessment": "ü©∫ ASSESSMENT"
        }

        badge = badges.get(response_type, "üìã RESULTS")

        # Create items grid HTML
        items_grid_html = create_items_grid_html(unique_items, response_type)

        # Format response
        formatted_response = response.replace('\n', '<br>')

        # Create final output
        formatted_output = f"""
<div style="font-family: 'Segoe UI', sans-serif; line-height: 1.8;">
    <div style="background: linear-gradient(135deg, #3b82f6, #1d4ed8); color: white; padding: 20px; border-radius: 10px 10px 0 0; margin-bottom: 20px;">
        <h3 style="margin: 0; font-size: 20px;">{badge} ANALYSIS</h3>
        <p style="margin: 8px 0 0 0; opacity: 0.9; font-size: 14px;">Query: <strong>{query}</strong></p>
    </div>

    <div style="background: white; padding: 25px; border-radius: 0 0 10px 10px; border: 1px solid #e0e0e0;">
        <div style="margin-bottom: 25px; padding: 18px; background: #f8f9fa; border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.05);">
            <div style="display: flex; flex-wrap: wrap; gap: 15px; margin-bottom: 15px;">
                <div style="flex: 1; min-width: 200px;">
                    <strong style="color: #3b82f6;">üìä Documents Analyzed:</strong><br>
                    <span style="font-size: 18px; font-weight: bold;">{metrics['Documents Retrieved']}</span>
                </div>
                <div style="flex: 1; min-width: 200px;">
                    <strong style="color: #3b82f6;">‚≠ê Average Relevance:</strong><br>
                    <span style="font-size: 18px; font-weight: bold;">{metrics['Average Relevance']}</span>
                </div>
                <div style="flex: 1; min-width: 200px;">
                    <strong style="color: #3b82f6;">üîç Unique Items:</strong><br>
                    <span style="font-size: 18px; font-weight: bold;">{metrics['Unique Items Found']}</span>
                </div>
            </div>
            <div style="color: #666; font-size: 14px;">
                <strong>Response Type:</strong> {response_type.upper()} |
                <strong>Categories:</strong> {metrics['Categories']} |
                <strong>Time:</strong> {metrics['Processing Time']}
            </div>
        </div>

        <div style="margin-bottom: 25px;">
            <h4 style="margin: 0 0 15px 0; color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 8px;">
                üéØ UNIQUE ITEMS IDENTIFIED
            </h4>
            {items_grid_html}
        </div>

        <div style="margin-bottom: 20px;">
            <h4 style="margin: 0 0 15px 0; color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 8px;">
                üìù DETAILED ANALYSIS
            </h4>
            <div class="detailed-section">
                {formatted_response}
            </div>
        </div>

        <div style="margin-top: 25px; padding-top: 15px; border-top: 1px solid #e0e0e0; font-size: 13px; color: #7f8c8d;">
            <strong>üìö Source Information:</strong> Retrieved from {len(docs_info)} clinical documents<br>
            <strong>üè∑Ô∏è Categories:</strong> {metrics['Categories']}
        </div>
    </div>
</div>
"""

        # Metrics for JSON
        metrics_display = {
            "query": query,
            "response_type": response_type,
            "documents_retrieved": metrics["Documents Retrieved"],
            "average_relevance": metrics["Average Relevance"],
            "unique_items_found": metrics["Unique Items Found"],
            "categories": metrics["Categories"],
            "processing_time": metrics["Processing Time"],
            "format": metrics["Response Format"]
        }

        # Docs dataframe
        if docs_info:
            docs_df = pd.DataFrame(docs_info)
        else:
            docs_df = pd.DataFrame({"Message": ["No documents retrieved"]})

        return formatted_output, metrics_display, docs_df

    # Create interface
    with gr.Blocks(theme=theme, css=css, title="üè• Clinical RAG Assistant") as demo:
        gr.Markdown("""
        # üè• Clinical RAG Assistant
        ### Get Unique Medical Items + Detailed Analysis
        *Ask about symptoms, treatments, diagnoses, medications - get unique items grid + detailed analysis*
        """)

        with gr.Row():
            with gr.Column(scale=2):
                # Query Input
                gr.Markdown("### üìù Enter Medical Query")
                query_input = gr.Textbox(
                    label="",
                    placeholder="Examples: 'symptoms of pneumonia', 'treatments for hypertension', 'medications for diabetes'...",
                    lines=3,
                    elem_classes="clinical-input"
                )

                # Response Type Selection
                gr.Markdown("### üéØ Select Information Type")
                response_type = gr.Radio(
                    choices=list(response_descriptions.keys()),
                    value="symptoms",
                    label="",
                    info="What information do you need?",
                    elem_id="response-type-radio"
                )

                # Display descriptions
                response_desc = gr.Markdown(
                    value=response_descriptions["symptoms"],
                    elem_id="response-desc"
                )

                # Update description
                def update_desc(response_type):
                    return response_descriptions.get(response_type, "")

                response_type.change(
                    fn=update_desc,
                    inputs=response_type,
                    outputs=response_desc
                )

                # Settings
                with gr.Row():
                    top_k_slider = gr.Slider(
                        minimum=2,
                        maximum=10,
                        value=5,
                        step=1,
                        label="Clinical documents to analyze",
                        info="More documents = better coverage"
                    )

                # Action Button
                submit_btn = gr.Button(
                    "üîç Extract Unique Items + Analysis",
                    variant="primary",
                    size="lg"
                )

                # Clear Button
                clear_btn = gr.Button("üóëÔ∏è Clear", variant="secondary")

            with gr.Column(scale=1):
                # Quick Query Buttons
                gr.Markdown("### üí° Quick Examples")

                with gr.Column():
                    symptoms_btn = gr.Button("ü©∫ Pneumonia Symptoms", size="sm")
                    treatment_btn = gr.Button("üíä Hypertension Treatments", size="sm")
                    diagnosis_btn = gr.Button("üîç Chest Pain Diagnoses", size="sm")
                    medication_btn = gr.Button("üíä Diabetes Medications", size="sm")

                # System Info
                gr.Markdown("### üõ†Ô∏è System Information")
                gr.Markdown(f"""
                - **Documents Indexed:** {len(documents)}
                - **Response Format:** Unique Items Grid + Details
                - **Output:** Clean visualization + Detailed analysis
                """)

        # Output Section
        gr.Markdown("### üìã Analysis Results")

        with gr.Tabs():
            with gr.TabItem("üí° Clinical Analysis"):
                output_html = gr.HTML(
                    label="",
                    value="<div style='text-align: center; padding: 40px; color: #666;'>Enter a query above and click 'Extract Unique Items + Analysis'</div>"
                )

            with gr.TabItem("üìä System Metrics"):
                metrics_json = gr.JSON(
                    label="Retrieval Metrics",
                    value={}
                )

            with gr.TabItem("üìö Source Documents"):
                docs_table = gr.Dataframe(
                    label="Retrieved Clinical Documents",
                    headers=["Rank", "Relevance", "Category", "Source", "Preview"],
                    value=pd.DataFrame({"Message": ["No analysis performed yet"]})
                )

        # Quick query actions
        def set_quick_query(query, r_type):
            return query, r_type, 5

        symptoms_btn.click(
            fn=lambda: set_quick_query("What are the symptoms of pneumonia?", "symptoms"),
            outputs=[query_input, response_type, top_k_slider]
        )

        treatment_btn.click(
            fn=lambda: set_quick_query("What treatments are available for hypertension?", "treatment"),
            outputs=[query_input, response_type, top_k_slider]
        )

        diagnosis_btn.click(
            fn=lambda: set_quick_query("What are possible diagnoses for chest pain?", "diagnosis"),
            outputs=[query_input, response_type, top_k_slider]
        )

        medication_btn.click(
            fn=lambda: set_quick_query("What medications are used for diabetes?", "medication"),
            outputs=[query_input, response_type, top_k_slider]
        )

        # Main action
        submit_btn.click(
            fn=process_query_with_unique_items,
            inputs=[query_input, response_type, top_k_slider],
            outputs=[output_html, metrics_json, docs_table]
        )

        # Clear action
        clear_btn.click(
            fn=lambda: ["", "symptoms", 5,
                       "<div style='text-align: center; padding: 40px; color: #666;'>Enter a query above and click 'Extract Unique Items + Analysis'</div>",
                       {}, pd.DataFrame({"Message": ["No analysis performed yet"]})],
            outputs=[query_input, response_type, top_k_slider, output_html, metrics_json, docs_table]
        )

        # Footer
        gr.Markdown("""
        ---
        ### üîí Important Notice
        *This system provides AI-generated analysis from clinical documents.*
        *For actual medical decisions, consult healthcare professionals.*
        *All data is de-identified for privacy protection.*

        **üéØ Response Format:**
        - **Unique Items Grid**: Clean, numbered list of identified items
        - **Detailed Analysis**: Comprehensive information for each item

        **üìä Each analysis includes:**
        1. Unique items identified (symptoms, treatments, etc.)
        2. Detailed description for each item
        3. Source document information
        4. System performance metrics
        """)

    return demo

# =============================================================================
# LAUNCH THE INTERFACE
# =============================================================================

print("üöÄ Creating Clinical RAG Interface with Unique Items Grid...")
demo = create_simple_gradio_interface()

# Kill existing processes
import subprocess
import time

print("üîÑ Checking for existing processes...")
subprocess.run(["pkill", "-f", "gradio"], capture_output=True)
subprocess.run(["pkill", "-f", "uvicorn"], capture_output=True)
time.sleep(2)

# Try different ports
ports_to_try = [7860, 7861, 7862, 7863, 7864]

for port in ports_to_try:
    print(f"\nüéØ Attempting to launch on port {port}...")
    try:
        demo.launch(
            server_name="0.0.0.0",
            server_port=port,
            share=True,
            quiet=False,
            debug=False
        )
        print(f"‚úÖ Success! Interface running on port {port}")
        print(f"üì± Check the gradio.app link above")
        break
    except Exception as e:
        print(f"‚ö†Ô∏è Port {port} failed: {str(e)[:100]}...")
        continue

# If all ports fail, try with ngrok
if 'demo' not in locals() or not hasattr(demo, 'server'):
    print("\nüîÑ Trying ngrok tunnel...")
    !pip install pyngrok -q

    from pyngrok import ngrok

    ngrok.kill()

    public_url = ngrok.connect(addr="7865", proto="http")
    print(f"üåê Public URL: {public_url}")

    demo.launch(
        server_name="0.0.0.0",
        server_port=7865,
        share=False,
        quiet=True
    )

    print(f"\n‚úÖ Interface ready!")
    print(f"üì± Open: {public_url}")

print("\nüéâ Ready! Query medical topics and get unique items grid + detailed analysis.")

üöÄ Creating Clinical RAG Interface with Unique Items Grid...
üîÑ Checking for existing processes...
IMPORTANT: You are using gradio version 4.21.0, however version 4.44.1 is available, please upgrade.
--------

üéØ Attempting to launch on port 7860...
‚ö†Ô∏è Port 7860 failed: Cannot find empty port in range: 7860-7860. You can specify a different port by setting the GRADIO_S...

üéØ Attempting to launch on port 7861...
‚ö†Ô∏è Port 7861 failed: Cannot find empty port in range: 7861-7861. You can specify a different port by setting the GRADIO_S...

üéØ Attempting to launch on port 7862...
‚ö†Ô∏è Port 7862 failed: Cannot find empty port in range: 7862-7862. You can specify a different port by setting the GRADIO_S...

üéØ Attempting to launch on port 7863...
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://10c0e31154e8a4e51e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrade

‚úÖ Success! Interface running on port 7863
üì± Check the gradio.app link above

üéâ Ready! Query medical topics and get unique items grid + detailed analysis.
