### Imports

In [None]:
import monai
from monai.data import ITKReader
from monai.data import DataLoader
from monai.data import decollate_batch
from monai.transforms import LoadImage, LoadImaged, Compose, ScaleIntensityd, RandFlipd, RandZoomd, Resized, EnsureType, EnsureTyped, Activations, AsDiscrete, Decollated, adaptor, RandRotated, ScaleIntensity, Resize, ConcatItemsd, ToTensord, SpatialCropd, CenterSpatialCropd, Rotated, EnsureChannelFirstd, MapTransform
from monai.metrics import ROCAUCMetric
from monai.engines import SupervisedTrainer, SupervisedEvaluator
from monai.handlers import from_engine, ValidationHandler, StatsHandler, TensorBoardStatsHandler, CheckpointSaver, TensorBoardImageHandler, ClassificationSaver, CheckpointLoader
from monai.apps import get_logger
from monai.utils import ImageMetaKey as Key
from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt

from glob import glob

import pandas as pd

import numpy as np

import torch
from torch.utils.tensorboard import SummaryWriter
import nibabel

import ignite
from ignite.metrics import Accuracy

import logging

import sys

### Data loading

In [None]:
empty = [1017,10251,13362,14642,15967,18516,24283,25964,29866,31592,32120,32248,43899,44323,46034,48151,50096,50156,54354,55034,56388,56890,57041,58325,59224,62591,65364,66028,67565,70158,70744,71067,75515,83014,83303,87267,90310,90614,95548] #39 Stück

In [None]:
df = pd.read_csv("/data/f18-psma-pet-ct-ml/data/labels.tsv", sep="\t")

df = df.assign(pet=lambda df: df['pseudo_id'].map(lambda pseudo_id: "/data/f18-psma-pet-ct-ml/cropped_nifti_urinary_bladder/" + str(pseudo_id).zfill(5) + "_pet.nii.gz" if pseudo_id in empty else "/data/f18-psma-pet-ct-ml/cropped_nifti_prostate/" + str(pseudo_id).zfill(5) + "_pet.nii.gz"))
df = df.assign(ct=lambda df: df['pseudo_id'].map(lambda pseudo_id: "/data/f18-psma-pet-ct-ml/cropped_nifti_urinary_bladder/" + str(pseudo_id).zfill(5) + "_ct.nii.gz" if pseudo_id in empty else "/data/f18-psma-pet-ct-ml/cropped_nifti_prostate/" + str(pseudo_id).zfill(5) + "_ct.nii.gz"))
df.head()

In [None]:
scaler = MinMaxScaler()
psa_normalized = scaler.fit_transform(df[["psa"]])
df["psa_norm"] = psa_normalized

### Sort out some IDs

In [None]:
problematic = [13019, 53135, 94420, 32841, 80544, 84704, 26023, 80297, 85350, 80857, 55044, 18663, 20684, 87138, 97067, 76290, 96548, 40776, 21150, 37960, 54052, 30443, 64579, 93143, 27689, 73064, 
               9404, 31111, 4433, 21589, 42404, 29825, 52939, 45756, 8099, 93472,72491, 59397, 75553, 24480, 67496, 67384, 86676, 3543, 19369, 14932, 97053, 40931, 55904, 47830, 96595, 88341, 14382, 
               39, 14579, 20481, 58596, 90461, 90747]

df = df[~df.pseudo_id.isin(problematic)]
df = df[df.label != 3]

### Label correction

In [None]:
df.head()

In [None]:
df.loc[(df['label'] == 0) & (df['alt_label'] == 1), 'label'] = 1

In [None]:
df = df.dropna()

### Create sets

In [None]:
complete_data = df.to_dict('records') 
train_data = df[df["set"] == "train"].to_dict('records')
val_data = df[df["set"] == "val"].to_dict('records')
#train_data = df[df["set"] == "train"].iloc[0:1].to_dict('records')
#val_data = df[df["set"] == "val"].iloc[1:2].to_dict('records')
print(f"Complete: {len(complete_data)}\nTraining: {len(train_data)}\nValidation: {len(val_data)}")

### Defining the transforms

In [None]:
class Repeatd(MapTransform):

    def __init__(
        self,
        keys,
        target_size,
    ) -> None:
        MapTransform.__init__(self, keys, allow_missing_keys = True)
        self.target_size = target_size

    def __call__(self, data):

        d = dict(data)
        for key in d:
            if key in self.keys:
                d[key] = torch.Tensor([d[key]]).repeat(*self.target_size)
        return d

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["ct","pet"]),  
        EnsureChannelFirstd(keys=["ct","pet"]), 
        ScaleIntensityd(keys=["ct","pet"]), 
        Resized(keys=["ct","pet"], spatial_size=(70, 70, 70)), 
        Repeatd(keys=["psa_norm", "px"], target_size=(1, 70, 70, 70)),
        #CenterSpatialCropd(keys=["ct", "pet"], roi_size = (30, 40, 30)),
        RandZoomd(keys=["ct", "pet"], prob=0.7, min_zoom=0.5, max_zoom=1.5),
        RandRotated(keys=["ct","pet"], prob=0.8, range_x=[-0.2,0.2], range_y=[-0.1,0.1], mode=['bilinear', 'nearest']),                                                                                                              
        EnsureTyped(keys=["ct","pet", "psa_norm", "px"]),  
        ConcatItemsd(keys=["ct", "pet", "psa_norm", "px"], name="petct", dim=0),  
                                              
        ToTensord(keys=["petct", "ct", "pet"]),  
    ]
)

In [None]:
val_transforms = Compose(
    [
        LoadImaged(keys=["ct","pet"]),
        EnsureChannelFirstd(keys=["ct","pet"]),
        ScaleIntensityd(keys=["ct","pet"]),
        Resized(keys=["ct","pet"], spatial_size=(70, 70, 70)),
        Repeatd(keys=["psa_norm", "px"], target_size=(1, 70, 70, 70)),
        #CenterSpatialCropd(keys=["ct", "pet"], roi_size = (30, 40, 30)),
        EnsureTyped(keys=["ct","pet", "psa_norm", "px"]),  
        ConcatItemsd(keys=["ct", "pet", "psa_norm", "px"], name="petct", dim=0),  
                                              
        ToTensord(keys=["petct", "ct", "pet"]),
    ]
) 

### Create data loaders

In [None]:
batchsize = 16

In [None]:
complete_ds = monai.data.Dataset(data=complete_data, transform=val_transforms)
complete_loader = DataLoader(complete_ds, batch_size=batchsize, num_workers=1, pin_memory=torch.cuda.is_available())

train_ds = monai.data.Dataset(data=train_data, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batchsize, shuffle=True, num_workers=1, pin_memory=torch.cuda.is_available())

val_ds = monai.data.Dataset(data=val_data, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=batchsize, num_workers=1, pin_memory=torch.cuda.is_available())

### Create model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=4, out_channels=2).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
auc_metric = ROCAUCMetric()

### Use SupervisedTrainer

In [None]:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
get_logger("train_log")

In [None]:
prepare_batch = lambda batch, device, non_blocking: (batch["petct"].to(device), batch["label"].to(device))

In [None]:
def get_pids(batch):
  return {Key.FILENAME_OR_OBJ: from_engine(["pseudo_id"])(batch)}


def output_for_csv(output):
	res = from_engine(["pred", "label"])(output)
	return [torch.concat([res[0][i], res[0][i].argmax().unsqueeze(0), torch.Tensor([res[1][i]]).to(device)]) for i in range(len(res[0]))]

##### Create handlers + Trainer and Evaluator

In [None]:
val_handlers = [
    StatsHandler(name="train_log", output_transform=lambda x: None),
    TensorBoardStatsHandler(log_dir="/data/f18-psma-pet-ct-ml/runs_prostate_marko_model7c", output_transform=lambda x: None),
    CheckpointSaver(save_dir="/data/f18-psma-pet-ct-ml/runs_prostate_marko_model7c", save_dict={"net": model}, save_key_metric=True),
    ClassificationSaver(output_dir="/data/f18-psma-pet-ct-ml/code/Code_Marko/Master/Files", filename="predictions_model7c.csv", delimiter="\t", overwrite=True, output_transform=output_for_csv, batch_transform=get_pids)
]

In [None]:
evaluator = SupervisedEvaluator(
    device = device,
    val_data_loader = val_loader,
    network = model,
    prepare_batch = prepare_batch,
    key_val_metric = {"val_acc": Accuracy(output_transform = from_engine(["pred", "label"]))},
    val_handlers = val_handlers,
    amp = True if monai.utils.get_torch_version_tuple() >= (1, 6) else False
)

In [None]:
train_handlers = [
    ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
    StatsHandler(name="train_log", tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
    TensorBoardStatsHandler(log_dir="/data/f18-psma-pet-ct-ml/runs_prostate_marko_model7c", tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
    CheckpointSaver(save_dir="/data/f18-psma-pet-ct-ml/runs_prostate_marko_model7c", save_dict={"net": model, "opt": optimizer}, save_interval=1, epoch_level=True)
]

In [None]:
trainer = SupervisedTrainer(
    device = device,
    max_epochs = 15,
    train_data_loader = train_loader,
    network = model,
    optimizer = optimizer,
    loss_function = loss_function,
    prepare_batch = prepare_batch,
    key_train_metric = {"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
    train_handlers = train_handlers,
    amp = False
)

### Training and evaluation

In [None]:
trainer.run()

In [None]:
#evaluator.get_validation_stats()

In [None]:
#best_epoch = evaluator.state.best_metric_epoch
#print(best_epoch)

### Prediction

In [None]:
#handler = CheckpointLoader(f"/data/runs_prostate_marko_model2/checkpoint_epoch={best_epoch}.pt", load_dict={"net": model, "opt": optimizer})
#handler(trainer)

In [None]:
#df = pd.read_csv("/data/f18-psma-pet-ct-ml/code/Code_Marko/...",sep="\t")
#df.head()

In [None]:
#model.eval()
#for batch in iter(complete_loader):
#    IDs = batch["pseudo_id"]
#    Preds = model(batch["petct"].to(device)).argmax(dim=1)
#    for ID, Pred in zip(IDs, Preds):
#        df.loc[df.pseudo_id == ID.item(), 'model 2'] = Pred.item()
#        print(ID, Pred)
#model.train()

In [None]:
#df.to_csv(path_or_buf="/data/f18-psma-pet-ct-ml/code/Code_Marko/...", sep="\t", index=False)

### Tensorboard

In [None]:
#%load_ext tensorboard
#%tensorboard --logdir=/data/runs_prostate_marko_model1 --port=12345