# Heart Failure Clinical Records: Modeling & Survival Analysis

**Goal:** Predict mortality and identify clinical risk factors using robust, interpretable models.


In [1]:
# 1. Imports & Setup
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    roc_curve, auc, precision_recall_curve, average_precision_score,
    classification_report, confusion_matrix, roc_auc_score
)
from imblearn.over_sampling import SMOTE
import warnings
sns.set_palette("husl")
plt.style.use('seaborn-v0_8')
warnings.filterwarnings('ignore')


## 2. Load Feature-Engineered Data


In [None]:
df = pd.read_csv('heart_failure.csv')
df.head()


## 3. Prepare Features and Handle Imbalance
- **Stratified split**
- **SMOTE oversampling**
- From EDA is know that the two classes are not balanced

In [None]:


X = df.drop(columns=['DEATH_EVENT'])
y = df['DEATH_EVENT']

X_train, X_test, y_train, y_test = train_test_split(
    X, y, stratify=y, test_size=0.25, random_state=42
)
print("Class balance (train):")
print(y_train.value_counts())

# SMOTE oversampling on training set
# How does SMOTE work? 
# SMOTE (Synthetic Minority Over-sampling Technique) generates synthetic samples
# for the minority class by interpolating between existing minority class samples.
# This helps to balance the class distribution in the training set.
smote = SMOTE(random_state=42)
X_res, y_res = smote.fit_resample(X_train, y_train)
print("Class balance after SMOTE:")
print(pd.Series(y_res).value_counts())


## 4. Logistic Regression Assumption Checks
### (VIF, linearity, Cook's distance, perfect separation)


In [None]:
# VIF
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.tools.tools import add_constant

X_vif = add_constant(X_res)
vif_df = pd.DataFrame({
    'feature': X_vif.columns,
    'VIF': [variance_inflation_factor(X_vif.values, i) for i in range(X_vif.shape[1])]
})
display(vif_df)


In [None]:
# Linearity of logit (visual check, on original training data)
for col in ['age', 'ejection_fraction', 'serum_creatinine', 'platelets', 'time']:
    df['bin'] = pd.qcut(df[col], 10, duplicates='drop')
    grouped = df.groupby('bin')['DEATH_EVENT'].mean()
    plt.figure()
    grouped.plot(marker='o')
    plt.title(f"Linearity check: Mean DEATH_EVENT by {col} decile")
    plt.ylabel("Mean DEATH_EVENT")
    plt.xlabel(col)
    plt.xticks(rotation=30)
    plt.show()
df.drop('bin', axis=1, inplace=True)


In [None]:
# Cook's distance (on original data, not oversampled)
import statsmodels.api as sm

logit_mod = sm.Logit(y, add_constant(X)).fit(disp=0)
influence = logit_mod.get_influence()
cooks = influence.cooks_distance[0]

plt.figure(figsize=(10,3))
plt.stem(np.arange(len(cooks)), cooks, markerfmt=",")
plt.title("Cook's Distance for Influential Observations (Logistic Regression)")
plt.xlabel("Observation")
plt.ylabel("Cook's Distance")
plt.show()

print(f"Observations with Cook's Distance > 4/n: {sum(cooks > 4/len(cooks))} (out of {len(cooks)})")


In [None]:
# Perfect separation check
for col in X.columns:
    print(f"{col}:")
    print(df.groupby(col)['DEATH_EVENT'].mean())


## 5. Hyperparameter Tuning: Logistic Regression & Random Forest


In [None]:
# Logistic Regression: Tune C (inverse regularization), solver
lr_params = {
    'C': [0.01, 0.1, 1, 10, 100],
    'solver': ['liblinear', 'lbfgs'],
    'class_weight': ['balanced']
}
lr = LogisticRegression(max_iter=1000, random_state=42)
lr_gs = GridSearchCV(lr, lr_params, cv=5, scoring='roc_auc', n_jobs=-1)
lr_gs.fit(X_res, y_res)
print("Best Logistic Regression Params:", lr_gs.best_params_)
print("Best CV ROC AUC: %.3f" % lr_gs.best_score_)


In [25]:
# Random Forest: Tune n_estimators, max_depth, min_samples_split
rf_params = {
    'n_estimators': [50, 100, 200],
    'max_depth': [None, 4, 8, 12],
    'min_samples_split': [2, 5, 10],
    'class_weight': ['balanced']
}
rf = RandomForestClassifier(random_state=42)
rf_gs = GridSearchCV(rf, rf_params, cv=5, scoring='roc_auc', n_jobs=-1)
rf_gs.fit(X_res, y_res)
print("Best Random Forest Params:", rf_gs.best_params_)
print("Best CV ROC AUC: %.3f" % rf_gs.best_score_)


Best Random Forest Params: {'class_weight': 'balanced', 'max_depth': None, 'min_samples_split': 5, 'n_estimators': 50}
Best CV ROC AUC: 0.953


## 6. Evaluation: Test Set Results


In [None]:
# Use the best estimators for final prediction/evaluation
best_lr = lr_gs.best_estimator_
best_rf = rf_gs.best_estimator_

probs_lr = best_lr.predict_proba(X_test)[:,1]
preds_lr = best_lr.predict(X_test)
probs_rf = best_rf.predict_proba(X_test)[:,1]
preds_rf = best_rf.predict(X_test)

print("Best Logistic Regression ROC AUC: %.3f" % roc_auc_score(y_test, probs_lr))
print("Best Random Forest ROC AUC: %.3f" % roc_auc_score(y_test, probs_rf))


In [None]:
# ROC and Precision-Recall curves
plt.figure(figsize=(10,5))
for name, probs in zip(['LR','RF'], [probs_lr, probs_rf]):
    fpr, tpr, _ = roc_curve(y_test, probs)
    auc_val = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"{name} (AUC={auc_val:.2f})")
plt.plot([0,1],[0,1],'k--')
plt.title('ROC Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend()
plt.show()

plt.figure(figsize=(10,5))
for name, probs in zip(['LR','RF'], [probs_lr, probs_rf]):
    precision, recall, _ = precision_recall_curve(y_test, probs)
    ap = average_precision_score(y_test, probs)
    plt.plot(recall, precision, label=f"{name} (AP={ap:.2f})")
plt.title('Precision-Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend()
plt.show()


In [None]:
# Confusion matrix and classification report (Random Forest)
from sklearn.metrics import ConfusionMatrixDisplay

cm = confusion_matrix(y_test, preds_rf)
ConfusionMatrixDisplay(cm, display_labels=['Survived','Died']).plot(cmap='coolwarm')
plt.title("Random Forest: Confusion Matrix")
plt.show()

print("Random Forest Classification Report:\n", classification_report(y_test, preds_rf))


In [None]:
# Random Forest Feature Importance
importances = pd.Series(best_rf.feature_importances_, index=features)
importances.sort_values().plot.barh(figsize=(8,5))
plt.title("Random Forest Feature Importances")
plt.show()


## 7. (Optional) SHAP: Model Explainability


In [None]:
try:
    import shap
    explainer = shap.TreeExplainer(best_rf)
    shap_values = explainer.shap_values(X_test)
    shap.summary_plot(shap_values[1], X_test)
except ImportError:
    print("SHAP not installed (pip install shap to enable explainability plots).")


## 8. Subgroup Survival Analysis (Kaplan-Meier by sex, diabetes, high BP)


In [None]:
from lifelines import KaplanMeierFitter

for col in ['sex','diabetes','high_blood_pressure']:
    for val in sorted(df[col].unique()):
        label = f"{col}={val}"
        ix = df[col]==val
        kmf = KaplanMeierFitter()
        kmf.fit(df.loc[ix,'time'], event_observed=df.loc[ix,'DEATH_EVENT'], label=label)
        kmf.plot_survival_function()
    plt.title(f"Survival Curve by {col.title()}")
    plt.xlabel("Days")
    plt.ylabel("Survival Probability")
    plt.legend()
    plt.show()


## 9. Clinical Insights & Next Steps

- Imbalance addressed with SMOTE and class weights.
- Hyperparameter tuning improved model ROC AUC.
- Key predictors: age, ejection fraction, serum creatinine, comorbidity.
- High-risk subgroups: age > 70, high serum creatinine, multiple comorbidities.
- Next steps: external validation, deeper explainability, and more advanced models if needed.
