### Imports

In [None]:
import monai
from monai.data import ITKReader
from monai.data import DataLoader
from monai.data import decollate_batch
from monai.transforms import LoadImage, LoadImaged, Compose, RandFlipd, RandZoomd, ScaleIntensityd, Resized, EnsureType, EnsureTyped, Activations, AsDiscrete, Decollated, adaptor, RandRotated, ScaleIntensity, Resize, ConcatItemsd, ToTensord, SpatialCropd, CenterSpatialCropd, Rotated, EnsureChannelFirstd, MapTransform
from monai.metrics import ROCAUCMetric
from monai.engines import SupervisedTrainer, SupervisedEvaluator
from monai.handlers import from_engine, ValidationHandler, StatsHandler, TensorBoardStatsHandler, CheckpointSaver, TensorBoardImageHandler, ClassificationSaver, CheckpointLoader
from monai.apps import get_logger
from monai.utils import ImageMetaKey as Key
from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt

from glob import glob

import pandas as pd

import numpy as np


import torch
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import nibabel

import ignite
from ignite.metrics import Accuracy
from ignite.engine import create_supervised_evaluator
from ignite.engine import create_supervised_trainer
from ignite.engine import Events

import logging

import sys

import optuna
from optuna.trial import TrialState

import mlflow
from optuna.integration.mlflow import MLflowCallback

import os

### Data loading

In [None]:
empty = [1017,10251,13362,14642,15967,18516,24283,25964,29866,31592,32120,32248,43899,44323,46034,48151,50096,50156,54354,55034,56388,56890,57041,58325,59224,62591,65364,66028,67565,70158,70744,71067,75515,83014,83303,87267,90310,90614,95548] #39 Stück

In [None]:
df = pd.read_csv("/data/f18-psma-pet-ct-ml/data/labels.tsv", sep="\t")

df = df.assign(pet=lambda df: df['pseudo_id'].map(lambda pseudo_id: "/data/f18-psma-pet-ct-ml/cropped_nifti_urinary_bladder/" + str(pseudo_id).zfill(5) + "_pet.nii.gz" if pseudo_id in empty else "/data/f18-psma-pet-ct-ml/cropped_nifti_prostate/" + str(pseudo_id).zfill(5) + "_pet.nii.gz"))
df = df.assign(ct=lambda df: df['pseudo_id'].map(lambda pseudo_id: "/data/f18-psma-pet-ct-ml/cropped_nifti_urinary_bladder/" + str(pseudo_id).zfill(5) + "_ct.nii.gz" if pseudo_id in empty else "/data/f18-psma-pet-ct-ml/cropped_nifti_prostate/" + str(pseudo_id).zfill(5) + "_ct.nii.gz"))
df.head()

In [None]:
scaler = MinMaxScaler()
psa_normalized = scaler.fit_transform(df[["psa"]])
df["psa_norm"] = psa_normalized

### Sort out some IDs

In [None]:
problematic = [13019, 53135, 94420, 32841, 80544, 84704, 26023, 80297, 85350, 80857, 55044, 18663, 20684, 87138, 97067, 76290, 96548, 40776, 21150, 37960, 54052, 30443, 64579, 93143, 27689, 73064, 
               9404, 31111, 4433, 21589, 42404, 29825, 52939, 45756, 8099, 93472,72491, 59397, 75553, 24480, 67496, 67384, 86676, 3543, 19369, 14932, 97053, 40931, 55904, 47830, 96595, 88341, 14382, 
               39, 14579, 20481, 58596, 90461, 90747]

df = df[~df.pseudo_id.isin(problematic)]
df = df[df.label != 3]

### Label correction

In [None]:
df.head()

In [None]:
df.loc[(df['label'] == 0) & (df['alt_label'] == 1), 'label'] = 1

In [None]:
df = df.dropna()

### Create sets

In [None]:
complete_data = df.to_dict('records') 
train_data = df[df["set"] == "train"].to_dict('records')
val_data = df[df["set"] == "val"].to_dict('records')
#train_data = df[df["set"] == "train"].iloc[0:1].to_dict('records')
#val_data = df[df["set"] == "val"].iloc[1:2].to_dict('records')
print(f"Complete: {len(complete_data)}\nTraining: {len(train_data)}\nValidation: {len(val_data)}")

### Defining the transforms

In [None]:
class Repeatd(MapTransform):

    def __init__(
        self,
        keys,
        target_size,
    ) -> None:
        MapTransform.__init__(self, keys, allow_missing_keys = True)
        self.target_size = target_size

    def __call__(self, data):

        d = dict(data)
        for key in d:
            if key in self.keys:
                tensor = torch.Tensor([d[key]])
                d[key] = tensor.repeat(*self.target_size)
        return d

In [None]:
transforms = Compose(
    [
        LoadImaged(keys=["ct","pet"]),
        EnsureChannelFirstd(keys=["ct","pet"]),
        ScaleIntensityd(keys=["ct","pet"]),
        Resized(keys=["ct","pet"], spatial_size=(70, 70, 70)),
        #Repeatd(keys=["psa_norm", "px"], target_size=(1, 70, 70, 70)),
        EnsureTyped(keys=["ct","pet"]),  
        ConcatItemsd(keys=["ct", "pet"], name="petct", dim=0),  
                                              
        ToTensord(keys=["petct", "ct", "pet"]),
    ]
)

In [None]:
post_pred = Compose([EnsureType(), Activations(softmax=True)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

### Optuna

In [None]:
prepare_batch=lambda batch, device, non_blocking: (batch["petct"].to(device), batch["label"].to(device))

In [None]:
# for the parameter "block_inplanes" in the ResNet networks
def get_inplanes():
    return [64, 128, 256, 512]

In [None]:
# Define the four different networks
DenseNet121 = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=4, out_channels=2)
DenseNet201 = monai.networks.nets.DenseNet201(spatial_dims=3, in_channels=4, out_channels=2)
ResNet34 = monai.networks.nets.ResNet(block="basic", layers=[3, 4, 6, 3], block_inplanes=get_inplanes(), spatial_dims=3, n_input_channels=4)
ResNet50 = monai.networks.nets.ResNet(block="bottleneck", layers=[3, 4, 6, 3], block_inplanes=get_inplanes(), spatial_dims=3, n_input_channels=4)

In [None]:
# Create the dictionary "models"
models = {
    "DenseNet121": DenseNet121,
    "DenseNet201": DenseNet201,
    "ResNet34": ResNet34,
    "ResNet50": ResNet50
}

In [None]:
# Optunas objective function to define the search space

def objective(trial):
    # Generate the model - choice between four different models/networks
    model_name = trial.suggest_categorical("model", ["DenseNet121", "DenseNet121"])
    model = models[model_name]
    
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
        model.cuda(device)  
    
    # Define the search space for number of epochs and batch size
    num_epochs = 15
    #seed = trial.suggest_int("seed", 1, 40)
    batch_size = 16
    batchsize = 16
    
    # Define the search space for the transforms - with minimum value, maximum value and step
    prob_CSC = trial.suggest_float("prob_CSC", 1, 1, step=1) # Probability for CenterSpatialCropd
    prob_aug = trial.suggest_float("prob_aug", 0, 1, step=0.1) # Probability for Augmentation(RandRotated, RandAxisFlipd, RandZoomd)
    #prob_px = trial.suggest_float("prob_px", 1, 1, step=1)
    #csc_x = trial.suggest_int("csc_x", 30, 70, step=1) # X value of CenterSpatialCropd roi_size
    #csc_y = trial.suggest_int("csc_y", 30, 70, step=1) # Y value of CenterSpatialCropd roi_size
    #csc_z = trial.suggest_int("csc_z", 30, 70, step=1) # Z value of CenterSpatialCropd roi_size
    minzoom = trial.suggest_float("minzoom", 0.5, 1.5, step=0.1)
    maxzoom = trial.suggest_float("maxzoom", 0.5, 1.5, step=0.1)

    train_transforms = Compose([transforms,
        #CenterSpatialCropd(keys=["ct", "pet"], roi_size = (csc_x, csc_y, csc_z)),
        Repeatd(keys=["psa_norm", "px"], target_size=(1, 70, 70, 70)),
        RandRotated(keys=["ct","pet"], prob=0.8, range_x=[-0.2,0.2], range_y=[-0.1,0.1], mode=['bilinear', 'nearest']),
        #RandFlipd(keys=["ct", "pet"], prob=prob_aug, spatial_axis=1), 
        RandZoomd(keys=["ct", "pet"], prob=prob_aug, min_zoom=minzoom, max_zoom=maxzoom),
        EnsureTyped(keys=["ct","pet", "psa_norm", "px"]),  
        ConcatItemsd(keys=["ct", "pet", "psa_norm", "px"], name="petct", dim=0), 
    ])
    
    val_transforms = Compose([transforms, 
        #CenterSpatialCropd(keys=["ct", "pet"], roi_size = (csc_x, csc_y, csc_z)),
        Repeatd(keys=["psa_norm", "px"], target_size=(1, 70, 70, 70)),
        #RandRotated(keys=["ct","pet"], prob=0.8, range_x=[-0.2,0.2], range_y=[-0.1,0.1], mode=['bilinear', 'nearest']),
        #RandFlipd(keys=["ct", "pet"], prob=prob_aug, spatial_axis=1), 
        #RandZoomd(keys=["ct", "pet"], prob=prob_aug, min_zoom=0.7, max_zoom=1.2),
        EnsureTyped(keys=["ct","pet", "psa_norm", "px"]),  
        ConcatItemsd(keys=["ct", "pet", "psa_norm", "px"], name="petct", dim=0),                      
    ])

    
    # Generate the training and validation dataset and dataloader 
    train_ds = monai.data.Dataset(data=train_data, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=batchsize, shuffle=True, num_workers=1, pin_memory=torch.cuda.is_available())
    
    val_ds = monai.data.Dataset(data=val_data, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=batchsize, num_workers=1, pin_memory=torch.cuda.is_available())
    

    # Generate the optimizers and define the search space for the learning rate (Adam as optimizer seemed to perform better than SGD)
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log = True)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
                
    # Generate trainer and evaluator
    trainer = create_supervised_trainer(model, optimizer, nn.CrossEntropyLoss(), device=device, prepare_batch=prepare_batch)
    evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy()}, device=device, prepare_batch=prepare_batch)
    train_evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy()}, device=device, prepare_batch=prepare_batch)

    # log the validation accuracy after every epoch and print its value
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(engine):
        evaluator.run(val_loader)
        validation_acc = evaluator.state.metrics["accuracy"]
        print("Epoch: {} Validation accuracy: {:.4f}".format(engine.state.epoch, validation_acc))

        train_evaluator.run(train_loader)
        training_acc = train_evaluator.state.metrics["accuracy"]
        print("Epoch: {} Training accuracy: {:.4f}".format(engine.state.epoch, training_acc))
        
        # attach the Pruner - Optuna should prune trials dependent on the validation accuracy
        trial.report(validation_acc, engine.state.epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()


    trainer.run(train_loader, max_epochs=num_epochs)

    evaluator.run(val_loader)
    return evaluator.state.metrics["accuracy"]
     

In [None]:
# Integration of MLflow - stores all experiment data in "mlruns_trial_models" folder
mlflc = MLflowCallback(
    tracking_uri="mlruns_trial_models",
    metric_name="accuracy"
)

In [None]:
# storage for the Optuna dashboard
storage = "sqlite:////data/f18-psma-pet-ct-ml/code/Code_Marko/Master/Code/Optuna/Optuna_SQLite1_7c"

In [None]:
if __name__ == "__main__":
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
    study = optuna.create_study(direction="maximize", storage=storage, study_name="trial_models_marko", pruner=optuna.pruners.MedianPruner())
    study.optimize(objective, n_trials=25, callbacks=[mlflc])
    
    study.set_user_attr("Loss_function", "CrossEntropyLoss")

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))
    print(f"Sampler: {study.sampler.__class__.__name__}")
    print(f"Pruner: {study.pruner.__class__.__name__}")

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))