In [8]:
# ------------------------------------------define logging and working directory
from ProjectRoot import change_wd_to_project_root
change_wd_to_project_root()
from src.utils.Notebook_imports import *
from src.utils.Tensorflow_helper import choose_gpu_by_id
from tensorflow.keras.utils import plot_model
from tensorflow.python.client import device_lib
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
import cv2
# ------------------------------------------define GPU id/s to use
GPU_IDS = '0,1'
GPUS = choose_gpu_by_id(GPU_IDS)
print(GPUS)
# ------------------------------------------jupyter magic config
%matplotlib inline
%reload_ext autoreload
%autoreload 2
# ------------------------------------------ import helpers
from src.utils.Utils_io import Console_and_file_logger, init_config
from src.visualization.Visualize import show_2D_or_3D
from src.data.Dataset import get_img_msk_files_from_split_dir, load_acdc_files, get_train_data_from_df, get_trainings_files
from src.data.Generators import DataGenerator, CycleMotionDataGenerator
from src.utils.KerasCallbacks import get_callbacks
import src.utils.Loss_and_metrics as metr
import src.models.SpatialTransformer as st
from src.models.SpatialTransformer import create_affine_cycle_transformer_model
from src.models.ModelUtils import load_pretrained_model
# ------------------------------------------path and project params
ARCHITECTURE = '3D' # 2D
DATASET = 'GCN'  # 'acdc' # or 'gcn' or different versions such as gcn_01/02
FOLD = 0 # CV fold 0-3
EXP_NAME = 'ax_sax/temp/' # Define an experiment name, could have subfolder conventions
EXPERIMENT = '{}/{}'.format(ARCHITECTURE, EXP_NAME) # Uniform path names, separation of concerns
timestemp = str(datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")) # ad a timestep to each project to make repeated experiments unique

# Our generator expects the following fix data structure (could be changed in src/data/Generators)
# any-path/
#    - AX_3D(anyname_img.nrrd and anyname_msk.nrrd)
#    - AX_to_SAX_3D
#    - SAX_3D
#    - SAX_to_AX_3D
DATA_PATH_AX = '/mnt/ssd/data/gcn/ax_sax_from_flo/AX_3D/' # path to AX 3D files
DATA_PATH_AX2SAX = '/mnt/ssd/data/gcn/ax_sax_from_flo/SAX_3D/' # path to transformed AX 3D files (target of AX)

DATA_PATH_SAX = '/mnt/ssd/data/gcn/ax_sax_from_flo/SAX_3D/' # path to SAX 3D files
DATA_PATH_SAX2AX = '/mnt/ssd/data/gcn/ax_sax_from_flo/AX_3D/' # path to transformed SAX 3D files (target of SAX)

DF_PATH = '/mnt/ssd/data/gcn/gcn_05_2020_ax_sax_86/folds.csv' # path to folds dataframe

MODEL_PATH = os.path.join('models', EXPERIMENT, timestemp)
TENSORBOARD_LOG_DIR = os.path.join('reports/tensorboard_logs', EXPERIMENT,timestemp)
CONFIG_PATH = os.path.join('reports/configs/',EXPERIMENT,timestemp)
HISTORY_PATH = os.path.join('reports/history/',EXPERIMENT,timestemp)

# ------------------------------------------static model, loss and generator hyperparameters
DIM = [80, 112, 112] # network input params for spacing of 3, (z,y,x)
DEPTH = 4 # number of down-/upsampling blocks
FILTERS = 16 # initial number of filters, will be doubled after each downsampling block
SPACING = [3, 3, 3] # if resample, resample to this spacing, (z,y,x)
M_POOL = [2, 2, 2]# size of max-pooling used for downsampling and upsampling
F_SIZE = [3, 3, 3] # conv filter size
IMG_CHANNELS = 1 # Currently our model needs that image channel
MASK_VALUES = [1, 2, 3]  #channel order: Background, RV, MYO, LV
MASK_CLASSES = len(MASK_VALUES) # no of labels
BORDER_MODE = cv2.BORDER_REFLECT_101 # border mode for the data generation
IMG_INTERPOLATION = cv2.INTER_LINEAR # image interpolation in the genarator
MSK_INTERPOLATION = cv2.INTER_NEAREST # mask interpolation in the generator
AUGMENT = False # Not implemented for the AX2SAX case
SHUFFLE = True
AUGMENT_GRID = False # Not implemented for the AX2SAX case
RESAMPLE = True
SCALER = 'MinMax' # MinMax Standard or Robust

AX_LOSS_WEIGHT = 10.0 # weighting factor of the ax2sax loss
WEIGHT_MSE_INPLANE = True # turn inplane weighting on/off
MASK_SMALLER_THAN_THRESHOLD = 0.001 # define the threshold for masking the ax2sax/sax2ax MSE loss, areas with smaller values, will be masked out

SAX_LOSS_WEIGHT = 10.0 # weighting factor of the sax2ax loss
CYCLE_LOSS = True # turn this loss on/off

FOCUS_LOSS_WEIGHT = 1.0 # weighting of the focus loss
FOCUS_LOSS = True # turn this loss on/off
USE_SAX2AX_PROB = False # apply the focus loss on AX2SAX_mask predictions, or on AX2SAX2AX_mask (back-transformed) predictions
MIN_UNET_PROBABILITY = 0.9 # threshold to count only prediction greater than this value for the focus loss

# ------------------------------------------individual training params
GENERATOR_WORKER = 2 # number of parallel workers in our generator. if not set, use batchsize, numbers greater than batchsize does not make any sense
SEED = 42 # define a seed for the generator shuffle
BATCHSIZE = 2 # 32, 64, 24, 16, 1 for 3spacing 3,3,3 use: 2
INITIAL_EPOCH = 0 # change this to continue training
EPOCHS = 300 # define a maximum numbers of epochs
EPOCHS_BETWEEN_CHECKPOINTS = 5
MONITOR_FUNCTION = 'val_loss'
MONITOR_MODE = 'min'
SAVE_MODEL_FUNCTION = 'val_loss'
SAVE_MODEL_MODE = 'min'
MODEL_PATIENCE = 20
BN_FIRST = False # decide if batch normalisation between conv and activation or afterwards
BATCH_NORMALISATION = True # apply BN or not
USE_UPSAMPLE = True # otherwise use transpose for upsampling
PAD = 'same' # padding strategy of the conv layers
KERNEL_INIT = 'he_normal' # conv weight initialisation
OPTIMIZER = 'adam' # Adam, Adagrad, RMSprop, Adadelta,  # https://keras.io/optimizers/
ACTIVATION = 'elu' # tf.keras.layers.LeakyReLU(), relu or any other non linear activation function
LEARNING_RATE = 1e-4 # start with a huge lr to converge fast
DECAY_FACTOR = 0.3 # Define a learning rate decay for the ReduceLROnPlateau callback
MIN_LR = 1e-10 # minimal lr, smaller lr does not improve the model
DROPOUT_min = 0.3 # lower dropout at the shallow layers
DROPOUT_max = 0.5 # higher dropout at the deep layers

# ------------------------------------------these metrics and loss function are meant if you continue training of the U-Net
metrics = [
    metr.dice_coef_labels,
    metr.dice_coef_myo,
    metr.dice_coef_lv,
    metr.dice_coef_rv
]
LOSS_FUNCTION = metr.bce_dice_loss

# Create a logger instance with the following setup: info or debug to console and file and error logs to a separate file
# Define a config for param injection,
# save a serialized version to load the experiment for prediction/evaluation, 
# make sure all paths exist
Console_and_file_logger(EXPERIMENT, logging.INFO)
config = init_config(config=locals(), save=True)
print(config)

2020-12-18 10:45:20,358 INFO -------------------- Start --------------------
2020-12-18 10:45:20,359 INFO Working directory: /mnt/ssd/git/3d-mri-domain-adaption.
2020-12-18 10:45:20,359 INFO Log file: ./logs/3D/ax_sax/temp/.log
2020-12-18 10:45:20,359 INFO Log level for console: INFO


search for root_dir and set working directory
Working directory set to: /mnt/ssd/git/3d-mri-domain-adaption
['/gpu:0', '/gpu:1']
{'GPU_IDS': '0,1', 'GPUS': ['/gpu:0', '/gpu:1'], 'ARCHITECTURE': '3D', 'DATASET': 'GCN', 'FOLD': 0, 'EXP_NAME': 'ax_sax/temp/', 'EXPERIMENT': '3D/ax_sax/temp/', 'DATA_PATH_AX': '/mnt/ssd/data/gcn/ax_sax_from_flo/AX_3D/', 'DATA_PATH_AX2SAX': '/mnt/ssd/data/gcn/ax_sax_from_flo/SAX_3D/', 'DATA_PATH_SAX': '/mnt/ssd/data/gcn/ax_sax_from_flo/SAX_3D/', 'DATA_PATH_SAX2AX': '/mnt/ssd/data/gcn/ax_sax_from_flo/AX_3D/', 'DF_PATH': '/mnt/ssd/data/gcn/gcn_05_2020_ax_sax_86/folds.csv', 'MODEL_PATH': 'models/3D/ax_sax/temp/2020-12-18_10_45', 'TENSORBOARD_LOG_DIR': 'reports/tensorboard_logs/3D/ax_sax/temp/2020-12-18_10_45', 'CONFIG_PATH': 'reports/configs/3D/ax_sax/temp/2020-12-18_10_45', 'HISTORY_PATH': 'reports/history/3D/ax_sax/temp/2020-12-18_10_45', 'DIM': [80, 112, 112], 'DEPTH': 4, 'FILTERS': 16, 'SPACING': [3, 3, 3], 'M_POOL': [2, 2, 2], 'F_SIZE': [3, 3, 3], 'IMG_CHAN

# Check Tensorflow setup and available GPUs

In [9]:
from tensorflow.keras.mixed_precision import experimental as mixed_precision
logging.info('Is built with tensorflow: {}'.format(tf.test.is_built_with_cuda()))
logging.info('Visible devices:\n{}'.format(tf.config.list_physical_devices()))
logging.info('Local devices: \n {}'.format(device_lib.list_local_devices()))
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
logging.info('Compute dtype: %s' % policy.compute_dtype)
logging.info('Variable dtype: %s' % policy.variable_dtype)

2020-12-18 10:45:21,098 INFO Is built with tensorflow: True
2020-12-18 10:45:21,099 INFO Visible devices:
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'), PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU'), PhysicalDevice(name='/physical_device:XLA_GPU:1', device_type='XLA_GPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]
2020-12-18 10:45:21,104 INFO Local devices: 
 [name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 10918796336304155601
, name: "/device:XLA_CPU:0"
device_type: "XLA_CPU"
memory_limit: 17179869184
locality {
}
incarnation: 17341752124371417084
physical_device_desc: "device: XLA_CPU device"
, name: "/device:XLA_GPU:0"
device_type: "XLA_GPU"
memory_limit: 17179869184
locality {
}
incarnation: 2124121724075268104
physical_device_desc

# Load trainings and validation files for the choosen fold

In [10]:
# Load AX volumes
x_train_ax, y_train_ax, x_val_ax, y_val_ax =  get_trainings_files(data_path=DATA_PATH_AX,path_to_folds_df=DF_PATH, fold=FOLD)
logging.info('AX train CMR: {}, AX train masks: {}'.format(len(x_train_ax), len(y_train_ax)))
logging.info('AX val CMR: {}, AX val masks: {}'.format(len(x_val_ax), len(y_val_ax)))

# load AX2SAX volumes
x_train_ax2sax, y_train_ax2sax, x_val_ax2sax, y_val_ax2sax =  get_trainings_files(data_path=DATA_PATH_AX2SAX,path_to_folds_df=DF_PATH, fold=FOLD)
logging.info('AX2SAX train CMR: {}, AX2SAX train masks: {}'.format(len(x_train_ax2sax), len(y_train_ax2sax)))
logging.info('AX2SAX val CMR: {}, AX2SAX val masks: {}'.format(len(x_val_ax2sax), len(y_val_ax2sax)))

# Load SAX volumes
x_train_sax, y_train_sax, x_val_sax, y_val_sax =  get_trainings_files(data_path=DATA_PATH_SAX,path_to_folds_df=DF_PATH, fold=FOLD)
logging.info('SAX train CMR: {}, SAX train masks: {}'.format(len(x_train_sax), len(y_train_sax)))
logging.info('SAX val CMR: {}, SAX val masks: {}'.format(len(x_val_sax), len(y_val_sax)))

# load SAX2AX volumes
x_train_sax2ax, y_train_sax2ax, x_val_sax2ax, y_val_sax2ax =  get_trainings_files(data_path=DATA_PATH_SAX2AX,path_to_folds_df=DF_PATH, fold=FOLD)
logging.info('SAX2AX train CMR: {}, SAX2AX train masks: {}'.format(len(x_train_sax2ax), len(y_train_sax2ax)))
logging.info('SAX2AX val CMR: {}, SAX2AX val masks: {}'.format(len(x_val_sax2ax), len(y_val_sax2ax)))

2020-12-18 10:45:21,838 INFO Found 162 images/masks in /mnt/ssd/data/gcn/ax_sax_from_flo/AX_3D/
2020-12-18 10:45:21,838 INFO Patients train: 64
2020-12-18 10:45:21,844 INFO Selected 120 of 162 files with 64 of 86 patients for training fold 0
2020-12-18 10:45:21,845 INFO AX train CMR: 120, AX train masks: 120
2020-12-18 10:45:21,845 INFO AX val CMR: 42, AX val masks: 42
2020-12-18 10:45:21,850 INFO Found 162 images/masks in /mnt/ssd/data/gcn/ax_sax_from_flo/SAX_3D/
2020-12-18 10:45:21,850 INFO Patients train: 64
2020-12-18 10:45:21,856 INFO Selected 120 of 162 files with 64 of 86 patients for training fold 0
2020-12-18 10:45:21,857 INFO AX2SAX train CMR: 120, AX2SAX train masks: 120
2020-12-18 10:45:21,857 INFO AX2SAX val CMR: 42, AX2SAX val masks: 42
2020-12-18 10:45:21,862 INFO Found 162 images/masks in /mnt/ssd/data/gcn/ax_sax_from_flo/SAX_3D/
2020-12-18 10:45:21,863 INFO Patients train: 64
2020-12-18 10:45:21,872 INFO Selected 120 of 162 files with 64 of 86 patients for training fol

In [11]:
# filter files by name, debugging purpose
#x_val_ax = [x for x in x_val_ax if '4A4PVCYL_2006' in x]
#x_val_sax = [x for x in x_val_sax if '4A4PVCYL_2006' in x]
#y_val_ax = [x for x in y_val_ax if '4A4PVCYL_2006' in x]

In [12]:
# create two generators, one for the training files, one for the validation files
batch_generator = CycleMotionDataGenerator(x=x_train_ax, y=x_train_ax2sax, x2=x_train_sax, y2=x_train_sax2ax, config=config)
valid_config = config.copy()
valid_config['AUGMENT_GRID'] = False
valid_config['AUGMENT'] = False
valid_generator = CycleMotionDataGenerator(x=x_val_ax, y=x_val_ax2sax, x2=x_val_sax, y2=x_val_sax2ax, config=valid_config)

2020-12-18 10:45:22,856 INFO Create DataGenerator
2020-12-18 10:45:22,858 INFO generator in debug mode = False
2020-12-18 10:45:22,858 INFO Datagenerator created with: 
 shape: [80, 112, 112]
 spacing: [3, 3, 3]
 batchsize: 2
 Scaler: MinMax
 Images: 120 
 Augment_grid: False 
 Thread workers: 2
2020-12-18 10:45:22,858 INFO No augmentation
2020-12-18 10:45:22,859 INFO Create DataGenerator
2020-12-18 10:45:22,860 INFO generator in debug mode = False
2020-12-18 10:45:22,860 INFO Datagenerator created with: 
 shape: [80, 112, 112]
 spacing: [3, 3, 3]
 batchsize: 2
 Scaler: MinMax
 Images: 42 
 Augment_grid: False 
 Thread workers: 2
2020-12-18 10:45:22,860 INFO No augmentation


In [13]:
# Select batch generator output
x = ''
y = ''
@interact
def select_batch(batch = (0,len(valid_generator), 1)):
    global x, y, x2, y2
    input_ , output_ = valid_generator.__getitem__(batch)
    x = input_[0]
    y = output_[0]
    x2 = input_[1]
    y2 = output_[1]
    logging.info('input elements: {}'.format(len(input_)))
    logging.info('output elements: {}'.format(len(output_)))
    logging.info(x.shape)
    logging.info(y.shape)
    logging.info(x2.shape)
    logging.info(y2.shape)

interactive(children=(IntSlider(value=10, description='batch', max=21), Output()), _dom_classes=('widget-inter…

In [14]:
@interact
def select_image_in_batch(im = (0,x.shape[0]- 1, 1),slice_by=(1,6)):
    
    # define a different logging level to make the generator steps visible
    #logging.getLogger().setLevel(logging.DEBUG)
    temp_dir = 'reports/figures/temp/'
    ensure_dir(temp_dir)

    logging.info('AX: {}'.format(x[im].shape))
    show_2D_or_3D(x[im][...,0][::slice_by])
    plt.savefig(os.path.join(temp_dir,'ax.pdf'))
    plt.show()
    logging.info('AXtoSAX: {}'.format(y[im].shape))
    show_2D_or_3D(y[im][...,0][::slice_by])
    plt.savefig(os.path.join(temp_dir,'ax2sax.pdf'))
    plt.show()
    logging.info('SAX: {}'.format(x2[im].shape))
    show_2D_or_3D(x2[im][...,0][::slice_by])
    plt.savefig(os.path.join(temp_dir,'sax.pdf'))
    plt.show()
    logging.info('SAXtoAX: {}'.format(y2[im].shape))
    show_2D_or_3D(y2[im][...,0][::slice_by])
    plt.savefig(os.path.join(temp_dir,'sax2ax.pdf'))
    plt.show()
    

interactive(children=(IntSlider(value=0, description='im', max=1), IntSlider(value=3, description='slice_by', …

In [8]:
# load a pretrained 2D unet
"""
load past config for model training 
"""
if 'config_chooser' in locals():
    config_file  = config_chooser.selected
else:
    #config_file = '/mnt/ssd/git/3d-mri-domain-adaption/reports/configs/2D/gcn_and_acdc_excl_ax/config.json' # config for TMI paper
    config_file = '/mnt/ssd/git/cardio/reports/configs/2D/gcn_05_2020_sax_excl_ax_patients/2020-11-20_17_24/config.json' # retrained with downsampling

# load config with all params into global namespace
with open(config_file, encoding='utf-8') as data_file:
    config_temp = json.loads(data_file.read())
logging.info('Load model from Experiment: {}'.format(config_temp['EXPERIMENT']))

if 'strategy' not in globals():
    # distribute the training with the mirrored data paradigm across multiple gpus if available, if not use gpu 0
    strategy = tf.distribute.MirroredStrategy(devices=config.get('GPUS', ["/gpu:0"]))
with strategy.scope():
    globals()['unet'] = load_pretrained_model(config_temp, metrics, comp=False)

2020-12-17 18:21:07,438 INFO Load model from Experiment: 2D/gcn_05_2020_sax_excl_ax_patients
2020-12-17 18:21:07,440 INFO load model with keras api
2020-12-17 18:21:09,791 INFO Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements `get_config`and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.
2020-12-17 18:21:09,792 INFO Keras API failed, use json repr. load model from: models/2D/gcn_05_2020_sax_excl_ax_patients/2020-11-20_17_24/model.json .
2020-12-17 18:21:09,793 INFO loading model description
2020-12-17 18:21:10,720 INFO loading model weights
2020-12-17 18:21:10,884 INFO model models/2D/gcn_05_2020_sax_excl_ax_patients/2020-11-20_17_24/model.json loaded


In [9]:
if 'strategy' not in globals():
    # distribute the training with the mirrored data paradigm across multiple gpus if available, if not use gpu 0
    strategy = tf.distribute.MirroredStrategy(devices=config.get('GPUS', ["/gpu:0"]))
# inject the pre-trained unet if given, otherwise build the model without the pretrained unet
with strategy.scope():
    model = st.create_affine_cycle_transformer_model(config=config, unet=locals().get('unet', None))

2020-12-17 18:21:13,224 INFO unet given, use it to max probability
2020-12-17 18:21:32,288 INFO adding ax2sax MSE loss with a weighting of 10.0
2020-12-17 18:21:32,288 INFO adding cycle loss with a weighting of 10.0
2020-12-17 18:21:32,289 INFO adding focus loss on mask_prob with a weighting of 1.0


In [10]:
model.summary(line_length=150)
#plot_model(model, to_file='reports/figures/temp_graph.pdf',show_shapes=True)

Model: "affine_cycle_transformer"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             [(None, 80, 112, 112, 1)]        0                                                                   
______________________________________________________________________________________________________________________________________________________
conv_encoder (ConvEncoder)                       ((None, 5, 7, 7, 256), [(None, 8 3537424           input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
global_average_pooling3d (GlobalAveragePooling3D (None, 256)

In [11]:
@interact
def select_image_in_batch(im = (0,x.shape[0]- 1, 1),mask_smaller_than='0.001', slice_by=(1,6)):
    global m
    import numpy as np
    temp = x[im]
    sax = x2[im]
    temp_ = y[im]
    
    mask = temp_ >float(mask_smaller_than)
    # define a different logging level to make the generator steps visible
    logging.getLogger().setLevel(logging.INFO)
    logging.info('prediction on: {}'.format(temp.shape))
    show_2D_or_3D(temp[::slice_by])
    plt.show()
    pred, inv_pred, ax2sax_mod, prob, ax_msk,m, m_mod = model.predict(x = [np.expand_dims(temp,axis=0),np.expand_dims(sax,axis=0)])                     
    logging.info('rotated by the model: {}'.format(pred[0].shape))
    show_2D_or_3D(pred[0][::slice_by], mask[::slice_by])
    plt.show()
    logging.info('inverse rotation on SAX: {}'.format(inv_pred[0].shape))
    show_2D_or_3D(inv_pred[0][::slice_by])
    plt.show()
    logging.info('predicted mask: {}'.format(inv_pred[0].shape))
    show_2D_or_3D(prob[0][::slice_by])
    plt.show()
    logging.info('predicted mask in ax: {}'.format(ax_msk[0].shape))
    show_2D_or_3D(ax_msk[0][::slice_by])
    plt.show()
    
    # calculate the loss mask from target AX2SAX image
    mask = temp_ >float(mask_smaller_than)
    logging.info('masked by GT: {}'.format(mask.shape))
    masked = pred[0] * mask
    show_2D_or_3D(masked[::slice_by], mask[::slice_by])
    plt.show()
    logging.info('target (AX2SAX): {}'.format(temp_.shape))
    show_2D_or_3D(temp_[::slice_by])
    plt.show()
    logging.info('Created MSE mask by thresholding the target (AX2SAX) with {}: {}'.format(mask_smaller_than,temp_.shape))
    show_2D_or_3D(mask[::slice_by])
    plt.show()

    try:
        from tensorflow.keras.metrics import MSE as mse
        logging.info('MSE: {}'.format(mse(pred[0], temp_).numpy().mean()))
        logging.info('prob loss: {}'.format(metr.max_volume_loss(min_probabillity=0.5)(temp_[tf.newaxis,...],prob).numpy().mean()))
        print(np.reshape(m[0],(3,4)))
        print(np.reshape(m_mod[0],(3,4)))
    except Exception as e:
        pass

interactive(children=(IntSlider(value=0, description='im', max=1), Text(value='0.001', description='mask_small…

In [12]:
# train one model
initial_epoch = 0
logging.info('Fit model, start trainings process')
# fit model with trainingsgenerator
results = model.fit(
    x=batch_generator,
    validation_data=valid_generator,
    validation_steps=len(valid_generator),
    epochs=200,
    callbacks = get_callbacks(config, valid_generator),
    steps_per_epoch = len(batch_generator),
    initial_epoch=initial_epoch,
    max_queue_size=20,
    workers=2,
    use_multiprocessing=True,
    verbose=1)

2020-12-17 18:22:26,158 INFO Fit model, start trainings process


Epoch 1/200
Epoch 00001: val_loss improved from inf to 23.18034, saving model to models/3D/ax_sax/train_on_ax_sax/fold0/2020-12-17_18_21/model.h5
Epoch 2/200
Epoch 00002: val_loss improved from 23.18034 to 22.82158, saving model to models/3D/ax_sax/train_on_ax_sax/fold0/2020-12-17_18_21/model.h5
Epoch 3/200
Epoch 00003: val_loss improved from 22.82158 to 22.17648, saving model to models/3D/ax_sax/train_on_ax_sax/fold0/2020-12-17_18_21/model.h5
Epoch 4/200
Epoch 00004: val_loss improved from 22.17648 to 21.27011, saving model to models/3D/ax_sax/train_on_ax_sax/fold0/2020-12-17_18_21/model.h5
Epoch 5/200
Epoch 00005: val_loss improved from 21.27011 to 21.19098, saving model to models/3D/ax_sax/train_on_ax_sax/fold0/2020-12-17_18_21/model.h5
Epoch 6/200
Epoch 00006: val_loss improved from 21.19098 to 20.61299, saving model to models/3D/ax_sax/train_on_ax_sax/fold0/2020-12-17_18_21/model.h5
Epoch 7/200
Epoch 00007: val_loss improved from 20.61299 to 20.30437, saving model to models/3D/ax_

In [None]:
# if, for any reason, you want to save the latest model, use this cell
#tf.keras.models.save_model(model,filepath=config['MODEL_PATH'],overwrite=True,include_optimizer=False,save_format='tf')

In [None]:
config['MODEL_PATH']

In [13]:
# Fast tests of a trained model, the "real" predictions will be done in src/notebooks/Predict

"""
load past config for model training 
"""
if 'strategy' not in locals():
    # distribute the training with the mirrored data paradigm across multiple gpus if available, if not use gpu 0
    strategy = tf.distribute.MirroredStrategy(devices=config.get('GPUS', ["/gpu:0"]))

# round the crop and pad values instead of ceil
#config_file = 'reports/configs/3D/ax_sax/unetwithdownsamplingaugmentation_new_data/2020-12-03_18_20/config.json' # Fold 0
#config_file = 'reports/configs/3D/ax_sax/unetwithdownsamplingaugmentation_new_data/2020-12-03_22_02/config.json' # Fold 1
#config_file = 'reports/configs/3D/ax_sax/unetwithdownsamplingaugmentation_new_data/2020-12-04_16_56/config.json' # Fold 2
#config_file = 'reports/configs/3D/ax_sax/unetwithdownsamplingaugmentation_new_data/2020-12-07_12_36/config.json' # Fold 3

config_file = 'reports/configs/3D/ax_sax/train_on_ax_sax/fold0/2020-12-17_11_44/config.json' # Fold 0



# load a pre-trained ax2sax model, create the graph and load the weights separately, due to own loss functions, this is easier
with open(config_file, encoding='utf-8') as data_file:
    config_temp = json.loads(data_file.read())
config_temp['LOSS_FUNCTION'] = config['LOSS_FUNCTION']
logging.info('Load model from Experiment: {}'.format(config_temp['EXPERIMENT']))

with strategy.scope():
    globals()['model'] = st.create_affine_cycle_transformer_model(config=config_temp, metrics=metrics, unet=locals().get('unet', None))
    model.load_weights(os.path.join(config_temp['MODEL_PATH'],'model.h5'))
    logging.info('loaded model weights as h5 file')

2020-12-07 14:30:29,323 INFO Load model from Experiment: 3D/ax_sax/unetwithdownsamplingaugmentation_new_data
2020-12-07 14:30:30,068 INFO unet given, use it to max probability
2020-12-07 14:30:48,204 INFO adding ax2sax MSE loss with a weighting of 10.0
2020-12-07 14:30:48,205 INFO adding cycle loss with a weighting of 10.0
2020-12-07 14:30:48,206 INFO adding focus loss on mask_prob with a weighting of 1.0
2020-12-07 14:30:48,484 INFO loaded model weights as h5 file


# Fast predictions with all files of the generator

In [10]:
# predict, visualise the transformation of AX train files
import numpy as np
cfg = config.copy()
cfg['BATCHSIZE'] = 10
cfg['AUGMENT_GRID'] = False
valid_generator = CycleMotionDataGenerator(x_train_ax, x_train_sax, cfg)
input_, output_ = valid_generator.__getitem__(0)
x_ = input_[0]
x2_ = input_[1]
y_ = output_[0]
y2_ = output_[1]
@interact
def select_image_in_batch(im = (0,x_.shape[0]- 1, 1), slice_by=(1,6)):
    global m
    temp = x_[im]
    temp_ = y_[im]
    sax = x2_[im]
    # define a different logging level to make the generator steps visible
    logging.getLogger().setLevel(logging.INFO)
    logging.info('prediction on:')
    show_2D_or_3D(temp[::slice_by])
    plt.show()
    pred, inv_pred, ax2sax_mod, pred_mask, ax2sax_msk,m_first, m = model.predict(x=[np.expand_dims(temp,axis=0),np.expand_dims(sax,axis=0)])

    logging.info('rotated by the model')
    show_2D_or_3D(pred[0][::slice_by])
    plt.show()
    logging.info('modified rotation of the model')
    show_2D_or_3D(ax2sax_mod[0][::slice_by])
    plt.show()
    logging.info('predicted mask:')
    show_2D_or_3D(pred_mask[0][::slice_by])
    plt.show()
    logging.info('target (SAX):')
    show_2D_or_3D(temp_[::slice_by])
    plt.show()
    logging.info('inverted rotation on SAX')
    show_2D_or_3D(inv_pred[0][::slice_by])
    plt.show()
    try:
        print(np.reshape(m_first[0],(3,4)))
        print(np.reshape(m[0],(3,4)))
    except Exception as e:
        pass

2020-12-03 20:07:27,707 INFO Create DataGenerator
2020-12-03 20:07:27,708 INFO Datagenerator created with: 
 shape: [80, 112, 112]
 spacing: [3, 3, 3]
 batchsize: 10
 Scaler: MinMax
 Images: 120 
 Augment_grid: False 
 Thread workers: 2
2020-12-03 20:07:27,709 INFO No augmentation


interactive(children=(IntSlider(value=4, description='im', max=9), IntSlider(value=3, description='slice_by', …

# Predictions on the heldout test split

In [11]:
cfg = config.copy()
cfg['BATCHSIZE'] = len(x_val_ax)
v_generator = CycleMotionDataGenerator(x_val_ax, x_val_sax, cfg)
input_, output_ = v_generator.__getitem__(0)
x_ = input_[0]
x2_ = input_[1]
y_ = output_[0]
y2_ = output_[1]
@interact
def select_image_in_batch(im = (0,x_.shape[0]- 1, 1), slice_by=(1,6)):
    global m
    temp = x_[im]
    temp_ = y_[im]
    sax = x2_[im]
    # define a different logging level to make the generator steps visible
    logging.getLogger().setLevel(logging.INFO)
    logging.info('prediction on:')
    show_2D_or_3D(temp[::slice_by])
    plt.show()
    
    pred, inv_pred, ax2sax_mod, pred_mask, ax_mask, m_first, m = model.predict(x=[np.expand_dims(temp,axis=0),np.expand_dims(sax,axis=0)])
    logging.info('rotated by the model')
    show_2D_or_3D(pred[0][::slice_by])
    plt.show()
    logging.info('modified rotation')
    show_2D_or_3D(ax2sax_mod[0][::slice_by])
    plt.show()
    logging.info('predicted mask')
    show_2D_or_3D(pred_mask[0][::slice_by])
    plt.show()
    logging.info('predicted in AX')
    show_2D_or_3D(ax_mask[0][::slice_by])
    plt.show()
    logging.info('target (SAX):')
    show_2D_or_3D(temp_[::slice_by])
    plt.show()
    logging.info('inverted rotation on SAX')
    show_2D_or_3D(inv_pred[0][::slice_by])
    plt.show()
    try:
        print(np.reshape(m_first[0],(3,4)))
        print(np.reshape(m[0],(3,4)))
    except Exception as e:
        pass

2020-12-03 20:08:19,977 INFO Create DataGenerator
2020-12-03 20:08:19,977 INFO Datagenerator created with: 
 shape: [80, 112, 112]
 spacing: [3, 3, 3]
 batchsize: 42
 Scaler: MinMax
 Images: 42 
 Augment_grid: False 
 Thread workers: 2
2020-12-03 20:08:19,978 INFO No augmentation


interactive(children=(IntSlider(value=20, description='im', max=41), IntSlider(value=3, description='slice_by'…

# Temp tests

In [None]:
# check the memory usage
import sys

# These are the usual ipython objects
ipython_vars = ['In', 'Out', 'exit', 'quit', 'get_ipython', 'ipython_vars']

# Get a sorted list of the objects and their sizes
sorted([(x, sys.getsizeof(globals().get(x))) for x in dir() if not x.startswith('_') and x not in sys.modules and x not in ipython_vars], key=lambda x: x[1], reverse=True)
