# <span style ='color:orange'>Imports

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

import shap

%reload_ext autoreload
%autoreload 2

from data_boxscore.data import load_dataframes, custom_query_df
from data_boxscore.xai import shap_query
from data_boxscore.constants import  features_minmax, features_no_scaling, features_perc, features_robust, features_standard
from data_boxscore.plots import create_spider_chart

# <span style ='color:blue'>Loading data

In [None]:
filename = 'data.csv'
team = None
league = None
season = None
excl_cols = []

In [None]:
gdf, X_train, y_train, X_val, y_val, X_test, y_test, X_query, y_query = load_dataframes(filename=filename,
                                                                    features_standard = features_standard, 
                                                                    features_minmax = features_minmax, 
                                                                    features_robust = features_robust, 
                                                                    features_perc = features_perc,
                                                                    features_no_scaling = features_no_scaling,              
                                                                    use_ELO = True,
                                                                    k_elo = 20,
                                                                    excluded_columns=excl_cols,
                                                                    team=team,
                                                                    season=season,
                                                                    league=league)
print(gdf.shape, X_train.shape, X_val.shape ,X_test.shape, )
gdf['home_win'].mean()

In [None]:
gdf.head(5)

In [None]:
X_test = X_test[:10]
y_test = y_test[:10]
X_test.shape, y_test.shape

# <span style ='color:green'>Models

In [None]:
from sklearn.svm import SVC

model = SVC(kernel = 'rbf', probability=True, C = 50, gamma = 0.0005)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print(f"Model accuracy = {accuracy_score(y_pred, y_test)}")

# <span style ='color:lightgreen'>XAI

### Local

In [None]:
# creating SV object with all X_train in the background
shap_prob_explainer = shap.Explainer(model.predict_proba, X_train)
# Fitting on X_test
sv_prob = shap_prob_explainer(X_test)

In [None]:
ind = 9

print(f"Index max in test set : {X_test.shape[0]}")
shap.plots.waterfall(sv_prob[ind,:,1])

### Global

In [None]:
# creating SV object with all X_train in the background
shap_pred_explainer = shap.Explainer(model.predict, X_train)
# Fitting on X_test
sv_pred = shap_pred_explainer(X_test)

In [None]:
shap.plots.beeswarm(sv_pred, max_display=13, show=False)

### Query

In [None]:
X_query.shape

In [None]:
shap_query_df = shap_query(df =gdf,
                           X_query = X_query,
                           shap_explainer= shap_prob_explainer,
                           team=team,
                           season = season,
                           league = league)
shap_query_df.head()

In [None]:
threshold = 0.0001

fig, ax = plt.subplots(1,1,figsize = (10,6))
info_cols = ['game_id','home_team', 'away_team'] if 'game_id' in shap_query_df.columns else ['home_team', 'away_team']
team_expl_cols = shap_query_df.drop(columns = info_cols).columns
id = pd.MultiIndex.from_product([shap_query_df[team_expl_cols].index, shap_query_df[team_expl_cols].columns], names=('gameday', 'feature'))
t_df = pd.DataFrame(np.array(shap_query_df[team_expl_cols]).reshape(len(shap_query_df[team_expl_cols])*len(shap_query_df[team_expl_cols].columns),1), index = id).rename(columns={0:'SHAP_value'})
t_df.reset_index(inplace= True)
feat_mask = (t_df.groupby('feature').mean()['SHAP_value'].abs() > threshold).rename('Threshold_filter')
mean_shap = t_df.groupby('feature').mean()['SHAP_value'].rename('Mean_SHAP_value')
t_df = t_df.merge(feat_mask, left_on = 'feature', right_index = True)
t_df = t_df.merge(mean_shap, left_on = 'feature', right_index = True)
t_df.sort_values(by = 'Mean_SHAP_value', ascending = False, inplace=True)
meanpointprops = dict(marker='D', markeredgecolor='black',
                    markerfacecolor='firebrick')
sns.boxplot(data = t_df[t_df['Threshold_filter']], x = 'SHAP_value', y = 'feature', hue = 'Mean_SHAP_value', palette = 'flare', legend = False, ax=ax, showmeans = True, meanprops = meanpointprops)
plt.axvline(x=0, color='red', linestyle='dashed', linewidth=2)
plt.show()
plt.close()

In [None]:
spider_plot_df = t_df[['feature','SHAP_value']].groupby('feature').mean()
spider_plot_df = spider_plot_df.merge(t_df[['feature','SHAP_value']].groupby('feature').std().rename(columns = {'SHAP_value':'std'}), right_index=True, left_index= True)
spider_plot_df

In [None]:
max_val = (spider_plot_df['SHAP_value'].max()*100 //2 +1)/50
min_val = (spider_plot_df['SHAP_value'].min()*100 //2)/50
graduation_level = int(max_val*50) -int(min_val*50)

min_val, max_val, graduation_level


In [None]:
min_val // 1000, max_val//1000, (2* max_val) // 1000

In [None]:
create_spider_chart(
    categories=spider_plot_df.index,
    values=spider_plot_df['SHAP_value'],
    color = 'green',
    std_devs=spider_plot_df['std'],
    min_value=min_val,
    max_value=max_val,
    graduation_levels=graduation_level,
    highlight_level=0,
    category_colors=['blue' if feat[:4] == 'team' else 'orange' for feat in spider_plot_df.index]
);