# Description

In this notebook, a five-fold grouped cross-validation was performed for **Model II** with data from the [German Mouse Clinic](https://www.mouseclinic.de/), where the groups consist of mice.

The class score outputs of Model I are used as input for the second model.

First, a 4:1 randomly split into training and test mice is done. The training data is then randomly split 4:1 into training and validation mice in each fold.

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline 

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Load libraries

In [None]:
import os

import pandas as pd 
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [20, 16]

import ABR_ThresholdFinder_NN.data_preparation as dataprep
from ABR_ThresholdFinder_NN.models import create_model_1, compile_model_1, create_model_2, compile_model_2

from sklearn.model_selection import KFold, StratifiedKFold, cross_val_score

import warnings
warnings.filterwarnings("ignore")

"""Define available GPUs"""
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
count_gpus = len([int(s) for s in os.environ['CUDA_VISIBLE_DEVICES'].split(',')])
print('%d GPUS' % count_gpus)

# Definitions

In [None]:
"""Define the batch size"""
batch_size=128

"""Define the cutoff for the hearing threshold"""
threshold_cutoff = 0.5

"""Initialize the sample weight column (for the main label) """
sample_weight_col = []

"""Define potential sound pressure levels measured in dB"""
sound_levels = [x for x in range(0, 100, 5)] 
print(*['potential sound pressure levels [dB]: ' + str(x) if x==0 else str(x) for x in sound_levels], sep = ", ") 
aul_sound_level = 999

# Load data

In [None]:
"""Set the path to the data files, for example '../data/GMC'"""
data_dir = ''
"""Set the saving location for the models"""
save_dir = '../models_cross-validation/GMC/cv_network_2/'
data_file = '../models_cross-validation/GMC/model_1_predicted_curves.csv'
if os.path.exists(data_file):
    print('data file: {}'.format(data_file))
else: 
    print('ERROR: data file not found!')

In [None]:
"""Load the output data of the first model"""
data = pd.read_csv(data_file)
data.head(2)

# Prepare data

In [None]:
"""Prepare the data for the cross validation approach"""
data1, sample_weight_col = dataprep.prepare_data4training_2(data, sound_levels, aul_sound_level)
data1.head(2)

In [None]:
main_label_col, freq_label_col, data_cols = dataprep.get_col_names4model_2_labels(data1)
print('main label column [%s ... %s]' % (main_label_col[0], main_label_col[-1]))
print('frequency label columns: [%s ... %s]' % (freq_label_col[0], freq_label_col[-1]))
print('data columns: [%s ... %s]' % (data_cols[0], data_cols[-1]))

print('sample weight column: %s' % sample_weight_col)

# Split data

In [None]:
"""Distinguish training, validation and testing data - determined by previous model"""
"""Get indices"""
train_indices = data1.index[data1['mouse_group'].isin(['train', 'valid'])]
test_indices = data1.index[data1['mouse_group']=='test']
"""Get mouse ids"""
train_mice = data1.loc[train_indices, 'mouse_id'].unique()
test_mice = data1.loc[test_indices, 'mouse_id'].unique()

# K-fold cross validation¶

In [None]:
kf = KFold(n_splits=5, shuffle=True)

## Utils

In [None]:
def get_model_name(k):
    return 'model_2_'+str(k)+'.h5'

In [None]:
def get_model_weights_name(k):
    return 'model_2_'+str(k)+'_weights.h5'

## Validation

In [None]:
"""Import data generator"""
from ABR_ThresholdFinder_NN.data_generator import DataGenerator
"""Import callbacks"""
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard

In [None]:
"""Perform cross validation"""
VAL_RESULTS = {}
fold_var = 1

length_input = len(data_cols)

"""Define criterion for early stopping"""
mon_var = 'val_main_prediction_loss'

for train_index, val_index in kf.split(train_mice):
    
    print('fold %d' % fold_var)
    
    _train_mice = train_mice[train_index]
    _val_mice = train_mice[val_index]
    print('Overlap train/validation: %s' % np.intersect1d(_train_mice, _val_mice))
    
    train_index = data1.index[data1['mouse_id'].isin(_train_mice)].values
    val_index = data1.index[data1['mouse_id'].isin(_val_mice)].values
    
    print(train_index, val_index)
    print()
    
    train_data_generator = DataGenerator(list_IDs=train_index, 
                                         df=data1, 
                                         value_cols=data_cols, 
                                         main_label_col=main_label_col, 
                                         freq_label_col=freq_label_col, 
                                         sl_label_col=[],
                                         dim=length_input, 
                                         batch_size=batch_size, 
                                         shuffle=True,
                                         sample_weight_col=sample_weight_col)
    valid_data_generator = DataGenerator(list_IDs=val_index, 
                                         df=data1, 
                                         value_cols=data_cols, 
                                         main_label_col=main_label_col, 
                                         freq_label_col=freq_label_col, 
                                         sl_label_col=[],
                                         dim=length_input, 
                                         batch_size=batch_size, 
                                         shuffle=True,
                                         sample_weight_col=sample_weight_col)
    
    """Create Model II"""
    model = create_model_2(len(data_cols), len(main_label_col), len(freq_label_col))
    """Compile Model II"""
    parallel_model, loss, loss_weights = compile_model_2(model, count_gpus)

    """Create callbacks"""
    checkpoint = ModelCheckpoint(filepath=save_dir+get_model_weights_name(fold_var), 
                                 verbose=1, save_best_only=True, monitor=mon_var)
    early_stopper = EarlyStopping(monitor=mon_var, patience=61)
    reduce_lr = ReduceLROnPlateau(monitor=mon_var, patience=10)
    callbacks_list = [checkpoint, early_stopper, reduce_lr]
    
    """Fit the model"""
    history = parallel_model.fit_generator(generator=train_data_generator,
                                           validation_data=valid_data_generator, 
                                           use_multiprocessing=True, 
                                           epochs=200,
                                           workers=8,
                                           shuffle=True,
                                           callbacks=callbacks_list)
    
    """Plot history"""
    """Summarize history for loss"""
    plt.figure(figsize=(20, 14), dpi= 80, facecolor='w', edgecolor='k')
    plt.plot(history.history['main_prediction_loss'])
    plt.plot(history.history['val_main_prediction_loss'])
    plt.title(get_model_name(fold_var) + ' main prediction loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper right')
    plt.savefig(save_dir + "/model_2_" + str(fold_var) + "_loss_history_.png")
    plt.show()

    """Summarize history for accuracy"""
    plt.figure(figsize=(20, 14), dpi= 80, facecolor='w', edgecolor='k')
    plt.plot(history.history['main_prediction_categorical_accuracy'])
    plt.plot(history.history['val_main_prediction_categorical_accuracy'])
    plt.title(get_model_name(fold_var) + ' main prediction categorical accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='lower right')
    plt.savefig(save_dir + "/model_2_" + str(fold_var) + "_acc_history_.png")
    plt.show()
    
    """Load best model to evaluate the performance of the model"""
    parallel_model.load_weights(save_dir + get_model_weights_name(fold_var))
    model.save(save_dir + get_model_name(fold_var))
    
    results = parallel_model.evaluate(valid_data_generator)
    results = dict(zip(parallel_model.metrics_names,results))

    VAL_RESULTS[get_model_name(fold_var)] = results
    
    tf.keras.backend.clear_session()
    
    fold_var+=1

## Validation results

In [None]:
df_results = pd.DataFrame(columns=['model'])
for idx,model in enumerate(VAL_RESULTS):
    df_results.loc[idx, 'model'] = model 
    for metric in VAL_RESULTS[model]:
        df_results.loc[idx, metric] = VAL_RESULTS[model][metric]
df_results

### Print results

In [None]:
print('loss: %.2f' % df_results.loss.mean())
print(' main prediction loss: %.2f' % df_results.main_prediction_loss.mean())
print(' frequency prediction loss: %.2f' % df_results.frequency_prediction_loss.mean())
print()
print('main prediction categorical accuracy: %.2f' % df_results.main_prediction_categorical_accuracy.mean())

## Test results

In [None]:
test_data_generator = DataGenerator(list_IDs=test_indices, 
                                    df=data1, 
                                    value_cols=data_cols, 
                                    main_label_col=main_label_col, 
                                    freq_label_col=freq_label_col, 
                                    sl_label_col=[],
                                    dim=length_input, 
                                    batch_size=batch_size, 
                                    sample_weight_col=sample_weight_col)

TST_RESULTS = {}
for fold_var in range(1, 6):
    
    print(fold_var)
    
    # Load model for testing
    parallel_model = tf.keras.models.load_model(save_dir + get_model_weights_name(fold_var))
   
    # Evaluate model on test data    
    results = parallel_model.evaluate(test_data_generator)
    results = dict(zip(parallel_model.metrics_names,results))
    
    TST_RESULTS[get_model_name(fold_var)] = results

In [None]:
df_results1 = pd.DataFrame(columns=['model'])
for idx,model in enumerate(TST_RESULTS):
    df_results1.loc[idx, 'model'] = model 
    for metric in TST_RESULTS[model]:
        df_results1.loc[idx, metric] = TST_RESULTS[model][metric]
df_results1

### Print results

In [None]:
print('loss: %.2f' % df_results1.loss.mean())
print(' main prediction loss: %.2f' % df_results1.main_prediction_loss.mean())
print(' frequency prediction loss: %.2f' % df_results1.frequency_prediction_loss.mean())
print()
print('main prediction categorical accuracy: %.2f' % df_results1.main_prediction_categorical_accuracy.mean())