# This notebook provides training, saving and evaluation for the vanilla RNN and CNN models
Models trained by this notebook can be later loaded in Taks3 for ensembling

In [None]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


## Choose the dataset

In [None]:
# DATASET = "MITBIH"
DATASET = "PTBDB"

if DATASET == "MITBIH":
    N_CLASSES = 5
else:
    N_CLASSES = 2
    

## Imports

In [None]:
import torch
import numpy as np

import skorch
from skorch.callbacks import LRScheduler, EarlyStopping, Checkpoint

from torch.optim.lr_scheduler import ReduceLROnPlateau

from copy import deepcopy

from src.data_loading import load_data_mitbih, load_data_ptbdb
from src.data_preprocessing import preprocess_x_pytorch, preprocess_y_pytorch
from src.metrics_utils import compute_metrics, compute_metrics_from_keras, skorch_f1_score, sklearn_f1_score
from src.cnn_models.cnn import CNN
from src.skorch_utils import get_neural_net_classifier, get_class_weights
from src.json_utils import serialize_tensors, save_file

torch.manual_seed(0)
np.random.seed(0)

## CNN Models Section

### Data Loading

In [None]:
if N_CLASSES == 5:
    (x, y), (xtest, ytest) = load_data_mitbih()
else:
    (x, y), (xtest, ytest) = load_data_ptbdb()

(x_orig, y_orig), (xtest_orig, ytest_orig) = (deepcopy(x), deepcopy(y)), (deepcopy(xtest), deepcopy(ytest))
    
print(x.shape)
print(np.unique(y))
assert np.array_equal(np.unique(y), np.unique(ytest))



In [None]:
x, xtest = preprocess_x_pytorch(x), preprocess_x_pytorch(xtest)
y, ytest = preprocess_y_pytorch(y), preprocess_y_pytorch(ytest)


### Train CNN Model

In [None]:
# load best hyperparameters found for the vanilla CNN for the MITBIH dataset
if DATASET == "MITBIH":
    params = {
        'module__strides': [2, 1],
        'module__n_filters': [64, 128],
        'module__kernel_sizes': [13, 7],
        'module__adaptive_average_len': 8,
        'module__fully_connected_features': 64,
        'module__residual': False,
        'lr': 0.0002,
        'iterator_train__batch_size': 256,
        'criterion__weight': torch.Tensor([1., 1., 1., 1., 1.])
    }
else:
    params = {
        'module__strides': [2, 1],
        'module__n_filters': [64, 128],
        'module__kernel_sizes': [13, 7],
        'module__adaptive_average_len': 8,
        'module__fully_connected_features': 256,
        'module__residual': False,
        'lr': 0.0008,
        'iterator_train__batch_size': 256,
        'criterion__weight': torch.Tensor([1.7981, 0.6926])
    }
    
# need to save these params to be able to load the model later
save_file("CnnVanilla" + "_" + DATASET + "/params.json", serialize_tensors(params))
    
# callbacks necessary for training
early_stopping_cb = EarlyStopping(patience=25, monitor="skorch_f1_score", lower_is_better=False)
lr_scheduler_cb = LRScheduler(policy=ReduceLROnPlateau, min_lr=0.000001, patience=2, verbose=True)

# callback for printing f1 score on validation set during fitting
macro_f1_cb = skorch.callbacks.EpochScoring(scoring=skorch_f1_score, lower_is_better=False)

# callback for saving the best model according to validation f1 score
cp_cb = Checkpoint(dirname="CnnVanilla" + "_" + DATASET, monitor="skorch_f1_score_best")

net = get_neural_net_classifier(module=CNN, n_classes=N_CLASSES, callbacks=[macro_f1_cb, lr_scheduler_cb, early_stopping_cb, cp_cb], params=params)
net.fit(x, y)


### Evaluate CNN Model

In [None]:
y_proba = net.predict_proba(xtest)
compute_metrics(ytest, y_proba, name="Vanilla_CNN")


### Train RNN Model

In [None]:
from tensorflow import keras
from keras import optimizers, losses, activations, models
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, ReduceLROnPlateau
from keras.layers import Dense, Input, Dropout, LSTM, GRU, SimpleRNN
from src.rnn_models.rnn import get_rnn_model

In [None]:
model = get_rnn_model(DATASET, "vanilla")

file_path = "RnnVanilla_"+DATASET+"/model.h5"
checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
early = keras.callbacks.EarlyStopping(monitor="val_acc", mode="max", patience=5, verbose=1)
redonplat = ReduceLROnPlateau(monitor="val_acc", mode="max", patience=3, verbose=2)
callbacks_list = [checkpoint, early, redonplat]  # early


model.fit(x_orig, y_orig, epochs=1000, verbose=2, callbacks=callbacks_list, validation_split=0.1)


### Evaluate RNN Model

In [None]:
model.load_weights(file_path)
y_proba = model.predict(xtest_orig)
compute_metrics_from_keras(ytest_orig, y_proba, name="Vanilla_RNN")
