In [1]:
import os
import sys
import warnings
warnings.filterwarnings("ignore")

import numpy as np
from keras.layers import (
    Conv2D, BatchNormalization, Dense, 
    ZeroPadding2D, Activation, GlobalAveragePooling2D,
    Reshape, Permute, multiply, AveragePooling2D,
    UpSampling2D, Concatenate, Add, Lambda, Multiply
)
from keras.models import Model, Sequential
from keras.layers import Input
import keras.backend as K
from keras.layers import DepthwiseConv2D, PReLU

Using TensorFlow backend.


In [2]:
import cv2
import imutils

In [3]:
from glob import glob
import pandas as pd
from sklearn.utils import shuffle
import imgaug as ia
from imgaug import augmenters as iaa

In [140]:
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, BaseLogger
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

## Constant Variables

In [5]:
IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNEL = 3
CLASSES = ["Background", "Person"]
N_CLASSES = 2
IMAGE_DATA_FORMAT = K.image_data_format()

## Building model

In [574]:
class SiNet:
    def __init__(self, img_height, img_width, img_channel, n_classes, reg=1e-4):
        self.img_height = img_height
        self.img_width = img_width
        self.img_channel = img_channel
        self.n_classes = n_classes
        self.reg = reg
        self.channel_axis = 1 if K.image_data_format() == "channels_first" else -1
        self.alpha = 1.0
        
    def relu6(self, x):
        return K.relu(x, max_value=6)
    
    def _conv_block(self, inputs, filters, alpha, strides=(1,1), kernel=(3,3), block_id=1, padding="valid"):
        """"""
        filters = int(filters * alpha)
        
        if padding=="same":
            x = inputs
        else:
            x = ZeroPadding2D((1, 1), data_format=IMAGE_DATA_FORMAT, 
                              name="conv_%s_pad" % block_id)(inputs)        
        
        x = Conv2D(filters, kernel, data_format=IMAGE_DATA_FORMAT, 
                   padding=padding, use_bias=False, strides=strides, 
                   kernel_initializer="he_normal", name="conv_%s" % block_id)(x)
        x = BatchNormalization(axis=self.channel_axis, name="conv_%s_bn" % block_id)(x)
        x = Activation("relu", name="conv_%s_act" % block_id)(x)
        
        return x
    
    def _pointwise_conv_block(self, inputs, pointwise_conv_filters, alpha, 
                              strides=(1, 1), block_id=1):
        x = Conv2D(pointwise_conv_filters, 
                   (1, 1),
                   data_format=IMAGE_DATA_FORMAT,
                   padding="same",
                   use_bias=False,
                   strides=(1, 1),
                   name="conv_pw_%s" % block_id)(inputs)
        x = BatchNormalization(axis=self.channel_axis,
                               name="conv_pw_%s_bn" % block_id)(x)
        x = Activation(self.relu6, name="conv_pw_%s_relu" % block_id)(x)
        
        return x
        
    def _depthwise_conv_block(self, inputs, pointwise_conv_filters, alpha, 
                              depth_multiplier=1, strides=(1, 1), block_id=1, 
                              kernel=(3,3), padding_size=(1, 1)):
        """"""
        pointwise_conv_filters = int(pointwise_conv_filters * alpha)
        
        x = ZeroPadding2D(padding_size, 
                          data_format=IMAGE_DATA_FORMAT,
                          name="conv_pad_%s" % block_id)(inputs)
        x = DepthwiseConv2D(kernel_size=kernel,
                            data_format=IMAGE_DATA_FORMAT,
                            depth_multiplier=depth_multiplier,
                            strides=strides,
                            use_bias=False,
                            name="conv_dw_%s" % block_id)(x)
        x = BatchNormalization(axis=self.channel_axis,
                               name="conv_dw_%s_bn" % block_id)(x)
#         x = Activation("PReLu", name="conv_dw_%d_Prelu" % block_id)(x)
        x = PReLU(name="conv_dw_%s_Prelu" % block_id)(x)
        
        x = self._pointwise_conv_block(x, pointwise_conv_filters, self.alpha, block_id=block_id)
        
        return x
    
    def _squeeze_excite_block(self, inputs, ratio=16, block_id=1):
        """"""
        filters = inputs._keras_shape[self.channel_axis]
        se_shape = (1, 1, filters) if self.channel_axis == -1 else (filters, 1, 1)
        
        se = GlobalAveragePooling2D(name="squeeze_glo_avg_%s" % block_id)(inputs)
        se = Dense(filters // ratio, activation="relu", 
                   kernel_initializer="he_normal", 
                   use_bias=False, name="squeeze_squ_%s" % block_id)(se)
        se = Dense(filters, activation="relu", kernel_initializer="he_normal", 
                   use_bias=False, name="squeeze_exci_%s" % block_id)(se)
        se = multiply([inputs, se], name="squeeze_scale_%s" % block_id)
        
        return se
    
    def _depthwise_conv_se_block(self, inputs, pointwise_conv_filters, alpha, 
                                 depth_multiplier=1, strides=(2, 2), block_id=1,
                                 kernel=(3,3), ratio=16):
        """
        DS-Conv + SE
        """
        x = self._depthwise_conv_block(inputs, pointwise_conv_filters, alpha, 
                                       block_id=block_id, strides=strides)
        x = self._squeeze_excite_block(x, ratio=ratio, block_id=block_id)
        
        return x
    
    def _s2_block(self, inputs, pointwise_conv_filters, alpha, 
                  depth_multiplier=1, strides=(1, 1), block_id=1,
                  kernel=(3,3), pool_size=(1,1), padding_size=(1, 1)):
        x = AveragePooling2D(pool_size=pool_size, strides=(2, 2), 
                             data_format=IMAGE_DATA_FORMAT, padding="same")(inputs)
#         print(x._keras_shape)
        x = self._depthwise_conv_block(x, pointwise_conv_filters, alpha, 
                                       block_id=block_id, kernel=kernel, 
                                       padding_size=padding_size)
#         print(x._keras_shape)
        x = UpSampling2D(size=(2, 2), interpolation="bilinear", name="s2_block_%s" % block_id)(x)
        
        x = BatchNormalization(axis=self.channel_axis)(x)
        x = Activation(self.relu6)(x)
        
        return x
    
    def _s2_module(self, inputs, pointwise_conv_filters, alpha,
                   depth_multiplier=1, strides=(1, 1), block_id=1,
                   kernel_conv=(3, 3), kernel_ds_1=(3, 3), 
                   kernel_ds_2=(3, 3), pad_ds_1=(1, 1), pad_ds_2=(1, 1),
                   pool_block_1=(1, 1), pool_block_2=(1, 1)):
        """
        The function to build S2 block
        """
        x = self._conv_block(inputs, pointwise_conv_filters, alpha, 
                             kernel=(1, 1), block_id=block_id, padding="same")
#         print(x._keras_shape)
        x1 = self._s2_block(x, pointwise_conv_filters, alpha, depth_multiplier=depth_multiplier,
                            strides=strides, kernel=kernel_ds_1, block_id=str(block_id) + "_1",
                            padding_size=pad_ds_1, pool_size=pool_block_1)
    
        x2 = self._s2_block(x, pointwise_conv_filters, alpha, depth_multiplier=depth_multiplier,
                            strides=strides, kernel=kernel_ds_2, block_id=str(block_id) + "_2", 
                            padding_size=pad_ds_2, pool_size=pool_block_2)
        
        x = Concatenate(axis=self.channel_axis)([x1, x2])
        x = Add()([inputs, x])
        x = BatchNormalization(axis=self.channel_axis)(x)
        x = PReLU()(x)
        
        return x
    
    def build_encoder(self, mean_substraction=[117, 117, 117]):
        """
        Build encoder function
        """
        
        input_shape = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNEL)
        
        if IMAGE_DATA_FORMAT == "channels_first":
            input_shape = (IMG_CHANNEL, IMG_HEIGHT, IMG_WIDTH)
            
        inputs = Input(shape=input_shape)
        
#         x = Lambda(lambda z: z[...,::-1], output_shape=input_shape, 
#                    name="swap_color_channel")(inputs)
        
#         if mean_substraction:
#             x = Lambda(lambda z: z - np.array(mean_substraction),
#                        output_shape=input_shape,
#                        name="mean_substraction_inputs")(x)
        x = inputs
    
        x1 = self._conv_block(x, 12, self.alpha, strides=(2, 2), block_id=1)
        x2 = self._depthwise_conv_se_block(x1, 16, self.alpha, block_id=2)
        x3 = self._depthwise_conv_se_block(x2, 48, self.alpha, block_id=3, strides=(1, 1))
        x4 = self._s2_module(x3, 24, self.alpha, block_id=4, kernel_ds_2=(5, 5), pad_ds_2=(2, 2))
        x5 = self._s2_module(x4, 24, self.alpha, block_id=5)
        
        x6 = Concatenate(axis=self.channel_axis, name="concat_2_5")([x2, x5])
        
        x7 = self._depthwise_conv_se_block(x6, 48, self.alpha, block_id=6)
        x8 = self._depthwise_conv_se_block(x7, 96, self.alpha, block_id=7, strides=(1, 1))
        x9 = self._s2_module(x8, 48, self.alpha, block_id=8, kernel_ds_2=(5, 5), pad_ds_2=(2, 2))
        x10 = self._s2_module(x9, 48, self.alpha, block_id=9)
        x11 = self._s2_module(x10, 48, self.alpha, block_id=10, 
                              kernel_ds_1=(5, 5), pad_ds_1=(2, 2),
                              kernel_ds_2=(3, 3), pool_block_2=(2, 2))
        x12 = self._s2_module(x11, 48, self.alpha, block_id=11,
                              kernel_ds_1=(5, 5), pad_ds_1=(2, 2),
                              kernel_ds_2=(3, 3), pool_block_2=(4, 4))
        x13 = self._s2_module(x12, 48, self.alpha, block_id=12)
        x14 = self._s2_module(x13, 48, self.alpha, block_id=13,
                              kernel_ds_1=(5, 5), pad_ds_1=(2, 2),
                              kernel_ds_2=(5, 5), pad_ds_2=(2, 2))
        x15 = self._s2_module(x14, 48, self.alpha, block_id=14,
                              kernel_ds_1=(3, 3), pool_block_1=(2, 2),
                              kernel_ds_2=(3, 3), pool_block_2=(4, 4))
        x16 = self._s2_module(x15, 48, self.alpha, block_id=15,
                              kernel_ds_1=(3, 3), pool_block_1=(1, 1),
                              kernel_ds_2=(5, 5), pad_ds_2=(2, 2), pool_block_2=(2, 2))
        x17 = Concatenate(axis=self.channel_axis, name="concat_16_7")([x16, x7])
        
        x = self._pointwise_conv_block(x17, N_CLASSES, self.alpha, block_id=16)
        
#         x = Reshape((-1, self.n_classes))(x)
        
#         x = Activation("softmax")(x)
        
#         model = Model(inputs=inputs, outputs=x)
        
        return inputs, x, x1, x2
    
    def build_decoder(self):
        inputs, x, x1, x2 = self.build_encoder()
        
        x = UpSampling2D((2, 2), data_format=IMAGE_DATA_FORMAT, 
                         interpolation="bilinear")(x)
        x_ac = Activation("softmax")(x)
        x_blocking = Lambda(lambda x: 1 - x, name="information_blocking_decoder")(x_ac)
        
        x2_pws = self._pointwise_conv_block(x2, self.n_classes, self.alpha, block_id=17)
        
        x_mul = Multiply()([x2_pws, x_blocking])
        x = Add()([x_mul, x])
        x = UpSampling2D((2, 2), interpolation="bilinear")(x)
        x = self._conv_block(x, self.n_classes, self.alpha, kernel=(1, 1), padding="same", block_id=18)
        x = UpSampling2D((2, 2), interpolation="bilinear")(x)
        x = Reshape((-1, self.n_classes))(x)
        x = Activation("softmax")(x)
        
        model = Model(inputs=inputs, outputs=x)
        
        return model

In [575]:
sinet = SiNet(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNEL, N_CLASSES)

In [576]:
model = sinet.build_decoder()

In [577]:
model.summary()

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
conv_1_pad (ZeroPadding2D)      (None, 226, 226, 3)  0           input_4[0][0]                    
__________________________________________________________________________________________________
conv_1 (Conv2D)                 (None, 112, 112, 12) 324         conv_1_pad[0][0]                 
__________________________________________________________________________________________________
conv_1_bn (BatchNormalization)  (None, 112, 112, 12) 48          conv_1[0][0]                     
____________________________________________________________________________________________

## Dataset Generation

In [578]:
os.listdir("Nukki/baidu_V1/")

['train.txt', '.DS_Store', 'input', 'val.txt', 'target']

In [579]:
# with open("Nukki/baidu_V1/train.txt") as f:
#     data = f.read()

In [580]:
files = pd.read_csv("Nukki/baidu_V1/train.txt", header=None)[0]

In [581]:
files.values

array(['501.png', '502.png', '503.png', ..., '5379.png', '5380.png',
       '5381.png'], dtype=object)

In [582]:
class DataAugmentation:
    def __init__(self):
        self.IMAGE_AUGMENTATION_SEQUENCE = None
        self.IMAGE_AUGMENTATION_NUM_TRIES = 10
        self.loaded_augmentation_name = ""
    
        self.augmentation_functions = {
            "aug_all": self._load_augmentation_aug_all,
            "aug_geometric": self._load_augmentation_aug_geometric,
            "aug_non_geometric": self._load_augmentation_aug_non_geometric,
        }
    
    def _load_augmentation_aug_geometric(self):
        return iaa.OneOf([
            iaa.Sometimes(0.5, iaa.Fliplr()),
            iaa.Sometimes(0.5, iaa.Rotate((-45, 45))),
            iaa.Sometimes(0.5, iaa.Affine(
                scale={"x": (0.5, 1.5), "y": (0.5, 1.5)},
                order=[0, 1],
                mode='constant',
                cval=(0, 255),
            )),
            iaa.Sometimes(0.5, iaa.Affine(
                translate_percent={"x": (-0.25, 0.25), "y": (-0.25, 0.25)},
                order=[0, 1],
                mode='constant',
                cval=(0, 255),
            )),
        ])
    
    def _load_augmentation_aug_non_geometric(self):
        return iaa.OneOf([
            iaa.Sometimes(0.5, iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5)),
            iaa.Sometimes(0.5, iaa.OneOf([
                iaa.GaussianBlur(sigma=(0.0, 3.0)),
                iaa.GaussianBlur(sigma=(0.0, 5.0))
            ])),
            iaa.Sometimes(0.5, iaa.MultiplyAndAddToBrightness(mul=(0.4, 1.7))),
            iaa.Sometimes(0.5, iaa.GammaContrast((0.4, 1.7))),
            iaa.Sometimes(0.5, iaa.Multiply((0.4, 1.7), per_channel=0.5)),
            iaa.Sometimes(0.5, iaa.MultiplyHue((0.4, 1.7))),
            iaa.Sometimes(0.5, iaa.MultiplyHueAndSaturation((0.4, 1.7), per_channel=True)),
            iaa.Sometimes(0.5, iaa.LinearContrast((0.4, 1.7), per_channel=0.5))
        ])
    
    def _load_augmentation_aug_all(self):
        return iaa.OneOf([
            iaa.Sometimes(0.5, self._load_augmentation_aug_geometric()),
            iaa.Sometimes(0.5, self._load_augmentation_aug_non_geometric())
        ])
    
    def _load_aug_by_name(self, aug_name="aug_all"):
        if not len(self.loaded_augmentation_name):
            self.loaded_augmentation_name = aug_name
            self.IMAGE_AUGMENTATION_SEQUENCE = self.augmentation_functions[aug_name]()
            
        return self.IMAGE_AUGMENTATION_SEQUENCE        

In [583]:
DATA_DIR = "Nukki/"
TRAIN_ANNO_FILE1 = "Nukki/baidu_V1/train.txt"
TRAIN_ANNO_FILE2 = "Nukki/baidu_V2/train.txt"
VAL_ANNO_FILE1 = "Nukki/baidu_V1/val.txt"
VAL_ANNO_FILE2 = "Nukki/baidu_V2/val.txt"

In [584]:
class DataGenerator:
    def __init__(self, 
                 data_dir, 
                 anno_paths, 
                 augment_func,
                 img_height=IMG_HEIGHT,
                 img_width=IMG_WIDTH,
                 img_channel=IMG_CHANNEL,
                 batch_size=36,
                 n_classes=N_CLASSES,
                 augmentation=True,
                 task="train"):
        
        self.data_dir = data_dir
        self.anno_paths = anno_paths
        self.augment_func = augment_func
        self.batch_size = batch_size
        self.task = task
        self.current_index = 0
#         self.current_test = 0
        self.img_height = img_height
        self.img_width = img_width
        self.img_channel = img_channel
        self.n_classes = n_classes
        self.augmentation = augmentation
        self.image_paths, self.label_paths = self.load_image_paths()
        
    def load_image_paths(self):
        image_paths = []
        label_paths = []
        
        for anno_path in self.anno_paths:
            anno_dir = os.path.dirname(anno_path)
            df_file_names = pd.read_csv(anno_path, header=None)
            file_names = df_file_names[0].values
            
            for fn in file_names:
                if "DS_Store" in fn: continue
                img_train_dir = os.path.join(anno_dir, "input")
                img_label_dir = os.path.join(anno_dir, "target")
                img_train_path = os.path.join(img_train_dir, fn)
                img_label_path = os.path.join(img_label_dir, fn)
        
                image_paths.append(img_train_path)
                label_paths.append(img_label_path)
            
        image_paths = np.array(image_paths)
        label_paths = np.array(label_paths)

        return image_paths, label_paths
    
    def get_n_examples(self):
        return len(self.image_paths)
        
    def load_image(self, img_path):
        image = cv2.imread(img_path)
#         image = image[..., ::-1]
        return image
    
    def parse_label(self, seg_image):
        label = seg_image
        label = label[:, :, 0]
        seg = np.zeros((self.img_height, self.img_width, self.n_classes))
        
        for i in range(self.n_classes):
            seg[:, :, i] = (label == i).astype("int")
        
        seg = np.reshape(seg, (-1, self.n_classes))
        return seg
    
    def load_batch_pair_in_pair(self, img_paths, label_paths):
        images = []
        segs = []
        preprocessors = [self.resize_img, self.mean_substraction]
        
        for i in range(len(img_paths)):
            image = self.load_image(img_paths[i])
#             image = self.preprocessing(image, preprocessors)
            
            seg = self.load_image(label_paths[i])
#             seg = self.preprocessing(seg, [self.resize_img])
#             print(seg.dtype)
#             seg = seg.astype(np.uint8)
#             seg = self.parse_label(seg)
#             seg = np.argmax(seg, axis=-1).astype(np.int32)
            segmap = SegmentationMapsOnImage(seg, shape=image.shape)
            
            if self.augmentation:
                image, segmap = self.augment_func(image=image, segmentation_maps=segmap)
            
            seg = segmap.get_arr()
#             print(seg.shape)
            image = self.preprocessing(image, preprocessors)
            
            seg = self.preprocessing(seg, [self.resize_img])
            
            seg = self.parse_label(seg)
            
            images.append(image)
            segs.append(seg)
            
#         if self.augmentation:
#             images, segs = self.augment_func(images=images, segmentation_maps=segs)
        
        return np.array(images), np.array(segs)
    
    def resize_img(self, image):
#         h, w = image.shape[:2]
#         print(h, w)
#         p_h, p_w = 0, 0
        
#         if h>w:
#             image = imutils.resize(image, width=self.img_width)
#             p_h = int((image.shape[0] - self.img_height) / 2)
#         else:
#             image = imutils.resize(image, height=self.img_height)
#             p_w = int((image.shape[1] - self.img_width) / 2)
        
#         image = image[p_w:image.shape[1]-p_w, p_h: image.shape[0]-p_h]
        image = cv2.resize(image, (self.img_height, self.img_width), interpolation=cv2.INTER_AREA)
        return image
    
    def mean_substraction(self, image, mean=[103.94, 116.78, 123.68], image_val=0.017):
        image = (image - np.array(mean))*image_val
        
        return image
    
    def preprocessing(self, image, preprocessors):
        for p in preprocessors:
            image = p(image)
        
        return image
    
    def load_batch(self):
        if self.current_index + self.batch_size >= len(self.image_paths):
            self.current_index = 0
            self.image_paths, self.label_paths = shuffle(self.image_paths, self.label_paths, random_state=42)
            
        img_batch_paths = self.image_paths[self.current_index:self.current_index+self.batch_size]
        seg_batch_paths = self.label_paths[self.current_index:self.current_index+self.batch_size]
        self.current_index += self.batch_size
        inputs, segs = self.load_batch_pair_in_pair(img_batch_paths, seg_batch_paths)
        
        return inputs, segs
    
    def generate(self):
        while True:
            images, labels = self.load_batch()
            
            yield (images, labels)

In [585]:
data_aug = DataAugmentation()

In [586]:
aug = data_aug._load_aug_by_name()

In [587]:
train_datagen = DataGenerator(DATA_DIR, [TRAIN_ANNO_FILE1, TRAIN_ANNO_FILE2], aug)

In [588]:
train_datagen.get_n_examples()

10160

In [589]:
train_datagen.label_paths

array(['Nukki/baidu_V1/target/501.png', 'Nukki/baidu_V1/target/502.png',
       'Nukki/baidu_V1/target/503.png', ...,
       'Nukki/baidu_V2/target/26852.jpg',
       'Nukki/baidu_V2/target/17457.jpg',
       'Nukki/baidu_V2/target/16749.jpg'], dtype='<U31')

In [590]:
val_datagen = DataGenerator(DATA_DIR, [VAL_ANNO_FILE1, VAL_ANNO_FILE2], aug)

In [591]:
val_datagen.get_n_examples()

600

In [592]:
# val_datagen.label_paths

In [593]:
def show_sample_image(img_paths):
    img_path = np.random.choice(img_paths, 1)[0]
    img = cv2.imread(img_path)
    img = cv2.resize(img, (224, 224))
#     img = imutils.resize(img, height=224, width=224)
#     a  = (img[..., 1] == img[..., 2])
    print(img[..., 0].shape)
    print(np.nonzero(img[..., 1]))
#     print(a.astype("int"))
#     ia.imshow(np.zeros((224, 224, 3)))
#     ia.imshow(img)

In [594]:
# show_sample_image(val_datagen.label_paths)

In [595]:
# imutils.resize?

## Training Model

In [596]:
from keras.optimizers import Adam
import keras.backend as K
import tensorflow as tf
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

In [597]:
# frotensorflow.no_op.n

In [598]:
class SINetLoss:
    def __init__(self, lamda=0.9):
        self.lamda = lamda
        
    def gt_dilation(self, y_true):
#         dilation = cv2.dilate(y_numpy, kernel, iterations=1)
        y_true = tf.reshape(y_true, (-1, IMG_HEIGHT, IMG_WIDTH, N_CLASSES))
        dilation = tf.nn.max_pool2d(y_true, ksize=(15, 15), strides=1, name='dilation2D', padding="SAME")
        dilation = tf.reshape(dilation, (-1, IMG_HEIGHT*IMG_WIDTH, N_CLASSES))
        
        return dilation
    
    def gt_erosion(self, y_true):
#         erosion = cv2.erode(y_numpy, kernel, iterations=1)
        y_true = tf.reshape(y_true, (-1, IMG_HEIGHT, IMG_WIDTH, N_CLASSES))
        erosion = -tf.nn.max_pool2d(-y_true, ksize=(15, 15), strides=1, name='erosion2D', padding="SAME")
        erosion = tf.reshape(erosion, (-1, IMG_HEIGHT*IMG_WIDTH, N_CLASSES))
        
        return erosion
        
    def log_loss(self, y_true, y_pred):
        loss = -tf.reduce_sum(y_true * tf.log(y_pred), axis=-1)
        return loss
        
    def boundary_loss(self, y_true, y_pred):
#         batch_size = tf.to_int32(tf.shape(y_true)[0])
#         batch_size = tf.shape(y_true)[0]
#         _y_true = []
        
#         for i in range(tf.to_float(batch_size)):
#             proto_tensor = tf.make_tensor_proto(y_true[i])
#             _y_true.append(tf.make_ndarray(proto_tensor))
            
#         y_true = _y_true
#         kernel = tf.ones((15, 15), dtype=tf.uint8)
#        with tf.Session() as sess:
#             sess.run(tf.global_variables_initializer())
#         kernel = (15, 15)
#         y_numpy = y_true.eval()
        dilation = self.gt_dilation(y_true)
        erosion = self.gt_erosion(y_true)

#         y_true = tf.reshape(y_true, (-1, IMG_HEIGHT, IMG_WIDTH, N_CLASSES))
#         dilation = tf.nn.dilation2d(y_true, 
#                                     filter=[15, 15, N_CLASSES], 
#                                     strides=[1, 1, 1, 1], 
#                                     rates=[1, 15, 15, 1],
#                                     padding="SAME")

#         erosion = tf.nn.erosion2d(y_true, 
#                                   filter=[15, 15, N_CLASSES], 
#                                   strides=[1, 1, 1, 1], 
#                                   rates=[1, 15, 15, 1])
#         boundary = dilation - erosion
        boundary = tf.math.subtract(dilation, erosion)
#         boundary = tf.reshape(boundary, (-1, IMG_HEIGHT*IMG_WIDTH, N_CLASSES))
#         assign_indices = tf.count_nonzero(boundary, axis=-1)
#         assign_indices = tf.where(boundary)
#         print(tf.shape(boundary).eval(session=tf.compat.v1.Session()))
#         mask = tf.math.greater(tf.to_float(boundary), tf.constant(0.0))
#         zeros_mask = tf.zeros_like(boundary)
        
#         b_true = tf.where(mask, y_true, zeros_mask)
#         b_pred = tf.where(mask, y_pred, zeros_mask)
        loss = -tf.reduce_sum(boundary * tf.log(y_pred), axis=-1)
        
        return loss
    
    def compute_loss(self, y_true, y_pred):
        batch_size = tf.shape(y_pred)[0]
        self.lamda = tf.constant(self.lamda)
        log_loss = tf.to_float(self.log_loss(y_true, y_pred))
        boundary_loss = tf.to_float(self.boundary_loss(y_true, y_pred))
        loss = log_loss + self.lamda * boundary_loss
#         loss = log_loss
        loss *= tf.to_float(batch_size)
        
        return loss

In [599]:
# tf.nn.dilation2d?

In [600]:
sinet_loss = SINetLoss().compute_loss

In [601]:
init_lr = 7.5e-3

opt = Adam(lr=init_lr, decay=2e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-07,)

In [602]:
model.compile(optimizer=opt, loss=sinet_loss, metrics=["accuracy"])

In [603]:
trained_weight = "trained_weight"

if not os.path.exists(trained_weight):
    os.mkdir(trained_weight)

In [604]:
import keras.backend as K
import keras.callbacks as cbks

In [605]:
class CustomMetrics(cbks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        for k in logs:
            if k.endswith('boundary_loss'):
                print(logs[k])

In [606]:
checkpoints = ModelCheckpoint(filepath=os.path.join(trained_weight, "best_weight.hdf5"), 
                              monitor="val_loss", 
                              save_best_only=True, 
                              verbose=1)

In [607]:
for idx, layer in enumerate(model.layers):
    print(idx, layer.name)
    
    if idx > 346:
        layer.trainable = False

0 input_4
1 conv_1_pad
2 conv_1
3 conv_1_bn
4 conv_1_act
5 conv_pad_2
6 conv_dw_2
7 conv_dw_2_bn
8 conv_dw_2_Prelu
9 conv_pw_2
10 conv_pw_2_bn
11 conv_pw_2_relu
12 squeeze_glo_avg_2
13 squeeze_squ_2
14 squeeze_exci_2
15 squeeze_scale_2
16 conv_pad_3
17 conv_dw_3
18 conv_dw_3_bn
19 conv_dw_3_Prelu
20 conv_pw_3
21 conv_pw_3_bn
22 conv_pw_3_relu
23 squeeze_glo_avg_3
24 squeeze_squ_3
25 squeeze_exci_3
26 squeeze_scale_3
27 conv_4
28 conv_4_bn
29 conv_4_act
30 average_pooling2d_41
31 average_pooling2d_42
32 conv_pad_4_1
33 conv_pad_4_2
34 conv_dw_4_1
35 conv_dw_4_2
36 conv_dw_4_1_bn
37 conv_dw_4_2_bn
38 conv_dw_4_1_Prelu
39 conv_dw_4_2_Prelu
40 conv_pw_4_1
41 conv_pw_4_2
42 conv_pw_4_1_bn
43 conv_pw_4_2_bn
44 conv_pw_4_1_relu
45 conv_pw_4_2_relu
46 s2_block_4_1
47 s2_block_4_2
48 batch_normalization_61
49 batch_normalization_62
50 activation_45
51 activation_46
52 concatenate_21
53 add_23
54 batch_normalization_63
55 p_re_lu_21
56 conv_5
57 conv_5_bn
58 conv_5_act
59 average_pooling2d_43
60

In [None]:
H = model.fit_generator(train_datagen.generate(),
                        epochs=300,
                        steps_per_epoch=train_datagen.get_n_examples() // train_datagen.batch_size,
                        validation_data=val_datagen.generate(),
                        validation_steps=val_datagen.get_n_examples() // val_datagen.batch_size,
                        callbacks=[checkpoints],
                        initial_epoch=0
                        )

Epoch 1/300

Epoch 00001: val_loss improved from inf to 42.11344, saving model to trained_weight/best_weight.hdf5
