In [None]:
#imports
import sys
import pandas as pd
import numpy as np
import os
import time
import random
from sklearn.metrics import accuracy_score
import logging
from livelossplot import PlotLosses

import torch
import monai
from monai.data import DataLoader
from monai.transforms import (
    AddChanneld,
    CenterSpatialCropd,
    Compose,
    Resized,
    RandSpatialCropd,
    ScaleIntensityd,
    ToTensord,
    LoadImaged,
    Identityd,
)

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

In [None]:
#hyperparameters which were selected during hyperparameter tuning
lr=1e-4
optimizer="none"
strategy="adam"
epoch=84

In [None]:
max_epochs=epoch

In [None]:
#definitions of path
MODEL_DIR = os.path.join("./SEResNet/")
path_train_data=os.path.join("../../data/trainValid_DL.csv")
filename_predictions_for_platt_scaling=os.path.join("./SEResNet/predictions_for_platt_scaling.csv")

In [None]:
#definition of batch size, numbers should be devisors of 64
BATCH_SIZE=2

In [None]:
#load ADNI training data
trainValidMerged=pd.read_csv(path_train_data,index_col="PTID")

In [None]:
#load data augmentations
train_transforms = Compose(
        [
            LoadImaged(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"],spatial_size=(256,256,256)),
            RandSpatialCropd(keys=["img"],roi_size=(224,224,224),random_size =False),
            ToTensord(keys=["img"]),
        ]
    )

valid_transforms = Compose(
        [
            LoadImaged(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"],spatial_size=(256,256,256)),
            CenterSpatialCropd(keys=["img"],roi_size=(224,224,224)),
            ToTensord(keys=["img"]),
        ]
    )

In [None]:
#define function to set seeds for reproducibility
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

In [None]:
liveloss = PlotLosses()

In [None]:
#reformat training and datasets for pytorch
Y_train=pd.get_dummies(trainValidMerged.DX,drop_first=True).to_numpy().squeeze()
Y_train=Y_train.tolist()
trainDSNew = [{"img": img, "label": label} for img, label in zip(trainValidMerged.filename, Y_train)]
set_seed(123)
train_ds = monai.data.Dataset(data=trainDSNew, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
#disable mixed precision as there are some problems with monai models
use_amp = False
#define gradient scaler
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
#define batchsize factor used for batch accumulation --> the virtual batch size is thus 64
batchsize_factor=64//BATCH_SIZE
set_seed(123)
#choose cuda as the device if it is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#load DL model using monai
model = monai.networks.nets.SEResNet152(num_classes=2,spatial_dims=3, in_channels=1)
#define cross entropy as loss function
loss_function = torch.nn.CrossEntropyLoss()
#select optimizer
if opt =="adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif opt=="sgd":
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
else:
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
#select learning rate scheduler
if strategy=="step":
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
elif strategy=="exp":
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 
model=model.to(device)
set_seed(123)
#iterate over epochs
for epoch in range(max_epochs+1):
    #store starting time of epoch
    start = time.time()
    #set model in train mode
    model.train()
    logs = {}
    epoch_loss = 0
    epoch_loss_val = 0
    step = 0
    #initialize lists to store predictions and labels of training data
    preds_epoch=[]
    label_epoch=[]
    #iterate over training batches
    for batch_data in train_loader:
        step += 1
        #load images and labels for batch
        inputs=batch_data["img"].cuda()
        labels=batch_data["label"].cuda()   
        #get model output
        outputs = model(inputs)
        #calculate loss
        loss = loss_function(outputs, labels)/batchsize_factor
        scaler.scale(loss).backward()
        #batch accumulation
        if (step+1) % batchsize_factor==0:
            #update scaler and optimizer
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        #calculate loss over epoch
        epoch_loss += (loss.item()*batchsize_factor)
        #extract prediction of model
        _, preds = torch.max(outputs, 1)
        #store prediction of model and label of subjects in the batch
        preds_epoch.append(preds.cpu().detach().numpy())
        label_epoch.append(labels.data.cpu().detach().numpy())
    #increase step for scheduler
    if strategy=="step":
        scheduler.step()
    elif strategy=="exp":
        scheduler.step()
    #store epoch training loss
    logs['log loss'] = epoch_loss
    #calculate and store training accuracy 
    preds_epoch = [item for sublist in preds_epoch for item in sublist]
    label_epoch = [item for sublist in label_epoch for item in sublist]
    logs['accuracy']=accuracy_score(label_epoch,preds_epoch)*100
    #show training loss and accuracy in liveloss plot
    liveloss.update(logs)
    liveloss.send()
    #save model parameters
    torch.save(model.state_dict(),MODEL_DIR+"model_"+str(opt)+"_"+str(lr)+"_"+str(strategy)+"_"+str(epoch)+"_final_model.pth")
    #calculate time used to train one epoch
    end = time.time()
    print(format(end-start))

In [None]:
#calculate polyak models by averaging the parameters of the last 5 models trained before
model = monai.networks.nets.SEResNet152(num_classes=2,spatial_dims=3, in_channels=1)
PATH=MODEL_DIR+"model_"+str(opt)+"_"+str(lr)+"_"+strategy+"_"+str(max_epochs)+"_final_model.pth"
model.load_state_dict(torch.load(PATH))
final_model = monai.networks.nets.SEResNet152(num_classes=2,spatial_dims=3, in_channels=1)
beta = 0.2   
dict_params_final_model = final_model.state_dict()
for a in range(0,5):
    model = monai.networks.nets.SEResNet152(num_classes=2,spatial_dims=3, in_channels=1)
    model.load_state_dict(torch.load(MODEL_DIR+"model_"+str(opt)+"_"+str(lr)+"_"+str(strategy)+"_"+str(max_epochs-a)+"_final_model.pth"))
    params = model.state_dict()
    if a==0:
        for name1 in dict_params_final_model:
            if name1 in params:
                dict_params_final_model[name1]=(beta*params[name1])

    else:
        for name1 in dict_params_final_model:
            if name1 in params:
                dict_params_final_model[name1]=(dict_params_final_model[name1]+beta*params[name1])
final_model.load_state_dict(dict_params_final_model)
#save model
torch.save(final_model.state_dict(),MODEL_DIR+"model_"+str(opt)+"_"+str(lr)+"_"+str(strategy)+"_"+str(max_epochs)+"_final_model_polyak_averaged.pth")