In [1]:
import torch
from pathlib import Path
from monai.utils        import set_determinism  
from split_data         import split_data
from transforms         import get_transforms
from model              import ResidualAttention3DUnet, MTLResidualAttention3DUnet, MTLResidualAttentionRecon3DUnet
from train_model        import train_model
from test_model         import test_model
from train_model_base   import train_model_base
from test_model_base    import test_model_base

In [2]:
# Choose whether to train and/or test model(s)
TRAIN           = 1
TEST            = 1

# Choose which models to test
BASE_CASE       = 1
AUX_SEGMENT     = 1
AUX_RECONSTRUCT = 1

# Parameters
params = {
    'BATCH_SIZE':       2,
    'MAX_EPOCHS':       2,
    'VAL_INTERVAL':     1,
    'PRINT_INTERVAL':   1
}

# Set deterministic training for reproducibility
set_determinism(seed = 2056)

# Path to data
img_path = Path("../data")
train_files, val_files, test_files = split_data(img_path, scale=28)

# Create transforms for training
train_transforms, val_transforms, pred_main, label_main, pred_aux, label_aux = get_transforms()

# Use CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define organ names in the segmentation task
all_organs =  ["Background", "Bladder", "Bone", "Obturator internus", "Transition zone", "Central gland", "Rectum", "Seminal vesicle", "Neurovascular bundle"]
organs = {
    'all': all_organs,
    'main': ["Transition zone", "Central gland"],
    'aux': [],
    'dict': {organ: idx for idx, organ in enumerate(all_organs)}
    }

----------------------------------------
Splitting data into train-validate-test sets...
The file does not exist
The file does not exist
Images have been divided into train-validate-test sets.
Total number of images:  585
Number of images train-validate-test:  16 - 2 - 2
----------------------------------------
Creating transformations...
Transforms have been defined.


## BASE CASE

In [None]:
model  = ResidualAttention3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 

if TRAIN:
    torch.cuda.empty_cache()
    train_model_base(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main)

In [None]:
if TEST:        
        torch.cuda.empty_cache()
        test_model_base(model, device, params, test_files, val_transforms, organs, pred_main, label_main)

## AUXILIARY - SEGMENT

In [7]:
organs['aux']  = ["Rectum", "Seminal vesicle", "Neurovascular bundle"]
params['TASK'] = 'SEGMENT'
model = MTLResidualAttention3DUnet(in_channels = 1, main_out_channels = len(organs['main'])+1, aux_out_channels = len(organs['aux'])+1, device=device).to(device) 

if TRAIN:
    torch.cuda.empty_cache()
    train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

--------------------
Starting model training...
--------------------
Epoch 1 / 2
2023-04-25 17:08:54,091 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
  ret = func(*args, **kwargs)
  if storage.is_cuda:



Epoch 1 average loss for main task: 1.0000

Epoch 1 average loss for aux task: 1.0000

Epoch 1 average total loss for both tasks: 2.6000
----------------------------------------
Testing on validation data...


  ret = func(*args, **kwargs)
  if storage.is_cuda:


saved new best metric model

Current epoch: 1 current mean dice for main task: 0.0000
Best mean dice for main task: 0.0000 at epoch: 1
Current epoch: 1 current mean metric for aux task: 0.0000
Done training! Best mean dice: 0.0000 at epoch: 1


In [8]:
if TEST:
    torch.cuda.empty_cache()
    test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

----------------------------------------
Starting model testing...


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-25 17:16:41,198 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:



Mean dice for main task: 0.0000
Mean metric for aux task: 0.0000


## AUXILIARY - RECONSTRUCT

In [4]:
organs['aux'] = []
params['TASK'] = 'RECONSTRUCT'
    
model = MTLResidualAttentionRecon3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 

if TRAIN:
    torch.cuda.empty_cache()
    train_model(model, device, params, train_files, train_transforms, val_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

--------------------
Starting model training...
--------------------
Epoch 1 / 2


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-04-25 20:04:50,226 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


  ret = func(*args, **kwargs)
  if storage.is_cuda:
  t = cls([], dtype=storage.dtype, device=storage.device)



Epoch 1 average loss for main task: 1.0000

Epoch 1 average loss for aux task: 0.1795

Epoch 1 average total loss for both tasks: 1.3693
----------------------------------------
Testing on validation data...


  ret = func(*args, **kwargs)
  if storage.is_cuda:


In [None]:
organs['aux'] = []
params['TASK'] = 'RECONSTRUCT'
model = MTLResidualAttentionRecon3DUnet(in_channels = 1, out_channels = len(organs['main'])+1, device=device).to(device) 

if TEST:
    torch.cuda.empty_cache()
    test_model(model, device, params, test_files, val_transforms, organs, pred_main, label_main, pred_aux, label_aux)

RuntimeError: Error(s) in loading state_dict for MTLResidualAttentionRecon3DUnet:
	Missing key(s) in state_dict: "attention_blocks.0.W_g.up_sample.weight", "attention_blocks.0.W_g_norm.weight", "attention_blocks.0.W_g_norm.bias", "attention_blocks.0.W_x.weight", "attention_blocks.0.W_x_norm.weight", "attention_blocks.0.W_x_norm.bias", "attention_blocks.0.phi.weight", "attention_blocks.0.final_norm.weight", "attention_blocks.0.final_norm.bias", "attention_blocks.1.W_g.up_sample.weight", "attention_blocks.1.W_g_norm.weight", "attention_blocks.1.W_g_norm.bias", "attention_blocks.1.W_x.weight", "attention_blocks.1.W_x_norm.weight", "attention_blocks.1.W_x_norm.bias", "attention_blocks.1.phi.weight", "attention_blocks.1.final_norm.weight", "attention_blocks.1.final_norm.bias", "attention_blocks.2.W_g.up_sample.weight", "attention_blocks.2.W_g_norm.weight", "attention_blocks.2.W_g_norm.bias", "attention_blocks.2.W_x.weight", "attention_blocks.2.W_x_norm.weight", "attention_blocks.2.W_x_norm.bias", "attention_blocks.2.phi.weight", "attention_blocks.2.final_norm.weight", "attention_blocks.2.final_norm.bias", "attention_blocks.3.W_g.up_sample.weight", "attention_blocks.3.W_g_norm.weight", "attention_blocks.3.W_g_norm.bias", "attention_blocks.3.W_x.weight", "attention_blocks.3.W_x_norm.weight", "attention_blocks.3.W_x_norm.bias", "attention_blocks.3.phi.weight", "attention_blocks.3.final_norm.weight", "attention_blocks.3.final_norm.bias", "upsamples.0.up_sample.weight", "upsamples.1.up_sample.weight", "upsamples.2.up_sample.weight", "upsamples.3.up_sample.weight", "up_conv.0.first_conv.weight", "up_conv.0.first_norm.weight", "up_conv.0.first_norm.bias", "up_conv.0.second_conv.weight", "up_conv.0.second_norm.weight", "up_conv.0.second_norm.bias", "up_conv.0.shortcut.weight", "up_conv.1.first_conv.weight", "up_conv.1.first_norm.weight", "up_conv.1.first_norm.bias", "up_conv.1.second_conv.weight", "up_conv.1.second_norm.weight", "up_conv.1.second_norm.bias", "up_conv.1.shortcut.weight", "up_conv.2.first_conv.weight", "up_conv.2.first_norm.weight", "up_conv.2.first_norm.bias", "up_conv.2.second_conv.weight", "up_conv.2.second_norm.weight", "up_conv.2.second_norm.bias", "up_conv.2.shortcut.weight", "up_conv.3.first_conv.weight", "up_conv.3.first_norm.weight", "up_conv.3.first_norm.bias", "up_conv.3.second_conv.weight", "up_conv.3.second_norm.weight", "up_conv.3.second_norm.bias", "up_conv.3.shortcut.weight", "final_conv.weight", "attention_blocks_recon.0.W_g.up_sample.weight", "attention_blocks_recon.0.W_g_norm.weight", "attention_blocks_recon.0.W_g_norm.bias", "attention_blocks_recon.0.W_x.weight", "attention_blocks_recon.0.W_x_norm.weight", "attention_blocks_recon.0.W_x_norm.bias", "attention_blocks_recon.0.phi.weight", "attention_blocks_recon.0.final_norm.weight", "attention_blocks_recon.0.final_norm.bias", "attention_blocks_recon.1.W_g.up_sample.weight", "attention_blocks_recon.1.W_g_norm.weight", "attention_blocks_recon.1.W_g_norm.bias", "attention_blocks_recon.1.W_x.weight", "attention_blocks_recon.1.W_x_norm.weight", "attention_blocks_recon.1.W_x_norm.bias", "attention_blocks_recon.1.phi.weight", "attention_blocks_recon.1.final_norm.weight", "attention_blocks_recon.1.final_norm.bias", "attention_blocks_recon.2.W_g.up_sample.weight", "attention_blocks_recon.2.W_g_norm.weight", "attention_blocks_recon.2.W_g_norm.bias", "attention_blocks_recon.2.W_x.weight", "attention_blocks_recon.2.W_x_norm.weight", "attention_blocks_recon.2.W_x_norm.bias", "attention_blocks_recon.2.phi.weight", "attention_blocks_recon.2.final_norm.weight", "attention_blocks_recon.2.final_norm.bias", "attention_blocks_recon.3.W_g.up_sample.weight", "attention_blocks_recon.3.W_g_norm.weight", "attention_blocks_recon.3.W_g_norm.bias", "attention_blocks_recon.3.W_x.weight", "attention_blocks_recon.3.W_x_norm.weight", "attention_blocks_recon.3.W_x_norm.bias", "attention_blocks_recon.3.phi.weight", "attention_blocks_recon.3.final_norm.weight", "attention_blocks_recon.3.final_norm.bias", "upsamples_recon.0.up_sample.weight", "upsamples_recon.1.up_sample.weight", "upsamples_recon.2.up_sample.weight", "upsamples_recon.3.up_sample.weight", "up_conv_recon.0.first_conv.weight", "up_conv_recon.0.first_norm.weight", "up_conv_recon.0.first_norm.bias", "up_conv_recon.0.second_conv.weight", "up_conv_recon.0.second_norm.weight", "up_conv_recon.0.second_norm.bias", "up_conv_recon.0.shortcut.weight", "up_conv_recon.1.first_conv.weight", "up_conv_recon.1.first_norm.weight", "up_conv_recon.1.first_norm.bias", "up_conv_recon.1.second_conv.weight", "up_conv_recon.1.second_norm.weight", "up_conv_recon.1.second_norm.bias", "up_conv_recon.1.shortcut.weight", "up_conv_recon.2.first_conv.weight", "up_conv_recon.2.first_norm.weight", "up_conv_recon.2.first_norm.bias", "up_conv_recon.2.second_conv.weight", "up_conv_recon.2.second_norm.weight", "up_conv_recon.2.second_norm.bias", "up_conv_recon.2.shortcut.weight", "up_conv_recon.3.first_conv.weight", "up_conv_recon.3.first_norm.weight", "up_conv_recon.3.first_norm.bias", "up_conv_recon.3.second_conv.weight", "up_conv_recon.3.second_norm.weight", "up_conv_recon.3.second_norm.bias", "up_conv_recon.3.shortcut.weight", "final_conv_recon.weight". 
	Unexpected key(s) in state_dict: "attention_blocks_main.0.W_g.up_sample.weight", "attention_blocks_main.0.W_g_norm.weight", "attention_blocks_main.0.W_g_norm.bias", "attention_blocks_main.0.W_x.weight", "attention_blocks_main.0.W_x_norm.weight", "attention_blocks_main.0.W_x_norm.bias", "attention_blocks_main.0.phi.weight", "attention_blocks_main.0.final_norm.weight", "attention_blocks_main.0.final_norm.bias", "attention_blocks_main.1.W_g.up_sample.weight", "attention_blocks_main.1.W_g_norm.weight", "attention_blocks_main.1.W_g_norm.bias", "attention_blocks_main.1.W_x.weight", "attention_blocks_main.1.W_x_norm.weight", "attention_blocks_main.1.W_x_norm.bias", "attention_blocks_main.1.phi.weight", "attention_blocks_main.1.final_norm.weight", "attention_blocks_main.1.final_norm.bias", "attention_blocks_main.2.W_g.up_sample.weight", "attention_blocks_main.2.W_g_norm.weight", "attention_blocks_main.2.W_g_norm.bias", "attention_blocks_main.2.W_x.weight", "attention_blocks_main.2.W_x_norm.weight", "attention_blocks_main.2.W_x_norm.bias", "attention_blocks_main.2.phi.weight", "attention_blocks_main.2.final_norm.weight", "attention_blocks_main.2.final_norm.bias", "attention_blocks_main.3.W_g.up_sample.weight", "attention_blocks_main.3.W_g_norm.weight", "attention_blocks_main.3.W_g_norm.bias", "attention_blocks_main.3.W_x.weight", "attention_blocks_main.3.W_x_norm.weight", "attention_blocks_main.3.W_x_norm.bias", "attention_blocks_main.3.phi.weight", "attention_blocks_main.3.final_norm.weight", "attention_blocks_main.3.final_norm.bias", "upsamples_main.0.up_sample.weight", "upsamples_main.1.up_sample.weight", "upsamples_main.2.up_sample.weight", "upsamples_main.3.up_sample.weight", "up_conv_main.0.first_conv.weight", "up_conv_main.0.first_norm.weight", "up_conv_main.0.first_norm.bias", "up_conv_main.0.second_conv.weight", "up_conv_main.0.second_norm.weight", "up_conv_main.0.second_norm.bias", "up_conv_main.0.shortcut.weight", "up_conv_main.1.first_conv.weight", "up_conv_main.1.first_norm.weight", "up_conv_main.1.first_norm.bias", "up_conv_main.1.second_conv.weight", "up_conv_main.1.second_norm.weight", "up_conv_main.1.second_norm.bias", "up_conv_main.1.shortcut.weight", "up_conv_main.2.first_conv.weight", "up_conv_main.2.first_norm.weight", "up_conv_main.2.first_norm.bias", "up_conv_main.2.second_conv.weight", "up_conv_main.2.second_norm.weight", "up_conv_main.2.second_norm.bias", "up_conv_main.2.shortcut.weight", "up_conv_main.3.first_conv.weight", "up_conv_main.3.first_norm.weight", "up_conv_main.3.first_norm.bias", "up_conv_main.3.second_conv.weight", "up_conv_main.3.second_norm.weight", "up_conv_main.3.second_norm.bias", "up_conv_main.3.shortcut.weight", "final_conv_main.weight", "attention_blocks_aux.0.W_g.up_sample.weight", "attention_blocks_aux.0.W_g_norm.weight", "attention_blocks_aux.0.W_g_norm.bias", "attention_blocks_aux.0.W_x.weight", "attention_blocks_aux.0.W_x_norm.weight", "attention_blocks_aux.0.W_x_norm.bias", "attention_blocks_aux.0.phi.weight", "attention_blocks_aux.0.final_norm.weight", "attention_blocks_aux.0.final_norm.bias", "attention_blocks_aux.1.W_g.up_sample.weight", "attention_blocks_aux.1.W_g_norm.weight", "attention_blocks_aux.1.W_g_norm.bias", "attention_blocks_aux.1.W_x.weight", "attention_blocks_aux.1.W_x_norm.weight", "attention_blocks_aux.1.W_x_norm.bias", "attention_blocks_aux.1.phi.weight", "attention_blocks_aux.1.final_norm.weight", "attention_blocks_aux.1.final_norm.bias", "attention_blocks_aux.2.W_g.up_sample.weight", "attention_blocks_aux.2.W_g_norm.weight", "attention_blocks_aux.2.W_g_norm.bias", "attention_blocks_aux.2.W_x.weight", "attention_blocks_aux.2.W_x_norm.weight", "attention_blocks_aux.2.W_x_norm.bias", "attention_blocks_aux.2.phi.weight", "attention_blocks_aux.2.final_norm.weight", "attention_blocks_aux.2.final_norm.bias", "attention_blocks_aux.3.W_g.up_sample.weight", "attention_blocks_aux.3.W_g_norm.weight", "attention_blocks_aux.3.W_g_norm.bias", "attention_blocks_aux.3.W_x.weight", "attention_blocks_aux.3.W_x_norm.weight", "attention_blocks_aux.3.W_x_norm.bias", "attention_blocks_aux.3.phi.weight", "attention_blocks_aux.3.final_norm.weight", "attention_blocks_aux.3.final_norm.bias", "upsamples_aux.0.up_sample.weight", "upsamples_aux.1.up_sample.weight", "upsamples_aux.2.up_sample.weight", "upsamples_aux.3.up_sample.weight", "up_conv_aux.0.first_conv.weight", "up_conv_aux.0.first_norm.weight", "up_conv_aux.0.first_norm.bias", "up_conv_aux.0.second_conv.weight", "up_conv_aux.0.second_norm.weight", "up_conv_aux.0.second_norm.bias", "up_conv_aux.0.shortcut.weight", "up_conv_aux.1.first_conv.weight", "up_conv_aux.1.first_norm.weight", "up_conv_aux.1.first_norm.bias", "up_conv_aux.1.second_conv.weight", "up_conv_aux.1.second_norm.weight", "up_conv_aux.1.second_norm.bias", "up_conv_aux.1.shortcut.weight", "up_conv_aux.2.first_conv.weight", "up_conv_aux.2.first_norm.weight", "up_conv_aux.2.first_norm.bias", "up_conv_aux.2.second_conv.weight", "up_conv_aux.2.second_norm.weight", "up_conv_aux.2.second_norm.bias", "up_conv_aux.2.shortcut.weight", "up_conv_aux.3.first_conv.weight", "up_conv_aux.3.first_norm.weight", "up_conv_aux.3.first_norm.bias", "up_conv_aux.3.second_conv.weight", "up_conv_aux.3.second_norm.weight", "up_conv_aux.3.second_norm.bias", "up_conv_aux.3.shortcut.weight", "final_conv_aux.weight". 
	size mismatch for down_conv.0.first_conv.weight: copying a param with shape torch.Size([32, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([16, 1, 3, 3, 3]).
	size mismatch for down_conv.0.first_norm.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for down_conv.0.first_norm.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for down_conv.0.second_conv.weight: copying a param with shape torch.Size([32, 32, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([16, 16, 3, 3, 3]).
	size mismatch for down_conv.0.second_norm.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for down_conv.0.second_norm.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for down_conv.0.shortcut.weight: copying a param with shape torch.Size([32, 1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([16, 1, 1, 1, 1]).
	size mismatch for down_conv.1.first_conv.weight: copying a param with shape torch.Size([64, 32, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 16, 3, 3, 3]).
	size mismatch for down_conv.1.first_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for down_conv.1.first_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for down_conv.1.second_conv.weight: copying a param with shape torch.Size([64, 64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3, 3]).
	size mismatch for down_conv.1.second_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for down_conv.1.second_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for down_conv.1.shortcut.weight: copying a param with shape torch.Size([64, 32, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 16, 1, 1, 1]).
	size mismatch for down_conv.2.first_conv.weight: copying a param with shape torch.Size([128, 64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 32, 3, 3, 3]).
	size mismatch for down_conv.2.first_norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for down_conv.2.first_norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for down_conv.2.second_conv.weight: copying a param with shape torch.Size([128, 128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3, 3]).
	size mismatch for down_conv.2.second_norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for down_conv.2.second_norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for down_conv.2.shortcut.weight: copying a param with shape torch.Size([128, 64, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 32, 1, 1, 1]).
	size mismatch for down_conv.3.first_conv.weight: copying a param with shape torch.Size([256, 128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3, 3]).
	size mismatch for down_conv.3.first_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down_conv.3.first_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down_conv.3.second_conv.weight: copying a param with shape torch.Size([256, 256, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3, 3]).
	size mismatch for down_conv.3.second_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down_conv.3.second_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down_conv.3.shortcut.weight: copying a param with shape torch.Size([256, 128, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 1, 1, 1]).
	size mismatch for bottleneck.first_conv.weight: copying a param with shape torch.Size([512, 256, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3, 3]).
	size mismatch for bottleneck.first_norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for bottleneck.first_norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for bottleneck.second_conv.weight: copying a param with shape torch.Size([512, 512, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3, 3]).
	size mismatch for bottleneck.second_norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for bottleneck.second_norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for bottleneck.shortcut.weight: copying a param with shape torch.Size([512, 256, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1, 1]).