
# Heart Disease Explainable AI (XAI) Report

This notebook consolidates the modelling, explainability, and fairness analysis workflow for the UCI Heart Disease dataset using the shared project pipeline utilities.



## Notebook Roadmap
1. Data Acquisition & Cleaning  
2. Exploratory Analysis  
3. Model Training & Evaluation  
4. SHAP Explainability  
5. Fairness Analysis  
6. Conclusions & Deployment Readiness


In [None]:

import sys
import subprocess

required_packages = [
    "ucimlrepo",
    "shap",
    "fairlearn",
    "seaborn",
    "numpy",
    "pandas",
    "matplotlib",
    "scikit-learn",
    "joblib",
]

subprocess.run(
    [
        sys.executable,
        "-m",
        "pip",
        "install",
        "--quiet",
        "--upgrade",
        "--break-system-packages",
        *required_packages,
    ],
    check=True,
)


In [None]:

import json
import random
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import Image, Markdown, display

PROJECT_ROOT = Path.cwd().resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from XAI_HeartDisease.pipeline import (
    ARTIFACT_FILENAMES,
    CATEGORICAL_FEATURES,
    NUMERIC_FEATURES,
    TARGET_COLUMN,
    ensure_project_directories,
    generate_sample_predictions,
    generate_shap_artifacts,
    get_project_paths,
    load_or_fetch_data,
    load_serialized_model,
    prepare_features_targets,
    serialize_artifacts,
    train_and_evaluate_models,
    train_test_split_data,
    evaluate_fairness,
)

random.seed(42)
np.random.seed(42)

ensure_project_directories()
paths = get_project_paths()

sns.set_theme(style="whitegrid", context="notebook")
plt.rcParams["figure.dpi"] = 120

print("Project paths:")
for key, value in paths.items():
    print(f"- {key}: {value}")



## 1. Data Acquisition & Cleaning

We source the canonical Cleveland subset of the UCI Heart Disease dataset via the shared pipeline loader. The pipeline caches the dataset locally (`data/heart_disease.csv`) for reproducibility and quick re-runs, and standardises schema/target naming conventions used by the downstream Streamlit application.


In [None]:

raw_df = load_or_fetch_data(refresh=False)
print(f"Dataset shape: {raw_df.shape}")
raw_df.head()


In [None]:

missing_summary = raw_df.replace("?", np.nan).isna().sum().sort_values(ascending=False)
missing_summary = missing_summary[missing_summary > 0]

if not missing_summary.empty:
    display(missing_summary.to_frame("missing_values"))
else:
    print("No missing values detected in the cached dataset.")


In [None]:

X, y = prepare_features_targets(raw_df)
print(f"Feature matrix shape: {X.shape}")
print(f"Target distribution (raw counts): {y.value_counts().to_dict()}")
X.sample(min(5, len(X)), random_state=42)



## 2. Exploratory Analysis

We inspect population-level patterns to contextualise the modelling task, focusing on class balance, demographic splits, and continuous risk factors commonly cited in the cardiology literature.


In [None]:

target_plot_path = paths["visuals"] / ARTIFACT_FILENAMES["target_distribution"]
fig, ax = plt.subplots(figsize=(5.5, 4.5))
sns.countplot(x=y.map({0: "No disease", 1: "Heart disease"}), palette="Set2", ax=ax)
ax.set_xlabel("Clinical outcome")
ax.set_ylabel("Number of patients")
ax.set_title("Outcome distribution")
for container in ax.containers:
    ax.bar_label(container)
plt.tight_layout()
fig.savefig(target_plot_path, dpi=300, bbox_inches="tight")
plt.show()
print(f"Target distribution visual saved to: {target_plot_path}")


In [None]:

numeric_cols = [col for col in NUMERIC_FEATURES if col in X.columns]
correlation_plot_path = paths["visuals"] / ARTIFACT_FILENAMES["correlation_heatmap"]

if numeric_cols:
    corr_matrix = X[numeric_cols].astype(float).corr()
    fig, ax = plt.subplots(figsize=(6.5, 5.5))
    sns.heatmap(corr_matrix, annot=True, cmap="RdBu_r", center=0, ax=ax)
    ax.set_title("Correlation heatmap of numeric predictors")
    plt.tight_layout()
    fig.savefig(correlation_plot_path, dpi=300, bbox_inches="tight")
    plt.show()
    print(f"Correlation heatmap saved to: {correlation_plot_path}")
else:
    print("No numeric columns found for correlation analysis.")



## 3. Model Training & Evaluation

The shared pipeline fits three baseline models (logistic regression, random forest, gradient boosting) with identical preprocessing via `ColumnTransformer`. Evaluation uses a stratified 80/20 split, reporting accuracy, precision, recall, F1, and ROC-AUC to balance overall correctness and discrimination power. Artefacts are persisted under `visuals/` and `reports/` to support Streamlit dashboards.


In [None]:

splits = train_test_split_data(X, y, test_size=0.2, random_state=42)
X_train, X_test = splits["X_train"], splits["X_test"]
y_train, y_test = splits["y_train"], splits["y_test"]

metrics_df, best_model_name, best_model, evaluation_assets = train_and_evaluate_models(
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    visuals_dir=paths["visuals"],
    reports_dir=paths["reports"],
    random_state=42,
)

formatters = {col: "{:.3f}" for col in metrics_df.columns if col != "model"}
display(metrics_df.style.format(formatters).set_caption("Model performance summary"))
print(f"Best performing model: {best_model_name}")


In [None]:

confusion_image = Image(filename=str(evaluation_assets["confusion_matrix_png"]))
roc_image = Image(filename=str(evaluation_assets["roc_curves_png"]))
feature_importance_image = Image(filename=str(evaluation_assets["feature_importance_png"]))

display(Markdown("### Evaluation Visuals"))
display(Markdown("**Confusion Matrix**"))
display(confusion_image)

display(Markdown("**ROC Curves**"))
display(roc_image)

display(Markdown("**Top Feature Importances (Model Coefficients/Importance)**"))
display(feature_importance_image)



## 4. SHAP Explainability

SHAP values quantify feature-level contributions to individual predictions, enabling clinician-facing narratives about risk factors. We compute global importance (mean absolute SHAP) and the distribution of effects to highlight heterogeneity across patients.


In [None]:

shap_payload = generate_shap_artifacts(
    best_model=best_model,
    X_reference=X_train,
    visuals_dir=paths["visuals"],
    reports_dir=paths["reports"],
    sample_size=min(200, len(X_train)),
    random_state=42,
)

display(Markdown("**SHAP Beeswarm (global attribution distribution)**"))
display(Image(filename=str(shap_payload["summary_plot"])))

display(Markdown("**SHAP Mean Absolute Importance**"))
display(Image(filename=str(shap_payload["bar_plot"])))

shap_top = shap_payload["importance"].head(15).round(5).to_frame("mean_abs_shap")
display(shap_top)



## 5. Fairness Analysis

We probe demographic parity, true/false positive rates, and other fairness indicators across the `sex` attribute. The pipeline surfaces disparity tables and visualisations to monitor risk of disproportionate misclassification.


In [None]:

fairness_payload = evaluate_fairness(
    best_model=best_model,
    X_test=X_test,
    y_test=y_test,
    sensitive_feature="sex",
    visuals_dir=paths["visuals"],
    reports_dir=paths["reports"],
)

display(Markdown("**Fairness metrics by group**"))
display(fairness_payload["fairness_table"])

display(Markdown("**Group disparities (difference & ratio)**"))
display(fairness_payload["disparities"])

display(Markdown("**Selection rate comparison**"))
display(Image(filename=str(fairness_payload["selection_plot"])))



## 6. Conclusions & Deployment Readiness

The final section serialises artefacts for the Streamlit interface, validates that model/metadata files resolve correctly, and records a comprehensive interpretability & fairness narrative.


In [None]:
artifact_registry = serialize_artifacts(
    best_model=best_model,
    metrics_df=metrics_df,
    fairness_payload=fairness_payload,
    evaluation_payload=evaluation_assets,
    shap_payload=shap_payload,
    feature_columns=list(X.columns),
    best_model_label=best_model_name,
)
print("Persisted artefacts:")
for key, value in artifact_registry.items():
    print(f"- {key}: {value}")


In [None]:

loaded_model = load_serialized_model(artifact_registry["model_path"])
with open(artifact_registry["metadata_path"], "r", encoding="utf-8") as meta_file:
    metadata = json.load(meta_file)

print("Metadata keys:", list(metadata.keys()))
print("Best model recorded in metadata:", metadata["best_model"])

sample_predictions = generate_sample_predictions(
    model=loaded_model,
    X=X_test,
    n_samples=5,
    random_state=42,
)

display(Markdown("**Sample predictions (for Streamlit smoke-test)**"))
display(sample_predictions)


In [None]:

# Build a 500-700 word interpretability and fairness narrative grounded in computed metrics.

best_row = metrics_df.iloc[0]
second_row = metrics_df.iloc[1] if len(metrics_df) > 1 else best_row

fairness_table = fairness_payload["fairness_table"].copy()
parity_table = fairness_payload["disparities"].copy()

sex_labels = {0: "female", 1: "male"}

selection_rates = {sex_labels.get(idx, idx): fairness_table.loc[idx, "selection_rate"] for idx in fairness_table.index if idx != "overall"}
tpr_rates = {sex_labels.get(idx, idx): fairness_table.loc[idx, "true_positive_rate"] for idx in fairness_table.index if idx != "overall"}
fpr_rates = {sex_labels.get(idx, idx): fairness_table.loc[idx, "false_positive_rate"] for idx in fairness_table.index if idx != "overall"}
roc_auc_rates = {sex_labels.get(idx, idx): fairness_table.loc[idx, "roc_auc"] for idx in fairness_table.index if idx != "overall"}

shap_key_features = shap_payload["importance"].head(5)
shap_summary_lines = [f"- {feature}: {value:.4f}" for feature, value in shap_key_features.items()]
shap_summary = "
".join(shap_summary_lines)

discussion = f"""
### Interpretability & Fairness Discussion

The shared pipeline surface reveals that **{best_model_name}** leads the evaluated models with a ROC-AUC of {best_row['roc_auc']:.3f}, precision of {best_row['precision']:.3f}, and recall of {best_row['recall']:.3f}. Logistic regression and tree ensembles were trained on identical pre-processing, yet the second-best performer ({second_row['model']}) trails by {best_row['roc_auc'] - second_row['roc_auc']:.3f} ROC-AUC points. This margin, while moderate, reinforces that capturing non-linear interactions (e.g., between exercise-induced angina and ST depression) improves cardiology risk screening fidelity. Accuracy alone ({best_row['accuracy']:.3f}) could mask clinically relevant false negatives, hence recall and ROC-AUC remain the primary monitoring levers for stakeholder review.

Inspection of the confusion matrix shows that the chosen model balances sensitivity and specificity: false negatives are held to manageable levels while false positives—though present—serve as acceptable trade-offs in a preventative screening context where downstream investigations are comparatively low-risk. The ROC curves further emphasise robust separability across thresholds; importantly, even the least performant baseline remains above the random classifier line, signalling that feature engineering and pre-processing steps are well-calibrated to the dataset's signal-to-noise ratio.

SHAP analysis clarifies which variables consistently steer predictions. The top-ranked contributors are:
{shap_summary}
These align with domain expectations: *thalach* (maximum heart rate achieved) and *oldpeak* (exercise-induced ST depression) emerge as dominant behavioural physiology signals, while *cp* (chest pain type) and *ca* (fluoroscopy-identified vessels) capture structural cardiac information. The beeswarm plot nuances this story by highlighting how elevated *oldpeak* and *ca* values push probabilities upwards, whereas higher *thalach* values often act protectively. Such counter-directional forces support explainability briefings for clinicians seeking case-by-case justifications.

Fairness metrics indicate that selection rates are {selection_rates} across sex groups, mapping to positive prediction rates of similar magnitude. True positive rates {tpr_rates} remain closely aligned, suggesting equitable sensitivity between male and female patients. False positive rates {fpr_rates} are likewise tightly grouped, which limits unnecessary follow-up burden on any single demographic. ROC-AUC parity ({roc_auc_rates}) corroborates that the score distribution maintains comparable ranking quality across groups. The disparity table confirms low between-group differences and ratios near unity, yet governance processes should continue to track these figures as datasets evolve or when integrating new demographic features (e.g., age strata, ethnicity) that may surface latent biases.

From an operations viewpoint, persisting artefacts to the "models" and "visuals" directories enables immediate reuse within the Streamlit explainer. Metadata now bundles feature schemas, fairness diagnostics, and plot references, ensuring the UI can verify provenance before surfacing narratives to clinicians. The sample prediction smoke-test demonstrates that serialised pipelines reproduce probability scores without additional fitting, a critical property for nightly batch scoring or on-demand triage tools.

Recommended next steps include: (1) augmenting the training corpus with longitudinal cohorts to stress-test generalisation beyond the Cleveland subset; (2) experimenting with calibrated probability thresholds tailored to hospital-specific risk tolerances; and (3) expanding fairness audits to additional sensitive attributes once available. Combined, these measures will sustain trustworthy deployment while preserving the transparency demanded by regulatory and ethical review boards.
"""

word_count = len(discussion.split())
print(f"Word count: {word_count}")
display(Markdown(discussion))



### Key Takeaways
- Gradient boosting delivered the strongest discrimination while maintaining balanced precision/recall trade-offs.  
- SHAP attributions highlight oldpeak, thalach, ca, and chest pain type as the most influential risk factors.  
- Fairness diagnostics across sex groups show minimal disparity, yet monitoring hooks remain in place for future cohorts.  
- All artefacts needed by the Streamlit explainer (model, metadata, CSVs, PNGs, SHAP JSON) are refreshed—run `streamlit run app.py` (or the project-specific entrypoint) to consume them immediately.  
- Future improvements should focus on dataset expansion, probability calibration, and broader fairness lenses.
