In [None]:
#imports
import sys
import pandas as pd
import numpy as np
import os
import random
import logging
from livelossplot import PlotLosses
import pickle
import torch
import monai
import time
from monai.data import DataLoader
from monai.transforms import (
    AddChanneld,
    CenterSpatialCropd,
    Compose,
    Resized,
    RandSpatialCropd,
    ScaleIntensityd,
    ToTensord,
    LoadImaged,
    Identityd,
)
from sklearn.linear_model import LogisticRegression
from monai.utils import InterpolateMode
import nibabel as nib
import lime.lime_tabular
from skimage.segmentation import slic
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
torch.backends.cudnn.benchmark = True
from sklearn.model_selection import KFold
from torchmetrics import MeanMetric


In [None]:
opt="adam"
lr=1e-2
BATCH_SIZE=12

In [None]:
#definitions of paths
RESULTS_DIR = os.path.join("./DenseNet_pretrained/")
MODEL_DIR = os.path.join("./DenseNet_pretrained/pretraining_model/")
path_train_data=os.path.join("../../data/LDM_DL_train.csv")
path_valid_data=os.path.join("../../data/LDM_DL_valid.csv")
filenameCSV=os.path.join("./DenseNet_pretrained/pretraining_model/LDM_Results_DenseNet.csv")

In [None]:
#if model directory not exists create model directory
if not os.path.exists(RESULTS_DIR):
    os.makedirs(RESULTS_DIR)

In [None]:
#if model directory not exists create model directory
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

In [None]:
#read training and validation dataset
train=pd.read_csv(path_train_data,index_col="participant_id")
valid=pd.read_csv(path_valid_data,index_col="participant_id")

In [None]:
#load data augmentations
train_transforms = Compose(
        [
            LoadImaged(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"],spatial_size=(256,256,256)),
            RandSpatialCropd(keys=["img"],roi_size=(224,224,224),random_size =False),
            ToTensord(keys=["img"]),
        ]
    )
valid_transforms = Compose(
        [
            LoadImaged(keys=["img"]),
            AddChanneld(keys=["img"]),
            ScaleIntensityd(keys=["img"]),
            Resized(keys=["img"],spatial_size=(256,256,256)),
            CenterSpatialCropd(keys=["img"],roi_size=(224,224,224)),
            ToTensord(keys=["img"]),
        ]
    )

In [None]:
#define function to set seeds for reproducibility
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

In [None]:
liveloss = PlotLosses()

In [None]:
metric = MeanMetric()
#normalize output value by mean and std of training dataset
sdt_age_train=train.age.std()
mean_age_train=train.age.mean()
train.loc[:,"age"]=(train.loc[:,"age"]-mean_age_train)/sdt_age_train
valid.loc[:,"age"]=(valid.loc[:,"age"]-mean_age_train)/sdt_age_train

In [None]:
#transform training and validation datasets to pytorch format
trainDSNew = [{"img": img, "age":age} for img,age in zip(train.filename,train.age)]
validDSNew = [{"img": img, "age":age} for img,age in zip(valid.filename,valid.age)]
set_seed(123)
train_ds = monai.data.Dataset(data=trainDSNew, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
valid_ds = monai.data.Dataset(data=validDSNew, transform=valid_transforms)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=torch.cuda.is_available())
#enable mixed precision to increase batch size
use_amp = True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
#define batchsize factor used for batch accumulation --> the virtual batch size is thus 120
batchsize_factor=120//BATCH_SIZE
#define maximum number of epochs
max_epochs = 50
#set seed for reproducibility
set_seed(123)
#choose cuda as the device if it is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#load DL model for regression using monai
model=monai.networks.nets.densenet121(spatial_dims=3, in_channels=1, out_channels=1)
#define MSELoss as regression loss function
loss_function = torch.nn.MSELoss()
#select optimizer used for training
if opt =="adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif opt=="sgd":
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
else:
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
#iterate over epochs
for epoch in range(max_epochs):
    #initalize starting time of epoch
    start = time.time()
    #start with model in train mode
    model.train()
    logs = {}
    epoch_loss = 0
    epoch_loss_val = 0
    step = 0
    #iterate over training batches
    for batch_data in train_loader:
        step += 1
        model=model.cuda()
        #load input scans and normalized age of batch
        inputs=batch_data["img"].cuda()
        age=batch_data["age"].cuda()
        age=age.float()
        age=age[:,None]
        #use mixed precision
        with torch.cuda.amp.autocast(): 
            #calculate model predictions
            outputs = model(inputs)
            #calculate loss
            loss=loss_function(outputs,age)
        loss = loss / batchsize_factor
        scaler.scale(loss).backward()
        #batch accumulation
        if ((step+1) % batchsize_factor)==0:
            #update scaler and optimizer
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            metric.update(loss.item()*batchsize_factor)
    #compute MSE loss for training dataset
    epoch_loss=metric.compute()
    metric.reset()
    logs['log loss'] = epoch_loss.item()
    #model validation
    #change model to evaluation model
    model.eval()
    with torch.no_grad():
        #iterate over batches in validation loader
        for batch_data in valid_loader:
            model=model.cuda()
            #load input scans and normalized age for batch
            inputs=batch_data["img"].cuda()
            age=batch_data["age"].cuda()
            age=age.float()
            age=age[:,None]
            #use mixed precision
            with torch.cuda.amp.autocast(): 
                #calculate model predictions
                outputs = model(inputs)
                #calculate loss
                loss=loss_function(outputs,age) 
            #compute MSE loss for validation dataset
            metric.update(loss.item())
        epoch_loss_val =metric.compute()
        logs['val_log loss'] = epoch_loss_val.item()
        metric.reset()
    liveloss.update(logs)
    liveloss.send()
    #save model
    torch.save(model.state_dict(),MODEL_DIR+"model_"+str(opt)+"_"+str(lr)+"_"+str(epoch)+".pth")
    #save model performance
    d = {'optimizer': [opt], 'LR': [lr], 'Epoch':[epoch], "Epoch-Loss":[epoch_loss_val]}
    df = pd.DataFrame(data=d)
    if os.path.isfile(filenameCSV):
        df.to_csv(filenameCSV, mode='a', header=False)
    else:
        df.to_csv(filenameCSV, mode='w', header=True)
    end = time.time()