In [1]:
import torch
from pathlib import Path
from monai.utils        import set_determinism  
from split_data         import split_data
from transforms         import get_transforms
from model              import ResidualAttention3DUnet, MTLResidualAttention3DUnet, MTLResidualAttentionRecon3DUnet
from train_model        import train_model
from test_model         import test_model
from train_model_base   import train_model_base
from test_model_base    import test_model_base

# Choose whether to train and/or test model(s)
TRAIN           = 0
TEST            = 1

# Choose which models to test
BASE_CASE       = 1
AUX_SEGMENT_3   = 0
AUX_SEGMENT_6   = 0
AUX_RECONSTRUCT = 0

# Parameters
params = {
    'BATCH_SIZE':       2,
    'MAX_EPOCHS':       100,
    'VAL_INT':          10,
    'PRINT_INT':        10
}

# Set deterministic training for reproducibility
set_determinism(seed = 2056)

# Path to data
img_path = Path("../data")
train_files, val_files, test_files = split_data(img_path)

# Create transforms for training
train_transforms, val_transforms, pred_main, label_main, pred_aux, label_aux = get_transforms()

# Use CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define organ names in the segmentation task
all_organs =  ["Background", "Bladder", "Bone", "Obturator internus", "Transition zone", "Central gland", "Rectum", "Seminal vesicle", "Neurovascular bundle"]
organs = {
    'all': all_organs,
    'main': ["Transition zone", "Central gland"],
    'dict': {organ: idx for idx, organ in enumerate(all_organs)}
    }

KeyboardInterrupt: 

In [None]:
############# BASE CASE #############
if BASE_CASE:
    organs['aux']  = []
    params['TASK'] = 'BASE_CASE'
    model_name     = 'base_case'
    model  = ResidualAttention3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 
    
    if TRAIN:
        train_model_base(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, model_name)
    if TEST:
        test_model_base(model, device, params, test_files, val_transforms, organs, pred_main, label_main, model_name)


############# AUXILIARY TASK - SEGMENT 3 EXTRA STRUCTURES #############
if AUX_SEGMENT_3:
    organs['aux']  = ["Rectum", "Seminal vesicle", "Neurovascular bundle"]
    params['TASK'] = 'SEGMENT'
    model_name     = 'auxiliary_segment_3'
    model = MTLResidualAttention3DUnet(in_channels = 1, main_out_channels = len(organs['main'])+1, aux_out_channels = len(organs['aux'])+1, device=device).to(device) 
    
    if TRAIN:
        train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    if TEST:
        test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
        
        
############# AUXILIARY TASK - SEGMENT 6 EXTRA STRUCTURES #############
if AUX_SEGMENT_6:
    organs['aux']  = ["Rectum", "Seminal vesicle", "Neurovascular bundle", "Bladder", "Bone", "Obturator internus"]
    params['TASK'] = 'SEGMENT'
    model_name     = 'auxiliary_segment_6'
    model = MTLResidualAttention3DUnet(in_channels = 1, main_out_channels = len(organs['main'])+1, aux_out_channels = len(organs['aux'])+1, device=device).to(device) 
    
    if TRAIN:
        train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    if TEST:
        test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    
    
############# AUXILIARY TASK - RECONSTRUCTION #############
if AUX_RECONSTRUCT:
    organs['aux']   = []
    params['TASK'] = 'RECONSTRUCT'
    model_name     = 'auxiliary_reconstruct'
    model = MTLResidualAttentionRecon3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 
    
    if TRAIN:
        train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    if TEST:
        test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux, model_name)
    


RuntimeError: Error(s) in loading state_dict for ResidualAttention3DUnet:
	Missing key(s) in state_dict: "attention_blocks.0.W_g.up_sample.weight", "attention_blocks.0.W_g_norm.weight", "attention_blocks.0.W_g_norm.bias", "attention_blocks.0.W_x.weight", "attention_blocks.0.W_x_norm.weight", "attention_blocks.0.W_x_norm.bias", "attention_blocks.0.phi.weight", "attention_blocks.0.final_norm.weight", "attention_blocks.0.final_norm.bias", "attention_blocks.1.W_g.up_sample.weight", "attention_blocks.1.W_g_norm.weight", "attention_blocks.1.W_g_norm.bias", "attention_blocks.1.W_x.weight", "attention_blocks.1.W_x_norm.weight", "attention_blocks.1.W_x_norm.bias", "attention_blocks.1.phi.weight", "attention_blocks.1.final_norm.weight", "attention_blocks.1.final_norm.bias", "attention_blocks.2.W_g.up_sample.weight", "attention_blocks.2.W_g_norm.weight", "attention_blocks.2.W_g_norm.bias", "attention_blocks.2.W_x.weight", "attention_blocks.2.W_x_norm.weight", "attention_blocks.2.W_x_norm.bias", "attention_blocks.2.phi.weight", "attention_blocks.2.final_norm.weight", "attention_blocks.2.final_norm.bias", "attention_blocks.3.W_g.up_sample.weight", "attention_blocks.3.W_g_norm.weight", "attention_blocks.3.W_g_norm.bias", "attention_blocks.3.W_x.weight", "attention_blocks.3.W_x_norm.weight", "attention_blocks.3.W_x_norm.bias", "attention_blocks.3.phi.weight", "attention_blocks.3.final_norm.weight", "attention_blocks.3.final_norm.bias", "upsamples.0.up_sample.weight", "upsamples.1.up_sample.weight", "upsamples.2.up_sample.weight", "upsamples.3.up_sample.weight", "up_conv.0.first_conv.weight", "up_conv.0.first_norm.weight", "up_conv.0.first_norm.bias", "up_conv.0.second_conv.weight", "up_conv.0.second_norm.weight", "up_conv.0.second_norm.bias", "up_conv.0.shortcut.weight", "up_conv.1.first_conv.weight", "up_conv.1.first_norm.weight", "up_conv.1.first_norm.bias", "up_conv.1.second_conv.weight", "up_conv.1.second_norm.weight", "up_conv.1.second_norm.bias", "up_conv.1.shortcut.weight", "up_conv.2.first_conv.weight", "up_conv.2.first_norm.weight", "up_conv.2.first_norm.bias", "up_conv.2.second_conv.weight", "up_conv.2.second_norm.weight", "up_conv.2.second_norm.bias", "up_conv.2.shortcut.weight", "up_conv.3.first_conv.weight", "up_conv.3.first_norm.weight", "up_conv.3.first_norm.bias", "up_conv.3.second_conv.weight", "up_conv.3.second_norm.weight", "up_conv.3.second_norm.bias", "up_conv.3.shortcut.weight", "final_conv.weight". 
	Unexpected key(s) in state_dict: "attention_blocks_main.0.W_g.up_sample.weight", "attention_blocks_main.0.W_g_norm.weight", "attention_blocks_main.0.W_g_norm.bias", "attention_blocks_main.0.W_x.weight", "attention_blocks_main.0.W_x_norm.weight", "attention_blocks_main.0.W_x_norm.bias", "attention_blocks_main.0.phi.weight", "attention_blocks_main.0.final_norm.weight", "attention_blocks_main.0.final_norm.bias", "attention_blocks_main.1.W_g.up_sample.weight", "attention_blocks_main.1.W_g_norm.weight", "attention_blocks_main.1.W_g_norm.bias", "attention_blocks_main.1.W_x.weight", "attention_blocks_main.1.W_x_norm.weight", "attention_blocks_main.1.W_x_norm.bias", "attention_blocks_main.1.phi.weight", "attention_blocks_main.1.final_norm.weight", "attention_blocks_main.1.final_norm.bias", "attention_blocks_main.2.W_g.up_sample.weight", "attention_blocks_main.2.W_g_norm.weight", "attention_blocks_main.2.W_g_norm.bias", "attention_blocks_main.2.W_x.weight", "attention_blocks_main.2.W_x_norm.weight", "attention_blocks_main.2.W_x_norm.bias", "attention_blocks_main.2.phi.weight", "attention_blocks_main.2.final_norm.weight", "attention_blocks_main.2.final_norm.bias", "attention_blocks_main.3.W_g.up_sample.weight", "attention_blocks_main.3.W_g_norm.weight", "attention_blocks_main.3.W_g_norm.bias", "attention_blocks_main.3.W_x.weight", "attention_blocks_main.3.W_x_norm.weight", "attention_blocks_main.3.W_x_norm.bias", "attention_blocks_main.3.phi.weight", "attention_blocks_main.3.final_norm.weight", "attention_blocks_main.3.final_norm.bias", "upsamples_main.0.up_sample.weight", "upsamples_main.1.up_sample.weight", "upsamples_main.2.up_sample.weight", "upsamples_main.3.up_sample.weight", "up_conv_main.0.first_conv.weight", "up_conv_main.0.first_norm.weight", "up_conv_main.0.first_norm.bias", "up_conv_main.0.second_conv.weight", "up_conv_main.0.second_norm.weight", "up_conv_main.0.second_norm.bias", "up_conv_main.0.shortcut.weight", "up_conv_main.1.first_conv.weight", "up_conv_main.1.first_norm.weight", "up_conv_main.1.first_norm.bias", "up_conv_main.1.second_conv.weight", "up_conv_main.1.second_norm.weight", "up_conv_main.1.second_norm.bias", "up_conv_main.1.shortcut.weight", "up_conv_main.2.first_conv.weight", "up_conv_main.2.first_norm.weight", "up_conv_main.2.first_norm.bias", "up_conv_main.2.second_conv.weight", "up_conv_main.2.second_norm.weight", "up_conv_main.2.second_norm.bias", "up_conv_main.2.shortcut.weight", "up_conv_main.3.first_conv.weight", "up_conv_main.3.first_norm.weight", "up_conv_main.3.first_norm.bias", "up_conv_main.3.second_conv.weight", "up_conv_main.3.second_norm.weight", "up_conv_main.3.second_norm.bias", "up_conv_main.3.shortcut.weight", "final_conv_main.weight". 
	size mismatch for down_conv.0.first_conv.weight: copying a param with shape torch.Size([32, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([16, 1, 3, 3, 3]).
	size mismatch for down_conv.0.first_norm.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for down_conv.0.first_norm.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for down_conv.0.second_conv.weight: copying a param with shape torch.Size([32, 32, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([16, 16, 3, 3, 3]).
	size mismatch for down_conv.0.second_norm.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for down_conv.0.second_norm.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for down_conv.0.shortcut.weight: copying a param with shape torch.Size([32, 1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([16, 1, 1, 1, 1]).
	size mismatch for down_conv.1.first_conv.weight: copying a param with shape torch.Size([64, 32, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 16, 3, 3, 3]).
	size mismatch for down_conv.1.first_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for down_conv.1.first_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for down_conv.1.second_conv.weight: copying a param with shape torch.Size([64, 64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3, 3]).
	size mismatch for down_conv.1.second_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for down_conv.1.second_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for down_conv.1.shortcut.weight: copying a param with shape torch.Size([64, 32, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 16, 1, 1, 1]).
	size mismatch for down_conv.2.first_conv.weight: copying a param with shape torch.Size([128, 64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 32, 3, 3, 3]).
	size mismatch for down_conv.2.first_norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for down_conv.2.first_norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for down_conv.2.second_conv.weight: copying a param with shape torch.Size([128, 128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3, 3]).
	size mismatch for down_conv.2.second_norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for down_conv.2.second_norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for down_conv.2.shortcut.weight: copying a param with shape torch.Size([128, 64, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 32, 1, 1, 1]).
	size mismatch for down_conv.3.first_conv.weight: copying a param with shape torch.Size([256, 128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3, 3]).
	size mismatch for down_conv.3.first_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down_conv.3.first_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down_conv.3.second_conv.weight: copying a param with shape torch.Size([256, 256, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3, 3]).
	size mismatch for down_conv.3.second_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down_conv.3.second_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down_conv.3.shortcut.weight: copying a param with shape torch.Size([256, 128, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 1, 1, 1]).
	size mismatch for bottleneck.first_conv.weight: copying a param with shape torch.Size([512, 256, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3, 3]).
	size mismatch for bottleneck.first_norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for bottleneck.first_norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for bottleneck.second_conv.weight: copying a param with shape torch.Size([512, 512, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3, 3]).
	size mismatch for bottleneck.second_norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for bottleneck.second_norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for bottleneck.shortcut.weight: copying a param with shape torch.Size([512, 256, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1, 1]).