In [None]:
def modify_labels(labels, organs):
    """
    Change labels so that the targetted sections are numbered from 1.0.
    """
    
    all_org = organs['all']
    main = organs['main']
    aux  = organs['aux']
    dict = organs['dict']
    
    # Modify the main labels to match the output of the main decoder
    main_labels = labels.clone()
    aux_labels = labels.clone()

    count_main = 1.0
    count_aux  = 1.0
    
    for organ in all_org:
        # Modify the labels to match the output of the decoder
        if organ in main:
            main_labels[main_labels == dict[organ]] = count_main
            count_main += 1.0
        else:
            main_labels[main_labels != dict[organ]] = 0.0
            
        if organ in aux:
            aux_labels[aux_labels == dict[organ]] = count_aux
            count_aux += 1.0
        else:
            aux_labels[aux_labels != dict[organ]] = 0.0
    
    return main_labels, aux_labels


In [None]:
from pathlib import Path
import random
from shutil import copyfile
import os

def split_data(img_path, scale=1):
    """
    Read all images and divide into training, validation and test sets.
    Scale to test models on fewer data.
    """
    
    print("-" * 40)
    print("Splitting data into train-validate-test sets...")
    
    # Delete two images that do not have segmentation masks
    for file in ['001001_img.nii', '005057_img.nii']:
        if os.path.exists(Path(img_path / file)):
            os.remove(Path(img_path / file))
        else:
            print("The file does not exist")
        
    # Read all files ending with _img.nii
    img_files   = list(img_path.glob("*_img.nii")) # Image and mask are in the same folder
    num_images  = len(img_files)

    # Create train, validation and test splits
    train_split = int(0.8 * num_images / scale)
    val_split   = int(0.1 * num_images / scale)
    test_split  = int((num_images - (train_split + val_split)* scale) / scale)

    # Set the random seed for reproducibility
    random.seed(2022)
    
    # Shuffle the image files
    random.shuffle(img_files)

    # Split the dataset
    """train_images    = img_files[:train_split]
    val_images      = img_files[train_split:(train_split + val_split)]
    test_images     = img_files[(train_split + val_split): int(num_images/scale)]"""
    
    train_images    = img_files[:2]
    val_images      = img_files[2:4]
    test_images     = img_files[4:6]

    # Create train, validation and test directories
    train_image_dir     = Path(img_path / "train_images")
    train_mask_dir      = Path(img_path / "train_masks")
    val_image_dir       = Path(img_path / "val_images")
    val_mask_dir        = Path(img_path / "val_masks")
    test_image_dir      = Path(img_path / "test_images")
    test_mask_dir       = Path(img_path / "test_masks")

    # Create the directories if they don't exist
    if not os.path.exists(train_image_dir) and not os.path.exists(train_mask_dir) and not os.path.exists(val_image_dir) and not os.path.exists(val_mask_dir) and not os.path.exists(test_image_dir) and not os.path.exists(test_mask_dir):
        for directory in [train_image_dir, train_mask_dir, val_image_dir, val_mask_dir, test_image_dir, test_mask_dir]:
            directory.mkdir(exist_ok = True, parents = True)

        # Copy the images and their corresponding segmentation masks to their respective directories
        for directory, images in zip([train_image_dir, val_image_dir, test_image_dir], [train_images, val_images, test_images]):
            for image in images:
                # Copy image
                copyfile(image, directory / image.name)

                # Get corresponding segmentation mask
                mask = image.name.replace("_img.nii", "_mask.nii")

                # Copy segmentation mask
                copyfile(image.parent / mask, image.parent / directory.name.replace("images", "masks") / mask)

    # Put the train images and masks in a dictionary
    train_images    = sorted(train_image_dir.glob("*"))
    train_masks     = sorted(train_mask_dir.glob("*"))
    train_files     = [{"image": image_name, "mask": mask_name} for image_name, mask_name in zip(train_images, train_masks)]
    
    # Put the validation images and masks in a dictionary
    val_images      = sorted(val_image_dir.glob("*"))
    val_masks       = sorted(val_mask_dir.glob("*"))
    val_files       = [{"image": image_name, "mask": mask_name} for image_name, mask_name in zip(val_images, val_masks)]
    
    # Put the test images and masks in a dictionary
    test_images     = sorted(test_image_dir.glob("*"))
    test_masks      = sorted(test_mask_dir.glob("*"))
    test_files      = [{"image": image_name, "mask": mask_name} for image_name, mask_name in zip(test_images, test_masks)]
        
    print('Images have been divided into train-validate-test sets.')
    print('Total number of images: ', num_images)
    print('Number of images train-validate-test: ', train_split, '-', val_split, '-', test_split)

    return train_files, val_files, test_files
    

In [None]:
import torch
import pickle
from monai.data             import DataLoader, Dataset, decollate_batch
from monai.metrics          import DiceMetric
from pathlib                import Path
from labels                 import modify_labels

def set_data(val_files, val_transforms, BATCH_SIZE):
    """
    Create dataloader for test set.
    """
    
    torch.cuda.empty_cache()
    val_ds = Dataset(data = val_files, transform = val_transforms)
    val_dl = DataLoader(dataset = val_ds, batch_size = BATCH_SIZE, num_workers = 4, shuffle = False)
    
    return val_dl


def set_model_params():
    """
    Set metrics for evaluation.
    """
    
    # Input image has eight anatomical structures of planning interest
    dice_metric_main    = DiceMetric(include_background=False, reduction="mean") # Collect the loss and metric values for every iteration
    
    return dice_metric_main


def save_results(MODEL_NAME, MODEL_PATH, main_metric_values):
    """
    Save performance metrics.
    """
    
    # Save metric values
    pref = f"{MODEL_NAME.split('.')[0]}"
    with open(MODEL_PATH/f"{pref}_test.pkl", "wb") as f:
        pickle.dump(main_metric_values, f)


def test_model_base(model, device, params, val_files, val_transforms, organs_dict, pred_main, label_main, model_name):
    """
    Evaluate the test dataset
    """
    BATCH_SIZE = params['BATCH_SIZE']
    
    val_dl              = set_data(val_files, val_transforms, BATCH_SIZE)
    dice_metric_main    = set_model_params()
    
    # Model save path
    MODEL_PATH = Path("models")
    MODEL_NAME = model_name + ".pth"
    MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
    
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    model.eval()

    print("-" * 40)
    print("Starting model testing...")
    
    # Disable gradient calculation
    with torch.inference_mode():
        # Loop through the validation data
        for val_data in val_dl:
            val_inputs, val_labels = val_data["image"].permute(0, 1, 4, 2, 3).to(device), val_data["mask"].to(device)
            val_main_labels, _     = modify_labels(val_labels, organs_dict)

            # Forward pass
            val_main_outputs = model(val_inputs)
            val_main_outputs = val_main_outputs.permute(0, 1, 3, 4, 2)
            
            # Transform main outputs and labels to calculate inference loss
            val_main_outputs    = [pred_main(i) for i in decollate_batch(val_main_outputs)]
            val_main_labels     = [label_main(i) for i in decollate_batch(val_main_labels)]

            # Compute dice metric for current iteration
            dice_metric_main(y_pred = val_main_outputs, y = val_main_labels)
            
        # Compute the average metric value across all iterations
        main_metric = dice_metric_main.aggregate().item()
        
    print(
        f"\nMean dice for main task: {main_metric:.4f}"
        )
    
    save_results(MODEL_NAME, MODEL_PATH, main_metric)
    

                    

        
        
        

In [None]:
import torch
import pickle
from monai.data                 import DataLoader, Dataset, decollate_batch
from monai.metrics              import DiceMetric, MSEMetric
from monai.metrics.regression   import SSIMMetric
from pathlib                    import Path
from labels                     import modify_labels

def set_data(val_files, val_transforms, BATCH_SIZE):
    """
    Create dataloader for test set.
    """
    
    torch.cuda.empty_cache()
    val_ds = Dataset(data = val_files, transform = val_transforms)
    val_dl = DataLoader(dataset = val_ds, batch_size = BATCH_SIZE, num_workers = 4, shuffle = False)
    
    return val_dl


def set_model_params(TASK):
    """
    Set metrics for evaluation.
    """
    
    # Input image has eight anatomical structures of planning interest
    metric_main    = DiceMetric(include_background=False, reduction="mean")# Collect the loss and metric values for every iteration
    if TASK == 'SEGMENT':
        metric_aux  = DiceMetric(include_background=False, reduction="mean")
    else:
        metric_aux  = MSEMetric()
    
    return metric_main, metric_aux


def save_results(MODEL_NAME, MODEL_PATH, main_metric_values, aux_metric_values):
    """
    Save performance metrics.
    """
    
    # Save metric values
    pref = f"{MODEL_NAME.split('.')[0]}"
    with open(MODEL_PATH/f"{pref}_main_test.pkl", "wb") as f:
        pickle.dump(main_metric_values, f)
    with open(MODEL_PATH/f"{pref}_aux_test.pkl", "wb") as f:
        pickle.dump(aux_metric_values, f)


def test_model(model, device, params, val_files, val_transforms, organs_dict, pred_main, label_main, pred_aux, label_aux, model_name):
    """
    Evaluate the test dataset
    """
    BATCH_SIZE = params['BATCH_SIZE']
    TASK       = params['TASK']
    
    val_dl                  = set_data(val_files, val_transforms, BATCH_SIZE)
    metric_main, metric_aux = set_model_params(TASK)
    
    # Model save path
    MODEL_PATH = Path("models")
    MODEL_NAME = model_name + ".pth"
    MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
    
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    model.eval()

    print("-" * 40)
    print("Starting model testing...")
    
    # Disable gradient calculation
    with torch.inference_mode():
        # Loop through the validation data
        for val_data in val_dl:
            val_inputs, val_labels = val_data["image"].permute(0, 1, 4, 2, 3).to(device), val_data["mask"].to(device)
            val_main_labels, val_aux_labels = modify_labels(val_labels, organs_dict)

            # Forward pass
            val_main_outputs, val_aux_outputs = model(val_inputs)
            val_main_outputs, val_aux_outputs = val_main_outputs.permute(0, 1, 3, 4, 2), val_aux_outputs.permute(0, 1, 3, 4, 2)

            # Transform main outputs and labels to calculate inference loss
            val_main_outputs    = [pred_main(i) for i in decollate_batch(val_main_outputs)]
            val_main_labels     = [label_main(i) for i in decollate_batch(val_main_labels)]

            # Compute dice metric for current iteration
            metric_main(y_pred = val_main_outputs, y = val_main_labels)
            if TASK == 'SEGMENT':
                # Transform aux outputs and labels to calculate inference loss
                val_aux_outputs     = [pred_aux(i) for i in decollate_batch(val_aux_outputs)]
                val_aux_labels      = [label_aux(i) for i in decollate_batch(val_aux_labels)]
            
                metric_aux(y_pred = val_aux_outputs, y = val_aux_labels)
            else:
                metric_aux(y_pred = val_aux_outputs, y = val_inputs.permute(0, 1, 3, 4, 2))
            
        # Compute the average metric value across all iterations
        main_metric = metric_main.aggregate().item()
        aux_metric  = metric_aux.aggregate().item()
        
    print(
        f"\nMean dice for main task: {main_metric:.4f}"
        f"\nMean metric for aux task: {aux_metric:.4f}"
        )
    
    save_results(MODEL_NAME, MODEL_PATH, main_metric, aux_metric)
    

                    

        
        
        

In [None]:
import torch
import pickle
from monai.data             import DataLoader, Dataset, decollate_batch
from monai.losses           import DiceLoss
from monai.metrics          import DiceMetric
from pathlib                import Path
from labels                 import modify_labels


def set_data(train_files, train_transforms, val_files, val_transforms, BATCH_SIZE):
    """
    Create dataloader for test set.
    """
    
    torch.cuda.empty_cache()
    train_ds = Dataset(data = train_files, transform = train_transforms)
    train_dl = DataLoader(dataset = train_ds, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)

    val_ds = Dataset(data = val_files, transform = val_transforms)
    val_dl = DataLoader(dataset = val_ds, batch_size = BATCH_SIZE, num_workers = 4, shuffle = False)
    
    return train_dl, val_dl


def set_model_params(model):
    """
    Set model parameters and metrics for evaluation.
    """
    
    # Input image has eight anatomical structures of planning interest
    loss_function       = DiceLoss(to_onehot_y = True, softmax = True, include_background=False) # For segmentation Expects BNHW[D] input i.e. batch, channel, height, width, depth, performs softmax on the channel dimension to get a probability distribution
    optimizer           = torch.optim.Adam(model.parameters(), (1e-3)/4) # Decreased the loss after getting a somewhat good model
    dice_metric_main    = DiceMetric(include_background=False, reduction="mean")# Collect the loss and metric values for every iteration
    scheduler           = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 60, eta_min = 1e-6) #** Adopt a cosine annealing learning rate schedule which reduces the learning rate as the training progresses
    
    return loss_function, optimizer, dice_metric_main, scheduler


def save_results(MODEL_NAME, MODEL_PATH, epoch_loss_values, main_metric_values):
    """
    Save performance metrics.
    """
    
    # Save epoch loss and metric values
    pref = f"{MODEL_NAME.split('.')[0]}"
    with open(MODEL_PATH/f"{pref}_epoch_loss_train.pkl", "wb") as f:
        pickle.dump(epoch_loss_values, f)
    with open(MODEL_PATH/f"{pref}_validate.pkl", "wb") as f:
        pickle.dump(main_metric_values, f)


def train_model_base(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, model_name):
    """
    Train the model on the training dataset and evaluate the validation dataset.
    """
    BATCH_SIZE      = params['BATCH_SIZE']
    MAX_EPOCHS      = params['MAX_EPOCHS']
    VAL_INTERVAL    = params['VAL_INTERVAL']
    PRINT_INTERVAL  = params['PRINT_INTERVAL']
    
    train_dl, val_dl = set_data(train_files, train_transforms, val_files, val_transforms, BATCH_SIZE)
    loss_function, optimizer, dice_metric_main, scheduler = set_model_params(model)
    
    # Create model directory
    MODEL_PATH = Path("models")
    MODEL_PATH.mkdir(parents=True, exist_ok=True)

    # Create model save path
    MODEL_NAME = model_name + ".pth"
    MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

    best_metric             = -1
    best_metric_epoch       = -1
    epoch_loss_values       = []
    main_metric_values      = []
    
    print("-" * 20)
    print("Starting model training...")
    
    for epoch in range(1,MAX_EPOCHS):
        if epoch % PRINT_INTERVAL == 0:
            print("-" * 20)
            print(f"Epoch {epoch} / {MAX_EPOCHS}")
        
        # Put the model into training mode
        model.train()
        epoch_loss = 0
        step = 0
        
        for batch in train_dl:
            step = step + 1
            inputs = batch["image"].permute(0, 1, 4, 2, 3).to(device)
            labels = batch["mask"].to(device) # Permute beccause of torch upsample
            
            main_labels, _ = modify_labels(labels, organs)

            # Forward pass
            main_seg = model(inputs) 
            main_seg = main_seg.permute(0, 1, 3, 4, 2) # Permute back to BNHWD

            # Compute the loss
            loss = loss_function(main_seg, main_labels) 

            # Zero the gradients
            optimizer.zero_grad()

            # Find the gradients of the loss w.r.t the model parameters
            loss.backward()

            # Update the parameters
            optimizer.step()

            # Add the loss to the epoch loss
            epoch_loss = epoch_loss + loss.item()
        
        # Compute the average loss of the epoch
        epoch_loss = epoch_loss        / step
        epoch_loss_values.append(epoch_loss)

        if epoch % PRINT_INTERVAL == 0:
            # Print the average loss of the epoch
            print(f"\nEpoch {epoch} average dice loss for main task: {epoch_loss:.4f}")

        # Step the scheduler after every epoch
        scheduler.step()

        # Print loss and evaluate model when epoch is divisible by val_interval
        if epoch % VAL_INTERVAL == 0:
            print("-" * 40)
            print("Testing on validation data...")
            
            # Put the model into evaluation mode
            model.eval()
            # Disable gradient calculation
            with torch.inference_mode():
                # Loop through the validation data
                for val_data in val_dl:
                    val_inputs, val_labels = val_data["image"].permute(0, 1, 4, 2, 3).to(device), val_data["mask"].to(device)
                    val_main_labels, _     = modify_labels(val_labels, organs)

                    # Forward pass
                    val_main_outputs = model(val_inputs)
                    val_main_outputs = val_main_outputs.permute(0, 1, 3, 4, 2)

                    # Transform main outputs and labels to calculate inference loss
                    val_main_outputs    = [pred_main(i) for i in decollate_batch(val_main_outputs)]
                    val_main_labels     = [label_main(i) for i in decollate_batch(val_main_labels)]

                    # Compute dice metric for current iteration
                    dice_metric_main(y_pred = val_main_outputs, y = val_main_labels)

                # Compute the average metric value across all iterations
                main_metric = dice_metric_main.aggregate().item()
                main_metric_values.append(main_metric)
                
                # Reset the metric for next validation run
                dice_metric_main.reset()

                # If the metric is better than the best seen so far, save the model
                if main_metric > best_metric:
                    best_metric = main_metric
                    best_metric_epoch = epoch
                    torch.save(model.state_dict(), MODEL_SAVE_PATH)
                    print("saved new best metric model")
                
                print(
                    f"\nCurrent epoch: {epoch} current mean dice for main task: {main_metric:.4f}"
                    f"\nBest mean dice for main task: {best_metric:.4f} at epoch: {best_metric_epoch}"
                    )
                
    # When training is complete:
    print(f"Done training! Best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}")
    
    save_results(MODEL_NAME, MODEL_PATH, epoch_loss_values, main_metric_values)
    

                    

        
        
        

In [None]:
import torch
import torch.nn as nn
import pickle
from monai.data                 import DataLoader, Dataset, decollate_batch
from monai.losses               import DiceLoss
from monai.losses.ssim_loss     import SSIMLoss
from monai.metrics              import DiceMetric, MSEMetric
from monai.metrics.regression   import SSIMMetric
from pathlib                    import Path
from labels                     import modify_labels


def set_data(train_files, train_transforms, val_files, val_transforms, BATCH_SIZE):
    """
    Create dataloader for test set.
    """
    
    torch.cuda.empty_cache()
    train_ds = Dataset(data = train_files, transform = train_transforms)
    train_dl = DataLoader(dataset = train_ds, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)

    val_ds = Dataset(data = val_files, transform = val_transforms)
    val_dl = DataLoader(dataset = val_ds, batch_size = BATCH_SIZE, num_workers = 4, shuffle = False)
    
    return train_dl, val_dl


def set_model_params(model, TASK):
    """
    Set model parameters and metrics for evaluation.
    """
    
    # Input image has eight anatomical structures of planning interest
    loss_main       = DiceLoss(to_onehot_y = True, softmax = True, include_background=False) 
    metric_main     = DiceMetric(include_background=False, reduction="mean")
    
    if TASK == 'SEGMENT':
        loss_aux    = DiceLoss(to_onehot_y = True, softmax = True, include_background=False) 
        metric_aux  = DiceMetric(include_background=False, reduction="mean")
    else:
        #loss_aux    = SSIMLoss()
        loss_aux    = torch.nn.L1Loss()
        metric_aux  = MSEMetric()
        
    optimizer       = torch.optim.Adam(model.parameters(), (1e-3)/4) # Decreased the loss after getting a somewhat good model
    scheduler       = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 60, eta_min = 1e-6) #** Adopt a cosine annealing learning rate schedule which reduces the learning rate as the training progresses
    
    return loss_main, loss_aux, metric_main, metric_aux, optimizer, scheduler


def save_results(MODEL_NAME, MODEL_PATH, epoch_loss_values, epoch_aux_loss_values, epoch_total_loss_values, main_metric_values, aux_metric_values):
    """
    Save performance metrics.
    """
    
    # Save epoch loss and metric values
    pref = f"{MODEL_NAME.split('.')[0]}"
    with open(MODEL_PATH/f"{pref}_epoch_loss.pkl", "wb") as f:
        pickle.dump(epoch_loss_values, f)
    with open(MODEL_PATH/f"{pref}_epoch_aux_loss.pkl", "wb") as f:
        pickle.dump(epoch_aux_loss_values, f)
    with open(MODEL_PATH/f"{pref}_epoch_total_loss.pkl", "wb") as f:
        pickle.dump(epoch_total_loss_values, f)
    with open(MODEL_PATH/f"{pref}_main_validation.pkl", "wb") as f:
        pickle.dump(main_metric_values, f)
    with open(MODEL_PATH/f"{pref}_aux_validation.pkl", "wb") as f:
        pickle.dump(aux_metric_values, f)


def train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name):
    """
    Train the model on the training dataset and evaluate the validation dataset.
    """
    BATCH_SIZE      = params['BATCH_SIZE']
    MAX_EPOCHS      = params['MAX_EPOCHS']
    VAL_INTERVAL    = params['VAL_INTERVAL']
    PRINT_INTERVAL  = params['PRINT_INTERVAL']
    TASK            = params['TASK']
    
    train_dl, val_dl = set_data(train_files, train_transforms, val_files, val_transforms, BATCH_SIZE)
    loss_main, loss_aux, metric_main, metric_aux, optimizer, scheduler = set_model_params(model, TASK)
    
    # Create model directory
    MODEL_PATH = Path("models")
    MODEL_PATH.mkdir(parents=True, exist_ok=True)

    # Create model save path
    MODEL_NAME = model_name + ".pth"
    MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

    best_metric             = -1
    best_metric_epoch       = -1
    epoch_loss_values       = []
    epoch_aux_loss_values   = []
    epoch_total_loss_values = []
    main_metric_values      = []
    aux_metric_values       = []

    # Loss weights
    main_weight = 1.1
    aux_weight  = 1.5
    
    print("-" * 20)
    print("Starting model training...")
    
    for epoch in range(1,MAX_EPOCHS):
        if epoch % PRINT_INTERVAL == 0:
            print("-" * 20)
            print(f"Epoch {epoch} / {MAX_EPOCHS}")
        
        # Put the model into training mode
        model.train()
        epoch_loss = 0
        epoch_aux_loss = 0
        epoch_total_loss = 0
        step = 0
        
        for batch in train_dl:
            step = step + 1
            inputs = batch["image"].permute(0, 1, 4, 2, 3).to(device)
            labels = batch["mask"].to(device) # Permute beccause of torch upsample
            
            main_labels, aux_labels = modify_labels(labels, organs)

            # Forward pass
            main_seg, aux_seg = model(inputs) 
            main_seg, aux_seg = main_seg.permute(0, 1, 3, 4, 2), aux_seg.permute(0, 1, 3, 4, 2) # Permute back to BNHWD

            # Compute the loss functions
            main_seg_loss = loss_main(main_seg, main_labels)
            if TASK == 'SEGMENT':
                aux_seg_loss = loss_aux(aux_seg, aux_labels)
            else:
                aux_seg_loss = loss_aux(aux_seg, inputs.permute(0, 1, 3, 4, 2))
                
            # Compute the total loss
            loss = main_weight * main_seg_loss + aux_weight * aux_seg_loss

            # Zero the gradients
            optimizer.zero_grad()

            # Find the gradients of the loss w.r.t the model parameters
            loss.backward()

            # Update the parameters
            optimizer.step()

            # Add the loss to the epoch loss
            epoch_loss = epoch_loss + main_seg_loss.item()
            epoch_aux_loss = epoch_aux_loss + aux_seg_loss.item()
            epoch_total_loss = epoch_total_loss + loss.item()
        
        # Compute the average loss of the epoch
        epoch_loss          = epoch_loss        / step
        epoch_aux_loss      = epoch_aux_loss    / step
        epoch_total_loss    = epoch_total_loss  / step
        epoch_loss_values.append(epoch_loss)
        epoch_total_loss_values.append(epoch_total_loss)
        epoch_aux_loss_values.append(epoch_aux_loss)

        if epoch % PRINT_INTERVAL == 0:
            # Print the average loss of the epoch
            print(f"\nEpoch {epoch} average loss for main task: {epoch_loss:.4f}")
            print(f"\nEpoch {epoch} average loss for aux task: {epoch_aux_loss:.4f}")
            print(f"\nEpoch {epoch} average total loss for both tasks: {epoch_total_loss:.4f}")

        # Step the scheduler after every epoch
        scheduler.step()

        # Print loss and evaluate model when epoch is divisible by val_interval
        if epoch % VAL_INTERVAL == 0:
            print("-" * 40)
            print("Testing on validation data...")
            
            # Put the model into evaluation mode
            model.eval()
            # Disable gradient calculation
            with torch.inference_mode():
                # Loop through the validation data
                for val_data in val_dl:
                    val_inputs = val_data["image"].permute(0, 1, 4, 2, 3).to(device)
                    val_labels = val_data["mask"].to(device)
                    
                    val_main_labels, val_aux_labels = modify_labels(val_labels, organs)

                    # Forward pass
                    val_main_outputs, val_aux_outputs = model(val_inputs)
                    val_main_outputs, val_aux_outputs = val_main_outputs.permute(0, 1, 3, 4, 2), val_aux_outputs.permute(0, 1, 3, 4, 2)

                    # Transform main outputs and labels to calculate inference loss
                    val_main_outputs    = [pred_main(i) for i in decollate_batch(val_main_outputs)]
                    val_main_labels     = [label_main(i) for i in decollate_batch(val_main_labels)]

                    # Compute metric for current iteration
                    metric_main(y_pred = val_main_outputs, y = val_main_labels)
                    if TASK == 'SEGMENT':
                        # Transform aux outputs and labels to calculate inference loss
                        val_aux_outputs     = [pred_aux(i) for i in decollate_batch(val_aux_outputs)]
                        val_aux_labels      = [label_aux(i) for i in decollate_batch(val_aux_labels)]
                        
                        metric_aux(y_pred = val_aux_outputs, y = val_aux_labels)
                    else:
                        metric_aux(y_pred = val_aux_outputs, y = inputs.permute(0, 1, 3, 4, 2))
                        
                # Compute the average metric value across all iterations
                main_metric = metric_main.aggregate().item()
                aux_metric  = metric_aux.aggregate().item()
                main_metric_values.append(main_metric)
                aux_metric_values.append(aux_metric)
                
                # Reset the metric for next validation run
                metric_main.reset()
                metric_aux.reset()

                # If the metric is better than the best seen so far, save the model
                if main_metric > best_metric:
                    best_metric = main_metric
                    best_metric_epoch = epoch
                    torch.save(model.state_dict(), MODEL_SAVE_PATH)
                    print("saved new best metric model")
                
                print(
                    f"\nCurrent epoch: {epoch} current mean dice for main task: {main_metric:.4f}"
                    f"\nBest mean dice for main task: {best_metric:.4f} at epoch: {best_metric_epoch}"
                    f"\nCurrent epoch: {epoch} current mean metric for aux task: {aux_metric:.4f}"
                    )
                
    # When training is complete:
    print(f"Done training! Best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}")
    
    save_results(MODEL_NAME, MODEL_PATH, epoch_loss_values, epoch_aux_loss_values, epoch_total_loss_values, main_metric_values, aux_metric_values)
    

                    

        
        
        

In [None]:
import numpy as np

from monai.transforms import (
    EnsureChannelFirstd, # Adjust or add the channel dimension of input data to ensure channel_first shape.
    CenterSpatialCropd,
    Compose,
    AsDiscrete,
    LoadImaged,
    ScaleIntensityd,
    Spacingd,
    SpatialPadd,
    RandAffined, 
    CropForegroundd, # Crop the foreground region of the input image based on the provided mask to help training and evaluation if the valid part is small in the whole medical image
    RandGaussianNoised, # Randomly add Gaussian noise to image.
    RandGaussianSmoothd, # Randomly smooth image with Gaussian filter.
    AdjustContrastd, # Adjust image contrast by gamma value.
)

def get_transforms():
    print("-" * 40)
    print("Creating transformations...")
    
    # Create transforms for training
    train_transforms = Compose(
        [
            LoadImaged(keys = ["image", "mask"]),
            EnsureChannelFirstd(keys = ["image", "mask"]),
            ScaleIntensityd(keys = "image"),
            CropForegroundd(keys = ["image", "mask"], source_key = "image"),
            Spacingd(
                keys = ["image", "mask"],
                pixdim = [0.75, 0.75, 2.5],
                mode = ("bilinear", "nearest"), # Interpolation mode for image and mask
            ),
            RandAffined(
                keys = ["image", "mask"],
                mode = ("bilinear", "nearest"),
                prob = 1.0,
                spatial_size = (256, 256, 40), # Output size of the image [height, width, depth]
                rotate_range = (np.pi / 36, np.pi / 36, np.pi / 36), # Rotation range
                scale_range = (0.1, 0.1, 0.1), # will do [-0.1, 0.1] scaling then add 1 so a scaling in the range [0.9, 1.1]
                padding_mode="zeros", # This means that the image will be padded with zeros, some images are smaller than 256x256x40
            ),
            RandGaussianNoised(
                keys = "image",
                prob = 0.15,
                mean = 0.0,
                std = 0.1
            ),
            RandGaussianSmoothd(
                keys = "image",
                prob = 0.1,
                sigma_x=(0.5, 1.5),
                sigma_y=(0.5, 1.5),
                sigma_z=(0.5, 1.5)
            ),
            AdjustContrastd(
                keys = "image",
                gamma = 1.3
            )
        ]
    )

    # Create transforms for validation
    val_transforms = Compose(
        [
            LoadImaged(keys = ["image", "mask"]),
            EnsureChannelFirstd(keys = ["image", "mask"]),
            ScaleIntensityd(keys = "image"),
            Spacingd(
                keys = ["image", "mask"],
                pixdim = [0.75, 0.75, 2.5],
                mode = ("bilinear", "nearest"),
            ),
            # since we are not doing data augmentation during validation,
            #we simply center crop the image and mask to the specified size of [256, 256, 40]
            CenterSpatialCropd(keys = ["image", "mask"], roi_size = (256, 256, 40)), 
            SpatialPadd(keys = ["image", "mask"], spatial_size= (256, 256, 40)) # Some images are smaller than 256x256x40, so we pad them to this size
        ]
    )
    
    # Post transforms for the main prostate zones: 2 classes + background
    post_pred_transform_main    = Compose([AsDiscrete(argmax = True, to_onehot = 3)])
    post_label_transform_main   = Compose([AsDiscrete(to_onehot = 3)])

    # Post transforms for the auxilliary prostate zones: 3 classes + background
    post_pred_transform_aux     = Compose([AsDiscrete(argmax = True, to_onehot = 4)])
    post_label_transform_aux    = Compose([AsDiscrete(to_onehot = 4)])
    
    print('Transforms have been defined.')
    
    return train_transforms, val_transforms, post_pred_transform_main, post_label_transform_main, post_pred_transform_aux, post_label_transform_aux

In [None]:
import torch
from pathlib import Path
from monai.utils        import set_determinism  
from split_data         import split_data
from transforms         import get_transforms
from model              import ResidualAttention3DUnet, MTLResidualAttention3DUnet, MTLResidualAttentionRecon3DUnet
from train_model        import train_model
from test_model         import test_model
from train_model_base   import train_model_base
from test_model_base    import test_model_base

# Choose whether to train and/or test model(s)
TRAIN           = 1
TEST            = 1

# Choose which models to test
BASE_CASE       = 1
AUX_SEGMENT_3   = 1
AUX_SEGMENT_6   = 1
AUX_RECONSTRUCT = 1

# Parameters
params = {
    'BATCH_SIZE':       2,
    'MAX_EPOCHS':       100,
    'VAL_INT':          10,
    'PRINT_INT':        10
}

# Set deterministic training for reproducibility
set_determinism(seed = 2056)

# Path to data
img_path = Path("../data")
train_files, val_files, test_files = split_data(img_path, scale=28)

# Create transforms for training
train_transforms, val_transforms, pred_main, label_main, pred_aux, label_aux = get_transforms()

# Use CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define organ names in the segmentation task
all_organs =  ["Background", "Bladder", "Bone", "Obturator internus", "Transition zone", "Central gland", "Rectum", "Seminal vesicle", "Neurovascular bundle"]
organs = {
    'all': all_organs,
    'main': ["Transition zone", "Central gland"],
    'dict': {organ: idx for idx, organ in enumerate(all_organs)}
    }

############# BASE CASE #############
if BASE_CASE:
    organs['aux']  = []
    params['TASK'] = 'BASE_CASE'
    model_name     = 'base_case'
    model  = ResidualAttention3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 
    
    if TRAIN:
        torch.cuda.empty_cache()
        train_model_base(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, model_name)
    if TEST:
        torch.cuda.empty_cache()
        test_model_base(model, device, params, test_files, val_transforms, organs, pred_main, label_main, model_name)


############# AUXILIARY TASK - SEGMENT 3 EXTRA STRUCTURES #############
if AUX_SEGMENT_3:
    organs['aux']  = ["Rectum", "Seminal vesicle", "Neurovascular bundle"]
    params['TASK'] = 'SEGMENT'
    model_name     = 'auxiliary_segment_3'
    model = MTLResidualAttention3DUnet(in_channels = 1, main_out_channels = len(organs['main'])+1, aux_out_channels = len(organs['aux'])+1, device=device).to(device) 
    
    if TRAIN:
        torch.cuda.empty_cache()
        train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    if TEST:
        torch.cuda.empty_cache()
        test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
        
        
############# AUXILIARY TASK - SEGMENT 6 EXTRA STRUCTURES #############
if AUX_SEGMENT_6:
    organs['aux']  = ["Rectum", "Seminal vesicle", "Neurovascular bundle", "Bladder", "Bone", "Obturator internus"]
    params['TASK'] = 'SEGMENT'
    model_name     = 'auxiliary_segment_6'
    model = MTLResidualAttention3DUnet(in_channels = 1, main_out_channels = len(organs['main'])+1, aux_out_channels = len(organs['aux'])+1, device=device).to(device) 
    
    if TRAIN:
        torch.cuda.empty_cache()
        train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    if TEST:
        torch.cuda.empty_cache()
        test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    
    
############# AUXILIARY TASK - RECONSTRUCTION #############
if AUX_RECONSTRUCT:
    organs['aux']   = []
    params['TASK'] = 'RECONSTRUCT'
    model_name     = 'auxiliary_reconstruct'
    model = MTLResidualAttentionRecon3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 
    
    if TRAIN:
        torch.cuda.empty_cache()
        train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    if TEST:
        torch.cuda.empty_cache()
        test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    
