In [None]:
#imports
import sys
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import numpy as np
import os
import time
import random
from sklearn.metrics import accuracy_score
import logging
from livelossplot import PlotLosses

import torch
import monai
from monai.data import DataLoader
from monai.transforms import (
    AddChanneld,
    CenterSpatialCropd,
    Compose,
    Resized,
    RandSpatialCropd,
    ScaleIntensityd,
    ToTensord,
    LoadImaged,
    Identityd,
)

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

In [None]:
#definitions of path
MODEL_DIR = os.path.join("./SEResNet/")
path_train_data=os.path.join("../../data/trainValid_DL.csv")
filenameCSV=os.path.join("./SEResNet/hyperparameter_tuning_results.csv")

In [None]:
#definition of batch size, numbers should be devisors of 64
BATCH_SIZE=2

In [None]:
#if model directory not exists create model directory
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

In [None]:
#load ADNI training dataset
trainValidMerged=pd.read_csv(path_train_data,index_col="PTID")

In [None]:
#load data augmentations
train_transforms = Compose(
        [
            LoadImaged(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"],spatial_size=(256,256,256)),
            RandSpatialCropd(keys=["img"],roi_size=(224,224,224),random_size =False),
            ToTensord(keys=["img"]),
        ]
    )

valid_transforms = Compose(
        [
            LoadImaged(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"],spatial_size=(256,256,256)),
            CenterSpatialCropd(keys=["img"],roi_size=(224,224,224)),
            ToTensord(keys=["img"]),
        ]
    )


In [None]:
#define function to set seeds for reproducibility
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


In [None]:
liveloss = PlotLosses()

In [None]:
#iterate over all hyperparameters
for lr in [1e-1,1e-2,1e-3,1e-4,1e-5]:
    for opt in ["sgd","adam","rmsprop"]:
        for strategy in ["none","exp","step"]:
            #check if model with parameters is already trained
            file_path=MODEL_DIR+"model_"+str(opt)+"_"+str(lr)+"_"+str(strategy)+"_5_49.pth"
            if not os.path.isfile(file_path):
                #define cross validation splits
                kf = StratifiedKFold(n_splits=5,shuffle=True,random_state=101)
                cvIt=0
                #iterate over all cross-validation splits
                for trainCross, validCross in kf.split(trainValidMerged,trainValidMerged.DX):
                    cvIt+=1
                    #extract training and validation data
                    training=trainValidMerged.iloc[trainCross]
                    valid=trainValidMerged.iloc[validCross]
                    #extract diagnosis for training and validation data
                    Y_train=pd.get_dummies(training.DX,drop_first=True).to_numpy().squeeze()
                    Y_train=Y_train.tolist()
                    Y_valid=pd.get_dummies(valid.DX,drop_first=True).to_numpy().squeeze()
                    Y_valid=Y_valid.tolist()
                    #reformat training and validation datasets for pytorch
                    trainDSNew = [{"img": img, "label": label} for img, label in zip(training.filename, Y_train)]
                    validDSNew = [{"img": img, "label": label} for img, label in zip(valid.filename, Y_valid)]
                    train_ds = monai.data.Dataset(data=trainDSNew, transform=train_transforms)
                    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
                    valid_ds = monai.data.Dataset(data=validDSNew, transform=valid_transforms)
                    valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=torch.cuda.is_available())
                    #set seed for reproducibility
                    set_seed(123)
                    #disable mixed precision as there are some problems with monai models
                    use_amp = False
                    #define gradient scaler
                    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
                    #define batchsize factor used for batch accumulation --> the virtual batch size is thus 64
                    batchsize_factor=64 // BATCH_SIZE
                    #define interval in which the validation should be performed
                    val_interval = 1
                    #define list to store epoch loss values and accuracy scores per epoch
                    epoch_loss_values = []
                    #set maximal number of epochs
                    max_epochs = 50
                    set_seed(123)
                    #choose cuda as the device if it is available
                    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                    #load DL model using monai
                    model = monai.networks.nets.SEResNet152(num_classes=2,spatial_dims=3, in_channels=1)
                    #define cross entropy as loss function
                    loss_function = torch.nn.CrossEntropyLoss()
                    #select optimizer
                    if opt =="adam":
                        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
                    elif opt=="sgd":
                        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
                    else:
                        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
                    #select learning rate scheduler
                    if strategy=="step":
                        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
                    elif strategy=="exp":
                        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 
                    model=model.to(device)
                    set_seed(123)
                    #iterate over epochs
                    for epoch in range(max_epochs):
                        #store starting time of epoch
                        start = time.time()
                        #start with model in train mode
                        model.train()
                        logs = {}
                        epoch_loss = 0
                        epoch_loss_val = 0
                        step = 0
                        #initialize lists to store predictions and labels of training and validation data
                        preds_epoch=[]
                        label_epoch=[]
                        preds_epoch_val=[]
                        label_epoch_val=[]
                        #iterate over training batches
                        for batch_data in train_loader:
                            step += 1
                            #load images and labels for batch
                            inputs=batch_data["img"].cuda()
                            labels=batch_data["label"].cuda()
                            #get model output
                            outputs = model(inputs)
                            #calculate loss
                            loss = loss_function(outputs, labels)/batchsize_factor
                            scaler.scale(loss).backward()
                            #batch accumulation
                            if (step+1) % batchsize_factor==0:
                                #update scaler and optimizer
                                scaler.step(optimizer)
                                scaler.update()
                                optimizer.zero_grad()
                            #calculate loss over epoch
                            epoch_loss += (loss.item()*batchsize_factor)
                            #extract prediction of model
                            _, preds = torch.max(outputs, 1)
                            #store prediction of model and label of subjects in the batch
                            preds_epoch.append(preds.cpu().detach().numpy())
                            label_epoch.append(labels.data.cpu().detach().numpy())
                        #increase step for scheduler
                        if strategy=="step":
                            scheduler.step()
                        elif strategy=="exp":
                            scheduler.step()
                        #store epoch training loss
                        logs['log loss'] = epoch_loss
                        #calculate and store training accuracy 
                        preds_epoch = [item for sublist in preds_epoch for item in sublist]
                        label_epoch = [item for sublist in label_epoch for item in sublist]
                        logs['accuracy']=accuracy_score(label_epoch,preds_epoch)*100
                        epoch_loss_values.append(epoch_loss)
                        #model validation
                        if (epoch + 1) % val_interval == 0:
                            #change model to evaluation model
                            model.eval()
                            with torch.no_grad():
                                num_correct = 0.0
                                metric_count = 0
                                #iterate over validation data
                                for val_data in valid_loader:
                                    #load images and labels of validation batch
                                    inputs=val_data["img"].cuda()
                                    labels=val_data["label"].cuda()
                                    #calculate outputs of model
                                    outputs = model(inputs)
                                    #calculate loss
                                    loss=loss_function(outputs, labels)
                                    #calculate and store prediction and label for batch
                                    _, preds= torch.max(outputs, 1)
                                    preds_epoch_val.append(preds.cpu().detach().numpy())
                                    label_epoch_val.append(labels.data.cpu().detach().numpy())
                                    epoch_loss_val += loss.item()
                                #calculate and store epoch loss during validation and accuracy during validation
                                logs['val_log loss'] = epoch_loss_val
                                preds_epoch_val = [item for sublist in preds_epoch_val for item in sublist]
                                label_epoch_val = [item for sublist in label_epoch_val for item in sublist]
                                logs['val_accuracy']=accuracy_score(label_epoch_val,preds_epoch_val)*100
                        #show training and validation loss and accuracy in liveloss plot
                        liveloss.update(logs)
                        liveloss.send()
                        #save model parameters 
                        torch.save(model.state_dict(),MODEL_DIR+"model_"+str(opt)+"_"+str(lr)+"_"+str(strategy)+"_"+str(cvIt)+"_"+str(epoch)+".pth")
                        #save hyperparameters and results of model
                        d = {'optimizer': [opt], 'LR': [lr],'strategy':[strategy],'CV':[cvIt], 'Epoch':[epoch], "Epoch-Accuracy":[accuracy_score(label_epoch_val,preds_epoch_val)*100],"Epoch-Loss":[epoch_loss_val]}
                        df = pd.DataFrame(data=d)
                        if os.path.isfile(filenameCSV):
                            df.to_csv(filenameCSV, mode='a', header=False)
                        else:
                            df.to_csv(filenameCSV, mode='w', header=True)
                        end = time.time()
                        #calculate time used to train one epoch
                        print(format(end-start))

In [None]:
#load table that includes all results received during the hyperparameter tuning
df_hyperparameters=pd.read_csv(filenameCSV)
#drop duplicated entries, (if the pipeline is run multiple times), for all hyperparameter combinations, the last entry is kept
df_filtered=df_hyperparameters.drop_duplicates(subset=["optimizer","LR","strategy","CV","Epoch"],keep="last")

In [None]:
#calculated mean and sd accuracies as well as losses for all hyperparameter combinations
df_mean=pd.DataFrame(columns = ['optimizer', 'lr',"strategy", 'epoch',"Mean ACC","Mean Loss"])

for optimizer in df_filtered.optimizer.unique():
    for lr in df_filtered.LR.unique():
        for strategy in df_filtered.strategy.unique():
            for epoch in df_filtered.Epoch.unique():
                mean_value=df_filtered[((df_filtered.optimizer==optimizer)&(df_filtered.strategy==strategy)&(df_filtered.LR==lr)&(df_filtered.Epoch==epoch))]["Epoch-Accuracy"].mean()
                sd_value=df_filtered[((df_filtered.optimizer==optimizer)&(df_filtered.strategy==strategy)&(df_filtered.LR==lr)&(df_filtered.Epoch==epoch))]["Epoch-Accuracy"].std()
                mean_loss=df_filtered[((df_filtered.optimizer==optimizer)&(df_filtered.strategy==strategy)&(df_filtered.LR==lr)&(df_filtered.Epoch==epoch))]["Epoch-Loss"].mean()
                sd_loss=df_filtered[((df_filtered.optimizer==optimizer)&(df_filtered.strategy==strategy)&(df_filtered.LR==lr)&(df_filtered.Epoch==epoch))]["Epoch-Loss"].std()
                df_new=pd.DataFrame([{'optimizer' : optimizer, 'lr' : lr,'strategy': strategy, 'epoch' : epoch,"Mean ACC":mean_value,"Mean Loss":mean_loss,"sd ACC":sd_value,"sd loss":sd_loss}])
                df_mean = pd.concat([df_mean,df_new])

In [None]:
#identify hyperparameter combination that achieved the best mean CV accuraacy
df_mean=df_mean.reset_index()
max_idx=df_mean["Mean ACC"].astype(float).idxmax()
max_obj=df_mean.iloc[max_idx]
optimizer=max_obj["optimizer"]
lr=max_obj["lr"]
strategy=max_obj["strategy"]
epoch=max_obj["epoch"]

In [None]:
#additional training for another 50 epochs for the hyperparameters that performed best
#define cross validation splits
kf = StratifiedKFold(n_splits=5,shuffle=True,random_state=101)
cvIt=0
#iterate over all cross-validation splits
for trainCross, validCross in kf.split(trainValidMerged,trainValidMerged.DX):
    cvIt+=1
    #extract training and validation data
    training=trainValidMerged.iloc[trainCross]
    valid=trainValidMerged.iloc[validCross]
    #extract diagnosis for training and validation data
    Y_train=pd.get_dummies(training.DX,drop_first=True).to_numpy().squeeze()
    Y_train=Y_train.tolist()
    Y_valid=pd.get_dummies(valid.DX,drop_first=True).to_numpy().squeeze()
    Y_valid=Y_valid.tolist()
    #reformat training and validation datasets for pytorch
    trainDSNew = [{"img": img, "label": label} for img, label in zip(training.filename, Y_train)]
    validDSNew = [{"img": img, "label": label} for img, label in zip(valid.filename, Y_valid)]
    train_ds = monai.data.Dataset(data=trainDSNew, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    valid_ds = monai.data.Dataset(data=validDSNew, transform=valid_transforms)
    valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=torch.cuda.is_available())
    #set seed for reproducibility
    set_seed(123)
    #disable mixed precision as there are some problems with monai models
    use_amp = False
    #define gradient scaler
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    #define batchsize factor used for batch accumulation --> the virtual batch size is thus 64
    batchsize_factor=64 // BATCH_SIZE
    #define interval in which the validation should be performed
    val_interval = 1
    #define list to store epoch loss values and accuracy scores per epoch
    epoch_loss_values = []
    #set maximal number of epochs
    max_epochs = 100
    set_seed(123)
    #choose cuda as the device if it is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #load DL model using monai
    model = monai.networks.nets.SEResNet152(num_classes=2,spatial_dims=3, in_channels=1)
    #define cross entropy as loss function
    loss_function = torch.nn.CrossEntropyLoss()
    #select optimizer
    if opt =="adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif opt=="sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    else:
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
    #select learning rate scheduler
    if strategy=="step":
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
    elif strategy=="exp":
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 
    model=model.to(device)
    set_seed(123)
    #iterate over epochs
    for epoch in range(max_epochs):
        #store starting time of epoch
        start = time.time()
        #start with model in train mode
        model.train()
        logs = {}
        epoch_loss = 0
        epoch_loss_val = 0
        step = 0
        #initialize lists to store predictions and labels of training and validation data
        preds_epoch=[]
        label_epoch=[]
        preds_epoch_val=[]
        label_epoch_val=[]
        #iterate over training batches
        for batch_data in train_loader:
            step += 1
            #load images and labels for batch
            inputs=batch_data["img"].cuda()
            labels=batch_data["label"].cuda()
            #get model output
            outputs = model(inputs)
            #calculate loss
            loss = loss_function(outputs, labels)/batchsize_factor
            scaler.scale(loss).backward()
            #batch accumulation
            if (step+1) % batchsize_factor==0:
                #update scaler and optimizer
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            #calculate loss over epoch
            epoch_loss += (loss.item()*batchsize_factor)
            #extract prediction of model
            _, preds = torch.max(outputs, 1)
            #store prediction of model and label of subjects in the batch
            preds_epoch.append(preds.cpu().detach().numpy())
            label_epoch.append(labels.data.cpu().detach().numpy())
        #increase step for scheduler
        if strategy=="step":
            scheduler.step()
        elif strategy=="exp":
            scheduler.step()
        #store epoch training loss
        logs['log loss'] = epoch_loss
        #calculate and store training accuracy 
        preds_epoch = [item for sublist in preds_epoch for item in sublist]
        label_epoch = [item for sublist in label_epoch for item in sublist]
        logs['accuracy']=accuracy_score(label_epoch,preds_epoch)*100
        epoch_loss_values.append(epoch_loss)
        #model validation
        if (epoch + 1) % val_interval == 0:
            #change model to evaluation model
            model.eval()
            with torch.no_grad():
                num_correct = 0.0
                metric_count = 0
                #iterate over validation data
                for val_data in valid_loader:
                    #load images and labels of validation batch
                    inputs=val_data["img"].cuda()
                    labels=val_data["label"].cuda()
                    #calculate outputs of model
                    outputs = model(inputs)
                    #calculate loss
                    loss=loss_function(outputs, labels)
                    #calculate and store prediction and label for batch
                    _, preds= torch.max(outputs, 1)
                    preds_epoch_val.append(preds.cpu().detach().numpy())
                    label_epoch_val.append(labels.data.cpu().detach().numpy())
                    epoch_loss_val += loss.item()
                #calculate and store epoch loss during validation and accuracy during validation
                logs['val_log loss'] = epoch_loss_val
                preds_epoch_val = [item for sublist in preds_epoch_val for item in sublist]
                label_epoch_val = [item for sublist in label_epoch_val for item in sublist]
                logs['val_accuracy']=accuracy_score(label_epoch_val,preds_epoch_val)*100
        #show training and validation loss and accuracy in liveloss plot
        liveloss.update(logs)
        liveloss.send()
        #save model parameters 
        torch.save(model.state_dict(),MODEL_DIR+"model_"+str(opt)+"_"+str(lr)+"_"+str(strategy)+"_"+str(cvIt)+"_"+str(epoch)+".pth")
        #save hyperparameters and results of model
        d = {'optimizer': [opt], 'LR': [lr],'strategy':[strategy],'CV':[cvIt], 'Epoch':[epoch], "Epoch-Accuracy":[accuracy_score(label_epoch_val,preds_epoch_val)*100],"Epoch-Loss":[epoch_loss_val]}
        df = pd.DataFrame(data=d)
        if os.path.isfile(filenameCSV):
            df.to_csv(filenameCSV, mode='a', header=False)
        else:
            df.to_csv(filenameCSV, mode='w', header=True)
        end = time.time()
        #calculate time used to train one epoch
        print(format(end-start))

In [None]:
#calculated mean and sd accuracies as well as losses for all hyperparameter combinations
df_mean=pd.DataFrame(columns = ['optimizer', 'lr',"strategy", 'epoch',"Mean ACC","Mean Loss"])

for optimizer in df_filtered.optimizer.unique():
    for lr in df_filtered.LR.unique():
        for strategy in df_filtered.strategy.unique():
            for epoch in df_filtered.Epoch.unique():
                mean_value=df_filtered[((df_filtered.optimizer==optimizer)&(df_filtered.strategy==strategy)&(df_filtered.LR==lr)&(df_filtered.Epoch==epoch))]["Epoch-Accuracy"].mean()
                sd_value=df_filtered[((df_filtered.optimizer==optimizer)&(df_filtered.strategy==strategy)&(df_filtered.LR==lr)&(df_filtered.Epoch==epoch))]["Epoch-Accuracy"].std()
                mean_loss=df_filtered[((df_filtered.optimizer==optimizer)&(df_filtered.strategy==strategy)&(df_filtered.LR==lr)&(df_filtered.Epoch==epoch))]["Epoch-Loss"].mean()
                sd_loss=df_filtered[((df_filtered.optimizer==optimizer)&(df_filtered.strategy==strategy)&(df_filtered.LR==lr)&(df_filtered.Epoch==epoch))]["Epoch-Loss"].std()
                df_new=pd.DataFrame([{'optimizer' : optimizer, 'lr' : lr,'strategy': strategy, 'epoch' : epoch,"Mean ACC":mean_value,"Mean Loss":mean_loss,"sd ACC":sd_value,"sd loss":sd_loss}])
                df_mean = pd.concat([df_mean,df_new])

In [None]:
#identify hyperparameter combination that achieved the best mean CV accuraacy
df_mean=df_mean.reset_index()
max_idx=df_mean["Mean ACC"].astype(float).idxmax()
max_obj=df_mean.iloc[max_idx]
optimizer=max_obj["optimizer"]
lr=max_obj["lr"]
strategy=max_obj["strategy"]
epoch=max_obj["epoch"]

In [None]:
#increase number of epochs by 10 %
epoch=int(round(epoch+1+epoch/10))-1

In [None]:
#calculate polyak models by averaging the parameters of the last 5 models for each cv iteration
model = monai.networks.nets.SEResNet152(num_classes=2,spatial_dims=3, in_channels=1)
PATH=MODEL_DIR+"model_"+str(optimizer)+"_"+str(lr)+"_"+strategy+"_1"+"_"+str(epoch)+".pth"
model.load_state_dict(torch.load(PATH))
final_model=monai.networks.nets.SEResNet152(num_classes=2,spatial_dims=3, in_channels=1)
beta = 0.2    
dict_params_final_model = final_model.state_dict()
for b in range(1,6):
    for a in range(0,5):
        model = monai.networks.nets.SEResNet152(num_classes=2,spatial_dims=3, in_channels=1)
        model.load_state_dict(torch.load(MODEL_DIR+"model_"+str(optimizer)+"_"+str(lr)+"_"+str(strategy)+"_"+str(b)+"_"+str(epoch-a)+".pth"))
        params = model.state_dict()
        if a==0:
            for name1 in dict_params_final_model:
                if name1 in params:
                    dict_params_final_model[name1]=(beta*params[name1])

        else:
            for name1 in dict_params_final_model:
                if name1 in params:
                    dict_params_final_model[name1]=(dict_params_final_model[name1]+beta*params[name1])
    final_model.load_state_dict(dict_params_final_model)
    #save model
    torch.save(final_model.state_dict(),MODEL_DIR+"model_"+str(strategy)+"_"+str(lr)+"_"+str(optimizer)+"_"+str(b)+"_"+str(epoch)+"_polyak_averaged.pth")
