<div>
<img src='../../img/WSP_red.png' style='height: 95px; float: left' alt='WSP Logo'/>
<img src='../../img/austroads.png' style='height: 115px; float: right' alt='Client Logo'/>
</div>
<center><h2>AAM6201 Development of Machine-Learning Decision-Support tools for Pavement Asset Management<br>Case Study 1: Project Identification</h2></center>


In [None]:
# magic command to autoreload changes in src
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import shap

from src.visualization import plot_metric
from src import util

from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch, Rectangle


# Load models and data

In [None]:
from data import DATA_DIR

report_dir = DATA_DIR.parent / 'reports' / 'raw_results'
model_dir = DATA_DIR.parent / 'models' / 'trained'

# Distribution of train labels vs predicted labels on transfer vs predicted labels on valid
from data import DATA_DIR
from itertools import product
import pickle

SUFFIX = 'even_split'
juris = ['MRWA', 'NSW', 'NZTA', 'VIC']
prefixes = ['final'] * 4
suffixes = [SUFFIX] * 4

model_dict = {
    juri.replace('TA', '').replace('MR', ''): {
        'models': {
            'XGB': None,
            'LR': None
        },
        'prediction_columns': None
    } for juri in juris 
}

for train_name, prefix, suffix in zip(juris, prefixes, suffixes):
    juri_name = train_name.replace('TA', '').replace('MR', '')
    model_dir = DATA_DIR.parent / 'models' / 'trained' / train_name / f'{train_name.lower()}_{prefix}_{suffix}_dir'
    for model_type in ['XGB', 'LR']:
        if train_name == 'MRWA' and model_type == 'LR':
            continue
        with open(model_dir / f'train_{model_type}_timehorizon_{train_name.lower()}_{prefix}_{suffix}.pkl', 'rb') as f:
            model_dict[juri_name]['models'][model_type] = pickle.load(f)
    with open(model_dir  / f'train_labels_columns_{train_name.lower()}_{prefix}_{suffix}.pkl', 'rb') as f:
        model_dict[juri_name]['prediction_columns'] = pickle.load(f)

In [None]:
from src import util
from data import DATA_DIR

train_flattened_mrwa_labels = util.load_data(DATA_DIR / 'processed' / 'MRWA' / 'mrwa_final' / 'train_flattened_labels_mrwa_final.csv', header=[0, 1])
train_flattened_nzta_labels = util.load_data(DATA_DIR / 'processed' / 'NZTA' / 'nzta_final' / 'train_flattened_labels_nzta_final.csv', header=[0, 1])
valid_flattened_mrwa_labels = util.load_data(DATA_DIR / 'processed' / 'MRWA' / 'mrwa_final' / 'valid_flattened_labels_mrwa_final.csv', header=[0, 1])
valid_flattened_nzta_labels = util.load_data(DATA_DIR / 'processed' / 'NZTA' / 'nzta_final' / 'valid_flattened_labels_nzta_final.csv', header=[0, 1])

train_flattened_mrwa = util.load_data(DATA_DIR / 'processed' / 'MRWA' / 'mrwa_final' / 'train_flattened_data_mrwa_final_no_offset.csv')
train_flattened_nzta = util.load_data(DATA_DIR / 'processed' / 'NZTA' / 'nzta_final' / 'train_flattened_data_nzta_final_no_offset.csv')
valid_flattened_mrwa = util.load_data(DATA_DIR / 'processed' / 'MRWA' / 'mrwa_final' / 'valid_flattened_data_mrwa_final_no_offset.csv')
valid_flattened_nzta = util.load_data(DATA_DIR / 'processed' / 'NZTA' / 'nzta_final' / 'valid_flattened_data_nzta_final_no_offset.csv')

In [None]:
train_flattened_nsw = util.load_data(DATA_DIR / 'processed' / 'NSW' / 'final' / 'train_all.csv')
train_flattened_nsw_labels = util.load_data(DATA_DIR / 'processed' / 'NSW' / 'final' / 'labels_all.csv', header=[0, 1])
train_flattened_vic = util.load_data(DATA_DIR / 'processed' / 'VIC' / 'final' / 'train_all.csv')
train_flattened_vic_labels = util.load_data(DATA_DIR / 'processed' / 'VIC' / 'final' / 'labels_all.csv', header=[0, 1])

valid_flattened_nsw = util.load_data(DATA_DIR / 'processed' / 'NSW' / 'final' / 'valid_all.csv')
valid_flattened_nsw_labels = util.load_data(DATA_DIR / 'processed' / 'NSW' / 'final' / 'valid_labels_all.csv', header=[0, 1])
valid_flattened_vic = util.load_data(DATA_DIR / 'processed' / 'VIC' / 'final' / 'valid_all.csv')
valid_flattened_vic_labels = util.load_data(DATA_DIR / 'processed' / 'VIC' / 'final' / 'valid_labels_all.csv', header=[0, 1])

In [None]:
year_map_dict = {
    'Treatment within 1 year': 'Year 1',
    'Treatment between 1 to 3 years': 'Year 2 - 3',
    'Treatment between 3 to 5 years': 'Year 4 - 5',
    'Treatment between 5 to 10 years': 'Year 6 - 10',
}

year_order_dict = {
    'Treatment within 1 year': 0.5,
    'Treatment between 1 to 3 years': 2, 
    'Treatment between 3 to 5 years': 4, 
    'Treatment between 5 to 10 years': 7.5,
}

treatment_type_order = {
    'Resurfacing_SS': 0,
    'Resurfacing_AC': 1,
    'Major Patching': 2,
    'Rehabilitation': 3,
    'Retexturing': 4,
    'Regulation': 5
}

In [None]:
save_fig_dir = report_dir.parent / 'figures' / 'feature_inspection' / SUFFIX
if save_fig_dir.exists() is False:
    save_fig_dir.mkdir(parents=True)

# SHAP Explanations

In [None]:
shap.initjs()

In [None]:
from typing import List, Tuple
from scipy.stats import spearmanr

def get_magnitude_direction(shap_values_lst: List[float], feature_values_lst: List[float]) -> Tuple[float, float]:
    """Return the mean of the absolute value of shap coefficients, siginfying the magnitude of their effect,
    multiplied by 1 if the values are positively correlated with feature values, else -1"""
    # if the feature has 0 contribution
    if np.max(np.abs(shap_values_lst)) < 1e-5:
        return 0, 0

    try:
        with np.errstate(all='raise'):
            # corr_coeff_mat = np.corrcoef(
            #     shap_values_lst, feature_values_lst
            # )
            corr_coeff, pval = spearmanr(
                shap_values_lst, feature_values_lst
            )
            return_val = np.mean(np.abs(shap_values_lst)) * (1 if corr_coeff > 0 else - 1)
            if pval < 0.05:
                corr_coeff_magnitude = np.abs(corr_coeff)
            else:
                corr_coeff_magnitude = 0

    except Exception as e:
        print('Feature name: ', feature_values_lst.name)
        print('Sample of feature values: ', np.random.choice(feature_values_lst, size=10))
        print('Shap values max abs: ', np.max(np.abs(shap_values_lst)))
        print('Sample of shap values: ', np.random.choice(shap_values_lst, size=10))
        raise e

    return return_val, corr_coeff_magnitude

def normalize_uniform(arr: np.ndarray):
    return (arr - np.min(arr)) / (np.max(arr) - np.min(arr))

## Compute shap

In [None]:
from tqdm.notebook import tqdm
from matplotlib.colors import LinearSegmentedColormap 

juris = ['NZ', 'NSW', 'WA']
datasets = [train_flattened_nzta, train_flattened_nsw, train_flattened_mrwa, train_flattened_vic] 
labels = [train_flattened_nzta_labels, train_flattened_nsw_labels, train_flattened_mrwa_labels, train_flattened_vic_labels]
force_redo = []

save_name = 'raw_shap_results_pos_background.pkl'
if (save_fig_dir / save_name).exists():
    with open(save_fig_dir / save_name, 'rb') as f:
        raw_shap_result = pickle.load(f)
else:
    raw_shap_result = {juri: {} for juri in juris} 

try:
    for juri, eval_features, eval_labels in tqdm(zip(juris, datasets, labels), desc='juri', total=len(juris)):
        row_dict = {t: i for i, t in enumerate(['Treatment within 1 year', 'Treatment between 1 to 3 years', 'Treatment between 3 to 5 years', 'Treatment between 5 to 10 years'])}
        col_dict = {t: i for i, t in enumerate(['Resurfacing_SS', 'Resurfacing_AC', 'Major Patching', 'Rehabilitation'])}
        if juri  not in raw_shap_result:
            raw_shap_result[juri] = {}
        for col_idx, col in enumerate(tqdm(model_dict[juri]['prediction_columns'], desc='Prediction columns')):
            if (col[1] not in col_dict) or (col[0] not in row_dict):
                continue
            elif col not in raw_shap_result[juri]:
                raw_shap_result[juri][col] = {}
            estimator_idx = np.argwhere(model_dict[juri]['prediction_columns'] == col).flatten()
            if estimator_idx.size != 1:
                continue
            # add to juri dict
            if estimator_idx[0] not in raw_shap_result[juri][col]:
                raw_shap_result[juri][col][estimator_idx[0]] = {}
            # calculate shap values for each model
            shap_values_lst = []
            for model_idx, model in enumerate(model_dict[juri]['models']['XGB']):
                if (model_idx not in raw_shap_result[juri][col][estimator_idx[0]]) or (raw_shap_result[juri][col][estimator_idx[0]][model_idx] is None) or juri in force_redo:
                    explainer = shap.TreeExplainer(
                        model.estimators_[estimator_idx[0]], 
                        data=(
                            eval_features[eval_labels[col] == 1] if 'pos_background' in save_name else
                            eval_features
                        ),
                        model_output='probability',
                        feature_perturbation='interventional'
                    )
                    # shap values show marginal (average value if all else equal) effect of each feature on the probability of 'positive class' being true 
                    shap_values = explainer.shap_values(eval_features)
                    raw_shap_result[juri][col][estimator_idx[0]][model_idx] = shap_values
            with open(save_fig_dir / save_name, 'wb') as f:
                pickle.dump(raw_shap_result, f)
except Exception as e:
    with open(save_fig_dir / save_name, 'wb') as f:
        pickle.dump(raw_shap_result, f)
    raise e
finally:
    with open(save_fig_dir / save_name, 'wb') as f:
        pickle.dump(raw_shap_result, f)

## Load computed shap

In [None]:
use_pos_background = True 
with open(save_fig_dir / ('raw_shap_results' + ('_pos_background' if use_pos_background else '_all') + '.pkl'), 'rb') as f:
    raw_shap_result = pickle.load(f)

## Beeswarm

In [None]:
from tqdm.notebook import tqdm
from src.visualization.shap_beeswarm import shap_summary
from matplotlib.ticker import FixedLocator 
from matplotlib.colors import LinearSegmentedColormap 

juris = ['NZ', 'NSW', 'WA']
datasets = [train_flattened_nzta, train_flattened_nsw, train_flattened_mrwa] 
labels = [train_flattened_nzta_labels, train_flattened_nsw_labels, train_flattened_mrwa_labels]
colors = ['red', 'orange','blue']

force_redo = []

for juri, c, eval_features, eval_labels in tqdm(zip(juris, colors, datasets, labels), desc='juri', total=len(juris)):
    row_dict = {t: i for i, t in enumerate(['Treatment within 1 year', 'Treatment between 1 to 3 years', 'Treatment between 3 to 5 years', 'Treatment between 5 to 10 years'])}
    col_dict = {t: i for i, t in enumerate(['Resurfacing_SS', 'Resurfacing_AC', 'Major Patching', 'Rehabilitation'])}

    fig = plt.figure(figsize=(36, 36))

    for col_idx, col in enumerate(tqdm(model_dict[juri]['prediction_columns'], desc='Prediction columns')):
        if (col[1] not in col_dict) or (col[0] not in row_dict):
            continue
        ax = plt.subplot(4, 4, row_dict[col[0]] * 4 + col_dict[col[1]] + 1)
        plt.sca(ax)
        # find model corresponding to prediction column
        estimator_idx = np.argwhere(model_dict[juri]['prediction_columns'] == col).flatten()
        if estimator_idx.size != 1:
            continue
        # calculate shap values for each model
        shap_values_lst = []
        for model_idx, model in enumerate(model_dict[juri]['models']['XGB']):
            shap_values = raw_shap_result[juri][col][estimator_idx[0]][model_idx]
            shap_values_lst.append(shap_values)

        # get mean shap values for each feature
        shap_values_arr = np.array(shap_values_lst).mean(axis=0) # shape: (len(eval_features), len(features))

        shap_summary(
            shap_values=shap_values_arr,
            features=eval_features.rename(columns={
                col: col.replace('_df0', '').replace('|idx=0', '') for col in eval_features.columns
            }) if juri != 'VIC' else eval_features,
            sort=False,
            show=False,
            color_bar_label='Feature Value',
        )

        plt.title(
            f'{year_map_dict[col[0]]} - {col[1]}',
            fontsize=18
        )

    inner_dir = save_fig_dir / ('beeswarm_joined' + ('_pos_background' if use_pos_background else ''))
    if not inner_dir.exists(): inner_dir.mkdir()
    plt.suptitle(
        f'{juri}\nDistribution of SHAP values colored by normalized feature values'\
        + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance'),
        fontsize=25
    )
    
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.savefig(inner_dir / (f'shap_{juri}_beeswarm' + ('_pos_background' if use_pos_background else '') + '.jpg'))
    plt.close()

## Spearman SHAP correlation

In [None]:
from tqdm.notebook import tqdm
from matplotlib.ticker import FixedLocator
from matplotlib.colors import LinearSegmentedColormap 

juris = ['NZ', 'NSW', 'WA']
datasets = [train_flattened_nzta, train_flattened_nsw, train_flattened_mrwa] 
labels = [train_flattened_nzta_labels, train_flattened_nsw_labels, train_flattened_mrwa_labels]
colors = ['red', 'orange', 'blue']
forced_redo = []

for juri, c, eval_features, eval_labels in tqdm(zip(juris, colors, datasets, labels), desc='juri', total=len(juris)):
    row_dict = {t: i for i, t in enumerate(['Treatment within 1 year', 'Treatment between 1 to 3 years', 'Treatment between 3 to 5 years', 'Treatment between 5 to 10 years'])}
    col_dict = {t: i for i, t in enumerate(['Resurfacing_SS', 'Resurfacing_AC', 'Major Patching', 'Rehabilitation'])}

    fig = plt.figure(figsize=(36, 36))
    for col_idx, col in enumerate(tqdm(model_dict[juri]['prediction_columns'], desc='Prediction columns')):
        if (col[1] not in col_dict) or (col[0] not in row_dict):
            continue
        ax = plt.subplot(4, 4, row_dict[col[0]] * 4 + col_dict[col[1]] + 1)
        col_name = f"{col[0].replace('Treatment ', '')} - {col[1].replace('Resurfacing_', '')}" 
        plt.sca(ax)
        # find model corresponding to prediction column
        estimator_idx = np.argwhere(model_dict[juri]['prediction_columns'] == col).flatten()
        if estimator_idx.size != 1:
            continue
        # calculate shap values for each model
        shap_values_lst = []
        for model_idx, model in enumerate(model_dict[juri]['models']['XGB']):
            shap_values = raw_shap_result[juri][col][estimator_idx[0]][model_idx]
            shap_values_lst.append(shap_values)
        # get mean shap values for each feature
        shap_values_arr = np.array(shap_values_lst).mean(axis=0) # shape: (len(eval_features), len(features))
        shap_values_all, corr_coef_mag_all = zip(
            *[get_magnitude_direction(shap_values_arr[:, i], eval_features.iloc[:, i]) for i in range(eval_features.shape[1])]
        )
        # make linear colormap from white - juri color
        cmap = LinearSegmentedColormap.from_list('', colors=['white', c])
        # plot bar plot
        ax.bar(
            x=np.arange(len(eval_features.columns)), 
            height=corr_coef_mag_all,
            color=cmap([1 if c > 0.5 else 0.5 for c in corr_coef_mag_all])
        )
        ax.grid()
        ax.set_ylabel('|Coefficient Magnitude|')
        ax.set_title(f'{year_map_dict[col[0]]} - {col[1]}')
        ax.xaxis.set_major_locator(FixedLocator(np.arange(len(eval_features.columns))))
        ax.xaxis.set_tick_params(direction='out')
        ax.xaxis.set_ticks_position('bottom')
        ax.set_xticklabels(
            [col.replace("_df0", "").replace("|idx=0", "") for col in eval_features.columns] if juri != 'VIC' else eval_features.columns, 
            rotation=45, ha='right'
        )
        threshold_line = ax.axhline(y=0.5, label='Threshold', c='r', linestyle='--')
        ax.legend(handles=[threshold_line])

    fig.suptitle(
        f'{juri} - Spearman correlation magnitude between\nSHAP values and feature values'\
        + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance'),
        fontsize=25,
    )

    fig.tight_layout(rect=[0, 0, 1, 0.98])
    inner_dir = save_fig_dir / 'spearman_only'
    if not inner_dir.exists(): inner_dir.mkdir()
    plt.savefig(inner_dir / f'corr_{juri}_new.jpg')
    plt.close()

# SHAP Effect direction check with known relationships

### Shap aggreement with known relationships

In [None]:
def get_flags_from_relationships(known_r: float, shap_r: float, corr_r: float, data_r: float):
    # this goes in priority level
    if known_r == 0:
        consistency_flag = 'No expected relationship'
    else:
        consistency_flag = 'Consistent with SME' if np.sign(known_r) == np.sign(shap_r) else 'Inconsistent with SME'
    
    correlation_flag = ('Strong ' if corr_r >= 0.5 else 'Weak ') + 'correlation'

    if data_r == 0:
        data_flag = 'No expected relationship from data'
    else:
        data_flag = 'Consistent with data' if np.sign(data_r) == np.sign(shap_r) else 'Inconsistent with data'

    return consistency_flag, correlation_flag, data_flag

from tqdm import tqdm
known_relationships = pd.read_csv(DATA_DIR.parent / 'references' / 'known_relationships.csv')

juris = ['NZ', 'NSW', 'WA']
datasets = [train_flattened_nzta, train_flattened_nsw, train_flattened_mrwa] 
labels = [train_flattened_nzta_labels, train_flattened_nsw_labels, train_flattened_mrwa_labels]
colors = ['red', 'orange','blue']
result = []
total = 0

time_set = {'Treatment between 3 to 5 years', 'Treatment between 1 to 3 years', 'Treatment within 1 year', 'Treatment between 5 to 10 years'}
treat_set = {'Resurfacing_SS', 'Resurfacing_AC', 'Rehabilitation'}

for juri, c, eval_features, eval_labels in tqdm(zip(juris, colors, datasets, labels), desc='juri', total=len(juris)):
    for col_idx, col in enumerate(tqdm(model_dict[juri]['prediction_columns'], desc='Prediction columns')):
        # find model corresponding to prediction column
        estimator_idx = np.argwhere(model_dict[juri]['prediction_columns'] == col).flatten()
        if (estimator_idx.size != 1) or (col[0] not in time_set) or (col[1] not in treat_set):
            continue
        shap_values_lst = []
        for model_idx, model in enumerate(model_dict[juri]['models']['XGB']):
            shap_values = raw_shap_result[juri][col][estimator_idx[0]][model_idx]
            shap_values_lst.append(shap_values)

        # get mean shap values for each feature
        shap_values_arr = np.array(shap_values_lst).mean(axis=0) # shape: (len(eval_features), len(features))
        shap_values_all, corr_coef_mag_all = zip(
            *[get_magnitude_direction(shap_values_arr[:, i], eval_features.iloc[:, i]) for i in range(eval_features.shape[1])]
        )

        for i, feature in enumerate(eval_features):
            # get sign of known relationship
            known_rs = known_relationships.loc[
                known_relationships['category'].apply(lambda f: f in feature.lower()),
                col[1].lower()
            ]
            if len(known_rs) > 0:
                assert known_rs.nunique() == 1
                known_r = known_rs.iloc[0]
                shap_r = shap_values_all[i]
                corr_r = corr_coef_mag_all[i]

                data_r, p_value = spearmanr(eval_features.iloc[:, i], eval_labels[col])
                data_r = int(np.sign(data_r) * (p_value < 0.05))

                con_flag, corr_flag, data_flag = get_flags_from_relationships(known_r=known_r, shap_r=shap_r, corr_r=corr_r, data_r=data_r)
                result.append((juri, col[0], col[1], feature, shap_r, corr_r, con_flag, corr_flag, data_flag, known_r, int(np.sign(shap_r)), data_r))

In [None]:
result_df = pd.DataFrame(result, columns=['juri', 'Treatment Time', 'Treatment Type', 'Feature', 'Mean Abs SHAP', 'Correlation Magnitude', 'Consistency Flag', 'Correlation Flag', 'Consistency Data Flag', 'Known effect', 'Model effect', 'Data effect'])

In [None]:
wrong_explained_simple = result_df[
    ((result_df['Consistency Flag'] == 'Inconsistent with SME') & (result_df['Correlation Flag'] == 'Strong correlation')) &\
    ((result_df['Known effect'] != result_df['Data effect']) & (result_df['Data effect'] != 0))
]
correct_unexplained_simple = result_df[
    ((result_df['Consistency Flag'] == 'Consistent with SME') & (result_df['Correlation Flag'] == 'Strong correlation')) &\
    ((result_df['Known effect'] != result_df['Data effect']) & (result_df['Data effect'] != 0))
]

print("Out of strongly correlated effects...")
print("Number of relationships which is inconsistent with SME: {}".format(((result_df['Consistency Flag'] == 'Inconsistent with SME') & (result_df['Correlation Flag'] == 'Strong correlation')).sum()))
print("Number of relationship which is inconsistent with SME but consistent with data: {}".format(len(wrong_explained_simple)))
print("Number of relationship which is consistent with SME but inconsistent with data: {}".format(len(correct_unexplained_simple)))

### Plot heatmap for each juri-year

In [None]:
# colors
aliceblue = np.array([0.9411764705882353, 0.9725490196078431, 1.0, 1.0])
red = np.array([167 / 255, 35 / 255, 38 / 255, 1])
pale_red = np.array([167 / 255, 35 / 255, 38 / 255, 0.1])
green = np.array([125 / 255, 148 / 255, 52 / 255, 1])
pale_green = np.array([125 / 255, 148 / 255, 52 / 255, 0.1])

# collector variables
f_to_i = {
    'Weak correlation and Inconsistent with SME': 0,
    'Strong correlation and Inconsistent with SME': 1,
    'Weak correlation and Consistent with SME': 2,
    'Strong correlation and Consistent with SME': 3,
    'Strong correlation and No expected relationship': 4,
    'Weak correlation and No expected relationship': 5,
}

c_array = np.array([pale_red, red, pale_green, green, aliceblue, aliceblue])


def plot_heatmap_effect_compar(df: pd.DataFrame, ax: plt.Axes):
    juri = df['juri'].iloc[0]
    year_type = df['Treatment Time'].iloc[0]
    df['Flag'] = df['Correlation Flag'] + ' and ' + df['Consistency Flag']
    # generate pivoted map"
    flag_mat = df.pivot(index='Feature', columns='Treatment Type', values='Flag')
    flag_mat = flag_mat.replace(f_to_i)
    effect_mat = df.pivot(index='Feature', columns='Treatment Type', values='Known effect')
    annot_mat = np.zeros_like(effect_mat).astype(str)
    annot_mat[np.where(effect_mat < 0)] = '▼'
    annot_mat[np.where(effect_mat > 0)] = '▲'
    annot_mat[np.where(effect_mat == 0)] = ''

    discrete_cmap = ListedColormap([pale_red, red, pale_green, green, 'aliceblue', 'aliceblue'], name='discrete_cmap') # color correspond to flag index

    # plot
    ax.imshow(
        c_array[flag_mat],
        aspect='auto'
    )
    ax.set_xticks(np.arange(0, flag_mat.shape[1], 1))
    ax.set_yticks(np.arange(0, flag_mat.shape[0], 1))

    # Labels for major ticks
    ax.set_xticklabels(list(flag_mat.columns))
    ax.set_yticklabels(list(flag_mat.index))

    # Minor ticks
    ax.set_xticks(np.arange(-.5, flag_mat.shape[1], 1), minor=True)
    ax.set_yticks(np.arange(-.5, flag_mat.shape[0], 1), minor=True)
    # make invisible
    ax.tick_params(axis='x', which='minor', length=0)
    ax.tick_params(axis='y', which='minor', length=0)

    # Gridlines based on minor ticks
    ax.grid(which='minor', color='white', linestyle='-', linewidth=1)
    # remove border
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)

    plt.title(f'{juri} | {year_type}')
    plt.xlabel(None)
    plt.ylabel(None)

    color_legend = plt.legend(
        title='Color Legend',
        handles=[Patch(color=discrete_cmap(i)) for i in range(5)],
        labels=['Weak correlation and Inconsistent with SME',
            'Strong correlation and Inconsistent with SME',
            'Weak correlation and Consistent with SME',
            'Strong correlation and Consistent with SME',
            'No expected relationship',
        ],
        bbox_to_anchor=(0.5, -0.1),
    ) 
    up_arrow = ax.scatter([], [], c='black', marker=u'$\u25b2$', s=150, label='SME expects Positive correlation')
    low_arrow = ax.scatter([], [], c='black', marker=u'$\u25bc$', s=150, label='SME expects Negative correlation')
    up_arrow_model = ax.scatter([], [], c='red', marker=u'$\u25b2$', s=150, label='Model expects Positive correlation')
    low_arrow_model = ax.scatter([], [], c='red', marker=u'$\u25bc$', s=150, label='Model expects Negative correlation')
    cross = Patch(facecolor='white', fill=False, hatch='xxxx', edgecolor='black', label='SME inconsistent with data')
    ax.legend(handles=[up_arrow, low_arrow, up_arrow_model, low_arrow_model, cross], title='Icon Legend', bbox_to_anchor=(1, -0.1))
    # ax.legend(title='Icon Legend', bbox_to_anchor=(1, -0.1))
    ax.add_artist(color_legend)

    # add crosses to denote correlation is supported/ unsupported with data 
    df.loc[
        (df['Known effect'] != df['Data effect']) & (df['Data effect'] != 0) &\
        (df['Known effect'] != 0) & (df['Correlation Flag'] == 'Strong correlation'),
        'Special consideration'
    ] = True
    cross_loc = np.where(df.pivot(index='Feature', columns='Treatment Type', values='Special consideration') == True) 
    for start, end in zip(cross_loc[0], cross_loc[1]):
        ax.add_patch(Rectangle((end - 0.5, start - 0.5), 1, 1, fill=False, hatch='x', edgecolor='black'))
    
    # add red arrows 
    model_mat = df.pivot(index='Feature', columns='Treatment Type', values='Model effect')
    no_known_annot = np.zeros_like(effect_mat).astype(str)
    no_known_annot[np.where((flag_mat == 4) & (model_mat > 0))] = '▲' 
    no_known_annot[np.where((flag_mat == 4) & (model_mat < 0))] = '▼' 
    no_known_annot[np.where((flag_mat != 4) | (model_mat == 0))] = '' 
    height, width = no_known_annot.shape
    xpos = np.arange(width)
    ypos = np.arange(height)
    for i_xpos, x in enumerate(xpos):
        for j_ypos, y in enumerate(ypos):
            val_no_known = no_known_annot[j_ypos][i_xpos]
            annot_val = annot_mat[j_ypos][i_xpos]
            if val_no_known == '': 
                text_color = "black"
                text_kwargs = dict(color=text_color, ha="center", va="center")
                ax.text(x, y, annot_val, **text_kwargs)
            else:
                text_color = "red"
                text_kwargs = dict(color=text_color, ha="center", va="center")
                ax.text(x, y, val_no_known, **text_kwargs)

In [None]:
fig = plt.figure(figsize=(24, 38))
juris = ['NZ', 'WA', 'NSW']

for i_juri, juri in enumerate(juris):
    # process result for each juri
    juri_result = result_df[result_df['juri'] == juri].copy()
    juri_result.loc[:, 'Treatment Time'] = juri_result['Treatment Time'].replace(year_map_dict)
    juri_result.loc[:, 'Feature'] = juri_result['Feature'].apply(lambda x: x.replace('|idx=0', '').replace('_df0', ''))

    for j_year_type, year_type in enumerate(['Year 1', 'Year 2 - 3', 'Year 4 - 5', 'Year 6 - 10']):
        df = juri_result[juri_result['Treatment Time'] == year_type].copy()
        if len(df) == 0:
            continue

        ax = plt.subplot(4, 3, 1 + (i_juri + j_year_type * 3))
        plot_heatmap_effect_compar(df, ax)

plt.suptitle(
    'Comparison between modelled univariate feature effects and expert (SME) expectations'\
   + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance'),
    fontsize=16
)
plt.tight_layout(rect=[0, 0.05, 1, 0.98], h_pad=7)

inner_dir = save_fig_dir / 'heatmap_feature_effects'
if not inner_dir.exists():
    inner_dir.mkdir()
plt.savefig(inner_dir / f'for_each_time{"_pos_background" if use_pos_background else "_all_background"}_new.jpg')
plt.show()

In [None]:
# get plot for each jurisdiction independently
juris = ['NZ', 'WA', 'NSW']

for i_juri, juri in enumerate(juris):
    # process result for each juri
    juri_result = result_df[result_df['juri'] == juri].copy()
    juri_result.loc[:, 'Treatment Time'] = juri_result['Treatment Time'].replace(year_map_dict)
    juri_result.loc[:, 'Feature'] = juri_result['Feature'].apply(lambda x: x.replace('|idx=0', '').replace('_df0', ''))
    fig = plt.figure(figsize=(16, 16))

    for j_year_type, year_type in enumerate(['Year 1', 'Year 2 - 3', 'Year 4 - 5', 'Year 6 - 10']):
        df = juri_result[juri_result['Treatment Time'] == year_type].copy()
        if len(df) == 0:
            continue

        ax = plt.subplot(2, 2, 1 + j_year_type)
        plot_heatmap_effect_compar(df, ax)

    plt.suptitle(
        f'Comparison between modelled univariate feature effects and expert (SME) expectations - Jurisdiction: {juri}'\
        + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance'),
        fontsize=16
    )
    plt.tight_layout(rect=[0, 0.05, 1, 0.98], h_pad=7)

    inner_dir = save_fig_dir / 'heatmap_feature_effects'
    if not inner_dir.exists():
        inner_dir.mkdir()
    plt.savefig(inner_dir / f'{juri}{"_pos_background" if use_pos_background else "_all_background"}_new.png', dpi=300)
    plt.close()

Example juri - year pair

In [None]:
# process result for each juri
juri = 'NZ'
year_type = 'Treatment between 1 to 3 years'
juri_result = result_df[(result_df['juri'] == juri) & (result_df['Treatment Time'] == year_type)].copy()
juri_result.loc[:, 'Treatment Time'] = juri_result['Treatment Time'].replace(year_map_dict)
juri_result.loc[:, 'Feature'] = juri_result['Feature'].apply(lambda x: x.replace('|idx=0', '').replace('_df0', ''))
fig = plt.figure(figsize=(10, 12))
ax = plt.subplot(1, 1, 1)
plot_heatmap_effect_compar(juri_result, ax)
plt.suptitle(
    f'Comparison between modelled univariate feature effects and expert (SME) expectations'\
  + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance')
)
plt.tight_layout(rect=[0, 0.05, 1, 0.98], h_pad=7)

inner_dir = save_fig_dir / 'heatmap_feature_effects'
if not inner_dir.exists():
    inner_dir.mkdir()
plt.savefig(inner_dir / f'heatmap_example.png', dpi=300)
plt.show()
plt.close()

### Plot heatmap for strength of shap effect for each juri year

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Patch, Rectangle
from itertools import product

fig = plt.figure(figsize=(23, 36))
juris = ['NZ', 'WA', 'NSW']

continous_cmap = {
    'NZ': LinearSegmentedColormap.from_list(name='austroads', colors=['#FFFFFF', 'tab:red']),
    'WA': LinearSegmentedColormap.from_list(name='austroads', colors=['#FFFFFF', 'tab:blue']),
    'NSW': LinearSegmentedColormap.from_list(name='austroads', colors=['#FFFFFF', 'tab:orange']),
}

for i_juri, juri in enumerate(juris):
    # process result for each juri
    juri_result = result_df[result_df['juri'] == juri].drop(columns=['juri']).copy()
    juri_result.loc[:, 'Treatment Time'] = juri_result['Treatment Time'].replace(year_map_dict)
    juri_result.loc[:, 'Feature'] = juri_result['Feature'].apply(lambda x: x.replace('|idx=0', '').replace('_df0', ''))

    for j_year_type, year_type in enumerate(['Year 1', 'Year 2 - 3', 'Year 4 - 5', 'Year 6 - 10']):
        df = juri_result[juri_result['Treatment Time'] == year_type].copy()
        if len(df) == 0:
            continue

        # generate pivoted map
        value_mat = df.pivot(index='Feature', columns='Treatment Type', values='Mean Abs SHAP')
        value_mat = (value_mat - value_mat.min(axis=0)) / (value_mat.max(axis=0) - value_mat.min(axis=0))
        effect_mat = df.pivot(index='Feature', columns='Treatment Type', values='Known effect')
        
        # plot
        ax : plt.Axes
        ax = plt.subplot(4, 3, 1 + (i_juri + j_year_type * 3))
        sns.heatmap(
            value_mat,
            cmap=continous_cmap[juri],
            linecolor='white', linewidth=1,
            fmt='',
            cbar=True,
            ax=ax,
        )
        cbar = ax.collections[0].colorbar
        cbar.set_ticks([0.1, 0.8])
        cbar.set_ticklabels(['low', 'high'])
        cbar.set_label('Normalised feature effect')
        plt.title(f'{juri} | {year_type}')
        plt.xlabel(None)
        plt.ylabel(None)

plt.suptitle(
    'Normalised feature importances'\
   + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance'),
    fontsize=25,
)
plt.tight_layout(rect=[0, 0.05, 1, 0.98], h_pad=7)

inner_dir = save_fig_dir / 'heatmap_feature_effects'
if not inner_dir.exists():
    inner_dir.mkdir()
plt.savefig(inner_dir / f'for_each_time_strength_effect{"_pos_background" if use_pos_background else "_all_background"}_new.jpg')
plt.show()

### Plot aggregated heatmap over years

In [None]:
def merge_years(years):
    res = ''
    start, end = None, None
    for y in years:
        if start is None:
            start = int(y[0])
            end = int(y[-2:])
        else:
            if int(y[0]) == end + 1:
                end = int(y[-2:])
            else:
                res += (', ' if len(res) > 0 else '') + (f'{start} - {end}' if end > start else f'{start}')
                start = int(y[0])
                end = int(y[-2:])
    return res + (', ' if len(res) > 0 else '') + (f'{start} - {end}' if end > start else f'{start}')

        
def process_flag_group(group: pd.DataFrame):
    """Process group of feature effects flag, each group for each treatment type - feature pair by rules"""
    # 1. if there is at least 1 strong incorrect, set flag to incorrect, return number of incorrect, and mean treatment time
    # 2. if there is at least 1 no expected relationship, assert ALL are no expected relationship, set flag to no expected relationship
    # 3. if all flags are noisy, set flag to noisy
    # 4. any other case, set flag to correct
    assert group['Known effect'].nunique() == 1
    ret = {'Treatment Time': np.nan, 'Flag': None, 'Percent wrong': np.nan, 'Known effect': group['Known effect'].iloc[0], 'Explained': np.nan, 'Model effect': 0}
    flags = group['Correlation Flag'] + ' and ' + group['Consistency Flag']
    flags = flags.replace('Weak correlation and No expected relationship', 'No expected relationship')
    flags = flags.replace('Strong correlation and No expected relationship', 'No expected relationship')
    
    seen_flags = set(flags)
    if ('Strong correlation and Inconsistent with SME' in seen_flags):
        ret['Treatment Time'] = 'Year ' + merge_years(sorted(group[group['Consistency Flag'] == 'Inconsistent with SME']['Treatment Time'].str.strip('Year ').to_list()))
        # if all the wrong flags have an opposite data effects, set flag to Inconsistent but explained
        if len(group[
            (flags == 'Strong correlation and Inconsistent with SME') &\
            (group['Known effect'] == group['Data effect'])
        ]) == 0:
            ret['Flag'] = 'Inconsistent with SME but explained with data'
        else:
            ret['Flag'] = 'Inconsistent with SME'
    elif 'No expected relationship' in seen_flags:
        assert len(seen_flags) == 1, "There cannot be both 'No expected relationship' flag and another flag!"
        ret['Flag'] = 'No expected relationship'
        if group.loc[(group['Correlation Flag'] == 'Strong correlation') & (group['Model effect'] != 0), 'Model effect'].nunique() == 1:
            ret['Model effect'] = group['Model effect'].max() if group['Model effect'].max() > 0 else -1
    elif np.all([flag.startswith('Weak correlation') for flag in seen_flags]):
        ret['Flag'] = 'Weak correlation with feature'
    else: # at least 1 strong correct, can have multiple weak incorrect
        if len(group[
            (flags == 'Strong correlation and Consistent with SME') &\
            (group['Known effect'] != group['Data effect'])
        ]) == 0:
            ret['Flag'] = 'Consistent with SME'
        else:
            ret['Flag'] = 'Consistent with SME but unsupported with data'
    return pd.Series(ret)

In [None]:
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch, Rectangle

fig, axs = plt.subplots(1, 3, figsize=(27, 11))
juris = ['NZ', 'WA', 'NSW']

flag_to_index = {
    'Weak correlation with feature': 0,
    'Inconsistent with SME': 1,
    'Inconsistent with SME but explained with data': 2,
    'Consistent with SME': 3,
    'Consistent with SME but unsupported with data': 4,
    'No expected relationship': 5,
}
# discrete_cmap = ListedColormap(['gainsboro', 'lightcoral', 'lightcoral', 'yellowgreen', 'yellowgreen', 'aliceblue'], name='flag_colors')
discrete_cmap = ListedColormap(['gainsboro', red, red, green, green, 'aliceblue'], name='discrete_cmap') # color correspond to flag index

for i, juri in enumerate(juris):
    # process result for each juri
    juri_result = result_df[result_df['juri'] == juri].drop(columns=['juri']).copy()
    juri_result.loc[:, 'Treatment Time'] = juri_result['Treatment Time'].replace(year_map_dict)
    juri_result.loc[:, 'Feature'] = juri_result['Feature'].apply(lambda x: x.replace('|idx=0', '').replace('_df0', ''))
    juri_result : pd.DataFrame = juri_result.groupby(['Treatment Type', 'Feature']).apply(process_flag_group)
    juri_result = juri_result.reset_index()
    # generate pivoted map
    flag_mat = juri_result.pivot(index='Feature', columns='Treatment Type', values='Flag')
    time_mat = juri_result.pivot(index='Feature', columns='Treatment Type', values='Treatment Time')
    effect_mat = juri_result.pivot(index='Feature', columns='Treatment Type', values='Known effect')
    # process mat for annotation and color coding
    annot_mat = np.where((flag_mat == 'Inconsistent with SME') | (flag_mat == 'Inconsistent with SME but explained with data'), time_mat.astype(str) + '\n', '')
    annot_mat[np.where(effect_mat < 0)] += '▼'
    annot_mat[np.where(effect_mat > 0)] += '▲'
    flag_mat = flag_mat.replace(flag_to_index)

    sns.heatmap(
        flag_mat,
        cmap=discrete_cmap,
        annot=annot_mat,
        annot_kws={'color': 'black'},
        linecolor='white', linewidth=1,
        fmt='',
        cbar=None,
        ax=axs[i],
    )
    axs[i].set_title(juri)
    
    # add crosses to denote correlation is supported/ unsupported with data 
    cross_loc = np.where((flag_mat == 2) | (flag_mat == 4))
    for start, end in zip(cross_loc[0], cross_loc[1]):
        axs[i].add_patch(Rectangle((end, start), 1, 1, fill=False, hatch='x', edgecolor='black'))

    if i == 1:
        color_legend = axs[i].legend(
            handles=[Patch(color=discrete_cmap(i)) for i in [0, 1, 3, 5]], 
            labels=['Weak correlation with feature', 'Inconsistent with SME', 'Consistent with SME', 'No expected relationship'],
            loc='lower center',
            bbox_to_anchor=(0.2, -0.2)
        )
        up_arrow = axs[i].scatter([], [], c='black', marker=u'$\u25b2$', s=150, label='SME expects Positive correlation')
        low_arrow = axs[i].scatter([], [], c='black', marker=u'$\u25bc$', s=150, label='SME expects Negative correlation')
        cross = Patch(facecolor='white', fill=False, hatch='xxxx', edgecolor='black', label='SME inconsistent with data')
        up_arrow_model = axs[i].scatter([], [], c='red', marker=u'$\u25b2$', s=150, label='Model expects Positive correlation')
        low_arrow_model = axs[i].scatter([], [], c='red', marker=u'$\u25bc$', s=150, label='Model expects Negative correlation')
        axs[i].legend(handles=[up_arrow, low_arrow, up_arrow_model, low_arrow_model, cross], title='Icon Legend', bbox_to_anchor=(1, -0.1))
        # axs[i].legend(title='Icon Legend', bbox_to_anchor=(1, -0.1))
        axs[i].add_artist(color_legend)

    # add red arrows 
    model_mat = juri_result.pivot(index='Feature', columns='Treatment Type', values='Model effect')
    no_known_annot = np.zeros_like(effect_mat).astype(str)
    no_known_annot[np.where((flag_mat == 5) & (model_mat > 0))] = '▲' 
    no_known_annot[np.where((flag_mat == 5) & (model_mat < 0))] = '▼' 
    no_known_annot[np.where((flag_mat != 5) | (model_mat == 0))] = '' 
    height, width = no_known_annot.shape
    xpos = np.arange(width) + 0.5
    ypos = np.arange(height) + 0.5
    for i_xpos, x in enumerate(xpos):
        for j_ypos, y in enumerate(ypos):
            val = no_known_annot[j_ypos][i_xpos]
            if val == '': continue
            text_color = "red"
            text_kwargs = dict(color=text_color, ha="center", va="center")
            axs[i].text(x, y, val, **text_kwargs)

plt.suptitle(
    'Color coded tables of feature effects. Effects considered only if correlation coefficient $> 0.5$\nInconsistent effects (red) annotated with time horizons where inconsistency occurs.'\
   + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance'),
)
plt.tight_layout()

inner_dir = save_fig_dir / 'heatmap_feature_effects'
if not inner_dir.exists():
    inner_dir.mkdir()
plt.savefig(inner_dir / f'summarised_over_time{"_pos_background" if use_pos_background else "_all_background"}_new.jpg')
plt.show()

### Plot consistency between data and SME separately

In [None]:
from sklearn.metrics import confusion_matrix
from src.visualization import plot_metric

plt.figure(figsize=(9, 4))
df = result_df[
    result_df['Correlation Flag'].str.startswith('Strong')
].copy()
known_vs_data = df.groupby('Consistency Flag')['Consistency Data Flag'].value_counts()
known_vs_data.name = None

known_vs_data = known_vs_data.reset_index().pivot(index='Consistency Flag', columns='Consistency Data Flag', values=0)
known_vs_data.columns.name = None
known_vs_data.index.name = None

annot = known_vs_data.copy().astype(int)
cmap = ListedColormap([red, green, aliceblue]) # color correspond to flag index
# cmap = ListedColormap(['lightcoral', 'yellowgreen', 'aliceblue'])

colors = known_vs_data.copy() 
colors.iloc[:, 0] = 1
colors.iloc[:, 1] = 0
colors.iloc[:, 2] = 2

ax = sns.heatmap(colors, annot=annot, cmap=cmap, linecolor='white', linewidth=3, annot_kws={'fontsize': 20}, cbar=False, fmt='0d')
ax.tick_params(axis='x', rotation=0)
ax.tick_params(axis='y', rotation=0)
# ax.set_xlabel('Model is')
# ax.set_ylabel('Model is')
ax.set_title(
    "Summary of agreement between subject-matter experts (SME) and models"\
  + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance')
)
plt.tight_layout()

inner_dir = save_fig_dir / 'heatmap_feature_effects'
if not inner_dir.exists():
    inner_dir.mkdir()
plt.savefig(inner_dir / 'stat_summarised_comparison_effects_new.png', dpi=300)
plt.show()

### Distribution of features against labels, annotated with known relationships

Boxplot distribution (no outliers)

In [None]:
from src.visualization.shap_beeswarm import shap_summary

input_dict = {
    'WA': {'data': train_flattened_mrwa, 'labels': train_flattened_mrwa_labels},
    'NZ': {'data': train_flattened_nzta, 'labels': train_flattened_nzta_labels},
    'NSW': {'data': train_flattened_nsw, 'labels': train_flattened_nsw_labels},
}

wrong : pd.DataFrame = result_df[result_df['Correlation Flag'].str.startswith('Strong') & result_df['Consistency Flag'].str.startswith('Inconsistent')]
fig = plt.figure(figsize=(30, 30))

for i, (_, row) in enumerate(wrong.iterrows()):
    eval_features = input_dict[row['juri']]['data']
    eval_labels = input_dict[row['juri']]['labels']
    col = (row['Treatment Time'], row['Treatment Type'])

    with_labels = eval_labels[col] == 1
    plot_df = pd.DataFrame(eval_features.loc[:, row['Feature']].copy())
    plot_df['Has treatment'] = list(map({True: 'Yes', False: 'No'}.__getitem__, with_labels))

    cmap = {'Yes': green, 'No': red}
    
    ax = plt.subplot(6, 6, i + 1)
    # ax = plt.subplot(4, 4, i + 1)
    # ax = plt.subplot(1, 1, 1)
    if plot_df[row['Feature']].nunique() > 2:
        sns.boxplot(data=plot_df, x='Has treatment', y=row['Feature'], order=['Yes', 'No'], showfliers=False, palette=cmap)
    else:
        plot_df = plot_df.groupby(['Has treatment'])[row['Feature']].mean().reset_index()
        sns.barplot(data=plot_df, x='Has treatment', y=row['Feature'], order=['Yes', 'No'], palette=cmap)

    ax = plt.gca()
    feature_name = row["Feature"].replace("_df0", "").replace("|idx=0", "")
    if feature_name.startswith("Pavement Type") or feature_name.startswith("Surface Material"):
        expect_str = f'SME expects {feature_name.split("_")[0]} being {feature_name.split("_")[1]} to ' + ('increase' if row['Known effect'] > 0 else 'decrease') + f'\n{row["Treatment Type"]} treatment odds. '\
                    + '\nModel predictions inconsistent.\nData '\
                    + ('supports subject-matter experts' if row["Data effect"] == row['Known effect'] else ('supports models' if row['Data effect'] == row['Model effect'] else ' has no expectation'))
    else:
        expect_str = f'SME expects higher {feature_name.split("_")[0]} to ' + ('increase' if row['Known effect'] > 0 else 'decrease') + f'\n{row["Treatment Type"]} treatment odds. '\
                    + '\nModel predictions inconsistent.\nData '\
                    + ('supports subject-matter experts' if row["Data effect"] == row['Known effect'] else ('supports models' if row['Data effect'] == row['Model effect'] else ' has no expectation'))
    ax.set_title(
        f'{row["juri"]} | {year_map_dict[row["Treatment Time"]]}' +\
        f'\n{expect_str}'
    )

    if row["Known effect"] == row["Data effect"]:
        ax.fill_between(x=ax.get_xlim(), y1=ax.get_ylim()[0], y2=ax.get_ylim()[1], color=red, alpha=0.2, zorder=-1)
    else:
        ax.fill_between(x=ax.get_xlim(), y1=ax.get_ylim()[0], y2=ax.get_ylim()[1], color=green, alpha=0.2, zorder=-1)

    axis_to_data = ax.transAxes + ax.transData.inverted()
    ax.set_yticks((axis_to_data.transform((0, 0.1))[1], axis_to_data.transform((0, 0.9))[1]))
    ax.set_yticklabels(['low', 'high'])

    if eval_features[row['Feature']].nunique() > 2:
        ax.set_ylabel(row["Feature"].replace("_df0", "").replace("|idx=0", ""))
    else:
        ax.set_ylabel(f'Fraction with {feature_name.split("_")[0]} being {feature_name.split("_")[1]}')
        plt.setp(ax.patches, linewidth=1, edgecolor='black')
    ax.set_xlabel(f'Treatment: {row["Treatment Type"]}')
    ax.tick_params(axis='y', rotation=90)

plt.suptitle(
    "Distribution of features against treatments for effects inconsistent with SME"\
   + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance'),
   fontsize=18
)
plt.tight_layout(rect=[0, 0, 1, 0.98], w_pad=5.5)

inner_dir = save_fig_dir / 'heatmap_feature_effects'
if not inner_dir.exists():
    inner_dir.mkdir()
plt.savefig(inner_dir / 'explain_inconsistent_boxplot_new.jpg')
plt.show()

Boxplot for individual juri

In [None]:
from src.visualization.shap_beeswarm import shap_summary

input_dict = {
    'WA': {'data': train_flattened_mrwa, 'labels': train_flattened_mrwa_labels},
    'NZ': {'data': train_flattened_nzta, 'labels': train_flattened_nzta_labels},
    'NSW': {'data': train_flattened_nsw, 'labels': train_flattened_nsw_labels},
}

wrong : pd.DataFrame = result_df[result_df['Correlation Flag'].str.startswith('Strong') & result_df['Consistency Flag'].str.startswith('Inconsistent')]

for juri in wrong['juri'].unique():
    df = wrong[wrong['juri'] == juri]
    num_col = min([i for i in range(2, 10) if abs(i - (len(df) // i)) < 2])
    num_row = int(np.ceil(len(df) / num_col))
    fig = plt.figure(figsize=(num_col * 5, num_row * 5))

    for i, (_, row) in enumerate(df.iterrows()):
        eval_features = input_dict[row['juri']]['data']
        eval_labels = input_dict[row['juri']]['labels']
        col = (row['Treatment Time'], row['Treatment Type'])

        with_labels = eval_labels[col] == 1
        plot_df = pd.DataFrame(eval_features.loc[:, row['Feature']].copy())
        plot_df['Has treatment'] = list(map({True: 'Yes', False: 'No'}.__getitem__, with_labels))

        cmap = {'Yes': green, 'No': red}
        ax = plt.subplot(num_row, num_col, i + 1)

        if plot_df[row['Feature']].nunique() > 2:
            sns.boxplot(data=plot_df, x='Has treatment', y=row['Feature'], order=['Yes', 'No'], showfliers=False, palette=cmap)
        else:
            plot_df = plot_df.groupby(['Has treatment'])[row['Feature']].mean().reset_index()
            sns.barplot(data=plot_df, x='Has treatment', y=row['Feature'], order=['Yes', 'No'], palette=cmap)

        ax = plt.gca()
        feature_name = row["Feature"].replace("_df0", "").replace("|idx=0", "")
        if feature_name.startswith("Pavement Type") or feature_name.startswith("Surface Material"):
            expect_str = f'SME expects {feature_name.split("_")[0]} being {feature_name.split("_")[1]} to ' + ('increase' if row['Known effect'] > 0 else 'decrease') + f'\n{row["Treatment Type"]} treatment odds. '\
                        + '\nModel predictions inconsistent.\nData '\
                        + ('supports subject-matter experts' if row["Data effect"] == row['Known effect'] else ('supports models' if row['Data effect'] == row['Model effect'] else ' has no expectation'))
        else:
            expect_str = f'SME expects higher {feature_name.split("_")[0]} to ' + ('increase' if row['Known effect'] > 0 else 'decrease') + f'\n{row["Treatment Type"]} treatment odds. '\
                        + '\nModel predictions inconsistent.\nData '\
                        + ('supports subject-matter experts' if row["Data effect"] == row['Known effect'] else ('supports models' if row['Data effect'] == row['Model effect'] else ' has no expectation'))

        ax.set_title(
            f'{row["juri"]} | {year_map_dict[row["Treatment Time"]]}' +\
            f'\n{expect_str}'
        )

        if row["Known effect"] == row["Data effect"]:
            ax.fill_between(x=ax.get_xlim(), y1=ax.get_ylim()[0], y2=ax.get_ylim()[1], color=red, alpha=0.2, zorder=-1)
        else:
            ax.fill_between(x=ax.get_xlim(), y1=ax.get_ylim()[0], y2=ax.get_ylim()[1], color=green, alpha=0.2, zorder=-1)

        axis_to_data = ax.transAxes + ax.transData.inverted()
        ax.set_yticks((axis_to_data.transform((0, 0.1))[1], axis_to_data.transform((0, 0.9))[1]))
        ax.set_yticklabels(['low', 'high'])

        if eval_features[row['Feature']].nunique() > 2:
            ax.set_ylabel(row["Feature"].replace("_df0", "").replace("|idx=0", ""))
        else:
            ax.set_ylabel(f'Fraction with {feature_name.split("_")[0]} being {feature_name.split("_")[1]}')
            plt.setp(ax.patches, linewidth=1, edgecolor='black')
        ax.set_xlabel(f'Treatment: {row["Treatment Type"]}')
        ax.tick_params(axis='y', rotation=90)

    plt.suptitle(
        "Distribution of features against treatments for effects inconsistent with SME - Jurisdiction: {}".format(juri)\
      + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance')
    )
    plt.tight_layout(rect=[0, 0, 1, 0.98], w_pad=5.5)

    inner_dir = save_fig_dir / 'heatmap_feature_effects'
    if not inner_dir.exists():
        inner_dir.mkdir()
    plt.savefig(inner_dir / f'{juri}_explain_boxplot_new.png', dpi=300)
    plt.close()

Example boxplot

In [None]:
input_dict = {
    'WA': {'data': train_flattened_mrwa, 'labels': train_flattened_mrwa_labels},
    'NZ': {'data': train_flattened_nzta, 'labels': train_flattened_nzta_labels},
    'NSW': {'data': train_flattened_nsw, 'labels': train_flattened_nsw_labels},
}

wrong : pd.DataFrame = result_df[result_df['Correlation Flag'].str.startswith('Strong') & result_df['Consistency Flag'].str.startswith('Inconsistent')]

juri = 'NZ'
year_type = 'Treatment between 1 to 3 years'
feature = 'HeavyIndex_df0|idx=0'
treatment_type = 'Resurfacing_SS'

row = wrong[(wrong['juri'] == juri) & (wrong['Treatment Time'] == year_type) & (wrong['Treatment Type'] == treatment_type) & (wrong['Feature'] == feature)].iloc[0]
fig = plt.figure(figsize=(5, 5))

eval_features = input_dict[row['juri']]['data']
eval_labels = input_dict[row['juri']]['labels']
col = (row['Treatment Time'], row['Treatment Type'])

with_labels = eval_labels[col] == 1
plot_df = pd.DataFrame(eval_features.loc[:, row['Feature']].copy())
plot_df['Has treatment'] = list(map({True: 'Yes', False: 'No'}.__getitem__, with_labels))

ax = plt.subplot(1, 1, 1)
cmap = {'Yes': green, 'No': red}

if plot_df[row['Feature']].nunique() > 2:
    sns.boxplot(data=plot_df, x='Has treatment', y=row['Feature'], order=['Yes', 'No'], showfliers=False, palette=cmap)
else:
    plot_df = plot_df.groupby(['Has treatment'])[row['Feature']].mean().reset_index()
    sns.barplot(data=plot_df, x='Has treatment', y=row['Feature'], order=['Yes', 'No'], palette=cmap)

ax = plt.gca()
feature_name = row["Feature"].replace("_df0", "").replace("|idx=0", "")
if feature_name.startswith("Pavement Type") or feature_name.startswith("Surface Material"):
    expect_str = f'SME expects {feature_name.split("_")[0]} being {feature_name.split("_")[1]}\nto ' + ('increase' if row['Known effect'] > 0 else 'decrease') + f'{row["Treatment Type"]} treatment odds.'\
                + '\nModel predictions inconsistent.\nData '\
                + ('supports subject-matter experts' if row["Data effect"] == row['Known effect'] else ('supports models' if row['Data effect'] == row['Model effect'] else ' has no expectation'))
else:
    expect_str = f'SME expects higher {feature_name.split("_")[0]} to ' + ('increase' if row['Known effect'] > 0 else 'decrease') + f'\n{row["Treatment Type"]} treatment odds. '\
                + '\nModel predictions inconsistent.\nData '\
                + ('supports subject-matter experts' if row["Data effect"] == row['Known effect'] else ('supports models' if row['Data effect'] == row['Model effect'] else ' has no expectation'))
ax.set_title(
    f'{row["juri"]} | {year_map_dict[row["Treatment Time"]]}' +\
    f'\n{expect_str}'
)

if row["Known effect"] == row["Data effect"]:
    ax.fill_between(x=ax.get_xlim(), y1=ax.get_ylim()[0], y2=ax.get_ylim()[1], color=red, alpha=0.2, zorder=-1)
else:
    ax.fill_between(x=ax.get_xlim(), y1=ax.get_ylim()[0], y2=ax.get_ylim()[1], color=green, alpha=0.2, zorder=-1)

axis_to_data = ax.transAxes + ax.transData.inverted()
ax.set_yticks((axis_to_data.transform((0, 0.1))[1], axis_to_data.transform((0, 0.9))[1]))
ax.set_yticklabels(['low', 'high'])

if eval_features[row['Feature']].nunique() > 2:
    ax.set_ylabel(row["Feature"].replace("_df0", "").replace("|idx=0", ""))
else:
    ax.set_ylabel(f'Fraction with {feature_name.split("_")[0]} being {feature_name.split("_")[1]}')
    plt.setp(ax.patches, linewidth=1, edgecolor='black')
ax.set_xlabel(f'Treatment: {row["Treatment Type"]}')
ax.tick_params(axis='y', rotation=90)

plt.tight_layout(rect=[0, 0, 1, 1], w_pad=5.5)

inner_dir = save_fig_dir / 'heatmap_feature_effects'
if not inner_dir.exists():
    inner_dir.mkdir()
plt.savefig(inner_dir / f'boxplot_explain_sample.png', bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
from src.visualization.shap_beeswarm import shap_summary

input_dict = {
    'WA': {'data': train_flattened_mrwa, 'labels': train_flattened_mrwa_labels},
    'NZ': {'data': train_flattened_nzta, 'labels': train_flattened_nzta_labels},
    'NSW': {'data': train_flattened_nsw, 'labels': train_flattened_nsw_labels},
}

wrong_and_inconsistent_data = result_df[
    (result_df['Consistency Flag'].str.startswith('Inconsistent with SME')) &\
    (result_df['Correlation Flag'].str.startswith('Strong')) &\
    (result_df['Consistency Data Flag'].str.startswith('Inconsistent'))
]
num_row = 5
num_col = 2
fig = plt.figure(figsize=(num_col * 5, num_row * 5))

for i, (_, row) in enumerate(wrong_and_inconsistent_data.iterrows()):
    eval_features = input_dict[row['juri']]['data']
    eval_labels = input_dict[row['juri']]['labels']
    col = (row['Treatment Time'], row['Treatment Type'])

    with_labels = eval_labels[col] == 1
    plot_df = pd.DataFrame(eval_features.loc[:, row['Feature']].copy())
    plot_df['Has treatment'] = list(map({True: 'Yes', False: 'No'}.__getitem__, with_labels))

    cmap = {'Yes': green, 'No': red}
    ax = plt.subplot(num_row, num_col, i+1)

    if plot_df[row['Feature']].nunique() > 2:
        sns.boxplot(data=plot_df, x='Has treatment', y=row['Feature'], order=['Yes', 'No'], showfliers=False, palette=cmap)
    else:
        plot_df = plot_df.groupby(['Has treatment'])[row['Feature']].mean().reset_index()
        sns.barplot(data=plot_df, x='Has treatment', y=row['Feature'], order=['Yes', 'No'], palette=cmap)

    ax = plt.gca()
    feature_name = row["Feature"].replace("_df0", "").replace("|idx=0", "")
    if feature_name.startswith("Pavement Type") or feature_name.startswith("Surface Material"):
        expect_str = f'Expect {feature_name.split("_")[0]} being {feature_name.split("_")[1]}\nto ' + ('increase' if row['Known effect'] > 0 else 'decrease') + ' treatment odds.\nObserve ' + ('opposite' if row["Data effect"] != row['Known effect'] else 'the same')
    else:
        expect_str = f'Expect higher {feature_name.split("_")[0]} to ' + ('increase' if row['Known effect'] > 0 else 'decrease') + ' treatment odds.\nObserve ' + ('opposite' if row["Data effect"] != row['Known effect'] else 'the same')
    ax.set_title(
        f'{row["juri"]} | {year_map_dict[row["Treatment Time"]]} - {row["Treatment Type"]}' +\
        f'\n{expect_str}'
    )

    axis_to_data = ax.transAxes + ax.transData.inverted()
    ax.set_yticks((axis_to_data.transform((0, 0.1))[1], axis_to_data.transform((0, 0.9))[1]))
    ax.set_yticklabels(['low', 'high'])

    if row["Known effect"] == row["Data effect"]:
        ax.fill_between(x=ax.get_xlim(), y1=ax.get_ylim()[0], y2=ax.get_ylim()[1], color='lightcoral', alpha=0.2, zorder=-1)
    else:
        ax.fill_between(x=ax.get_xlim(), y1=ax.get_ylim()[0], y2=ax.get_ylim()[1], color='yellowgreen', alpha=0.2, zorder=-1)

    if eval_features[row['Feature']].nunique() > 2:
        ax.set_ylabel(row["Feature"].replace("_df0", "").replace("|idx=0", ""))
    else:
        ax.set_ylabel(f'Fraction with {feature_name.split("_")[0]} being {feature_name.split("_")[1]}')
        plt.setp(ax.patches, linewidth=1, edgecolor='black')
    ax.set_xlabel('Has treatment')
    ax.tick_params(axis='y', rotation=90)

plt.suptitle(
    "Distribution of features against treatments for effects inconsistent with SME and inconsistent with data"\
    + ('' if SUFFIX != 'balanced_sampled' else f'\nSampling performed to correct for class imbalance')
)
plt.tight_layout(rect=[0, 0, 1, 0.98], w_pad=5.5)

inner_dir = save_fig_dir / 'heatmap_feature_effects'
if not inner_dir.exists():
    inner_dir.mkdir()
plt.savefig(inner_dir / f'unexplain_inconsistent.png', dpi=300)
plt.close()

# PDP

In [None]:
from sklearn.inspection import partial_dependence
from matplotlib.patches import Patch
import matplotlib.gridspec as gridspec

juris = ['NSW', 'VIC', 'NZ', 'WA']
datasets = [train_flattened_nsw, train_flattened_vic, train_flattened_nzta, train_flattened_mrwa] 
colors = ['orange', 'green', 'red', 'blue']
cutoff = (0.05, 0.95)

if (save_fig_dir / 'raw_results_pdp_all.pkl').exists():
    with open(save_fig_dir / 'raw_results_pdp_all.pkl', 'rb') as f:
        raw_pdp_result = pickle.load(f)
else:
    raw_pdp_result = {juri: {} for juri in juris} 

for juri, c, eval_features in zip(juris, colors, datasets):

    for feature_idx, feature in enumerate(eval_features.columns): 
        if feature_idx <= 9:
            continue
        if feature not in raw_pdp_result[juri]:
            raw_pdp_result[juri][feature] = {}

        fig = plt.figure(figsize=(24, 60))
        gs = gridspec.GridSpec(5, 4, figure=fig) # 5 rows for 5 times, 4 cols for 4 treatments

        # calculate quantile cut off
        if eval_features[feature].nunique() > 2:
            low, high = np.quantile(eval_features[feature], cutoff)
        else:
            low, high = sorted(eval_features[feature].unique()) 
        if low == high:
            low, high = eval_features[feature].min(), eval_features[feature].max()
 
        row_dict = {t: i for i, t in enumerate(['within 1 year', 'between 1 to 3 years', 'between 3 to 5 years', 'between 5 to 10 years', 'between 10 to 30 years'])}
        col_dict = {t: i for i, t in enumerate(['Resurfacing_SS', 'Resurfacing_AC', 'Major Patching', 'Rehabilitation'])}

        for col_idx, col in enumerate(model_dict[juri]['prediction_columns']):
            if col[1] not in col_dict:
                continue
            # for testing:
            if col != ('Treatment within 1 year', 'Resurfacing_SS'):
                continue

            nested_gs = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[row_dict[col[0].replace('Treatment ', '')], col_dict[col[1]]], hspace=0.1)
            ax = fig.add_subplot(nested_gs[:2, :])
            frequency_ax = fig.add_subplot(nested_gs[2, :], sharex=ax)
            col_name = f"{col[0].replace('Treatment ', '')} - {col[1].replace('Resurfacing_', '')}" 
            ax.set_title(col_name)

            estimator_idx = np.argwhere(model_dict[juri]['prediction_columns'] == col).flatten()
            if estimator_idx.size != 1:
                continue
            all_means, true_means, prev_values = [], [], None
            if estimator_idx[0] not in raw_pdp_result[juri][feature]:
                raw_pdp_result[juri][feature][estimator_idx[0]] = {}
            for model_idx, model in enumerate(model_dict[juri]['models']['XGB']):
                if (model_idx not in raw_pdp_result[juri][feature][estimator_idx[0]]) or (raw_pdp_result[juri][feature][estimator_idx[0]][model_idx] is None):
                    results = partial_dependence(
                        model.estimators_[estimator_idx[0]],
                        X=eval_features,
                        features=[feature],
                        kind='average'
                    )
                    raw_pdp_result[juri][feature][estimator_idx[0]][model_idx] = results
                    with open(save_fig_dir / 'raw_results_pdp_all.pkl', 'wb') as f:
                        pickle.dump(raw_pdp_result, f)
                else:
                    results = raw_pdp_result[juri][feature][estimator_idx[0]][model_idx]

                # filter by percentile
                plot_x, plot_y = results['values'][0], results['average'][0]
                plot_idx = np.where((plot_x >= low) & (plot_x <= high))

                ax.plot(plot_x[plot_idx], plot_y[plot_idx], label=juri, color=c, alpha=0.2)
                all_means.append(plot_y[plot_idx])
                prev_values = plot_x[plot_idx] if prev_values is None else (None if (prev_values != plot_x[plot_idx]).any() else prev_values)
                assert prev_values is not None

                # get true predictions
                true_means.append(np.mean(model.estimators_[estimator_idx[0]].predict(np.array(eval_features)), axis=0))

            ax.plot(prev_values, np.array(all_means).mean(axis=0), label=juri, color=c, alpha=1)
            ax.axhline(np.array(true_means).mean(axis=0), xmin=prev_values.min(), xmax=prev_values.max(), label='Actual %', color='red', alpha=1)
            ax.tick_params(axis='y', labelcolor=c)

            # plot frequency
            if eval_features[feature].nunique() > 2:
                sns.kdeplot(data=eval_features, x=feature, clip=(low, high), shade=True, ax=frequency_ax, color=c)
            else:
                count_0 = (eval_features[feature] == low).sum()
                count_1 = (eval_features[feature] == high).sum()
                sns.barplot(x=[low, high], y=[count_0, count_1], ax=frequency_ax, color=c)
            frequency_ax.invert_yaxis()

        fig.supylabel('% of positive labels')
        fig.suptitle(
            f'Partial Dependency Plot - {juri}\n{feature.replace("_df0|idx=0", "")}',
            fontsize=25
        )
        fig.tight_layout(rect=[0.03, 0.03, 1, 0.95])
        inner_dir = save_fig_dir / 'constrained_pdp'
        if not inner_dir.exists(): inner_dir.mkdir()
        plt.savefig(inner_dir / f'pdp_{juri}_{feature.replace("_df0|idx=0", "")}.jpg')