# DeepMind-Inspired Cardiovascular Risk Prediction

**Hack4Health Project**

This notebook implements a modern, uncertainty-aware framework for heart disease prediction. Instead of standard classifiers, we leverage:
1. **TabPFN (Foundation Model):** A Transformer pre-trained on tabular data priors.
2. **Uncertainty Quantification (MC Dropout):** Bayesian-style confidence estimation.
3. **Concept-Bottleneck Interpretability:** Grouping features into clinical concepts (Vitals, Lifestyle) for doctor-friendly explanations.


In [None]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import shap
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, roc_curve

# Add src to path
sys.path.append(os.path.abspath('../src'))

try:
    from utils_data import load_and_preprocess_data, get_concept_map
    from utils_model import get_xgboost, get_tabpfn, UncertaintyModel
except ImportError:
    print("Source modules not found. Ensure you are running this from the 'notebooks' directory and 'src' exists.")


In [None]:
# Load Data
# We use the 'heart_processed.csv' dataset which contains rich clinical features
DATA_PATH_PROC = '../Data/Heart Attack/heart_processed.csv'
DATA_PATH_BASE = '../Data/Cardiac Failure/cardio_base.csv'

# Toggle 'DATA_PATH_BASE' to None if you want only the high-fidelity set.
X, y, concept_map = load_and_preprocess_data(DATA_PATH_PROC, base_path=DATA_PATH_BASE)

print("Concepts Defined:", list(concept_map.keys()))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"Training Data Shape: {X_train.shape}")
print(f"Testing Data Shape: {X_test.shape}")


In [None]:
# Configuration
sns.set_style("whitegrid")
plt.rcParams['font.size'] = 12


## 1. Baseline Model: XGBoost
The industry standard for tabular data.

In [None]:
xgb_model = get_xgboost()
xgb_model.fit(X_train, y_train)

y_pred_xgb = xgb_model.predict(X_test)
print("=== XGBoost Classification Report ===")
print(classification_report(y_test, y_pred_xgb))


## 2. Novel Model: TabPFN (Transformer)
A Prior-Data Fitted Network that typically generalizes better on small, complex medical datasets without tuning.

In [None]:
print("Fitting TabPFN (Pre-trained foundation model)...")
tabpfn_model = get_tabpfn()
tabpfn_model.fit(X_train, y_train)

y_pred_tab = tabpfn_model.predict(X_test)
print("=== TabPFN Classification Report ===")
print(classification_report(y_test, y_pred_tab))


## 3. Uncertainty Quantification (MC Dropout)
Medical AI requires knowing *when* the model is unsure. We use Monte Carlo Dropout to generate Bayesian confidence intervals.

In [None]:
print("Training Uncertainty-Aware Network...")
unc_model = UncertaintyModel(epochs=500, lr=0.005)
unc_model.fit(X_train, y_train)

mean_preds, std_preds = unc_model.predict_uncertainty(X_test, n_samples=100)

# Visualize
plt.figure(figsize=(10, 6))
sc = plt.scatter(mean_preds, std_preds, c=y_test, cmap='coolwarm', alpha=0.6, edgecolors='k')
plt.axhline(0, color='gray', linestyle='--')
plt.xlabel('Predicted Risk (Probability)')
plt.ylabel('Model Uncertainty (Std Dev)')
plt.title('Risk vs. Confidence (DeepMind-Style)')
plt.colorbar(sc, label='Actual Heart Disease (1=Yes)')
plt.text(0.5, max(std_preds)*0.9, 'High Uncertainty Zone', ha='center', color='red')
plt.show()


## 4. Concept-Based Interpretability
Instead of raw feature importance, we explain risk in terms of clinical concepts.

In [None]:
# Calculate SHAP values for Baseline (XGBoost)
explainer = shap.TreeExplainer(xgb_model)
shap_values = explainer.shap_values(X_test)

# Aggregate by Concept
concept_importance = pd.DataFrame(index=X_test.index)
abs_shap = np.abs(shap_values)

for concept, features in concept_map.items():
    # Find column indices for this concept
    indices = [X_test.columns.get_loc(f) for f in features if f in X_test.columns]
    if indices:
        # Sum absolute SHAP values for these features
        concept_importance[concept] = abs_shap[:, indices].sum(axis=1)

# Plot Mean Importance
plt.figure(figsize=(10, 5))
sns.barplot(data=concept_importance, ci=None, palette="viridis")
plt.title("Driver of Risk by Clinical Concept")
plt.ylabel("Mean Corrective Impact (SHAP magnitude)")
plt.show()

print("This plot tells the doctor: 'Is the patient's risk driven by their Vitals, their Lifestyle, or Demographics?'")
