In [3]:
import numpy as np
import pandas as pd
import pickle
from darts.models import RegressionModel
from darts.explainability.shap_explainer import ShapExplainer
from pprint import pprint

In [4]:
TARGET_BM = 'Weight'
WEIGHTS_DIR = 'weights/' + TARGET_BM + '/'

In [5]:
WEEKS_PER_MONTH = 4.2
horizons = [int(np.ceil(WEEKS_PER_MONTH * i)) for i in range(1, 6)]

In [6]:
model = RegressionModel.load(WEIGHTS_DIR + "model.pkl")
preprocess_pipeline = pickle.load(open(WEIGHTS_DIR + 'preprocessor.pkl', 'rb'))
scaler = pickle.load(open(WEIGHTS_DIR + 'scaler.pkl', 'rb'))
target = pickle.load(open(WEIGHTS_DIR + 'target.pkl', 'rb'))
past_cov = pickle.load(open(WEIGHTS_DIR + 'past_cov.pkl', 'rb'))

# Explainability

In [7]:
shap_explainer = ShapExplainer(model, target, past_cov)

In [8]:
def get_shap_values(shap_explainer, target, past_cov, horizons):
    
    explainability_res = shap_explainer.explain(target, past_cov,  horizons=horizons)
    comp_list = explainability_res.get_feature_values(horizons[0]).components.to_list()
    drop_comp_list = set([x for x in comp_list if any([x.startswith(y) for y in explainability_res.available_components])])

    importances_df = pd.DataFrame([])
    for horizon in horizons:
        if len(importances_df) == 0:
            importances_df = explainability_res.get_explanation(horizon).pd_dataframe().iloc[-1].T
        else:
            importances_df = pd.concat([importances_df, explainability_res.get_explanation(horizon).pd_dataframe().iloc[-1]], axis=1)
            
    importances_df.columns = horizons

    if isinstance(importances_df, pd.Series):
        importances_df = importances_df.to_frame()
        importances_df.columns = [horizons[0]]
    importances_df.reset_index(inplace=True)
    importances_df = importances_df[~importances_df['component'].isin(drop_comp_list)]

    importances_df = importances_df.T
    importances_df.columns = importances_df.iloc[0]
    importances_df = importances_df.drop(importances_df.index[0])
    importances_df.columns.name = 'Horizon'

    lag_columns = [col for col in importances_df.columns if 'lag-' in col]
    prefixes = set(col.rsplit('_', 1)[0] for col in lag_columns)

    for prefix in prefixes:
        lag_cols = [col for col in lag_columns if col.startswith(prefix)]
        importances_df[f'{prefix}'] = importances_df[lag_cols].mean(axis=1)
        
    importances_df = importances_df.drop(columns=lag_columns)
    return importances_df

In [9]:
%%capture
importances_df = get_shap_values(shap_explainer, target[0], past_cov[0], horizons)

In [10]:
pprint(importances_df.to_dict())

{'Age_statcov_target_Weight': {5: -0.00012739568994709258,
                               9: -0.0002609015675332967,
                               13: -0.0004384671265787639,
                               17: -0.0006194160851065907,
                               21: -0.0007631538104770167},
 'Gender_statcov_target_Weight': {5: 0.00026210597425254295,
                                  9: 0.00048445266438664386,
                                  13: 0.0008854091982356168,
                                  17: 0.0013320266967638734,
                                  21: 0.0019220073241742093},
 'avg_calories_per_workout_pastcov': {5: -4.896645400709332e-06,
                                      9: -7.637095518209281e-07,
                                      13: 3.2405538409354913e-05,
                                      17: -6.15254245521039e-06,
                                      21: 3.4052826842570186e-05},
 'avg_cardio_workouts_pastcov': {5: 9.608622604632504e-05,
            