## This notebook is used to make tables and figures for the arxiv preprint

In [None]:
import os
import sys
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
import matplotlib as mpl
mpl.rc('hatch', color='w', linewidth=1.5)
font = {'family' : 'Arial',
        'size'   : 6.5}
plt.rc('font', **font)
plt.rc('hatch', color='w', linewidth=0.5)
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['axes.linewidth'] = 0.5
plt.rcParams['xtick.major.width'] = 0.5
plt.rcParams['ytick.major.width'] = 0.5
from visual_utils.utils_tabfig import *
sys.path.append('../')
from src_safety_evaluation.validation_utils.utils_eval_metrics import *
from src_safety_evaluation.validation_utils.utils_evaluation import optimize_threshold

path_processed = '../ProcessedData/'
path_prepared = '../PreparedData/'
path_result = '../ResultData/'
path_raw = '../RawData/'
path_fig = '../../arXiv/Figures/'

## Result 1 Accurate detection of safety-critical events

In [None]:
models = ['highD_current', 'ACT', 'TTC2D', 'TAdv', 'EI']
model_labels = ['GSSM', 'ACT', 'TTC2D', 'TAdv', 'EI']
colors = cmap([0.15, 0.3, 0.45, 0.6, 0.75])

In [None]:
warning_files = os.listdir(path_result + 'Conflicts/Results/')
warning_files = [f for f in warning_files if f.startswith('RiskEval_') and f.endswith('.h5')]
conflict_warning = pd.concat([pd.read_hdf(path_result+'Conflicts/Results/'+f, key='results') for f in tqdm(warning_files, desc='Reading files')])
voted_events = pd.read_csv(path_result + 'Conflicts/Voted_conflicting_targets.csv').set_index('event_id')
voted_events = voted_events[voted_events['target_id']>=0]
voted_events['event'] = [category[c] for c in voted_events['event_category'].values]
len(voted_events), voted_events['event'].value_counts()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(4.5, 1.8), constrained_layout=True)

## Receiver operating characteristic curves
ax_roc = axes[0]
ax_roc.set_title('Receiver operating characteristic\ncurve (ROC)', pad=5)
ax_roc.set_aspect('equal')
ax_roc.fill_between([0,80], 80, 100, fc=light_color('tab:red',0.05), ec=light_color('tab:red',0.35), lw=0.35, zorder=-10, label='Safety-critical')
event_curve(ax_roc, models, colors, conflict_warning, curve_type='roc')
ax_roc.set_xlim(0, 80)
ax_roc.set_ylim(20, 100)
ax_roc.tick_params(axis='both', which='both', pad=2, direction='in')
ax_roc.set_xticks([0, 20, 40, 60, 80])
ax_roc.set_yticks([20, 40, 60, 80, 100])

## Precision-recall curves
ax_prc = axes[1]
ax_prc.set_title('Precision-recall curve\n(PRC)', pad=5)
ax_prc.set_aspect('equal')
ax_prc.fill_betweenx([20,100], 80, 100, fc=light_color('tab:red',0.05), ec=light_color('tab:red',0.35), lw=0.35, zorder=-10)
event_curve(ax_prc, models, colors, conflict_warning, curve_type='prc')
ax_prc.tick_params(axis='both', which='both', pad=2, direction='in')
ax_prc.set_xlim(20, 100)
ax_prc.set_ylim(20, 100)
ax_prc.set_xticks([40, 60, 80, 100])
ax_prc.set_yticks([20, 40, 60, 80, 100])

## Time to alert curves
ax_time = axes[2]
ax_time.set_title('Accuracy-timeliness curve\n(ATC)', pad=5)
ax_time.set_xlim(0.5, 4.5)
ax_time.set_ylim(0.425, 0.925)
ax_time.set_aspect(4./0.5)
ax_time.fill_between([0.5,4.5], 0.8, 0.925, fc=light_color('tab:red',0.05), ec=light_color('tab:red',0.35), lw=0.35, zorder=-10)
event_curve(ax_time, models, colors, conflict_warning, curve_type='atc')
ax_time.tick_params(axis='both', which='both', pad=2, direction='in')

handles, legends = axes[0].get_legend_handles_labels()
fig.legend(handles, ['Safety-critical']+model_labels, loc='lower center', bbox_to_anchor=(0.5, -0.05),
           ncol=len(models)+1, frameon=False, handlelength=2.5, handletextpad=0.4, columnspacing=1)

In [None]:
fig.savefig(path_fig + 'Result1.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

In [None]:
metric_table = pd.DataFrame(columns=['model','auprc','aroc_80','aroc_90','pprc_80','pprc_90','PTTI_star','mTTI_star'])
for model, model_label in zip(models, model_labels):
    metrics = get_eval_metrics(conflict_warning[conflict_warning['model']==model],
                               thresholds={'roc': [0.80, 0.90], 'prc':[0.80, 0.90],'tti':1.5})
    metrics['model'] = model_label
    metric_table.loc[len(metric_table), list(metrics.keys())] = list(metrics.values())
metric_table[metric_table.columns[1:]] = metric_table[metric_table.columns[1:]].astype(float)
metric_table = highlight(metric_table, max_cols=metric_table.columns[1:], involve_second=True)
metric_table = metric_table.rename(columns={'model':'Method', 'auprc':'AUPRC',
                             'aroc_80':'$A_{80\\%}^\\mathrm{ROC}$', 'aroc_90':'$A_{90\\%}^\\mathrm{ROC}$',
                             'pprc_80':'$\\mathrm{Precision}_{80\\%}^\\mathrm{PRC}$', 'pprc_90':'$\\mathrm{Precision}_{90\\%}^\\mathrm{PRC}$',
                             'PTTI_star':'$P^*_{\\mathrm{TTI}\\geq1.5}$','mTTI_star':'$m\\mathrm{TTI}^*$'})
metric_table = metric_table.set_index('Method').loc[model_labels]
metric_table

In [None]:
print(metric_table.to_latex())

## Result 2 Scalability

In [None]:
warning_files = os.listdir(path_result + 'Analyses/')
warning_files = [f for f in warning_files if f.startswith('Warning_') and f.endswith('.h5')]
warning_files = [f for f in warning_files if 'mixed' in f or 'SafeBaseline' in f]
conflict_warning = pd.concat([pd.read_hdf(path_result+'Analyses/'+f, key='results') for f in tqdm(warning_files, desc='Reading files')])
event_meta = pd.read_csv(path_result + 'Analyses/EventMeta.csv')

In [None]:
fig, axes = draw_data_scalability(conflict_warning, event_meta)

In [None]:
fig.savefig(path_fig + 'Result2.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

## Result 3 Context-awareness

In [None]:
warning_files = os.listdir(path_result + 'Conflicts/Results/')
warning_files = [f for f in warning_files if f.startswith('RiskEval_') and f.endswith('.h5')]
conflict_warning = pd.concat([pd.read_hdf(path_result+'Conflicts/Results/'+f, key='results') for f in tqdm(warning_files, desc='Reading files')])

In [None]:
fig, axes = draw_feature_scalability(conflict_warning)

In [None]:
fig.savefig(path_fig + 'Result3.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

## Result 4 Generalisability

In [None]:
conflict_warning['model'].unique()

In [None]:
warning_files = os.listdir(path_result + 'Conflicts/Results/')
warning_files = [f for f in warning_files if f.startswith('RiskEval_') and f.endswith('.h5')]
conflict_warning = pd.concat([pd.read_hdf(path_result+'Conflicts/Results/'+f, key='results') for f in tqdm(warning_files, desc='Reading files')])
event_meta = pd.read_csv(path_result + 'Analyses/EventMeta.csv')

models = ['SafeBaseline_current_environment_profiles', 
          'ACT', 
          'TTC2D']

In [None]:
fig = draw_generalisability(conflict_warning, models, event_meta, voted_events)

In [None]:
fig.savefig(path_fig + 'Result4.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

In [None]:
event_meta = event_meta[event_meta['conflict']!='none']
model_labels = ['GSSM', 'ACT', 'TTC2D', 'TAdv', 'EI']

event_metric_list = []
event_metric_list.append([['leading'], 'Rear-end'])
event_metric_list.append([['adjacent_lane'], 'Adjacent lane'])
event_metric_list.append([['turning_into_parallel','turning_across_opposite','turning_into_parallel','turning_into_opposite','intersection_crossing'], 
                         'Crossing/turning'])
event_metric_list.append([['merging'], 'Merging'])
event_metric_list.append([['pedestrian','cyclist','animal'], 'With pedestrian/cyclist/animal'])

metric_table = []
for events, event_label in tqdm(event_metric_list):
    event_metrics = pd.DataFrame(columns=['Event', 'model','auprc','aroc_80','aroc_90','pprc_80','pprc_90','PTTI_star','mTTI_star'])
    for model, model_label in zip(models, model_labels):
        event_ids = event_meta[event_meta['conflict'].isin(events)]['event_id'].values
        filtered_warning = conflict_warning[(conflict_warning['event_id'].isin(event_ids))]
        filtered_warning = filtered_warning[filtered_warning['model']==model]

        metrics = get_eval_metrics(filtered_warning, thresholds={'roc': [0.80, 0.90], 'prc':[0.80, 0.90],'tti':1.5})
        metrics['Event'] = event_label
        metrics['model'] = model_label
        event_metrics.loc[len(event_metrics), list(metrics.keys())] = list(metrics.values())
        if model_label=='S-CaET':
            _, _, optimal_threshold = optimize_threshold(filtered_warning, 'GSSM', curve_type='PRC', return_stats=True)
            print(f'{event_label} FP:{optimal_threshold['FP']}, FN:{optimal_threshold['FN']}')                                            
    event_metrics[event_metrics.columns[2:]] = event_metrics[event_metrics.columns[2:]].astype(float)
    event_metrics = highlight(event_metrics, max_cols=event_metrics.columns[2:], involve_second=True)
    event_metrics['Number of events'] = f"{filtered_warning['event_id'].nunique()}"
    metric_table.append(event_metrics)
metric_table = pd.concat(metric_table, axis=0)
metric_table = metric_table.rename(columns={'model':'Method', 'aroc_80':'$A_{80\\%}^\\mathrm{ROC}$', 'aroc_90':'$A_{90\\%}^\\mathrm{ROC}$',
                                            'pprc_80':'$\\mathrm{Precision}_{80\\%}^\\mathrm{PRC}$', 'pprc_90':'$\\mathrm{Precision}_{90\\%}^\\mathrm{PRC}$',
                                            'auprc':'$\\mathrm{AUPRC}$', 'PTTI_star':'$P^*_{\\mathrm{TTI}\\geq1.5}$', 'mTTI_star':'$m\\mathrm{TTI}^*$'})

In [None]:
metric_table = metric_table.set_index(['Event','Number of events','Method'])
metric_table

In [None]:
print(metric_table.to_latex())

In [None]:
events_list = [['leading'], ['merging']]
model = models[0]

for events in events_list:
    print(events)
    event_ids = event_meta[event_meta['conflict'].isin(events)]['event_id'].values
    filtered_warning = conflict_warning[(conflict_warning['event_id'].isin(event_ids))]
    filtered_warning = filtered_warning[filtered_warning['model']==model]
    _, _, optimal_threshold = optimize_threshold(filtered_warning, 'GSSM', return_stats=True)
    print(f'FP:{optimal_threshold['FP']}/({optimal_threshold['TP']+optimal_threshold['FP']}) FN:{optimal_threshold['FN']}')                                            

## Result 5 Attribution

In [None]:
attribution = pd.read_hdf(path_result + 'FeatureAttribution/SafeBaseline_current_environment_profiles.h5').reset_index()
eg_columns = [var for var in attribution.columns[4:25]]
columns = [var[3:] for var in attribution.columns[4:25]]
attribution['eg_sum'] = attribution[eg_columns].sum(axis=1)
positive_mask = (attribution[eg_columns]>0)
attribution['positive_sum'] = (attribution[eg_columns]*positive_mask.astype(int)).sum(axis=1)
negative_mask = (attribution[eg_columns]<0)
attribution['negative_sum'] = (attribution[eg_columns]*negative_mask.astype(int)).sum(axis=1)

conflict_warning = pd.read_hdf(path_result + 'Conflicts/Results/RiskEval_SafeBaseline_current_environment_profiles.h5', key='results')
_, _, optimal_threshold = optimize_threshold(conflict_warning, 'GSSM', return_stats=True)
conflict_warning = conflict_warning[conflict_warning['threshold']==optimal_threshold['threshold']].set_index('event_id')
_, _, meta_events, environment = read_meta(path_processed, path_result)
event_ids = conflict_warning.index.values
meta_events = meta_events.loc[event_ids]
environment = environment.loc[event_ids]

def get_rank(conflict_warning, attribution, optimal_threshold, type='both'):
    warning_attribution = []
    non_warning_attribution = []
    for event_id in tqdm(conflict_warning.index.values):
        start_time, end_time = conflict_warning.loc[event_id][['danger_start','danger_end']].values/1000
        if conflict_warning.loc[event_id]['true_warning']>0.5:
            wattr = attribution[(attribution['event_id']==event_id)&
                                (attribution['time']>=start_time)&(attribution['time']<=end_time)]
            warning_attribution.append(wattr)
        if conflict_warning.loc[event_id]['num_true_non_warning']>0.5:
            nattr = attribution[(attribution['event_id']==event_id)&
                                (attribution['time']<start_time-3)]
            non_warning_attribution.append(nattr)
    warning_attribution = pd.concat(warning_attribution).reset_index(drop=True)
    non_warning_attribution = pd.concat(non_warning_attribution).reset_index(drop=True)

    if type=='non_warning' or type=='both':
        non_warning_attribution = non_warning_attribution[non_warning_attribution['intensity']<=optimal_threshold['threshold']]
        non_warning_statistics = pd.DataFrame(np.zeros((1,len(eg_columns))), columns=eg_columns)
        for idx in tqdm(range(len(non_warning_attribution)), desc='Non-warning attribution'):
            attrs = non_warning_attribution.iloc[idx][eg_columns]
            if np.all(attrs>=0):
                continue
            top3 = attrs[attrs<0].nsmallest(3)
            non_warning_statistics.loc[0,top3.index.values] = non_warning_statistics.loc[0,top3.index.values] + 1 #top3.values
        non_warning_statistics = non_warning_statistics.loc[0]
    else:
        non_warning_statistics = None

    if type=='warning' or type=='both':
        warning_attribution = warning_attribution[warning_attribution['intensity']>optimal_threshold['threshold']]
        warning_statistics = pd.DataFrame(np.zeros((1,len(eg_columns))), columns=eg_columns)
        for idx in tqdm(range(len(warning_attribution)), desc='Warning attribution'):
            attrs = warning_attribution.iloc[idx][eg_columns]
            if np.all(attrs<=0):
                continue
            top3 = attrs[attrs>0].nlargest(3)
            warning_statistics.loc[0,top3.index.values] = warning_statistics.loc[0,top3.index.values] + 1 #top3.values
        warning_statistics = warning_statistics.loc[0]
    else:
        warning_statistics = None
    
    return warning_statistics, non_warning_statistics

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(7.05, 3.), constrained_layout=True, gridspec_kw={'wspace':0.1})
def settle_ax(ax):
    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)
    ax.tick_params(axis='both', which='both', direction='in')
    ax.set_yticks([ax.get_ylim()[0], ax.get_ylim()[1]])
    ax.set_yticklabels([])
    xmax = ax.get_xlim()[1]
    xticks = ax.get_xticks()
    if xticks[-1]>xmax:
        xticks = list(xticks[:-1])
        xticklabels = ax.get_xticklabels()[:-1]
        if xmax-xticks[-1]<xmax*0.02:
            ax.set_xticks(xticks)
            ax.set_xticklabels(xticklabels)
        else:
            ax.set_xticks(xticks+[xmax])
            ax.set_xticklabels(xticklabels+[''])

def get_sorted(statistics, proportion=0.9, number=None):
    statistics = statistics/statistics.sum()
    statistics = statistics.sort_values(ascending=False)
    if number is None:
        statistics = statistics[statistics.cumsum()<proportion]
    else:
        statistics = statistics[:number]
    return statistics

def plot_bars(statistics, ax):
    stat2plot = get_sorted(statistics, proportion=0.9, number=6)
    ax.barh(np.arange(len(stat2plot)), stat2plot.values[::-1],
            color=cmap(np.linspace(0.65, 0.15, len(stat2plot))), alpha=0.75, lw=0.35)
    xmax = ax.get_xlim()[1]
    for pos, label in zip(np.arange(len(stat2plot)), stat2plot.index[::-1]):
        if label=='eg_Sur lat speed':
            label = "eg_Surrounding object's lateral speed"
        if label=='eg_Sur lon speed':
            label = "eg_Surrounding object's longitudinal speed"
        if label=='eg_2D spacing direction':
            label = 'eg_Spacing direction'
        ax.text(xmax*0.99, pos, label.split('_')[1], ha='right', va='center', color='k')
    settle_ax(ax)


ax_conflit = axes[0]
ax_conflit.set_title('(a) Top factors in leading to and avoiding lateral conflicts', pad=15)
remove_box(ax_conflit)
ax_adj_danger = ax_conflit.inset_axes([0, 0.55, 0.45, 0.45])
ax_adj_danger.set_title('Danger in adjacent lane', pad=3)
filtered_warning = conflict_warning.loc[meta_events[(meta_events['conflict']=='Adjacent lane')].index.values]
warning_statistics, non_warning_statistics = get_rank(filtered_warning, attribution, optimal_threshold, type='both')
plot_bars(warning_statistics, ax_adj_danger)

ax_adj_safe = ax_conflit.inset_axes([0.55, 0.55, 0.45, 0.45])
ax_adj_safe.set_title('Safe in adjacent lane', pad=3)
plot_bars(non_warning_statistics, ax_adj_safe)

ax_cat_danger = ax_conflit.inset_axes([0, -0.05, 0.45, 0.45])
ax_cat_danger.set_title('Danger during crossing/turning', pad=3)
filtered_warning = conflict_warning.loc[meta_events[(meta_events['conflict']=='Crossing/turning')].index.values]
warning_statistics, non_warning_statistics = get_rank(filtered_warning, attribution, optimal_threshold, type='both')
plot_bars(warning_statistics, ax_cat_danger)

ax_cat_safe = ax_conflit.inset_axes([0.55, -0.05, 0.45, 0.45])
ax_cat_safe.set_title('Safe during crossing/turning', pad=3)
plot_bars(non_warning_statistics, ax_cat_safe)

ax_environment = axes[1]
ax_environment.set_title('(b) Top factors in conflicts in adverse environments', pad=15)
remove_box(ax_environment)
ax_wea = ax_environment.inset_axes([0, 0.55, 0.45, 0.45])
ax_wea.set_title('Raining weather', pad=3)
filtered_warning = conflict_warning.loc[environment[(environment['weather']=='Raining')|
                                                    (environment['weather']=='Mist/Light Rain')].index.values]
warning_statistics, _ = get_rank(filtered_warning, attribution, optimal_threshold, type='warning')
plot_bars(warning_statistics, ax_wea)

ax_road = ax_environment.inset_axes([0.55, 0.55, 0.45, 0.45])
ax_road.set_title('Not dry road', pad=3)
filtered_warning = conflict_warning.loc[environment[(environment['surfaceCondition']!='Dry')].index.values]
warning_statistics, _ = get_rank(filtered_warning, attribution, optimal_threshold, type='warning')
plot_bars(warning_statistics, ax_road)

ax_light = ax_environment.inset_axes([0, -0.05, 0.45, 0.45])
ax_light.set_title('Not in daylight', pad=3)
filtered_warning = conflict_warning.loc[environment[(environment['lighting']!='Daylight')].index.values]
warning_statistics, _ = get_rank(filtered_warning, attribution, optimal_threshold, type='warning')
plot_bars(warning_statistics, ax_light)

ax_traffic = ax_environment.inset_axes([0.55, -0.05, 0.45, 0.45])
ax_traffic.set_title('LOS D Unstable flow', pad=3)
filtered_warning = conflict_warning.loc[environment[(environment['trafficDensity']=='Level-of-service D: Unstable flow - temporary restrictions substantially slow driver')].index.values]
warning_statistics, _ = get_rank(filtered_warning, attribution, optimal_threshold, type='warning')
plot_bars(warning_statistics, ax_traffic)

for ax in [ax_cat_danger, ax_cat_safe, ax_light, ax_traffic]:
    ax.set_xlabel('Frequency of being top 3 factors')

In [None]:
fig.savefig(path_fig + 'Result5.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

## Figure A1 SHRP2 Reconstruction error distribution

In [None]:
meta_both = pd.read_csv(path_processed + 'SHRP2/metadata_birdseye.csv').set_index('event_id')
meta_both['event'] = [category[c] for c in meta_both['event_category'].values]

fig, axes = plt.subplots(1, 5, figsize=(7.05, 1.8), constrained_layout=True)

var_list = {'Subject speed': ['m/s', 'v_ekf','speed_comp', np.linspace(-0.008, 0.408, 30)],
            'Subject yaw rate': ['rad/s', 'omega_ekf','yaw_rate', np.linspace(-0.00004, 0.00204, 30)],
            'Subject acceleration': ['m/s$^2$', 'acc_ekf','acc_lon', np.linspace(-0.02, 1.02, 30)],
            'Object displacement': ['m', np.linspace(-0.04, 2.04, 30)],
            'Object speed': ['m/s', 'v_ekf','speed_comp', np.linspace(-0.02, 1.02, 30)]}
for event_type, color, text_pos in zip(['Crashes', 'Near-crashes', 'Safe baselines'], [cmap(0.4), cmap(0.65), cmap(0.15)], [0.95, 0.65, 0.35]):
    if event_type == 'Safe baselines':
        data_ego = pd.concat([pd.read_hdf(path_processed + f'SHRP2/SafeBaseline/Ego_birdseye_{i}.h5', key='data') for i in range(1, 5)])
        data_sur = pd.concat([pd.read_hdf(path_processed + f'SHRP2/SafeBaseline/Surrounding_birdseye_{i}.h5', key='data') for i in range(1, 5)])
    else:
        event_categories = meta_both[meta_both['event']==event_type]['event_category'].unique()
        data_ego = []
        data_sur = []
        for event_cat in event_categories:
            data_ego.append(pd.read_hdf(path_processed + f'SHRP2/{event_cat}/Ego_birdseye.h5', key='data'))
            data_sur.append(pd.read_hdf(path_processed + f'SHRP2/{event_cat}/Surrounding_birdseye.h5', key='data'))
        data_ego = pd.concat(data_ego)
        data_sur = pd.concat(data_sur)
    
    data_list = [data_ego, data_ego, data_ego, data_sur, data_sur]
    for col, data, key, values in zip(range(5), data_list, var_list.keys(), var_list.values()):
        if key == 'Object displacement':
            # Mean displacement error
            error = ((data['x_ekf']-data['x'])**2 + (data['y_ekf']-data['y'])**2)**0.5
            error = error.to_frame(name='error')
            error['event_id'] = data['event_id']
            error = error.groupby('event_id')['error'].mean()
        else:
            # Root mean square error
            error = (data[values[1]]-data[values[2]])**2
            error = error.to_frame(name='squared error')
            error['event_id'] = data['event_id']
            error = error.groupby('event_id')['squared error'].mean()**0.5
            
        mean, std = error.mean(), error.std()
        axes[col].hist(error, bins=values[-1], density=True, color=color, alpha=0.6, lw=0, label=event_type)
        if key == 'Subject yaw rate':
            axes[col].text(0.6, text_pos, f'$\\mu={mean:.5f}$\n$\\sigma={std:.5f}$', ha='left', va='top', transform=axes[col].transAxes, color=color)
        else:
            axes[col].text(0.7, text_pos, f'$\\mu={mean:.2f}$\n$\\sigma={std:.2f}$', ha='left', va='top', transform=axes[col].transAxes, color=color)
        axes[col].set_title(f'{key}')
        if key == 'Object displacement':
            axes[col].set_xlabel(f'{key.split(' ')[1].capitalize()} MAE ({values[0]})', labelpad=1)
        else:
            axes[col].set_xlabel(f'{key.split(' ')[1].capitalize()} RMSE ({values[0]})', labelpad=1)
        axes[col].set_yticks([])
axes[0].set_ylabel('Probability density')
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles[:3], labels[:3], loc='lower center', bbox_to_anchor=(0.5, -0.1),
           ncol=3, frameon=False, handlelength=1, handletextpad=0.5, columnspacing=1)

In [None]:
fig.savefig(path_fig + 'SHRP2_error_distributions.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)