# üè•‚≠ê Complete Use Case: Medical Diagnosis - Heart Disease

<div style="background-color: #ffebee; padding: 20px; border-radius: 5px; border-left: 5px solid #d32f2f;">
<b>üî• CRITICAL HEALTHCARE APPLICATION</b><br>
<b>Level:</b> Advanced - Clinical Decision Support<br>
<b>Duration:</b> 40 minutes<br>
<b>Dataset:</b> Heart Disease (UCI)<br>
<b>Importance:</b> ‚≠ê‚≠ê LIFE-CRITICAL - Medical AI validation
</div>

---

## üéØ Objectives

By the end of this notebook, you will be able to:
- ‚úÖ Validate medical AI models for clinical use
- ‚úÖ Prioritize **uncertainty quantification** (critical for healthcare)
- ‚úÖ Maximize **robustness** (patient safety)
- ‚úÖ Analyze false negatives (life-threatening errors)
- ‚úÖ Generate clinical validation reports
- ‚úÖ Implement physician-AI collaboration workflow

---

## üìö Table of Contents

1. [Clinical Context](#context)
2. [Medical AI Requirements](#requirements)
3. [Data & EDA](#data)
4. [Model Training](#training)
5. [Performance Analysis](#performance)
6. [Uncertainty Quantification (CRITICAL)](#uncertainty)
7. [Robustness Testing](#robustness)
8. [Error Analysis - False Negatives](#errors)
9. [Clinical Validation](#validation)
10. [Physician-AI Workflow](#workflow)
11. [Conclusion](#conclusion)

<a id="context"></a>
## 1. üè• Clinical Context

### The Scenario

You are an **AI/ML Engineer at HeartCare Hospital**, developing a clinical decision support system.

**Your Mission:**
> "We need an AI system to help physicians identify patients at high risk of heart disease. The model will flag high-risk patients for further evaluation. This is NOT autonomous - physicians make final decisions, but AI must be reliable, provide confidence estimates, and never miss critical cases."
> 
> ‚Äî Chief Medical Officer

### ü©∫ Clinical Requirements

1. **High Recall (Sensitivity)** - Must catch most diseased patients
   - Missing a diseased patient (false negative) = potential death ‚ò†Ô∏è
   - False positives are acceptable (extra tests are better than missing disease)

2. **Uncertainty Quantification** - MANDATORY
   - Every prediction needs confidence interval
   - Low confidence ‚Üí physician reviews more carefully
   - High confidence ‚Üí faster triage

3. **Robustness** - Model must be stable
   - Small variations in patient data shouldn't flip diagnosis
   - Measurement errors are common in clinical practice

4. **Explainability** - Physicians need to understand
   - Which factors drive the prediction?
   - Can the prediction be clinically justified?

### ‚ö†Ô∏è Stakes

**Failure modes:**
- üíÄ **False Negatives** - Missed heart disease = patient dies
- üí∏ **Too many False Positives** - Unnecessary tests, anxiety, cost
- ‚öñÔ∏è **No confidence estimates** - Physicians can't calibrate trust
- üîß **Brittle model** - Fails on measurement noise

**Success:**
- ‚úÖ Early detection saves lives
- ‚úÖ Efficient triage reduces workload
- ‚úÖ Physician trust through transparency
- ‚úÖ Improved patient outcomes

**Let's build it right!** ü©∫

<a id="requirements"></a>
## 2. üìã Medical AI Requirements

### üèõÔ∏è Regulatory Context

#### FDA - Medical Device Classification
- Clinical decision support systems are regulated as **medical devices**
- Requirements: Safety, efficacy, validation studies

#### Key Standards
- **ISO 13485** - Medical devices quality management
- **IEC 62304** - Medical device software lifecycle
- **HIPAA** - Patient data privacy

### üéØ Performance Thresholds

| Metric | Target | Critical Threshold |
|--------|--------|--------------------|
| **Recall (Sensitivity)** | ‚â• 0.95 | ‚â• 0.90 (CRITICAL) |
| **Specificity** | ‚â• 0.80 | ‚â• 0.70 |
| **ROC AUC** | ‚â• 0.90 | ‚â• 0.85 |
| **Robustness Score** | ‚â• 0.90 | ‚â• 0.85 |
| **Uncertainty Coverage** | ‚â• 0.90 | ‚â• 0.85 |

### üî¨ Validation Requirements

1. ‚úÖ **Clinical Validation** - Performance on real patient data
2. ‚úÖ **Prospective Study** - Test on new patients (not just retrospective)
3. ‚úÖ **Physician Review** - Medical professionals validate predictions
4. ‚úÖ **Uncertainty Analysis** - Confidence calibration
5. ‚úÖ **Robustness Testing** - Stability under noise
6. ‚úÖ **Error Analysis** - Understand failure modes

### üìù Documentation

Must include:
- Clinical study protocol
- Validation results
- Risk analysis (FMEA)
- Instructions for use
- Intended use statement
- Limitations and contraindications

<a id="data"></a>
## 3. üìä Data & Exploratory Analysis

### Setup

In [None]:
# Imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from pathlib import Path

# sklearn
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, recall_score, precision_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)

# DeepBridge
from deepbridge import DBDataset, Experiment

# Settings
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('Set2')
%matplotlib inline

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

print("‚úÖ Setup complete!")
print("üè• Project: Heart Disease Clinical Decision Support")

### Load Heart Disease Dataset

In [None]:
# For demo, we'll create a realistic heart disease dataset
# (In production, use real clinical data with IRB approval)

print("üìä Loading Heart Disease dataset...\n")

# Create realistic synthetic clinical data
np.random.seed(RANDOM_STATE)
n = 1000

df = pd.DataFrame({
    # Demographics
    'age': np.random.randint(30, 80, n),
    'sex': np.random.choice([0, 1], n, p=[0.45, 0.55]),  # 0=F, 1=M
    
    # Clinical measurements
    'rest_bp': np.random.normal(130, 20, n).clip(90, 200),  # Resting blood pressure
    'cholesterol': np.random.normal(240, 50, n).clip(120, 400),  # Serum cholesterol
    'fasting_bs': np.random.choice([0, 1], n, p=[0.7, 0.3]),  # Fasting blood sugar > 120
    'max_hr': np.random.normal(150, 25, n).clip(70, 200),  # Max heart rate
    'exercise_angina': np.random.choice([0, 1], n, p=[0.65, 0.35]),  # Exercise induced angina
    'oldpeak': np.random.exponential(1.0, n).clip(0, 6),  # ST depression
    'num_vessels': np.random.choice([0, 1, 2, 3], n, p=[0.5, 0.25, 0.15, 0.1]),  # Vessels colored
    
    # Categorical
    'chest_pain': np.random.choice([0, 1, 2, 3], n, p=[0.3, 0.3, 0.2, 0.2]),
    'rest_ecg': np.random.choice([0, 1, 2], n, p=[0.5, 0.4, 0.1]),
    'slope': np.random.choice([0, 1, 2], n, p=[0.3, 0.5, 0.2]),
    'thal': np.random.choice([0, 1, 2, 3], n, p=[0.05, 0.5, 0.3, 0.15])
})

# Create target (disease presence) based on clinical risk factors
risk_score = (
    (df['age'] - 40) / 40 * 0.15 +
    df['sex'] * 0.10 +  # Males higher risk
    (df['rest_bp'] - 120) / 80 * 0.12 +
    (df['cholesterol'] - 200) / 200 * 0.12 +
    df['fasting_bs'] * 0.08 +
    (200 - df['max_hr']) / 100 * 0.10 +
    df['exercise_angina'] * 0.15 +
    df['oldpeak'] / 6 * 0.10 +
    df['num_vessels'] / 3 * 0.08
)

df['disease'] = (risk_score + np.random.normal(0, 0.15, n) > 0.45).astype(int)

print(f"‚úÖ Dataset loaded: {df.shape}")
print(f"\nüìä Disease prevalence: {df['disease'].mean():.1%}")
print(f"   Healthy: {(df['disease']==0).sum()}")
print(f"   Disease: {(df['disease']==1).sum()}")

### Clinical Features Description

In [None]:
print("üìã CLINICAL FEATURES DESCRIPTION")
print("=" * 80)

features_desc = {
    'age': 'Patient age (years)',
    'sex': 'Sex (0=Female, 1=Male)',
    'chest_pain': 'Chest pain type (0-3)',
    'rest_bp': 'Resting blood pressure (mm Hg)',
    'cholesterol': 'Serum cholesterol (mg/dl)',
    'fasting_bs': 'Fasting blood sugar > 120 mg/dl (0=No, 1=Yes)',
    'rest_ecg': 'Resting ECG results (0-2)',
    'max_hr': 'Maximum heart rate achieved',
    'exercise_angina': 'Exercise induced angina (0=No, 1=Yes)',
    'oldpeak': 'ST depression induced by exercise',
    'slope': 'Slope of peak exercise ST segment (0-2)',
    'num_vessels': 'Number of major vessels colored by fluoroscopy (0-3)',
    'thal': 'Thalassemia (0-3)',
    'disease': 'üéØ TARGET: Heart disease presence (0=No, 1=Yes)'
}

for feat, desc in features_desc.items():
    print(f"   ‚Ä¢ {feat:20s}: {desc}")

### EDA - Key Clinical Distributions

In [None]:
# Disease distribution by key risk factors
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.ravel()

# Age vs Disease
for disease in [0, 1]:
    axes[0].hist(df[df['disease']==disease]['age'], bins=20, alpha=0.6,
                 label=f'Disease={disease}', edgecolor='black')
axes[0].set_title('Age Distribution', fontweight='bold')
axes[0].set_xlabel('Age')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Cholesterol vs Disease
for disease in [0, 1]:
    axes[1].hist(df[df['disease']==disease]['cholesterol'], bins=20, alpha=0.6,
                 label=f'Disease={disease}', edgecolor='black')
axes[1].set_title('Cholesterol Distribution', fontweight='bold')
axes[1].set_xlabel('Cholesterol (mg/dl)')
axes[1].legend()
axes[1].grid(alpha=0.3)

# Max HR vs Disease
for disease in [0, 1]:
    axes[2].hist(df[df['disease']==disease]['max_hr'], bins=20, alpha=0.6,
                 label=f'Disease={disease}', edgecolor='black')
axes[2].set_title('Max Heart Rate', fontweight='bold')
axes[2].set_xlabel('Max HR')
axes[2].legend()
axes[2].grid(alpha=0.3)

# Disease by Sex
disease_by_sex = pd.crosstab(df['sex'], df['disease'], normalize='index') * 100
disease_by_sex.plot(kind='bar', ax=axes[3], color=['lightgreen', 'coral'])
axes[3].set_title('Disease Rate by Sex', fontweight='bold')
axes[3].set_xlabel('Sex (0=F, 1=M)')
axes[3].set_ylabel('Percentage')
axes[3].set_xticklabels(['Female', 'Male'], rotation=0)
axes[3].legend(['Healthy', 'Disease'])
axes[3].grid(alpha=0.3)

# Disease by Exercise Angina
disease_by_angina = pd.crosstab(df['exercise_angina'], df['disease'], normalize='index') * 100
disease_by_angina.plot(kind='bar', ax=axes[4], color=['lightgreen', 'coral'])
axes[4].set_title('Disease Rate by Exercise Angina', fontweight='bold')
axes[4].set_xlabel('Exercise Angina')
axes[4].set_ylabel('Percentage')
axes[4].set_xticklabels(['No', 'Yes'], rotation=0)
axes[4].legend(['Healthy', 'Disease'])
axes[4].grid(alpha=0.3)

# Disease by Num Vessels
disease_by_vessels = pd.crosstab(df['num_vessels'], df['disease'], normalize='index') * 100
disease_by_vessels.plot(kind='bar', ax=axes[5], color=['lightgreen', 'coral'])
axes[5].set_title('Disease Rate by Num Vessels', fontweight='bold')
axes[5].set_xlabel('Number of Vessels')
axes[5].set_ylabel('Percentage')
axes[5].legend(['Healthy', 'Disease'])
axes[5].grid(alpha=0.3)

plt.tight_layout()
plt.show()

<a id="training"></a>
## 4. ü§ñ Model Training

### Prepare Data

In [None]:
# Features and target
feature_cols = [col for col in df.columns if col != 'disease']
X = df[feature_cols]
y = df['disease']

# Train/test split (stratified)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=RANDOM_STATE, stratify=y
)

print(f"üìä Data Split:")
print(f"   Train: {X_train.shape} (Disease rate: {y_train.mean():.1%})")
print(f"   Test: {X_test.shape} (Disease rate: {y_test.mean():.1%})")

### Train Random Forest Classifier

For medical applications, we prioritize:
- **High recall** - catch disease cases
- **Calibrated probabilities** - for uncertainty quantification

In [None]:
print("ü©∫ Training clinical decision support model...\n")

# RandomForest with adjusted parameters for clinical use
model = RandomForestClassifier(
    n_estimators=300,
    max_depth=8,
    min_samples_split=10,
    min_samples_leaf=5,
    class_weight='balanced',  # ‚Üê Handle class imbalance
    random_state=RANDOM_STATE,
    n_jobs=-1
)

model.fit(X_train, y_train)

print("‚úÖ Model trained!")
print(f"   Algorithm: RandomForestClassifier")
print(f"   Trees: {model.n_estimators}")
print(f"   Class weights: Balanced (favors recall)")

<a id="performance"></a>
## 5. üìä Performance Analysis

### Clinical Metrics - Focus on Recall!

In [None]:
# Predictions
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)
y_proba_train = model.predict_proba(X_train)[:, 1]
y_proba_test = model.predict_proba(X_test)[:, 1]

# Calculate metrics
print("üìä CLINICAL PERFORMANCE METRICS")
print("=" * 80)

metrics = {
    'Recall (Sensitivity)': [recall_score(y_train, y_pred_train), 
                              recall_score(y_test, y_pred_test)],
    'Precision (PPV)': [precision_score(y_train, y_pred_train), 
                         precision_score(y_test, y_pred_test)],
    'Accuracy': [accuracy_score(y_train, y_pred_train), 
                  accuracy_score(y_test, y_pred_test)],
    'F1 Score': [f1_score(y_train, y_pred_train), 
                  f1_score(y_test, y_pred_test)],
    'ROC AUC': [roc_auc_score(y_train, y_proba_train), 
                 roc_auc_score(y_test, y_proba_test)]
}

metrics_df = pd.DataFrame(metrics, index=['Train', 'Test']).T
display(metrics_df.style.format("{:.3f}").background_gradient(cmap='RdYlGn', axis=1))

# Check critical thresholds
recall_test = recall_score(y_test, y_pred_test)
print(f"\nü©∫ CRITICAL CLINICAL THRESHOLD:")
print(f"   Recall (Sensitivity): {recall_test:.3f}")
if recall_test >= 0.90:
    print(f"   ‚úÖ EXCELLENT - Catches ‚â•90% of disease cases")
elif recall_test >= 0.85:
    print(f"   üü° ACCEPTABLE - Consider improving")
else:
    print(f"   ‚ùå INSUFFICIENT - Too many missed cases!")
    print(f"   ‚ö†Ô∏è  ACTION REQUIRED: Adjust threshold or retrain")

### Confusion Matrix - Clinical Interpretation

In [None]:
# Confusion matrix
cm = confusion_matrix(y_test, y_pred_test)
tn, fp, fn, tp = cm.ravel()

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['Healthy (Pred)', 'Disease (Pred)'],
            yticklabels=['Healthy (Actual)', 'Disease (Actual)'])
plt.title('Confusion Matrix - Clinical Interpretation', fontsize=14, fontweight='bold')
plt.ylabel('Actual Diagnosis', fontsize=12)
plt.xlabel('Model Prediction', fontsize=12)

# Add clinical labels
plt.text(0.5, 0.15, f'TN={tn}\n(Correctly identified healthy)', 
         ha='center', va='center', fontsize=10, color='green', fontweight='bold')
plt.text(1.5, 0.15, f'FP={fp}\n(False alarm - extra tests)', 
         ha='center', va='center', fontsize=10, color='orange', fontweight='bold')
plt.text(0.5, 1.15, f'FN={fn}\n‚ö†Ô∏è  CRITICAL - Missed disease!', 
         ha='center', va='center', fontsize=10, color='red', fontweight='bold')
plt.text(1.5, 1.15, f'TP={tp}\n(Correctly identified disease)', 
         ha='center', va='center', fontsize=10, color='darkgreen', fontweight='bold')

plt.tight_layout()
plt.show()

# Clinical interpretation
print("\nü©∫ CLINICAL INTERPRETATION:")
print("=" * 80)
print(f"\n‚úÖ True Negatives (TN): {tn} - Healthy patients correctly identified")
print(f"üü° False Positives (FP): {fp} - Healthy but flagged for further testing")
print(f"   Impact: Extra tests, patient anxiety, costs")
print(f"   Acceptable: Yes (better safe than sorry)")
print(f"\n‚úÖ True Positives (TP): {tp} - Disease caught by AI")
print(f"‚ùå False Negatives (FN): {fn} - CRITICAL ERRORS")
print(f"   Impact: Missed disease, potential death")
print(f"   Acceptable: NO - Must minimize!")

# Calculate specificity
specificity = tn / (tn + fp)
sensitivity = tp / (tp + fn)

print(f"\nüìä Clinical Metrics:")
print(f"   Sensitivity (Recall): {sensitivity:.3f} - {tp}/{tp+fn} disease cases caught")
print(f"   Specificity: {specificity:.3f} - {tn}/{tn+fp} healthy correctly identified")
print(f"   Positive Predictive Value: {tp/(tp+fp):.3f}")
print(f"   Negative Predictive Value: {tn/(tn+fn):.3f}")

## Continuing...

Next sections:
- Section 6: Uncertainty Quantification (CRITICAL)
- Section 7: Robustness Testing
- Section 8: Error Analysis
- Sections 9-11: Clinical validation, workflow, conclusion

**Key message:** Medical AI requires extreme validation - uncertainty, robustness, and error analysis are MANDATORY!