In [3]:
import torch
from pathlib import Path
from monai.utils        import set_determinism  
from split_data         import split_data
from transforms         import get_transforms
from model              import ResidualAttention3DUnet, MTLResidualAttention3DUnet, MTLResidualAttentionRecon3DUnet
from train_model        import train_model
from test_model         import test_model
from train_model_base   import train_model_base
from test_model_base    import test_model_base

# Choose whether to train and/or test model(s)
TRAIN           = 1
TEST            = 1

# Choose which models to test
BASE_CASE       = 1
AUX_SEGMENT     = 1
AUX_RECONSTRUCT = 1

# Parameters
params = {
    'BATCH_SIZE':       2,
    'MAX_EPOCHS':       100,
    'VAL_INT':          10,
    'PRINT_INT':        10
}

# Set deterministic training for reproducibility
set_determinism(seed = 2056)

# Path to data
img_path = Path("../data")
train_files, val_files, test_files = split_data(img_path, scale=28)

# Create transforms for training
train_transforms, val_transforms, pred_main, label_main, pred_aux, label_aux = get_transforms()

# Use CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define organ names in the segmentation task
all_organs =  ["Background", "Bladder", "Bone", "Obturator internus", "Transition zone", "Central gland", "Rectum", "Seminal vesicle", "Neurovascular bundle"]
organs = {
    'all': all_organs,
    'main': ["Transition zone", "Central gland"],
    'aux': [],
    'dict': {organ: idx for idx, organ in enumerate(all_organs)}
    }

----------------------------------------
Splitting data into train-validate-test sets...
The file does not exist
The file does not exist
Images have been divided into train-validate-test sets.
Total number of images:  585
Number of images train-validate-test:  16 - 2 - 2
----------------------------------------
Creating transformations...
Transforms have been defined.


## BASE CASE

In [None]:
model  = ResidualAttention3DUnet(in_channels = 1, out_channels = len(organs['main'])+1).to(device) 

if TRAIN:
    torch.cuda.empty_cache()
    train_model_base(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main)

In [None]:
if TEST:
        torch.cuda.empty_cache()
        test_model_base(model, device, params, test_files, val_transforms, organs, pred_main, label_main)

## AUXILIARY - SEGMENT

In [None]:
organs['aux']  = ["Rectum", "Seminal vesicle", "Neurovascular bundle"]
params['TASK'] = 'SEGMENT'
model = MTLResidualAttention3DUnet(in_channels = 1, main_out_channels = len(organs['main'])+1, aux_out_channels = len(organs['aux'])+1).to(device) 

if TRAIN:
    torch.cuda.empty_cache()
    train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

In [None]:
if TEST:
    torch.cuda.empty_cache()
    test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

## AUXILIARY - RECONSTRUCT

In [None]:
organs['aux'] = []
params['TASK'] = 'RECONSTRUCT'
    
model = MTLResidualAttentionRecon3DUnet(in_channels = 1, main_out_channels = len(organs['main'])+1).to(device) 

if TRAIN:
    torch.cuda.empty_cache()
    train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

In [None]:
if TEST:
    torch.cuda.empty_cache()
    test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)