Initial training of SilvaNet.

It saves the weights that are then used in the **transfer** scripts

In [None]:
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import pandas as pd
import sys

from network import SilvaNet

from dataset import BeatDataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#TRAINING SETTINGS
BATCH_SIZE = 64
EPOCHS = 10

In [None]:
#EXPERIMENT SETTINGS
MODEL = SilvaNet
NAME_EXPERIMENT = 'SilvaNet'
DATASET = ['MIT_BIH_LongTerm', 'MIT_BIH_NormalSinus']
N = 240000

DATAROOT = './data'

In [None]:
#LOAD DATASETS
PARTITION = 'train' #<-- this should never change, since it is a training script
dataset_train = BeatDataset(f'{DATAROOT}/dataset',
                            [f'{x}/{PARTITION}' for x in DATASET],
                            N=N)
loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
model = MODEL()
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss(weight = torch.Tensor([0.06, 0.94]).to(device))

In [None]:
#%%
#TRAIN
model.train()
LR = 1e-2

losses_tr = []

print(datetime.datetime.now() )
for epoch in range(EPOCHS):
  if epoch % 50 == 0:
    optimizer = torch.optim.Adadelta(model.parameters(), lr=LR)
    LR/=10

  #for each batch in the dataset
  for j, batch in enumerate(loader_train):
    optimizer.zero_grad()
      
    data = batch["data"].to(device)
    target = batch["target"].to(device)
    output = model(data) 
      
    loss = criterion(output, target) #compute loss
    loss.backward() #backward
    optimizer.step() #update weights
    loss_tr = loss.item()

    if j % 5 == 0:
      losses_tr.append(loss_tr)

    #print status to stdout
    sys.stdout.write('\r Epoch {} of {}  [{:.2f}%] - loss TR: {:.4f}'.format(epoch+1, EPOCHS, 100*j/len(loader_train), loss_tr))
print(datetime.datetime.now() )

In [None]:
result_dir = f'{DATAROOT}/results/{NAME_EXPERIMENT}' 
os.makedirs(result_dir)

In [None]:
import matplotlib.pyplot as plt

plt.plot(losses_tr)
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.savefig(f'{result_dir}/losses.png')

In [None]:
torch.save(model.state_dict(), f'{result_dir}/weights.pth')