In [1]:
import torch
#import argparse
from pathlib import Path
from monai.utils    import set_determinism  
from split_data     import split_data
from transforms     import get_transforms
from model          import MTLResidualAttention3DUnet
from train_model    import train_model
from test_model     import test_model

import warnings
warnings.simplefilter("ignore", UserWarning)


In [2]:

TEST = 0

# Set deterministic training for reproducibility
set_determinism(seed = 2056)

img_path = Path("../data")
train_files, val_files, test_files = split_data(img_path)

# Create transforms for training
train_transforms, val_transforms, pred_main, label_main, pred_aux, label_aux = get_transforms()

# Create an index dictionary for the organs
all_organs =  ["Background", "Bladder", "Bone", "Obturator internus", "Transition zone", "Central gland", "Rectum", "Seminal vesicle", "Neurovascular bundle"]
organs = {
    'all': all_organs,
    'main': ["Transition zone", "Central gland"],
    'aux':  ["Rectum", "Seminal vesicle", "Neurovascular bundle"],
    'dict': {organ: idx for idx, organ in enumerate(all_organs)}
    }

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MTLResidualAttention3DUnet(in_channels = 1, main_out_channels = 3, aux_out_channels = 4).to(device) #Main: 2 structures + background, Aux: 3 structures + background

train_model(model, device, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)
if TEST:
    test_model(model, device, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

----------------------------------------
Splitting data into train-validate-test sets...
468
58
59
Images have been divided into train-validate-test sets.
Total number of images:  585
Number of images train-validate-test:  468 - 58 - 59
----------------------------------------
Creating transformations...
Transforms have been defined.
--------------------
Starting model training...


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 13:51:29,699 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
2023-04-24 13:51:29,716 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 13:51:31,835 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 13:51:33,249 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 13:51:34,195 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 13:51:34,573 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


KeyboardInterrupt: 

pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 13:51:57,704 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
2023-04-24 13:51:57,704 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 13:51:59,281 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 13:51:59,593 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
  ret = func(*args, **kwargs)
  if storage.is_cuda:


2023-04-24 13:52:00,820 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


In [None]:
import torch
import pickle
from monai.data             import DataLoader, Dataset, decollate_batch
from monai.losses           import DiceLoss
from monai.metrics          import DiceMetric
from pathlib                import Path
from labels                 import modify_labels

BATCH_SIZE      = 2

def set_data(val_files, val_transforms):
    torch.cuda.empty_cache()
    val_ds = Dataset(data = val_files, transform = val_transforms)
    val_dl = DataLoader(dataset = val_ds, batch_size = BATCH_SIZE, num_workers = 4, shuffle = False)
    
    return val_dl


def set_model_params():
    # Input image has eight anatomical structures of planning interest
    dice_metric_main    = DiceMetric(include_background=False, reduction="mean")# Collect the loss and metric values for every iteration
    dice_metric_aux     = DiceMetric(include_background=False, reduction="mean")
    
    return dice_metric_main, dice_metric_aux


def save_results(MODEL_NAME, MODEL_PATH, main_metric_values, aux_metric_values):
    # Save epoch loss and metric values based on the model name
    pref = f"{MODEL_NAME.split('.')[0]}"
    with open(MODEL_PATH/f"{pref}_main_metric_values_test.pkl", "wb") as f:
        pickle.dump(main_metric_values, f)
    with open(MODEL_PATH/f"{pref}_aux_metric_values_test.pkl", "wb") as f:
        pickle.dump(aux_metric_values, f)
        

In [None]:
def test_model(model, device, val_files, val_transforms, organs_dict, pred_main, label_main, pred_aux, label_aux):
    val_dl                              = set_data(val_files, val_transforms)
    dice_metric_main, dice_metric_aux   = set_model_params()
    
    # Model save path
    MODEL_PATH = Path("models")
    MODEL_NAME = "pelvic_segmentation_model.pth"
    MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
    
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    model.eval()

    print("-" * 40)
    print("Starting model testing...")

    # Disable gradient calculation
    with torch.inference_mode():
        # Loop through the validation data
        for val_data in val_dl:
            val_inputs, val_labels = val_data["image"].permute(0, 1, 4, 2, 3).to(device), val_data["mask"].to(device)
            val_main_labels, val_aux_labels = modify_labels(val_labels, organs_dict)

            """# Forward pass
            val_main_outputs, val_aux_outputs = model(val_inputs)
            val_main_outputs, val_aux_outputs = val_main_outputs.permute(0, 1, 3, 4, 2), val_aux_outputs.permute(0, 1, 3, 4, 2)

            # Transform main outputs and labels to calculate inference loss
            val_main_outputs    = [pred_main(i) for i in decollate_batch(val_main_outputs)]
            val_main_labels     = [label_main(i) for i in decollate_batch(val_main_labels)]

            # Transform aux outputs and labels to calculate inference loss
            val_aux_outputs     = [pred_aux(i) for i in decollate_batch(val_aux_outputs)]
            val_aux_labels      = [label_aux(i) for i in decollate_batch(val_aux_labels)]
            
            # Compute dice metric for current iteration
            dice_metric_main(y_pred = val_main_outputs, y = val_main_labels)
            dice_metric_aux(y_pred = val_aux_outputs, y = val_aux_labels)"""
            
            break
        """"  
        main_metric_values      = []
        main_metric = dice_metric_main.aggregate().item()
        print('main',main_metric)
        print(main_metric_values.append(main_metric))        
        
        # Reset the metric for next validation run
        dice_metric_main.reset()
        
    print(
        f"\nMean dice for main task: {main_metric:.4f}"
        f"\nMean dice for aux task: {aux_metric:.4f}"
        )
    
    save_results(MODEL_NAME, MODEL_PATH, main_metric, aux_metric)"""

In [None]:
test_model(model, device, test_files, val_transforms, organs_dict, pred_main, label_main, pred_aux, label_aux)

----------------------------------------
Starting model testing...


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 12:51:44,312 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
2023-04-24 12:51:44,312 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
2023-04-24 12:51:44,313 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 12:51:45,377 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 12:51:45,811 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-24 12:51:47,030 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:
  ret = func(*args, **kwargs)
  if storage.is_cuda:


main 0.017180070281028748
None


In [None]:
MODEL_PATH = Path("models")
MODEL_NAME = "pelvic_segmentation_model.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
save_results(MODEL_NAME, MODEL_PATH, 0.01, 0.02)