In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
import sklearn
from dmgpred.train import get_pipeline
from dmgpred.utils.loading import load_data
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    f1_score,
    matthews_corrcoef,
)
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier

# print the JS visualization code to the notebook
shap.initjs()

# sns.set_theme('talk')
sklearn.set_config(transform_output="pandas")
np.random.seed(0)

In [None]:
data = load_data(data_dir="../data/", processed=True)

X = data["X_train"]
y = data["y_train"] - 1

In [None]:
X.head()

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

In [None]:
# change to xgboost
pipe = get_pipeline(X_train, clf=XGBClassifier(n_estimators=500))
pipe.fit(X_train, y_train)

In [None]:
y_pred = pipe.predict(X_test)

f1_score(y_test, y_pred, average="micro")

In [None]:
matthews_corrcoef(y_test, y_pred)

In [None]:
preprocessor = pipe.named_steps["preprocessor"]
clf = pipe.named_steps["clf"]

In [None]:
X_train_preprocessed = preprocessor.transform(X_train)
X_test_preprocessed = preprocessor.transform(X_test)
feature_names = X_test_preprocessed.columns

## Feature Importances

In [None]:
from xgboost import plot_importance

bst = clf.get_booster()
importance_type = "weight"  # or "cover", "gain"
plot_importance(bst, importance_type="gain", show_values=False)

In [None]:
feature_imp = pd.Series(clf.feature_importances_, index=feature_names).sort_values(
    ascending=True
)

feature_imp.plot(kind="barh", figsize=(10, 10))
plt.suptitle("Feature Importances (Gain)")
plt.tight_layout()
plt.savefig("../output/feature_importance.svg")
plt.show()

## Confusion Matrices

In [None]:
plt.figure(figsize=(12, 12))
ConfusionMatrixDisplay.from_estimator(
    pipe, X_test, y_test, cmap="Blues", display_labels=["Grade 1", "Grade 2", "Grade 3"]
)
plt.grid(False)
plt.tight_layout()
plt.savefig("../output/confusion_matrix.svg")

Overall we can see that class 1 (i.e. damage grade 2) is the most difficult class to predict. It gets mistaken a lot by damage grade 3 and sometimes 1.



## SHAP Values

In [None]:
explainer = shap.TreeExplainer(clf)
explanation = explainer(X_test_preprocessed)

In [None]:
shap_values = explainer.shap_values(X_test_preprocessed)

In [None]:
plt.figure(figsize=(20, 12))
shap.summary_plot(
    [shap_values[:, :, class_ind] for class_ind in range(shap_values.shape[-1])],
    X_test_preprocessed,
    plot_type="bar",
    show=False,
)
plt.suptitle("SHAP Summary Plot")
# plt.tight_layout()
plt.savefig("../output/shap_summary_plot.svg")
plt.show()

In [None]:
shap.plots.waterfall(explanation[0, :, 0])

In [None]:
cls = 0
shap.plots.force(
    explainer.expected_value[cls],
    shap_values[0, :, cls],
    features=X_test_preprocessed.iloc[0],
)

In [None]:
shap.plots.force(
    explainer.expected_value[cls],
    shap_values[::100, :, cls],
    feature_names=feature_names,
)

In [None]:
shap.summary_plot(shap_values[:, :, cls], X_test_preprocessed)

In [None]:
shap.plots.violin(
    shap_values[:, :, cls],
    X_test_preprocessed,
    feature_names=feature_names,
    plot_type="layered_violin",
)

## Wrong Predictions

In [None]:
y_pred = pipe.predict(X_test)

tmp = X_test_preprocessed.copy()
tmp["target"] = y_test
tmp["pred"] = y_pred
misclassified = tmp.query("target != pred").copy()
X_mis = misclassified.drop(columns=["target", "pred"])
misclassified.head()

In [None]:
explainer_mis = shap.TreeExplainer(clf)
shap_values_mis = explainer_mis.shap_values(X_mis)
explanation_mis = explainer_mis(X_mis)

### Visualizing a single prediction

In [None]:
idx = 325
pred_cls = misclassified.iloc[idx]["pred"]
true_cls = misclassified.iloc[idx]["target"]
print(f"Predicted class: {pred_cls + 1} (actual: {true_cls + 1})")
shap.plots.force(
    explainer.expected_value[pred_cls],
    shap_values_mis[idx, :, pred_cls],
    X_mis.iloc[idx],
    # matplotlib=True,
)

In [None]:
shap.plots.waterfall(
    explanation_mis[idx, :, true_cls],
    max_display=10,
    show=False,
)
plt.suptitle(
    f"Explanation of misclassified sample (pred: {pred_cls +1}, true: {true_cls + 1})"
)
plt.tight_layout()
plt.savefig("../output/shap_waterfall.svg")
plt.show()

In [None]:
shap.decision_plot(
    explainer.expected_value[pred_cls],
    shap_values_mis[idx, :, pred_cls],
    X_mis.iloc[idx],
    show=False,
)
plt.suptitle(
    f"Explanation of misclassified sample (pred: {pred_cls + 1}, true: {true_cls + 1})"
)
plt.tight_layout()
plt.savefig("../output/shap_decision_plot.svg")
plt.show()

In [None]:
# decision plot for all class `cls` wrong predictions


cls = 2
every_nth = 50
X_mis_cls = misclassified.query("target == @cls")
X_mis_cls = X_mis_cls.drop(columns=["target", "pred"])
print(len(X_mis_cls))
ind = X_mis_cls.reset_index().index[::every_nth]

shap.decision_plot(
    explainer.expected_value[cls],
    shap_values_mis[ind, :, cls],
    X_mis_cls[::every_nth],
    show=False,
)
plt.suptitle(f"Decision plot for subset of grade {cls + 1} misclassified samples")
plt.tight_layout()
plt.savefig(f"../output/shap_decision_plot_grade{cls+1}.svg")
plt.show()

## Analyzing Features

In [None]:
X_test_preprocessed.columns

In [None]:
feature = "age"
cls = 0
filter_ = (X_mis["age"] < 100).to_numpy()
shap.plots.scatter(
    explanation_mis[filter_, feature, cls], alpha=0.5, hist=True, show=False
)
plt.suptitle(f"SHAP values for feature '{feature}' for class {cls + 1}")
plt.tight_layout()
plt.savefig(f"../output/shap_scatter_{feature}.svg")
plt.show()

In [None]:
features = ["geo_level_1_id", "geo_level_2_id", "geo_level_3_id"]
cls = 2
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
for feature, ax in zip(features, axs):
    shap.plots.scatter(
        explanation[::5, feature, cls],
        ax=ax,
        show=False,
        alpha=0.5,
    )
plt.suptitle(f"SHAP values for damage grade {cls + 1}")
plt.tight_layout()
plt.savefig("../output/shap_scatter.svg")
plt.show()