Transfer learning starting from the pre-trained SilvaNet network

In [None]:
from google.colab import drive
drive.mount('/gdrive')

In [None]:
!cp -a /gdrive/MyDrive/DL_beat_detection .

In [None]:
import os

In [None]:
os.chdir('DL_beat_detection/data')

In [None]:
#!tar xfz MIT_BIH_LongTerm.tar.gz 
#!tar xfz MIT_BIH_NormalSinus.tar.gz 
!tar xfz MIT_BIH_Arrhythmia.tar.gz 

In [None]:
os.chdir('..')

In [None]:
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import pandas as pd
import sys

from network import SilvaNet

from dataset import BeatDataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

In [None]:
#TRAINING SETTINGS
BATCH_SIZE = 64
EPOCHS = 10

In [None]:
#EXPERIMENT SETTINGS
MODEL = SilvaNet
NAME_EXPERIMENT = 'transfer_SilvaNet_ComfTech_Movement'
DATASET = ['WCS_ComfTech_movement']
PARTITION = 'train'
N = 14886

DATAROOT = './data'


In [None]:
#LOAD DATASETS
dataset = BeatDataset(f'{DATAROOT}/dataset',
                      [f'{x}/{PARTITION}' for x in DATASET],
                      N=N)

loader_train = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
state_dict = torch.load(f'{DATAROOT}/results/SilvaNet/weights.pth')

model = MODEL()
model.load_state_dict(state_dict)
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss(weight = torch.Tensor([0.06, 0.94]).to(device))

In [None]:
#only retrain the fully connected block
for p in list(model.parameters())[:-6]:
    p.requires_grad = False

In [None]:
#%%
#TRAIN
model.train()
LR = 1e-2

losses_tr = []

for epoch in range(EPOCHS):
    if epoch % 50 == 0:
        optimizer = torch.optim.Adadelta(model.parameters(), lr=LR)
        LR/=10

    #for each batch in the dataset
    for j, batch in enumerate(loader_train):
        optimizer.zero_grad()

        data = batch["data"].to(device)
        target = batch["target"].to(device)
        output = model(data) 

        loss = criterion(output, target) #compute loss
        loss.backward() #backward
        optimizer.step() #update weights
        loss_tr = loss.item()

        if j % 5 == 0:
            losses_tr.append(loss_tr)

            #print status to stdout
            sys.stdout.write('\r Epoch {} of {}  [{:.2f}%] - loss TR: {:.4f}'.format(epoch+1, EPOCHS, 100*j/len(loader_train), loss_tr))

In [None]:
result_dir = f'{DATAROOT}/results/{NAME_EXPERIMENT}' 
os.makedirs(result_dir)

In [None]:
import matplotlib.pyplot as plt

plt.plot(losses_tr)
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.savefig(f'{result_dir}/losses.png')

In [None]:
torch.save(model.state_dict(), f'{result_dir}/weights.pth')