In [None]:
import torch, time, os, shutil
import  util
import numpy as np
import pandas as pd
from torch import nn, optim
from torch.utils.data import DataLoader
from dataset import Dataset
from config import cfg
import matplotlib.pyplot as plt
from torch.optim import lr_scheduler

from tqdm import tqdm
#from loss.md_loss import MultiDomainLoss
import torch.nn.functional as F
import data_util as du
import random

import warnings
warnings.filterwarnings("ignore")

In [None]:
def train_epoch(model, optimizer, criterion, scheduler, train_dataloader):
    model.train()

    losses = []
    total = 0
    show_bar =False
    tbar = tqdm(train_dataloader, disable = not show_bar)
    for i, (inputs, target) in enumerate(tbar):      
        data = inputs.to(device)
        data = data.to(torch.float32)
        labelt = target.to(device)

        optimizer.zero_grad()
        output = model(data)
 
        output = F.sigmoid(output)

        loss = criterion(output,labelt.to(torch.float32)) 
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        scheduler.step()
             
    tbar.close()       
    for i in range(len(losses)):
        total = total + losses[i]
        
    total /= len(losses)
      
    return total


def val_epoch(model, optimizer, criterion, scheduler, val_dataloader):
    model.eval()
    losses = []
    total = 0
    show_bar =False
    tbar = tqdm(val_dataloader, disable =not show_bar)
    for i, (inputs, target) in enumerate(tbar):     
        data = inputs.to(device)
        data = data.to(torch.float32)
        labelt = target.to(device)

        optimizer.zero_grad()
        output = model(data)

        output = F.sigmoid(output)

        loss = criterion(output,labelt.to(torch.float32)) 
        losses.append(loss.item())
    
    tbar.close()
    
    for i in range(len(losses)):
        total = total + losses[i]
    total /= len(losses)
        
    return total

In [None]:
def get_model():
    from model import FR_Net
    model = FR_Net(input_channel=4,layer=32,kernel_size=3)
    return model

def get_loss():
    from monai.losses import FocalLoss
    return FocalLoss(to_onehot_y=True)



In [None]:
def main():
    model = get_model()
    model = model.to(device)
    start_epoch = 1
    
    train_dataset = Dataset(train=True, seg_len = c.seg_len, fs = c.fs, 
                                test_idx = test_idx)
    train_dataloader = DataLoader(train_dataset, batch_size=c.batch_size, shuffle=True, num_workers=0)#####
    val_dataset = Dataset(train=False, seg_len = c.seg_len, fs = c.fs, 
                            test_idx = test_idx)
    val_dataloader = DataLoader(val_dataset, batch_size=c.batch_size, num_workers=0)
    test_dataset = Dataset(train=False, seg_len = c.seg_len, fs = c.fs, 
                            test_idx = test_idx)
    test_dataloader = DataLoader(test_dataset, batch_size=c.batch_size, num_workers=0)

    optimizer = optim.Adam(model.parameters(), lr=c.lr)
    criterion = get_loss()
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.2)
    
    
    for epoch in range(start_epoch, c.max_epoch+1):
        since = time.time()
        train_loss = train_epoch(model, optimizer, criterion, exp_lr_scheduler, train_dataloader)
     
        val_loss = val_epoch(model, optimizer, criterion, exp_lr_scheduler, val_dataloader)




    model.eval()

    pred_label = None

    for i, (inputs, target) in enumerate(test_dataloader):     
        data = inputs.to(device)
        data = data.to(torch.float32)
        labelt = target.to(device)

        optimizer.zero_grad()
        output = model(data)
     

        output = F.sigmoid(output)

        output[output >= 0.5] = 1
        output[output <  0.5] = 0
        output = output.squeeze().detach().cpu().numpy()
        if pred_label is None:
            pred_label = output.flatten()
        else:
            pred_label = np.hstack((pred_label, output.flatten()))
            
    pred_label = pred_label.reshape((len(pred_label)//c.seg_len,1, c.seg_len))
    pred_label = du.deframe(pred_label)
    
    pred_peaks = util.get_peak_from_label(pred_label, fs = c.fs)
    fqrs_rpeaks = test_dataset.get_fqrs()
    Recall,Precision, F1_score = util.evaluate([fqrs_rpeaks,], [pred_peaks,],fs = c.fs, thr=50)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
c = cfg()
c.max_epoch = 1
test_idx = 0
main()
