## This notebook is used to make tables and figures

In [None]:
import os
import sys
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
import matplotlib as mpl
mpl.rc('hatch', color='w', linewidth=1.5)
font = {'family' : 'Arial',
        'size'   : 6.5}
plt.rc('font', **font)
mpl.rc('axes', titlesize=7, labelsize=6.5)
mpl.rc('xtick', labelsize=6.5)
mpl.rc('ytick', labelsize=6.5)
mpl.rc('legend', fontsize=6.5)
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['axes.linewidth'] = 0.5
plt.rcParams['xtick.major.width'] = 0.5
plt.rcParams['ytick.major.width'] = 0.5
from visual_utils.utils_tabfig import *
sys.path.append('../')
from src_safety_evaluation.validation_utils.utils_eval_metrics import *
from src_safety_evaluation.validation_utils.utils_evaluation import optimize_threshold

single_column_width = 3.46
double_column_width = 7.05

path_processed = '../ProcessedData/'
path_prepared = '../PreparedData/'
path_result = '../ResultData/'
path_raw = '../RawData/'
path_fig = 'Figures/'
if not os.path.exists(path_fig):
    os.makedirs(path_fig)

def savefig(fig, name, path_fig=path_fig):
    fig.savefig(f'{path_fig}{name}.pdf', dpi=600, bbox_inches='tight', pad_inches=0.05)

def number_subfig(ax, label, x, y):
    ax.text(x, y, label, fontsize=8, fontweight='bold', va='top', ha='left', transform=ax.transAxes)

## Figure 1 Accident Statistics

In [None]:
fig, axes = plt.subplots(1,2,figsize=(double_column_width*1.2,3),gridspec_kw={'hspace':0.5,'wspace':0.2,'width_ratios':[1.,1.1]})
ax_total = axes[0].inset_axes([0., 0.72, 0.95, 0.365])
global_deaths_over_years(ax_total, path_raw)
ax_total.set_title('Global deaths in road crashes over years')
ax_unece = axes[0].inset_axes([0., 0., 0.95, 0.45])
unece_accident_reduction(ax_unece, path_raw)
ax_unece.set_title('Accidents by location in UNECE countries', y=1.05)

ax_nl = axes[1].inset_axes([0., 0.56, 0.75, 0.54], xlim=(0, 1), ylim=(-0.5, 0.5))
location_type_nl(ax_nl, path_raw)
ax_nl.set_title('Road crash types 2021-2023 (NL)', y=0.97)
ax_us = axes[1].inset_axes([0., -0.08, 0.75, 0.54], xlim=(0, 1), ylim=(-0.5, 0.5))
location_type_us(ax_us, path_raw)
ax_us.set_title('Road crash types 2021-2023 (US)', y=0.97)

remove_box(axes[0])
remove_box(axes[1])

number_subfig(ax_total, 'a', -0.1, 1.25)
number_subfig(ax_unece, 'b', -0.1, 1.21)
number_subfig(axes[1], 'c', -0.03, 1.17)

In [None]:
savefig(fig, 'Fig1_AccidentStatistics')

## Figure 2 Effectiveness, Scability, Context-awareness

In [None]:
fig = plt.figure(figsize=(double_column_width*1.23, 5.8))
gs = fig.add_gridspec(42, 28, hspace=1.3, wspace=1.5)

# Effectiveness plots
models = ['highD_current', 'ACT', 'TTC2D', 'TAdv', 'EI']
model_labels = ['GSSM', 'ACT', 'TTC2D', 'TAdv', 'EI']
colors = cmap([0.15, 0.3, 0.45, 0.6, 0.75])

warning_files = os.listdir(path_result + 'Conflicts/Results/')
warning_files = [f for f in warning_files if f.startswith('RiskEval_') and f.endswith('.h5')]
conflict_warning = pd.concat([pd.read_hdf(path_result+'Conflicts/Results/'+f, key='results') for f in tqdm(warning_files, desc='Reading files')])
voted_events = pd.read_csv(path_result + 'Conflicts/Voted_conflicting_targets.csv').set_index('event_id')
voted_events = voted_events[voted_events['target_id']>=0]
voted_events['event'] = [category[c] for c in voted_events['event_category'].values]

axroc, axprc, axatc = fig.add_subplot(gs[2:12,:6]), fig.add_subplot(gs[2:12,6:12]), fig.add_subplot(gs[2:12,12:18])
axroc, axprc, axatc = draw_effectiveness(axroc, axprc, axatc, models, colors, conflict_warning)
handles, legends = axroc.get_legend_handles_labels()
fig.legend(handles, ['Safety-critical']+model_labels, loc='lower center', bbox_to_anchor=(0.36, 0.875),
           ncol=len(models)+1, frameon=False, handlelength=2.5, handletextpad=0.4, columnspacing=1)

# Data scalability
warning_files = os.listdir(path_result + 'Analyses/')
warning_files = [f for f in warning_files if f.startswith('Warning_') and f.endswith('.h5')]
warning_files = [f for f in warning_files if 'mixed' in f or 'SafeBaseline' in f]
conflict_warning = pd.concat([pd.read_hdf(path_result+'Analyses/'+f, key='results') for f in tqdm(warning_files, desc='Reading files')])
event_meta = pd.read_csv(path_result + 'Analyses/EventMeta.csv')
ds_axes = [fig.add_subplot(gs[:12,19:28]), fig.add_subplot(gs[16:27,19:28]), fig.add_subplot(gs[17:27,:18])]
ds_axes = draw_data_scalability(conflict_warning, event_meta, axes=ds_axes)
ds_axes[0].set_title('Increasing crossings in ArgoverseHV', pad=5)
labels = ['$A_{80\\%}^\\mathrm{ROC}$', '$A_{90\\%}^\\mathrm{ROC}$', '$\\mathrm{Precision}_{80\\%}^\\mathrm{PRC}$', 
          '$\\mathrm{Precision}_{90\\%}^\\mathrm{PRC}$', '$\\mathrm{AUPRC}$']
handles, _ = ds_axes[0].get_legend_handles_labels()
ds_axes[0].legend(handles, labels, loc='center right', bbox_to_anchor=(1.04, 0.48), ncol=3, fontsize=6,
                  frameon=False, handlelength=2.5, handletextpad=0.4, columnspacing=1, labelspacing=0.)
ds_axes[1].set_title('Increasing lane changes in highD', pad=2)
ds_axes[1].set_ylabel('Metric value', labelpad=1)
handles, _ = ds_axes[1].get_legend_handles_labels()
ds_axes[1].legend(handles, labels, loc='lower right', bbox_to_anchor=(1.04, -0.05), ncol=3, fontsize=6,
                  frameon=False, handlelength=2.5, handletextpad=0.4, columnspacing=1, labelspacing=0.)
ds_axes[2].set_title('Evaluation on lateral interactions', pad=5)

# Feature scalability
warning_files = os.listdir(path_result + 'Conflicts/Results/')
warning_files = [f for f in warning_files if f.startswith('RiskEval_') and f.endswith('.h5')]
conflict_warning = pd.concat([pd.read_hdf(path_result+'Conflicts/Results/'+f, key='results') for f in tqdm(warning_files, desc='Reading files')])
fs_axes = [fig.add_subplot(gs[32:42,4*i:4*(i+1)]) for i in range(7)]
fs_axes = draw_feature_scalability(conflict_warning, axes=fs_axes)
for ax in fs_axes[1:]:
    ax.tick_params(axis='y', which='major', pad=2) # type: ignore
handles, labels = fs_axes[0].get_legend_handles_labels()
fig.legend(handles[0].patches+handles[1].patches,
           ['S-C', 'S-CE', 'S-CET', 'S-Ca', 'S-CaE', 'S-CaET'],
           loc='lower center', bbox_to_anchor=(0.5, 0.065),
           ncol=6, frameon=False, handlelength=1.8, handletextpad=0.5, columnspacing=1.5)

number_subfig(axroc, 'a', -0.32, 1.38)
number_subfig(ds_axes[0], 'b', -0.1, 1.12)
number_subfig(ds_axes[1], 'c', -0.1, 1.1)
number_subfig(ds_axes[2], 'd', -0.04, 1.2)
number_subfig(fs_axes[0], 'e', -0.2, 1.2)

In [None]:
savefig(fig, 'Result123')

# Table 2 Effectiveness

In [None]:
metric_table = pd.DataFrame(columns=['model','auprc','aroc_80','aroc_90','pprc_80','pprc_90','PTTI_star','mTTI_star'])
for model, model_label in zip(models, model_labels):
    metrics = get_eval_metrics(conflict_warning[conflict_warning['model']==model],
                               thresholds={'roc': [0.80, 0.90], 'prc':[0.80, 0.90],'tti':1.5}, with_CI=True)
    metrics['model'] = model_label
    metric_table.loc[len(metric_table), list(metrics.keys())] = list(metrics.values())
metric_table[metric_table.columns[1:]] = metric_table[metric_table.columns[1:]].astype(float)
mtti_star = ['mTTI_star', 'TTI_star_q25', 'TTI_star_q75', 'TTI_star_lowCI', 'TTI_star_upCI']
metric_table['mTTI_star'] = metric_table['mTTI_star'].round(2)
metric_table = highlight(metric_table, max_cols=metric_table.columns[1:-4], involve_second=True)
metric_table['mTTI_star'] = (metric_table['mTTI_star'] + ' [' + 
                             metric_table['TTI_star_q25'].apply(lambda x: f"{x:.2f}") + '--' + 
                             metric_table['TTI_star_q75'].apply(lambda x: f"{x:.2f}") + ']; ' +
                             metric_table['TTI_star_lowCI'].apply(lambda x: f"{x:.2f}") + '--' +
                             metric_table['TTI_star_upCI'].apply(lambda x: f"{x:.2f}"))
metric_table['mTTI_star'] = [metric_table['mTTI_star'].iloc[i].replace('0 [',' [').replace('0}','}') for i in range(len(metric_table))]
metric_table = metric_table.drop(columns=mtti_star[1:])
metric_table = metric_table.rename(columns={'model':'Method', 'auprc':'AUPRC',
                             'aroc_80':'$A_{80\\%}^\\mathrm{ROC}$', 'aroc_90':'$A_{90\\%}^\\mathrm{ROC}$',
                             'pprc_80':'$\\mathrm{Precision}_{80\\%}^\\mathrm{PRC}$', 'pprc_90':'$\\mathrm{Precision}_{90\\%}^\\mathrm{PRC}$',
                             'PTTI_star':'$P^*_{\\mathrm{TTI}\\geq1.5}$','mTTI_star':'$m\\mathrm{TTI}^*$ [$Q1$--$Q3$]; $99\\%CI$'})
metric_table = metric_table.set_index('Method').loc[model_labels]
metric_table

In [None]:
print(metric_table.to_latex())

# Table 3 One GSSM for all interactions

In [None]:
event_meta = event_meta[event_meta['conflict']!='none']
models = ['SafeBaseline_current_environment_profiles', 'ACT', 'TTC2D', 'TAdv', 'EI']
model_labels = ['GSSM', 'ACT', 'TTC2D', 'TAdv', 'EI']

event_metric_list = []
event_metric_list.append([['leading'], 'Rear-end'])
event_metric_list.append([['adjacent_lane'], 'Adjacent lane'])
event_metric_list.append([['turning_into_parallel','turning_across_opposite','turning_into_parallel','turning_into_opposite','intersection_crossing'], 
                         'Crossing/turning'])
event_metric_list.append([['merging'], 'Merging'])
event_metric_list.append([['pedestrian','cyclist','animal'], 'With pedestrian/cyclist/animal'])

metric_table = []
for events, event_label in tqdm(event_metric_list):
    event_metrics = pd.DataFrame(columns=['Event', 'model','auprc','aroc_80','aroc_90','pprc_80','pprc_90','PTTI_star','mTTI_star'])
    for model, model_label in zip(models, model_labels):
        event_ids = event_meta[event_meta['conflict'].isin(events)]['event_id'].values
        filtered_warning = conflict_warning[(conflict_warning['event_id'].isin(event_ids))]
        filtered_warning = filtered_warning[filtered_warning['model']==model]

        metrics = get_eval_metrics(filtered_warning, thresholds={'roc': [0.80, 0.90], 'prc':[0.80, 0.90],'tti':1.5}, with_CI=True)
        metrics['Event'] = event_label
        metrics['model'] = model_label
        event_metrics.loc[len(event_metrics), list(metrics.keys())] = list(metrics.values())
    event_metrics[event_metrics.columns[2:]] = event_metrics[event_metrics.columns[2:]].astype(float)
    mtti_star = ['mTTI_star', 'TTI_star_q25', 'TTI_star_q75', 'TTI_star_lowCI', 'TTI_star_upCI']
    event_metrics['mTTI_star'] = event_metrics['mTTI_star'].round(2)
    event_metrics = highlight(event_metrics, max_cols=event_metrics.columns[2:-4], involve_second=True)
    event_metrics['mTTI_star'] = (event_metrics['mTTI_star'] + ' [' + 
                                event_metrics['TTI_star_q25'].apply(lambda x: f"{x:.2f}") + '--' + 
                                event_metrics['TTI_star_q75'].apply(lambda x: f"{x:.2f}") + ']; ' +
                                event_metrics['TTI_star_lowCI'].apply(lambda x: f"{x:.2f}") + '--' +
                                event_metrics['TTI_star_upCI'].apply(lambda x: f"{x:.2f}"))
    event_metrics['mTTI_star'] = [event_metrics['mTTI_star'].iloc[i].replace('0 [',' [').replace('0}','}') for i in range(len(event_metrics))]
    event_metrics = event_metrics.drop(columns=mtti_star[1:])
    event_metrics['Number of events'] = f"{filtered_warning['event_id'].nunique()}"
    metric_table.append(event_metrics)
metric_table = pd.concat(metric_table, axis=0)
metric_table = metric_table.rename(columns={'model':'Method', 'aroc_80':'$A_{80\\%}^\\mathrm{ROC}$', 'aroc_90':'$A_{90\\%}^\\mathrm{ROC}$',
                                            'pprc_80':'$\\mathrm{Precision}_{80\\%}^\\mathrm{PRC}$', 'pprc_90':'$\\mathrm{Precision}_{90\\%}^\\mathrm{PRC}$',
                                            'auprc':'$\\mathrm{AUPRC}$', 'PTTI_star':'$P^*_{\\mathrm{TTI}\\geq1.5}$', 'mTTI_star':'$m\\mathrm{TTI}^*$ [$Q1$--$Q3$]; $99\\%CI$'})

metric_table = metric_table.set_index(['Event','Number of events','Method'])
metric_table.replace('nan', 'N/A', inplace=True)
metric_table

In [None]:
print(metric_table.to_latex())

In [None]:
events_list = [['leading'], ['merging']]
model = models[0]

for events in events_list:
    print(events)
    event_ids = event_meta[event_meta['conflict'].isin(events)]['event_id'].values
    filtered_warning = conflict_warning[(conflict_warning['event_id'].isin(event_ids))]
    filtered_warning = filtered_warning[filtered_warning['model']==model]
    _, _, optimal_threshold = optimize_threshold(filtered_warning, 'GSSM', return_stats=True)
    print(f"FP:{optimal_threshold['FP']}/({optimal_threshold['TP']+optimal_threshold['FP']}) FN:{optimal_threshold['FN']}")

# Figure 3 Attribution

In [None]:
attribution = pd.read_hdf(path_result + 'FeatureAttribution/SafeBaseline_current_environment_profiles.h5').reset_index()
eg_columns = [var for var in attribution.columns[4:25]]
columns = [var[3:] for var in attribution.columns[4:25]]
attribution['eg_sum'] = attribution[eg_columns].sum(axis=1)
positive_mask = (attribution[eg_columns]>0)
attribution['positive_sum'] = (attribution[eg_columns]*positive_mask.astype(int)).sum(axis=1)
negative_mask = (attribution[eg_columns]<0)
attribution['negative_sum'] = (attribution[eg_columns]*negative_mask.astype(int)).sum(axis=1)

conflict_warning = pd.read_hdf(path_result + 'Conflicts/Results/RiskEval_SafeBaseline_current_environment_profiles.h5', key='results')
_, _, optimal_threshold = optimize_threshold(conflict_warning, 'GSSM', return_stats=True)
conflict_warning = conflict_warning[conflict_warning['threshold']==optimal_threshold['threshold']].set_index('event_id')
_, _, meta_events, environment = read_meta(path_processed, path_result)
event_ids = conflict_warning.index.values
meta_events = meta_events.loc[event_ids]
environment = environment.loc[event_ids]

fig, axes = plt.subplots(1, 2, figsize=(double_column_width, 2.8), constrained_layout=True, gridspec_kw={'wspace':0.07})
ax_conflit = axes[0]
ax_conflit.set_title('Top factors in leading to and avoiding lateral conflicts', pad=15)
remove_box(ax_conflit)
ax_adj_danger = ax_conflit.inset_axes([0, 0.55, 0.46, 0.45])
ax_adj_danger.set_title('Danger in adjacent lane', pad=3)
filtered_warning = conflict_warning.loc[meta_events[(meta_events['conflict']=='Adjacent lane')].index.values]
warning_statistics, non_warning_statistics = get_rank(filtered_warning, attribution, eg_columns, optimal_threshold, type='both')
plot_bars(warning_statistics, ax_adj_danger)

ax_adj_safe = ax_conflit.inset_axes([0.54, 0.55, 0.46, 0.45])
ax_adj_safe.set_title('Safe in adjacent lane', pad=3)
plot_bars(non_warning_statistics, ax_adj_safe)

ax_cat_danger = ax_conflit.inset_axes([0, -0.05, 0.46, 0.45])
ax_cat_danger.set_title('Danger during crossing/turning', pad=3)
filtered_warning = conflict_warning.loc[meta_events[(meta_events['conflict']=='Crossing/turning')].index.values]
warning_statistics, non_warning_statistics = get_rank(filtered_warning, attribution, eg_columns, optimal_threshold, type='both')
plot_bars(warning_statistics, ax_cat_danger)

ax_cat_safe = ax_conflit.inset_axes([0.54, -0.05, 0.46, 0.45])
ax_cat_safe.set_title('Safe during crossing/turning', pad=3)
plot_bars(non_warning_statistics, ax_cat_safe)

ax_environment = axes[1]
ax_environment.set_title('Top factors in conflicts in adverse environments', pad=15)
remove_box(ax_environment)
ax_wea = ax_environment.inset_axes([0, 0.55, 0.46, 0.45])
ax_wea.set_title('Raining weather', pad=3)
filtered_warning = conflict_warning.loc[environment[(environment['weather']=='Raining')|
                                                    (environment['weather']=='Mist/Light Rain')].index.values]
warning_statistics, _ = get_rank(filtered_warning, attribution, eg_columns, optimal_threshold, type='warning')
plot_bars(warning_statistics, ax_wea)

ax_road = ax_environment.inset_axes([0.54, 0.55, 0.46, 0.45])
ax_road.set_title('Not dry road', pad=3)
filtered_warning = conflict_warning.loc[environment[(environment['surfaceCondition']!='Dry')].index.values]
warning_statistics, _ = get_rank(filtered_warning, attribution, eg_columns, optimal_threshold, type='warning')
plot_bars(warning_statistics, ax_road)

ax_light = ax_environment.inset_axes([0, -0.05, 0.46, 0.45])
ax_light.set_title('Not in daylight', pad=3)
filtered_warning = conflict_warning.loc[environment[(environment['lighting']!='Daylight')].index.values]
warning_statistics, _ = get_rank(filtered_warning, attribution, eg_columns, optimal_threshold, type='warning')
plot_bars(warning_statistics, ax_light)

ax_traffic = ax_environment.inset_axes([0.54, -0.05, 0.46, 0.45])
ax_traffic.set_title('Unstable traffic flow', pad=3)
filtered_warning = conflict_warning.loc[environment[(environment['trafficDensity']=='Level-of-service D: Unstable flow - temporary restrictions substantially slow driver')].index.values]
warning_statistics, _ = get_rank(filtered_warning, attribution, eg_columns, optimal_threshold, type='warning')
plot_bars(warning_statistics, ax_traffic)

for ax in [ax_cat_danger, ax_cat_safe, ax_light, ax_traffic]:
    ax.set_xlabel('Frequency of being top 3 factors')

number_subfig(ax_conflit, 'a', -0.005, 1.15)
number_subfig(ax_environment, 'b', -0.005, 1.15)

In [None]:
savefig(fig, 'Result5')

# Extended Data Figure 2 SHRP2 statistics

In [None]:
fig, axes = draw_SHRP2(path_processed, path_result, figsize=(double_column_width, 3.8))
axes[0].set_title('Event counts', y=1.05)
axes[1].set_title('Event types', y=1.05)
axes[2].set_title('Weather and road surface conditions', y=1.0)
axes[3].set_title('Lighting conditions', y=1.0)
axes[4].set_title('Traffic conditions', y=1.0)
number_subfig(axes[0], 'a', -0.03, 1.15)
number_subfig(axes[1], 'b', -0.1, 1.15)
number_subfig(axes[2], 'c', -0.36, 1.14)
number_subfig(axes[3], 'd', -0.05, 1.14)
number_subfig(axes[4], 'e', -0.1, 1.14)

In [None]:
savefig(fig, 'SHRP2_event_counts')

# Extended Data Figure 3 Danger and safe period

In [None]:
fig, axes = draw_periods(figsize=(double_column_width, 1.8))

In [None]:
savefig(fig, 'Safe_danger_period')

# Extended Data Figure 4 One GSSM for all interactions

In [None]:
warning_files = os.listdir(path_result + 'Conflicts/Results/')
warning_files = [f for f in warning_files if f.startswith('RiskEval_') and f.endswith('.h5')]
conflict_warning = pd.concat([pd.read_hdf(path_result+'Conflicts/Results/'+f, key='results') for f in tqdm(warning_files, desc='Reading files')])
event_meta = pd.read_csv(path_result + 'Analyses/EventMeta.csv')
voted_events = pd.read_csv(path_result + 'Conflicts/Voted_conflicting_targets.csv').set_index('event_id')
voted_events = voted_events[voted_events['target_id']>=0]
voted_events['event'] = [category[c] for c in voted_events['event_category'].values]

models = ['SafeBaseline_current_environment_profiles', 
          'ACT', 
          'TTC2D']

fig, axes = draw_generalisability(conflict_warning, models, event_meta, voted_events, figsize=(double_column_width*0.94, 2.8))
axes[0].set_title('Event type distribution', pad=0)
axes[1].set_title('Receiver operating characteristic (ROC) curves', pad=11.5)
axes[2].set_title('Precision-recall (PRC) curves', pad=5)
axes[3].set_title('Accuracy-timeliness (ATC() curves', pad=5)
number_subfig(axes[0], 'a', -0.05, 1.04)
number_subfig(axes[1], 'b', -0.04, 1.35)

In [None]:
savefig(fig, 'Result4')

# Supplementary Figure 1 SHRP2 Reconstruction error distribution

In [None]:
meta_both = pd.read_csv(path_processed + 'SHRP2/metadata_birdseye.csv').set_index('event_id')
meta_both['event'] = [category[c] for c in meta_both['event_category'].values]

fig, axes = plt.subplots(1, 5, figsize=(double_column_width, 1.5), constrained_layout=True)

var_list = {'Subject speed': ['m/s', 'v_ekf','speed_comp', np.linspace(-0.008, 0.408, 30)],
            'Subject yaw rate': ['rad/s', 'omega_ekf','yaw_rate', np.linspace(-0.00004, 0.00204, 30)],
            'Subject acceleration': ['m/s$^2$', 'acc_ekf','acc_lon', np.linspace(-0.02, 1.02, 30)],
            'Object displacement': ['m', np.linspace(-0.04, 2.04, 30)],
            'Object speed': ['m/s', 'v_ekf','speed_comp', np.linspace(-0.02, 1.02, 30)]}
for event_type, color, text_pos in zip(['Crashes', 'Near-crashes', 'Safe baselines'], [cmap(0.4), cmap(0.65), cmap(0.15)], [0.95, 0.65, 0.35]):
    if event_type == 'Safe baselines':
        data_ego = pd.concat([pd.read_hdf(path_processed + f'SHRP2/SafeBaseline/Ego_birdseye_{i}.h5', key='data') for i in range(1, 5)])
        data_sur = pd.concat([pd.read_hdf(path_processed + f'SHRP2/SafeBaseline/Surrounding_birdseye_{i}.h5', key='data') for i in range(1, 5)])
    else:
        event_categories = meta_both[meta_both['event']==event_type]['event_category'].unique()
        data_ego = []
        data_sur = []
        for event_cat in event_categories:
            data_ego.append(pd.read_hdf(path_processed + f'SHRP2/{event_cat}/Ego_birdseye.h5', key='data'))
            data_sur.append(pd.read_hdf(path_processed + f'SHRP2/{event_cat}/Surrounding_birdseye.h5', key='data'))
        data_ego = pd.concat(data_ego)
        data_sur = pd.concat(data_sur)
    
    data_list = [data_ego, data_ego, data_ego, data_sur, data_sur]
    for col, data, key, values in zip(range(5), data_list, var_list.keys(), var_list.values()):
        if key == 'Object displacement':
            # Mean displacement error
            error = ((data['x_ekf']-data['x'])**2 + (data['y_ekf']-data['y'])**2)**0.5
            error = error.to_frame(name='error')
            error['event_id'] = data['event_id']
            error = error.groupby('event_id')['error'].mean()
        else:
            # Root mean square error
            error = (data[values[1]]-data[values[2]])**2
            error = error.to_frame(name='squared error')
            error['event_id'] = data['event_id']
            error = error.groupby('event_id')['squared error'].mean()**0.5
            
        mean, std = error.mean(), error.std()
        axes[col].hist(error, bins=values[-1], density=True, color=color, alpha=0.6, lw=0, label=event_type)
        if key == 'Subject yaw rate':
            axes[col].text(0.6, text_pos, f'$\\mu={mean:.5f}$\n$\\sigma={std:.5f}$', ha='left', va='top', transform=axes[col].transAxes, color=color)
        else:
            axes[col].text(0.7, text_pos, f'$\\mu={mean:.2f}$\n$\\sigma={std:.2f}$', ha='left', va='top', transform=axes[col].transAxes, color=color)
        axes[col].set_title(f'{key}', pad=5)
        if key == 'Object displacement':
            axes[col].set_xlabel(f"{key.split(' ')[1].capitalize()} MAE ({values[0]})", labelpad=1)
        else:
            axes[col].set_xlabel(f"{key.split(' ')[1].capitalize()} RMSE ({values[0]})", labelpad=1)
        axes[col].set_yticks([])
axes[0].set_ylabel('Probability density')
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles[:3], labels[:3], loc='lower center', bbox_to_anchor=(0.5, -0.09),
           ncol=3, frameon=False, handlelength=1, handletextpad=0.5, columnspacing=1)

In [None]:
savefig(fig, 'SHRP2_error_distributions')