In [1]:
import torch
from pathlib import Path
from monai.utils        import set_determinism  
from split_data         import split_data
from transforms         import get_transforms
from model              import ResidualAttention3DUnet, MTLResidualAttention3DUnet, MTLResidualAttentionRecon3DUnet
from train_model        import train_model
from test_model         import test_model
from train_model_base   import train_model_base
from test_model_base    import test_model_base

In [2]:
# Choose whether to train and/or test model(s)
TRAIN           = 1
TEST            = 1

# Choose which models to test
BASE_CASE       = 1
AUX_SEGMENT     = 1
AUX_RECONSTRUCT = 1

# Parameters
params = {
    'BATCH_SIZE':       2,
    'MAX_EPOCHS':       2,
    'VAL_INTERVAL':     1,
    'PRINT_INTERVAL':   1
}

# Set deterministic training for reproducibility
set_determinism(seed = 2056)

# Path to data
img_path = Path("../data")
train_files, val_files, test_files = split_data(img_path, scale=28)

# Create transforms for training
train_transforms, val_transforms, pred_main, label_main, pred_aux, label_aux = get_transforms()

# Use CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define organ names in the segmentation task
all_organs =  ["Background", "Bladder", "Bone", "Obturator internus", "Transition zone", "Central gland", "Rectum", "Seminal vesicle", "Neurovascular bundle"]
organs = {
    'all': all_organs,
    'main': ["Transition zone", "Central gland"],
    'aux': [],
    'dict': {organ: idx for idx, organ in enumerate(all_organs)}
    }

----------------------------------------
Splitting data into train-validate-test sets...
The file does not exist
The file does not exist
Images have been divided into train-validate-test sets.
Total number of images:  585
Number of images train-validate-test:  16 - 2 - 2
----------------------------------------
Creating transformations...
Transforms have been defined.


## BASE CASE

In [3]:
model  = ResidualAttention3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 

if TRAIN:
    torch.cuda.empty_cache()
    train_model_base(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main)

--------------------
Starting model training...
--------------------
Epoch 1 / 2


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-25 23:46:02,766 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:
  t = cls([], dtype=storage.dtype, device=storage.device)



Epoch 1 average dice loss for main task: 1.0000
----------------------------------------
Testing on validation data...


  ret = func(*args, **kwargs)
  if storage.is_cuda:


saved new best metric model

Current epoch: 1 current mean dice for main task: 0.0000
Best mean dice for main task: 0.0000 at epoch: 1
Done training! Best mean dice: 0.0000 at epoch: 1


In [4]:
if TEST:        
        torch.cuda.empty_cache()
        test_model_base(model, device, params, test_files, val_transforms, organs, pred_main, label_main)

----------------------------------------
Starting model testing...


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-25 23:47:36,116 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:



Mean dice for main task: 0.0000


## AUXILIARY - SEGMENT

In [5]:
organs['aux']  = ["Rectum", "Seminal vesicle", "Neurovascular bundle"]
params['TASK'] = 'SEGMENT'
model = MTLResidualAttention3DUnet(in_channels = 1, main_out_channels = len(organs['main'])+1, aux_out_channels = len(organs['aux'])+1, device=device).to(device) 

if TRAIN:
    torch.cuda.empty_cache()
    train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

--------------------
Starting model training...
--------------------
Epoch 1 / 2


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-25 23:48:06,830 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:



Epoch 1 average loss for main task: 1.0000

Epoch 1 average loss for aux task: 1.0000

Epoch 1 average total loss for both tasks: 2.6000
----------------------------------------
Testing on validation data...


  ret = func(*args, **kwargs)
  if storage.is_cuda:


saved new best metric model

Current epoch: 1 current mean dice for main task: 0.0000
Best mean dice for main task: 0.0000 at epoch: 1
Current epoch: 1 current mean metric for aux task: 0.0000
Done training! Best mean dice: 0.0000 at epoch: 1


In [6]:
if TEST:
    torch.cuda.empty_cache()
    test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

----------------------------------------
Starting model testing...


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-25 23:52:41,426 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:



Mean dice for main task: 0.0000
Mean metric for aux task: 0.0000


## AUXILIARY - RECONSTRUCT

In [7]:
organs['aux'] = []
params['TASK'] = 'RECONSTRUCT'
    
model = MTLResidualAttentionRecon3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 

if TRAIN:
    torch.cuda.empty_cache()
    train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

--------------------
Starting model training...
--------------------
Epoch 1 / 2


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-25 23:53:53,978 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:



Epoch 1 average loss for main task: 1.0000

Epoch 1 average loss for aux task: 0.2280

Epoch 1 average total loss for both tasks: 1.4420
----------------------------------------
Testing on validation data...


  ret = func(*args, **kwargs)
  if storage.is_cuda:


saved new best metric model

Current epoch: 1 current mean dice for main task: 0.0000
Best mean dice for main task: 0.0000 at epoch: 1
Current epoch: 1 current mean metric for aux task: 0.1334
Done training! Best mean dice: 0.0000 at epoch: 1


In [8]:
organs['aux'] = []
params['TASK'] = 'RECONSTRUCT'
model = MTLResidualAttentionRecon3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 

if TEST:
    torch.cuda.empty_cache()
    test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

----------------------------------------
Starting model testing...


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-25 23:56:12,664 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:



Mean dice for main task: 0.0000
Mean metric for aux task: 0.2286
