In [None]:
import pandas as pd
import numpy as np
import os
import datetime
import matplotlib.pyplot as plt
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from torchsampler import ImbalancedDatasetSampler
from sklearn import metrics
from sklearn.metrics import average_precision_score
from efficientnet_architecture import efficientnet_b0 as efficientnetb0 #Ensure the "efficientnet_architecture.py" file is in the same directory as this notebook

device='cuda' 

In [None]:
def znorm(tensor):
    avg = tensor.mean(dim=-1,keepdim=True)
    std = tensor.std(dim=-1,keepdim=True)
    return (tensor-avg)/(std+1e-10)

In [None]:
df_train=pd.read_csv('df_train.csv')
df_val=pd.read_csv('df_val.csv')
df_test=pd.read_csv('df_test.csv')

In [None]:
wf_train=torch.load('wf_train.pt')
wf_val=torch.load('wf_val.pt')
wf_test=torch.load('wf_test.pt')

In [None]:
wf_train=znorm(wf_train)
wf_val=znorm(wf_val)
wf_test=znorm(wf_test)

In [None]:
def pipeline(df_train,df_val,df_test,bs,lr,weight_decay,epoch_num,optimizer_type,label,trainvaltest_iter,themodel,modelname,savepath):
    try:
        os.mkdir(savepath+modelname+'_'+str(bs)+'_'+str(lr))
    except:
        pass
    savepath=(savepath+modelname+'_'+str(bs)+'_'+str(lr)+'/')
    
    class Afib_train(Dataset):
        def __init__(self):
            self.thedf=df_train
            self.len=len(self.thedf)

        def __getitem__(self,index):
            self.input=wf_train[index]
            self.output=self.thedf[label][index]
            return self.input,self.output

        def get_labels(self): 
            return np.array(self.thedf[label])

        def __len__(self):
            return self.len
        
    class Afib_val(Dataset):
        def __init__(self):
            self.thedf=df_val
            self.len=len(self.thedf)

        def __getitem__(self,index):
            self.input=wf_val[index]
            self.output=self.thedf[label][index]
            return self.input,self.output

        def __len__(self):
            return self.len


    class Afib_test(Dataset):
        def __init__(self):
            self.thedf=df_test
            self.len=len(self.thedf)

        def __getitem__(self,index):
            self.input=wf_test[index]
            self.output=self.thedf[label][index]
            return self.input,self.output

        def __len__(self):
            return self.len


    train_dataset = Afib_train()
    val_dataset = Afib_val()
    test_dataset = Afib_test()

    train_dataset_dataloader = DataLoader(train_dataset,batch_size=bs, sampler=ImbalancedDatasetSampler(train_dataset))
    val_dataset_dataloader = DataLoader(val_dataset,batch_size=bs)
    test_dataset_dataloader = DataLoader(test_dataset,batch_size=bs)

    model=themodel
    criterion=nn.CrossEntropyLoss().to(device)
    if optimizer_type=='Adam':
        optimizer=torch.optim.Adam(model.parameters(),lr=lr,weight_decay=weight_decay)
    elif optimizer_type=='RMSprop':
        optimizer=torch.optim.RMSprop(model.parameters(),lr=lr, weight_decay=weight_decay,alpha=0.9, momentum=0.1)
    scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=5,factor=0.5,mode='min')

    bestAUROC=0
    accum=0
    toggle=0
    targets_test=[]
    prediction_test=[]

    for epoch in range(epoch_num):
        if toggle==1:
            break 

        now1=datetime.datetime.now()
        for phase in trainvaltest_iter:
            if phase=='training':
                running_loss=0.0
                acc=0.
                correct=0
                for param_group in optimizer.param_groups:
                    print(param_group['lr']) 
                TP=0
                TN=0
                FP=0
                FN=0
                model.train()
                for i,data in enumerate(train_dataset_dataloader):
                    inputs,targets=data
                    inputs=inputs.to(device)
                    targets=targets.long().to(device)
                    optimizer.zero_grad()
                    outputs,outputssoft=model(inputs)
                    loss=criterion(outputs,targets)
                    loss.backward()
                    optimizer.step()
                    running_loss+=loss.item()
                    prediction=torch.max(outputs.data,1)[1]
                    correct+=prediction.eq(targets.data.view_as(prediction)).cpu().sum()
                    for a,b in zip(targets,prediction):
                        if a==1 and b==1:
                            TP+=1
                        elif a==1 and b==0:
                            FN+=1
                        elif a==0 and b==1:
                            FP+=1
                        else:
                            TN+=1 
                
                print("[%d] loss: %.3f  " % (epoch+1,running_loss/100))     

            if phase=='val':
                model.eval()
                running_loss=0.0
                acc=0.
                correct=0
                TP=0
                TN=0
                FP=0
                FN=0
                prediction_score=[]
                targets_list=[]
                with torch.no_grad():
                    for i,data in enumerate(val_dataset_dataloader):
                        inputs,targets=data
                        for j in range(len(targets)):
                            targets_list.append(targets[j])
                        inputs=inputs.to(device)
                        targets=targets.long().to(device)
                        outputs,outputssoft=model(inputs)
                        loss=criterion(outputs,targets)
                        running_loss+=loss.item()
                        prediction=torch.max(outputs.data,1)[1]
                        correct+=prediction.eq(targets.data.view_as(prediction)).cpu().sum()

                        for a,b in zip(targets,prediction):
                            if a==1 and b==1:
                                TP+=1
                            elif a==1 and b==0:
                                FN+=1
                            elif a==0 and b==1:
                                FP+=1
                            else:
                                TN+=1
                        for j in range(len(outputssoft)):
                            prediction_score.append(outputssoft[j][1].item())
               
                fpr_temp,tpr_temp,threshold=metrics.roc_curve(targets_list,prediction_score,pos_label=1)
                AUROC=metrics.auc(fpr_temp,tpr_temp)
                AP=average_precision_score(targets_list,prediction_score)

                print("[%d] loss: %.3f AUROC: %.3f AUPRC: %.3f " % (epoch+1,running_loss/100,AUROC,AP))     
                scheduler.step(running_loss)
                
                if epoch==0:
                    lowest_val_loss=running_loss
                else:
                    if running_loss<lowest_val_loss:
                        lowest_val_loss=running_loss
                        accum=0
                    else:
                        accum+=1
                    if accum==15:
                        toggle=1

            if phase=='test':
                model.eval()
                running_loss=0.0
                acc=0.
                correct=0
                TP=0
                TN=0
                FP=0
                FN=0
                prediction_score=[]
                targets_list=[]
                with torch.no_grad():
                    for i,data in enumerate(test_dataset_dataloader):
                        inputs,targets=data
                        for j in range(len(targets)):
                            targets_list.append(targets[j])
                        inputs=inputs.to(device)
                        targets=targets.long().to(device)
                        outputs,outputssoft=model(inputs)
                        loss=criterion(outputs,targets)
                        running_loss+=loss.item()
                        prediction=torch.max(outputs.data,1)[1]
                        correct+=prediction.eq(targets.data.view_as(prediction)).cpu().sum()

                        for a,b in zip(targets,prediction):
                            if a==1 and b==1:
                                TP+=1
                            elif a==1 and b==0:
                                FN+=1
                            elif a==0 and b==1:
                                FP+=1
                            else:
                                TN+=1
                        for j in range(len(outputssoft)):
                            prediction_score.append(outputssoft[j][1].item())

                try:
                    sensitivity=(TP/(TP+FN))
                except:
                    sensitivity=0.5
                try:
                    specificity=(TN/(TN+FP))
                except:
                    specificity=0.5
                try:
                    precision=(TP/(TP+FP))
                except:
                    precision=0
                if precision!=0:
                    f1score=((2*precision*sensitivity)/(precision+sensitivity))
                else:
                    f1score=0                
                
                fpr_temp,tpr_temp,threshold=metrics.roc_curve(targets_list,prediction_score,pos_label=1)
                AUROC=metrics.auc(fpr_temp,tpr_temp)
                AP=average_precision_score(targets_list,prediction_score)

                print("[%d] loss: %.3f AUROC: %.3f AUPRC: %.3f " % (epoch+1,running_loss/100,AUROC,AP))     
                
                now = datetime.datetime.now()
                nowStr2 = "{:%Y%m%d%H%M%S}".format(now)
                modelname2=nowStr2+'_'
                modelname2+=(label+'_')
                modelname2+=modelname
                modelname2+=('_'+'epoch_'+str(epoch))
                modelname2+='.pt'
                
                torch.save(model.state_dict(),savepath+modelname2)


        now2=datetime.datetime.now()
        print(now2-now1)

In [None]:
save_path='testsavepath/'
phases=['training','val','test']
#modeltemp=efficientnetb0().to(device)
for modelname in ['b0']:
    for batchsize in [200,400,800]:
        for learningrate in [0.1,0.01,0.001]:
            for weightdecay in [0,0.001]: 

                if modelname=='b0':
                    model=efficientnetb0().to(device)
                elif modelname=='b1':
                    model=efficientnetb1().to(device)
                else:
                    model=efficientnetb2().to(device)

                print('===================================================')
                print(modelname,batchsize,learningrate,weightdecay)

                model.to(device)

                if torch.cuda.device_count()>1:
                    print(f"using all {torch.cuda.device_count()} GPUs")
                    model=nn.DataParallel(model.to(device))


                pipeline(df_train,df_val,df_test,batchsize,learningrate,weightdecay,300,'Adam',"label",phases,model,modelname,save_path)
