## This notebook provides transfer learning functionality for the CNN Models

Specifically, we need the params.json of the best PTBDB model. We first pretrain that model on MITBIH, then finetune finally on PTBDB

In [None]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


## Imports

In [None]:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

import numpy as np

import skorch
from skorch.callbacks import LRScheduler, EarlyStopping, Checkpoint, Freezer

from copy import deepcopy

from src.data_loading import load_data_mitbih, load_data_ptbdb
from src.data_preprocessing import preprocess_x_pytorch, preprocess_y_pytorch
from src.metrics_utils import compute_metrics, skorch_f1_score, sklearn_f1_score
from src.cnn_models.cnn import CNN
from src.skorch_utils import get_neural_net_classifier, get_class_weights
from src.json_utils import serialize_tensors, deserialize_tensors, save_file, read_file
from src.constants import DEVICE

torch.manual_seed(0)
np.random.seed(0)


## Data Loading

In [None]:
(x_ptbdb, y_ptbdb), (xtest_ptbdb, ytest_ptbdb) = load_data_ptbdb()
(x_mitbih, y_mitbih), (_, _) = load_data_mitbih()
    
print(x_ptbdb.shape)
print(np.unique(y_ptbdb))
assert np.array_equal(np.unique(y_ptbdb), np.unique(ytest_ptbdb))


### Data Preprocessing

In [None]:
x_ptbdb, xtest_ptbdb = preprocess_x_pytorch(x_ptbdb), preprocess_x_pytorch(xtest_ptbdb)
y_ptbdb, ytest_ptbdb = preprocess_y_pytorch(y_ptbdb), preprocess_y_pytorch(ytest_ptbdb)

x_mitbih = preprocess_x_pytorch(x_mitbih)
y_mitbih = preprocess_y_pytorch(y_mitbih)


## Pretrain best PTBDB Architecture on MITBIH dataset

### Define callbacks for training

In [None]:
# callbacks necessary for training
early_stopping_cb = EarlyStopping(patience=25, monitor="skorch_f1_score", lower_is_better=False)
lr_scheduler_cb = LRScheduler(policy=ReduceLROnPlateau, min_lr=0.000001, patience=2, verbose=True)

# callback for printing f1 score on validation set during fitting
macro_f1_cb = skorch.callbacks.EpochScoring(scoring=skorch_f1_score, lower_is_better=False)


In [None]:
params = deserialize_tensors(read_file("CnnResidual_PTBDB" + "/params.json"))

params["criterion__weight"] = torch.Tensor([1., 1., 1., 1., 1.])

net = get_neural_net_classifier(module=CNN, n_classes=5, callbacks=[macro_f1_cb, lr_scheduler_cb, early_stopping_cb], params=params)
net.fit(x_mitbih, y_mitbih)


## Replace classification layer

In [None]:
# inspired from the skorch docs https://github.com/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb
class PretrainedModel(nn.Module):
    def __init__(self, n_classes, pretrained_model, fully_connected_features):
        super().__init__()
        pretrained_model = deepcopy(pretrained_model)
        # overwrite with a linear layer that maps to only to classes
        pretrained_model.linear2 = nn.Linear(fully_connected_features, n_classes)
        self.model = pretrained_model
        
    def forward(self, x):
        return self.model(x)
    

## Fine-tune CNN model and evaluate

In [None]:
finetuning_param_keys = ["lr", "iterator_train__batch_size", "module__fully_connected_features"]
finetuning_params = {}
for k in finetuning_param_keys:
    finetuning_params[k] = params[k]

finetuning_params["criterion__weight"] = torch.Tensor([1., 1.])
finetuning_params["module__pretrained_model"] = net.module_


## Try 1: Retrain everything without freezing

In [None]:
pretrained_net = get_neural_net_classifier(module=PretrainedModel, n_classes=2, callbacks=[macro_f1_cb, lr_scheduler_cb, early_stopping_cb], params=finetuning_params)
pretrained_net.fit(x_ptbdb, y_ptbdb)

y_proba = pretrained_net.predict_proba(xtest_ptbdb)

print("-------------------------\n\n")
compute_metrics(ytest_ptbdb, y_proba, name="Transfer learning CNN - No Freeze")


## Try 2: Freeze everything but last 2 fully connected layers

In [None]:
freezer = Freezer(lambda x: not x.startswith('model.linear'))

pretrained_net = get_neural_net_classifier(module=PretrainedModel, n_classes=2, callbacks=[macro_f1_cb, lr_scheduler_cb, early_stopping_cb, freezer], params=finetuning_params)
pretrained_net.fit(x_ptbdb, y_ptbdb)

y_proba = pretrained_net.predict_proba(xtest_ptbdb)

print("-------------------------\n\n")
compute_metrics(ytest_ptbdb, y_proba, name="Transfer learning CNN - With Freeze")


## Try 3: First retrain fully connected and keep the rest frozen, then unfreeze everything

In [None]:
# set what we've trained previously
finetuning_params["module__pretrained_model"] = pretrained_net.module_

all_unfrozen = get_neural_net_classifier(module=PretrainedModel, n_classes=2, callbacks=[macro_f1_cb, lr_scheduler_cb, early_stopping_cb], params=finetuning_params)
all_unfrozen.fit(x_ptbdb, y_ptbdb)

y_proba = all_unfrozen.predict_proba(xtest_ptbdb)

print("-------------------------\n\n")
compute_metrics(ytest_ptbdb, y_proba, name="Transfer learning CNN - Freeze top layers then unfreeze")