<a href="https://colab.research.google.com/github/JulioEI/TBCG_Group12/blob/main/training_CNN_master.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import os
import numpy as np
import sys

In [None]:
!git clone https://github.com/JulioEI/TBCG_Group12.git
sys.path.insert(0,'/content/TBCG_Group12/code')
import utils as ut
import model_builders as mb
import bcg_auxiliary as bcg

Define General Parameters

In [13]:
fs=1250 #sampling frequency after downsampling
window_seconds = 0.05 #window length in seconds
overlapping = 0.7 #window overlapping
batch_size = 32 #batch size
learning_rate = 1e-5 #learning rate
epochs = 300 #number of training epochs

Select Model Type (best results with prob_NANI)

In [15]:
binary = False #predict only binary for each window (whether there is or is not a ripple)
Unet = False #architecture lossely based on unet
prob_NANI = True #(by Default) predict probability for each window (gicen by the # of points inside the window which are ripples)

Download training and validation data from figshare

In [None]:
#TRAINING DATA
!wget  https://figshare.com/ndownloader/articles/16856182/versions/2 -O zip_Amigo2
!unzip /content/zip_Amigo2 -d /content/Amigo2
!rm zip_Amigo2

!wget  https://figshare.com/ndownloader/articles/16856137/versions/2 -O zip_Som2
!unzip /content/zip_Som2 -d /content/Som2
!rm zip_Som2

#VALIDATION DATA
!wget https://figshare.com/ndownloader/articles/14959449/versions/4 -O zip_Dlx1
!unzip /content/zip_Dlx1 -d /content/Dlx1
!rm zip_Dlx1

!wget https://figshare.com/ndownloader/articles/14960085/versions/1 -O zip_Thy7
!unzip /content/zip_Thy7 -d /content/Thy7
!rm zip_Thy7

Load training data and preprocess it

In [None]:
### LOAD TRAIN DATA and PREPROCESS IT (downsampling, zscore, create windows)###
datapath = "/content/Amigo2"
data_Amigo2, ripples_tags_Amigo2, signal_Amigo2, x_train_Amigo2, y_train_Amigo2, indx_map_Amigo2 = ut.load_data_pipeline(
    datapath, desired_fs=fs, window_seconds = window_seconds, overlapping = overlapping, zscore= True, binary = binary)

datapath = "/content/Som2"
data_Som2, ripples_tags_Som2, signal_Som2, x_train_Som2, y_train_Som2, indx_map_Som2 = ut.load_data_pipeline(
    datapath, desired_fs=fs, window_seconds = window_seconds, overlapping = overlapping, zscore=True, binary = binary)

### MERGE TRAINING DATA ###
x_train = np.vstack((x_train_Amigo2, x_train_Som2))
if binary:
    y_train = np.vstack((y_train_Amigo2, y_train_Som2))
else:
    y_train = np.vstack((np.expand_dims(y_train_Amigo2,axis=1), np.expand_dims(y_train_Som2,axis=1)))
indx_map_train = np.vstack((indx_map_Amigo2, indx_map_Som2))

Load validation data and preprocess it

In [None]:
###############################################################################
#                             VALIDATION DATA                                 #
###############################################################################
### LOAD VALIDATION DATA ###
datapath = "/content/Dlx1"
data_Dlx1, ripples_tags_Dlx1, signal_Dlx1, x_validation_Dlx1, y_validation_Dlx1, indx_map_Dlx1 = ut.load_data_pipeline(
    datapath, desired_fs=fs, window_seconds = window_seconds, overlapping = overlapping, zscore=True, binary = binary)

datapath = "/content/Thy7"
data_Thy7, ripples_tags_Thy7, signal_Thy7, x_validation_Thy7, y_validation_Thy7, indx_map_Thy7 = ut.load_data_pipeline(
    datapath, desired_fs=fs, window_seconds = window_seconds, overlapping = overlapping, zscore=True, binary = binary)

### MERGE VALIDATION DATA ###
x_validation = np.vstack((x_validation_Dlx1, x_validation_Thy7))
if binary:
    y_validation = np.vstack((y_validation_Dlx1, y_validation_Thy7))
else:
    y_validation = np.vstack((np.expand_dims(y_validation_Dlx1,axis=1), np.expand_dims(y_validation_Thy7,axis=1)))
indx_map_validation = np.vstack((indx_map_Dlx1, indx_map_Thy7))

Define the architecture and build the CNN model

In [None]:
###############################################################################
#                                CNN BUILDING                                 #
###############################################################################

from tensorflow.keras.callbacks import ModelCheckpoint
n_ch = data_Som2.shape[1]
input_shape = (int(fs*window_seconds),n_ch,1)

if binary:
    model = mb.model_builder_binary(filters_Conv1 = 32, filters_Conv2 = 16, filters_Conv3=8, filters_Conv4 = 16,
                      filters_Conv5 =16, units_Dense = 60, input_shape = input_shape,
                      learning_rate = 1e-5)
    !mkdir '/content/training_cp/training_binary_v1'
    checkpoint_path = "/content/training_cp/training_binary_v1/cp-{epoch:04d}.ckpt"

elif Unet:
        model= mb.model_builder_Unet(filters_Conv1 = 10, filters_Conv2 = 15, filters_Conv3=15, filters_Conv4 = 15,
                               input_shape = input_shape, learning_rate  = 1e-4)
        !mkdir '/content/training_cp/training_unet_v1'
        checkpoint_path = "/content/training_cp/training_unet_v1/cp-{epoch:04d}.ckpt"
elif prob_NANI:
        model = mb.model_builder_prob(filters_Conv1 = 32, filters_Conv2 = 16, filters_Conv3=8, filters_Conv4 = 16,
                          filters_Conv5 =16, filters_Conv6=8, input_shape = input_shape, 
                          learning_rate  = 1e-5)
        !mkdir '/content/training_cp/training_prob_vf'
        checkpoint_path = "/content/training_cp/training_prob_vf/cp-{epoch:04d}.ckpt"

#create checkpoint save method
checkpoint_dir = os.path.dirname(checkpoint_path)
save_freq = int(25*np.ceil(x_train.shape[0]/batch_size))
cp_callback = ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=save_freq
 )

Train the CNN and save the final model

In [None]:
###############################################################################
#                                  TRAIN CNN                                  #
###############################################################################
model.fit(x_train,y_train, shuffle = True, epochs = epochs, batch_size = batch_size, 
          callbacks=[cp_callback], validation_data = (x_validation, y_validation))
model.save_weights(checkpoint_path)
# Save model
!mkdir '/content/model'
model.save('content/model/model_prob_vf.h5')

Evaluate again the final model with the validation data

In [18]:
###############################################################################
#                            EVALUATE CNN OUTPUT                              #
###############################################################################
y_prediction_Dlx1 = model.predict(x_validation_Dlx1)
y_prediction_Thy7 = model.predict(x_validation_Thy7)

events_prediction_Dlx1 = ut.get_ripple_times_from_CNN_output(y_prediction_Dlx1, 
                                     indx_map_Dlx1, th_zero = 5e-1, th_dur = 0.01, verbose = False)

events_prediction_Thy7 = ut.get_ripple_times_from_CNN_output(y_prediction_Thy7,
                                     indx_map_Thy7, th_zero = 5e-1, th_dur = 0.01, verbose = False)

P_Dlx1, R_Dlx1, F1_Dlx1 = bcg.get_score(ripples_tags_Dlx1, events_prediction_Dlx1, threshold=0.1)
print("Dlx1: ", P_Dlx1, R_Dlx1, F1_Dlx1)

P_Thy7, R_Thy7, F1_Thy7 = bcg.get_score(ripples_tags_Thy7, events_prediction_Thy7, threshold=0.1)
print("Thy7: ", P_Thy7, R_Thy7, F1_Thy7)

Dlx1:  0.7149532710280374 0.7251184834123223 0.7200000000000001
Thy7:  0.9854014598540146 0.25375939849624063 0.40358744394618834
