# Project 4: Clinical Chatbot - RAG-based Medical Q&A System

**Objective**: Build an intelligent chatbot that answers clinical questions using retrieval-augmented generation (RAG) over medical knowledge bases

**Tech Stack**: PySpark, Hugging Face Transformers, FAISS, Sentence Transformers

## Cell 1: Environment Setup & Library Installation

In [0]:
# Install required libraries
%pip install transformers sentence-transformers faiss-cpu

dbutils.library.restartPython()

## Cell 2: Import Libraries

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf, explode, array, lit, concat_ws, monotonically_increasing_id
from pyspark.sql.types import StringType, ArrayType, StructType, StructField, FloatType

import pandas as pd
import numpy as np
from typing import List, Dict
import re

# NLP and ML libraries
from sentence_transformers import SentenceTransformer
import faiss

print("✅ Libraries imported successfully")

## Cell 3: Create Medical Knowledge Base

In [0]:
# Medical knowledge base
medical_knowledge_data = [
    {
        "doc_id": "DM_001",
        "category": "Endocrinology",
        "title": "Type 2 Diabetes Management",
        "content": """Type 2 diabetes mellitus is a chronic metabolic disorder characterized by insulin resistance 
        and relative insulin deficiency. First-line treatment includes lifestyle modifications such as diet and exercise. 
        Metformin is the recommended initial pharmacological therapy unless contraindicated. HbA1c target is generally 
        below 7% for most adults. Monitoring includes regular blood glucose testing, HbA1c every 3-6 months, and annual 
        screening for complications including retinopathy, nephropathy, and neuropathy."""
    },
    {
        "doc_id": "CV_001",
        "category": "Cardiology",
        "title": "Hypertension Guidelines",
        "content": """Hypertension is defined as blood pressure ≥130/80 mmHg. Initial management includes lifestyle 
        modifications: DASH diet, sodium restriction (<2300mg/day), weight loss, regular aerobic exercise, and 
        limited alcohol intake. For stage 1 hypertension (130-139/80-89), consider pharmacotherapy if 10-year ASCVD 
        risk ≥10%. First-line medications include ACE inhibitors, ARBs, calcium channel blockers, and thiazide diuretics. 
        Target BP is <130/80 for most adults."""
    },
    {
        "doc_id": "RESP_001",
        "category": "Pulmonology",
        "title": "Asthma Management",
        "content": """Asthma is a chronic inflammatory airway disease characterized by reversible airflow obstruction. 
        Classification includes intermittent and persistent (mild, moderate, severe). Step therapy approach: Step 1 - 
        SABA as needed; Step 2 - low-dose ICS; Step 3 - low-dose ICS + LABA or medium-dose ICS; Step 4 - medium-dose 
        ICS + LABA; Step 5 - high-dose ICS + LABA, consider biologics. Asthma action plan should be provided to all patients. 
        Spirometry recommended for diagnosis and monitoring."""
    },
    {
        "doc_id": "INF_001",
        "category": "Infectious Disease",
        "title": "Community-Acquired Pneumonia Treatment",
        "content": """Community-acquired pneumonia (CAP) treatment depends on severity and risk factors. Outpatient 
        treatment for healthy adults: macrolide (azithromycin) or doxycycline. For patients with comorbidities: 
        respiratory fluoroquinolone or beta-lactam plus macrolide. Hospitalized patients require beta-lactam plus 
        macrolide or respiratory fluoroquinolone. ICU patients need beta-lactam plus azithromycin or fluoroquinolone. 
        CURB-65 or PSI score helps determine admission need. Blood cultures and sputum culture recommended for 
        severe cases."""
    },
    {
        "doc_id": "ONCO_001",
        "category": "Oncology",
        "title": "Breast Cancer Screening",
        "content": """Breast cancer screening recommendations vary by risk. Average-risk women: annual mammography 
        starting age 40-45, continuing as long as life expectancy >10 years. Clinical breast exam every 1-3 years 
        for women in their 20s-30s, annually for age 40+. High-risk women (BRCA mutation, prior thoracic radiation): 
        annual mammography and MRI starting age 30 or 10 years after radiation. Risk assessment tools include 
        Gail model and Tyrer-Cuzick model. Genetic counseling recommended for strong family history."""
    },
    {
        "doc_id": "NEURO_001",
        "category": "Neurology",
        "title": "Migraine Management",
        "content": """Migraine headaches are characterized by recurrent episodes of moderate to severe headache, often 
        unilateral and pulsating. Acute treatment includes NSAIDs, triptans, or gepants for moderate to severe attacks. 
        Preventive therapy indicated if ≥4 headache days per month or disabling attacks. First-line preventives include 
        beta-blockers (propranolol, metoprolol), antiepileptics (topiramate, valproate), or CGRP antagonists. 
        Lifestyle modifications: regular sleep, hydration, stress management, trigger avoidance."""
    }
]

# Create DataFrame
knowledge_df = spark.createDataFrame(medical_knowledge_data)

print(f"✅ Loaded {knowledge_df.count()} medical knowledge documents")
knowledge_df.show(truncate=False)

## Cell 4: Text Chunking

In [0]:
def chunk_text(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> List[str]:
    """Split text into overlapping chunks"""
    words = text.split()
    chunks = []
    
    for i in range(0, len(words), chunk_size - chunk_overlap):
        chunk = ' '.join(words[i:i + chunk_size])
        if chunk:
            chunks.append(chunk)
    
    return chunks

# Register UDF
chunk_text_udf = udf(chunk_text, ArrayType(StringType()))

# Apply chunking
chunked_df = knowledge_df.withColumn("chunks", chunk_text_udf(col("content")))

# Explode chunks
exploded_df = chunked_df.select(
    col("doc_id"),
    col("category"),
    col("title"),
    explode(col("chunks")).alias("chunk_text")
).withColumn("chunk_id", concat_ws("_", col("doc_id"), lit("chunk"), 
                                   monotonically_increasing_id().cast("string")))

print(f"✅ Created {exploded_df.count()} text chunks")
exploded_df.show(5, truncate=80)

## Cell 5: Generate Embeddings

In [0]:
# Load embedding model
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

print(f"✅ Loaded embedding model")
print(f"   Embedding dimension: {embedding_model.get_sentence_embedding_dimension()}")

# Convert to Pandas for embedding generation
chunks_pdf = exploded_df.toPandas()

# Generate embeddings
print("🔄 Generating embeddings...")
embeddings = embedding_model.encode(chunks_pdf['chunk_text'].tolist(), show_progress_bar=True)

chunks_pdf['embedding'] = embeddings.tolist()

print(f"✅ Generated {len(embeddings)} embeddings")

## Cell 6: Build FAISS Index

In [0]:
# Convert to numpy array
embeddings_array = np.array(embeddings).astype('float32')

# Build FAISS index
dimension = embeddings_array.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings_array)

print(f"✅ Built FAISS index")
print(f"   Vectors: {index.ntotal}")
print(f"   Dimension: {dimension}")

## Cell 7: Implement Retrieval

In [0]:
class MedicalRetriever:
    """Retrieve relevant medical knowledge"""
    
    def __init__(self, faiss_index, metadata_df, embedding_model):
        self.index = faiss_index
        self.metadata = metadata_df
        self.model = embedding_model
        
    def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
        """Retrieve most relevant chunks"""
        # Generate query embedding
        query_emb = self.model.encode([query]).astype('float32')
        
        # Search
        distances, indices = self.index.search(query_emb, top_k)
        
        # Get results
        results = []
        for idx, dist in zip(indices[0], distances[0]):
            result = {
                'chunk_text': self.metadata.iloc[idx]['chunk_text'],
                'category': self.metadata.iloc[idx]['category'],
                'title': self.metadata.iloc[idx]['title'],
                'similarity': float(1 / (1 + dist))
            }
            results.append(result)
        
        return results

# Initialize retriever
retriever = MedicalRetriever(index, chunks_pdf, embedding_model)

print("✅ Retriever initialized")

## Cell 8: Test Retrieval

In [0]:
test_queries = [
    "What is the first-line treatment for type 2 diabetes?",
    "What blood pressure medications are recommended?",
    "How to manage asthma?",
    "What antibiotics for pneumonia?"
]

print("=" * 80)
print("RETRIEVAL TESTING")
print("=" * 80)

for query in test_queries:
    print(f"\n📝 Query: {query}")
    print("-" * 80)
    
    results = retriever.retrieve(query, top_k=2)
    
    for i, result in enumerate(results, 1):
        print(f"\nResult {i}:")
        print(f"  Category: {result['category']}")
        print(f"  Title: {result['title']}")
        print(f"  Similarity: {result['similarity']:.4f}")
        print(f"  Text: {result['chunk_text'][:200]}...")

## Cell 9: Build Simple Chatbot

In [0]:
class ClinicalChatbot:
    """Simple RAG-based chatbot"""
    
    def __init__(self, retriever):
        self.retriever = retriever
        self.history = []
        
    def ask(self, query: str, top_k: int = 2) -> Dict:
        """Answer a clinical question"""
        # Retrieve context
        docs = self.retriever.retrieve(query, top_k=top_k)
        
        # Build response (simplified - in production use LLM)
        context_text = "\n\n".join([
            f"[{doc['category']} - {doc['title']}]\n{doc['chunk_text']}"
            for doc in docs
        ])
        
        response = {
            'query': query,
            'context': context_text,
            'sources': [
                {
                    'category': doc['category'],
                    'title': doc['title'],
                    'similarity': doc['similarity']
                }
                for doc in docs
            ],
            'answer': f"Based on {docs[0]['category']} guidelines:\n\n{docs[0]['chunk_text'][:400]}..."
        }
        
        self.history.append(response)
        return response

# Initialize chatbot
chatbot = ClinicalChatbot(retriever)

print("✅ Clinical Chatbot ready")

## Cell 10: Interactive Demo

In [0]:
demo_queries = [
    "What is metformin used for?",
    "Recommended blood pressure targets?",
    "When to start asthma prevention therapy?"
]

print("=" * 80)
print("CLINICAL CHATBOT DEMO")
print("=" * 80)

for query in demo_queries:
    print(f"\n{'='*80}")
    print(f"❓ Question: {query}")
    print('='*80)
    
    response = chatbot.ask(query)
    
    print(f"\n💡 Answer:\n{response['answer']}")
    
    print(f"\n📚 Sources:")
    for i, source in enumerate(response['sources'], 1):
        print(f"  {i}. {source['category']} - {source['title']} (Score: {source['similarity']:.3f})")

## Cell 11: Performance Metrics

In [0]:
import time

# Benchmark retrieval
latencies = []

for _ in range(50):
    start = time.time()
    _ = retriever.retrieve("test query", top_k=3)
    latencies.append((time.time() - start) * 1000)

print("⏱️  Performance Metrics:")
print(f"  Mean latency: {np.mean(latencies):.2f} ms")
print(f"  P50 latency: {np.percentile(latencies, 50):.2f} ms")
print(f"  P95 latency: {np.percentile(latencies, 95):.2f} ms")
print(f"  P99 latency: {np.percentile(latencies, 99):.2f} ms")

## Cell 12: Summary Statistics

In [0]:
# Conversation history
history_df = pd.DataFrame([
    {
        'query': item['query'],
        'num_sources': len(item['sources']),
        'top_category': item['sources'][0]['category'],
        'top_similarity': item['sources'][0]['similarity']
    }
    for item in chatbot.history
])

print("📊 Chatbot Statistics:")
print(f"  Total queries: {len(chatbot.history)}")
print(f"  Knowledge base documents: {knowledge_df.count()}")
print(f"  Total chunks: {len(chunks_pdf)}")
print(f"  Embedding dimension: {dimension}")
print(f"  Average retrieval similarity: {history_df['top_similarity'].mean():.4f}")

print("\n📈 Category Distribution:")
print(history_df['top_category'].value_counts())

print("\n" + "="*80)
print("✅ Project 4 Complete - Clinical Chatbot Functional")
print("="*80)