In [None]:
!python --version

In [None]:
!pip --version

In [None]:
!pip install \
"streamlit>=1.28.0" \
"pandas>=2.2.0" \
"numpy>=1.26,<2.0" \
"scikit-learn>=1.4.0" \
"tensorflow>=2.17" \
"joblib>=1.3.0" \
"shap>=0.44.0" \
"matplotlib>=3.8.0" \
"seaborn>=0.13.0" \
"tqdm>=4.65.0" \
"protobuf>=4.25"

In [None]:
import tensorflow as tf
import pandas as pd
import streamlit as st

print(tf.__version__)
print(pd.__version__)

In [23]:
import pandas as pd
import numpy as np
from pathlib import Path

# Get project root (go up from scripts/models to project root)
PROJECT_ROOT = Path(__file__).parent.parent.parent if '__file__' in globals() else Path.cwd().parent.parent
df = pd.read_csv(PROJECT_ROOT / "data" / "gold" / "gold_dataset.csv")
leakage_features = [
    "subject_id",
    "hadm_id",
    
]
X = df.drop(columns=["y"] + leakage_features)
y = df["y"].astype(int)

In [24]:
from sklearn.model_selection import train_test_split

# =====================================================
# Leave-One-Subject-Out (LOSO) Split
# Split by subject_id to ensure all admissions for a patient
# stay together in the same dataset (train/val/test)
# =====================================================

# Get unique subject_ids
unique_subjects = df["subject_id"].unique()
print(f"Total unique subjects: {len(unique_subjects)}")

# First split: 70% train, 30% temp (which will be split into 15% val + 15% test)
# We need to stratify by the target distribution at subject level
# For each subject, get the majority class (or any class if mixed)
subject_labels = df.groupby("subject_id")["y"].agg(lambda x: 1 if x.sum() > 0 else 0).values

# Split subjects: 70% train, 30% temp
subjects_train, subjects_temp, labels_train, labels_temp = train_test_split(
    unique_subjects,
    subject_labels,
    test_size=0.3,
    stratify=subject_labels,
    random_state=42
)

# Split temp (30%) into validation (15%) and test (15%)
# Calculate the proportion: 15% of total = 50% of temp (since temp is 30% of total)
subjects_val, subjects_test, labels_val, labels_test = train_test_split(
    subjects_temp,
    labels_temp,
    test_size=0.5,  # 50% of temp = 15% of total
    stratify=labels_temp,
    random_state=42
)

print(f"\nSplit summary:")
print(f"Training subjects: {len(subjects_train)} ({len(subjects_train)/len(unique_subjects)*100:.1f}%)")
print(f"Validation subjects: {len(subjects_val)} ({len(subjects_val)/len(unique_subjects)*100:.1f}%)")
print(f"Test subjects: {len(subjects_test)} ({len(subjects_test)/len(unique_subjects)*100:.1f}%)")

# Get all rows for each subject group
train_mask = df["subject_id"].isin(subjects_train)
val_mask = df["subject_id"].isin(subjects_val)
test_mask = df["subject_id"].isin(subjects_test)

X_train = X[train_mask].copy()
X_val = X[val_mask].copy()
X_test = X[test_mask].copy()

y_train = y[train_mask].copy()
y_val = y[val_mask].copy()
y_test = y[test_mask].copy()

print(f"\nDataset sizes:")
print(f"Training: {len(X_train)} admissions")
print(f"Validation: {len(X_val)} admissions")
print(f"Test: {len(X_test)} admissions")

print(f"\nTarget distribution in training:")
print(y_train.value_counts())
print(f"\nTarget distribution in validation:")
print(y_val.value_counts())
print(f"\nTarget distribution in test:")
print(y_test.value_counts())


Total unique subjects: 96242

Split summary:
Training subjects: 67369 (70.0%)
Validation subjects: 14436 (15.0%)
Test subjects: 14437 (15.0%)

Dataset sizes:
Training: 110541 admissions
Validation: 23668 admissions
Test: 23811 admissions

Target distribution in training:
y
0    109262
1      1279
Name: count, dtype: int64

Target distribution in validation:
y
0    23367
1      301
Name: count, dtype: int64

Target distribution in test:
y
0    23556
1      255
Name: count, dtype: int64


In [25]:
# =====================================================
# Load Pre-trained Model
# =====================================================
import tensorflow as tf

# Load the trained model
model_path = PROJECT_ROOT / "scripts" / "models" / "cauti_ann_loso_model.keras"
model_path_h5 = PROJECT_ROOT / "scripts" / "models" / "cauti_ann_loso_model.h5"

def load_keras_model_safe(model_path):
    """Load Keras model with error handling."""
    model_path = str(Path(model_path).resolve())
    try:
        return tf.keras.models.load_model(model_path)
    except:
        # If that fails, try without compile
        model = tf.keras.models.load_model(model_path, compile=False)
        if not model._is_compiled:
            model.compile(
                optimizer='adam',
                loss='binary_crossentropy',
                metrics=['AUC', 'accuracy']
            )
        return model

# Try loading .keras format first, then .h5
if model_path.exists():
    print(f"Loading model from: {model_path}")
    model = load_keras_model_safe(model_path)
    print("✅ Model loaded successfully!")
elif model_path_h5.exists():
    print(f"Loading model from: {model_path_h5}")
    model = load_keras_model_safe(model_path_h5)
    print("✅ Model loaded successfully!")
else:
    raise FileNotFoundError(
        f"Model not found at:\n  - {model_path}\n  - {model_path_h5}\n\n"
        "Please train the model first or ensure the model file exists."
    )

print(f"Model input shape: {model.input_shape}")
print(f"Model output shape: {model.output_shape}")


Loading model from: C:\Users\Coditas\Desktop\Projects\Cauti\scripts\models\cauti_ann_loso_model.keras
✅ Model loaded successfully!
Model input shape: (None, 68)
Model output shape: (None, 1)


In [26]:
# Fill remaining NaNs with 0 (needed for SHAP background data)
X_train = X_train.fillna(0)
X_test = X_test.fillna(0)

# Feature Importance & Correlation Analysis Notebook

This notebook focuses on:
1. **SHAP Feature Importance** - Calculate global feature importance using SHAP values
2. **Correlation Matrices** - Analyze correlations between features and target, and between features


# SHAP Feature Importance Analysis (LOSO)

Using SHAP (SHapley Additive exPlanations) to calculate feature importance for the trained ANN model.

**Approach: Leave-One-Subject-Out (LOSO)**
- Calculates SHAP values for each subject separately in the test set
- Aggregates feature importance across all subjects
- Ensures proper evaluation on unseen subjects (respects LOSO split)

**Why SHAP?**
- Works very well with Dense ANN
- Handles non-linearity + feature interactions
- Accepted in research papers, audits, and healthcare ML
- Gives both global feature importance (%) and local (per-patient) explanations
- Preferred for healthcare/medical data due to interpretability requirements


In [31]:
# Install shap if not already installed
try:
    import shap
    print(f"✅ SHAP version: {shap.__version__}")
except ImportError:
    print("Installing SHAP...")
    import sys
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "shap>=0.43.0"])
    import shap
    print(f"✅ SHAP installed: {shap.__version__}")

# Also check for tqdm (for progress bars)
try:
    import tqdm
    print(f"✅ tqdm version: {tqdm.__version__}")
except ImportError:
    print("Installing tqdm...")
    import sys
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tqdm>=4.65.0"])
    import tqdm
    print(f"✅ tqdm installed: {tqdm.__version__}")


✅ SHAP version: 0.49.1
✅ tqdm version: 4.66.5


In [None]:
# =====================================================
# Calculate SHAP Values for Feature Importance
# Using DeepExplainer for Dense Neural Networks
# LEAVE-ONE-SUBJECT-OUT (LOSO) Approach
# =====================================================
import shap
import pandas as pd
import numpy as np
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm

# Use a sample of training data as background (SHAP needs background data)
# Using a smaller sample for faster computation (100-500 samples is typical)
background_size = min(500, len(X_train))
print(f"Using {background_size} samples as background for SHAP...")

# Sample background data (stratified to maintain class distribution)
if len(y_train.unique()) > 1:
    # Stratified sampling
    from sklearn.model_selection import train_test_split
    _, X_background, _, _ = train_test_split(
        X_train, y_train,
        train_size=background_size,
        stratify=y_train,
        random_state=42
    )
else:
    # Random sampling if only one class
    X_background = X_train.sample(n=background_size, random_state=42)

print(f"Background data shape: {X_background.shape}")

# Initialize SHAP DeepExplainer for TensorFlow/Keras models
print("\nInitializing SHAP DeepExplainer...")
explainer = shap.DeepExplainer(model, X_background.values)

# =====================================================
# LOSO: Calculate SHAP values for each subject separately
# =====================================================
# Get test dataframe with subject_id (we need to merge back with subject_id)
test_df_with_ids = df[test_mask].copy()
test_df_with_ids = test_df_with_ids.reset_index(drop=True)

# Get unique subjects in test set
test_subjects = test_df_with_ids['subject_id'].unique()
print(f"\nTotal unique subjects in test set: {len(test_subjects)}")

# Option: Use all subjects or sample for faster computation
# For full LOSO, use all subjects. For faster computation, sample.
max_subjects_for_shap = 1000  # Use 1000 subjects for SHAP calculation
if len(test_subjects) > max_subjects_for_shap:
    print(f"Sampling {max_subjects_for_shap} subjects for SHAP calculation (out of {len(test_subjects)})...")
    # Stratified sampling by subject label (majority class per subject)
    subject_labels = test_df_with_ids.groupby('subject_id')['y'].agg(lambda x: 1 if x.sum() > 0 else 0)
    if len(subject_labels.unique()) > 1:
        from sklearn.model_selection import train_test_split
        subjects_sample, _, _, _ = train_test_split(
            test_subjects,
            subject_labels[test_subjects],
            train_size=max_subjects_for_shap,
            stratify=subject_labels[test_subjects],
            random_state=42
        )
    else:
        subjects_sample = np.random.choice(test_subjects, size=max_subjects_for_shap, replace=False)
    test_subjects = subjects_sample
    print(f"Using {len(test_subjects)} subjects for SHAP calculation")
else:
    print(f"Using all {len(test_subjects)} subjects for SHAP calculation")

# Store SHAP values for all subjects
all_shap_values = []

print(f"\nCalculating SHAP values using LOSO approach...")
print(f"Processing {len(test_subjects)} subjects...")

# Process each subject
for subject_id in tqdm(test_subjects, desc="Processing subjects"):
    # Get all admissions for this subject
    subject_mask = test_df_with_ids['subject_id'] == subject_id
    subject_admissions = test_df_with_ids[subject_mask]
    
    if len(subject_admissions) == 0:
        continue
    
    # Get feature data for this subject (without subject_id, hadm_id, y)
    subject_X = subject_admissions.drop(columns=['subject_id', 'hadm_id', 'y'] + leakage_features, errors='ignore')
    
    # Ensure columns match X_train order
    subject_X = subject_X[X.columns]
    
    # Fill any missing values
    subject_X = subject_X.fillna(0)
    
    # Calculate SHAP values for this subject's admissions
    try:
        subject_shap = explainer.shap_values(subject_X.values)
        
        # Handle output format (SHAP may return list for multi-output models)
        if isinstance(subject_shap, list):
            subject_shap = subject_shap[0]  # Take first output (binary classification)
        
        # Append to list
        all_shap_values.append(subject_shap)
        
    except Exception as e:
        print(f"Warning: Error calculating SHAP for subject {subject_id}: {e}")
        continue

# Concatenate all SHAP values from all subjects
if len(all_shap_values) > 0:
    shap_values = np.concatenate(all_shap_values, axis=0)
    print(f"\n✅ SHAP values calculated successfully!")
    print(f"   Total subjects processed: {len(all_shap_values)}")
    print(f"   Total admissions: {shap_values.shape[0]}")
    print(f"   SHAP values shape: {shap_values.shape}")
    print(f"   Features: {shap_values.shape[1]}")
else:
    raise ValueError("No SHAP values were calculated. Please check the data and model.")


Using 500 samples as background for SHAP...
Background data shape: (110041, 68)

Initializing SHAP DeepExplainer...

Total unique subjects in test set: 14437
Using all 14437 subjects for SHAP calculation

Calculating SHAP values using LOSO approach...
Processing 14437 subjects...


Processing subjects:  12%|█▏        | 1698/14437 [30:04<6:47:19,  1.92s/it] 



Processing subjects:  12%|█▏        | 1699/14437 [30:05<5:54:13,  1.67s/it]



Processing subjects:  12%|█▏        | 1700/14437 [30:05<4:38:05,  1.31s/it]



Processing subjects:  12%|█▏        | 1702/14437 [30:06<2:41:42,  1.31it/s]



Processing subjects:  12%|█▏        | 1704/14437 [30:06<1:37:54,  2.17it/s]



Processing subjects:  12%|█▏        | 1706/14437 [30:07<1:05:51,  3.22it/s]



Processing subjects:  12%|█▏        | 1708/14437 [30:07<48:11,  4.40it/s]  



Processing subjects:  12%|█▏        | 1710/14437 [30:07<38:40,  5.48it/s]



Processing subjects:  12%|█▏        | 1712/14437 [30:07<34:35,  6.13it/s]



Processing subjects:  12%|█▏        | 1714/14437 [30:08<32:17,  6.57it/s]



Processing subjects:  12%|█▏        | 1716/14437 [30:08<30:46,  6.89it/s]



Processing subjects:  12%|█▏        | 1718/14437 [30:08<30:48,  6.88it/s]



Processing subjects:  12%|█▏        | 1719/14437 [30:08<29:43,  7.13it/s]



Processing subjects:  12%|█▏        | 1720/14437 [30:09<37:15,  5.69it/s]



Processing subjects:  12%|█▏        | 1722/14437 [30:09<42:05,  5.03it/s]



Processing subjects:  12%|█▏        | 1724/14437 [30:10<44:54,  4.72it/s]



Processing subjects:  12%|█▏        | 1726/14437 [30:10<46:40,  4.54it/s]



Processing subjects:  12%|█▏        | 1729/14437 [30:10<36:03,  5.88it/s]



Processing subjects:  12%|█▏        | 1731/14437 [30:11<30:37,  6.91it/s]



Processing subjects:  12%|█▏        | 1733/14437 [30:11<27:30,  7.70it/s]



Processing subjects:  12%|█▏        | 1737/14437 [30:11<23:40,  8.94it/s]



Processing subjects:  12%|█▏        | 1740/14437 [30:12<21:58,  9.63it/s]



Processing subjects:  12%|█▏        | 1743/14437 [30:12<20:24, 10.36it/s]



Processing subjects:  12%|█▏        | 1745/14437 [30:12<18:33, 11.40it/s]



Processing subjects:  12%|█▏        | 1749/14437 [30:12<19:59, 10.58it/s]



Processing subjects:  12%|█▏        | 1753/14437 [30:13<18:59, 11.13it/s]



Processing subjects:  12%|█▏        | 1755/14437 [30:13<21:27,  9.85it/s]



Processing subjects:  12%|█▏        | 1757/14437 [30:13<23:50,  8.86it/s]



Processing subjects:  12%|█▏        | 1760/14437 [30:14<26:09,  8.08it/s]



Processing subjects:  12%|█▏        | 1762/14437 [30:14<28:02,  7.53it/s]



Processing subjects:  12%|█▏        | 1764/14437 [30:14<28:50,  7.32it/s]



Processing subjects:  12%|█▏        | 1766/14437 [30:15<28:26,  7.43it/s]



Processing subjects:  12%|█▏        | 1768/14437 [30:15<28:36,  7.38it/s]



Processing subjects:  12%|█▏        | 1771/14437 [30:15<24:16,  8.69it/s]



Processing subjects:  12%|█▏        | 1773/14437 [30:15<21:35,  9.78it/s]



Processing subjects:  12%|█▏        | 1776/14437 [30:16<24:56,  8.46it/s]



Processing subjects:  12%|█▏        | 1778/14437 [30:16<25:42,  8.21it/s]



Processing subjects:  12%|█▏        | 1782/14437 [30:16<22:54,  9.21it/s]



Processing subjects:  12%|█▏        | 1785/14437 [30:17<22:33,  9.35it/s]



Processing subjects:  12%|█▏        | 1788/14437 [30:17<21:24,  9.84it/s]



Processing subjects:  12%|█▏        | 1792/14437 [30:17<17:13, 12.23it/s]



Processing subjects:  12%|█▏        | 1794/14437 [30:17<18:20, 11.48it/s]



Processing subjects:  12%|█▏        | 1798/14437 [30:18<19:20, 10.89it/s]



Processing subjects:  12%|█▏        | 1800/14437 [30:18<19:57, 10.55it/s]



Processing subjects:  12%|█▏        | 1804/14437 [30:18<18:07, 11.61it/s]



Processing subjects:  13%|█▎        | 1806/14437 [30:19<17:56, 11.74it/s]



Processing subjects:  13%|█▎        | 1810/14437 [30:19<17:47, 11.82it/s]



Processing subjects:  13%|█▎        | 1812/14437 [30:19<17:45, 11.85it/s]



Processing subjects:  13%|█▎        | 1816/14437 [30:19<17:30, 12.02it/s]



Processing subjects:  13%|█▎        | 1818/14437 [30:20<18:11, 11.56it/s]



Processing subjects:  13%|█▎        | 1822/14437 [30:20<18:46, 11.20it/s]



Processing subjects:  13%|█▎        | 1824/14437 [30:20<18:31, 11.34it/s]



Processing subjects:  13%|█▎        | 1826/14437 [30:20<18:15, 11.51it/s]



Processing subjects:  13%|█▎        | 1830/14437 [30:21<19:26, 10.81it/s]



Processing subjects:  13%|█▎        | 1834/14437 [30:21<21:03,  9.97it/s]



Processing subjects:  13%|█▎        | 1841/14437 [30:21<12:24, 16.92it/s]



Processing subjects:  13%|█▎        | 1850/14437 [30:22<07:49, 26.83it/s]



Processing subjects:  13%|█▎        | 1858/14437 [30:22<07:14, 28.94it/s]



Processing subjects:  13%|█▎        | 1862/14437 [30:22<08:46, 23.89it/s]



Processing subjects:  13%|█▎        | 1868/14437 [30:23<14:26, 14.51it/s]



Processing subjects:  13%|█▎        | 1870/14437 [30:23<14:59, 13.98it/s]



Processing subjects:  13%|█▎        | 1874/14437 [30:23<16:14, 12.89it/s]



Processing subjects:  13%|█▎        | 1876/14437 [30:23<16:25, 12.74it/s]



Processing subjects:  13%|█▎        | 1880/14437 [30:24<16:16, 12.86it/s]



Processing subjects:  13%|█▎        | 1882/14437 [30:24<16:03, 13.04it/s]



Processing subjects:  13%|█▎        | 1886/14437 [30:24<16:55, 12.36it/s]



Processing subjects:  13%|█▎        | 1888/14437 [30:24<17:00, 12.30it/s]



Processing subjects:  13%|█▎        | 1892/14437 [30:25<17:40, 11.83it/s]



Processing subjects:  13%|█▎        | 1896/14437 [30:25<19:00, 11.00it/s]



Processing subjects:  13%|█▎        | 1898/14437 [30:25<19:12, 10.88it/s]



Processing subjects:  13%|█▎        | 1902/14437 [30:26<19:22, 10.78it/s]



Processing subjects:  13%|█▎        | 1904/14437 [30:26<19:28, 10.73it/s]



Processing subjects:  13%|█▎        | 1908/14437 [30:26<18:08, 11.51it/s]



Processing subjects:  13%|█▎        | 1910/14437 [30:26<17:49, 11.72it/s]



Processing subjects:  13%|█▎        | 1914/14437 [30:27<17:44, 11.76it/s]



Processing subjects:  13%|█▎        | 1918/14437 [30:27<18:28, 11.29it/s]



Processing subjects:  13%|█▎        | 1920/14437 [30:27<17:38, 11.83it/s]



Processing subjects:  13%|█▎        | 1924/14437 [30:27<17:42, 11.77it/s]



Processing subjects:  13%|█▎        | 1926/14437 [30:28<18:20, 11.37it/s]



Processing subjects:  13%|█▎        | 1930/14437 [30:28<17:32, 11.88it/s]



Processing subjects:  13%|█▎        | 1932/14437 [30:28<17:59, 11.59it/s]



Processing subjects:  13%|█▎        | 1936/14437 [30:29<18:03, 11.54it/s]



Processing subjects:  13%|█▎        | 1938/14437 [30:29<18:16, 11.40it/s]



Processing subjects:  13%|█▎        | 1940/14437 [30:29<18:55, 11.00it/s]



Processing subjects:  13%|█▎        | 1944/14437 [30:29<20:33, 10.13it/s]



Processing subjects:  13%|█▎        | 1946/14437 [30:30<19:59, 10.41it/s]



Processing subjects:  14%|█▎        | 1950/14437 [30:30<19:53, 10.46it/s]



Processing subjects:  14%|█▎        | 1952/14437 [30:30<20:41, 10.06it/s]



Processing subjects:  14%|█▎        | 1956/14437 [30:30<18:24, 11.31it/s]



Processing subjects:  14%|█▎        | 1960/14437 [30:31<17:48, 11.67it/s]



Processing subjects:  14%|█▎        | 1962/14437 [30:31<19:09, 10.85it/s]



Processing subjects:  14%|█▎        | 1966/14437 [30:31<18:35, 11.18it/s]



Processing subjects:  14%|█▎        | 1968/14437 [30:32<18:19, 11.34it/s]



Processing subjects:  14%|█▎        | 1970/14437 [30:32<17:45, 11.70it/s]



Processing subjects:  14%|█▎        | 1974/14437 [30:32<18:28, 11.24it/s]



Processing subjects:  14%|█▎        | 1978/14437 [30:32<18:04, 11.49it/s]



Processing subjects:  14%|█▎        | 1980/14437 [30:33<17:10, 12.09it/s]



Processing subjects:  14%|█▎        | 1984/14437 [30:33<15:39, 13.25it/s]



Processing subjects:  14%|█▍        | 1988/14437 [30:33<16:24, 12.65it/s]



Processing subjects:  14%|█▍        | 1994/14437 [30:34<14:39, 14.15it/s]



Processing subjects:  14%|█▍        | 1998/14437 [30:34<13:48, 15.01it/s]



Processing subjects:  14%|█▍        | 2000/14437 [30:34<14:22, 14.42it/s]



Processing subjects:  14%|█▍        | 2004/14437 [30:34<19:20, 10.72it/s]



Processing subjects:  14%|█▍        | 2006/14437 [30:35<18:11, 11.39it/s]



Processing subjects:  14%|█▍        | 2010/14437 [30:35<17:32, 11.80it/s]



Processing subjects:  14%|█▍        | 2014/14437 [30:35<16:52, 12.27it/s]



Processing subjects:  14%|█▍        | 2016/14437 [30:35<17:37, 11.75it/s]



Processing subjects:  14%|█▍        | 2020/14437 [30:36<17:47, 11.64it/s]



Processing subjects:  14%|█▍        | 2022/14437 [30:36<18:19, 11.30it/s]



Processing subjects:  14%|█▍        | 2027/14437 [30:36<13:43, 15.07it/s]



Processing subjects:  14%|█▍        | 2044/14437 [30:37<05:49, 35.42it/s]



Processing subjects:  14%|█▍        | 2054/14437 [30:37<05:01, 41.03it/s]



Processing subjects:  14%|█▍        | 2060/14437 [30:37<04:36, 44.77it/s]



Processing subjects:  14%|█▍        | 2065/14437 [30:37<06:00, 34.32it/s]



Processing subjects:  14%|█▍        | 2069/14437 [30:37<09:13, 22.33it/s]



Processing subjects:  14%|█▍        | 2079/14437 [30:38<06:52, 29.94it/s]



Processing subjects:  14%|█▍        | 2089/14437 [30:38<06:19, 32.50it/s]



Processing subjects:  14%|█▍        | 2093/14437 [30:38<09:12, 22.33it/s]



Processing subjects:  15%|█▍        | 2096/14437 [30:39<10:50, 18.97it/s]



Processing subjects:  15%|█▍        | 2099/14437 [30:39<12:38, 16.28it/s]



Processing subjects:  15%|█▍        | 2104/14437 [30:39<14:36, 14.07it/s]



Processing subjects:  15%|█▍        | 2108/14437 [30:40<16:23, 12.54it/s]



Processing subjects:  15%|█▍        | 2110/14437 [30:40<16:29, 12.45it/s]



Processing subjects:  15%|█▍        | 2114/14437 [30:40<15:11, 13.52it/s]



Processing subjects:  15%|█▍        | 2120/14437 [30:40<12:56, 15.86it/s]



Processing subjects:  15%|█▍        | 2124/14437 [30:41<12:54, 15.89it/s]



Processing subjects:  15%|█▍        | 2126/14437 [30:41<14:15, 14.38it/s]



Processing subjects:  15%|█▍        | 2130/14437 [30:41<14:29, 14.16it/s]



Processing subjects:  15%|█▍        | 2134/14437 [30:41<14:01, 14.62it/s]



Processing subjects:  15%|█▍        | 2138/14437 [30:42<13:26, 15.24it/s]



Processing subjects:  15%|█▍        | 2142/14437 [30:42<14:00, 14.63it/s]



Processing subjects:  15%|█▍        | 2144/14437 [30:42<14:54, 13.75it/s]



Processing subjects:  15%|█▍        | 2148/14437 [30:42<14:44, 13.90it/s]



Processing subjects:  15%|█▍        | 2152/14437 [30:43<14:05, 14.53it/s]



Processing subjects:  15%|█▍        | 2156/14437 [30:43<13:49, 14.81it/s]



Processing subjects:  15%|█▍        | 2160/14437 [30:43<13:54, 14.71it/s]



Processing subjects:  15%|█▍        | 2164/14437 [30:43<14:10, 14.43it/s]



Processing subjects:  15%|█▌        | 2166/14437 [30:44<16:45, 12.21it/s]



Processing subjects:  15%|█▌        | 2170/14437 [30:44<19:51, 10.29it/s]



Processing subjects:  15%|█▌        | 2174/14437 [30:44<16:32, 12.35it/s]



Processing subjects:  15%|█▌        | 2178/14437 [30:45<13:53, 14.71it/s]



Processing subjects:  15%|█▌        | 2182/14437 [30:45<13:24, 15.22it/s]



Processing subjects:  15%|█▌        | 2186/14437 [30:45<14:30, 14.07it/s]



Processing subjects:  15%|█▌        | 2188/14437 [30:45<16:30, 12.37it/s]



Processing subjects:  15%|█▌        | 2194/14437 [30:46<14:35, 13.98it/s]



Processing subjects:  15%|█▌        | 2198/14437 [30:46<14:20, 14.22it/s]



Processing subjects:  15%|█▌        | 2200/14437 [30:46<15:16, 13.35it/s]



Processing subjects:  15%|█▌        | 2204/14437 [30:47<16:03, 12.69it/s]



Processing subjects:  15%|█▌        | 2208/14437 [30:47<15:24, 13.22it/s]



Processing subjects:  15%|█▌        | 2210/14437 [30:47<15:59, 12.74it/s]



Processing subjects:  15%|█▌        | 2214/14437 [30:47<15:35, 13.07it/s]



Processing subjects:  15%|█▌        | 2218/14437 [30:48<16:03, 12.68it/s]



Processing subjects:  15%|█▌        | 2222/14437 [30:48<15:32, 13.10it/s]



Processing subjects:  15%|█▌        | 2224/14437 [30:48<14:58, 13.60it/s]



Processing subjects:  15%|█▌        | 2228/14437 [30:48<14:56, 13.62it/s]



Processing subjects:  15%|█▌        | 2232/14437 [30:49<14:30, 14.02it/s]



Processing subjects:  15%|█▌        | 2236/14437 [30:49<13:14, 15.35it/s]



Processing subjects:  16%|█▌        | 2238/14437 [30:49<14:42, 13.82it/s]



Processing subjects:  16%|█▌        | 2240/14437 [30:49<16:57, 11.99it/s]



Processing subjects:  16%|█▌        | 2317/14437 [31:58<2:14:27,  1.50it/s]



Processing subjects:  16%|█▌        | 2319/14437 [31:58<1:39:08,  2.04it/s]



Processing subjects:  16%|█▌        | 2321/14437 [31:59<1:03:38,  3.17it/s]



Processing subjects:  16%|█▌        | 2322/14437 [31:59<50:58,  3.96it/s]  



Processing subjects:  16%|█▌        | 2325/14437 [31:59<37:44,  5.35it/s]



Processing subjects:  16%|█▌        | 2327/14437 [31:59<34:27,  5.86it/s]



Processing subjects:  16%|█▌        | 2329/14437 [32:00<34:38,  5.83it/s]



Processing subjects:  16%|█▌        | 2330/14437 [32:00<1:09:28,  2.90it/s]



Processing subjects:  16%|█▌        | 2334/14437 [32:04<2:44:59,  1.22it/s]



Processing subjects:  16%|█▌        | 2338/14437 [32:07<2:27:59,  1.36it/s]



Processing subjects:  16%|█▌        | 2341/14437 [32:08<1:49:29,  1.84it/s]



Processing subjects:  16%|█▌        | 2343/14437 [32:09<1:37:37,  2.06it/s]



Processing subjects:  16%|█▌        | 2344/14437 [32:10<2:05:03,  1.61it/s]



Processing subjects:  16%|█▌        | 2346/14437 [32:13<3:24:14,  1.01s/it]



Processing subjects:  16%|█▋        | 2350/14437 [32:15<1:53:12,  1.78it/s]



Processing subjects:  16%|█▋        | 2351/14437 [32:15<1:51:43,  1.80it/s]



Processing subjects:  17%|█▋        | 2434/14437 [33:11<2:43:38,  1.22it/s]


KeyboardInterrupt: 

In [None]:
# =====================================================
# Calculate Global Feature Importance as Percentages
# Using mean absolute SHAP values (LOSO aggregated)
# =====================================================

# Calculate mean absolute SHAP value for each feature
# This represents the average impact of each feature on the model output
# Aggregated across all subjects using LOSO approach
mean_abs_shap = np.abs(shap_values).mean(axis=0)

# Convert to percentages
total_importance = mean_abs_shap.sum()
feature_importance_pct = (mean_abs_shap / total_importance) * 100

# Create DataFrame with feature names and importance percentages
feature_importance_df = pd.DataFrame({
    'feature': X.columns,
    'importance': mean_abs_shap,
    'importance_percentage': feature_importance_pct
})

# Sort by importance (descending)
feature_importance_df = feature_importance_df.sort_values('importance_percentage', ascending=False).reset_index(drop=True)

print("="*60)
print("GLOBAL FEATURE IMPORTANCE (LOSO Aggregated)")
print("="*60)
print(f"\nTop 20 Most Important Features:")
print(feature_importance_df.head(20).to_string(index=False))
print(f"\nTotal features: {len(feature_importance_df)}")
print(f"Total admissions analyzed: {shap_values.shape[0]}")
print(f"Sum of importance percentages: {feature_importance_df['importance_percentage'].sum():.2f}%")
print("="*60)


In [None]:
# =====================================================
# Save Feature Importance to CSV
# =====================================================
from pathlib import Path

# Save to data/analysis directory
output_dir = PROJECT_ROOT / "data" / "analysis"
output_dir.mkdir(parents=True, exist_ok=True)

output_path = output_dir / "feature_importance_shap.csv"

# Save with importance as percentage (rounded to 4 decimal places)
feature_importance_df.to_csv(output_path, index=False, float_format='%.4f')

print(f"✅ Feature importance saved to: {output_path}")
print(f"   Total features: {len(feature_importance_df)}")
print(f"   Columns: feature, importance, importance_percentage")


# Correlation Matrix Analysis

Plot correlation matrices for:
1. Features vs Target (y)
2. Features vs Features


In [None]:
# =====================================================
# Correlation Matrix: Features vs Target (y)
# =====================================================
import matplotlib.pyplot as plt
import seaborn as sns

# Calculate correlation between each feature and target
corr_with_target = X.corrwith(y).sort_values(ascending=False)

# Create a figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Bar plot of correlations
ax1 = axes[0]
top_n = min(30, len(corr_with_target))  # Show top 30 features
top_corr = corr_with_target.head(top_n)
colors = ['red' if x < 0 else 'blue' for x in top_corr.values]
ax1.barh(range(len(top_corr)), top_corr.values, color=colors)
ax1.set_yticks(range(len(top_corr)))
ax1.set_yticklabels(top_corr.index, fontsize=8)
ax1.set_xlabel('Correlation with Target (y)', fontsize=12)
ax1.set_title(f'Top {top_n} Features: Correlation with CAUTI (y)', fontsize=14, fontweight='bold')
ax1.axvline(x=0, color='black', linestyle='--', linewidth=0.8)
ax1.grid(axis='x', alpha=0.3)
ax1.invert_yaxis()

# Plot 2: Heatmap of top correlations
ax2 = axes[1]
# Create a DataFrame for heatmap (single row)
corr_df = pd.DataFrame(top_corr.values.reshape(1, -1), 
                       columns=top_corr.index,
                       index=['CAUTI (y)'])
sns.heatmap(corr_df, annot=False, cmap='RdBu_r', center=0, 
            cbar_kws={'label': 'Correlation'}, ax=ax2, fmt='.2f')
ax2.set_title(f'Top {top_n} Features: Correlation Heatmap', fontsize=14, fontweight='bold')
ax2.set_xticklabels(top_corr.index, rotation=90, ha='right', fontsize=8)

plt.tight_layout()
plt.savefig(output_dir / 'correlation_features_vs_target.png', dpi=300, bbox_inches='tight')
print(f"✅ Correlation plot saved to: {output_dir / 'correlation_features_vs_target.png'}")
plt.show()

# Print summary statistics
print(f"\n=== Correlation with Target Summary ===")
print(f"Total features: {len(corr_with_target)}")
print(f"Positive correlations: {(corr_with_target > 0).sum()}")
print(f"Negative correlations: {(corr_with_target < 0).sum()}")
print(f"Strong positive (>0.3): {(corr_with_target > 0.3).sum()}")
print(f"Strong negative (<-0.3): {(corr_with_target < -0.3).sum()}")
print(f"\nTop 10 Positive Correlations:")
print(corr_with_target.head(10))
print(f"\nTop 10 Negative Correlations:")
print(corr_with_target.tail(10))


In [None]:
# =====================================================
# Correlation Matrix: Features vs Features
# =====================================================

# Calculate full correlation matrix for features
# Note: For large datasets, this can be memory-intensive
# We'll use a sample or focus on top features

# Option 1: Use all features (if dataset is manageable)
if X.shape[1] <= 100:
    print(f"Calculating full correlation matrix for {X.shape[1]} features...")
    feature_corr_matrix = X.corr()
    use_all_features = True
else:
    # Option 2: Use top features by importance or correlation with target
    print(f"Dataset has {X.shape[1]} features. Using top 50 features for correlation matrix...")
    top_features = corr_with_target.abs().nlargest(50).index.tolist()
    feature_corr_matrix = X[top_features].corr()
    use_all_features = False
    print(f"Selected {len(top_features)} top features for correlation analysis")

# Create correlation heatmap
plt.figure(figsize=(14, 12))
mask = np.triu(np.ones_like(feature_corr_matrix, dtype=bool), k=1)  # Mask upper triangle

sns.heatmap(feature_corr_matrix, 
            mask=mask,
            annot=False,  # Set to True if you want correlation values on plot
            cmap='coolwarm', 
            center=0,
            square=True,
            linewidths=0.5,
            cbar_kws={'label': 'Correlation Coefficient'},
            fmt='.2f')

plt.title('Feature-to-Feature Correlation Matrix', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig(output_dir / 'correlation_features_vs_features.png', dpi=300, bbox_inches='tight')
print(f"✅ Feature correlation matrix saved to: {output_dir / 'correlation_features_vs_features.png'}")
plt.show()

# Find highly correlated feature pairs
print(f"\n=== Highly Correlated Feature Pairs ===")
# Get lower triangle (excluding diagonal)
corr_pairs = feature_corr_matrix.where(np.tril(np.ones(feature_corr_matrix.shape), k=-1).astype(bool))
corr_pairs = corr_pairs.stack().reset_index()
corr_pairs.columns = ['Feature_1', 'Feature_2', 'Correlation']
corr_pairs['Abs_Correlation'] = corr_pairs['Correlation'].abs()

# Filter high correlations (>= 0.7 or <= -0.7)
high_corr = corr_pairs[corr_pairs['Abs_Correlation'] >= 0.7].sort_values('Abs_Correlation', ascending=False)
print(f"\nFeature pairs with |correlation| >= 0.7: {len(high_corr)}")
if len(high_corr) > 0:
    print(high_corr.head(20).to_string(index=False))
    
    # Save high correlations to CSV
    high_corr_path = output_dir / 'high_correlated_feature_pairs.csv'
    high_corr.to_csv(high_corr_path, index=False)
    print(f"\n✅ High correlation pairs saved to: {high_corr_path}")
else:
    print("No highly correlated feature pairs found (|correlation| >= 0.7)")


In [None]:
# =====================================================
# Summary: Save correlation with target to CSV
# =====================================================

# Save correlation with target
corr_target_df = pd.DataFrame({
    'feature': corr_with_target.index,
    'correlation_with_target': corr_with_target.values
}).sort_values('correlation_with_target', ascending=False, key=abs)

corr_target_path = output_dir / 'correlation_with_target.csv'
corr_target_df.to_csv(corr_target_path, index=False, float_format='%.6f')
print(f"✅ Correlation with target saved to: {corr_target_path}")

print("\n" + "="*60)
print("SUMMARY")
print("="*60)
print(f"✅ Feature importance (SHAP) saved: {output_dir / 'feature_importance_shap.csv'}")
print(f"✅ Correlation with target saved: {corr_target_path}")
print(f"✅ Correlation plots saved:")
print(f"   - {output_dir / 'correlation_features_vs_target.png'}")
print(f"   - {output_dir / 'correlation_features_vs_features.png'}")
# Check if high_corr exists and has data
if 'high_corr' in globals() and len(high_corr) > 0:
    print(f"✅ High correlation pairs saved: {output_dir / 'high_correlated_feature_pairs.csv'}")
print("="*60)
