# KerasでU-Net  
[U-NetでPascal VOC 2012の画像をSemantic Segmentationする (TensorFlow)](https://qiita.com/tktktks10/items/0f551aea27d2f62ef708)
Kerasバージョンに変更する

* keras == 2.0.4  
* tensorflow == 1.15.0  

In [1]:
IMAGE_SIZE = 256

## モデル  

In [2]:
from keras.models import Model
from keras.layers import Input
from keras.layers.convolutional import Conv2D, ZeroPadding2D, Conv2DTranspose
from keras.layers.merge import concatenate
from keras.layers import LeakyReLU, BatchNormalization, Activation, Dropout

class UNet(object):
    def __init__(self, input_channel_count, output_channel_count, first_layer_filter_count):
        self.INPUT_IMAGE_SIZE = 256
        self.CONCATENATE_AXIS = -1
        self.CONV_FILTER_SIZE = 4
        self.CONV_STRIDE = 2
        self.CONV_PADDING = (1, 1)
        self.DECONV_FILTER_SIZE = 2
        self.DECONV_STRIDE = 2

        # (256 x 256 x input_channel_count)
        inputs = Input((self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE, input_channel_count))

        # エンコーダーの作成
        # (128 x 128 x N)
        enc1 = ZeroPadding2D(self.CONV_PADDING)(inputs)
        enc1 = Conv2D(first_layer_filter_count, self.CONV_FILTER_SIZE, strides=self.CONV_STRIDE)(enc1)

        # (64 x 64 x 2N)
        filter_count = first_layer_filter_count*2
        enc2 = self._add_encoding_layer(filter_count, enc1)

        # (32 x 32 x 4N)
        filter_count = first_layer_filter_count*4
        enc3 = self._add_encoding_layer(filter_count, enc2)

        # (16 x 16 x 8N)
        filter_count = first_layer_filter_count*8
        enc4 = self._add_encoding_layer(filter_count, enc3)

        # (8 x 8 x 8N)
        enc5 = self._add_encoding_layer(filter_count, enc4)

        # (4 x 4 x 8N)
        enc6 = self._add_encoding_layer(filter_count, enc5)

        # (2 x 2 x 8N)
        enc7 = self._add_encoding_layer(filter_count, enc6)

        # (1 x 1 x 8N)
        enc8 = self._add_encoding_layer(filter_count, enc7)

        # デコーダーの作成
        # (2 x 2 x 8N)
        dec1 = self._add_decoding_layer(filter_count, True, enc8)
        dec1 = concatenate([dec1, enc7], axis=self.CONCATENATE_AXIS)

        # (4 x 4 x 8N)
        dec2 = self._add_decoding_layer(filter_count, True, dec1)
        dec2 = concatenate([dec2, enc6], axis=self.CONCATENATE_AXIS)

        # (8 x 8 x 8N)
        dec3 = self._add_decoding_layer(filter_count, True, dec2)
        dec3 = concatenate([dec3, enc5], axis=self.CONCATENATE_AXIS)

        # (16 x 16 x 8N)
        dec4 = self._add_decoding_layer(filter_count, False, dec3)
        dec4 = concatenate([dec4, enc4], axis=self.CONCATENATE_AXIS)

        # (32 x 32 x 4N)
        filter_count = first_layer_filter_count*4
        dec5 = self._add_decoding_layer(filter_count, False, dec4)
        dec5 = concatenate([dec5, enc3], axis=self.CONCATENATE_AXIS)

        # (64 x 64 x 2N)
        filter_count = first_layer_filter_count*2
        dec6 = self._add_decoding_layer(filter_count, False, dec5)
        dec6 = concatenate([dec6, enc2], axis=self.CONCATENATE_AXIS)

        # (128 x 128 x N)
        filter_count = first_layer_filter_count
        dec7 = self._add_decoding_layer(filter_count, False, dec6)
        dec7 = concatenate([dec7, enc1], axis=self.CONCATENATE_AXIS)

        # (256 x 256 x output_channel_count)
        dec8 = Activation(activation='relu')(dec7)
        dec8 = Conv2DTranspose(output_channel_count, self.DECONV_FILTER_SIZE, strides=self.DECONV_STRIDE)(dec8)
        dec8 = Activation(activation='sigmoid')(dec8)

        self.UNET = Model(input=inputs, output=dec8)

    def _add_encoding_layer(self, filter_count, sequence):
        new_sequence = LeakyReLU(0.2)(sequence)
        new_sequence = ZeroPadding2D(self.CONV_PADDING)(new_sequence)
        new_sequence = Conv2D(filter_count, self.CONV_FILTER_SIZE, strides=self.CONV_STRIDE)(new_sequence)
        new_sequence = BatchNormalization()(new_sequence)
        return new_sequence

    def _add_decoding_layer(self, filter_count, add_drop_layer, sequence):
        new_sequence = Activation(activation='relu')(sequence)
        new_sequence = Conv2DTranspose(filter_count, self.DECONV_FILTER_SIZE, strides=self.DECONV_STRIDE,
                                       kernel_initializer='he_uniform')(new_sequence)
        new_sequence = BatchNormalization()(new_sequence)
        if add_drop_layer:
            new_sequence = Dropout(0.5)(new_sequence)
        return new_sequence

    def get_model(self):
        return self.UNET
    

Using TensorFlow backend.


## 前処理関連の関数

In [3]:
# 値を-1から1に正規化する関数
def normalize_x(image):
    image = image/127.5 - 1
    return image


# 値を0から1に正規化する関数
def normalize_y(image):
    image = image/255
    return image


# 値を0から255に戻す関数
def denormalize_y(image):
    image = image*255
    return image


# インプット画像を読み込む関数
def load_X(folder_path):
    import os, cv2

    image_files = os.listdir(folder_path)
    image_files.sort()
    images = np.zeros((len(image_files), IMAGE_SIZE, IMAGE_SIZE, 3), np.float32)
    for i, image_file in enumerate(image_files):
        image = cv2.imread(folder_path + os.sep + image_file)
        image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
        images[i] = normalize_x(image)
    return images, image_files


# ラベル画像を読み込む関数
def load_Y(folder_path):
    import os, cv2

    image_files = os.listdir(folder_path)
    image_files.sort()
    images = np.zeros((len(image_files), IMAGE_SIZE, IMAGE_SIZE, 1), np.float32)
    for i, image_file in enumerate(image_files):
        image = cv2.imread(folder_path + os.sep + image_file, cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
        image = image[:, :, np.newaxis]
        images[i] = normalize_y(image)
    return images

## メイン処理

In [5]:
import os
import numpy as np
from keras.optimizers import Adam
import keras.backend as K
from keras.callbacks import ModelCheckpoint, EarlyStopping
#from unet import UNet

# ダイス係数を計算する関数
def dice_coef(y_true, y_pred):
    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)
    intersection = K.sum(y_true * y_pred)
    return 2.0 * intersection / (K.sum(y_true) + K.sum(y_pred) + 1)


# ロス関数
def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)


# U-Netのトレーニングを実行する関数
def train_unet():
    X_train, file_names = load_X('trainingData' + os.sep + 'image')
    Y_train = load_Y('trainingData' + os.sep + 'mask')

    # 入力はBGR3チャンネル
    input_channel_count = 3
    # 出力はグレースケール1チャンネル
    output_channel_count = 1
    # 一番初めのConvolutionフィルタ枚数は64
    first_layer_filter_count = 64
    # U-Netの生成
    network = UNet(input_channel_count, output_channel_count, first_layer_filter_count)
    model = network.get_model()
    model.compile(loss=dice_coef_loss, optimizer=Adam(), metrics=[dice_coef])

    BATCH_SIZE = 12
    NUM_EPOCH = 100
    history = model.fit(X_train, Y_train, batch_size=BATCH_SIZE, epochs=NUM_EPOCH, verbose=1)
    model.save_weights('unet_weights.h5')


# 学習後のU-Netによる予測を行う関数
def predict():
    import cv2

    X_test, file_names = load_X('testData' + os.sep + 'image')

    input_channel_count = 3
    output_channel_count = 1
    first_layer_filter_count = 64
    network = UNet(input_channel_count, output_channel_count, first_layer_filter_count)
    model = network.get_model()
    model.load_weights('unet_weights.h5')
    BATCH_SIZE = 12
    Y_pred = model.predict(X_test, BATCH_SIZE)

    for i, y in enumerate(Y_pred):
        img = cv2.imread('testData' + os.sep + 'image' + os.sep + file_names[i])
        y = cv2.resize(y, (img.shape[1], img.shape[0]))
        cv2.imwrite('prediction' + os.sep + 'prediction' + str(i) + '.png', denormalize_y(y))


if __name__ == '__main__':
    train_unet()
    predict()



Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100

ResourceExhaustedError: OOM when allocating tensor with shape[12,128,128,64] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu
	 [[node gradients_1/concatenate_14/concat_grad/Slice (defined at C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\framework\ops.py:1748) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.


Original stack trace for 'gradients_1/concatenate_14/concat_grad/Slice':
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\traitlets\config\application.py", line 664, in launch_instance
    app.start()
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\ipykernel\kernelapp.py", line 597, in start
    self.io_loop.start()
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tornado\platform\asyncio.py", line 149, in start
    self.asyncio_loop.run_forever()
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\asyncio\base_events.py", line 541, in run_forever
    self._run_once()
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\asyncio\base_events.py", line 1786, in _run_once
    handle._run()
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\asyncio\events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tornado\ioloop.py", line 690, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tornado\ioloop.py", line 743, in _run_callback
    ret = callback()
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tornado\gen.py", line 787, in inner
    self.run()
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tornado\gen.py", line 748, in run
    yielded = self.gen.send(value)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\ipykernel\kernelbase.py", line 365, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tornado\gen.py", line 209, in wrapper
    yielded = next(result)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\ipykernel\kernelbase.py", line 268, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tornado\gen.py", line 209, in wrapper
    yielded = next(result)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\ipykernel\kernelbase.py", line 545, in execute_request
    user_expressions, allow_stdin,
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tornado\gen.py", line 209, in wrapper
    yielded = next(result)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\ipykernel\ipkernel.py", line 300, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\ipykernel\zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\IPython\core\interactiveshell.py", line 2867, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\IPython\core\interactiveshell.py", line 2895, in _run_cell
    return runner(coro)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\IPython\core\async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\IPython\core\interactiveshell.py", line 3072, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\IPython\core\interactiveshell.py", line 3263, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\IPython\core\interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-5-200b219201df>", line 65, in <module>
    train_unet()
  File "<ipython-input-5-200b219201df>", line 39, in train_unet
    history = model.fit(X_train, Y_train, batch_size=BATCH_SIZE, epochs=NUM_EPOCH, verbose=1)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\keras\engine\training.py", line 1481, in fit
    self._make_train_function()
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\keras\engine\training.py", line 1013, in _make_train_function
    self.total_loss)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\keras\optimizers.py", line 381, in get_updates
    grads = self.get_gradients(loss, params)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\keras\optimizers.py", line 47, in get_gradients
    grads = K.gradients(loss, params)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\keras\backend\tensorflow_backend.py", line 2264, in gradients
    return tf.gradients(loss, variables, colocate_gradients_with_ops=True)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\gradients_impl.py", line 158, in gradients
    unconnected_gradients)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\gradients_util.py", line 679, in _GradientsHelper
    lambda: grad_fn(op, *out_grads))
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\gradients_util.py", line 350, in _MaybeCompile
    return grad_fn()  # Exit early
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\gradients_util.py", line 679, in <lambda>
    lambda: grad_fn(op, *out_grads))
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\array_grad.py", line 228, in _ConcatGradV2
    op, grad, start_value_index=0, end_value_index=-1, dim_index=-1)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\array_grad.py", line 156, in _ConcatGradHelper
    out_grads.append(array_ops.slice(grad, begin, size))
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\array_ops.py", line 855, in slice
    return gen_array_ops._slice(input_, begin, size, name=name)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\gen_array_ops.py", line 9222, in _slice
    "Slice", input=input, begin=begin, size=size, name=name)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\framework\op_def_library.py", line 794, in _apply_op_helper
    op_def=op_def)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\framework\ops.py", line 3357, in create_op
    attrs, op_def, compute_device)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\framework\ops.py", line 3426, in _create_op_internal
    op_def=op_def)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\framework\ops.py", line 1748, in __init__
    self._traceback = tf_stack.extract_stack()

...which was originally created as op 'concatenate_14/concat', defined at:
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
[elided 26 identical lines from previous traceback]
  File "<ipython-input-5-200b219201df>", line 65, in <module>
    train_unet()
  File "<ipython-input-5-200b219201df>", line 33, in train_unet
    network = UNet(input_channel_count, output_channel_count, first_layer_filter_count)
  File "<ipython-input-2-a9958481612b>", line 79, in __init__
    dec7 = concatenate([dec7, enc1], axis=self.CONCATENATE_AXIS)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\keras\layers\merge.py", line 508, in concatenate
    return Concatenate(axis=axis, **kwargs)(inputs)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\keras\engine\topology.py", line 585, in __call__
    output = self.call(inputs, **kwargs)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\keras\layers\merge.py", line 283, in call
    return K.concatenate(inputs, axis=self.axis)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\keras\backend\tensorflow_backend.py", line 1681, in concatenate
    return tf.concat([to_dense(x) for x in tensors], axis)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\util\dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\array_ops.py", line 1420, in concat
    return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\ops\gen_array_ops.py", line 1257, in concat_v2
    "ConcatV2", values=values, axis=axis, name=name)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\framework\op_def_library.py", line 794, in _apply_op_helper
    op_def=op_def)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\framework\ops.py", line 3357, in create_op
    attrs, op_def, compute_device)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\framework\ops.py", line 3426, in _create_op_internal
    op_def=op_def)
  File "C:\Users\Teppei\Anaconda3\envs\u-net-keras\lib\site-packages\tensorflow_core\python\framework\ops.py", line 1748, in __init__
    self._traceback = tf_stack.extract_stack()
