# Description

This notebook is used to evaluate the hearing thresholds predicted by neural networks (NN) or estimated by a sound level regression method (SLR) by comparison with a ground truth which are manually assessed thresholds.

Different evaluation approaches are used to compare machine-determined hearing thresholds with the ground truth corresponding to manually assessed thresholds.</br>
For this, a series of experiments are done in which the two different ABR threshold finding methods (NN and SLR) were tested on [GMC](https://www.mouseclinic.de/) and [ING](https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.3000194) data sets:
1. **NN GMC-GMC**: NN trained on GMC data and tested on GMC data
2. NN GMC-ING: NN trained on GMC data and tested on ING data
3. NN ING-GMC: NN trained on ING data and tested on GMC data
4. **NN ING-ING**: NN trained on ING data and tested on ING data
5. **SLR GMC-GMC**: SLR calibrated on GMC data and tested on GMC data
6. SLR GMC-ING: SLR calibrated on GMC data and tested on ING data
7. SLR ING-GMC: SLR calibrated on ING data and tested on GMC data
8. **SLR ING-ING**: SLR calibrated on ING data and tested on ING data

Furthermore, evaluation curves are calculated and plotted to enable an estimation of the quality of thresholding.</br>
**Evaluation curves** allow the relative comparison of threshold finding methods without requiring absolute ground truth labels. For this, four experiments are done to manual thresholds (blue, dotted lines), SLR estimations (<font color='red'>**red**</font>, dashed lines), NN predictions (<font color='green'>**green**</font>, dash-dotted lines), and a ”always 50 dB” dummy method (<font color='grey'>**grey**</font>, solid lines).
1. **NN GMC-GMC, SLR GMC-GMC, GMC manual thresholds, dummy method**
2. NN GMC-ING, SLR GMC-ING, ING manual thresholds, dummy method
3. NN ING-GMC, SLR ING-GMC, GMC manual thresholds, dummy method
4. **NN ING-ING, SLR ING-ING, ING manual thresholds, dummy method**

Following files were used: 

* Files with ABR curves and the manually assessed thresholds per stimulus:
    * GMC data: _GMC_abr_curves.csv_
    * ING data: _ING_abr_curves.csv_

* Files with thresholds predicted by neural networks: 
    * NN GMC-GMC: _../results/GMC_data_GMCtrained_NN_predictions.csv_
    * NN GMC-ING: _../results/ING_data_GMCtrained_NN_predictions.csv_
    * NN ING-GMC: _../results/GMC_data_INGtrained_NN_predictions.csv_
    * NN ING-ING: _../results/ING_data_INGtrained_NN_predictions.csv_
    
* Files with thresholds estimated by a sound level regression method: 
    * SLR GMC-GMC: _../results/GMC_data_GMCcalibrated_SLR_estimations.csv_
    * SLR GMC-ING: _../results/ING_data_GMCcalibrated_SLR_estimations.csv_
    * SLR ING-GMC: _../results/GMC_data_INGcalibrated_SLR_estimations.csv_
    * SLR ING-ING: _../results/ING_data_INGcalibrated_SLR_estimations.csv_

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Load libraries

In [None]:
import os
import warnings

import pandas as pd
import numpy as np
import seaborn as sns

import ABR_ThresholdFinder_NN.data_preparation as dataprep
import ABR_ThresholdFinder_NN.thresholder as abrthr

from ABR_ThresholdFinder_SLR.evaluations import plot_evaluation_curve_for_specific_stimulus

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.ticker as ticker
plt.rcParams['figure.figsize'] = [10, 8]

os.environ["CUDA_VISIBLE_DEVICES"]=""

from sklearn.metrics import confusion_matrix, precision_score, accuracy_score, recall_score, f1_score, r2_score
from sklearn.metrics import auc, average_precision_score, precision_recall_curve, plot_precision_recall_curve, roc_curve, roc_auc_score
from sklearn.metrics import classification_report

import itertools

import scipy.stats as sp

# Definitions

In [None]:
"""Set the path to the data files, for example '../data'"""
path2data = '../data/'
"""Set the path to result data"""
path2results = '../results/'
"""Name the time step columns"""
datacols = ['t' + str(i) for i in range(0, 1000)]

## Utils

In [None]:
def plot_standardised_error(_data, _manual_thr_col, _pred_thr_col, 
                            _xlabel, _ylabel, _title, _freq_specific=False, 
                            _hue=None, _fontsize=30, _figsize=(30,15)):
    """
    Plots the error of the automatically determined threshold values normalised to the 
    variance of the manually determined threshold values.
    
    Parameters
    ----------
        _data: pandas-data-frame
            It contains both manually assessed and automatically determined thresholds for each mouse identifier 
            and hearing stimulus. 
            
        _manual_thr_col: string
            The name of the column in which the manually assessed thresholds are stored.
            
        _pred_thr_col: string 
            The name of the column in which the automatically determined thresholds are stored.
            
        _freq_specific: boolean
            If true error is plotted for each stimulus separately.
    """
    
    
    df = _data.copy()
    for col in [_manual_thr_col, _pred_thr_col]:
        df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]
    
    df['thr_diff'] = abs(df[_pred_thr_col] - df[_manual_thr_col])
    var = np.var(df[_manual_thr_col])
    df['thr_diff'] = df['thr_diff']/var
    
    if _freq_specific:
        frequencies = df.frequency.unique()

        rows = 3
        cols = 2
    
        col = 0
        row = 0

        fig, axes = plt.subplots(ncols=cols, nrows=rows, #gridspec_kw={'wspace': 5},
                                 sharey=True, sharex=True, figsize=_figsize, constrained_layout=True)
        
        for idx, freq in enumerate(sorted(frequencies)):
            
            sns.histplot(data=df[df.frequency == freq], x="thr_diff", stat="percent", # "count",
                         palette=sns.color_palette("colorblind"), bins=35, ax=axes[row,col])
            
            axes[row,col].set_yticks([x for x in range(10,110,10)])
            axes[row,col].set_ylim(0,100)
            
            axes[row,col].set_xticks([x for x in np.arange(0,0.35,0.05)])
            axes[row,col].set_xlim(0,0.3)
            
            axes[row,col].set_xticklabels(["{:.2f}".format(xtick) for xtick in axes[row,col].get_xticks()], fontsize=_fontsize-10)
            axes[row,col].set_yticklabels(['%d%%' % ytick for ytick in axes[row,col].get_yticks()], fontsize=_fontsize-10)

            if col == 0:
                axes[row,col].set_ylabel(_ylabel, fontsize=_fontsize, labelpad=15)
            if row == rows-1:
                axes[row,col].set_xlabel(_xlabel, fontsize=_fontsize, labelpad=15)
            if freq == 100:
                axes[row,col].set_title('Click', fontsize=_fontsize, pad=15)
            else:
                axes[row,col].set_title('%d kHz' % (freq/1000), fontsize=_fontsize, pad=15)
    
            col += 1
            if col == cols:
                row += 1
                col = 0
        plt.suptitle(_title, fontsize=_fontsize)
    else:
        plt.figure(figsize=_figsize)
        ax = sns.histplot(data=df, x="thr_diff", stat="percent", # "count", 
                          palette=sns.color_palette("colorblind"), bins=35, hue=_hue)
        ax.set_yticks([x for x in range(0,110,10)])
        ax.set_ylim(0,100)
        
        ax.set_xticklabels(["{:.2f}".format(xtick) for xtick in ax.get_xticks()], fontsize=_fontsize-10)
        ax.set_yticklabels(['%d%%' % ytick for ytick in ax.get_yticks()], fontsize=_fontsize-10)
#         ax.set_yticklabels([int(ytick) for ytick in ax.get_yticks()], fontsize=_fontsize-10)
        
        ax.set_xlabel(_xlabel, fontsize=_fontsize)
        ax.set_ylabel(_ylabel, fontsize=_fontsize)
        ax.set_title(_title, fontsize=_fontsize, pad=15)

In [None]:
def plot_median(_data, _column, _title=None, _fontsize=30, _figsize=(30,30)):
    
    """
    Plots the medians of the automatically determined against the manually assessed thresholds.
    
    Parameters
    ----------
        _data: pandas-data-frame
            It contains both manually assessed and automatically determined thresholds for each mouse identifier 
            and hearing stimulus. 
            
        _column: string
            The name of the column in which the automatically determined thresholds are stored.
    """
    
    df = _data.copy()
    
    fig = plt.figure(figsize=_figsize)
    
    fontsize = _fontsize

    ax = sns.lineplot(data=df, x='threshold', y=_column, estimator=np.median, 
                      ci=None, marker='o', markersize=15, palette=sns.color_palette("colorblind"), legend='auto')    
        
    ax.set_yticks([x for x in range(0,110,5)])
    ax.set_ylim(0,100)
    ax.set_xticks([x for x in range(0,110,5)])
    ax.set_xlim(0,105)
        
    ax.set_xticklabels(ax.get_xticks(), fontsize=fontsize-10)
    ax.set_yticklabels(ax.get_yticks(), fontsize=fontsize-10)
        
    if _title:
        ax.set_title(_title, fontsize=fontsize+5, pad=20)
    ax.set_xlabel('Manual threshold', fontsize=fontsize, labelpad=15)
    ax.set_ylabel('NN predicted threshold' if 'nn' in _column else 'SLR estimated threshold', fontsize=fontsize, labelpad=15)
          
    bounds = df.groupby('threshold')[_column].quantile((0.25,0.75)).unstack()
    ax.fill_between(x=bounds.index, y1=bounds.iloc[:,0], y2=bounds.iloc[:,1], alpha=0.1)
    #ax.text(1, 95.5, 'median values and the interquartile range of the NN predicted thresholds', fontsize=fontsize-10)
    
    fig.legend(['median values and the interquartile range of the NN predicted thresholds'], loc='lower left', bbox_to_anchor= (0.065, 1.0), 
               borderaxespad=0, frameon=True, fontsize=fontsize-10)
    fig.tight_layout()
    
    plt.show()

In [None]:
def plot_median_stimulus_specific(_data, _column, _title=None,  _fontsize=30, _figsize=(30,15)):
    
    """
    Plots the medians of the automatically determined thresholds against the manually assessed thresholds 
    for each stimulus separately.
    
    Parameters
    ----------
        _data: pandas-data-frame
            It contains both manually assessed and automatically determined thresholds for each mouse identifier 
            and hearing stimulus. 
            
        _column: string
            The name of the column in which the automatically determined thresholds are stored.
    """
    
    df = _data
    
    cols = 2
    rows = 3

    row = 0
    col = 0
    
    fig, axes = plt.subplots(ncols=cols, nrows=rows, sharey=True, sharex=True, 
                             figsize=_figsize, constrained_layout=True)
    fontsize = _fontsize

    for freq in df.frequency.unique():
    
        axes[row, col].set_yticks([x for x in range(0,110,5)])
        axes[row, col].set_ylim(0,100)
        axes[row, col].set_xticks([x for x in range(0,110,5)])
        axes[row, col].set_xlim(0,105)
        
        axes[row, col].set_xticklabels(axes[row, col].get_xticks(), fontsize=fontsize-10)
        axes[row, col].set_yticklabels(axes[row, col].get_yticks(), fontsize=fontsize-10)
    
        if freq == 100:
            axes[row, col].set_title('Click', fontsize=fontsize+5, pad=20)
        else:
            axes[row, col].set_title('%dkHz' % int(freq/1000), fontsize=fontsize, pad=20)
        if row == rows-1:
            axes[row, col].set_xlabel('Manual threshold', fontsize=fontsize, labelpad=15)
        else:
            axes[row, col].set_xlabel('')
        if col == 0:
            axes[row, col].set_ylabel('NN predicted threshold' if 'nn' in _column else 'SLR estimated threshold', fontsize=fontsize)
            
        else:
            axes[row, col].set_ylabel('')
    
        # line plot
        sns.lineplot(data=df.loc[df.frequency==freq], 
                     x='threshold', y=_column, 
                     estimator=np.median, ci=None, marker='o', markersize=15, ax=axes[row, col], 
                     palette=sns.color_palette("colorblind"))
        
        bounds = df.loc[df.frequency==freq].groupby('threshold')[_column].quantile((0.25,0.75)).unstack()
        axes[row, col].fill_between(x=bounds.index, y1=bounds.iloc[:,0], y2=bounds.iloc[:,1], alpha=0.1)
        
        col += 1
        if col == cols:
            row += 1
            col = 0 
    if _title:
        fig.suptitle(_title, fontsize=fontsize)
    
    fig.legend(['median values and the interquartile range of the NN predicted thresholds'], loc='lower left', bbox_to_anchor= (0.035, 1.01), 
               borderaxespad=0, frameon=True, fontsize=fontsize)
    
    plt.show()

In [None]:
def plot_confusion_matrix(_data, _column, _title=None, _fontsize=30, _figsize=(30,15)): 
    
    """
    Plots the confusion matrix for automatically determined threshold values.
    
    Parameters
    ----------
        _data: pandas-data-frame
            It contains both manually assessed and automatically determined thresholds for each mouse identifier 
            and hearing stimulus. 
            
        _column: string
            The name of the column in which the automatically determined thresholds are stored.
    """
    
    predicted_thr = _column
    result_columns = ['mouse_id', 'frequency', 'threshold'] + [_column]
    
    result_test2 = _data[result_columns].drop_duplicates()
    
    confusion_mtx = confusion_matrix(result_test2['threshold'], result_test2[predicted_thr]) 
   
    columns = [int(x) for x in sorted(result_test2[predicted_thr].unique())]
    merged_list = list(itertools.chain(*itertools.zip_longest(result_test2['threshold'].unique(), 
                                                              result_test2[predicted_thr].unique())))
    merged_list = [i for i in merged_list if i is not None]
    merged_list = sorted(set(merged_list))
    
    confusion_mtx = pd.DataFrame(confusion_mtx, index = merged_list, 
                                 columns = merged_list)
    for sl in range(0, 100, 5):
        if sl not in confusion_mtx.index:
            confusion_mtx[sl] = 0
            confusion_mtx.loc[sl] = 0
    confusion_mtx = confusion_mtx.sort_index(axis=1).sort_index()
    
    fig = plt.figure(figsize=_figsize)
    fontsize = _fontsize
    
    ax = sns.heatmap(confusion_mtx, annot=True, fmt="d", 
                     annot_kws={"size": fontsize-15}, cbar_kws={"pad": 0.02},
                     square=True, center=250, vmin=0, vmax=1000, cmap='Spectral_r') 
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('top') 
    label = 'NN predicted threshold' if 'nn' in _column else 'SLR estimated threshold'
    
    ax.set_xticklabels(ax.get_xmajorticklabels(), fontsize=fontsize-10)
    ax.set_yticklabels(ax.get_ymajorticklabels(), fontsize=fontsize-10)
    
    ax.set_xlabel(label, fontsize=fontsize, labelpad=15)
    ax.set_ylabel('Manual threshold', fontsize=fontsize, labelpad=15)
    
    cbar = ax.collections[0].colorbar
    cbar.ax.set_yticklabels(cbar.ax.get_ymajorticklabels(), fontsize=fontsize-10)  
    
    if _title:
        ax.set_title(_title, fontsize=fontsize+5, pad=20)

In [None]:
def plot_confusion_matrix_stimulus_specific(_data, _column, _title=None, _fontsize=30, _figsize=(30,45)): 
    
    """
    Plots the confusion matrix for automatically determined thresholds for each stimulus separately.
    
    Parameters
    ----------
        _data: pandas-data-frame
            It contains both manually assessed and automatically determined thresholds for each mouse identifier 
            and hearing stimulus. 
            
        _column: string
            The name of the column in which the automatically determined thresholds are stored.
    """
    
    predicted_thr = _column
    result_columns = ['mouse_id', 'frequency', 'threshold'] + [_column]
    
    rows = 3
    cols = 2
    
    row = 0
    col = 0
    
    fig, axes = plt.subplots(ncols=cols, nrows=rows, #gridspec_kw={'wspace': 5},
                             sharey=False, sharex=False, figsize=_figsize, constrained_layout=True)
    fontsize = _fontsize
    
    for freq in _data.frequency.unique():
        
        result_test2 = _data[_data.frequency == freq][result_columns].drop_duplicates()
    
        confusion_mtx = confusion_matrix(result_test2['threshold'], result_test2[predicted_thr]) 
        index = [int(x) for x in sorted(result_test2['threshold'].unique())]
   
        columns = [int(x) for x in sorted(result_test2[predicted_thr].unique())]
        merged_list = list(itertools.chain(*itertools.zip_longest(result_test2['threshold'].unique(), 
                                                              result_test2[predicted_thr].unique())))
        merged_list = [i for i in merged_list if i is not None]
        merged_list = sorted(set(merged_list))
    
        confusion_mtx = pd.DataFrame(confusion_mtx, index = merged_list, 
                                     columns = merged_list)
        for sl in range(0, 100, 5):
            if sl not in confusion_mtx.index:
                confusion_mtx[sl] = 0
                confusion_mtx.loc[sl] = 0
        confusion_mtx = confusion_mtx.sort_index(axis=1).sort_index()
        
        sns.heatmap(confusion_mtx, annot=True, fmt="d", 
                    annot_kws={"size": fontsize-15}, cbar_kws={"pad": 0.02, 'shrink': 0.9},
                    square=True, center=250, vmin=0, vmax=300, cmap='Spectral_r', ax=axes[row,col]) 
        axes[row,col].xaxis.tick_top()
        axes[row,col].xaxis.set_label_position('top') 
        
        axes[row,col].set_xticklabels(axes[row,col].get_xmajorticklabels(), fontsize=fontsize-10)
        axes[row,col].set_yticklabels(axes[row,col].get_ymajorticklabels(), fontsize=fontsize-10)
        
        label = 'NN predicted threshold' if 'nn' in _column else 'SLR estimated threshold'
        axes[row,col].set_xlabel(label, fontsize=fontsize, labelpad=15)
        if col == 0:
            axes[row,col].set_ylabel('Manual threshold', fontsize=fontsize, labelpad=15)
        if freq == 100:
            axes[row,col].set_title('Click', fontsize=fontsize, pad=15)
        else:
            axes[row,col].set_title('%d kHz' % (freq/1000), fontsize=fontsize, pad=15)
            
        cbar = axes[row,col].collections[0].colorbar
        cbar.ax.set_yticklabels(cbar.ax.get_ymajorticklabels(), fontsize=fontsize-10)  
        
        col += 1
        if col == cols:
            row += 1
            col = 0
    
    if _title:
        fig.suptitle(_title, fontsize=_fontsize)
    else:
        fig.suptitle('NN models' if 'nn' in _column else 'SLR method', fontsize=_fontsize-10)
       

In [None]:
def double_std(array):
    return np.std(array) * 2

def q1(x):
    return x.quantile(0.25)

def q3(x):
    return x.quantile(0.75)

def plot_threshold_stats(_data, _columns=['threshold', 'nn_predicted_thr'], 
                         _stat='mean', _fontsize=20):
    
    # map 999 to 100
    df = _data.copy()
    for col in _columns:
        df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]
    
    df1 = df[['frequency', _columns[0]]].copy()   
    df2 = df[['frequency', _columns[1]]].copy()
    
    if 'nn' in _columns[1]:
        colors = ['green', 'mediumblue']
    else:
        colors = ['red', 'mediumblue']
        
    df1['threshold_type']=['Click: manual' if df1.at[idx, 'frequency']==100 else '%dkHz: manual' % (df1.at[idx, 'frequency']/1000) for idx in df1.index]
    df2.rename(columns={_columns[1]: 'threshold'}, inplace=True)
    if 'nn' in _columns[1]:
        df2['threshold_type']=['Click: nn predicted' if df2.at[idx, 'frequency']==100 else '%dkHz: nn predicted' % (df2.at[idx, 'frequency']/1000) for idx in df2.index]
    else:
        df2['threshold_type']=['Click: slr estimated' if df2.at[idx, 'frequency']==100 else '%dkHz: slr estimated' % (df2.at[idx, 'frequency']/1000) for idx in df2.index]
        
    df = df1.append(df2).reset_index(drop=True) 
    df = df[df.columns.drop('frequency')]
    thresholds = df.groupby(['threshold_type']).agg(
        [np.mean, np.std, double_std, sp.sem, np.median, q1, q3])
        
    if 'nn' in _columns[1]:
        row_type = 'nn predicted'
    else:
        row_type = 'slr estimated'
    order = 0
    for freq in sorted(_data.frequency.unique(), reverse=True):
        if freq > 100:
            thresholds.at['%dkHz: %s' % ((freq/1000), row_type), 'order'] = order
            order+=1
            thresholds.at['%dkHz: manual' % (freq/1000), 'order'] = order
            order+=1             
    thresholds.at['Click: %s' % row_type, 'order'] = order
    order+=1
    thresholds.at['Click: manual', 'order'] = order
    thresholds.sort_values(by='order', inplace=True)        
        
    xerr = 'std'
    _title = 'Threshold Mean'
    if _stat == 'median':
        xerr = None
        _title = 'Threshold Median'
        
    ax = thresholds['threshold'].plot(kind='barh', y=_stat, legend=False, grid=True, zorder=3,
                                      color=colors, xerr=xerr, capsize=4, 
                                      fontsize=_fontsize)
    ax.grid(zorder=0, color='lightgray')
    ax.set_title(_title, pad=15)
    ax.title.set_size(_fontsize+5)
        
    if _stat == 'median':
        for idx, thr_type in enumerate(thresholds['threshold'].index): 
                
            plt.hlines(y=thr_type, 
                       xmin=thresholds['threshold'].at[thr_type, 'q1'], xmax=thresholds['threshold'].at[thr_type, 'q3'], 
                       colors='black', linewidths=1.5)
            plt.vlines(x=thresholds['threshold'].at[thr_type, 'q1'], ymin=idx-0.12, ymax=idx+0.12, colors='black', linewidths=1)
            plt.vlines(x=thresholds['threshold'].at[thr_type, 'q3'], ymin=idx-0.12, ymax=idx+0.12, colors='black', linewidths=1)
        
    ax.set_xlabel('dB', fontsize=_fontsize)
    ax.set_ylabel('')
        
    for key, spine in ax.spines.items():
        spine.set_visible(False)
    ax.tick_params(bottom=False, left=False)

In [None]:
def plot_threshold_boxplots(_data, _columns=['threshold', 'nn_predicted_thr'], 
                            _figsize=(15, 10), _fontsize=20):
    
    # map 999 to 100
    df = _data.copy()
    for col in _columns:
        df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]
    
    df1 = df[['frequency', _columns[0]]].copy()   
    df2 = df[['frequency', _columns[1]]].copy()
    
    df1['threshold type']=['Click: manual' if df1.at[idx, 'frequency']==100 else '%dkHz: manual' % (df1.at[idx, 'frequency']/1000) for idx in df1.index]
    df2.rename(columns={_columns[1]: 'threshold'}, inplace=True)
    if 'nn' in _columns[1]:
        df2['threshold type']=['Click: nn predicted' if df2.at[idx, 'frequency']==100 else '%dkHz: nn predicted' % (df2.at[idx, 'frequency']/1000) for idx in df2.index]
    else:
        df2['threshold type']=['Click: slr estimated' if df2.at[idx, 'frequency']==100 else '%dkHz: slr estimated' % (df2.at[idx, 'frequency']/1000) for idx in df2.index]
        
    df = pd.concat([df1, df2]).reset_index(drop=True)
    df.rename(columns={'threshold': 'threshold (dB)'}, inplace=True)
    
    if 'nn' in _columns[1]:
        colors = ['mediumblue', 'green']
    else:
        colors = ['mediumblue', 'red']
        
    if 'nn' in _columns[1]:
        row_type = 'nn predicted'
    else:
        row_type = 'slr estimated'
    
    fig, ax = plt.subplots(figsize=_figsize)
    ax.set_ylabel('dB', fontsize=_fontsize, labelpad=15)
    
    # Remove top and right border
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    # Remove y-axis tick marks
    ax.yaxis.set_ticks_position('none')
    
    # Add major gridlines
    ax.grid(zorder=0, color='lightgray', linestyle='-')#, linewidth=0.25, alpha=0.5)
    ax.title.set_size(_fontsize+5)
    
    # Set threshold types as labels for the boxplot
    labels = ['Click: manual', 'Click: %s' % row_type]
    for freq in sorted(df.frequency.unique()):
        if freq > 100:
            labels.append('%dkHz: manual' % (freq/1000))
            labels.append('%dkHz: %s' % ((freq/1000), row_type))
            
    dataset = [df[df['threshold type'] == thr_type]['threshold (dB)'] for thr_type in labels]

    """
    We want to apply different properties to each threshold type, so we're going to plot one boxplot
    for each threshold type and set their properties individually
        positions: position of the boxplot in the plot area
        medianprops: dictionary of properties applied to median line
        whiskerprops: dictionary of properties applied to the whiskers
        capprops: dictionary of properties applied to the caps on the whiskers
        flierprops: dictionary of properties applied to outliers
    """
    
    medianprops = dict(linestyle='-', linewidth=3, color='k')
    
    for idx,ds in enumerate(dataset):
        ax.boxplot(ds, zorder=3,  
                   positions=[idx+1], labels=[labels[idx]], 
                   boxprops=dict(color=colors[idx%2], linewidth=3), 
                   medianprops=medianprops, 
                   whiskerprops=dict(color=colors[idx%2]), 
                   capprops=dict(color=colors[idx%2]), 
                   flierprops=dict(markeredgecolor=colors[idx%2]))
    
    plt.xticks(fontsize=_fontsize, rotation = 90)
    plt.yticks(fontsize=_fontsize)
    plt.show()

In [None]:
def plot_evaluation_curves(_dataset, _xlabel, _ylabel, _fontsize=30, _figsize=(20, 13)):
    
    # name columns containing the ABR wave time series data
    timeseries_columns = ['t%d' %i for i in range(1000)] 
    
    thresholds = ['threshold manual', 'threshold SLR', 'threshold NN', 50]
    fig = plt.figure(constrained_layout=True, figsize=_figsize)
    
    ncols = 3
    nrows = int(len(_dataset.frequency.unique())/ncols)
    col = 0
    row = 0
    spec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)
    ax = {}
    
    legend_elements = []
    for idx, freq in enumerate(sorted(_dataset.frequency.unique())):
        if idx == 0:
            ax[idx] = fig.add_subplot(spec[row, col])
        else:
            ax[idx] = fig.add_subplot(spec[row, col], sharex=ax[idx-1], sharey=ax[idx-1])
        if freq == 100:
            ax[idx].set_title('Click', fontsize=_fontsize, y=1.01)#, fontweight='bold'
        else:
            ax[idx].set_title('%dkHz' % int(float(freq)/1000), fontsize=_fontsize, y=1.01)#, fontweight='bold'
    
        ax[idx].set_xticks([0.25,0.5,0.75,1.0])
        ax[idx].tick_params(axis='x', labelsize=_fontsize-5)
        ax[idx].tick_params(axis='y', labelsize=_fontsize-5)
        # ax[idx].grid(color='lightgray')
    
        if col == 0:
            if row == nrows-1:
                legend_elements = plot_evaluation_curve_for_specific_stimulus(_dataset, freq, thresholds,
                                                                              timeseries_columns,
                                                                              frequency = 'frequency',
                                                                              sound_level = 'sound_level', 
                                                                              fontsize=_fontsize,
                                                                              legend=False,
                                                                              xlabel=_xlabel,
                                                                              ylabel=_ylabel,
                                                                              ax=ax[idx])
            else:
                legend_elements = plot_evaluation_curve_for_specific_stimulus(_dataset, freq, thresholds,
                                                                              timeseries_columns,
                                                                              frequency = 'frequency',
                                                                              sound_level = 'sound_level', 
                                                                              fontsize=_fontsize,
                                                                              legend=False, 
                                                                              xlabel=None,
                                                                              ylabel=_ylabel,
                                                                              ax=ax[idx])
        else:
            if row==nrows-1:
                legend_elements = plot_evaluation_curve_for_specific_stimulus(_dataset, freq, thresholds,
                                                                              timeseries_columns,
                                                                              frequency = 'frequency',
                                                                              sound_level = 'sound_level', 
                                                                              fontsize=_fontsize,
                                                                              legend=False,
                                                                              xlabel=_xlabel,
                                                                              ylabel=None,
                                                                              ax=ax[idx])
            else:
                legend_elements = plot_evaluation_curve_for_specific_stimulus(_dataset, freq, thresholds,
                                                                              timeseries_columns,
                                                                              frequency = 'frequency',
                                                                              sound_level = 'sound_level', 
                                                                              fontsize=_fontsize,
                                                                              legend=False, 
                                                                              xlabel=None,
                                                                              ylabel=None,
                                                                              ax=ax[idx])
    
        col+=1
        if col == ncols:
            row+=1
            col=0
            
    leg = fig.legend(handles=legend_elements, loc='lower left', bbox_to_anchor= (0.061, 1.01), ncol=2, 
                     borderaxespad=0, frameon=True, fontsize=_fontsize, title_fontsize=_fontsize-5)
    leg._legend_box.align = "left"

# Load GMC data

Load the ABR curves from the German Mouse Clinic and the lists of mice used to train, validate and test the NN/SLR models.</br>
The files can be found under the path specified by _path2data_:

* _GMC/GMC_abr_curves.csv_
* _GMC/GMC_train_mice.npy_
* _GMC/GMC_valid_mice.npy_
* _GMC/GMC_test_mice.npy_

## Load the manually assessed thresholds

In [None]:
"""Load the ABR curves from GMC"""
GMC_data = pd.read_csv(os.path.join(path2data, 'GMC', 'GMC_abr_curves.csv'), low_memory=False)

In [None]:
"""Load training, validation and test mouse ids"""
GMC_mice = {}
for _ in ['train', 'valid', 'test']:
    GMC_mice[_] = np.load(os.path.join(path2data, 'GMC', 'GMC_'+_+'_mice.npy'))
    print('GMC %s mice: %d' % (_, len(GMC_mice[_])))
    
GMC_data['mouse_group'] = ['train' if mouse_id in GMC_mice['train'] else 'valid' if mouse_id in GMC_mice['valid'] else 'test' for mouse_id in GMC_data['mouse_id']]

## Load the NN predicted thresholds

In [None]:
"""Load predictions by neural networks trained with GMC data (GMCtrained_NN)"""
GMC_data_predictions1 = pd.read_csv(os.path.join(path2results, 'GMC_data_GMCtrained_NN_predictions.csv'))

In [None]:
"""Load predictions by neural networks trained with Ingham et al. data (INGtrained_NN)"""
GMC_data_predictions2 = pd.read_csv(os.path.join(path2results, 'GMC_data_INGtrained_NN_predictions.csv'))

## Load the SLR estimated thresholds

In [None]:
"""Load estimations made by SLR calibrated with GMC training data (GMCcalibrated_SLR)"""
GMC_data_estimations1 = pd.read_csv(os.path.join(path2results, 'GMC_data_GMCcalibrated_SLR_estimations.csv'))

In [None]:
"""Load estimations made by SLR calibrated with ING training data (INGcalibrated_SLR)"""
GMC_data_estimations2 = pd.read_csv(os.path.join(path2results, 'GMC_data_INGcalibrated_SLR_estimations.csv'))

# Load ING data

Load the ABR curves provided by Ingham et al. and the lists of mice used to train, validate and test the NN/SLR models.</br>
The files can be found under the path specified by _path2data_:
* _ING/ING_abr_curves.csv_
* _ING/ING_train_mice.npy_
* _ING/ING_valid_mice.npy_
* _ING/ING_test_mice.npy_

In [None]:
"""Load the ING ABR curves"""
ING_data = pd.read_csv(os.path.join(path2data, 'ING', 'ING_abr_curves.csv'), low_memory=False)

In [None]:
"""Load training, validation and test mouse IDs"""

# call load_data with allow_pickle set to true
ING_mice = {}
for _ in ['train', 'valid', 'test']:
    ING_mice[_] = np.load(os.path.join(path2data, 'ING', 'ING_'+_+'_mice.npy'), allow_pickle=True)
    print('ING %s mice: %d' % (_, len(ING_mice[_])))

## Load NN predicted thresholds

In [None]:
"""Load predictions by neural networks trained with GMC data (GMCtrained_NN)"""
ING_data_predictions1 = pd.read_csv(os.path.join(path2results, 'ING_data_GMCtrained_NN_predictions.csv'))

In [None]:
"""Load predictions by neural networks trained with ING data (INGtrained_NN)"""
ING_data_predictions2 = pd.read_csv(os.path.join(path2results, 'ING_data_INGtrained_NN_predictions.csv'))

## Load SLR estimated thresholds

In [None]:
"""Load estimations by SLR calibrated with GMC training data (GMCcalibrated_SLR)"""
ING_data_estimations1 = pd.read_csv(os.path.join(path2results, 'ING_data_GMCcalibrated_SLR_estimations.csv'))

In [None]:
"""Load estimations by SLR calibrated with ING training data (INGcalibrated_SLR)"""
ING_data_estimations2 = pd.read_csv(os.path.join(path2results, 'ING_data_INGcalibrated_SLR_estimations.csv'))

# Evaluation approaches 

+ threshold detection accuracy
+ standardised error plots - threshold detection absolute error standardised to the hearing threshold variance
+ relationship between manually and ML-detected thresholds using the median as aggregation method
+ confusion matrix plots
+ stimulus specific threshold averages
+ stimulus specific threshold medians

## Experiment 1

NNs trained with the GMC training data set and tested on the GMC test data set:
* NN GMC-GMC

In [None]:
GMC_test_data_predictions1 = GMC_data_predictions1[GMC_data_predictions1.mouse_id.isin(GMC_mice['test'])]
title = 'GMC trained NNs / GMC test data set'

In [None]:
"""Print overall mouse metrics"""
abrthr.print_overall_mouse_metrics(
    GMC_test_data_predictions1, _predicted_thr_col='nn_predicted_thr')

In [None]:
"""Plot of the threshold error normalised to the variance, overall and stimulus specific"""

xlabel = 'Threshold error normalised to the variance'
ylabel = 'Threshold assessments in percent'

plot_standardised_error(GMC_test_data_predictions1, 
                        'threshold', 'nn_predicted_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel)
plot_standardised_error(GMC_test_data_predictions1, 
                       'threshold', 'nn_predicted_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel,
                        _freq_specific=True, 
                        _figsize=(30,30))

In [None]:
np.median(df[df.threshold==15]['nn_predicted_thr'])

In [None]:
"""Plot of the medians of the NN predicted thresholds against the manually assessed thresholds, overall and stimulus specific"""

# map 999 to 100
df = GMC_test_data_predictions1.copy()
for col in ['threshold', 'nn_predicted_thr']:
    df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]

with sns.axes_style("whitegrid"):
    plot_median(df, 'nn_predicted_thr', _figsize=(21,20))
    plot_median_stimulus_specific(df, 'nn_predicted_thr', _figsize=(30,45))

In [None]:
"""Plot of the confusion matrix, overall and stimulus specific"""

plot_confusion_matrix(GMC_test_data_predictions1, 'nn_predicted_thr')
plot_confusion_matrix_stimulus_specific(GMC_test_data_predictions1, 'nn_predicted_thr', _figsize=(30,45), _title=' ')

In [None]:
"""Stimulus specific plots of the threshold mean as horizontal bars"""
plot_threshold_stats(GMC_test_data_predictions1)

"""Stimulus specific plots of the threshold median as horizontal bars"""
plot_threshold_stats(GMC_test_data_predictions1, _stat='median')

"""Plot of threshold value boxplots grouped by threshold type"""
plot_threshold_boxplots(GMC_test_data_predictions1)

## Experiment 2

NNs trained with the GMC training data set and tested on ING data:
* NN GMC-ING

In [None]:
title = 'GMC trained NNs / ING data set'

In [None]:
"""Print overall mouse metrics"""
abrthr.print_overall_mouse_metrics(ING_data_predictions1, 
                                  _predicted_thr_col='nn_predicted_thr')

In [None]:
"""Plot of the threshold error normalised to the variance, overall and stimulus specific"""

xlabel = 'Threshold error normalised to the variance'
ylabel = 'Threshold assessments in percent'

plot_standardised_error(ING_data_predictions1, 
                        'threshold', 'nn_predicted_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel)
plot_standardised_error(ING_data_predictions1, 
                        'threshold', 'nn_predicted_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel, 
                        _freq_specific=True, 
                        _figsize=(30,30))

In [None]:
"""Plot of the medians of the NN predicted thresholds against the manually assessed thresholds, overall and stimulus specific"""

# map 999 to 100
df = ING_data_predictions1.copy()
for col in ['threshold', 'nn_predicted_thr']:
    df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]

with sns.axes_style("whitegrid"):
    plot_median(df, 'nn_predicted_thr', _figsize=(21,20))
    plot_median_stimulus_specific(df, 'nn_predicted_thr', _figsize=(30,45))

In [None]:
"""Plot of the confusion matrix, overall and stimulus specific"""

plot_confusion_matrix(ING_data_predictions1, 'nn_predicted_thr')
plot_confusion_matrix_stimulus_specific(ING_data_predictions1, 'nn_predicted_thr', _figsize=(30,45), _title=' ')

In [None]:
"""Stimulus specific plots of the threshold mean as horizontal bars"""
plot_threshold_stats(ING_data_predictions1)

"""Stimulus specific plots of the threshold median as horizontal bars"""
plot_threshold_stats(ING_data_predictions1, _stat='median')

"""Plot of threshold value boxplots grouped by threshold type"""
plot_threshold_boxplots(ING_data_predictions1)

## Experiment 3

NNs trained with the ING training data set and tested on GMC data:
* NN ING-GMC

In [None]:
title = 'ING trained NNs / GMC data set'

In [None]:
"""Print overall mouse metrics"""
abrthr.print_overall_mouse_metrics(GMC_data_predictions2, 
                                  _predicted_thr_col='nn_predicted_thr')

In [None]:
"""Plot of the threshold error normalised to the variance, overall and stimulus specific"""

xlabel = 'Threshold error normalised to the variance'
ylabel = 'Threshold assessments in percent'

plot_standardised_error(GMC_data_predictions2, 
                        'threshold', 'nn_predicted_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel)
plot_standardised_error(GMC_data_predictions2, 
                        'threshold', 'nn_predicted_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel, 
                        _freq_specific=True, 
                        _figsize=(30,30))

In [None]:
"""Plot of the medians of the NN predicted thresholds against the manually assessed thresholds, overall and stimulus specific"""

# map 999 to 100
df = GMC_data_predictions2.copy()
for col in ['threshold', 'nn_predicted_thr']:
    df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]

with sns.axes_style("whitegrid"):
    plot_median(df, 'nn_predicted_thr', _figsize=(21,20))
    plot_median_stimulus_specific(df, 'nn_predicted_thr', _figsize=(30,45))

In [None]:
"""Plot of the confusion matrix, overall and stimulus specific"""

plot_confusion_matrix(GMC_data_predictions2, 'nn_predicted_thr')
plot_confusion_matrix_stimulus_specific(GMC_data_predictions2, 'nn_predicted_thr', _figsize=(30,45), _title=' ')

In [None]:
"""Stimulus specific plots of the threshold mean as horizontal bars"""
plot_threshold_stats(GMC_data_predictions2)

"""Stimulus specific plots of the threshold median as horizontal bars"""
plot_threshold_stats(GMC_data_predictions2, _stat='median')

"""Plot of threshold value boxplots grouped by threshold type"""
plot_threshold_boxplots(GMC_data_predictions2)

## Experiment 4

NNs trained with the ING training data set and tested on the ING test data set:
* NN ING-ING

In [None]:
ING_test_data_predictions2 = ING_data_predictions2[ING_data_predictions2.mouse_id.isin(ING_mice['test'])]
title = 'ING trained NNs / ING test data set'

In [None]:
"""Print overall mouse metrics"""
abrthr.print_overall_mouse_metrics(ING_test_data_predictions2, 
                                  _predicted_thr_col='nn_predicted_thr')

In [None]:
"""Plot of the threshold error normalised to the variance, overall and stimulus specific"""

xlabel = 'Threshold error normalised to the variance'
ylabel = 'Threshold assessments in percent'

plot_standardised_error(ING_test_data_predictions2, 
                        'threshold', 'nn_predicted_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel)
plot_standardised_error(ING_test_data_predictions2, 
                        'threshold', 'nn_predicted_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel, 
                        _freq_specific=True, 
                        _figsize=(30,30))

In [None]:
"""Plot of the medians of the NN predicted thresholds against the manually assessed thresholds, overall and stimulus specific"""

# map 999 to 100
df = ING_test_data_predictions2.copy()
for col in ['threshold', 'nn_predicted_thr']:
    df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]
    
with sns.axes_style("whitegrid"):
    plot_median(df, 'nn_predicted_thr', _figsize=(21,20))
    plot_median_stimulus_specific(df, 'nn_predicted_thr', _figsize=(30,45))

In [None]:
"""Plot of the confusion matrix, overall and stimulus specific"""

plot_confusion_matrix(ING_test_data_predictions2, 'nn_predicted_thr')
plot_confusion_matrix_stimulus_specific(ING_test_data_predictions2, 'nn_predicted_thr', _figsize=(30,45), _title=' ')

In [None]:
"""Stimulus specific plots of the threshold mean as horizontal bars"""
plot_threshold_stats(ING_test_data_predictions2)

"""Stimulus specific plots of the threshold median as horizontal bars"""
plot_threshold_stats(ING_test_data_predictions2, _stat='median')

"""Plot of threshold value boxplots grouped by threshold type"""
plot_threshold_boxplots(ING_test_data_predictions2)

## Experiment 5

SLR method calibrated with the GMC training data set and tested on GMC data:
* SLR GMC-GMC

In [None]:
title = 'GMC calibrated SLR method / GMC data set'

In [None]:
"""Print overall mouse metrics"""
abrthr.print_overall_mouse_metrics(GMC_data_estimations1, 
                                  _predicted_thr_col='slr_estimated_thr')

In [None]:
"""Plot of the threshold error normalised to the variance, overall and stimulus specific"""

xlabel = 'Threshold error normalised to the variance'
ylabel = 'Threshold assessments in percent'

plot_standardised_error(GMC_data_estimations1, 
                        'threshold', 'slr_estimated_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel)
plot_standardised_error(GMC_data_estimations1, 
                        'threshold', 'slr_estimated_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel, 
                        _freq_specific=True, 
                        _figsize=(30,30))

In [None]:
"""Plot of the medians of the SLR estimated thresholds against the manually assessed thresholds, overall and stimulus specific"""

# map 999 to 100
df = GMC_data_estimations1.copy()
for col in ['threshold', 'slr_estimated_thr']:
    df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]

with sns.axes_style("whitegrid"):
    plot_median(df, 'slr_estimated_thr', _figsize=(21,20))
    plot_median_stimulus_specific(df, 'slr_estimated_thr', _figsize=(30,45))

In [None]:
"""Plot of the confusion matrix, overall and stimulus specific"""

plot_confusion_matrix(GMC_data_estimations1, 'slr_estimated_thr')
plot_confusion_matrix_stimulus_specific(GMC_data_estimations1, 'slr_estimated_thr', _figsize=(30,45), _title=' ')

In [None]:
"""Stimulus specific plots of the threshold mean as horizontal bars"""
plot_threshold_stats(GMC_data_estimations1, _columns=['threshold', 'slr_estimated_thr'])

"""Stimulus specific plots of the threshold median as horizontal bars"""
plot_threshold_stats(GMC_data_estimations1, _stat='median', _columns=['threshold', 'slr_estimated_thr'])

"""Plot of threshold value boxplots grouped by threshold type"""
plot_threshold_boxplots(GMC_data_estimations1, _columns=['threshold', 'slr_estimated_thr'])

## Experiment 6

SLR method calibrated with the GMC training data set and tested on ING data:
* SLR GMC-ING

In [None]:
title = 'GMC calibrated SLR method / ING data set'

In [None]:
"""Print overall mouse metrics"""
abrthr.print_overall_mouse_metrics(ING_data_estimations1, 
                                  _predicted_thr_col='slr_estimated_thr')

In [None]:
"""Plot of the threshold error normalised to the variance, overall and stimulus specific"""

xlabel = 'Threshold error normalised to the variance'
ylabel = 'Threshold assessments in percent'

plot_standardised_error(ING_data_estimations1, 
                        'threshold', 'slr_estimated_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel)
plot_standardised_error(ING_data_estimations1, 
                        'threshold', 'slr_estimated_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel, 
                        _freq_specific=True, 
                        _figsize=(30,30))

In [None]:
"""Plot of the medians of the SLR estimated thresholds against the manually assessed thresholds, overall and stimulus specific"""

# map 999 to 100
df = ING_data_estimations1.copy()
for col in ['threshold', 'slr_estimated_thr']:
    df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]

with sns.axes_style("whitegrid"):
    plot_median(df, 'slr_estimated_thr', _figsize=(21,20))
    plot_median_stimulus_specific(df, 'slr_estimated_thr', _figsize=(30,45))

In [None]:
"""Plot of the confusion matrix, overall and stimulus specific"""

plot_confusion_matrix(ING_data_estimations1, 'slr_estimated_thr')
plot_confusion_matrix_stimulus_specific(ING_data_estimations1, 'slr_estimated_thr', _figsize=(30,45), _title=' ')

In [None]:
"""Stimulus specific plots of the threshold mean as horizontal bars"""
plot_threshold_stats(ING_data_estimations1, _columns=['threshold', 'slr_estimated_thr'])

"""Stimulus specific plots of the threshold median as horizontal bars"""
plot_threshold_stats(ING_data_estimations1, _stat='median', _columns=['threshold', 'slr_estimated_thr'])

"""Plot of threshold value boxplots grouped by threshold type"""
plot_threshold_boxplots(ING_data_estimations1, _columns=['threshold', 'slr_estimated_thr'])

## Experiment 7 

SLR method calibrated with the ING training data set and tested on GMC data:
* SLR ING-GMC

In [None]:
title = 'ING calibrated SLR method / GMC data set'

In [None]:
"""Print overall mouse metrics"""
abrthr.print_overall_mouse_metrics(GMC_data_estimations2, 
                                  _predicted_thr_col='slr_estimated_thr')

In [None]:
"""Plot of the threshold error normalised to the variance, overall and stimulus specific"""

xlabel = 'Threshold error normalised to the variance'
ylabel = 'Threshold assessments in percent'

plot_standardised_error(GMC_data_estimations2, 
                        'threshold', 'slr_estimated_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel)
plot_standardised_error(GMC_data_estimations2, 
                        'threshold', 'slr_estimated_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel, 
                        _freq_specific=True, 
                        _figsize=(30,30))

In [None]:
"""Plot of the medians of the SLR estimated thresholds against the manually assessed thresholds, overall and stimulus specific"""

# map 999 to 100
df = GMC_data_estimations2.copy()
for col in ['threshold', 'slr_estimated_thr']:
    df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]
    
with sns.axes_style("whitegrid"):
    plot_median(df, 'slr_estimated_thr', _figsize=(21,20))
    plot_median_stimulus_specific(df, 'slr_estimated_thr', _figsize=(30,45))

In [None]:
"""Plot of the confusion matrix overall and stimulus specific"""

plot_confusion_matrix(GMC_data_estimations2, 'slr_estimated_thr')
plot_confusion_matrix_stimulus_specific(GMC_data_estimations2, 'slr_estimated_thr', _figsize=(30,45), _title=' ')

In [None]:
"""Stimulus specific plots of the threshold mean as horizontal bars"""
plot_threshold_stats(GMC_data_estimations2, _columns=['threshold', 'slr_estimated_thr'])

"""Stimulus specific plots of the threshold median as horizontal bars"""
plot_threshold_stats(GMC_data_estimations2, _stat='median', _columns=['threshold', 'slr_estimated_thr'])

"""Plot of threshold value boxplots grouped by threshold type"""
plot_threshold_boxplots(GMC_data_estimations2, _columns=['threshold', 'slr_estimated_thr'])

## Experiment 8 

SLR method calibrated with the ING training data set and tested on ING data
* SLR ING-ING

In [None]:
title = 'ING calibrated SLR method / ING data set'

In [None]:
"""Print overall mouse metrics"""
abrthr.print_overall_mouse_metrics(ING_data_estimations2, 
                                  _predicted_thr_col='slr_estimated_thr')

In [None]:
"""Plot of the threshold error normalised to the variance, overall and stimulus specific"""

xlabel = 'Threshold error normalised to the variance'
ylabel = 'Threshold assessments in percent'

plot_standardised_error(ING_data_estimations2, 
                        'threshold', 'slr_estimated_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel)
plot_standardised_error(ING_data_estimations2, 
                        'threshold', 'slr_estimated_thr', 
                        _title='', _xlabel=xlabel, _ylabel=ylabel, 
                        _freq_specific=True, 
                        _figsize=(30,30))

In [None]:
"""Plot of the medians of the SLR estimated thresholds against the manually assessed thresholds, overall and stimulus specific"""

# map 999 to 100
df = ING_data_estimations2.copy()
for col in ['threshold', 'slr_estimated_thr']:
    df[col] = [100 if df.loc[idx, col] == 999 else df.loc[idx, col] for idx in df.index]

with sns.axes_style("whitegrid"):
    plot_median(df, 'slr_estimated_thr', _figsize=(21,20))
    plot_median_stimulus_specific(df, 'slr_estimated_thr', _figsize=(30,45))

In [None]:
"""Plot of the confusion matrix, overall and stimulus specific"""

plot_confusion_matrix(ING_data_estimations2, 'slr_estimated_thr')
plot_confusion_matrix_stimulus_specific(ING_data_estimations2, 'slr_estimated_thr', _figsize=(30,45), _title=' ')

In [None]:
"""Stimulus specific plots of the threshold mean as horizontal bars"""
plot_threshold_stats(ING_data_estimations2, _columns=['threshold', 'slr_estimated_thr'])

"""Stimulus specific plots of the threshold median as horizontal bars"""
plot_threshold_stats(GMC_data_estimations2, _stat='median', _columns=['threshold', 'slr_estimated_thr'])

"""Plot of threshold value boxplots grouped by threshold type"""
plot_threshold_boxplots(ING_data_estimations2, _columns=['threshold', 'slr_estimated_thr'])

# Evaluation curves

Four threshold types are evaluated and compared using **evaluation curves**:

+ the threshols predicted with neural networks ('threshold NN') 
+ the thresholds estimated by a sound level regression method ('threshold SLR')
+ the human ground truth ('threshold manual')
+ a constant threshold ('50')

### Evaluation curves - brief description:

For each ABR wave a threshold_normalized_sound_level = sound_level/threshold
is computed.
Then the ABR waves are sorted with respect to the threshold_normalized_sound_level in increasing order.

Next we compute the strength of signal average over the N ABR-curves with smallest threshold_normalized_sound_level and plot it against N.

N and the strength of signal are both normalized to have a maximum of 1.

**If method A is better than method B then the evaluation curve of method A is strictly smaller then the evaluation curve of method B.**

### Reasoning behind the evaluation curves:

Let us assume we know the ground-truth threshold.
Given this ground-truth threshold we can take the sample average of all super-threshold curves, as well as the sample average of all sub-threshold curves. The sample average of all super-threshold curves should give a temporal pattern, as the mice respond to a cue sound in a temporal coherent way. In contrast to this, averaging the sub-threshold ABR curves should give a constant signal, as - due to the lack of a percepted cue - the ABR curves are temporally incoherent.

If we now sort all ABR cuves by their threshold normalized sound level (=sound_level/threshold) in increasing order and compute the cummulative average, we should obtain an approximately constant signal, until we start adding curves that are above the threshold. (The 'approximately' is due to the finite sample size.) If we use the ground truth threshold to do this sorting, the averaged curve will not deviate signifiantly from a constant signal before we have added all subthreshold curves to the cummulative average. However if we use a suboptimal thresholding, the averaged signal should already deviate from a constant signal before all sub-threshold curves are taken into account.

Based on this, we can construct evaluation curves, that compare the quality of thresholding methods:
We plot the (normalized) temporal variance of the averaged signal (which is zero for a constant signal) versus the total fraction of ABR curves that are contained in the cummulative average. For the ground truth threshold this curve should thus be about equal to zero until 'Normalized N' is equal to the number sub-threshold curves divided by the total number ABR curves (= sub + super threshold). After that it should increase.

For suboptimal thresholds the curve should start to deviate from zero allready at a smaller level of 'Normalized N'.
The more errorprone the thresholds are, the faster the corresponding evaluation curve deviates from zero.
As we can see in the plots all methods start deviating from zero allready quite early.
So none of them seems to be perfect. 

Nevertheless these curves allow us to judge the relative quality of different thesholding methods:
If Method A is better than method B then the evaluation curve of method A is strictly smaller then the evaluation curve of method B. The plot shows, that an assumed constant threshold, gives the worst performance of the three thresholds compared. (**Note**, that the evaluation curves are invariant to which constant threshold is assumed: The curve looks the same for all constant thresholds.)
We can also see that the 'human ground truth' (called 'threshold manual' in the graph) shows actually a worse performance than the ML detected thresholds ('threshold NN', 'threshold SLR).

## Experiment 1

* NN GMC-GMC, SLR GMC-GMC, GMC manual thresholds, dummy method

In [None]:
exp1_dataset = pd.merge(left=GMC_data, 
                        right=GMC_data_predictions1[['mouse_id', 'frequency', 'threshold', 'nn_predicted_thr']], 
                        how='left',
                        on=['mouse_id', 'frequency', 'threshold'])
exp1_dataset = pd.merge(left=exp1_dataset, 
                        right=GMC_data_estimations1, how='left',
                        on=['mouse_id', 'frequency', 'threshold'])
exp1_dataset.rename(columns={'threshold': 'threshold manual', 
                             'slr_estimated_thr': 'threshold SLR', 
                             'nn_predicted_thr': 'threshold NN'}, inplace=True)

print('GMC_data rows:\t\t\t%d\nGMC_data_predictions1 rows:\t%d\nGMC_data_estimations1 rows:\t%d\nexp1_dataset rows:\t\t%d' % 
      (GMC_data.index.nunique(), GMC_data_predictions1.index.nunique(), GMC_data_estimations1.index.nunique(), exp1_dataset.index.nunique()))

display(exp1_dataset.head(5))

In [None]:
"""Plot of evaluation curves"""
plot_evaluation_curves(exp1_dataset, _xlabel=r'n/N', _ylabel=r'$S^2(n)/S^2(N)$')

In [None]:
#r'$\alpha^2$'
#import matplotlib
#matplotlib.rcParams['text.usetex'] = False
#plt.xlabel(r'$\alpha^2$')
#plt.ylabel(r'$S^2(n)/S^2(N)$')

## Experiment 2
   
* NN GMC-ING, SLR GMC-ING, ING manual thresholds, dummy method

In [None]:
exp2_dataset = pd.merge(left=ING_data, 
                        right=ING_data_predictions1[['mouse_id', 'frequency', 'threshold', 'nn_predicted_thr']], 
                        how='left',
                        on=['mouse_id', 'frequency', 'threshold'])
exp2_dataset = pd.merge(left=exp2_dataset, 
                        right=ING_data_estimations1, how='left',
                        on=['mouse_id', 'frequency', 'threshold'])
exp2_dataset.rename(columns={'threshold': 'threshold manual', 
                             'slr_estimated_thr': 'threshold SLR', 
                             'nn_predicted_thr': 'threshold NN'}, inplace=True)

print('ING_data rows:\t\t\t%d\nING_data_predictions1 rows:\t%d\nING_data_estimations1 rows:\t%d\nexp2_dataset rows:\t\t%d' % 
      (ING_data.index.nunique(), ING_data_predictions1.index.nunique(), ING_data_estimations1.index.nunique(), exp2_dataset.index.nunique()))

display(exp2_dataset.head(5))

In [None]:
"""Plot of evaluation curves"""
plot_evaluation_curves(exp2_dataset, _xlabel=r'n/N', _ylabel=r'$S^2(n)/S^2(N)$')

## Experiment 3

* NN ING-GMC, SLR ING-GMC, GMC manual thresholds, dummy method

In [None]:
exp3_dataset = pd.merge(left=GMC_data, 
                        right=GMC_data_predictions2[['mouse_id', 'frequency', 'threshold', 'nn_predicted_thr']], 
                        how='left',
                        on=['mouse_id', 'frequency', 'threshold'])
exp3_dataset = pd.merge(left=exp3_dataset, 
                        right=GMC_data_estimations2, how='left',
                        on=['mouse_id', 'frequency', 'threshold'])
exp3_dataset.rename(columns={'threshold': 'threshold manual', 
                             'slr_estimated_thr': 'threshold SLR', 
                             'nn_predicted_thr': 'threshold NN'}, inplace=True)

print('GMC_data rows:\t\t\t%d\nGMC_data_predictions2 rows:\t%d\nGMC_data_estimations2 rows:\t%d\nexp3_dataset rows:\t\t%d' % 
      (GMC_data.index.nunique(), GMC_data_predictions2.index.nunique(), GMC_data_estimations2.index.nunique(), exp3_dataset.index.nunique()))
print(len(GMC_data), len(GMC_data_predictions2), len(GMC_data_estimations2), len(exp3_dataset))
display(exp3_dataset.head(5))

In [None]:
"""Plot of evaluation curves"""
plot_evaluation_curves(exp3_dataset, _xlabel=r'n/N', _ylabel=r'$S^2(n)/S^2(N)$')

## Experiment 4

* NN ING-ING, SLR ING-ING, ING manual thresholds, dummy method

In [None]:
exp4_dataset = pd.merge(left=ING_data, 
                        right=ING_data_predictions2[['mouse_id', 'frequency', 'threshold', 'nn_predicted_thr']], 
                        how='left',
                        on=['mouse_id', 'frequency', 'threshold'])
exp4_dataset = pd.merge(left=exp4_dataset, 
                        right=ING_data_estimations2, how='left',
                        on=['mouse_id', 'frequency', 'threshold'])
exp4_dataset.rename(columns={'threshold': 'threshold manual', 
                             'slr_estimated_thr': 'threshold SLR', 
                             'nn_predicted_thr': 'threshold NN'}, inplace=True)

print('ING_data rows:\t\t\t%d\nING_data_predictions2 rows:\t%d\nING_data_estimations2 rows:\t%d\nexp4_dataset rows:\t\t%d' % 
      (ING_data.index.nunique(), ING_data_predictions2.index.nunique(), ING_data_estimations2.index.nunique(), exp4_dataset.index.nunique()))
print(len(ING_data), len(ING_data_predictions2), len(ING_data_estimations2), len(exp4_dataset))
display(exp4_dataset.head(5))

In [None]:
"""Plot of evaluation curves"""
plot_evaluation_curves(exp4_dataset, _xlabel=r'n/N', _ylabel=r'$S^2(n)/S^2(N)$')

---