# Training and Evaluation
This notebook presents a minimal example of how to setup, train and evaluate the deep learning architecture presented in: 
*"A Flexible Deep Learning Architecture for Temporal Sleep Stage Classification using Accelerometry and Photoplethysmography".* 

The notebook is organized following the brilliant work from Chambon et al. 2018 $^{1}$.

Requirements: 
1. Install repository.
	- pip install -e
2. Install dependencies
	- dependencies
3. Prepare signal modalities and annotation in h5 files.
	- minimal_example/to_h5.py

###### $^{1}$ "S. Chambon, V. Thorey, P. J. Arnal, E. Mignot, A. Gramfort, "A deep learning architecture to detect events in EEG signals during sleep." IEEE 28th International Workshop on Machine Learning for Signal Processing (MLSP), 2018. [[Paper](https://arxiv.org/abs/1807.05981)|[Github](https://github.com/Dreem-Organization/dosed)]"



In [1]:
import os
import numpy as np

#print(os.getcwd()) # data path is relative to the current working directory, which is found using: print(os.getcwd()) 
os.chdir('C:\\Users\\mads_\\OneDrive - Danmarks Tekniske Universitet\\Dokumenter\\python\\MasterAlgorithm')
from datasets import get_train_validation_test

seed = 2022
#print(os.getcwd())

In [2]:
# data path is relative to the current working directory, which is found using: print(os.getcwd()) 

data_directory = '..\MasterAlgorithm\data\h5'
training_set, evaluation_set, test_set = get_train_validation_test(data_directory,
                                                    percent_test=40,
                                                    percent_validation=20,
                                                    seed=seed)

# Now we are ready 
print('training set: {}'.format(len(training_set)))
print('evaluation set: {}'.format(len(evaluation_set)))
print('test set: {}'.format(len(test_set)))

training set: 16
evaluation set: 4
test set: 12


# Dataset preparation

## Signal preprocessing. 
- "h5_path": string - path within h5 file to locate exact modality. 
- "channel_idx": list of integers - channel indexes to extract from the signal modality. 
- "preprocessing": list of dicts - defining preprocessing steps. A signal modality can have an arbitrary number of pre-processing steps.
- "batch_normalization": dict - assigning normalization actions to do during training. 
- "transformations": dict - assigning transformation (augmentation) operations to process the signal input with during training. 
- "add": boolean - assinging whether the different signal channels should be added after initial preprocessing (to reduce dimensionality).
- "fs_post": float - sample frequency of the preprocessed signal. 
- "dimensions": list of integers - dimensions of the preprocessed signal. 

Our algorithm assumes that both ACC and PPG have a sampling frequency of 32 Hz. If they do not, a resampling step should be added as the first preprocessing operation. 

In [3]:
signals_format = {
    'ACC_merge': {
        'h5_path': 'acc_signal',
        'channel_idx': [0, 1, 2],
        'preprocessing': [
            {
                'type': 'median',
                'args': {
                    'window_size': 30
                }
            },
            {
                'type': 'iqr_normalization_adaptive',
                'args': {
                    'median_window': 30001,
                    'iqr_window': 30001
                }
            },
            {
                'type': 'clip_by_iqr',
                'args': {
                    'threshold': 20
                }
            },
            {
                'type': 'cal_psd',
                'args': {
                    'window': 10 * 32, 
                    'noverlap': 8 * 32, 
                    'nfft': int(2 ** np.ceil(np.log2(10 * 32))),
                    'f_min': 0,
                    'f_max': 12,
                    'f_sub': 3
                }
            }

        ],
        'batch_normalization': {},
        'transformations': {
            #'image_translation': {},
            #'time_mask': {},
            #'freq_mask': {},
        },
        'add': True,
        'fs_post': 1,
        'dimensions': [int(2 ** np.ceil(np.log2(10 * 32)) / 32 * (12 - 0)) // 3, 1]
    },
    'ppg_signal': {
        'h5_path': 'ppg_signal',
        'channel_idx': [0],
        'preprocessing': [
            {
                'type': 'zscore',
                'args': {}
            },
            {
                'type': 'change_PPG_direction',
                'args': {}
            },
            {
                'type': 'iqr_normalization_adaptive',
                'args': {
                    'median_window': 301,
                    'iqr_window': 301
                }
            },
            {
                'type': 'clip_by_iqr',
                'args': {
                    'threshold': 20
                }
            },
            {
                'type': 'cal_psd',
                'args': {
                    'window': 10 * 32,
                    'noverlap': 8 * 32,
                    'nfft': int(2 ** np.ceil(np.log2(10 * 32))),
                    'f_min': 0.1,
                    'f_max': 4.1,
                    'f_sub': 1
                }
            }
        ],
        'batch_normalization': {},
        'transformations': {
            #'image_translation': {},
            #'time_mask': {},
            #'freq_mask': {},
        },
        'add': False,
        'fs_post': 1,
        'dimensions': [int(2 ** np.ceil(np.log2(10 * 32)) / 32 * (4.1 - 0.1)) // 1, 1]
    }
}

## Event format
events are formatted with a name, their h5 relative path, and a probability. This probability is used during batch generation to balance the events. 


In [14]:
events = ['wake', 'light', 'deep', 'rem']
events_format = [
    {
        'name': 'wake', 
        'h5_path': 'wake',
        'probability': 1 / len(events)
    },
    {
        'name': 'light', 
        'h5_path': 'light',
        'probability': 1 / len(events)
    },
    {
        'name': 'deep', 
        'h5_path': 'deep',
        'probability': 1 / len(events)
    },
    {
        'name': 'rem', 
        'h5_path': 'rem',
        'probability': 1 / len(events)
    }
]

## Dataset class
The dataset class works as batch generator at training time. I handles preprocessing of the signal modalities, that are loaded from their specified h5 directories. 

- "records": list of strings - list of record filenames.
- "h5_directory": string - h5 directory of data files.
- "signal_format": list of dicts - directory and preprocessing steps for each signal modality (see above). 
- "window": interger - temporal window segment size (in seconds). 
- "number_of_channels": interger - number of signal modality inputs. 
- "events_format": list of dicts - format of events to model.  
- "prediction_resolution": integer - model output resolution (in seconds). 
- "overlap": float - consecutive window segments have this assigned overlap. Not used during balanced sampling. 
- "minimum_overlap": float - when signals are segmented, there is a risk of cutting off events. The minimum overlap is the required duration of an event that is cut off relative to the window size. 
- "batch_size": integer - batch size. 
- "mode": string - "inference" or "training".
- "cache_data": boolean - Cache preprocessing using Joblib.  
- "n_jobs": integer - starts parallel preprocessing. max is number of cores supported by local system.
- "seed": integer - seed.
- "use_mask": boolean - whether to apply mask. Mask must be defined as event in h5 file. 
- "load_signal_in_RAM": boolean - whether to load all preprocessed data in RAM during trianing (faster but requires memory).

In [15]:
from datasets import DatasetGenerator, BalancedDatasetGenerator

dataset_params = {
    "h5_directory": data_directory, 
    "signals_format": signals_format,
    "window": 30 * 2 ** 8, 
    "number_of_channels": len(signals_format), 
    "events_format": events_format,
    "prediction_resolution": 30,
    "overlap": 0.25,
    "minimum_overlap": 0.1,
    "batch_size": 2,
    "cache_data": True,
    "n_jobs": 4,
    "use_mask": True,
    "load_signal_in_RAM": True
}

ds_train = DatasetGenerator(records=training_set, mode="train", **dataset_params)
ds_evaluate = DatasetGenerator(records=evaluation_set, mode="inference", **dataset_params)
ds_test = DatasetGenerator(records=test_set, mode="inference", **dataset_params)

signal, event = ds_train.__getitem__(0)
print(signal.shape)
print(event.shape)

 ... (more hidden) ...
 ... (more hidden) ...
 ... (more hidden) ...


(2, 7680, 64, 2)
(2, 256, 4)


# Model creation

- "input shape": list of integers: [T, F, C] - Temporal dimension, Spatial dimension, Channel dimension - inferred from signals_format. 
- "num_classes": integer - number of classes - inferred from events_format
- "num_outputs": integer - number of model outputs (timesteps) - inferred from the signals_format
- "depth": integer integer - number of encoder and decoder layers in the network - M.
- "init_filter_num": integer - number of filters of the first encoder layer. 
- "filter_increment_factor": float - number of filters of layer n = number of filters of layer n-1 * filter_increment_factor. 
- "max_pool_size": tuple of integers - maxpool size used in each layer. 
- "kernel size": tuple of integers - kernel size of filters in each layer. 


In [16]:
from models import ResUNet

# model creation
model_params = {
    'input_shape': [ds_train.fsTime * ds_train.window, ds_train.nSpace, ds_train.nChannels], 
    'num_classes': len(events),
    'num_outputs': ds_train.window // ds_train.prediction_resolution,
    'depth': 9,
    'init_filter_num': 16,
    'filter_increment_factor': 2 ** (1 / 3),
    'max_pool_size': (2, 2),
    'kernel_size': (16, 3)
}

resunet = ResUNet(**model_params)
resunet.summary() # print summary.

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 7680, 64, 2) 0                                            
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 8192, 64, 2)  0           input_4[0][0]                    
__________________________________________________________________________________________________
conv2d_105 (Conv2D)             (None, 8192, 64, 16) 1552        zero_padding2d_3[0][0]           
__________________________________________________________________________________________________
batch_normalization_139 (BatchN (None, 8192, 64, 16) 64          conv2d_105[0][0]                 
____________________________________________________________________________________________

# training session

In [17]:
from tensorflow.keras.optimizers import Adam
from functions import loss_functions, metrics


resunet.compile(loss=loss_functions['weighted_loss'](**{'alpha': 0.35, 'gamma': 2}),
                run_eagerly=False,
                optimizer=Adam(learning_rate=1e-4, epsilon=1e-8),
                metrics=[metrics['cohens_kappa'](num_classes=len(events))])

history = resunet.fit(ds_train,
                      epochs=10,
                      verbose=1,
                      initial_epoch=0,
                      validation_data=ds_evaluate)


Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
 1/20 [>.............................] - ETA: 7s - loss: 0.1186 - cohens_kappa: 0.8061

KeyboardInterrupt: 

# Test Model 