In [None]:
import joblib
import shap
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
OSM_ID = 8269826
MAP_HEX_SIZE = 9
DATASET_TYPE = "inference" # one of "training", "test" or "inference"
MODEL_COMMENT = "artificial_S5P_scaled_wind_shift" # String or None
DATA_COMMENT = "artificial_S5P_scaled_wind_shift_Gliwice" # String or None
SELECTED_YEARS = (2023, 2024)
years_str = "_".join(str(year) for year in SELECTED_YEARS)
if DATA_COMMENT:
    DATA_FILE = f"../data/NO2_{DATASET_TYPE}_dataset_osm_{OSM_ID}_hex_{MAP_HEX_SIZE}_year_{years_str}_{DATA_COMMENT}.csv"
else:
    DATA_FILE = f"../data/NO2_{DATASET_TYPE}_dataset_osm_{OSM_ID}_hex_{MAP_HEX_SIZE}_year_{years_str}.csv"
if MODEL_COMMENT:
    ML_MODEL = f"../data/random_forest_NO2_gios_{MODEL_COMMENT}.pkl"
else:
    ML_MODEL = f"../data/random_forest_NO2_gios.pkl"
SELECTED_PARAMETERS = [
    "tree_cover",
    "grassland",
    "population_density",
    "low_vegetation",
    "medium_vegetation",
    "high_vegetation",
    "road",
    "residential_1",
    "residential_2",
    "residential_3",
    "residential_4",
    "non-residential_1",
    "non-residential_2",
    "non-residential_3",
    "non-residential_4",
    "temperature",
    "temperature_trend_3h",
    "temperature_trend_6h",
    "temperature_anomaly",
    "relative_humidity",
    "relative_humidity_trend_3h",
    "relative_humidity_trend_6h",
    "pressure",
    "pressure_trend_3h",
    "pressure_trend_6h",
    "precipitation",
    "wind_speed",
    "traffic_mean_count",
    "day_of_year_sin",
    "day_of_year_cos",
    "working_day",
]
TARGET = "no2_gios"
TREE_EXPLAINER_SAMPLE_SIZE = 2000
DATASET_SAMPLE_SIZE = 50000
PARAMETERS_ORDER = [
    "residential_4",
    "low_vegetation",
    "temperature_anomaly",
    "non-residential_3",
    "wind_speed",
    "day_of_year_cos",
    "relative_humidity",
    "residential_3",
    "temperature",
    "traffic_mean_count",
    "temperature_trend_3h",
    "precipitation",
    "working_day",
    "day_of_year_sin",
    "temperature_trend_6h",
    "high_vegetation",
    "relative_humidity_trend_6h",
    "pressure",
    "pressure_trend_6h",
    "relative_humidity_trend_3h",
    "tree_cover",
    "pressure_trend_3h",
    "road",
    "medium_vegetation",
    "residential_2",
    "non-residential_2",
    "grassland",
    "non-residential_4",
    "population_density",
    "residential_1",
    "non-residential_1",
]

In [None]:
mpl.rcParams["font.family"] = "Palatino Linotype"

In [None]:
df = pd.read_csv(DATA_FILE)
df = df[SELECTED_PARAMETERS]
df.dropna(inplace=True)
df.head()

In [None]:
rf_model = joblib.load(ML_MODEL)

In [None]:
features = list(df)
importances = rf_model.feature_importances_

df_importances = pd.DataFrame(index=features)
df_importances["importance"] = importances
df_importances.head()

In [None]:
X = df.sample(DATASET_SAMPLE_SIZE)

In [None]:
explainer = shap.TreeExplainer(rf_model, df.sample(TREE_EXPLAINER_SAMPLE_SIZE))
shapley_values = explainer.shap_values(X, check_additivity=False)

In [None]:
df_importances["mean_shapley"] = np.mean(shapley_values, axis=0)
df_importances["mean_abs_shapley"] = np.mean(np.abs(shapley_values),axis=0)
df_importances.head()

In [None]:
idx = [X.columns.get_loc(f) for f in PARAMETERS_ORDER]

In [None]:
X_reordered = X[PARAMETERS_ORDER]
shapley_values_reordered = shapley_values[:, idx]

In [None]:
abs_mean_shap_values = np.mean(np.abs(shapley_values_reordered), axis=0)

plt.figure(figsize=(8, 6), dpi=300)
plt.bar(PARAMETERS_ORDER, abs_mean_shap_values)
plt.xticks(rotation=90, fontsize=12)
plt.ylim(0, 5)
plt.grid(True, axis="y")
plt.ylabel("Absolute mean SHAP value")
plt.tight_layout()

if DATA_COMMENT:
    file_name = f"shap_mean_abs_hex_{MAP_HEX_SIZE}_{DATA_COMMENT}.png"
else:
    file_name = f"shap_mean_abs_hex_{MAP_HEX_SIZE}.png"

plt.savefig(file_name, bbox_inches="tight")
plt.show()

In [None]:
shap.summary_plot(
    shap_values=shapley_values[:, rest_idx],
    features=X.iloc[:, rest_idx],
    feature_names=remaining_features,
    max_display=len(rest_idx),
    plot_size=(14, 10),
    show=False,
)

plt.tight_layout()

if DATA_COMMENT:
    remaining_file_name = f"shap_remaining_features_hex_{MAP_HEX_SIZE}_{DATA_COMMENT}.png"
else:
    remaining_file_name = f"shap_remaining_features_hex_{MAP_HEX_SIZE}.png"
plt.savefig(remaining_file_name, bbox_inches="tight")
plt.show()
plt.close()

## Links

[SHAP (SHapley Additive exPlanations)](https://github.com/shap/shap)

[SHAP Values for Random Forest](https://medium.com/biased-algorithms/shap-values-for-random-forest-1150577563c9)

[Explaining model predictions with Shapley values - Random Forest](https://michaelallen1966.github.io/titanic/27_random_forest_shap.html)