In [1]:
import os

config = {}
config["weights_file"] = os.getcwd() + '/model/weight'
config["patch_size"] = (64, 64, 64)  # switch to None to train on the whole image
config["patch_gap"] = 16
config["batch_size"] = 10
config["kfold"] = 5

config["input_shape"] = (1, None, None, None)
config["depth"] = 4 # depth of layers for V/Unet
config["n_base_filters"] = 32
config["pool_size"] = (2, 2, 2)  # pool size for the max pooling operations
config["deconvolution"] = True  # if False, will use upsampling instead of deconvolution

config["patience"] = 10  # learning rate will be reduced after this many epochs if the validation loss is not improving
config["early_stop"] = 10  # training will be stopped after this many epochs without the validation loss improving
config["initial_learning_rate"] = 0.00001
config["learning_rate_drop"] = 0.5  # factor by which the learning rate will be reduced
config["n_epochs"] = 1

In [2]:
from model.data import *
from model.model import *

d = Data()
d.load_data()

# prepare data for training
train_num, valid_num = d.prekfold(config["patch_size"], config["patch_gap"], config["batch_size"], config["kfold"])

Using TensorFlow backend.


In [3]:

model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"],
                              depth=config["depth"],
                              n_base_filters=config["n_base_filters"])

model.load_weights(os.getcwd() + '/model/weight/weights-01-0.00.hdf5')


In [None]:
############## need to have a way to merge all patches for testing the resulting weights ##############

image = []
target = []
result = []
indices = None
shape = None
fold_index = 0
for i in d.valid_index:
    j = d.valid_index[i][fold_index]
    indices = d.patch_index[i][j]
    shape = d.data[i][j][0].shape
    for ind in range(d.patch_index[i].shape[1]):
        patch = d.patch_index[i][j][ind]
        image_i = np.expand_dims(d.data[i][j][0][patch[0]:patch[0]+d.patch_size[0], 
                         patch[1]:patch[1]+d.patch_size[1], 
                         patch[2]:patch[2]+d.patch_size[2]], axis=0)
        target_i = np.expand_dims(d.data[i][j][1][patch[0]:patch[0]+d.patch_size[0], 
                         patch[1]:patch[1]+d.patch_size[1], 
                         patch[2]:patch[2]+d.patch_size[2]], axis=0)
        image.append(image_i)
        target.append(target_i)
        result.append(model.predict([image_i[None, :]]))
    break


In [None]:
def reconstruct_from_patches(patches, patch_indices, data_shape, default_value=0):
    """
    Reconstructs an array of the original shape from the lists of patches and corresponding patch indices. Overlapping
    patches are averaged.
    :param patches: List of numpy array patches.
    :param patch_indices: List of indices that corresponds to the list of patches.
    :param data_shape: Shape of the array from which the patches were extracted.
    :param default_value: The default value of the resulting data. if the patch coverage is complete, this value will
    be overwritten.
    :return: numpy array containing the data reconstructed by the patches.
    """
    data = np.ones(data_shape) * default_value
    image_shape = data_shape[-3:]
    count = np.zeros(data_shape, dtype=np.int)
    for patch, index in zip(patches, patch_indices):
        image_patch_shape = patch.shape[-3:]
        patch_index = np.zeros(data_shape, dtype=np.bool)
        patch_index[...,
                    index[0]:index[0]+patch.shape[-3],
                    index[1]:index[1]+patch.shape[-2],
                    index[2]:index[2]+patch.shape[-1]] = True
        patch_data = np.zeros(data_shape)
        
        patch_data[patch_index] = patch.flatten()

        new_data_index = np.logical_and(patch_index, np.logical_not(count > 0))
        data[new_data_index] = patch_data[new_data_index]

        averaged_data_index = np.logical_and(patch_index, count > 0)
        if np.any(averaged_data_index):
            data[averaged_data_index] = (data[averaged_data_index] * count[averaged_data_index] + patch_data[averaged_data_index]) / (count[averaged_data_index] + 1)
        count[patch_index] += 1
    return data


merge_image = reconstruct_from_patches(image, indices, shape[-3:])
merge_target = reconstruct_from_patches(target, indices, shape[-3:])
merge_result = reconstruct_from_patches(result, indices, shape[-3:])

In [None]:
d.show_image([merge_image, merge_target, merge_result])

In [4]:
image = []
target = []
result = []

count = 0
for img, tar in d.valid_generator(0):
    if count == 20:
        break
    image.append(img)
    target.append(tar)
    result.append(model.predict(img))
    count += 1


In [7]:
sel = 13
d.show_image([image[sel], target[sel], result[sel]>0])
print("target:", np.mean(target[sel]))
print("result:", np.mean(result[sel]))
print(dice_coefficient(target[sel], result[sel]))

print(np.unique(result[sel]))

interactive(children=(IntSlider(value=32, description='id', max=63), Output()), _dom_classes=('widget-interact…

target: 0.0
result: 0.00013920286
0.026672928987640735
[0.0000000e+00 1.1769894e-38 1.1810282e-38 ... 3.5310113e-01 3.5785753e-01
 3.6344498e-01]
