# Example U-Net organ clustering

## Imports + model initialization

In [1]:
import os
import importlib
import numpy as np
import pandas as pd
import tensorflow as tf

from sklearn.utils import shuffle as sklearn_shuffle

from loggers import set_level
from utils import plot, plot_multiple, set_display_options
from datasets import get_dataset, prepare_dataset, test_dataset_time, train_test_split
from models.detection import med_unet_clusterer
from models import get_pretrained
from utils.med_utils import *
from models.model_utils import is_model_name


set_display_options()

tf.config.set_visible_devices(tf.config.list_physical_devices('GPU')[:1], 'GPU')

input_size = (None, None, 1)
model_name = 'cluster_unet_cosine_v2'

print("Tensorflow version : {}".format(tf.__version__))
#print('# GPU(s) : {}'.format(len(tf.config.list_logical_devices('GPU'))))

2023-05-26 10:31:30.315745: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-26 10:31:30.413464: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-26 10:31:30.437905: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until 

Tensorflow version : 2.10.0


## Model creation

In [2]:
try:
    tf.config.set_visible_devices([], 'GPU')
except:
    print('Unable to modify the visible devices')

importlib.reload(med_unet_clusterer)

if 'highres' in model_name:
    voxel_dims = (0.5, 0.5, 1.5)
elif 'lowres' in model_name:
    voxel_dims = (3., 3., 3.)
else:
    voxel_dims = (1.5, 1.5, 1.5)

config = {
    'input_size' : input_size,
    'voxel_dims' : voxel_dims,
    'n_frames'   : None,
    'pad_value'  : 0,

    'normalize'       : False if 'cosine' in model_name else True,
    'embedding_dim'   : 32,
    'distance_metric' : model_name.split('_')[2],
    
    'image_normalization' : 'mean',
    
    'norm_type' : 'instance'
}

if 'scratch' in model_name:
    config.update({
        # Architecture config
        'n_stages'   : 4,
        'n_conv_per_stage'    : 1,
        'up_n_conv_per_stage' : lambda i: min(i, 1),
        'filters'     : list(np.array([16, 32, 64, 128])),
        'bnorm'       : 'never',
        'activation'  : 'leaky',
        'drop_rate'   : lambda i: 0. if i == 0 else 0.25,

        'n_middle_stages' : 2,
        'n_middle_conv'   : 2,
        'middle_filters'  : 64,
        'middle_bnorm'    : 'never',

        'concat_mode'     : lambda i: 'concat' if i > 0 else None,
    })
else:
    config['pretrained_name'] = 'totalsegmentator'

model = med_unet_clusterer.MedUNetClusterer(
    nom = model_name, ** config
)

print(model)

Initializing model with kwargs : {'model': {'architecture_name': 'totalsegmentator', 'input_shape': (None, None, None, 1), 'output_dim': 32, 'final_activation': None, 'normalize_output': True, 'norm_type': 'instance'}}


2023-05-26 10:31:10.692884: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.




Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet

nnUNet_raw_data_base is not defined and nnU-Net can only be used on data for which preprocessed files are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like this. If this is not intended, please read nnunet/paths.md for information on how to set this up properly.
nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing or training. If this is not intended, please read nnunet/pathy.md for information on how to set this up.
RESULTS_FOLDER is not defined and nnU-Net cannot be used for training or inference. If this is not intended behavior, please read nnunet/



Weights transfered successfully !
Initializing submodel : `model` !
Submodel model saved in pretrained_models/cluster_unet_euclidian_v2/saving/model.json !
Model cluster_unet_euclidian_v2 initialized successfully !

Sub model model
- Inputs 	: (None, None, None, None, 1)
- Outputs 	: (None, None, None, None, 32)
- Number of layers 	: 123
- Number of parameters 	: 30.477 Millions
- Model not compiled

Transfer-learning from : totalsegmentator
Already trained on 0 epochs (0 steps)

- Image size : (None, None, 1)
- Normalization style : mean
- Voxel dims : (1.5, 1.5, 1.5)
- # frames : variable
- Embedding dim   : 32
- Distance metric : euclidian



In [None]:
model = med_unet_clusterer.MedUNetClusterer.from_pretrained(
    nom = model_name, pretrained_name = model_name[:-2]
)
print(model)

In [3]:
model.summary()



Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, None, None,  0           []                               
                                 None, 1)]                                                        
                                                                                                  
 zero_padding3d (ZeroPadding3D)  (None, None, None,   0          ['input_image[0][0]']            
                                None, 1)                                                          
                                                                                                  
 conv_blocks_context/0/blocks/0  (None, None, None,   896        ['zero_padding3d[0][0]']         
 /conv (Conv3D)                 None, 32)                                                   

## Model instanciation + dataset loading

In [2]:
model = get_pretrained(model_name)

if 'scratch' in model_name:
    lr = {'name' : 'DivideByStep', 'maxval' : 1e-2,'minval' : 1e-4}
else:
    lr = {'name' : 'DivideByStep', 'maxval' : 1e-3,'minval' : 1e-4}

loss_config = {
    'init_w' : -1 if model.distance_metric == 'euclidian' else 1.,
    'loss_averaging'  : 'micro',
    'background_mode' : 'concat'
}
model.compile(
    optimizer = 'adam', optimizer_config = {'lr' : lr}, loss_config = loss_config
)

print(model)

Model restoration...
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for groups will be overwritten.
The given value for

2023-05-26 10:31:34.316543: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-26 10:31:34.701018: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14783 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


Initializing submodel : `model` !
Successfully restored model from pretrained_models/cluster_unet_euclidian_v2/saving/model.json !
Model cluster_unet_euclidian_v2 initialized successfully !
Optimizer 'model_optimizer' initilized successfully !
Submodel model compiled !
  Loss : {'reduction': 'none', 'name': 'ge2e_seg_loss', 'mode': 'softmax', 'init_w': -1.0, 'init_b': 0.0, 'loss_averaging': 'micro', 'distance_metric': 'euclidian', 'background_mode': 'concat'}
  Optimizer : {'name': 'Adam', 'learning_rate': {'class_name': 'DivideByStep', 'config': {'factor': <tf.Tensor: shape=(), dtype=float32, numpy=1.0>, 'minval': <tf.Tensor: shape=(), dtype=float32, numpy=1e-04>, 'maxval': <tf.Tensor: shape=(), dtype=float32, numpy=0.001>}}, 'decay': 0.0, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}
  Metrics : []

Sub model model
- Inputs 	: (None, None, None, None, 1)
- Outputs 	: (None, None, None, None, 32)
- Number of layers 	: 123
- Number of parameters 	: 30.477 Millions

In [3]:
dataset_name = 'total_segmentator'
dataset = get_dataset(dataset_name, slice_step = 16, slice_size = 32)

keep_mask = dataset['segmentation'].apply(lambda f: f.endswith('.npz'))
dataset   = dataset[keep_mask]

if isinstance(dataset, dict):
    train, valid = dataset['train'], dataset['valid']
else:
    train, valid = train_test_split(
        dataset, train_size = 0.9, shuffle = True, random_state = 10, split_by_unique = True, min_occurence = 0
    )
    train = sklearn_shuffle(train, random_state = model.epochs)

print('Dataset length ({} data skipped, {} ids) :\n  Train size : {} ({} ids)\n  Valid size : {} ({} ids)'.format(
    len(keep_mask) - np.sum(keep_mask.values), len(dataset['id'].unique()), 
    len(train), len(train['id'].unique()), len(valid), len(valid['id'].unique())
))
print('# ids in valid that are also in train : {}'.format(len([id_i for id_i in valid['id'].unique() if id_i in train['id'].values])))

Loading dataset total_segmentator...
Dataset length (53 data skipped, 1202 ids) :
  Train size : 17969 (1081 ids)
  Valid size : 2066 (121 ids)
# ids in valid that are also in train : 0


In [4]:
valid

Unnamed: 0,subject_id,thickness,images,segmentation,label,start_frame,end_frame,id
128,s0012,-1,/storage/Totalsegmentator_dataset/s0012/ct.nii.gz,/storage/Totalsegmentator_dataset/s0012/masks.npz,"[adrenal_gland_left, adrenal_gland_right, aort...",0,32,s0012
129,s0012,-1,/storage/Totalsegmentator_dataset/s0012/ct.nii.gz,/storage/Totalsegmentator_dataset/s0012/masks.npz,"[adrenal_gland_left, adrenal_gland_right, aort...",16,48,s0012
130,s0012,-1,/storage/Totalsegmentator_dataset/s0012/ct.nii.gz,/storage/Totalsegmentator_dataset/s0012/masks.npz,"[adrenal_gland_left, adrenal_gland_right, aort...",32,64,s0012
131,s0012,-1,/storage/Totalsegmentator_dataset/s0012/ct.nii.gz,/storage/Totalsegmentator_dataset/s0012/masks.npz,"[adrenal_gland_left, adrenal_gland_right, aort...",48,80,s0012
132,s0012,-1,/storage/Totalsegmentator_dataset/s0012/ct.nii.gz,/storage/Totalsegmentator_dataset/s0012/masks.npz,"[adrenal_gland_left, adrenal_gland_right, aort...",64,96,s0012
...,...,...,...,...,...,...,...,...
20083,s1403,-1,/storage/Totalsegmentator_dataset/s1403/ct.nii.gz,/storage/Totalsegmentator_dataset/s1403/masks.npz,"[adrenal_gland_left, adrenal_gland_right, aort...",240,272,s1403
20084,s1403,-1,/storage/Totalsegmentator_dataset/s1403/ct.nii.gz,/storage/Totalsegmentator_dataset/s1403/masks.npz,"[adrenal_gland_left, adrenal_gland_right, aort...",256,288,s1403
20085,s1403,-1,/storage/Totalsegmentator_dataset/s1403/ct.nii.gz,/storage/Totalsegmentator_dataset/s1403/masks.npz,"[adrenal_gland_left, adrenal_gland_right, aort...",272,304,s1403
20086,s1403,-1,/storage/Totalsegmentator_dataset/s1403/ct.nii.gz,/storage/Totalsegmentator_dataset/s1403/masks.npz,"[adrenal_gland_left, adrenal_gland_right, aort...",288,320,s1403


## Training + history analysis

In [5]:
for epochs in [3]:
    batch_size = 1

    if 'highres' in model_name:
        max_size, max_frames = 512, 32
    elif 'lowres' in model_name:
        max_size, max_frames = 128, 128
    else:
        max_size, max_frames = 256 - 64, 32

    if not isinstance(max_size, tuple):
        max_size = (max_size, max_size)
    
    augment_prct = 0.25
    shuffle_size = 0 if epochs + model.epochs < 5 else batch_size * 8

    crop_mode    = ['random_center_80', 'random_center_80', 'random']

    if 'test' in model_name:
        train = train.sample(10, random_state = 0)
        valid = valid.sample(10, random_state = 0)

    if 'scratch' not in model_name and model.epochs < 1 and not is_model_name(model.pretrained_name):
        for l in model.layers[:-2]: l.trainable = False
        epochs = 1
    else:
        for l in model.layers: l.trainable = True

    model.train(
        train, validation_data = valid, epochs = epochs, batch_size = batch_size,

        augment_prct = augment_prct, shuffle_size = shuffle_size,
        is_rectangular = True, cache = False,

        max_size = max_size, max_frames = max_frames, crop_mode = crop_mode, run_eagerly = False
    )

Training config :
HParams :
- augment_prct	: 0.25
- augment_methods	: ['noise']
- max_size	: (192, 192)
- max_frames	: 32
- crop_mode	: ['random_center_80', 'random_center_80', 'random']
- skip_empty_frames	: False
- skip_empty_labels	: True
- batch_size	: 1
- train_batch_size	: None
- valid_batch_size	: None
- test_batch_size	: 1
- shuffle_size	: 0
- epochs	: 1
- verbose	: 1
- train_times	: 1
- valid_times	: 1
- train_size	: None
- valid_size	: None
- test_size	: 4
- pred_step	: -1

Running on 1 GPU

Epoch 1 / 1


2023-05-26 10:32:02.168117: W tensorflow/core/common_runtime/forward_type_inference.cc:332] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_FLOAT
    }
  }
}

	while inferring type of node 'ge2e_seg_loss/cond/else/_621/ge2e_seg_loss/cond/cond_2/output/_1908'
2023-05-26 10:32:04.250383: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204


     21/Unknown - 42s 977ms/step - loss: nan - foreground_loss: nan - background_loss: nan         

ValueError: NaN loss at batch : 20
  Logs : {'loss': nan, 'foreground_loss': nan, 'background_loss': nan}

In [None]:
model.plot_history()
print(model.history)

In [14]:
pd.DataFrame(model.history.trainings_infos)

[{'start': datetime.datetime(2023, 5, 24, 11, 26, 58, 477270), 'end': datetime.datetime(2023, 5, 24, 14, 24, 44, 408554), 'time': 10665.931284, 'interrupted': False, 'start_epoch': -1, 'final_epoch': 0}, {'start': datetime.datetime(2023, 5, 24, 14, 29, 53, 904649), 'end': datetime.datetime(2023, 5, 25, 9, 37, 51, 689042), 'time': 68877.784393, 'interrupted': False, 'start_epoch': 0, 'final_epoch': 3}]


In [15]:
from utils import *
print(time_to_string(68877.784393))

19h 7min 57sec


In [None]:
show_memory()

## Evaluation

In [5]:
model.max_size   = (None, None)
model.max_frames = -1

for idx, row in dataset.iloc[:1].iterrows():
    image, target = model.encode_data(row)
    image  = tf.expand_dims(image, axis = 0)
    target = tf.sparse.expand_dims(target, axis = 0)
    pred   = model.infer(image, 64)
    
    print('Image shape : {} - Mask shape : {} - Output shape : {}'.format(image.shape, target.shape, pred.shape))

2023-05-15 09:20:11.324760: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204


Image shape : (1, 249, 188, 213, 1) - Mask shape : (1, 249, 188, 213, 104) - Output shape : (1, 249, 188, 213, 32)


## Tests

### Test prediction

In [5]:
model.load_checkpoint(directory = model_name[:-2])

Loading checkpoint pretrained_models/cluster_unet_cosine/saving/ckpt-161


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f914c023100>

In [None]:
if 'highres' in model_name:
    image_size   = 512 - 64
    model.max_size     = (image_size, image_size)
    model.max_frames   = 32
elif 'lowres' in model_name:
    image_size   = 128
    model.max_size     = (image_size, image_size)
    model.max_frames   = 128
else:
    image_size   = 256
    model.max_size     = (image_size, image_size)
    model.max_frames   = 32

model.pad_value = 0.
print(model.max_size, model.max_frames)
config = model.get_dataset_config(is_validation = False, batch_size = 0, shuffle_size = 0)

set_level('debug', 'datasets')

ds = prepare_dataset(train.iloc[:1], ** config)

set_level('info', 'datasets')

model.get_loss().skip_empty_frames.assign(True)
model.get_loss().skip_empty_labels.assign(True)

for inp, out in ds:
    print('Input shape : {} - output shape : {}'.format(inp.shape, out.shape))
    #plot_mask(inp[..., 0], out, n = 4)
    pred = model(tf.expand_dims(inp, axis = 0))[0]
    intersect = tf.cast(out, tf.float32) * pred
    print(model.get_loss()(tf.sparse.expand_dims(out, 0), pred[tf.newaxis]))
    print(tf.sparse.reduce_sum(intersect) / len(intersect.indices))
    print(intersect)
    plot_mask(inp[..., 0], tf.cast(pred > 0.01, tf.uint8), n = 4)
    
    

In [None]:
mask   = (pred.numpy() > 0.1).astype(np.uint8)
labels = np.argmax(mask[..., 1:], axis = -1) + 1
voxels = np.any(mask[..., 1:], axis = -1)

organs = model.labels

show_organs = [o for o in organs if (o is not None) and ('rib' in o or 'vertebr' in o )]

skip_indexes = [i for i, organ in enumerate(organs) if organ not in show_organs]

#voxels[np.any(mask[..., skip_indexes], axis = -1)] = 0
#labels[np.any(mask[..., skip_indexes], axis = -1)] = 0


In [None]:
print(np.any(mask, axis = -1).sum())
print(skip_indexes)
print(voxels.sum())

In [None]:
from utils import plot_utils

importlib.reload(plot_utils)

import matplotlib.pyplot as plt

from loggers import set_level

set_level('debug', 'utils.plot_utils')

def add_color_axis(labels, cmap = None):
    mapper = plt.cm.ScalarMappable(cmap = cmap)
    return np.reshape(mapper.to_rgba(np.reshape(labels, [-1]).tolist()), list(labels.shape) + [4])

sx, sy, sz = -3, 3, 1

plot_utils.plot(
    voxels[::sy, ::sx, ::sz].astype(bool),
    figsize    = (10, 10),
    facecolors = add_color_axis(labels[::sy, ::sx, ::sz], cmap = 'magma'),
    plot_type  = 'voxels',
    color      = None,
    is_3d      = True,
    with_legend = True,
    with_colorbar = True
)

### Test dataset performances

In [5]:
from loggers import set_level

model.max_frames = 32
model.max_size   = (256, 256)

set_level('debug', 'datasets')

config = model.get_dataset_config(is_validation = True, batch_size = 1, cache = False, prefetch = False, shuffle_size = 0)

ds_train = prepare_dataset(train.sample(10), ** config)

set_level('info', 'datasets')

test_dataset_time(ds_train, steps = 10)

Original dataset : <TensorSliceDataset element_spec={'subject_id': TensorSpec(shape=(), dtype=tf.string, name=None), 'thickness': TensorSpec(shape=(), dtype=tf.int32, name=None), 'images': TensorSpec(shape=(), dtype=tf.string, name=None), 'segmentation': TensorSpec(shape=(), dtype=tf.string, name=None), 'label': TensorSpec(shape=(104,), dtype=tf.string, name=None), 'start_frame': TensorSpec(shape=(), dtype=tf.int32, name=None), 'end_frame': TensorSpec(shape=(), dtype=tf.int32, name=None), 'id': TensorSpec(shape=(), dtype=tf.string, name=None)}>
- Dataset after encoding : <ParallelMapDataset element_spec=(TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name=None), SparseTensorSpec(TensorShape([None, None, None, None]), tf.uint8))>
- Dataset after filtering : <FilterDataset element_spec=(TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name=None), SparseTensorSpec(TensorShape([None, None, None, None]), tf.uint8))>
- Dataset after batch : <BatchDataset element_spec=(Ten

9it [00:20,  2.31s/it]


10 batchs in 20.750 sec sec (0.482 batch / sec)






TypeError: unsupported format string passed to SparseTensor.__format__

### Dataset visualization

In [None]:
if 'highres' in model_name:
    image_size   = 512 - 64
    model.max_size     = (image_size, image_size)
    model.max_frames   = 32
elif 'lowres' in model_name:
    image_size   = 128
    model.max_size     = (image_size, image_size)
    model.max_frames   = 128
else:
    image_size   = 256
    model.max_size     = (image_size, image_size)
    model.max_frames   = 32

model.pad_value = 0.
print(model.max_size, model.max_frames)
config = model.get_dataset_config(is_validation = False, batch_size = 0, shuffle_size = 0)

set_level('debug', 'datasets')

ds = prepare_dataset(train.iloc[:5], ** config)

set_level('info', 'datasets')

for inp, out in ds:
    print('Input shape : {} - output shape : {}'.format(inp.shape, out.shape))
    plot_mask(inp[..., 0], out, n = 4)
    
    

### Test processing functions

In [None]:
from tqdm import tqdm

n = 5
model.max_frames = 32
model.max_size   = (512, 512, 32)

for _, row in train.sample(n, random_state = 0).iterrows():
    inp, _ = model.get_input(row, False)
    print('Input shape : {} - {}'.format(inp.shape, inp.dtype))
    print("Original   : {} - {}".format(tf.reduce_min(inp), tf.reduce_max(inp)))
    inp = model.normalize_image(inp)
    print("Normalized : {} - {}".format(tf.reduce_min(inp), tf.reduce_max(inp)))
    inp = model.augment_input(inp)
    print("Augmented  : {} - {}".format(tf.reduce_min(inp), tf.reduce_max(inp)))


for _, row in train.sample(n, random_state = 0).iterrows():
    out = model.get_output_fn(row['segmentation'], row['label'])
    print('Output shape : {} - {} - type : {}'.format(out.shape, out.dtype, out.__class__.__name__))

for _, row in tqdm(train.sample(n, random_state = 0).iterrows()):
    inp, out = model.encode_data(row)
    print('Input shape : {} - output shape : {}'.format(inp.shape, out.shape))

In [None]:
from loggers import set_level

model.max_frames = 1
model.max_size   = (256, 256)

set_level('debug', 'datasets')

config = model.get_dataset_config(is_validation = True, batch_size = 0, cache = False, prefetch = False, shuffle_size = 0)

ds_train = prepare_dataset(train, ** config)

set_level('info', 'datasets')

for i, (inp, out) in enumerate(ds_train):
    print('Batch #{} : input shape = {} - output shape = {}'.format(i, inp.shape, out.shape))
    tf.sparse.to_dense(out)


In [None]:
print(model.voxel_dims)

img, vox_dims = load_medical_image(train.loc[0, 'images'], target_voxel_dims = model.voxel_dims)
print(img.shape)

inp = model.get_input(train.loc[0], reshape = False)
print(inp.shape)

### Configure `learning-rate scheduler`

This cell allows you to play with `learning_rate scheduler`'s parameters to get the one you want !

In [None]:
lr = model.model_optimizer.learning_rate
lr.factor = 1024.
lr.warmup_steps = 1024
lr.plot(512 * 15)