# Description

This notebook is used to detect ABR hearing thresholds using neural networks (NN) trained on ABR data from the [German Mouse Clinic](https://www.mouseclinic.de/).</br>
The threshold detection is done on ABR hearing curves from the [German Mouse Clinic](https://www.mouseclinic.de/) (GMC data) as well as on ABR hearing curves provided by [Ingham et. al](https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.3000194) (ING data).

Training data set: 
* all ABR hearing curves in _GMC_abr_curves.csv_ measured for the mouse ids found in _../data/GMC/train_mice.npy_.

Neural network (NN) models:
* first NN trained on GMC data: _../models/GMCtrained_model_1.h5_
* second NN trained on GMC data: _../models/GMCtrained_model_2.h5_
   
Data sets for threshold detection:
* GMC data: _GMC_abr_curves.csv_
* ING data: _ING_abr_curves.csv_

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline 

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Load libraries

In [None]:
import os

import pandas as pd
import numpy as np
import seaborn as sns
import ABR_ThresholdFinder_NN.data_preparation as dataprep
import ABR_ThresholdFinder_NN.thresholder as abrthr

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.ticker as ticker
plt.rcParams['figure.figsize'] = [20, 16]

os.environ["CUDA_VISIBLE_DEVICES"]="1,2,3,4"

from tensorflow.keras.models import load_model

from ABR_ThresholdFinder_NN.swish_activation_function import swish

# Definitions

In [None]:
"""Set the path to the model files"""
path2models = '../models'
"""Set the path to data files, for example '../data'"""
path2data = ''

"""Set the time step columns"""
datacols = ['t' + str(i) for i in range(0, 1000)]

# Utils

In [None]:
def plot_data_infos(_data, _validated=True, _ylim=6000):
    
    temp = _data.copy()
    
    # Mice count
    print('Number of mice: %d' % temp.mouse_id.nunique())
    # Classifications count
    print('Number of classifications: %d' % temp[['mouse_id', 'frequency']].drop_duplicates().shape[0])
    # Curves count
    print('Number of individual curves: %d' % temp[['mouse_id', 'frequency', 'sound_level']].drop_duplicates().shape[0])
    
    temp['frequency'] = temp['frequency'].map({100: 'Click', 6000: '6 kHz', 12000: '12 kHz', 
                                               18000: '18 kHz', 24000: '24 kHz', 30000: '30 kHz'})
    if 'still_hears' in temp.columns:
        temp['still_hears'] = temp['still_hears'].map({0: 'No', 1: 'Yes'})   
    
    # Curves by stimulus and validation
    if _validated:
        plt.figure(figsize=(20,20))
        plt.subplot(321)
        _g1 = sns.countplot(x='frequency', hue='validated', data=temp, palette=sns.color_palette("colorblind"))
        _g1.legend(title='Validation')
        plt.subplot(322)
        _g2 = sns.countplot(x='frequency', hue='validated', 
                           data=temp[['mouse_id', 'frequency', 'threshold', 'validated']].drop_duplicates()) #, palette=palette1)
        _g2.legend(title='Validation')
        plt.subplot(323)
    else:
        plt.subplot(221)
        
    g1 = sns.countplot(x='frequency', data=temp, palette=sns.color_palette("colorblind"))
    g1.set(xlabel='Stimulus', ylabel='Count', title='Available curves per stimulus')

    # Classifications by stimulus and validation
    if _validated:
        plt.subplot(324)
        plt.ylim(0, _ylim)
    else:
        plt.subplot(222)
    g2 = sns.countplot(x='frequency', 
                       data=temp[['mouse_id', 'frequency', 'threshold']].drop_duplicates(), palette=sns.color_palette("colorblind"))
    g2.set(xlabel='Stimulus', ylabel='Count', title='Available classifications per stimulus')
    
    # Classifications by stimulus and sound level
    if _validated:
        plt.subplot(325)
    else:
        plt.subplot(223)
    g3 = sns.countplot(x='frequency', hue='sound_level', data=temp, palette=sns.color_palette("colorblind")) #palette=sns.color_palette("light", 20)) #, palette=palette1)
    # g3.legend(title='Sound level')
    g3.legend(title='Sound level', loc='center left', bbox_to_anchor=(1, 0.5), fontsize='small')
    g3.set(xlabel='Stimulus', ylabel='Count', title='Available classifications per stimulus (sound level specific)')
    
    # Classifications by stimulus and threshold
    if _validated:
        plt.subplot(326)
    else:
        plt.subplot(224)
    g4 = sns.countplot(x='frequency', hue='threshold', data=temp, palette=sns.color_palette("colorblind")) #, palette=palette1)
    # g4.legend(title='Manual threshold')
    g4.legend(title='Manual threshold', loc='center left', bbox_to_anchor=(1, 0.5), fontsize='small')
    g4.set(xlabel='Stimulus', ylabel='Count', title='Available classifications per stimulus (manual thresholds)')
    plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.20, wspace=0.25)

    
    plt.show()

# Load GMC models

In [None]:
GMC_model1 = load_model(os.path.join(path2models, 'GMCtrained_model_1.h5'))
GMC_model1.summary()

In [None]:
GMC_model2 = load_model(os.path.join(path2models, 'GMCtrained_model_2.h5'))
GMC_model2.summary()

# ABR threshold detection on GMC data

## Load GMC data

In [None]:
GMC_data = pd.read_csv(os.path.join(path2data, 'GMC', 'GMC_abr_curves.csv'), low_memory=False)

In [None]:
"""Checking for duplicates"""
GMC_data[GMC_data.duplicated()]

In [None]:
"""Checking for multiple thresholds for the same mouse ID and frequency"""
mouse_ids = GMC_data[GMC_data.columns.drop('threshold')][GMC_data[GMC_data.columns.drop('threshold')].duplicated()].mouse_id.unique()
mouse_ids

## Plot data infos

In [None]:
plot_data_infos(GMC_data)

## Make curve specific predictions on GMC data

In [None]:
GMC_data1 = abrthr.make_curve_specific_predictions(GMC_data[['mouse_id', 'frequency', 'sound_level', 'threshold'] + datacols], 
                                                   GMC_model1, _ING_model=False)
GMC_data1.head()

## Make threshold predictions on GMC data

In [None]:
GMC_data2 = abrthr.make_threshold_predictions(GMC_data1, GMC_model2)
GMC_data2.head()

## Save predictions

In [None]:
print('GMC data')
print(' Predicted thresholds: %s' % sorted(GMC_data2.predicted_thr.unique()))
print(' Number of curves: %d' % GMC_data2.index.nunique())
print(' Number of mice: %d' % GMC_data2.mouse_id.nunique())

In [None]:
GMC_data2.rename(columns={'predicted_thr': 'nn_predicted_thr'}, inplace=True)
if 'mouse_group' in GMC_data2.columns:
    GMC_data2 = GMC_data2[list(GMC_data2.columns.drop('mouse_group')) + ['mouse_group']]
GMC_data2.head()

In [None]:
GMC_data2.to_csv('../reports/GMC_data_GMCtrained_NN_predictions.csv', index=False)

# ABR threshold detection on ING data

## Load ING data

In [None]:
ING_data = pd.read_csv(os.path.join(path2data, 'ING', 'ING_abr_curves.csv'), low_memory=False)

In [None]:
# Check for duplicates
ING_data[ING_data.duplicated()]

In [None]:
# Check for multiple thresholds for specific mouse ID and frequency 
mouse_ids = ING_data[ING_data.columns.drop('threshold')][ING_data[ING_data.columns.drop('threshold')].duplicated()].mouse_id.unique()
mouse_ids

## Plot data infos

In [None]:
plot_data_infos(ING_data, _validated=False)

## Make curve specific predictions on ING data

In [None]:
ING_data1 = abrthr.make_curve_specific_predictions(ING_data[['mouse_id', 'frequency', 'sound_level', 'threshold'] + datacols], 
                                                   GMC_model1, _ING_model=False)
ING_data1.head()

## Make threshold predictions on ING data

In [None]:
ING_data2 = abrthr.make_threshold_predictions(ING_data1, GMC_model2)
ING_data2.head()

## Save predictions

In [None]:
print('ING data')
print(' Predicted thresholds: %s' % sorted(ING_data2.predicted_thr.unique()))
print(' Number of curves: %d' % ING_data2.index.nunique())
print(' Number of mice: %d' % ING_data2.mouse_id.nunique())

In [None]:
ING_data2.rename(columns={'predicted_thr': 'nn_predicted_thr'}, inplace=True)
if 'mouse_group' in ING_data2.columns:
    ING_data2 = ING_data2[list(ING_data2.columns.drop('mouse_group')) + ['mouse_group']]
ING_data2.head()

In [None]:
ING_data2.to_csv('../reports/ING_data_GMCtrained_NN_predictions.csv', index=False)

---