In [1]:
# reference: https://github.com/ellisdg/3DUnetCNN/
%load_ext autoreload
%autoreload 2

In [2]:
import os

config = {}
config["weights_file"] = os.getcwd() + '/model/weight'
config["patch_size"] = (64, 64, 64)  # switch to None to train on the whole image
config["patch_gap"] = 16
config["batch_size"] = 2
config["kfold"] = 5

config["input_shape"] = (1, None, None, None)
config["depth"] = 4 # depth of layers for V/Unet
config["n_base_filters"] = 32
config["pool_size"] = (2, 2, 2)  # pool size for the max pooling operations
config["deconvolution"] = True  # if False, will use upsampling instead of deconvolution

config["patience"] = 10  # learning rate will be reduced after this many epochs if the validation loss is not improving
config["early_stop"] = 10  # training will be stopped after this many epochs without the validation loss improving
config["initial_learning_rate"] = 0.00001
config["learning_rate_drop"] = 0.5  # factor by which the learning rate will be reduced
config["n_epochs"] = 10

In [3]:
from model.data import *
from model.model import *

d = Data()
d.load_data(config["patch_size"])

# prepare data for training
train_num, valid_num = d.prekfold(config["patch_size"], config["patch_gap"], config["batch_size"], config["kfold"])
train_generator = DataGenerator(d.data, d.patch_index, d.kfold, d.batch_size, 
                                d.patch_size, d.patch_gap, d.valid_index, True)
valid_generator = DataGenerator(d.data, d.patch_index, d.kfold, d.batch_size, 
                                d.patch_size, d.patch_gap, d.valid_index, False)

Using TensorFlow backend.


In [4]:
result = []
target = []
image = []

def train(config, data, train_generator, valid_generator, train_num, valid_num):
#     models = []
    print(train_num, valid_num)
    for i in range(data.kfold):
        print ('-'*100)
        print ("Fold:", i)
        
        train_generator.set_index(i)
        valid_generator.set_index(i)
        
        model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"],
                              depth=config["depth"],
                              n_base_filters=config["n_base_filters"])
        
        # print(model.summary())
        
        callbacks = get_callbacks(config["weights_file"],
                                initial_learning_rate=config["initial_learning_rate"],
                                learning_rate_drop=config["learning_rate_drop"],
                                learning_rate_patience=config["patience"],
                                early_stopping_patience=config["early_stop"])

        model.fit_generator(generator=train_generator,
                            steps_per_epoch=train_num,
                            epochs=config["n_epochs"],
                            validation_data=valid_generator,
                            validation_steps=valid_num,
                            callbacks=callbacks,
                            workers=2,
                            verbose=1)
        break
train(config, d, train_generator, valid_generator, train_num, valid_num)

26496 6624
----------------------------------------------------------------------------------------------------
Fold: 0
Epoch 1/10
  323/26496 [..............................] - ETA: 3:59:19 - loss: 0.9359 - dice_coefficient: 0.0641

KeyboardInterrupt: 

In [None]:
# import ipywidgets as widgets
# from ipywidgets import interact, interactive

# def show_image(images):
#     def show_frame(id):
#         length = len(images)
#         for i in range(length):
#             plt.subplot(1, length, i+1)
#             plt.imshow(images[i][0, 0, id, :, :], cmap='gray')
#     interact(show_frame, id=widgets.IntSlider(min=0, max=images[0].shape[2]-1, step=1, value=images[0].shape[2]/2))

# sel = next(iter(d.valid_index))
# show_image([result[0], image[0][None, :, :, :, :], target[0][None, :, :, :, :]])


In [None]:
# dice = 0
# for i in range(len(result)):
#     dice += dice_coefficient(target[i], result[i])

# print("dice:", dice / len(result))
# print("loss:", 1 - dice / len(result)) 

# print(dice_coefficient(target, result))
# print(np.array(result).shape)

In [None]:

# model = unet_model_3d(input_shape=config["input_shape"],
#                               pool_size=config["pool_size"],
#                               initial_learning_rate=config["initial_learning_rate"],
#                               deconvolution=config["deconvolution"],
#                               depth=config["depth"],
#                               n_base_filters=config["n_base_filters"])

# model.load_weights(os.getcwd() + '/model/weight/weights-10--1.00.hdf5.h5')

# for j in d.valid_index:
#     valid = d.data[j][0] # fold = 0
#     shape = valid[0].shape
# #             x = shape[0] // 2
# #             y = shape[1] // 2
# #             z = shape[2] // 2
# #             input = np.array([np.expand_dims(valid[0][x : x + patch_size, y : y + patch_size, z : z + patch_size], axis=0)])
# #             result.append(model.predict(input))
# #             image.append(input)
# #             target.append([np.expand_dims(valid[1][x : x + patch_size, y : y + patch_size, z : z + patch_size], axis=0)])

#     input = np.array([np.expand_dims(ndimage.zoom(valid[0], (32/shape[0], 32/shape[1], 32/shape[2])), axis=0)])
#     target.append(np.expand_dims(ndimage.zoom(valid[1], (32/shape[0], 32/shape[1], 32/shape[2])), axis=0))
#     image.append(input)
#     result.append(model.predict(input))
