<div>
<img src='../../img/WSP_red.png' style='height: 95px; float: left' alt='WSP Logo'/>
<img src='../../img/austroads.png' style='height: 115px; float: right' alt='Client Logo'/>
</div>
<center><h2>AAM6201 Development of Machine-Learning Decision-Support tools for Pavement Asset Management<br>Case Study 1: Project Identification</h2></center>


In [None]:
# magic command to autoreload changes in src
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pickle

from src import util

# Load data and models

In [None]:
from data import DATA_DIR
import pickle
report_dir = DATA_DIR.parent / 'reports' / 'raw_results'
model_dir = DATA_DIR.parent / 'models' / 'trained'
SUFFIX = 'even_split'
suffix = SUFFIX

paths_dict = { 
    'conf_mat': {
        'WA': report_dir / 'MRWA' / f'mrwa_final_{suffix}_dir' / f'valid_XGB_rawconfmat_mrwa_final_{suffix}.pkl',
        'NZ': report_dir / 'NZTA' / f'nzta_final_{suffix}_dir' / f'valid_XGB_rawconfmat_nzta_final_{suffix}.pkl',
        'NSW': report_dir / 'NSW' / f'nsw_final_{suffix}_dir' / f'valid_XGB_rawconfmat_nsw_final_{suffix}.pkl',
        'VIC': report_dir / 'VIC' / f'vic_final_{suffix}_dir' / f'valid_XGB_rawconfmat_vic_final_{suffix}.pkl',
    },
    'conf_mat_dummy': {
        'WA': report_dir / 'MRWA' / f'mrwa_final_{suffix}_dir' / f'valid_dummy_rawconfmat_mrwa_final_{suffix}.pkl',
        'NZ': report_dir / 'NZTA' / f'nzta_final_{suffix}_dir' / f'valid_dummy_rawconfmat_nzta_final_{suffix}.pkl',
        'NSW': report_dir / 'NSW' / f'nsw_final_{suffix}_dir' / f'valid_dummy_rawconfmat_nsw_final_{suffix}.pkl',
        'VIC': report_dir / 'VIC' / f'vic_final_{suffix}_dir' / f'valid_dummy_rawconfmat_vic_final_{suffix}.pkl',
    },
    'prediction_columns': {
        'WA': model_dir / 'MRWA' / f'mrwa_final_{suffix}_dir' / f'train_labels_columns_mrwa_final_{suffix}.pkl',
        'NZ': model_dir / 'NZTA' / f'nzta_final_{suffix}_dir' / f'train_labels_columns_nzta_final_{suffix}.pkl',
        'NSW': model_dir / 'NSW' / f'nsw_final_{suffix}_dir' / f'train_labels_columns_nsw_final_{suffix}.pkl',
        'VIC': model_dir / 'VIC' / f'vic_final_{suffix}_dir' / f'train_labels_columns_vic_final_{suffix}.pkl',
    }
}


model_dict = {
    juri: {
        'models': {
            'XGB': None
        },
        'prediction_columns': None
    } for juri in ['NSW', 'VIC', 'WA', 'NZ']
}

juris = ['NSW', 'MRWA', 'NZTA']
prefixes = ['final'] * 3
suffixes = [suffix] * 3

for train_name, prefix, suffix in zip(juris, prefixes, suffixes):
    juri_name = train_name.replace('TA', '').replace('MR', '') # turn mrwa->wa and nzta->nz
    model_dir = DATA_DIR.parent / 'models' / 'trained' / train_name / f'{train_name.lower()}_{prefix}_{suffix}_dir'
    for model_type in ['XGB']:
        with open(model_dir / f'train_{model_type}_timehorizon_{train_name.lower()}_{prefix}_{suffix}.pkl', 'rb') as f:
            model_dict[juri_name]['models'][model_type] = pickle.load(f)
    with open(model_dir  / f'train_labels_columns_{train_name.lower()}_{prefix}_{suffix}.pkl', 'rb') as f:
        model_dict[juri_name]['prediction_columns'] = pickle.load(f)

result_dict = {}
for juri in ['WA', 'NSW', 'NZ', 'VIC']:
    result_dict[juri] = {}
    for val_type in ['conf_mat', 'prediction_columns', 'conf_mat_dummy']:
        pth = paths_dict[val_type][juri]
        with open(pth, 'rb') as f:
            val = pickle.load(f)
            result_dict[juri][val_type] = val

In [None]:
from src import util
train_flattened_mrwa_labels = util.load_data(DATA_DIR / 'processed' / 'MRWA' / 'mrwa_final' / 'train_flattened_labels_mrwa_final.csv', header=[0, 1])
train_flattened_nzta_labels = util.load_data(DATA_DIR / 'processed' / 'NZTA' / 'nzta_final' / 'train_flattened_labels_nzta_final.csv', header=[0, 1])
valid_flattened_mrwa_labels = util.load_data(DATA_DIR / 'processed' / 'MRWA' / 'mrwa_final' / 'valid_flattened_labels_mrwa_final.csv', header=[0, 1])
valid_flattened_nzta_labels = util.load_data(DATA_DIR / 'processed' / 'NZTA' / 'nzta_final' / 'valid_flattened_labels_nzta_final.csv', header=[0, 1])

train_flattened_mrwa = util.load_data(DATA_DIR / 'processed' / 'MRWA' / 'mrwa_final' / 'train_flattened_data_mrwa_final_no_offset.csv')
train_flattened_nzta = util.load_data(DATA_DIR / 'processed' / 'NZTA' / 'nzta_final' / 'train_flattened_data_nzta_final_no_offset.csv')
valid_flattened_mrwa = util.load_data(DATA_DIR / 'processed' / 'MRWA' / 'mrwa_final' / 'valid_flattened_data_mrwa_final_no_offset.csv')
valid_flattened_nzta = util.load_data(DATA_DIR / 'processed' / 'NZTA' / 'nzta_final' / 'valid_flattened_data_nzta_final_no_offset.csv')

In [None]:
train_flattened_nsw = util.load_data(DATA_DIR / 'processed' / 'NSW' / 'final' / 'train_all.csv')
train_flattened_nsw_labels = util.load_data(DATA_DIR / 'processed' / 'NSW' / 'final' / 'labels_all.csv', header=[0, 1])
train_flattened_vic = util.load_data(DATA_DIR / 'processed' / 'VIC' / 'final' / 'train_all.csv')
train_flattened_vic_labels = util.load_data(DATA_DIR / 'processed' / 'VIC' / 'final' / 'labels_all.csv', header=[0, 1])

valid_flattened_nsw = util.load_data(DATA_DIR / 'processed' / 'NSW' / 'final' / 'valid_all.csv')
valid_flattened_nsw_labels = util.load_data(DATA_DIR / 'processed' / 'NSW' / 'final' / 'valid_labels_all.csv', header=[0, 1])
valid_flattened_vic = util.load_data(DATA_DIR / 'processed' / 'VIC' / 'final' / 'valid_all.csv')
valid_flattened_vic_labels = util.load_data(DATA_DIR / 'processed' / 'VIC' / 'final' / 'valid_labels_all.csv', header=[0, 1])

In [None]:
year_map_dict = {
    'Treatment within 1 year': 'Year 1',
    'Treatment between 1 to 3 years': 'Year 2 - 3',
    'Treatment between 3 to 5 years': 'Year 4 - 5',
    'Treatment between 5 to 10 years': 'Year 6 - 10',
}

year_order_dict = {
    'Treatment within 1 year': 0,
    'Treatment between 1 to 3 years': 1, 
    'Treatment between 3 to 5 years': 2, 
    'Treatment between 5 to 10 years': 3,
}

treatment_type_order = {
    'Resurfacing_SS': 0,
    'Resurfacing_AC': 1,
    'Major Patching': 2,
    'Rehabilitation': 3,
    'Retexturing': 4,
    'Regulation': 5
}

treatment_time_order = {
    'Treatment within 1 year': 0,
    'Treatment between 1 to 3 years': 1,
    'Treatment between 3 to 5 years': 2,
    'Treatment between 5 to 10 years': 3,
    'Treatment between 10 to 30 years': 4
}

juri_order = {juri: i for i, juri in enumerate(['WA', 'NSW', 'VIC', 'NZ'])}
juri_colors = {'WA': 'tab:blue', 'NSW': 'tab:orange', 'VIC': 'tab:green', 'NZ': 'tab:red'}

treatment_type_colors = {
    'Resurfacing_SS': 'tab:purple',
    'Resurfacing_AC': 'tab:brown',
    'Major Patching': 'tab:gray',
    'Rehabilitation': 'tab:olive',
    'Retexturing': 'tab:cyan',
    'Regulation': 'tab:pink',
}

In [None]:
save_fig_dir = report_dir.parent / 'figures' / 'shared' 
if save_fig_dir.exists() is False:
    save_fig_dir.mkdir(parents=True)

# Correlation plots
Plot correlation within datasets

In [None]:
def plot_cross_correlation(corr, title, figsize=(16, 8), heatmap_dict={}, colormap=None):
    fig, ax = plt.subplots(figsize=figsize) 
    if colormap is None: 
        colormap = sns.diverging_palette(220, 10, as_cmap=True)
    dropvals = np.zeros_like(corr)
    dropvals[np.triu_indices_from(dropvals)] = True
    sns.heatmap(corr, cmap = colormap, linewidths = .5, annot = True, fmt = ".2f", mask = dropvals, **heatmap_dict)
    plt.title(title)
    return fig, ax

corr_mrwa = train_flattened_mrwa.corr().abs()
corr_nzta = train_flattened_nzta.corr().abs()
corr_vic = train_flattened_vic.corr().abs()
corr_nsw = train_flattened_nsw.corr().abs()

In [None]:
vmin = 0; vmax=1
cmap = sns.light_palette('red', as_cmap=True)
fig_nzta, ax_nzta = plot_cross_correlation(corr_nzta, f'{"nzta".upper()} Correlation', heatmap_dict={'vmin': vmin, 'vmax': vmax}, colormap=cmap)
fig_mrwa, ax_mrwa = plot_cross_correlation(corr_mrwa, f'{"mrwa".upper()} Correlation', heatmap_dict={'vmin': vmin, 'vmax': vmax}, colormap=cmap)
fig_vic, ax_vic = plot_cross_correlation(corr_vic, f'{"vic".upper()} Correlation', figsize=(22, 8), heatmap_dict={'vmin': vmin, 'vmax': vmax}, colormap=cmap)
fig_nsw, ax_nsw = plot_cross_correlation(corr_nsw, f'{"nsw".upper()} Correlation', heatmap_dict={'vmin': vmin, 'vmax': vmax}, colormap=cmap)

fig_nzta.tight_layout()
fig_nzta.savefig(save_fig_dir / 'nzta_corr.jpeg')
fig_mrwa.tight_layout()
fig_mrwa.savefig(save_fig_dir / 'mrwa_corr.jpeg')
fig_vic.tight_layout()
fig_vic.savefig(save_fig_dir / 'vic_corr.jpeg')
fig_nsw.tight_layout()
fig_nsw.savefig(save_fig_dir / 'nsw_corr.jpeg')

# All results

## Result grouped by treatment

In [None]:
import src.visualization.plot_metric as plot_metric
import matplotlib.patches as mpatches
import warnings

from collections import OrderedDict
from typing import List, Dict
from pathlib import Path

warnings.filterwarnings('ignore')

def plot_metric_over_time_all_datasets(
        result_dict: dict, 
    ):
    """plot total accuracy for each of type-treatment pair"""
    uniq_juris= sorted(result_dict.keys(), key=lambda x: juri_order[x])
    metric_names = ['Precision', 'Recall', 'F-Score']

    treatments_time_dict : Dict[str, set] = {} 
    for juri in uniq_juris:
        prediction_columns : pd.MultiIndex = result_dict[juri]['prediction_columns']
        for treatment in set(prediction_columns.get_level_values(1)):
            times = set(prediction_columns[prediction_columns.get_level_values(1) == treatment].get_level_values(0))
            if 'Treatment between 10 to 30 years' in times:
                times.remove('Treatment between 10 to 30 years')
            if treatment in treatments_time_dict:
                treatments_time_dict[treatment].update(times)
            else:
                treatments_time_dict[treatment] = times

    for treatment, times in treatments_time_dict.items():
        fig, axs = plt.subplots(nrows=len(metric_names), ncols=1, figsize=(12, 4 * len(metric_names)))
        axs = axs.ravel()
        x = np.arange(len(times))
        width = 0.1
        uniq_types = sorted(times, key=treatment_time_order.__getitem__)

        metric_dict = {}
        dummy_metric_dict = {}
        for juri in uniq_juris:
            prediction_columns : pd.MultiIndex = result_dict[juri]['prediction_columns']

            # initialise metric dictionaries
            metric_dict[juri] = {}
            dummy_metric_dict[juri] = {}
            for metric in metric_names:
                metric_time_dict = OrderedDict()
                dummy_metric_time_dict = OrderedDict()
                for time_horizon in uniq_types:
                    metric_time_dict[time_horizon] = [np.nan, np.nan]
                    dummy_metric_time_dict[time_horizon] = np.nan
                metric_dict[juri][metric] = metric_time_dict
                dummy_metric_dict[juri][metric] = dummy_metric_time_dict
                
            for i, (time_type, inner_treatment) in enumerate(prediction_columns):
                # fill out values for metric dictionary
                if inner_treatment != treatment or time_type == 'Treatment between 10 to 30 years':
                    continue
                running_conf_mat : np.ndarray = np.array(result_dict[juri]['conf_mat'])[:, i, :, :]
                metric_dict[juri]['Precision'][time_type] = (running_conf_mat[:, 1, 1] / running_conf_mat[:, :, 1].sum(axis=1))
                metric_dict[juri]['Recall'][time_type] = (running_conf_mat[:, 1, 1] / running_conf_mat[:, 1, :].sum(axis=1))
                metric_dict[juri]['F-Score'][time_type] = (2 / (1 / metric_dict[juri]['Precision'][time_type] + 1 / metric_dict[juri]['Recall'][time_type]))
                # do the same for dummy
                for _, conf_mat in result_dict[juri]['conf_mat_dummy'].items():
                    running_conf_mat = np.array(conf_mat)[:, i, :, :]
                    prec = running_conf_mat[:, 1, 1] / running_conf_mat[:, :, 1].sum(axis=1)
                    recall = running_conf_mat[:, 1, 1] / running_conf_mat[:, 1, :].sum(axis=1)
                    f_score = 2 / (1 / prec + 1 / recall)
                    # get running max over means of multiple dummy strategies
                    dummy_metric_dict[juri]['Precision'][time_type] = np.nanmax([np.nanmean(prec), dummy_metric_dict[juri]['Precision'][time_type]])
                    dummy_metric_dict[juri]['Recall'][time_type] = np.nanmax([np.nanmean(recall), dummy_metric_dict[juri]['Recall'][time_type]])
                    dummy_metric_dict[juri]['F-Score'][time_type] = np.nanmax([np.nanmean(f_score), dummy_metric_dict[juri]['F-Score'][time_type]])

        handles = []
        acc_bars = []
        for i, juri in enumerate(uniq_juris):
            # make violin plots
            for j, metric_name in enumerate(metric_names):
                acc_bars.append(axs[j].violinplot(list(metric_dict[juri][metric_name].values()), positions=x+i*width, widths=width, showmeans=True))
            # color violin plots by juri accurately
            for violin_plots in acc_bars:
                for key, collection in violin_plots.items():
                    if key == 'bodies':
                        for pc in collection:
                            pc.set_facecolor(juri_colors[juri])
                            pc.set_alpha(0.3)
                            pc.set_edgecolor(juri_colors[juri])
                    else:
                        collection.set_edgecolor(juri_colors[juri])
            handles.append(mpatches.Patch(color=acc_bars[0]["bodies"][0].get_facecolor().flatten()))
            acc_bars = []
            # make chance-level bars
            enlist = lambda lst: [[val] for val in lst]
            for j, metric_name in enumerate(metric_names):
                bp = axs[j].boxplot(enlist(dummy_metric_dict[juri][metric_name].values()), positions=x+i*width, widths=width)
                for element in ['boxes', 'whiskers', 'medians', 'caps']:
                    plt.setp(bp[element], color='red')

        for metric in range(len(metric_names)):
            # axs[metric].set_ylabel(metric_names[metric], rotation=0)
            axs[metric].set_title(metric_names[metric], loc='left')
            axs[metric].set_xticks(x + width * (len(uniq_juris) - 1) / 2)
            axs[metric].set_xticklabels(map(year_map_dict.__getitem__, uniq_types))
            juri_legend = axs[metric].legend(handles, uniq_juris, bbox_to_anchor=(1, 1), loc="upper left", title='juri')
            ##  color top half
            curr_xlim = axs[metric].get_xlim()
            axs[metric].fill_between(x=[-0.5, x.max() + 1], y1=0.5, y2=1, color=treatment_type_colors[treatment], alpha=0.3, zorder=-1, label='Better than 0.5')
            axs[metric].set_xlim(curr_xlim)
            # axs[metric].axhline(0.5, xmin=axs[metric].get_xlim()[0], xmax=axs[metric].get_xlim()[1], color='red', linestyle='--', linewidth=0.5, label='0.5 line')
            axs[metric].plot([], [], 'r', label='Best chance level')
            axs[metric].legend(bbox_to_anchor=(1, 0), loc="lower left", title='Context')
            axs[metric].add_artist(juri_legend)
            axs[metric].set_ylim((0, 1) if metric_names[metric] != 'accuracy' else (0.8, 1))
            axs[metric].grid(True)

        if SUFFIX == 'balanced_sampled':
            fig.suptitle(f'Performance on validation set of short sections\n{treatment} - {" ".join(map(str.capitalize, suffix.split("_")))}', horizontalalignment='center', verticalalignment='top')
        else:
            fig.suptitle(f'Performance on validation set of short sections - Sampling performed to correct for class imbalance\n{treatment} - {" ".join(map(str.capitalize, suffix.split("_")))}', horizontalalignment='center', verticalalignment='top')
        plt.tight_layout()
        save_path = f'joined_results_{treatment.lower()}.jpg'
        plt.savefig(save_fig_dir / save_path)
        plt.close()

In [None]:
save_fig_dir = report_dir.parent / 'figures' / 'shared' / ('joined_results_by_treatment' + f'_{SUFFIX}')
if save_fig_dir.exists() is False:
    save_fig_dir.mkdir()
plot_metric_over_time_all_datasets(result_dict)

## Result grouped by metric

In [None]:
from collections import OrderedDict
import matplotlib.patches as mpatches

def plot_specific_metric_all_datasets(
        result_dict: dict, 
        rare_class: bool=False,
        metric: str='f_score',
        anonymised: bool=False
    ):
    """plot total accuracy for each of type-treatment pair"""
    uniq_juris= sorted(result_dict.keys(), key=lambda x: juri_order[x])

    treatments_time_dict : Dict[str, set] = {} 
    for juri in uniq_juris:
        prediction_columns : pd.MultiIndex = result_dict[juri]['prediction_columns']
        for treatment in set(prediction_columns.get_level_values(1)):
            if not rare_class:
                if treatment not in ['Resurfacing_SS', 'Resurfacing_AC', 'Rehabilitation']:
                    continue
            else:
                if treatment in ['Resurfacing_SS', 'Resurfacing_AC', 'Rehabilitation']:
                    continue
            times = set(prediction_columns[prediction_columns.get_level_values(1) == treatment].get_level_values(0))
            if 'Treatment between 10 to 30 years' in times:
                times.remove('Treatment between 10 to 30 years')
            if treatment in treatments_time_dict:
                treatments_time_dict[treatment].update(times)
            else:
                treatments_time_dict[treatment] = times

    fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(12, 12))
    axs = axs.ravel()
    for treatment_idx, (treatment, times) in enumerate(sorted(treatments_time_dict.items(), key=lambda tup: treatment_type_order[tup[0]])):
        x = np.arange(len(times))
        width = 0.1
        uniq_types = sorted(times, key=treatment_time_order.__getitem__)

        metric_dict = {}
        dummy_metric_dict = {}
        for juri in uniq_juris:
            prediction_columns : pd.MultiIndex = result_dict[juri]['prediction_columns']

            metric_dict[juri] = {}
            dummy_metric_dict[juri] = {}
            metric_time_dict = OrderedDict()
            dummy_metric_time_dict = OrderedDict()
            for time_horizon in uniq_types:
                metric_time_dict[time_horizon] = [np.nan, np.nan]
                dummy_metric_time_dict[time_horizon] = np.nan
            metric_dict[juri] = metric_time_dict
            dummy_metric_dict[juri] = dummy_metric_time_dict
            
            for i, (time_type, inner_treatment) in enumerate(prediction_columns):
                if inner_treatment != treatment or time_type == 'Treatment between 10 to 30 years':
                    continue
                running_conf_mat : np.ndarray = np.array(result_dict[juri]['conf_mat'])[:, i, :, :]
                precision = (running_conf_mat[:, 1, 1] / running_conf_mat[:, :, 1].sum(axis=1))
                recall = (running_conf_mat[:, 1, 1] / running_conf_mat[:, 1, :].sum(axis=1))
                if metric == 'F-Score':
                    metric_dict[juri][time_type] = (2 / (1 / precision + 1 / recall))
                elif metric == 'Recall':
                    metric_dict[juri][time_type] = recall
                elif metric == 'Precision':
                    metric_dict[juri][time_type] = precision
                else:
                    raise NotImplementedError(f"{metric} not implemented!")
                # compute for dummy
                for _, conf_mat in result_dict[juri]['conf_mat_dummy'].items():
                    running_conf_mat = np.array(conf_mat)[:, i, :, :]
                    precision = running_conf_mat[:, 1, 1] / running_conf_mat[:, :, 1].sum(axis=1)
                    recall = running_conf_mat[:, 1, 1] / running_conf_mat[:, 1, :].sum(axis=1)
                    f_score = 2 / (1 / precision + 1 / recall)
                    if metric == 'F-Score':
                        dummy_metric_dict[juri][time_type] = np.nanmax([np.nanmean(f_score), dummy_metric_dict[juri][time_type]])
                    elif metric == 'Recall':
                        dummy_metric_dict[juri][time_type] = np.nanmax([np.nanmean(recall), dummy_metric_dict[juri][time_type]])
                    elif metric == 'Precision':
                        dummy_metric_dict[juri][time_type] = np.nanmax([np.nanmean(precision), dummy_metric_dict[juri][time_type]])
                    else:
                        raise NotImplementedError(f"{metric} not implemented!")

        handles = []
        acc_bars = []
        for i, juri in enumerate(uniq_juris):
            # make violin plots
            acc_bars.append(axs[treatment_idx].violinplot(list(metric_dict[juri].values()), positions=x+i*width, widths=width, showmeans=True))
            # color violin plots by juri accurately
            for violin_plots in acc_bars:
                for key, collection in violin_plots.items():
                    if key == 'bodies':
                        for pc in collection:
                            pc.set_facecolor(juri_colors[juri])
                            pc.set_alpha(0.3)
                            pc.set_edgecolor(juri_colors[juri])
                    else:
                        collection.set_edgecolor(juri_colors[juri])
            handles.append(mpatches.Patch(color=acc_bars[0]["bodies"][0].get_facecolor().flatten()))
            acc_bars = []
            # make chance-level bars
            enlist = lambda lst: [[val] for val in lst]
            bp = axs[treatment_idx].boxplot(enlist(dummy_metric_dict[juri].values()), positions=x+i*width, widths=width)
            for element in ['boxes', 'whiskers', 'medians', 'caps']:
                plt.setp(bp[element], color='red')

        if not anonymised:
            legend_labels = uniq_juris
        else:
            legend_labels = ['Jurisdiction {}'.format(num + 1) for num in range(len(uniq_juris))]

        axs[treatment_idx].set_ylabel(metric)
        axs[treatment_idx].set_title(treatment, loc='left')
        axs[treatment_idx].set_xticks(x + width * (len(uniq_juris) - 1) / 2)
        axs[treatment_idx].set_xticklabels(map(year_map_dict.__getitem__, uniq_types))
        axs[treatment_idx].set_ylim((0, 1))
        juri_legend = axs[treatment_idx].legend(handles, legend_labels, bbox_to_anchor=(1, 1), loc="upper left", title='Jurisdiction')
        # color top half
        curr_xlim = axs[treatment_idx].get_xlim()
        axs[treatment_idx].fill_between(x=[-0.5, x.max() + 1], y1=0.5, y2=1, color=treatment_type_colors[treatment], alpha=0.3, zorder=-1, label='Better than 0.5')
        axs[treatment_idx].set_xlim(curr_xlim)
        axs[treatment_idx].set_ylim((0, 1))
        # best chance bars
        axs[treatment_idx].plot([], [], 'r', label='Best chance level')
        axs[treatment_idx].legend(bbox_to_anchor=(1, 0), loc="lower left", title='Context')
        axs[treatment_idx].add_artist(juri_legend)
        axs[treatment_idx].grid(True)

    if SUFFIX == 'balanced_sampled':
        fig.suptitle(f'{metric} performance on validation set of short sections\nSampling performed to correct for class imbalance', horizontalalignment='center', verticalalignment='top')
    else:
        fig.suptitle(f'{metric} performance on validation set of short sections', horizontalalignment='center', verticalalignment='top')
    plt.tight_layout()
    save_path = f'joined_results{"_rare" if rare_class else ""}_{metric}{"_anon" if anonymised else ""}.png'
    plt.savefig(save_fig_dir / save_path, dpi=200)
    plt.show()
    plt.close()

In [None]:
save_fig_dir = report_dir.parent / 'figures' / 'shared' / ('joined_results_by_metric' + f'_{SUFFIX}')
if save_fig_dir.exists() is False:
    save_fig_dir.mkdir(parents=True)

for metric in ['Recall', 'Precision', 'F-Score']:
    plot_specific_metric_all_datasets(result_dict, rare_class=False, metric=metric, anonymised=True)

# Redo performance plot

In [None]:
from src.visualization.visualize import plot_metric_by_treatment_type

inner_dir = save_fig_dir / 'redo_valid'
if not inner_dir.exists():
    inner_dir.mkdir()

def plot_metric_each_dataset_redo(
        result_dict: dict, 
    ):
    """plot total accuracy for each of type-treatment pair"""
    for juri in result_dict.keys():
        valid_mat = result_dict[juri]['conf_mat']
        prediction_columns = result_dict[juri]['prediction_columns']

        plot_metric_by_treatment_type(pd.DataFrame(columns=prediction_columns), valid_mat, 
            suptitle=f'Performance on validation set of short sections\n{juri} - {" ".join(map(str.capitalize, suffix.split("_")))}',
            save_path=inner_dir / f'{juri}_valid_redo_{suffix}.jpg'
        )

plot_metric_each_dataset_redo(result_dict)