In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import eli5
from eli5.sklearn import PermutationImportance
from pdpbox import pdp, get_dataset, info_plots

import shap
from pathlib import Path
from modelling_pipeline import features
from joblib import dump, load
import matplotlib as mpl

plt.rcParams['figure.dpi'] = 150
%matplotlib inline
%config InlineBackend.figure_format='retina'

mpl.rcParams['font.family'] = 'Optima'
mpl.rcParams['text.usetex'] = 'False'

In [None]:
def correlatePredictions(model, val_X, val_y, name):
    preds = model.predict(val_X)
    fig, ax = plt.subplots(figsize=(8, 5))
    sns.scatterplot(preds, val_y, cmap=plt.cm.Spectral, s=7, edgecolor='k', lw=1)

    ax.set_xlabel('Predicted pay gap (%)',  **{ 'fontsize': 18})
    ax.set_ylabel('Actual pay gap (%)',  **{ 'fontsize': 18})
    ax.set_title(name,  **{ 'fontsize': 24})
    plt.tick_params(axis='y', which='both', labelleft=True, labelright=False, labelsize=14)
    plt.tick_params(axis='x', which='both', labelleft=True, labelright=False, labelsize=14)
    plt.savefig('viz/{}-correlation.png'.format(name), dpi=150)
    plt.show()



def feature_importance(model, X, target_name):
    feature_importances = pd.DataFrame(
        model.feature_importances_,
        index=X.columns,
        columns=['importance']).sort_values('importance', ascending=False)
    fig, ax = plt.subplots(figsize=(7, 10))
    sns.barplot(ax=ax, data=feature_importances.reset_index(), 
                x='importance', 
                y='index', orient='h',  **{ 'fontsize': 24})
    plt.tick_params(axis='y', which='both', labelleft=True, labelright=False, labelsize=14)
    plt.tick_params(axis='x', which='both', labelleft=True, labelright=False, labelsize=14)
    plt.savefig('viz/{}-fi.png'.format(target_name), dpi=150)
    plt.show()


def permutation_importance(model, X, y, target_name):
    perm = PermutationImportance(model, random_state=1).fit(X, y)
    eli5.show_weights(perm, feature_names=X.columns.tolist())
    plt.savefig('viz/{}-perm-imp.png'.format(target_name), dpi=150)
    plt.show()


def partial_dependence_plot(model, data, feature_names, feature_to_plot, target_name):
    pdp_score = pdp.pdp_isolate(model=model, dataset=data, model_features=feature_names,
                                feature=feature_to_plot)
    pdp.pdp_plot(pdp_score, feature_to_plot)
    pdp.plt.title('Partial dependence plot of {}'.format(feature_to_plot),  **{ 'fontsize': 24})
    plt.tick_params(axis='y', which='both', labelleft=True, labelright=False, labelsize=18)
    plt.tick_params(axis='x', which='both', labelleft=True, labelright=False, labelsize=18)
    plt.xlabel([])
    pdp.plt.savefig('viz/{}-{}-pdp.png'.format(target_name, feature_to_plot), dpi=150)
    plt.show()


def plot_shap_values(model, X, target_name):
    # Create object that can calculate shap values
    explainer = shap.TreeExplainer(model)
    # calculate shap values. This is what we will plot.
    # Calculate shap_values for all of val_X rather than a single row, to have more data for plot.
    shap_values = explainer.shap_values(X.sample(frac=0.05))
    shap.summary_plot(shap_values, X.sample(frac=0.05),  **{ 'fontsize': 14})
    plt.savefig('viz/{}-shap.png'.format(target_name), dpi=150)


def explain_model(target_name, X, y):
    model = load('models/{}-best-model.joblib'.format(target_name))
    correlatePredictions(model, X, y, target_name)
    partial_dependence_plot(model, X.sample(frac=0.05), features, 'RepresentationInLowerQuartileSkew' , target_name )
    partial_dependence_plot(model, X.sample(frac=0.05), features, 'RepresentationInTopQuartileSkew' , target_name )
    permutation_importance(model, X, y, target_name)
    feature_importance(model, X, target_name)
    plot_shap_values(model, X, target_name)




In [None]:
Path('viz').mkdir(parents=True, exist_ok=True)
df = pd.read_csv('data/holdout_data.csv')
X = df[features]
for target in ['DiffMeanHourlyPercent', 'DiffMedianHourlyPercent']:
    y = df[target].values
    explain_model(target, X, y)

