In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.ops import nn_ops
from tensorflow.python.keras import initializers, regularizers, constraints
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import tf_utils

from tensorflow.python.keras import layers
from tensorflow.python.keras.layers import Input, Layer, Flatten
from tensorflow.python.keras.layers.core import Activation, Reshape
from tensorflow.python.keras.layers.convolutional import Convolution2D
from tensorflow.python.keras.layers.normalization import BatchNormalization

In [None]:
class MaxPoolingWithArgmax2D(Layer):

    def __init__(self, pool_size=(2, 2), strides=(2, 2), padding='same', **kwargs):
        super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)
        self.pool_size = conv_utils.normalize_tuple(pool_size, 2, 'pool_size')
        self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
        self.padding = conv_utils.normalize_padding(padding)

    def call(self, inputs, **kwargs):
        ksize = [1, self.pool_size[0], self.pool_size[1], 1]
        strides = [1, self.strides[0], self.strides[1], 1]
        padding = self.padding.upper()
        output, argmax = nn_ops.max_pool_with_argmax(inputs, ksize, strides, padding)
        argmax = tf.cast(argmax, K.floatx())
        return [output, argmax]
    
    @tf_utils.shape_type_conversion
    def compute_output_shape(self, input_shape):
        ratio = (1, 2, 2, 1)
        output_shape = [dim // ratio[idx] if dim is not None else None for idx, dim in enumerate(input_shape)]
        output_shape = tuple(output_shape)
        return [output_shape, output_shape]

    def compute_mask(self, inputs, mask=None):
        return 2 * [None]

In [None]:
class MaxUnpooling2D(Layer):

    def __init__(self, size=(2, 2), **kwargs):
        super(MaxUnpooling2D, self).__init__(**kwargs)
        self.size = conv_utils.normalize_tuple(size, 2, 'size')

    def call(self, inputs, output_shape=None):
        updates, mask = inputs[0], inputs[1]
        with tf.variable_scope(self.name):
            mask = tf.cast(mask, 'int32')
            input_shape = tf.shape(updates, out_type='int32')
            #  calculation new shape
            if output_shape is None:
                output_shape = (input_shape[0], input_shape[1] * self.size[0], input_shape[2] * self.size[1], input_shape[3])
            
            # calculation indices for batch, height, width and feature maps
            one_like_mask = K.ones_like(mask, dtype='int32')
            batch_shape = K.concatenate([[input_shape[0]], [1], [1], [1]], axis=0)
            batch_range = K.reshape(tf.range(output_shape[0], dtype='int32'), shape=batch_shape)
            b = one_like_mask * batch_range
            y = mask // (output_shape[2] * output_shape[3])
            x = (mask // output_shape[3]) % output_shape[2]
            feature_range = tf.range(output_shape[3], dtype='int32')
            f = one_like_mask * feature_range
            
            # transpose indices & reshape update values to one dimension
            updates_size = tf.size(updates)
            indices = K.transpose(K.reshape(K.stack([b, y, x, f]), [4, updates_size]))
            values = K.reshape(updates, [updates_size])
            ret = tf.scatter_nd(indices, values, output_shape)
            return ret
    
    @tf_utils.shape_type_conversion
    def compute_output_shape(self, input_shape):
        mask_shape = input_shape[1]
        output_shape = [mask_shape[0], mask_shape[1] * self.size[0], mask_shape[2] * self.size[1], mask_shape[3]]
        return tuple(output_shape)

In [None]:
# Build SEGNET model
IMG_HEIGHT = 128
IMG_WIDTH = 128
IMG_CHANNELS = 3
n_labels = 2
output_mode = "softmax"

inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))

# Encoder

conv_1 = Convolution2D(64, (3, 3), kernel_initializer='he_normal', padding="same")(inputs)
conv_1 = BatchNormalization()(conv_1)
conv_1 = Activation("relu")(conv_1)
conv_2 = Convolution2D(64, (3, 3), kernel_initializer='he_normal', padding="same")(conv_1)
conv_2 = BatchNormalization()(conv_2)
conv_2 = Activation("relu")(conv_2)

pool_1, mask_1 = MaxPoolingWithArgmax2D((2,2))(conv_2)

conv_3 = Convolution2D(128, (3, 3), kernel_initializer='he_normal', padding="same")(pool_1)
conv_3 = BatchNormalization()(conv_3)
conv_3 = Activation("relu")(conv_3)
conv_4 = Convolution2D(128, (3, 3), kernel_initializer='he_normal', padding="same")(conv_3)
conv_4 = BatchNormalization()(conv_4)
conv_4 = Activation("relu")(conv_4)

pool_2, mask_2 = MaxPoolingWithArgmax2D((2,2))(conv_4)

conv_5 = Convolution2D(256, (3, 3), kernel_initializer='he_normal', padding="same")(pool_2)
conv_5 = BatchNormalization()(conv_5)
conv_5 = Activation("relu")(conv_5)
conv_6 = Convolution2D(256, (3, 3), kernel_initializer='he_normal', padding="same")(conv_5)
conv_6 = BatchNormalization()(conv_6)
conv_6 = Activation("relu")(conv_6)
conv_7 = Convolution2D(256, (3, 3), kernel_initializer='he_normal', padding="same")(conv_6)
conv_7 = BatchNormalization()(conv_7)
conv_7 = Activation("relu")(conv_7)

pool_3, mask_3 = MaxPoolingWithArgmax2D((2,2))(conv_7)

conv_8 = Convolution2D(512, (3, 3), kernel_initializer='he_normal', padding="same")(pool_3)
conv_8 = BatchNormalization()(conv_8)
conv_8 = Activation("relu")(conv_8)
conv_9 = Convolution2D(512, (3, 3), kernel_initializer='he_normal', padding="same")(conv_8)
conv_9 = BatchNormalization()(conv_9)
conv_9 = Activation("relu")(conv_9)
conv_10 = Convolution2D(512, (3, 3), kernel_initializer='he_normal', padding="same")(conv_9)
conv_10 = BatchNormalization()(conv_10)
conv_10 = Activation("relu")(conv_10)

pool_4, mask_4 = MaxPoolingWithArgmax2D((2,2))(conv_10)

conv_11 = Convolution2D(512, (3, 3), kernel_initializer='he_normal', padding="same")(pool_4)
conv_11 = BatchNormalization()(conv_11)
conv_11 = Activation("relu")(conv_11)
conv_12 = Convolution2D(512, (3, 3), kernel_initializer='he_normal', padding="same")(conv_11)
conv_12 = BatchNormalization()(conv_12)
conv_12 = Activation("relu")(conv_12)
conv_13 = Convolution2D(512, (3, 3), kernel_initializer='he_normal', padding="same")(conv_12)
conv_13 = BatchNormalization()(conv_13)
conv_13 = Activation("relu")(conv_13)

pool_5, mask_5 = MaxPoolingWithArgmax2D((2,2))(conv_13)

# Decoder

unpool_1 = MaxUnpooling2D((2,2))([pool_5, mask_5])

conv_14 = Convolution2D(512, (3, 3), kernel_initializer='he_normal', padding="same")(unpool_1)
conv_14 = BatchNormalization()(conv_14)
conv_14 = Activation("relu")(conv_14)
conv_15 = Convolution2D(512, (3, 3), kernel_initializer='he_normal', padding="same")(conv_14)
conv_15 = BatchNormalization()(conv_15)
conv_15 = Activation("relu")(conv_15)
conv_16 = Convolution2D(512, (3, 3), kernel_initializer='he_normal', padding="same")(conv_15)
conv_16 = BatchNormalization()(conv_16)
conv_16 = Activation("relu")(conv_16)

unpool_2 = MaxUnpooling2D((2,2))([conv_16, mask_4])
