In [None]:
import sys
import os
import datetime
import time
import numpy as np
import tensorflow as tf
import skimage
import sklearn
import sklearn.metrics
import sklearn.utils

import data.patch2D
import data.patch3D

import models.layers.morphology
import models.custom_backbone
import models.custom_architecture
import tism

import metrics.connected_components
import metrics.distance_contour

from utils.clean_string import clean_string

In [None]:
if 'ipykernel' in sys.modules:
    verbose = 1
    # params = ["MODEL_ARCHITECTURE", "2", "UNET", "32", "(256, 256, 1)", "16", "MSE_DT_20_20", "MitoEM-H"]
    # params = ["MODEL_ARCHITECTURE", "2", "UNET_MM_ALPHA", "32", "(256, 256, 1)", "16", "MSE_DT_20_20", "MitoEM-H"]
    # params = ["MODEL_ARCHITECTURE", "2", "UNET_MM_GAMMA", "32", "(256, 256, 1)", "16", "MSE_DT_20_20", "MitoEM-H"]
    # params = ["MODEL_ARCHITECTURE", "2", "UNET_MM_BETA", "32", "(256, 256, 1)", "16", "MSE_DT_20_20", "MitoEM-H"]
    # params = ["REU_IGBMC_9C", "3", "UNET4", "32", "(16, 128, 128, 1)", "8", "MSE_DT", "LW4_40_9"]
    # params = ["REU_IGBMC_9C", "2", "UNET", "32", "(128, 128, 1)", "64", "MSE_DT", "LW4_40_9"]
    # params = ["MODEL_ARCHITECTURE", "2", "UNETMULT", "32", "(256, 256, 1)", "16", "MSE_DT_20_20", "MitoEM-H"]
    # params = ["MODEL_ARCHITECTURE", "3", "UNETMULT", "32", "(32,256,256,1)", "1", "MSE_DT_20_20", "MitoEM"]
else:
    verbose = 0
    params = sys.argv[1:]

### Experiment setup

In [None]:
# Exp params
# params = ["str_expname", "int_model_dim", "str_modelname", "str_model_param", "tuple_input_size", "int_batch_size", "str_lossname", "str_dataset"]
N_PARAMS = 8
if not (len(params) == N_PARAMS):
    print("error: number of param", flush=True)
    exit(1)

EXP_NAME = str(params[0])
MODEL_DIM = int(params[1])
MODEL_NAME = str(params[2])
MODEL_PARAM = int(params[3])
PATCH_SIZE = tuple(map(int, params[4].replace("(","").replace(")","").split(','))) 
BATCH_SIZE = int(params[5])
LOSS = str(params[6])
DATASET = str(params[7])


if MODEL_DIM not in (2, 3):
    print("error: model dimension must be equal to 2 or 3", flush=True)
    exit(1)

# Fixed params
OUTPUT_CLASSES = 2

OUTPUT_ACT = 'sigmoid'
BINARY_THRESHOLD = 0.5
if "DT" in LOSS:
    OUTPUT_ACT = 'tanh'
    BINARY_THRESHOLD = 0

EPOCHS = 750
TRAIN_PER_EPOCHS = 300
VALID_PER_EPOCHS = 100
EARLY_PATIENCE = 20
REDUCE_PATIENCE = 10

In [None]:
# Exp related computing
if MODEL_DIM == 3:
    input_shape=(None, None, None, 1)
else:
    input_shape=(None, None, 1)

be = tism.backbone.VGG(initial_block_depth=MODEL_PARAM, initial_block_length=2, batch_normalization=True)
bd = tism.backbone.VGG(initial_block_depth=MODEL_PARAM, initial_block_length=2, batch_normalization=True)

if   MODEL_NAME == "UNET":
    MODEL = tism.model.get(architecture=tism.architecture.UNet(input_shape=input_shape, depth=5, output_classes=OUTPUT_CLASSES, output_activation=OUTPUT_ACT, op_dim=MODEL_DIM, dropout=0.50),
                          backbone_encoder=be,
                          backbone_decoder=bd)
elif MODEL_NAME == "UNET4":
    MODEL = tism.model.get(architecture=tism.architecture.UNet(input_shape=input_shape, depth=4, output_classes=OUTPUT_CLASSES, output_activation=OUTPUT_ACT, op_dim=MODEL_DIM, dropout=0.50),
                          backbone_encoder=be,
                          backbone_decoder=bd)
elif MODEL_NAME == "UNET_RES":
    be = tism.backbone.ResBlock(backbone=be)
    bd = tism.backbone.ResBlock(backbone=bd)
    MODEL = tism.model.get(architecture=tism.architecture.UNet(input_shape=input_shape, depth=5, output_classes=OUTPUT_CLASSES, output_activation=OUTPUT_ACT, op_dim=MODEL_DIM, dropout=0.50),
                          backbone_encoder=be,
                          backbone_decoder=bd)
elif MODEL_NAME == "LINKNET":
    MODEL = tism.model.get(architecture=tism.architecture.LinkNet(input_shape=input_shape, depth=5, output_classes=OUTPUT_CLASSES, output_activation=OUTPUT_ACT, op_dim=MODEL_DIM, dropout=0.50),
                          backbone_encoder=be,
                          backbone_decoder=bd)
elif MODEL_NAME == "LINKNET_RES":
    be = tism.backbone.ResBlock(backbone=be)
    bd = tism.backbone.ResBlock(backbone=bd)
    MODEL = tism.model.get(architecture=tism.architecture.LinkNet(input_shape=input_shape, depth=5, output_classes=OUTPUT_CLASSES, output_activation=OUTPUT_ACT, op_dim=MODEL_DIM, dropout=0.50),
                          backbone_encoder=be,
                          backbone_decoder=bd)
elif MODEL_NAME == "UNET_MM_ALPHA":
    be = models.custom_backbone.MM_Alpha(backbone=be)
    MODEL = tism.model.get(architecture=tism.architecture.UNet(input_shape=input_shape, depth=5, output_classes=OUTPUT_CLASSES, output_activation=OUTPUT_ACT, op_dim=MODEL_DIM, dropout=0.50),
                          backbone_encoder=be,
                          backbone_decoder=bd)
elif MODEL_NAME == "UNET_MM_BETA":
    be = models.custom_backbone.MM_Beta(backbone=be)
    MODEL = tism.model.get(architecture=tism.architecture.UNet(input_shape=input_shape, depth=5, output_classes=OUTPUT_CLASSES, output_activation=OUTPUT_ACT, op_dim=MODEL_DIM, dropout=0.50),
                          backbone_encoder=be,
                          backbone_decoder=bd)
elif MODEL_NAME == "UNET_MM_GAMMA":
    be = models.custom_backbone.MM_Gamma(backbone=be)
    MODEL = tism.model.get(architecture=tism.architecture.UNet(input_shape=input_shape, depth=5, output_classes=OUTPUT_CLASSES, output_activation=OUTPUT_ACT, op_dim=MODEL_DIM, dropout=0.50),
                          backbone_encoder=be,
                          backbone_decoder=bd)
elif MODEL_NAME == "UNETMULT":
    MODEL = tism.model.get(architecture=models.custom_architecture.UNetMultiply(input_shape=input_shape, depth=5, output_classes=OUTPUT_CLASSES, output_activation=OUTPUT_ACT, op_dim=MODEL_DIM, dropout=0.50),
                          backbone_encoder=be,
                          backbone_decoder=bd)
else:
    print("error: model does not exist", flush=True)
    exit(1)

print(EXP_NAME, str(MODEL_DIM) + "D", MODEL_NAME, MODEL_PARAM, PATCH_SIZE, BATCH_SIZE, LOSS, DATASET, flush=True)

### Experiment run

In [None]:
if DATASET == "I3":
    import data.datasets.I3 as D
elif DATASET == "LW4":
    import data.datasets.LW4 as D
elif DATASET == "LW4_40_9":
    import data.datasets.LW4_40_9 as D
elif DATASET == "MitoEM":
    if LOSS == "MSE_DT":
        import data.datasets.MitoEM_dt as D
        D.load_dt(20, 20)
    else:
        import data.datasets.MitoEM as D
elif DATASET == "MitoEM-H":
    if "DT" in LOSS:
        import data.datasets.MitoEMH_dt as D
        s = LOSS.split("_")
        if len(s) == 4:
            D.load_dt(int(s[2]), int(s[3]))
        else:
            D.load_dt(20, 20)
    else:
        import data.datasets.MitoEMH as D
elif DATASET == "MitoEM-R":
    if "DT" in LOSS:
        import data.datasets.MitoEMR_dt as D
        D.load_dt(20, 20)
    else:
        import data.datasets.MitoEMR as D
else:
    print("error: dataset does not exist", flush=True)
    exit(1)

if(os.uname()[1] == 'lythandas'):
    OUTPUT_FOLDER = "/home/cyril/Development/NeNISt/" + EXP_NAME
else:
    OUTPUT_FOLDER = "/b/home/miv/cmeyer/NeNISt/" + EXP_NAME

if not os.path.exists(OUTPUT_FOLDER):
    os.makedirs(OUTPUT_FOLDER)

In [None]:
dt = datetime.datetime.today().strftime("%j%H%M%S%f")[:-2]

CURRENT_EXP_NAME = clean_string(str(MODEL_DIM) + "D" + "_" + str(MODEL_NAME) + "_" + str(MODEL_PARAM) + "_" + str(PATCH_SIZE) + "_" + str(BATCH_SIZE) + "_" + LOSS + "_" + DATASET + "_" + dt).replace(",","x")
print(CURRENT_EXP_NAME, flush=True)

In [None]:
if DATASET == "LW4_40_9":
    train_image = D.train_image_normalized_f16
    #train_labels_dt = D.train_labels_dt
    #train_labels_indexes = [D.train_label_1_indexes, D.train_label_2_indexes, D.train_label_3_indexes, D.train_label_4_indexes, D.train_label_5_indexes, D.train_label_6_indexes, D.train_label_7_indexes, D.train_label_8_indexes, D.train_label_9_indexes]
    valid_image = D.test_image_normalized_f16
    #test_labels_dt = D.test_labels_dt
    #test_labels_indexes = [D.test_label_1_indexes, D.test_label_2_indexes, D.test_label_3_indexes, D.test_label_4_indexes, D.test_label_5_indexes, D.test_label_6_indexes, D.test_label_7_indexes, D.test_label_8_indexes, D.test_label_9_indexes]
    train_label = D.train_labels_dt
    valid_label = D.test_labels_dt

elif DATASET in ["MitoEM", "MitoEM-R", "MitoEM-H"]:
    train_image = D.train_image_normalized_f16
    train_label = D.train_label
    valid_image = D.valid_image_normalized_f16
    valid_label = D.valid_label

    # case when X and Z axis are echanged
    '''
    if not PATCH_SIZE[-1] == 1:
        train_image = np.moveaxis(train_image, 0, 2)
        train_label = np.moveaxis(train_label, 0, 2)
        valid_image = np.moveaxis(valid_image, 0, 2)
        valid_label = np.moveaxis(valid_label, 0, 2)
    '''
else:
    raise NotImplementedError

In [None]:
# class weights
'''
# compute class weights on train dataset
train_labels_one_hot = train_labels_dt > 0

# class weight ([0, 1])
class_weights = np.zeros(train_labels_one_hot.shape[-1])
for c in range(len(class_weights)):
    class_weights[c] = 1 - train_labels_one_hot[:,:,:,c].sum() / (train_labels_one_hot.shape[0] * train_labels_one_hot.shape[1] * train_labels_one_hot.shape[2])
print(np.round(class_weights, 2))

# class weight : wj=n_samples / (n_classes * n_samplesj)
class_weights = np.zeros(train_labels_one_hot.shape[-1])
for c in range(len(class_weights)):
    class_weights[c] = (train_labels_one_hot.shape[0] * train_labels_one_hot.shape[1] * train_labels_one_hot.shape[2]) / (train_labels_one_hot.shape[3] * train_labels_one_hot[:,:,:,c].sum())
print(np.round(class_weights, 2))
'''
# only run once, results :
# BCE WEIGHTS: [0.52861168 9.23769132]
'''
if LOSS == "BCE":
    weights = sklearn.utils.class_weight.compute_class_weight('balanced',
                                            classes=[0,1],
                                            y=train_label[0:10].flatten())
    print("BCE WEIGHTS:", weights)
'''
'''
weights = None
if LOSS == "BCE":
    weights = np.array([0.5, 10.0])
'''
weights = None

In [None]:
# data generator
'''
if "3D" in MODELNAME and len(PATCH_SIZE) == 4:
    train = data.patch3D.gen_patches_batch_augmented_3d_label_indexes_one_hot(PATCH_SIZE[0], PATCH_SIZE[1], PATCH_SIZE[2], train_image, train_labels_dt, train_labels_indexes, batch_size=BATCH_SIZE)
    test = data.patch3D.gen_patches_batch_augmented_3d_label_indexes_one_hot(PATCH_SIZE[0], PATCH_SIZE[1], PATCH_SIZE[2], test_image, test_labels_dt, test_labels_indexes, batch_size=BATCH_SIZE)
elif "2D" in MODELNAME and len(PATCH_SIZE) == 3:
    if PATCH_SIZE[0] == PATCH_SIZE[1]:
        train = data.patch2D.gen_patches_batch_augmented_label_indexes_one_hot(PATCH_SIZE[0], train_image, train_labels_dt, train_labels_indexes, batch_size=BATCH_SIZE)
        test = data.patch2D.gen_patches_batch_augmented_label_indexes_one_hot(PATCH_SIZE[0], test_image, test_labels_dt, test_labels_indexes, batch_size=BATCH_SIZE)
    else:
        print("error: non square 2D patch size, check data.patch2D")
        exit(1)
else:
    print("error: patch size the model are not compatible")
    exit(1)
'''

if OUTPUT_CLASSES > 2:
    if len(PATCH_SIZE) == 3:
        train = data.patch2D.gen_patches_batch_augmented_2d(PATCH_SIZE[0], PATCH_SIZE[1], OUTPUT_CLASSES, train_image, train_label, batch_size=BATCH_SIZE, weights=weights)
        valid = data.patch2D.gen_patches_batch_augmented_2d(PATCH_SIZE[0], PATCH_SIZE[1], OUTPUT_CLASSES, valid_image, valid_label, batch_size=BATCH_SIZE, weights=weights)
    elif len(PATCH_SIZE) == 4:
        train = data.patch3D.gen_patches_batch_augmented_3d(PATCH_SIZE[0], PATCH_SIZE[1], PATCH_SIZE[2], OUTPUT_CLASSES, train_image, train_label, batch_size=BATCH_SIZE, weights=weights)
        valid = data.patch3D.gen_patches_batch_augmented_3d(PATCH_SIZE[0], PATCH_SIZE[1], PATCH_SIZE[2], OUTPUT_CLASSES, valid_image, valid_label, batch_size=BATCH_SIZE, weights=weights)
    else:
        print("error: patch size", flush=True)
        exit(1)
else:
    if len(PATCH_SIZE) == 3:
        train = data.patch2D.gen_patches_batch_augmented_2d_bin(PATCH_SIZE[0], PATCH_SIZE[1], train_image, train_label, batch_size=BATCH_SIZE, weights=weights)
        valid = data.patch2D.gen_patches_batch_augmented_2d_bin(PATCH_SIZE[0], PATCH_SIZE[1], valid_image, valid_label, batch_size=BATCH_SIZE, weights=weights)
    elif len(PATCH_SIZE) == 4:
        train = data.patch3D.gen_patches_batch_augmented_3d_bin(PATCH_SIZE[0], PATCH_SIZE[1], PATCH_SIZE[2], train_image, train_label, batch_size=BATCH_SIZE, weights=weights)
        valid = data.patch3D.gen_patches_batch_augmented_3d_bin(PATCH_SIZE[0], PATCH_SIZE[1], PATCH_SIZE[2], valid_image, valid_label, batch_size=BATCH_SIZE, weights=weights)
    else:
        print("error: patch size", flush=True)
        exit(1)

In [None]:
# losses
from tensorflow.keras import backend as K
def jaccard_distance_loss(y_true, y_pred, smooth=100):
    # https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return (1 - jac) * smooth

def dice_coef(y_true, y_pred, smooth=1):
    # https://gist.github.com/wassname/7793e2058c5c9dacb5212c0ac0b18a8a
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

In [None]:
# loss = sm.losses.DiceLoss(class_weights=class_weights)
if "MSE" in LOSS:
    loss = tf.keras.losses.MeanSquaredError()
elif "MAE_DT" in LOSS:
    loss = tf.keras.losses.MeanAbsoluteError()
elif LOSS == "DICE":
    loss = dice_coef_loss
elif LOSS == "JACCARD":
    loss = jaccard_distance_loss
elif LOSS == "SM_DICE":
    loss = sm.losses.dice_loss
elif LOSS == "SM_JACCARD":
    loss = sm.losses.jaccard_loss
elif LOSS == "SM_BIN_FOCAL":
    loss = sm.losses.binary_focal_loss
elif LOSS == "BCE":
    loss = tf.keras.losses.BinaryCrossentropy()
else:
    print("error: loss name", flush=True)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07)

model = MODEL
model.compile(optimizer=optimizer, loss=loss) # metrics=[tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.BinaryCrossentropy()]

In [None]:
# earlystopping = tf.keras.callbacks.EarlyStopping(monitor ="val_loss", mode ="min", patience=EARLY_PATIENCE, restore_best_weights = True)
reducelrplateau = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=REDUCE_PATIENCE)
checkpoint_path = OUTPUT_FOLDER + "/" + CURRENT_EXP_NAME + ".h5"
savebestmodel = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, monitor='val_loss', mode='min', save_best_only=True)

In [None]:
t0 = time.time()

# fast test
if 'ipykernel' in sys.modules:
    EPOCHS = 10
    TRAIN_PER_EPOCHS = 64
    VALID_PER_EPOCHS = 6

fit_history = model.fit(train, steps_per_epoch=TRAIN_PER_EPOCHS, epochs=EPOCHS,
                        validation_data=valid, validation_steps=VALID_PER_EPOCHS,
                        verbose=verbose, callbacks=[reducelrplateau, savebestmodel])

t1 = time.time()
train_time = int(t1-t0)

In [None]:
# if not savebestmodel
# checkpoint_path = OUTPUT_FOLDER + "/" + CURRENT_EXP_NAME + ".h5"
# model.save_weights(checkpoint_path)

# model.save_weights(OUTPUT_FOLDER + "/" + CURRENT_EXP_NAME + ".h5")
# model.evaluate(valid, steps=TRAIN_PER_EPOCHS)

In [None]:
model.load_weights(checkpoint_path)

In [None]:
# model.save_weights(checkpoint_path)
# model.save(OUTPUT_FOLDER + "/MODEL_" + CURRENT_EXP_NAME + ".h5")

In [None]:
n_epochs = len(fit_history.history['loss'])
f_history_name = OUTPUT_FOLDER + "/" + clean_string(CURRENT_EXP_NAME + "_history_" + str(n_epochs)) + ".txt"
f = open(f_history_name, "w")
f.write(str(fit_history.history))
f.close()

### Evaluation

In [None]:
f_results_name = OUTPUT_FOLDER + "/" + clean_string(CURRENT_EXP_NAME) + ".csv"
f_results = open(f_results_name, "w")
if len(PATCH_SIZE) == 3 and PATCH_SIZE[-1] == 1:
    f_results.write("modelname,train_time,ACC_2048,IOU_2048,F1_2048,MEAN_DIST_GD_2048,MEAN_DIST_PRED_2048,CC_TP_2048,CC_FN_2048,CC_UD_2048,ACC_384,IOU_384,F1_384\n")
elif len(PATCH_SIZE) == 3:
    f_results.write("modelname,train_time,ACC_384,IOU_384,F1_384,MEAN_DIST_GD_384,MEAN_DIST_PRED_384,CC_TP_384,CC_FN_384,CC_UD_384\n")
elif len(PATCH_SIZE) == 4:
    f_results.write("modelname,train_time,ACC_384,IOU_384,F1_384,MEAN_DIST_GD_384,MEAN_DIST_PRED_384,CC_TP_384,CC_FN_384,CC_UD_384\n")
f_results.close()

modelname = clean_string(str(MODEL_DIM) + "D" + "_" + str(MODEL_NAME) + "_" + str(MODEL_PARAM) + "_" + str(PATCH_SIZE) + "_" + str(BATCH_SIZE) + "_" + LOSS + "_" + DATASET).replace(",","x")

In [None]:
test_image = valid_image
test_label = valid_label
if PATCH_SIZE[-1] == 1:
    test_section_label = test_label[10:10+50, 1250:1250+1500, 1250:1250+1500]
    test_section_pred  = np.zeros(test_section_label.shape, dtype=test_section_label.dtype)
else:
    test_section_label = test_label[1250:1250+1500, 1250:1250+1500, 10:10+50]
    test_section_pred  = np.zeros(test_section_label.shape, dtype=test_section_label.dtype)

In [None]:
# Proof
'''
for z in range(50):
    p = np.expand_dims(np.expand_dims(test_label[z+10, 1000:1000+2048, 1000:1000+2048], -1), 0)
    p = p[0,250:250+1500,250:250+1500,0]
    test_section_pred[z] = p
print((test_section_pred == test_section_label).all())

for z in range(50):
    for y_ in range(4):
        for x_ in range(4):
            x = x_*384
            y = y_*384
            pad = 18
            p = np.expand_dims(np.expand_dims(test_label[z+10, 1250-pad+y:1250-pad+384+y, 1250-pad+x:1250-pad+384+x], -1), 0)
            p = p[0, :, :, 0]
            
            if x_ == 0:
                p = p[:, pad:]
            else:
                x = x-pad
            if x_ == 3:
                p = p[:, :-pad]
            if y_ == 0:
                p = p[pad:, :]
            else:
                y = y-pad
            if y_ == 3:
                p = p[:-pad, :]

            test_section_pred[z, y:y+p.shape[0], x:x+p.shape[1]] = p[:, :]

print((test_section_pred == test_section_label).all())


for y_ in range(4):
    for x_ in range(4):
        x = x_*384
        y = y_*384
        pad = 18
        p = np.expand_dims(np.expand_dims(test_label[0:64, 1250-pad+y:1250-pad+384+y, 1250-pad+x:1250-pad+384+x], -1), 0)
        
        p = p[0,10:10+50, :, :, 0]

        if x_ == 0:
            p = p[:, :, pad:]
        else:
            x = x-pad
        if x_ == 3:
            p = p[:, :, :-pad]
        if y_ == 0:
            p = p[:, pad:, :]
        else:
            y = y-pad
        if y_ == 3:
            p = p[:, :-pad, :]
        
        test_section_pred[:, y:y+p.shape[1], x:x+p.shape[2]] = p[:, :, :]
print((test_section_pred == test_section_label).all())
'''
'''
for y_ in range(4):
    for x_ in range(4):
        x = x_*384
        y = y_*384
        pad = 18
        p = np.expand_dims(test_label[1250-pad+x:1250-pad+384+x, 1250-pad+y:1250-pad+384+y, 0:64], 0)
        
        p = p[0, :, :, 10:10+50]

        if x_ == 0:
            p = p[pad:, :, :]
        else:
            x = x-pad
        if x_ == 3:
            p = p[:-pad, :, :]
        if y_ == 0:
            p = p[:, pad:, :]
        else:
            y = y-pad
        if y_ == 3:
            p = p[:, :-pad, :]
            
        test_section_pred[x:x+p.shape[0], y:y+p.shape[1], :] = p[:, :, :]
print((test_section_pred == test_section_label).all())
'''

In [None]:
if len(PATCH_SIZE) == 3 and PATCH_SIZE[-1] == 1:
    # pred 2048 x 2048
    test_section_pred.fill(0)
    for z in range(50):
        p = np.expand_dims(np.expand_dims(test_image[z+10, 1000:1000+2048, 1000:1000+2048], -1), 0)
        p = model.predict(p)[0,250:250+1500,250:250+1500,0]
        test_section_pred[z] = p
    
    test_section_label_f = (test_section_label>BINARY_THRESHOLD).flatten()
    test_section_pred_f = (test_section_pred>BINARY_THRESHOLD).flatten()
    acc_2048 = sklearn.metrics.accuracy_score(test_section_label_f, test_section_pred_f)
    iou_2048 = sklearn.metrics.jaccard_score(test_section_label_f, test_section_pred_f)
    f1_2048 = sklearn.metrics.f1_score(test_section_label_f, test_section_pred_f)
    mean_dist_contour_2048 = metrics.distance_contour.distance_contour_segmentation_3D((test_section_label>BINARY_THRESHOLD), (test_section_pred>BINARY_THRESHOLD))
    mean_dist_gd_2048 = (mean_dist_contour_2048[0])[0]
    mean_dist_pred_2048 = (mean_dist_contour_2048[1])[0]
    cc_2048, cc_tp_2048, cc_fn_2048, cc_ud_2048 = metrics.connected_components.connected_components_detection((test_section_label>BINARY_THRESHOLD), (test_section_pred>BINARY_THRESHOLD), 0.75)
    
    # pred 384 x 384
    test_section_pred.fill(0)
    for z in range(50):
        for y_ in range(4):
            for x_ in range(4):
                x = x_*384
                y = y_*384
                pad = 18
                p = np.expand_dims(np.expand_dims(test_image[z+10, 1250-pad+y:1250-pad+384+y, 1250-pad+x:1250-pad+384+x], -1), 0)
                p = model.predict(p)[0, :, :, 0]

                if x_ == 0:
                    p = p[:, pad:]
                else:
                    x = x-pad
                if x_ == 3:
                    p = p[:, :-pad]
                if y_ == 0:
                    p = p[pad:, :]
                else:
                    y = y-pad
                if y_ == 3:
                    p = p[:-pad, :]

                test_section_pred[z, y:y+p.shape[0], x:x+p.shape[1]] = p[:, :]
        
    test_section_label_f = (test_section_label>BINARY_THRESHOLD).flatten()
    test_section_pred_f = (test_section_pred>BINARY_THRESHOLD).flatten()
    acc_384 = sklearn.metrics.accuracy_score(test_section_label_f, test_section_pred_f)
    iou_384 = sklearn.metrics.jaccard_score(test_section_label_f, test_section_pred_f)
    f1_384 = sklearn.metrics.f1_score(test_section_label_f, test_section_pred_f)
    
    f_results = open(f_results_name, "a")
    f_results.write(modelname + ",")
    f_results.write(str(train_time) + ",")
    f_results.write(str(acc_2048) + ",")
    f_results.write(str(iou_2048) + ",")
    f_results.write(str(f1_2048) + ",")
    f_results.write(str(mean_dist_gd_2048) + ",")
    f_results.write(str(mean_dist_pred_2048) + ",")
    f_results.write(str(cc_tp_2048) + ",")
    f_results.write(str(cc_fn_2048) + ",")
    f_results.write(str(cc_ud_2048) + ",")
    f_results.write(str(acc_384) + ",")
    f_results.write(str(iou_384) + ",")
    f_results.write(str(f1_384) + "\n")
    f_results.close()

elif len(PATCH_SIZE) == 4:
    test_section_pred.fill(0)
    for y_ in range(4):
        for x_ in range(4):
            x = x_*384
            y = y_*384
            pad = 18
            p = np.expand_dims(np.expand_dims(test_image[0:64, 1250-pad+y:1250-pad+384+y, 1250-pad+x:1250-pad+384+x], -1), 0)
            p = model.predict(p)[0,10:10+50, :, :, 0]

            if x_ == 0:
                p = p[:, :, pad:]
            else:
                x = x-pad
            if x_ == 3:
                p = p[:, :, :-pad]
            if y_ == 0:
                p = p[:, pad:, :]
            else:
                y = y-pad
            if y_ == 3:
                p = p[:, :-pad, :]

            test_section_pred[:, y:y+p.shape[1], x:x+p.shape[2]] = p[:, :, :]
    
    test_section_label_f = (test_section_label>BINARY_THRESHOLD).flatten()
    test_section_pred_f = (test_section_pred>BINARY_THRESHOLD).flatten()
    acc_384 = sklearn.metrics.accuracy_score(test_section_label_f, test_section_pred_f)
    iou_384 = sklearn.metrics.jaccard_score(test_section_label_f, test_section_pred_f)
    f1_384 = sklearn.metrics.f1_score(test_section_label_f, test_section_pred_f)
    mean_dist_contour_384 = metrics.distance_contour.distance_contour_segmentation_3D((test_section_label>BINARY_THRESHOLD), (test_section_pred>BINARY_THRESHOLD))
    mean_dist_gd_384 = (mean_dist_contour_384[0])[0]
    mean_dist_pred_384 = (mean_dist_contour_384[1])[0]
    cc_384, cc_tp_384, cc_fn_384, cc_ud_384 = metrics.connected_components.connected_components_detection((test_section_label>BINARY_THRESHOLD), (test_section_pred>BINARY_THRESHOLD), 0.75)
    
    
    f_results = open(f_results_name, "a")
    f_results.write(modelname + ",")
    f_results.write(str(train_time) + ",")
    f_results.write(str(acc_384) + ",")
    f_results.write(str(iou_384) + ",")
    f_results.write(str(f1_384) + ",")
    f_results.write(str(mean_dist_gd_384) + ",")
    f_results.write(str(mean_dist_pred_384) + ",")
    f_results.write(str(cc_tp_384) + ",")
    f_results.write(str(cc_fn_384) + ",")
    f_results.write(str(cc_ud_384) + "\n")
    f_results.close()

elif len(PATCH_SIZE) == 3:
    test_section_pred.fill(0)
    for y_ in range(4):
        for x_ in range(4):
            x = x_*384
            y = y_*384
            pad = 18
            p = np.expand_dims(test_label[1250-pad+x:1250-pad+384+x, 1250-pad+y:1250-pad+384+y, 0:64], 0)
            p = model.predict(p)[0,:, :, 10:10+50]

            if x_ == 0:
                p = p[pad:, :, :]
            else:
                x = x-pad
            if x_ == 3:
                p = p[:-pad, :, :]
            if y_ == 0:
                p = p[:, pad:, :]
            else:
                y = y-pad
            if y_ == 3:
                p = p[:, :-pad, :]

            test_section_pred[x:x+p.shape[0], y:y+p.shape[1], :] = p[:, :, :]
    
    test_section_label_f = (test_section_label>BINARY_THRESHOLD).flatten()
    test_section_pred_f = (test_section_pred>BINARY_THRESHOLD).flatten()
    acc_384 = sklearn.metrics.accuracy_score(test_section_label_f, test_section_pred_f)
    iou_384 = sklearn.metrics.jaccard_score(test_section_label_f, test_section_pred_f)
    f1_384 = sklearn.metrics.f1_score(test_section_label_f, test_section_pred_f)
    mean_dist_contour_384 = metrics.distance_contour.distance_contour_segmentation_3D((test_section_label>BINARY_THRESHOLD), (test_section_pred>BINARY_THRESHOLD))
    mean_dist_gd_384 = (mean_dist_contour_384[0])[0]
    mean_dist_pred_384 = (mean_dist_contour_384[1])[0]
    cc_384, cc_tp_384, cc_fn_384, cc_ud_384 = metrics.connected_components.connected_components_detection((test_section_label>BINARY_THRESHOLD), (test_section_pred>BINARY_THRESHOLD), 0.75)
    
    
    f_results = open(f_results_name, "a")
    f_results.write(modelname + ",")
    f_results.write(str(train_time) + ",")
    f_results.write(str(acc_384) + ",")
    f_results.write(str(iou_384) + ",")
    f_results.write(str(f1_384) + ",")
    f_results.write(str(mean_dist_gd_384) + ",")
    f_results.write(str(mean_dist_pred_384) + ",")
    f_results.write(str(cc_tp_384) + ",")
    f_results.write(str(cc_fn_384) + ",")
    f_results.write(str(cc_ud_384) + "\n")
    f_results.close()
    
else:
    print("error: patch size", flush=True)
    exit(1)