In [None]:
import pandas as pd
import numpy as np
from ast import literal_eval
import base64
import array
import matplotlib.pyplot as plt
import datetime
import os
from sklearn.model_selection import StratifiedGroupKFold
from torch.utils.data import Dataset,DataLoader
from torchsampler import ImbalancedDatasetSampler
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import utils
import matplotlib.pyplot as plt
%matplotlib inline
import time
import copy
import random
from sklearn import metrics
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_curve,auc
from torchsummary import summary
from tableone import TableOne
import copy
from sklearn.model_selection import train_test_split
from lifelines.utils import concordance_index
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedKFold



device='cuda'

# Model training

In [None]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return x * self.sigmoid(x)

class SEBlock(nn.Module):
    def __init__(self, in_channels, r=4):
        super().__init__()

        self.squeeze = nn.AdaptiveAvgPool1d(1)
        self.excitation = nn.Sequential(
            nn.Linear(in_channels, in_channels * r),
            Swish(),
            nn.Linear(in_channels * r, in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.squeeze(x)
        x = x.view(x.size(0), -1)
        x = self.excitation(x)
        x = x.view(x.size(0), x.size(1), 1)
        return x

class MBConv(nn.Module):
    expand = 6
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, se_scale=4, p=0.5):
        super().__init__()
        # first MBConv is not using stochastic depth
        self.p = torch.tensor(p).float() if (in_channels == out_channels) else torch.tensor(1).float()

        self.residual = nn.Sequential(
            nn.Conv1d(in_channels, in_channels * MBConv.expand, 1, stride=stride, padding=0, bias=False),
            nn.BatchNorm1d(in_channels * MBConv.expand, momentum=0.99, eps=1e-3),
            Swish(),
            nn.Conv1d(in_channels * MBConv.expand, in_channels * MBConv.expand, kernel_size=kernel_size,
                      stride=1, padding=kernel_size//2, bias=False, groups=in_channels*MBConv.expand),
            nn.BatchNorm1d(in_channels * MBConv.expand, momentum=0.99, eps=1e-3),
            Swish()
        )

        self.se = SEBlock(in_channels * MBConv.expand, se_scale)

        self.project = nn.Sequential(
            nn.Conv1d(in_channels*MBConv.expand, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
        )

        self.shortcut = (stride == 1) and (in_channels == out_channels)

    def forward(self, x):
        # stochastic depth
        if self.training:
            if not torch.bernoulli(self.p):
                return x

        x_shortcut = x
        x_residual = self.residual(x)
        x_se = self.se(x_residual)

        x = x_se * x_residual
        x = self.project(x)

        if self.shortcut:
            x= x_shortcut + x

        return x

class SepConv(nn.Module):
    expand = 1
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, se_scale=4, p=0.5):
        super().__init__()
        # first SepConv is not using stochastic depth
        self.p = torch.tensor(p).float() if (in_channels == out_channels) else torch.tensor(1).float()

        self.residual = nn.Sequential(
            nn.Conv1d(in_channels * SepConv.expand, in_channels * SepConv.expand, kernel_size=kernel_size,
                      stride=1, padding=kernel_size//2, bias=False, groups=in_channels*SepConv.expand),
            nn.BatchNorm1d(in_channels * SepConv.expand, momentum=0.99, eps=1e-3),
            Swish()
        )

        self.se = SEBlock(in_channels * SepConv.expand, se_scale)

        self.project = nn.Sequential(
            nn.Conv1d(in_channels*SepConv.expand, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
        )

        self.shortcut = (stride == 1) and (in_channels == out_channels)

    def forward(self, x):
        # stochastic depth
        if self.training:
            if not torch.bernoulli(self.p):
                return x

        x_shortcut = x
        x_residual = self.residual(x)
        x_se = self.se(x_residual)

        x = x_se * x_residual
        x = self.project(x)

        if self.shortcut:
            x= x_shortcut + x

        return x

class EfficientNet(nn.Module):
    def __init__(self, num_classes=10, width_coef=1., depth_coef=1., scale=1., dropout=0.2, se_scale=4, stochastic_depth=False, p=0.5):
        super().__init__()
        channels = [32, 16, 24, 40, 80, 112, 192, 320, 1280]
        repeats = [1, 2, 2, 3, 3, 4, 1]
        strides = [1, 2, 2, 2, 1, 2, 1]
        kernel_size = [3, 3, 5, 3, 5, 5, 3]
        depth = depth_coef
        width = width_coef

        channels = [int(x*width) for x in channels]
        repeats = [int(x*depth) for x in repeats]

        # stochastic depth
        if stochastic_depth:
            self.p = p
            self.step = (1 - 0.5) / (sum(repeats) - 1)
        else:
            self.p = 1
            self.step = 0


        # efficient net
        self.upsample = nn.Upsample(scale_factor=scale, mode='linear', align_corners=False)

        self.stage1 = nn.Sequential(
            nn.Conv1d(8, channels[0],3, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(channels[0], momentum=0.99, eps=1e-3)
        )

        self.stage2 = self._make_Block(SepConv, repeats[0], channels[0], channels[1], kernel_size[0], strides[0], se_scale)

        self.stage3 = self._make_Block(MBConv, repeats[1], channels[1], channels[2], kernel_size[1], strides[1], se_scale)

        self.stage4 = self._make_Block(MBConv, repeats[2], channels[2], channels[3], kernel_size[2], strides[2], se_scale)

        self.stage5 = self._make_Block(MBConv, repeats[3], channels[3], channels[4], kernel_size[3], strides[3], se_scale)

        self.stage6 = self._make_Block(MBConv, repeats[4], channels[4], channels[5], kernel_size[4], strides[4], se_scale)

        self.stage7 = self._make_Block(MBConv, repeats[5], channels[5], channels[6], kernel_size[5], strides[5], se_scale)

        self.stage8 = self._make_Block(MBConv, repeats[6], channels[6], channels[7], kernel_size[6], strides[6], se_scale)

        self.stage9 = nn.Sequential(
            nn.Conv1d(channels[7], channels[8], 1, stride=1, bias=False),
            nn.BatchNorm1d(channels[8], momentum=0.99, eps=1e-3),
            Swish()
        ) 

        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.dropout = nn.Dropout(p=dropout)
        self.linear = nn.Linear(channels[8], num_classes)
        self.softmax=nn.Softmax(dim=1)

    def forward(self, x):
        x = self.upsample(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)
        x = self.stage8(x)
        x = self.stage9(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.linear(x)
        xsoft=self.softmax(x)
        return x,xsoft


    def _make_Block(self, block, repeats, in_channels, out_channels, kernel_size, stride, se_scale):
        strides = [stride] + [1] * (repeats - 1)
        layers = []
        for stride in strides:
            layers.append(block(in_channels, out_channels, kernel_size, stride, se_scale, self.p))
            in_channels = out_channels
            self.p -= self.step

        return nn.Sequential(*layers)


def efficientnet_b0(num_classes=2):
    return EfficientNet(num_classes=num_classes, width_coef=1.0, depth_coef=1.0, scale=1.0,dropout=0.2, se_scale=4)

def efficientnet_b1(num_classes=2):
    return EfficientNet(num_classes=num_classes, width_coef=1.0, depth_coef=1.1, scale=240/224, dropout=0.2, se_scale=4)

def efficientnet_b2(num_classes=2):
    return EfficientNet(num_classes=num_classes, width_coef=1.1, depth_coef=1.2, scale=260/224., dropout=0.3, se_scale=4)

def efficientnet_b3(num_classes=2):
    return EfficientNet(num_classes=num_classes, width_coef=1.2, depth_coef=1.4, scale=300/224, dropout=0.3, se_scale=4)

def efficientnet_b4(num_classes=2):
    return EfficientNet(num_classes=num_classes, width_coef=1.4, depth_coef=1.8, scale=380/224, dropout=0.4, se_scale=4)

def efficientnet_b5(num_classes=2):
    return EfficientNet(num_classes=num_classes, width_coef=1.6, depth_coef=2.2, scale=456/224, dropout=0.4, se_scale=4)

def efficientnet_b6(num_classes=2):
    return EfficientNet(num_classes=num_classes, width_coef=1.8, depth_coef=2.6, scale=528/224, dropout=0.5, se_scale=4)

def efficientnet_b7(num_classes=2):
    return EfficientNet(num_classes=num_classes, width_coef=2.0, depth_coef=3.1, scale=600/224, dropout=0.5, se_scale=4)

In [None]:
filedir='' # enter file path here
wf_total=torch.load(filedir+'.pt') # enter the name of the waveform file here

In [None]:
def pipeline(bs,lr,weight_decay,epoch_num,optimizer_type,label,trainvaltest_iter,themodel,modelname):
    
    df_merged3=pd.read_csv('.csv') # enter the name of the data file here
    print('len df_merged3',len(df_merged3))

    index_test=[]
    index_trainval=[]
    for i in range(len(df_merged3)):
        if '체크업' in df_merged3['처방코드(코드명)'][i]: # Korean term for checkup
            index_test.append(i)
        else:
            index_trainval.append(i)

    df_test_temp=df_merged3.iloc[index_test]
    df_test_temp.reset_index(inplace=True,drop=True)
    df_test_temp=df_test_temp[df_test_temp['datediff_int']<=0]
    df_test_temp=df_test_temp[df_test_temp['datediff_int']>=-0]
    df_test_temp.reset_index(inplace=True,drop=True)
    print('test length',len(df_test_temp))
    print('test >= 100:',len(df_test_temp[df_test_temp['CACS']>=100]))
    print('test >= 400:',len(df_test_temp[df_test_temp['CACS']>=400]))
    print('test >= 1000:',len(df_test_temp[df_test_temp['CACS']>=1000]))

    df_test=df_merged3.iloc[index_test]
    df_trainval=df_merged3.iloc[index_trainval]

    df_test.reset_index(inplace=True,drop=True)
    df_trainval.reset_index(inplace=True,drop=True)

    df_test=df_test[df_test['datediff_int']<=0]
    df_test=df_test[df_test['datediff_int']>=-0]
    df_test.reset_index(inplace=True,drop=True)
    
    skf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0)

    labels_temp = df_trainval[label]
    groups_temp = df_trainval['patient_id']
    
    for train_idx, val_idx in skf.split(df_trainval, labels_temp, groups_temp):
        df_train = df_trainval.iloc[train_idx].reset_index(drop=True)
        df_val = df_trainval.iloc[val_idx].reset_index(drop=True)
        break  

    df_train.reset_index(inplace=True,drop=True)
    df_val.reset_index(inplace=True,drop=True)

    print('len df_train',len(df_train))
    print('len df_val',len(df_val))
    print('len df_test',len(df_test))
    print('len df_test positive',np.sum(df_test[label]))
    
    savepath=''# enter savepath here
    
    df_train.to_csv(savepath+'df_train.csv')
    df_val.to_csv(savepath+'df_val.csv')
    df_test.to_csv(savepath+'df_test.csv')
    
    class Agatston_train(Dataset):
        def __init__(self):
            self.thedf=df_train
            self.len=len(self.thedf)

        def __getitem__(self,index):
            self.primary_key=self.thedf['primary_key'][index]
            randomint=random.randint(0,3750)
            self.input=wf_total[self.primary_key][:,randomint:randomint+1250]
            self.output=self.thedf[label][index]
            return self.input,self.output

        def get_labels(self): 
            return np.array(self.thedf[label])

        def __len__(self):
            return self.len

    class Agatston_val(Dataset):
        def __init__(self):
            self.thedf=df_val
            self.len=len(self.thedf)*4

        def __getitem__(self,index):
            index_a=index//4
            index_b=index%4
            self.primary_key=self.thedf['primary_key'][index_a]
            self.input=wf_total[self.primary_key][:,index_b*1250:(index_b+1)*1250]
            self.output=self.thedf[label][index_a]
            return self.input,self.output

        def __len__(self):
            return self.len

    class Agatston_test(Dataset):
        def __init__(self):
            self.thedf=df_test
            self.len=len(self.thedf)*4

        def __getitem__(self,index):
            index_a=index//4
            index_b=index%4
            self.primary_key=self.thedf['primary_key'][index_a]
            self.input=wf_total[self.primary_key][:,index_b*1250:(index_b+1)*1250]
            self.output=self.thedf[label][index_a]
            return self.input,self.output

        def __len__(self):
            return self.len
    

    train_dataset = Agatston_train()
    val_dataset = Agatston_val()
    test_dataset = Agatston_test()

    train_dataset_dataloader = DataLoader(train_dataset,batch_size=bs, sampler=ImbalancedDatasetSampler(train_dataset))
    val_dataset_dataloader = DataLoader(val_dataset,batch_size=bs)
    test_dataset_dataloader = DataLoader(test_dataset,batch_size=bs)
    

    model=themodel
    criterion=nn.CrossEntropyLoss().to(device)
    if optimizer_type=='Adam':
        optimizer=torch.optim.Adam(model.parameters(),lr=lr,weight_decay=weight_decay)
    elif optimizer_type=='RMSprop':
        optimizer=torch.optim.RMSprop(model.parameters(),lr=lr, weight_decay=weight_decay,alpha=0.9, momentum=0.1)
    scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=5,factor=0.5,mode='min')

    
    toggle=0
    accum=0
    
    for epoch in range(epoch_num):
        now1=datetime.datetime.now()
        
        if toggle==1:
            break
        
        
        for phase in trainvaltest_iter:
            if phase=='training':
                running_loss=0.0
                acc=0.
                correct=0
                for param_group in optimizer.param_groups:
                    print(param_group['lr']) 
                TP=0
                TN=0
                FP=0
                FN=0
                model.train()
                for i,data in enumerate(train_dataset_dataloader):
                    inputs,targets=data
                    inputs=inputs.to(device)
                    targets=targets.long().to(device)
                    optimizer.zero_grad()
                    outputs,outputssoft=model(inputs)
                    loss=criterion(outputs,targets)
                    loss.backward()
                    optimizer.step()
                    running_loss+=loss.item()
                    prediction=torch.max(outputs.data,1)[1]
                    correct+=prediction.eq(targets.data.view_as(prediction)).cpu().sum()
                    for a,b in zip(targets,prediction):
                        if a==1 and b==1:
                            TP+=1
                        elif a==1 and b==0:
                            FN+=1
                        elif a==0 and b==1:
                            FP+=1
                        else:
                            TN+=1 
                sensitivity=(TP/(TP+FN))
                specificity=(TN/(TN+FP))
                try:
                    precision=(TP/(TP+FP))
                except:
                    precision=0
                if precision!=0:
                    f1score=((2*precision*sensitivity)/(precision+sensitivity))
                else:
                    f1score=0
                print("[%d] loss: %.3f  sensitivity: %.3f specificity: %.3f precision: %.3f f1score: %.3f" % (epoch+1,running_loss/100,sensitivity,specificity,precision,f1score))     


            if phase=='val':
                model.eval()
                running_loss=0.0
                acc=0.
                correct=0
                TP=0
                TN=0
                FP=0
                FN=0
                prediction_score=[]
                targets_list=[]
                with torch.no_grad():
                    for i,data in enumerate(val_dataset_dataloader):
                        inputs,targets=data
                        for j in range(len(targets)):
                            targets_list.append(targets[j])
                        inputs=inputs.to(device)
                        targets=targets.long().to(device)
                        outputs,outputssoft=model(inputs)
                        loss=criterion(outputs,targets)
                        running_loss+=loss.item()
                        prediction=torch.max(outputs.data,1)[1]
                        correct+=prediction.eq(targets.data.view_as(prediction)).cpu().sum()

                        for a,b in zip(targets,prediction):
                            if a==1 and b==1:
                                TP+=1
                            elif a==1 and b==0:
                                FN+=1
                            elif a==0 and b==1:
                                FP+=1
                            else:
                                TN+=1
                        for j in range(len(outputssoft)):
                            prediction_score.append(outputssoft[j][1].item())

                try:
                    sensitivity=(TP/(TP+FN))
                except:
                    sensitivity=0.5
                try:
                    specificity=(TN/(TN+FP))
                except:
                    specificity=0.5
                try:
                    precision=(TP/(TP+FP))
                except:
                    precision=0
                if precision!=0:
                    f1score=((2*precision*sensitivity)/(precision+sensitivity))
                else:
                    f1score=0                

                fpr_temp,tpr_temp,threshold=metrics.roc_curve(targets_list,prediction_score,pos_label=1)
                AUROC_val=metrics.auc(fpr_temp,tpr_temp)
                AP=average_precision_score(targets_list,prediction_score)

                print("[%d] loss: %.3f AUROC: %.3f AUPRC: %.3f sensitivity: %.3f specificity: %.3f precision: %.3f f1score: %.3f" % (epoch+1,running_loss/100,AUROC_val,AP, sensitivity,specificity,precision,f1score))     
                scheduler.step(running_loss)
                
                if epoch==0:
                    lowest_val_loss=running_loss
                else:
                    if running_loss<lowest_val_loss:
                        lowest_val_loss=running_loss
                        accum=0
                    else:
                        accum+=1
                    if accum==15:
                        toggle=1
                
            if phase=='test':
                model.eval()
                running_loss=0.0
                acc=0.
                correct=0
                TP=0
                TN=0
                FP=0
                FN=0
                prediction_score=[]
                targets_list=[]
                with torch.no_grad():
                    for i,data in enumerate(test_dataset_dataloader):
                        inputs,targets=data
                        for j in range(len(targets)):
                            targets_list.append(targets[j])
                        inputs=inputs.to(device)
                        targets=targets.long().to(device)
                        outputs,outputssoft=model(inputs)
                        loss=criterion(outputs,targets)
                        running_loss+=loss.item()
                        prediction=torch.max(outputs.data,1)[1]
                        correct+=prediction.eq(targets.data.view_as(prediction)).cpu().sum()

                        for a,b in zip(targets,prediction):
                            if a==1 and b==1:
                                TP+=1
                            elif a==1 and b==0:
                                FN+=1
                            elif a==0 and b==1:
                                FP+=1
                            else:
                                TN+=1
                        for j in range(len(outputssoft)):
                            prediction_score.append(outputssoft[j][1].item())

                try:
                    sensitivity=(TP/(TP+FN))
                except:
                    sensitivity=0.5
                try:
                    specificity=(TN/(TN+FP))
                except:
                    specificity=0.5
                try:
                    precision=(TP/(TP+FP))
                except:
                    precision=0
                if precision!=0:
                    f1score=((2*precision*sensitivity)/(precision+sensitivity))
                else:
                    f1score=0                
                
                fpr_temp,tpr_temp,threshold=metrics.roc_curve(targets_list,prediction_score,pos_label=1)
                AUROC=metrics.auc(fpr_temp,tpr_temp)
                AP=average_precision_score(targets_list,prediction_score)

                print("[%d] loss: %.3f AUROC: %.3f AUPRC: %.3f sensitivity: %.3f specificity: %.3f precision: %.3f f1score: %.3f" % (epoch+1,running_loss/100,AUROC,AP, sensitivity,specificity,precision,f1score))     
 
                now = datetime.datetime.now()
                nowStr2 = "{:%Y%m%d%H%M%S}".format(now)
                modelname2=nowStr2+'_'
                modelname2+=(label+'_')
                modelname2+=modelname
                modelname2+=('_'+'epoch_'+str(epoch))
                modelname2+='.pt'

                torch.save(model.state_dict(),savepath+modelname2)

        now2=datetime.datetime.now()
        print(now2-now1)

In [None]:
phases=['training','val','test']
experimentnum=0
for model in [efficientnetb0]:
    for batchsize in [128,256,512,1024]:
        for learningrate in [0.001,0.01,0.1]:
            for weightdecay in [0,0.001,0.01]:
                print('===================================================')
                print(experimentnum,model,batchsize,learningrate,weightdecay)
                themodel=model().to(device)
                modelname='experimentnum'
                modelname+=str(experimentnum)
                modelname+='_'
                modelname+=(str(model).split('function')[1][1:].split('at')[0][:-1]+'_')
                pipeline(batchsize,learningrate,weightdecay,50,'Adam',"label0",phases,themodel,modelname)
                experimentnum+=1