# üè• AI-Based Medical Diagnosis Assistant

## Complete ML Pipeline with Explainable AI (SHAP)

This notebook implements an end-to-end medical AI system that:
- Predicts disease risk probability from patient data
- Identifies top contributing factors (explainable AI)
- Handles class imbalance with SMOTE
- Compares Logistic Regression, Random Forest, and XGBoost models
- Uses SHAP for local and global explainability
- Prepares models for Streamlit deployment

**Date**: February 2026  
**Purpose**: Educational & Research Use Only

## Section 1: Import Required Libraries

Import all necessary packages for data processing, ML modeling, and explainability

In [None]:
# Add parent directory to path to import custom modules
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

# Data processing
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import (
    confusion_matrix, precision_recall_curve, roc_auc_score, 
    roc_curve, auc, f1_score, classification_report
)

# Models
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb

# Imbalance handling
from imblearn.over_sampling import SMOTE

# Explainability
import shap

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px

# Utilities
import warnings
warnings.filterwarnings('ignore')
import joblib
import json

print("‚úÖ All libraries imported successfully!")
print(f"XGBoost version: {xgb.__version__}")
print(f"SHAP version: {shap.__version__}")

## Section 2: Load and Explore Medical Dataset

Load the medical dataset containing symptoms, vitals, and lab values. Perform exploratory data analysis.

In [None]:
# Import custom modules
from src.data_generator import generate_medical_dataset
from src.data_processor import MedicalDataProcessor, prepare_data

# Generate dataset (or load if exists)
data_path = Path("../data/medical_data_single_disease.csv")

if data_path.exists():
    print(f"Loading dataset from {data_path}")
    df = pd.read_csv(data_path)
else:
    print("Generating synthetic medical dataset...")
    df = generate_medical_dataset(n_samples=2000, imbalance_ratio=0.15)
    data_path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(data_path, index=False)
    print(f"Dataset saved to {data_path}")

print("\n" + "="*60)
print("DATASET OVERVIEW")
print("="*60)
print(f"\nShape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")
print(f"\nFirst few rows:")
print(df.head())

In [None]:
# Exploratory Data Analysis
print("\n" + "="*60)
print("DATA STATISTICS")
print("="*60)
print(df.describe())

print("\n" + "="*60)
print("TARGET VARIABLE DISTRIBUTION")
print("="*60)
target_counts = df['disease_risk'].value_counts()
print(target_counts)
print(f"\nDisease Prevalence: {df['disease_risk'].mean():.2%}")
print(f"Class Imbalance Ratio: {target_counts[0]/target_counts[1]:.2f}:1")

# Visualize class distribution
fig = px.pie(
    values=target_counts.values,
    names=['Healthy', 'Disease'],
    title='Target Variable Distribution',
    color_discrete_map={'Healthy': '#2ca02c', 'Disease': '#d62728'}
)
fig.show()

# Check for missing values
print(f"\n{'Missing Values:':30} {df.isnull().sum().sum()}")
print(f"{'Data Types:':30} {df.dtypes.nunique()} unique")

## Section 3: Data Preprocessing and Feature Engineering

Clean data, normalize features, and create engineered features from symptoms, vitals, and lab values.

In [None]:
# Prepare and preprocess data
print("\n" + "="*60)
print("DATA PREPROCESSING & FEATURE ENGINEERING")
print("="*60)

X_train, X_test, y_train, y_test, processor = prepare_data(
    df,
    target_column='disease_risk',
    test_size=0.2,
    random_state=42
)

print(f"\nFeatures after engineering: {X_train.shape[1]}")
print(f"New features added: {X_train.shape[1] - (len(df.columns) - 1)}")

# Show feature names
print(f"\nFeature names:\n{X_train.columns.tolist()}")

# Visualize feature distributions
print("\nFeature Statistics (Training Set):")
print(X_train.describe())

## Section 4: Handle Class Imbalance with SMOTE

Apply SMOTE (Synthetic Minority Over-sampling Technique) to balance the dataset.

In [None]:
# Demonstrate SMOTE
print("\n" + "="*60)
print("CLASS IMBALANCE HANDLING - SMOTE")
print("="*60)

print(f"\nOriginal Training Set Distribution:")
print(f"Class 0 (Healthy): {np.sum(y_train == 0)} ({np.mean(y_train == 0):.1%})")
print(f"Class 1 (Disease): {np.sum(y_train == 1)} ({np.mean(y_train == 1):.1%})")

# Apply SMOTE
smote = SMOTE(random_state=42, k_neighbors=5)
X_train_smote, y_train_smote = smote.fit_resample(X_train.values, y_train.values)

print(f"\nAfter SMOTE (for XGBoost training):")
print(f"Class 0 (Healthy): {np.sum(y_train_smote == 0)} ({np.mean(y_train_smote == 0):.1%})")
print(f"Class 1 (Disease): {np.sum(y_train_smote == 1)} ({np.mean(y_train_smote == 1):.1%})")
print(f"\nTotal samples: {len(y_train_smote)} (increased from {len(y_train)})")

# Visualization
fig = go.Figure()
fig.add_trace(go.Bar(name='Before SMOTE', x=['Healthy', 'Disease'], y=[np.sum(y_train==0), np.sum(y_train==1)]))
fig.add_trace(go.Bar(name='After SMOTE', x=['Healthy', 'Disease'], y=[np.sum(y_train_smote==0), np.sum(y_train_smote==1)]))
fig.update_layout(title='Class Distribution: Before vs After SMOTE', barmode='group', height=400)
fig.show()

## Section 5: Train Logistic Regression Model (Baseline)

Build and train a Logistic Regression baseline model.

In [None]:
from src.models import LogisticRegressionModel

print("\n" + "="*60)
print("MODEL 1: LOGISTIC REGRESSION (BASELINE)")
print("="*60)

# Train model
lr_model = LogisticRegressionModel(max_iter=1000)
lr_model.train(X_train, y_train)

# Evaluate
y_pred_lr, y_proba_lr = lr_model.evaluate(X_test, y_test)

print("\nModel Metrics:")
print(f"  ROC-AUC: {lr_model.model_metrics['roc_auc']:.4f}")
print(f"  PR-AUC: {lr_model.model_metrics['pr_auc']:.4f}")
print(f"  F1-Score: {lr_model.model_metrics['f1_score']:.4f}")

# ROC Curve
fpr_lr, tpr_lr, _ = roc_curve(y_test, y_proba_lr)
fig_roc = go.Figure()
fig_roc.add_trace(go.Scatter(x=fpr_lr, y=tpr_lr, mode='lines', name=f'Logistic Regression (AUC={lr_model.model_metrics["roc_auc"]:.3f})'))
fig_roc.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Random Classifier'))
fig_roc.update_layout(title='ROC Curve - Logistic Regression', xaxis_title='False Positive Rate', yaxis_title='True Positive Rate')
fig_roc.show()

## Section 6: Train Random Forest Model

Implement and train a Random Forest classifier.

In [None]:
from src.models import RandomForestModel, get_feature_importance

print("\n" + "="*60)
print("MODEL 2: RANDOM FOREST")
print("="*60)

# Train model
rf_model = RandomForestModel(n_estimators=100, max_depth=15)
rf_model.train(X_train, y_train)

# Evaluate
y_pred_rf, y_proba_rf = rf_model.evaluate(X_test, y_test)

print("\nModel Metrics:")
print(f"  ROC-AUC: {rf_model.model_metrics['roc_auc']:.4f}")
print(f"  PR-AUC: {rf_model.model_metrics['pr_auc']:.4f}")
print(f"  F1-Score: {rf_model.model_metrics['f1_score']:.4f}")

# Feature Importance
importance_df = get_feature_importance(rf_model, feature_names=X_train.columns, top_n=10)

# Plot feature importance
fig_imp = go.Figure()
fig_imp.add_trace(go.Bar(
    x=importance_df['importance'],
    y=importance_df['feature'],
    orientation='h'
))
fig_imp.update_layout(title='Top 10 Important Features - Random Forest', 
                     xaxis_title='Importance', yaxis_title='Feature', height=400)
fig_imp.show()

# ROC Curve
fpr_rf, tpr_rf, _ = roc_curve(y_test, y_proba_rf)
fig_roc = go.Figure()
fig_roc.add_trace(go.Scatter(x=fpr_rf, y=tpr_rf, mode='lines', name=f'Random Forest (AUC={rf_model.model_metrics["roc_auc"]:.3f})'))
fig_roc.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Random Classifier'))
fig_roc.update_layout(title='ROC Curve - Random Forest', xaxis_title='False Positive Rate', yaxis_title='True Positive Rate')
fig_roc.show()

## Section 7: Train XGBoost Model with SMOTE

Build and train an XGBoost model with SMOTE balancing.

In [None]:
from src.models import XGBoostModel

print("\n" + "="*60)
print("MODEL 3: XGBOOST WITH SMOTE (BEST PERFORMANCE)")
print("="*60)

# Train model with SMOTE
xgb_model = XGBoostModel(n_estimators=150, max_depth=5, learning_rate=0.1)
xgb_model.train(X_train, y_train, apply_smote=True)

# Evaluate
y_pred_xgb, y_proba_xgb = xgb_model.evaluate(X_test, y_test)

print("\nModel Metrics:")
print(f"  ROC-AUC: {xgb_model.model_metrics['roc_auc']:.4f}")
print(f"  PR-AUC: {xgb_model.model_metrics['pr_auc']:.4f}")
print(f"  F1-Score: {xgb_model.model_metrics['f1_score']:.4f}")

# Feature Importance
importance_df_xgb = get_feature_importance(xgb_model, feature_names=X_train.columns, top_n=10)

# Plot feature importance
fig_imp_xgb = go.Figure()
fig_imp_xgb.add_trace(go.Bar(
    x=importance_df_xgb['importance'],
    y=importance_df_xgb['feature'],
    orientation='h',
    marker_color='#ff7f0e'
))
fig_imp_xgb.update_layout(title='Top 10 Important Features - XGBoost', 
                          xaxis_title='Importance', yaxis_title='Feature', height=400)
fig_imp_xgb.show()

# ROC Curve
fpr_xgb, tpr_xgb, _ = roc_curve(y_test, y_proba_xgb)
fig_roc_xgb = go.Figure()
fig_roc_xgb.add_trace(go.Scatter(x=fpr_xgb, y=tpr_xgb, mode='lines', name=f'XGBoost (AUC={xgb_model.model_metrics["roc_auc"]:.3f})'))
fig_roc_xgb.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Random Classifier'))
fig_roc_xgb.update_layout(title='ROC Curve - XGBoost', xaxis_title='False Positive Rate', yaxis_title='True Positive Rate')
fig_roc_xgb.show()

## Section 8: Model Evaluation & Precision-Recall Optimization

Calculate precision-recall curves and find optimal decision thresholds.

In [None]:
from src.data_processor import calculate_precision_recall_metrics

print("\n" + "="*60)
print("PRECISION-RECALL OPTIMIZATION")
print("="*60)

# Calculate PR metrics for different thresholds
thresholds_to_test = [0.3, 0.4, 0.5, 0.6, 0.7]

print("\nOptimal Threshold Analysis (XGBoost):")
print(f"{'Threshold':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12}")
print("-" * 48)

pr_metrics, pr_auc, (precision, recall, pr_thresholds) = calculate_precision_recall_metrics(
    y_test, y_proba_xgb, thresholds=thresholds_to_test
)

for metric in pr_metrics:
    print(f"{metric['threshold']:<12.1f} {metric['precision']:<12.4f} {metric['recall']:<12.4f} {metric['f1']:<12.4f}")

# Precision-Recall Curve
fig_pr = go.Figure()
fig_pr.add_trace(go.Scatter(
    x=recall, y=precision, mode='lines', fill='tozeroy',
    name=f'XGBoost (PR-AUC={pr_auc:.3f})', line=dict(color='red')
))
fig_pr.update_layout(
    title='Precision-Recall Curve - XGBoost',
    xaxis_title='Recall (Sensitivity)',
    yaxis_title='Precision',
    xaxis=dict(range=[0, 1]),
    yaxis=dict(range=[0, 1])
)
fig_pr.show()

# Recommendation
print(f"\n‚úÖ Recommended Threshold: 0.5 (balanced precision-recall)")
print(f"   At threshold 0.5: Precision={pr_metrics[2]['precision']:.3f}, Recall={pr_metrics[2]['recall']:.3f}")

## Section 9: Generate SHAP Explainability Values

Use SHAP to explain predictions and identify contributing factors.

In [None]:
from src.explainability import ModelExplainer

print("\n" + "="*60)
print("SHAP EXPLAINABILITY ANALYSIS")
print("="*60)

# Create explainer for XGBoost
X_background = X_train.iloc[:min(100, len(X_train))]
explainer = ModelExplainer(xgb_model, X_background, feature_names=X_train.columns.tolist())
explainer.create_explainer(explainer_type='tree')

print("\nGenerating SHAP values for test set...")
shap_values = explainer.explain_dataset(X_test, max_samples=200)

# Get feature importance from SHAP
print("\nGlobal Feature Importance (from SHAP):")
shap_importance = explainer.get_feature_importance_from_shap(top_n=10)
print(shap_importance)

# Example: Explain a high-risk prediction
print("\n" + "="*60)
print("EXAMPLE: EXPLAINING HIGH-RISK PATIENT")
print("="*60)

# Find a high-risk misclassified patient
high_risk_idx = np.where((y_proba_xgb > 0.7) & (y_test == 1))[0]
if len(high_risk_idx) > 0:
    idx = high_risk_idx[0]
    explanation = explainer.explain_prediction(X_test.iloc[idx:idx+1], top_features=5)
    
    print(f"\nPatient Index: {X_test.index[idx]}")
    print(f"Predicted Risk: {explanation['prediction_probability']:.1%}")
    print(f"True Label: {'Disease' if y_test.iloc[idx] == 1 else 'Healthy'}")
    
    print(f"\nTop Contributing Factors:")
    for i, factor in enumerate(explanation['top_contributing_factors'], 1):
        print(f"  {i}. {factor['feature']}")
        print(f"     Value: {factor['feature_value']:.2f}")
        print(f"     SHAP Value: {factor['shap_value']:.4f} ({factor['direction'].replace('_', ' ')})")
        print()

## Section 10: Compare Model Performance

Create comparison metrics across all three models.

In [None]:
print("\n" + "="*60)
print("MODEL PERFORMANCE COMPARISON")
print("="*60)

# Create comparison dataframe
comparison_data = {
    'Model': ['Logistic Regression', 'Random Forest', 'XGBoost + SMOTE'],
    'ROC-AUC': [
        lr_model.model_metrics['roc_auc'],
        rf_model.model_metrics['roc_auc'],
        xgb_model.model_metrics['roc_auc']
    ],
    'PR-AUC': [
        lr_model.model_metrics['pr_auc'],
        rf_model.model_metrics['pr_auc'],
        xgb_model.model_metrics['pr_auc']
    ],
    'F1-Score': [
        lr_model.model_metrics['f1_score'],
        rf_model.model_metrics['f1_score'],
        xgb_model.model_metrics['f1_score']
    ]
}

df_comparison = pd.DataFrame(comparison_data)
print("\n" + df_comparison.to_string(index=False))

# Radar chart
fig_radar = go.Figure()

for idx, model_name in enumerate(df_comparison['Model']):
    fig_radar.add_trace(go.Scatterpolar(
        r=[df_comparison.loc[idx, 'ROC-AUC'],
           df_comparison.loc[idx, 'PR-AUC'],
           df_comparison.loc[idx, 'F1-Score']],
        theta=['ROC-AUC', 'PR-AUC', 'F1-Score'],
        fill='toself',
        name=model_name
    ))

fig_radar.update_layout(
    polar=dict(radialaxis=dict(visible=True, range=[0.7, 1])),
    title='Model Performance Comparison (Radar Chart)',
    height=600
)
fig_radar.show()

# Bar chart
fig_bar = go.Figure()
fig_bar.add_trace(go.Bar(x=df_comparison['Model'], y=df_comparison['ROC-AUC'], name='ROC-AUC'))
fig_bar.add_trace(go.Bar(x=df_comparison['Model'], y=df_comparison['PR-AUC'], name='PR-AUC'))
fig_bar.add_trace(go.Bar(x=df_comparison['Model'], y=df_comparison['F1-Score'], name='F1-Score'))
fig_bar.update_layout(barmode='group', title='Model Metrics Comparison', height=400)
fig_bar.show()

# Best model
best_idx = df_comparison['ROC-AUC'].idxmax()
print(f"\nüèÜ Best Model: {df_comparison.loc[best_idx, 'Model']} (ROC-AUC: {df_comparison.loc[best_idx, 'ROC-AUC']:.4f})")

## Section 11: Create Prediction Function

Develop a prediction function with risk factors for user input.

In [None]:
def predict_disease_risk(patient_data):
    """
    Predict disease risk for a patient
    
    Input: Dict with patient symptoms, vitals, and lab values
    Output: Dict with risk probability and contributing factors
    """
    # Convert to DataFrame
    patient_df = pd.DataFrame([patient_data])
    
    # Preprocess
    patient_scaled = processor.add_interaction_features(patient_df)
    
    # Get predictions from all models
    y_proba_lr = lr_model.predict_proba(patient_scaled)[0, 1]
    y_proba_rf = rf_model.predict_proba(patient_scaled)[0, 1]
    y_proba_xgb = xgb_model.predict_proba(patient_scaled)[0, 1]
    
    # Ensemble average
    ensemble_prob = np.mean([y_proba_lr, y_proba_rf, y_proba_xgb])
    
    # Get SHAP explanation
    explanation = explainer.explain_prediction(patient_scaled, top_features=5)
    
    # Categorize risk
    if ensemble_prob < 0.3:
        risk_level = "LOW"
    elif ensemble_prob < 0.7:
        risk_level = "MEDIUM"
    else:
        risk_level = "HIGH"
    
    return {
        'ensemble_probability': ensemble_prob,
        'risk_level': risk_level,
        'individual_predictions': {
            'logistic_regression': y_proba_lr,
            'random_forest': y_proba_rf,
            'xgboost': y_proba_xgb
        },
        'top_contributing_factors': explanation['top_contributing_factors']
    }

# Example prediction
print("\n" + "="*60)
print("EXAMPLE: PREDICT FOR NEW PATIENT")
print("="*60)

example_patient = {
    'chest_pain_severity': 5,
    'shortness_of_breath': 35,
    'fatigue_level': 40,
    'dizziness': 25,
    'headache_frequency': 15,
    'nausea_level': 20,
    'systolic_bp': 145,
    'diastolic_bp': 95,
    'heart_rate': 92,
    'body_temperature': 37.2,
    'respiratory_rate': 18,
    'oxygen_saturation': 96,
    'cholesterol_total': 240,
    'ldl_cholesterol': 160,
    'hdl_cholesterol': 35,
    'triglycerides': 180,
    'glucose_fasting': 140,
    'hemoglobin_a1c': 7.2,
    'creatinine': 1.1,
    'white_blood_cells': 7.5
}

result = predict_disease_risk(example_patient)

print(f"\nPatient Risk Assessment:")
print(f"  Ensemble Probability: {result['ensemble_probability']:.1%}")
print(f"  Risk Level: {result['risk_level']}")
print(f"\nIndividual Model Predictions:")
print(f"  Logistic Regression: {result['individual_predictions']['logistic_regression']:.1%}")
print(f"  Random Forest: {result['individual_predictions']['random_forest']:.1%}")
print(f"  XGBoost: {result['individual_predictions']['xgboost']:.1%}")
print(f"\nTop Contributing Risk Factors:")
for i, factor in enumerate(result['top_contributing_factors'], 1):
    print(f"  {i}. {factor['feature']}: {factor['feature_value']:.2f} ({factor['direction']})")

## Section 12: Prepare Model for Streamlit Deployment

Save models and create helper functions for dashboard.

In [None]:
from src.models import save_model

print("\n" + "="*60)
print("PREPARING FOR DEPLOYMENT")
print("="*60)

# Create models directory
models_dir = Path("../models")
models_dir.mkdir(parents=True, exist_ok=True)

# Save models
print("\nSaving trained models...")
save_model(lr_model, models_dir / "logistic_regression_model.pkl")
save_model(rf_model, models_dir / "random_forest_model.pkl")
save_model(xgb_model, models_dir / "xgboost_model.pkl")

print("‚úÖ Models saved successfully!")
print(f"Location: {models_dir.absolute()}")

# Create deployment configuration
deployment_config = {
    'model_type': 'medical_diagnosis_ensemble',
    'version': '1.0.0',
    'date': pd.Timestamp.now().isoformat(),
    'models': {
        'logistic_regression': {
            'type': 'LogisticRegression',
            'roc_auc': float(lr_model.model_metrics['roc_auc']),
            'f1_score': float(lr_model.model_metrics['f1_score'])
        },
        'random_forest': {
            'type': 'RandomForestClassifier',
            'roc_auc': float(rf_model.model_metrics['roc_auc']),
            'f1_score': float(rf_model.model_metrics['f1_score'])
        },
        'xgboost': {
            'type': 'XGBClassifier',
            'roc_auc': float(xgb_model.model_metrics['roc_auc']),
            'f1_score': float(xgb_model.model_metrics['f1_score'])
        }
    },
    'features': X_train.columns.tolist(),
    'ensemble_method': 'average',
    'recommended_threshold': 0.5
}

# Save config
config_path = models_dir / "deployment_config.json"
with open(config_path, 'w') as f:
    json.dump(deployment_config, f, indent=2)

print(f"Config saved to {config_path}")

print("\n" + "="*60)
print("NEXT STEPS FOR DEPLOYMENT")
print("="*60)
print("""
1. Run Streamlit Dashboard:
   streamlit run ../app.py

2. Access at:
   http://localhost:8501

3. Features Available:
   - Single Patient Prediction
   - Batch Predictions (CSV upload)
   - Model Comparison
   - Dataset Exploration
   - SHAP Explanations

4. For Production Deployment:
   - Use FastAPI or Flask wrapper
   - Containerize with Docker
   - Deploy on cloud (AWS, GCP, Azure)
   - Set up monitoring and logging
""")