In [None]:
import sys
import os
import datetime
import time
import numpy as np
import tensorflow as tf
import segmentation_models as sm
import skimage
import sklearn
import sklearn.metrics

import data.patch2D
import data.patch25D
import data.patch2DM
import data.patch3D

import models.UNet
import models.UNet3D
import models.EfficientUNet

import metrics.connected_components
import metrics.distance_contour

In [None]:
(64*64*16)/(256*256) == 1
if 'ipykernel' in sys.modules:
    # params = ["2DUNET-32", "(256,256,1)", "4"]
    params = ["2DUNET-32", "(1024,1024,1)", "2"]
    # params = ["3DUNET-32", "(16,64,64,1)", "4"]
    # params = ["3DUNET-32", "(16,128,128,1)", "1"]
else:
    params = sys.argv[1:]

### Experiment setup

In [None]:
EXP_NAME = "UNET_vs_UNET3D"

# Exp params
N_PARAMS = 3
# P1 : exp/model codename
# P2 : patch size tuple
# P3 : batch size
# Check number of parameters
True if (len(params) == N_PARAMS) else exit(1)

MODELNAME = params[0]
PATCH_SIZE = tuple(map(int, params[1].replace("(","").replace(")","").split(','))) 
BATCH_SIZE = int(params[2])

# Fixed params
DATASET = "LW4_40_9"
EPOCHS = 100
TRAIN_PER_EPOCHS = 512
VALID_PER_EPOCHS = 0

# Exp related computing
if   MODELNAME == "2DUNET-8":
    # Total params: 488,689
    MODEL = models.UNet.UNet(input_shape=(None, None, 1), output_classes=9, output_activation='tanh',
             filters=8, depth=5, conv_per_block=2,
             dropouts=0.50, batch_normalization=True)
elif MODELNAME == "2DUNET-16":
    # Total params: 1,946,841
    MODEL = models.UNet.UNet(input_shape=(None, None, 1), output_classes=9, output_activation='tanh',
             filters=16, depth=5, conv_per_block=2,
             dropouts=0.50, batch_normalization=True)
elif MODELNAME == "2DUNET-32":
    # Total params: 7,771,561
    MODEL = models.UNet.UNet(input_shape=(None, None, 1), output_classes=9, output_activation='tanh',
             filters=32, depth=5, conv_per_block=2,
             dropouts=0.50, batch_normalization=True)
elif MODELNAME == "2DUNET-64":
    # Total params: 31,054,665
    MODEL = models.UNet.UNet(input_shape=(None, None, 1), output_classes=9, output_activation='tanh',
             filters=64, depth=5, conv_per_block=2,
             dropouts=0.50, batch_normalization=True)
elif MODELNAME == "3DUNET-8":
    # Total params: 1,474,433
    MODEL = models.UNet3D.UNet(input_shape=(None, None, None, 1), output_classes=9, output_activation='tanh',
             filters=8, depth=5, pool_size=(2, 2, 2), conv_per_block=2,
             dropouts=0.50, batch_normalization=True)
elif MODELNAME == "3DUNET-16_4":
    # Total params: 1,462,401
    MODEL = models.UNet3D.UNet(input_shape=(None, None, None, 1), output_classes=9, output_activation='tanh',
             filters=16, depth=4, pool_size=(2, 2, 2), conv_per_block=2,
             dropouts=0.50, batch_normalization=True)
elif MODELNAME == "3DUNET-16":
    # Total params: 5,889,921
    MODEL = models.UNet3D.UNet(input_shape=(None, None, None, 1), output_classes=9, output_activation='tanh',
             filters=16, depth=5, pool_size=(2, 2, 2), conv_per_block=2,
             dropouts=0.50, batch_normalization=True)
elif MODELNAME == "3DUNET-32_4":
    # Total params: 5,841,665
    MODEL = models.UNet3D.UNet(input_shape=(None, None, None, 1), output_classes=9, output_activation='tanh',
             filters=32, depth=4, pool_size=(2, 2, 2), conv_per_block=2,
             dropouts=0.50, batch_normalization=True)
elif MODELNAME == "3DUNET-32":
    # Total params: 23,544,065
    MODEL = models.UNet3D.UNet(input_shape=(None, None, None, 1), output_classes=9, output_activation='tanh',
             filters=32, depth=5, pool_size=(2, 2, 2), conv_per_block=2,
             dropouts=0.50, batch_normalization=True)
else:
    print("error: model does not exist")
    exit(1)

### Experiment run

In [None]:
if DATASET == "I3":
    import data.datasets.I3 as D
elif DATASET == "LW4":
    import data.datasets.LW4 as D
elif DATASET == "LW4_40_9":
    import data.datasets.LW4_40_9 as D
else:
    print("error: dataset does not exist")
    exit(1)

if(os.uname()[1] == 'lythandas'):
    OUTPUT_FOLDER = "/home/cyril/Development/NeNISt/" + EXP_NAME
else:
    OUTPUT_FOLDER = "/b/home/miv/cmeyer/NeNISt/" + EXP_NAME

if not os.path.exists(OUTPUT_FOLDER):
    os.makedirs(OUTPUT_FOLDER)

In [None]:
dt = datetime.datetime.today().strftime("%j%H%M%S%f")[:-2]
EXP_NAME = (EXP_NAME + "_" + str(MODELNAME) + "_" + str(PATCH_SIZE) + "_" + dt).replace(" ", "")

In [None]:
train_image = D.train_image_normalized_f32
train_labels_dt = D.train_labels_dt
train_labels_indexes = [D.train_label_1_indexes, D.train_label_2_indexes, D.train_label_3_indexes, D.train_label_4_indexes, D.train_label_5_indexes, D.train_label_6_indexes, D.train_label_7_indexes, D.train_label_8_indexes, D.train_label_9_indexes]

test_image = D.test_image_normalized_f32
test_labels_dt = D.test_labels_dt
test_labels_indexes = [D.test_label_1_indexes, D.test_label_2_indexes, D.test_label_3_indexes, D.test_label_4_indexes, D.test_label_5_indexes, D.test_label_6_indexes, D.test_label_7_indexes, D.test_label_8_indexes, D.test_label_9_indexes]

In [None]:
# data generator
if "3D" in MODELNAME and len(PATCH_SIZE) == 4:
    train = data.patch3D.gen_patches_batch_augmented_3d_label_indexes_one_hot(PATCH_SIZE[0], PATCH_SIZE[1], PATCH_SIZE[2], train_image, train_labels_dt, train_labels_indexes, batch_size=BATCH_SIZE)
    test = data.patch3D.gen_patches_batch_augmented_3d_label_indexes_one_hot(PATCH_SIZE[0], PATCH_SIZE[1], PATCH_SIZE[2], test_image, test_labels_dt, test_labels_indexes, batch_size=BATCH_SIZE)
elif "2D" in MODELNAME and len(PATCH_SIZE) == 3:
    if PATCH_SIZE[0] == PATCH_SIZE[1]:
        train = data.patch2D.gen_patches_batch_augmented_label_indexes_one_hot(PATCH_SIZE[0], train_image, train_labels_dt, train_labels_indexes, batch_size=BATCH_SIZE)
        test = data.patch2D.gen_patches_batch_augmented_label_indexes_one_hot(PATCH_SIZE[0], test_image, test_labels_dt, test_labels_indexes, batch_size=BATCH_SIZE)
    else:
        print("error: non square 2D patch size, check data.patch2D")
        exit(1)
else:
    print("error: patch size the model are not compatible")
    exit(1)

In [None]:
# compute class weights on train dataset
'''
class_weights = np.zeros(train_labels_one_hot.shape[-1])
for c in range(len(class_weights)):
    class_weights[c] = 1 - train_labels_one_hot[:,:,:,c].sum() / (train_labels_one_hot.shape[0] * train_labels_one_hot.shape[1] * train_labels_one_hot.shape[2])
class_weights
'''

In [None]:
# loss = sm.losses.DiceLoss(class_weights=class_weights)
loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07)

model = MODEL
model.compile(optimizer=optimizer, loss=loss)

In [None]:
model.fit(train, steps_per_epoch=TRAIN_PER_EPOCHS, epochs=EPOCHS, validation_data=test, validation_steps=32)

In [None]:
model.save_weights(OUTPUT_FOLDER + "/" + EXP_NAME + ".h5")

In [None]:
model.evaluate(test, steps=4*TRAIN_PER_EPOCHS)