# Description

This notebook is used to train **Model II** based on data provided by the [German Mouse Clinic](https://www.mouseclinic.de/). 

Model II is trained as classifier for every stimulus to predict the hearing threshold using the respective class score outputs of Model I as input and the original hearing thresholds as labels.

**A hearing threshold is defined as the lowest sound level at which the mouse can still hear something**.

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline 

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Load libraries

In [None]:
import os

import pandas as pd 
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [20, 16]

import ABR_ThresholdFinder_NN.data_preparation as dataprep
from ABR_ThresholdFinder_NN.models import create_model_1, compile_model_1, create_model_2, compile_model_2

from sklearn.model_selection import KFold, StratifiedKFold, cross_val_score

import warnings
warnings.filterwarnings("ignore")

"""Set the available GPUs"""
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
count_gpus = len([int(s) for s in os.environ['CUDA_VISIBLE_DEVICES'].split(',')])
print('%d GPUS' % count_gpus)

# Definitions

In [1]:
"""Set the batch size"""
batch_size=128

"""Set the cutoff for the hearing threshold"""
threshold_cutoff = 0.5

"""Initialize the sample weight column (for the main label) """
sample_weight_col = []

"""Define potential sound pressure levels measured in dB"""
sound_levels = [x for x in range(0, 100, 5)] 
print(*['potential sound pressure levels [dB]: ' + str(x) if x==0 else str(x) for x in sound_levels], sep = ", ") 
aul_sound_level = 999

potential sound pressure levels [dB]: 0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95


# Load data

In [None]:
"""Set the path to the data files"""
save_dir = '../models/'
data_file = os.path.join(save_dir, 'GMCtrained_model_1_predicted_curves.csv')
if os.path.exists(data_file):
    print('data file: {}'.format(data_file))
else: 
    print('ERROR: data file not found!')

In [None]:
"""Load data from file, which is the output of Model I"""
data = pd.read_csv(data_file)
data.head(2)

# Prepare data

In [None]:
"""Prepare the data to feed the second convolutional neural network"""
data1, sample_weight_col = dataprep.prepare_data4training_2(data, sound_levels, aul_sound_level)
data1.head(2)

In [None]:
main_label_col, freq_label_col, data_cols = dataprep.get_col_names4model_2_labels(data1)
print('main label column [%s ... %s]' % (main_label_col[0], main_label_col[-1]))
print('frequency label columns: [%s ... %s]' % (freq_label_col[0], freq_label_col[-1]))
print('data columns: [%s ... %s]' % (data_cols[0], data_cols[-1]))

print('sample weight column: %s' % sample_weight_col)

# Split data

In [None]:
"""Distinguish training, validation and testing data - determined by previous model"""
"""Get indices"""
train_indices = data1.index[data1['mouse_group']=='train']
valid_indices = data1.index[data1['mouse_group']=='valid']
test_indices = data1.index[data1['mouse_group']=='test']
"""Get mouse IDs"""
train_mice = data1.loc[train_indices, 'mouse_id']
valid_mice = data1.loc[valid_indices, 'mouse_id']
test_mice = data1.loc[test_indices, 'mouse_id']

# Initialize data generators

In [None]:
"""Import the data generator"""
import ABR_ThresholdFinder_NN.data_generator as datagenerate

In [None]:
length_input = len(data_cols)

train_data_generator = datagenerate.DataGenerator(list_IDs=train_indices,
                                     df=data1, 
                                     value_cols=data_cols, 
                                     main_label_col=main_label_col, 
                                     freq_label_col=freq_label_col, 
                                     sl_label_col=[], 
                                     dim=length_input, 
                                     batch_size=batch_size, 
                                     shuffle=True,
                                     sample_weight_col=sample_weight_col)
valid_data_generator = datagenerate.DataGenerator(list_IDs=valid_indices,
                                     df=data1, 
                                     value_cols=data_cols, 
                                     main_label_col=main_label_col, 
                                     freq_label_col=freq_label_col, 
                                     sl_label_col=[], 
                                     dim=length_input, 
                                     batch_size=batch_size, 
                                     shuffle=True,
                                     sample_weight_col=sample_weight_col)
test_data_generator = datagenerate.DataGenerator(list_IDs=test_indices,
                                    df=data1, 
                                    value_cols=data_cols, 
                                    main_label_col=main_label_col, 
                                    freq_label_col=freq_label_col, 
                                    sl_label_col=[], 
                                    dim=length_input, 
                                    batch_size=batch_size, 
                                    sample_weight_col=sample_weight_col)

# Model building

In [None]:
"""Create Model II"""
model = create_model_2(len(data_cols), len(main_label_col), len(freq_label_col))
model.summary()

In [None]:
"""Compile Model II"""
parallel_model, loss, loss_weights = compile_model_2(model, count_gpus)

# Training

In [None]:
"""Import callbacks"""
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard

In [None]:
"""Define criterion for early stopping"""
mon_var = 'val_main_prediction_loss'

"""Create callbacks"""
checkpoint = ModelCheckpoint(filepath=save_dir + 'GMCtrained_model_2_weights.hdf5', 
                             verbose=1, save_best_only=True, monitor=mon_var)
early_stopper = EarlyStopping(monitor=mon_var, patience=61)
reduce_lr = ReduceLROnPlateau(monitor=mon_var, patience=10)
callbacks_list = [checkpoint, early_stopper, reduce_lr]

"""Fit Model II"""
history = parallel_model.fit_generator(generator=train_data_generator, 
                                       validation_data=valid_data_generator, 
                                       use_multiprocessing=True, 
                                       epochs=200,
                                       workers=8,
                                       shuffle=True,
                                       callbacks=callbacks_list)

## Training history

In [None]:
"""Plot history"""
"""Summarize history for loss"""
plt.figure(figsize=(20, 14), dpi= 80, facecolor='w', edgecolor='k')
plt.plot(history.history['main_prediction_loss'])
plt.plot(history.history['val_main_prediction_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper right')
plt.savefig(save_dir + "GMCtraining_history/GMCtrained_model_2_loss_history.png")
plt.show()

"""Summarize history for accuracy"""
plt.figure(figsize=(20, 14), dpi= 80, facecolor='w', edgecolor='k')
plt.plot(history.history['main_prediction_categorical_accuracy'])
plt.plot(history.history['val_main_prediction_categorical_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='lower right')
plt.savefig(save_dir + "GMCtraining_history/GMCtrained_model_2_acc_history.png")
plt.show()

## Model evaluation on test data set

In [None]:
results = parallel_model.evaluate(test_data_generator)
results = dict(zip(parallel_model.metrics_names,results))

In [None]:
for metric in results: 
    print('%s: %.2f' % (metric, results[metric]))

## Save non-GPU model

In [None]:
model.save(save_dir + 'GMCtrained_model_2.h5')

# Prediction

In [None]:
"""Load model for predictions and validation"""
from tensorflow.keras.models import load_model
parallel_model = load_model(save_dir + 'GMCtrained_model_2.h5')

In [None]:
test_data = data1[data1.mouse_id.isin(test_mice)].reset_index()

In [None]:
"""Make predictions for the test dataset"""
df = parallel_model.predict(test_data[data_cols].values[:,:,np.newaxis])
predictions = pd.concat([pd.DataFrame(df[0]), pd.DataFrame(df[1])], axis=1, ignore_index=True)

"""Define column names of predictions"""
predictions.columns = main_label_col + freq_label_col
predictions.head(2)

In [None]:
thr_estimations = pd.DataFrame()
for idx in predictions.index:
    thr_preds = predictions.loc[idx, main_label_col]
    thr_preds = thr_preds.iloc[thr_preds.values == min(thr_preds.iloc[thr_preds.values > threshold_cutoff])]
    thr_estimations.loc[idx, 'pred_thr'] = thr_preds.index[0]
    thr_estimations.loc[idx, 'pred_score'] = thr_preds.values[0]
    
print(thr_estimations)

In [None]:
test_data1 = test_data.copy()

"""Based on the defined cutoff, determine the threshold from the prediction values"""

"""Apply cutoff"""
a = predictions[main_label_col].where(predictions[main_label_col] > threshold_cutoff).notna()
a = a[a==True]

"""Save encoded value in the new dataframe"""
test_data1['predicted_thr_encoded'] = a.apply(lambda x: x[x.notnull()].index.values[-1], axis=1)

"""Decode the threshold with dictionary"""
test_data1['predicted_thr'] = test_data1['predicted_thr_encoded'].str.replace(r'\D+', '').astype('int')
_temp = [aul_sound_level] + sorted(sound_levels, reverse=True)
thresh_dict = dict(zip(_temp, list(range(len(_temp)))))
inv_thresh_dict = {v: k for k, v in thresh_dict.items()}
test_data1['predicted_thr'] = test_data1['predicted_thr'].map(inv_thresh_dict)

"""Decode and save frequency prediction"""
test_data1['predicted_freq_encoded'] = predictions.loc[:, freq_label_col].idxmax(axis=1)
test_data1['predicted_freq'] = test_data1['predicted_freq_encoded'].str.replace(r'\D+', '').astype('int')
_temp = sorted(test_data1['frequency'].unique())
freq_dict = dict(zip(_temp, list(range(len(_temp)))))
inv_freq_dict = {v: k for k, v in freq_dict.items()}
test_data1['predicted_freq'] = test_data1['predicted_freq'].map(inv_freq_dict)
test_data1.head(2)

In [None]:
test_data1['prediction_score'] = [thr_estimations.loc[idx, 'pred_score'] for idx in test_data1.index]
test_data1

# Validation

## Accuracy

In [None]:
"""Output overall mouse metrics"""

"""Select only test mice"""
result_test = test_data1.loc[(test_data1['mouse_id'].isin(test_mice))].copy()# & (test_data1['swipe']=='right')].copy()
result_test = result_test.sort_values(by='frequency')

"""Compute distance to actual threshold"""
result_test['thr_dist'] = result_test.threshold - result_test.predicted_thr

"""Print frequency accuracy"""
freq_acc = 100 * round(len(result_test[result_test['frequency'] == result_test['predicted_freq']]) / 
                       len(result_test), 3)
print('Frequency accuracy: %.2f%%' % freq_acc)
print()

"""Print exact accuracy of main label prediction"""
thr_acc = 100 * round(len(result_test[result_test.threshold==result_test['predicted_thr']]) / 
                  len(result_test), 3)
print('threshold accuracy: %.2f%%' % thr_acc)
for freq in result_test.frequency.unique():
    freq_thr_acc = 100 * round(len(result_test[(result_test.threshold==result_test.predicted_thr) & (result_test.frequency==freq)]) / 
                               len(result_test[result_test.frequency==freq]), 3)
    if freq == 100: 
        print("   Click: %.2f%%" % freq_thr_acc)
    else:
        print("   Frequency %d: %.2f%%" % (freq, freq_thr_acc))

"""Print error of main label prediction"""
all_me = round(np.mean(abs(result_test.threshold - result_test.predicted_thr)), 3)
print('Mean error of all classifications: %.2f' % all_me)

"""Print only error of wrong classifications"""
f_mae = round(np.mean(abs(result_test[result_test.threshold!=result_test.predicted_thr].threshold - 
                          result_test[result_test.threshold!=result_test.predicted_thr].predicted_thr)), 3)
print("Absolute mean error of false classifications: %.2f" % f_mae)
for freq in result_test.frequency.unique():
    freq_f_mae = round(np.mean(abs(result_test[(result_test.threshold!=result_test.predicted_thr) & (result_test.frequency==freq)].threshold - 
                                   result_test[(result_test.threshold!=result_test.predicted_thr) & (result_test.frequency==freq)].predicted_thr)), 3)
    freq_mae = round(np.mean(abs(result_test[result_test.frequency==freq].threshold - 
                                 result_test[result_test.frequency==freq].predicted_thr)), 3)
    if freq == 100:
        print('   Click: %.2f (%.2f)' % (freq_f_mae, freq_mae))
    else: 
        print('   Frequency %d: %.2f (%.2f)' % (freq, freq_f_mae, freq_mae))

"""Print accuracy with 5 dB buffer"""
prop_mice_5db = 100 * round(len(result_test[abs(result_test.thr_dist)<=5]) / len(result_test), 3)
print('Proportion of mice with deviations of maximum 5dB: %.2f%%' % prop_mice_5db)
for freq in result_test.frequency.unique():
    freq_prop_mice_5db = 100 * round(len(result_test[(abs(result_test.thr_dist)<=5) & (result_test.frequency==freq)]) / 
                                     len(result_test[result_test.frequency==freq]), 3)
    freq_count = len(result_test[result_test.frequency==freq])    
    if freq == 100: 
        print('   Click: %.2f%% (%d)' % (freq_prop_mice_5db, freq_count))
    else:
        print('   Frequency %d: %.2f%% (%d)' % (freq, freq_prop_mice_5db, freq_count))
    
"""Print accuracy with 10 dB buffer"""
prop_mice_10db = 100 * round(len(result_test[abs(result_test.thr_dist)<=10]) / len(result_test), 3)
print('Proportion of mice with deviations of maximum 10dB: %.2f%%' % prop_mice_5db)
for freq in result_test.frequency.unique():
    freq_prop_mice_10db = 100 * round(len(result_test[(abs(result_test.thr_dist)<=10) & (result_test.frequency==freq)]) / 
                                      len(result_test[result_test.frequency==freq]), 3)
    freq_count = len(result_test[result_test.frequency==freq])
    if freq == 100: 
        print('   Click: %.2f%% (%d)' % (freq_prop_mice_10db, freq_count))
    else:
        print('   Frequency %d: %.2f%% (%d)' % (freq, freq_prop_mice_10db, freq_count))

## Plot threshold deviations

In [None]:
from sklearn.metrics import confusion_matrix
import itertools
import seaborn as sns

In [None]:
"""Confusion matrix for all frequencies"""
result_test2 = result_test
confusion_mtx = confusion_matrix(result_test2['threshold'], result_test2['predicted_thr']) 
index = [int(x) for x in sorted(result_test2['threshold'].unique())]
print('threshold', index, len(index))
print()
columns = [int(x) for x in sorted(result_test2['predicted_thr'].unique())]
print('predicted threshold', columns, len(columns))
print()
merged_list = list(itertools.chain(*itertools.zip_longest(result_test2['threshold'].unique(), 
                                                          result_test2['predicted_thr'].unique())))
merged_list = [i for i in merged_list if i is not None]
merged_list = sorted(set(merged_list))
print(merged_list)
confusion_mtx = pd.DataFrame(confusion_mtx, index = merged_list, 
                             columns = merged_list)

plt.rcParams.update({'font.size': 14})
plt.figure(figsize=(20,16))

ax = sns.heatmap(confusion_mtx, annot=True, fmt="d", cmap='Spectral_r', center=210)
ax.set(xlabel='Predicted threshold', ylabel='Manual threshold')