# Description

This notebook is used to train the **first convolutional neural network** of the automated ABR thresholder with data provided by [Ingham et. al](https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.3000194). 

The first model (**Model I**) is trained as classifier to predict if an ABR response is present or not present in a single stimulus curve - one frequency, one sound pressure level (SPL). The required labels for Model I are derived from the original hearing thresholds under the assumption that all sub-threshold SPL curves represent non-hearing, while threshold and supra-threshold SPL curves represent hearing.

The input data consists of the measured ABR curves represented as time series and the SPL associated with each curve.

The prediction values of Model I are used as input data for the second neural network.

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline 

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Load libraries

In [None]:
import os

import pandas as pd 
import numpy as np 
import seaborn as sns

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [20, 16]

import ABR_ThresholdFinder_NN.data_preparation as dataprep
from ABR_ThresholdFinder_NN.models import create_model_1, compile_model_1

"""Set the available GPUs"""
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
if os.environ["CUDA_VISIBLE_DEVICES"]:
    count_gpus = len([int(s) for s in os.environ['CUDA_VISIBLE_DEVICES'].split(',')])
    print('%d GPUs available' % count_gpus)
else:
    count_gpus = 0
    print('no GPUs available')

# Definitions

Define variables and methods to be used later.

In [None]:
"""Set the batch size"""
batch_size = 32
print('batch size %d' % batch_size)

"""
Define potential frequencies measured in Hz, with the exception of 100, 
which stands for a broadband frequency stimulus (click)  
"""
poss_freq = [100, 6000, 12000, 18000, 24000, 30000]
print(*['potential stimulus frequencies: ' + str(x) if x==100 else str(x)+'Hz' for x in poss_freq], sep = ", ") 

"""Define potential sound pressure levels measured in dB"""
poss_thr = [p for p in range(0, 100, 5)]
print(*['potential sound pressure levels [dB]: ' + str(x) if x==0 else str(x) for x in poss_thr], sep = ", ") 

# Load data

In [None]:
"""Set the path to the data files, for example '../data/ING/'"""
data_dir = ''
"""Define the name of the file containing the ABR curve dataset"""
data_file = os.path.join(data_dir, 'ING_abr_curves.csv')
if os.path.exists(data_file):
    print('data file: {}'.format(data_file))
else: 
    print('ERROR: data file not found!')
"""Define the folder in which to store the trained models"""
save_dir = '../models/'

In [None]:
"""Load the data and add a column for (non-)hearing definition of an ABR curve"""
data = pd.read_csv(data_file, low_memory=True)
data['still_hears'] = [1 if data.loc[idx, 'sound_level'] >= data.loc[idx, 'threshold'] else 0 for idx in data.index]
display(data.head(10))

# Prepare data

In [None]:
"""Prepare the data to feed the neural network"""
data1 = dataprep.prepare_data4training_1(data, poss_freq, poss_thr, _ING_model=True)
display(data1.head(2))

In [None]:
"""
The input data has to fulfill certain requirements: 
- define ["still_hears"] as main_label_col
- only the first 1000 columns (time steps) are to be considered 
"""
main_label_col, freq_label_col, sl_label_col, data_cols = dataprep.get_col_names4model_1_labels(data1)
print('main label column %s' % main_label_col)
print('frequency label columns: [%s ... %s]' % (freq_label_col[0], freq_label_col[-1]))
print('sound level label columns: [%s ... %s]' % (sl_label_col[0], sl_label_col[-1]))
print('data columns: [%s ... %s]' % (data_cols[0], data_cols[-1]))

# Split data into training, validation and test sets

## New split of the training data in training and validation data sets

In [None]:
random_seed = 42
print('random seed: {}'.format(random_seed))

In [None]:
"""Split all data into train data and valid data (80:20)"""
train_mice, valid_mice, train_indices, valid_indices = dataprep.split_data(data1, random_seed) 
print('train index: %d' % len(train_indices))
print('valid index: %d' % len(valid_indices))

train_data = data1.loc[train_indices]
train_data.head(2)

"""Split train data into train data 2 and test data (80:20)"""
train_mice2, test_mice, train_indices2, test_indices = dataprep.split_data(train_data, random_seed)
train_data2 = train_data.loc[train_indices2]

In [None]:
data_file = os.path.join(data_dir,'ING_train_mice.npy')
np.save(data_file, train_mice2)
data_file = os.path.join(data_dir,'ING_valid_mice.npy')
np.save(data_file, valid_mice)
data_file = os.path.join(data_dir,'ING_test_mice.npy')
np.save(data_file, test_mice)

## Load the existing training, validation and test data splits

In [None]:
train_mice = np.load(os.path.join(data_dir, 'ING_train_mice.npy'), allow_pickle=True)
valid_mice = np.load(os.path.join(data_dir, 'ING_valid_mice.npy'), allow_pickle=True)
test_mice = np.load(os.path.join(data_dir, 'ING_test_mice.npy'), allow_pickle=True)

train_indices = data1.index[data1['mouse_id'].isin(train_mice)]
valid_indices = data1.index[data1['mouse_id'].isin(valid_mice)]
test_indices = data1.index[data1['mouse_id'].isin(test_mice)]

print('train index: %d' % len(train_indices))
print('valid index: %d' % len(valid_indices))
print('test index: %d' % len(test_indices))

# Standardize

In [None]:
data2 = dataprep.standardize(data1, train_indices, data_cols)
data2.head(2)

# Initialize the data generators

In [None]:
"""Import the data generator"""
import ABR_ThresholdFinder_NN.data_generator as datagenerate

In [None]:
"""Define the data generators (actually keras sequences)"""
train_data_generator = datagenerate.DataGenerator(list_IDs=train_indices2, 
                                     df=data2, 
                                     value_cols=data_cols, 
                                     main_label_col=main_label_col, 
                                     freq_label_col=freq_label_col, 
                                     sl_label_col=sl_label_col,
                                     dim=1000, 
                                     batch_size=batch_size, 
                                     shuffle=True)
valid_data_generator = datagenerate.DataGenerator(list_IDs=valid_indices, 
                                     df=data2, 
                                     value_cols=data_cols, 
                                     main_label_col=main_label_col,
                                     freq_label_col=freq_label_col, 
                                     sl_label_col=sl_label_col,
                                     dim=1000, 
                                     batch_size=batch_size, 
                                     shuffle=True)
test_data_generator = datagenerate.DataGenerator(list_IDs=test_indices, 
                                    df=data2, 
                                    value_cols=data_cols, 
                                    main_label_col=main_label_col,
                                    freq_label_col=freq_label_col, 
                                    sl_label_col=sl_label_col, 
                                    dim=1000, 
                                    batch_size=batch_size, 
                                    shuffle=True)

# Model building

In [None]:
"""Create Model I"""
model = create_model_1(len(freq_label_col), len(sl_label_col))
model.summary()

In [None]:
"""Compile Model I"""
parallel_model, loss, loss_weights = compile_model_1(model, count_gpus)

# Training

In [None]:
"""Import callbacks"""
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard

In [None]:
"""Define metric for early stopping"""
mon_var = 'val_main_prediction_loss'

"""Define callbacks"""
checkpointer = ModelCheckpoint(filepath=save_dir + 'INGtrained_model_1_weights.hdf5', 
                               verbose=1, 
                               save_best_only=True, 
                               monitor=mon_var)
early_stopper = EarlyStopping(monitor=mon_var, patience=7)
reduce_lr = ReduceLROnPlateau(monitor=mon_var, patience=2)

"""Training"""
history = parallel_model.fit_generator(generator=train_data_generator,
                                       validation_data=valid_data_generator, 
                                       use_multiprocessing=True, 
                                       epochs=30,
                                       workers=32,
                                       shuffle=True,
                                       callbacks=[checkpointer, early_stopper, reduce_lr])

## Training history

In [None]:
"""Summarize history for loss"""
plt.figure(figsize=(20, 14), dpi= 80, facecolor='w', edgecolor='k')
plt.plot(history.history['main_prediction_loss'])
plt.plot(history.history['val_main_prediction_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper right')
plt.savefig(save_dir + "INGtrained_model_1_loss_history.png")
plt.show()

"""Summarize history for accuracy"""
plt.figure(figsize=(20, 14), dpi= 80, facecolor='w', edgecolor='k')
plt.plot(history.history['main_prediction_acc'])
plt.plot(history.history['val_main_prediction_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='lower right')
plt.savefig(save_dir + "INGtrained_model_1_acc_history.png")
plt.show()

## Model evaluation on the test data set

In [None]:
results = parallel_model.evaluate(test_data_generator)
results = dict(zip(parallel_model.metrics_names,results))

In [None]:
for metric in results: 
    print('%s: %.2f' % (metric, results[metric]))

## Save non-GPU model

In [None]:
model.save(save_dir + 'INGtrained_model_1.h5')

# Prediction

In [None]:
"""Load the model for predictions and validation"""
from tensorflow.keras.models import load_model
parallel_model = load_model(save_dir + 'INGtrained_model_1.h5')

In [None]:
test_data = data2[data2.mouse_id.isin(test_mice)].reset_index()

In [None]:
"""Make predictions for test data"""
df = parallel_model.predict(test_data[data_cols].values[:,:,np.newaxis])
predictions = pd.concat([pd.DataFrame(df[0]), pd.DataFrame(df[1])], axis=1, ignore_index=True)
predictions = pd.concat([predictions, pd.DataFrame(df[2])], axis=1, ignore_index=True)

"""Define column names of predictions"""
predictions.columns = main_label_col + freq_label_col + sl_label_col

display(predictions.head(2))

In [None]:
"""Extract predicted labels"""
"""still hearing label"""
test_data['predicted_hears'] = round(predictions[main_label_col])
test_data['predicted_hears_exact'] = predictions[main_label_col]

"""frequency label"""
test_data['predicted_freq_encoded'] = predictions.loc[:, freq_label_col].idxmax(axis=1)
test_data['predicted_freq'] = test_data['predicted_freq_encoded'].str.replace(r'\D+', '').astype('int')
freq_dict = dict(zip(poss_freq, list(range(len(poss_freq)))))
inv_freq_dict = {v: k for k, v in freq_dict.items()}
test_data['predicted_freq'] = test_data['predicted_freq'].map(inv_freq_dict)

"""sound level label"""
test_data['predicted_sl_encoded'] = predictions.loc[:, sl_label_col].idxmax(axis=1)
test_data['predicted_sl'] = test_data['predicted_sl_encoded'].str.replace(r'\D+', '').astype('int')
sl_dict = dict(zip(poss_thr, list(range(len(poss_thr)))))
inv_sl_dict = {v: k for k, v in sl_dict.items()}
test_data['predicted_sl'] = test_data['predicted_sl'].map(inv_sl_dict)

display(test_data.head(2))

# Validation

## Accuracy

In [None]:
"""Accuracy predictions"""
test_data1 = test_data.copy()

freq_acc = 100 * round(len(test_data1[test_data1['frequency'] == test_data1['predicted_freq']]) / len(test_data1), 3)
print('frequency accuracy: %.2f%%' % freq_acc)
print()
    
sl_acc = 100 * round(len(test_data1[test_data1['sound_level'] == test_data1['predicted_sl']]) / len(test_data1), 3)
print('sound level accuracy: %.2f%%' % sl_acc)
print()
    
curve_acc = 100 * round(len(test_data1[test_data1['still_hears'] == test_data1['predicted_hears']]) / len(test_data1), 4)
print('curve accuracy: %.2f%%' % curve_acc)
print()

print('frequency specific curve accuracy (count of test data)')
for f in test_data1['frequency'].unique():
    freq_spec_curve_acc = 100 * round(len(test_data1[(test_data1['still_hears'] == test_data1['predicted_hears']) & 
                                                     (test_data1['frequency'] == f)]) / 
                                      len(test_data1[test_data1['frequency'] == f]), 3)
    count_of_test_data = len(test_data1[test_data1['frequency'] == f])
    if f == 100:
        print(' click: %.2f%% (%d)' % (freq_spec_curve_acc, count_of_test_data))
    else:
        print(' frequency %d: %.2f%% (%d)' % (f, freq_spec_curve_acc, count_of_test_data))

## Plot of prediction probabilities of hearing and non-hearing curves

In [None]:
"""Visualize prediction probabilites of hearing and non-hearing curves"""
still_hears = [0, 1]
plt.figure(figsize=(20, 14), dpi= 80, facecolor='w', edgecolor='k')

"""Iterate through the two hearing possibilities"""
for still_hearing in still_hears:
    subset = test_data1[test_data1['still_hears'] == still_hearing]
    sns.kdeplot(subset['predicted_hears_exact'], shade=True, gridsize=70, label=still_hearing)

"""Plot formatting"""
plt.legend(prop={'size': 16}, title = 'predicted hears')
plt.title('Density Plot hearing prediction values')
plt.xlabel('Probability')
plt.ylabel('Density')

# Store predictions to be fed into the second neural network

In [None]:
"""Create new dataframe for predicting"""
test_data2 = data2.copy() # data2 entspricht komplettem Datensatz; hier test_data2 genommen, um oben verwendetes test_data1 nicht zu überschreiben

"""Make predictions for all data"""
predictions_temp = parallel_model.predict(test_data2[data_cols].values[:,:,np.newaxis])
predictions = pd.concat([pd.DataFrame(predictions_temp[0]), pd.DataFrame(predictions_temp[1])], axis=1, ignore_index=True)
predictions = pd.concat([predictions, pd.DataFrame(predictions_temp[2])], axis=1, ignore_index=True)

"""Define column names of predictions"""
predictions.columns = main_label_col + freq_label_col + sl_label_col
predictions.head(2)

In [None]:
"""Extract predicted labels"""
"""still hearing label"""
test_data2['predicted_hears'] = round(predictions[main_label_col])
test_data2['predicted_hears_exact'] = predictions[main_label_col]

"""frequency label"""
test_data2['predicted_freq_encoded'] = predictions.loc[:, freq_label_col].idxmax(axis=1)
test_data2['predicted_freq'] = test_data2['predicted_freq_encoded'].str.replace(r'\D+', '').astype('int')
inv_freq_dict = {v: k for k, v in freq_dict.items()}
test_data2['predicted_freq'] = test_data2['predicted_freq'].map(inv_freq_dict)

"""sound level label"""
test_data2['predicted_sl_encoded'] = predictions.loc[:, sl_label_col].idxmax(axis=1)
test_data2['predicted_sl'] = test_data2['predicted_sl_encoded'].str.replace(r'\D+', '').astype('int')
inv_sl_dict = {v: k for k, v in sl_dict.items()}
test_data2['predicted_sl'] = test_data2['predicted_sl'].map(inv_sl_dict)

In [None]:
"""Write csv of exact predictions for external use"""
test_data2.loc[test_data2['mouse_id'].isin(train_mice), 'mouse_group'] = 'train'
test_data2.loc[test_data2['mouse_id'].isin(valid_mice), 'mouse_group'] = 'valid'
test_data2.loc[test_data2['mouse_id'].isin(test_mice), 'mouse_group'] = 'test'

"""'validated' entfernt, da nicht vorhanden"""
test_data2[['mouse_id', 'mouse_group', 'frequency', 'threshold', 'sound_level', 
            'predicted_hears_exact']].to_csv(save_dir + 'model_1_predicted_curves_21_04.csv', index=False)

print(test_data2[['mouse_id', 'mouse_group', 'frequency', 'threshold', 'sound_level']])

# Plot predictions

In [None]:
data_file = os.path.join(save_dir, 'INGtrained_model_1_predicted_curves.csv')
os.path.exists(data_file)

In [None]:
all_pred_curves = pd.read_csv(data_file)
all_pred_curves.head(2) 

In [None]:
import matplotlib.gridspec as gridspec

for mouse in all_pred_curves.mouse_id.unique()[:10]:
    
    fig = plt.figure(constrained_layout=True, figsize=(15, 8))

    df = all_pred_curves[all_pred_curves.mouse_id == mouse]
    
    ncols = 3
    nrows = int(len(df.frequency.unique())/ncols)
    col = 0
    row = 0
    spec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)
    f_ax = {}
    for idx,freq in enumerate(df.frequency.unique()):
        f_ax[idx] = fig.add_subplot(spec[row, col])
        if freq == 100: 
            f_ax[idx].set_title('Click')
        else:
            f_ax[idx].set_title('%sHz' % freq)
        f_ax[idx].set_xticks(df.sound_level.unique())
        df[df.frequency == freq].plot(x='sound_level', y='predicted_hears_exact', ax=f_ax[idx], legend=False)
        plt.vlines(x=df[df.frequency == freq]['threshold'], ymin=0, ymax=1., linestyles='dashed', color='lightgray')
        col+=1
        if col == ncols:
            row+=1
            col=0
    fig.suptitle('%s mouse id: %s' % (df.mouse_group.unique()[0], mouse), fontsize=16)    

In [None]:

noofmice = 10

for group in all_pred_curves.mouse_group.unique():
    
    df = all_pred_curves.loc[all_pred_curves.mouse_group == group]
    
    fig = plt.figure(constrained_layout=True, figsize=(15, 8))

    ncols = 3
    nrows = int(len(df.frequency.unique())/ncols)
    col = 0
    row = 0
    spec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)
    f_ax = {}

    for idx,freq in enumerate(df.frequency.unique()):
        
        df1 = df[df.frequency == freq]
    
        f_ax[idx] = fig.add_subplot(spec[row, col])
        if freq == 100:
            f_ax[idx].set_title('Click')
        else:
            f_ax[idx].set_title('%sHz' % freq)
        f_ax[idx].set_xticks(df1.sound_level.unique())
        
        for mouse in df1.mouse_id.unique()[:noofmice]:
            df1.loc[df1.mouse_id == mouse].plot(x='sound_level', y='predicted_hears_exact', ax=f_ax[idx], legend=False)
            
        col+=1
        if col == ncols:
            row+=1
            col=0
           
    fig.suptitle('mouse group: %s (%d mice)' % (group, noofmice), fontsize=16)  