In [None]:
import numpy as np
import pandas as pd

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
!pip install imutils
!pip install keras-segmentation
import imutils

In [None]:
import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from keras.optimizers import SGD, Adam, RMSprop
from keras.layers import (
    Conv2D, BatchNormalization, MaxPooling2D, Dropout, Activation, UpSampling2D, concatenate, Reshape, ZeroPadding2D
)
import cv2
import imutils
from sklearn.utils import shuffle
import keras.backend as K
from keras.models import Sequential, Model, Input
from keras.regularizers import l2
import numpy as np
from keras.callbacks import ModelCheckpoint
import imgaug as ia
from imgaug import augmenters as iaa

## Constant variables

In [None]:
IMG_TRAIN_DIR = "../input/segmentsampledataset/dataset1/dataset1/images_prepped_train/"
ANNO_TRAIN_DIR = "../input/segmentsampledataset/dataset1/dataset1/annotations_prepped_train/"
IMG_VAL_DIR = "../input/segmentsampledataset/dataset1/dataset1/images_prepped_test/"
ANNO_VAL_DIR = "../input/segmentsampledataset/dataset1/dataset1/annotations_prepped_test/"
BATCH_SIZE = 16
HEIGHT = 224
WIDTH = 224
CHANNEL = 3
IMAGE_AUGMENTATION_SEQUENCE= None
IMAGE_AUGMENTATION_NUM_TRIES = 10
IMAGE_AUMENTATION_NAME_LOADED = ""
N_CLASSES=12

## Data Augmentation

In [None]:
def _load_augmenation_aug_geometric():
    return iaa.OneOf([
        iaa.Sequential([iaa.Fliplr(0.5), iaa.Flipud(0.2)]),
        iaa.CropAndPad(percent=(-0.05, 0.1),
                       pad_mode="constant",
                       pad_cval=(0, 255)),
        iaa.Crop(percent=(0.0, 0.1)),
        iaa.Crop(percent=(0.3, 0.5)),
        iaa.Crop(percent=(0.3, 0.5)),
        iaa.Crop(percent=(0.3, 0.5)),
        iaa.Sequential([
            iaa.Affine(
                scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
                translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
                rotate=(-45, 45),
                shear=(-16, 16),
                # Use nearest neighborhood or bilinear interpolation
                order=[0,1],
                mode="constant",
                cval=(0, 255),
            ),
            iaa.Sometimes(0.3, iaa.Crop(percent=(0.3, 0.5)))
        ])
    ])

In [None]:
def _load_augmentation_agu_non_geometric():
    return iaa.Sequential([
        iaa.Sometimes(0.3, iaa.Multiply((0.5, 1.5), per_channel=0.5)),
        iaa.Sometimes(0.2, iaa.JpegCompression(compression=(70, 99))),
        iaa.Sometimes(0.2, iaa.GaussianBlur(sigma=(0, 3.0))),
        iaa.Sometimes(0.2, iaa.MotionBlur(k=15, angle=[-45, 45])),
        iaa.Sometimes(0.2, iaa.MultiplyHue((0.5, 1.5))),
        iaa.Sometimes(0.2, iaa.MultiplySaturation((0.5, 1.5))),
        iaa.Sometimes(0.34, iaa.MultiplyHueAndSaturation((0.5, 1.5), per_channel=True)),
        iaa.Sometimes(0.34, iaa.Grayscale(alpha=(0.0, 1.0))),
        iaa.Sometimes(0.2, iaa.ChangeColorTemperature((1100, 10000))),
        iaa.Sometimes(0.1, iaa.GammaContrast((0.5, 2.0))),
        iaa.Sometimes(0.2, iaa.SigmoidContrast(gain=(3, 10),
                                               cutoff=(0.4, 0.6))),
        iaa.Sometimes(0.1, iaa.CLAHE()),
        iaa.Sometimes(0.1, iaa.HistogramEqualization()),
        iaa.Sometimes(0.2, iaa.LinearContrast((0.5, 2.0), per_channel=0.5)),
        iaa.Sometimes(0.1, iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)))
    ])

In [None]:
def _load_augmentation_aug_all2():
    return iaa.Sequential([
        iaa.Sometimes(0.65, _load_augmenation_aug_geometric()),
        iaa.Sometimes(0.65, _load_augmentation_agu_non_geometric()),
    ])

In [None]:
def _load_augmentation_aug_all():
    def sometimes(aug):
        return iaa.Sometimes(0.5, aug)
    
    return iaa.Sequential([
        iaa.Fliplr(0.5),
        iaa.Flipud(0.2),
        sometimes(iaa.CropAndPad(
            percent=(-0.05, 0.1),
            pad_mode="constant",
            pad_cval=(0, 255)
        )),
        sometimes(iaa.Affine(
            scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
            translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
            rotate=(-45, 45),
            shear=(-16, 16),
            # Use nearest neighborhood or bilinear interpolation
            order=[0,1],
            mode="constant",
            cval=(0, 255),
        )),
        iaa.SomeOf((0,5), 
                   [
            # convert images into their superpixel representation
            sometimes(iaa.Superpixels(p_replace=(0, 1.0), n_segments=(20, 200))),
            iaa.OneOf([
                # blur images with a sigma between 0 and 3.0
                iaa.GaussianBlur((0, 3.0)),
                # blur image using local means with kernel sizes
                # between 2 and 7
                iaa.AverageBlur(k=(2, 7)),
                # blur image using local medians with kernel sizes
                # between 2 and 7
                iaa.MedianBlur(k=(3, 11)),
            ]),
            iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)),  # sharpen images
            iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)),  # emboss images
            iaa.SimplexNoiseAlpha(iaa.OneOf([
                iaa.EdgeDetect(alpha=(0.5, 1.0)),
                iaa.DirectedEdgeDetect(
                    alpha=(0.5, 1.0), direction=(0.0, 1.0)),
            ])),
            # add gaussian noise to images
            iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5),
            iaa.OneOf([
                # randomly remove up to 10% of the pixels
                iaa.Dropout((0.01, 0.1), per_channel=0.5),
                iaa.CoarseDropout((0.03, 0.15), size_percent=(
                    0.02, 0.05), per_channel=0.2),
            ]),
            # invert color channels
            iaa.Invert(0.05, per_channel=True),
            # change brightness of images (by -10 to 10 of original value)
            iaa.Add((-10, 10), per_channel=0.5),
            # change hue and saturation
            iaa.AddToHueAndSaturation((-20, 20)),
            # either change the brightness of the whole image (sometimes
            # per channel) or change the brightness of subareas
            iaa.OneOf([
                iaa.Multiply((0.5, 1.5), per_channel=0.5),
                iaa.FrequencyNoiseAlpha(
                    exponent=(-4, 0),
                    first=iaa.Multiply((0.5, 1.5), per_channel=True),
                    second=iaa.ContrastNormalization((0.5, 2.0))
                )
            ]),
            # improve or worsen the contrast
            iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5),
            iaa.Grayscale(alpha=(0.0, 1.0)),
            # move pixels locally around (with random strengths)
            sometimes(iaa.ElasticTransformation(alpha=(0.5, 3.5), sigma=0.25)),
            # sometimes move parts of the image around
            sometimes(iaa.PiecewiseAffine(scale=(0.01, 0.05))),
            sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1)))
        ], 
            random_order=True
        )
    ], random_order=True)

In [None]:
list_augm_support = {
    "aug_all": _load_augmentation_aug_all,
    "aug_all2": _load_augmentation_aug_all2,
    "aug_geometric": _load_augmenation_aug_geometric,
    "aug_non_geometric": _load_augmentation_agu_non_geometric
}

In [None]:
def load_aug(aug_name="aug_all"):
    global IMAGE_AUGMENTATION_SEQUENCE
    
    if not aug_name in list_augm_support.keys():
        raise ValueError("Aug function does not support")
    
    IMAGE_AUGMENTATION_SEQUENCE = list_augm_support[aug_name]()
    
    return IMAGE_AUGMENTATION_SEQUENCE

In [None]:
def aug_image(image, seg, aug_name="aug_all"):
    global IMAGE_AUMENTATION_NAME_LOADED
    
    if not IMAGE_AUMENTATION_NAME_LOADED or IMAGE_AUGMENTATION_SEQUENCE is None:
        load_aug(aug_name)
        IMAGE_AUMENTATION_NAME_LOADED = aug_name
    
    # Create a deterministic from the random one
    aug_det = IMAGE_AUGMENTATION_SEQUENCE.to_deterministic()
    image_aug = aug_det.augment_image(image)
    
    segmap = ia.SegmentationMapOnImage(seg, nb_classes=np.max(seg) + 1, shape=image.shape)
    segmap_aug = aug_det.augment_segmentation_maps(segmap)
    segmap_aug = segmap_aug.get_arr_int()
    
    return image_aug, segmap_aug

In [None]:
def try_n_times(fn, n, *args, **kwargs):
    attemps = 0
    
    while attemps < n:
        try:
            return fn(*args, **kwargs)
        except:
            attemps += 1
    
    return fn(*args, **kwargs)

In [None]:
def augment_seg(img, seg, aug_name="aug_all"):
    return try_n_times(aug_image, IMAGE_AUGMENTATION_NUM_TRIES, img, seg)

## Data Generation

In [None]:
class DataGeneration:
    def __init__(self, img_train_dir, img_test_dir, anno_train_dir, anno_test_dir, batch_size,
                 width, height, channel, nb_classes, apply_aug=True):
        self.img_train_dir = img_train_dir
        self.img_test_dir = img_test_dir
        self.anno_train_dir = anno_train_dir
        self.anno_test_dir = anno_test_dir
        self.batch_size = batch_size
        self.width = width
        self.height = height
        self.channel = channel
        self.current_train = 0 
        self.current_test = 0
        self.image_train_paths = self.load_image_paths(self.img_train_dir)
#         self.anno_train_paths = self.load_image_paths(self.anno_train_dir)
        self.image_test_paths = self.load_image_paths(self.img_test_dir)
#         self.anno_test_paths = self.load_image_paths(self.anno_test_dir)
        self.nb_classes = nb_classes
        self.apply_aug = apply_aug
        
    def load_image_paths(self, data_path):
        image_paths = []
        
        for img in os.listdir(data_path):
            img_path = os.path.join(data_path, img)
            image_paths.append(img_path)
        
        image_paths = np.array(image_paths)
        return image_paths
    
    def image_preprocess(self, img):
        h, w = img.shape[:2]
        d_w = 0
        d_h = 0
        
        if h > w:
            img = imutils.resize(img, width=self.width)
            d_h = int((img.shape[0] - self.heigth)/2)
        else:
            img = imutils.resize(img, height=self.height)
            d_w = int((img.shape[1] - self.width)/2)
            
        img = img[d_h:img.shape[1]-d_h, d_w: img.shape[0]-d_w]
        img = cv2.resize(img, (self.height, self.width), interpolation=cv2.INTER_AREA)
        return img
    
    def load_data(self, image_paths):
        images = []
        segs = []
        
        for i in range(0, len(image_paths)):
            labels = np.zeros((self.height, self.width, self.nb_classes))
            image = cv2.imread(image_paths[i])
            image = image[:,:,::-1]
            image = self.image_preprocess(image)    
            
            label_path = image_paths[i].replace('images', 'annotations')
            label = cv2.imread(label_path, 0)
            label = self.image_preprocess(label)
            label = np.expand_dims(label, axis=-1)
        
            if self.apply_aug:
                image, label = augment_seg(image, label)
            
            for c in range(self.nb_classes):
                labels[:, :, c] = (label == c).astype("int")
    
            image = image.astype('float')/255
            images.append(image)
            labels = np.reshape(labels, (self.width*self.height, self.nb_classes))
#             labels = labels.astype('float')/255
            segs.append(labels)
            
        segs = np.array(segs) 
        images = np.array(images)        

        return images, segs

    def next_train(self):
        if self.current_train + self.batch_size > len(self.image_train_paths):
            self.current_train = 0
            self.image_train_paths = shuffle(self.image_train_paths, random_state=42)
            
        batch_X_paths = self.image_train_paths[self.current_train:self.current_train+self.batch_size]
        batch_X, batch_y = self.load_data(batch_X_paths)
        self.current_train += self.batch_size
        
        return batch_X, batch_y
    
    def next_test(self):
        if self.current_test + self.batch_size > len(self.image_test_paths):
            self.current_test = 0
            self.image_test_paths = shuffle(self.image_test_paths, random_state=42)
            
        batch_X_paths = self.image_test_paths[self.current_test:self.current_test+self.batch_size]
        batch_X, batch_y = self.load_data(batch_X_paths)
        self.current_test += self.batch_size
        
        return batch_X, batch_y
    
    def get_batch(self, s="train"):
        while True:
            if s == "train":
                batch_X, batch_y = self.next_train()
            else:
                batch_X, batch_y = self.next_test()

            yield (batch_X, batch_y)

In [None]:
data_gener = DataGeneration(IMG_TRAIN_DIR, IMG_VAL_DIR, ANNO_TRAIN_DIR, ANNO_VAL_DIR, BATCH_SIZE, WIDTH, HEIGHT, CHANNEL, N_CLASSES)

In [None]:
img_test_path = data_gener.image_test_paths[0]

In [None]:
img_seg_test_path = data_gener.image_test_paths[0].replace('images', 'annotations')

In [None]:
img_test = cv2.imread(img_test_path)
img_seg_test = cv2.imread(img_seg_test_path, 0)
img_seg_test = np.expand_dims(img_seg_test, axis=-1)
img_seg_test.shape

In [None]:
X_val, y_val = data_gener.load_data(data_gener.image_test_paths)

In [None]:
X_val.shape

In [None]:
len(data_gener.image_train_paths)

In [None]:
img_aug, img_seg_aug = augment_seg(img_test, img_seg_test)

In [None]:
# segmap = ia.SegmentationMapOnImage(img_seg_test, shape=img_test.shape, nb_classes=np.max(img_seg_test)+1)

In [None]:
plt.imshow(img_aug)

In [None]:
plt.imshow(img_seg_aug)

## Model Architecture

In [None]:
def initial_bias(shape, dtype=None):
    return np.random.normal(loc=0.5, scale=1e-2, size=shape)

In [None]:
def initial_weight(shape, dtype=None):
    return np.random.normal(loc=0.0, scale=1e-2, size=shape)

In [None]:
def base_model(intput_height=HEIGHT, input_width=WIDTH, image_ordering="channels_last"):
    bn = -1
    input_shape=(intput_height, input_width, 3)
    pad = 1
    pool_size = 2
    filter_size = [64, 128, 256]
    levels = []
    
    if image_ordering=="channels_first":
        bn = 1
        input_shape=(3, intput_height, input_width)
    
    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((pad, pad), data_format=image_ordering)(img_input)
    x = Conv2D(filter_size[0], (3, 3), padding='valid', data_format=image_ordering)(x)
    x = BatchNormalization(axis=bn)(x)
    x = Activation("relu")(x)
    x = ZeroPadding2D((pad, pad), data_format=image_ordering)(x)
    x = Conv2D(filter_size[0], (3, 3), padding='valid', data_format=image_ordering)(x)
    x = BatchNormalization(axis=bn)(x)
    x = Activation("relu")(x)
    x = MaxPooling2D(pool_size=(pool_size, pool_size))(x)
    levels.append(x)
    
    x = ZeroPadding2D((pad, pad), data_format=image_ordering)(x)
    x = Conv2D(filter_size[1], (3, 3), padding='valid', data_format=image_ordering)(x)
    x = BatchNormalization(axis=bn)(x)
    x = Activation("relu")(x)
    x = ZeroPadding2D((pad, pad), data_format=image_ordering)(x)
    x = Conv2D(filter_size[1], (3, 3), padding='valid', data_format=image_ordering)(x)
    x = BatchNormalization(axis=bn)(x)
    x = Activation("relu")(x)
    x = MaxPooling2D(pool_size=(pool_size, pool_size))(x)
    levels.append(x)
    
    for _ in range(3):
        x = ZeroPadding2D((pad, pad), data_format=image_ordering)(x)
        x = Conv2D(filter_size[2], (3, 3), padding='valid', data_format=image_ordering)(x)
        x = BatchNormalization(axis=bn)(x)
        x = Activation("relu")(x)
        x = ZeroPadding2D((pad, pad), data_format=image_ordering)(x)
        x = Conv2D(filter_size[2], (3, 3), padding='valid', data_format=image_ordering)(x)
        x = BatchNormalization(axis=bn)(x)
        x = Activation("relu")(x)
        x = MaxPooling2D(pool_size=(pool_size, pool_size))(x)
        levels.append(x)
        
    return img_input, levels

In [None]:
# class UnetSegment:
#     @staticmethod
#     def build(height, width, channel, n_classes):
#         input_shape = (height, width, channel)
        
#         if K.image_data_format() == "channel_firsts":
#             input_shape = (channel, height, width)
        
#         model = Model()
#         inputs = Input(shape=input_shape)
        
#         #Conv block 1
#         conv_1 = Conv2D(32, (3, 3), padding="same", activation="relu", 
#                         bias_initializer=initial_bias)(inputs)
#         conv_1 = Conv2D(32, (3, 3), padding="same", activation="relu", 
#                         bias_initializer=initial_bias, 
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(conv_1)
#         batch_1 = BatchNormalization()(conv_1)
#         maxpool_1 = MaxPooling2D(pool_size=(2,2))(batch_1)
        
#         #Conv block 2
#         conv_2 = Conv2D(64, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(maxpool_1)
#         conv_2 = Conv2D(64, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(conv_2)
#         batch_2 = BatchNormalization()(conv_2)
#         maxpool_2 = MaxPooling2D(pool_size=(2,2))(batch_2)
        
#         #Conv block 3
#         conv_3 = Conv2D(128, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(maxpool_2)
#         conv_3 = Conv2D(128, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(conv_3)
#         batch_3 = BatchNormalization()(conv_3)
#         maxpool_3 = MaxPooling2D(pool_size=(2,2))(batch_3)
        
#         #Conv block 4
#         conv_4 = Conv2D(256, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(maxpool_3)
#         conv_4 = Conv2D(256, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(conv_4)
#         batch_4 = BatchNormalization()(conv_4)
#         maxpool_4 = MaxPooling2D(pool_size=(2,2))(batch_4)
        
#         #Conv block 5
#         conv_4_ = Conv2D(512, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(maxpool_4)
#         drop_4_ = Dropout(0.2)(conv_4_)
#         conv_4_ = Conv2D(512, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(conv_4_)
        
#         # Decoder block 1
#         up_1 = concatenate([UpSampling2D((2, 2))(conv_4_), conv_4], axis=-1)
#         conv_5 = Conv2D(256, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(up_1)
#         conv_5 = Conv2D(256, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(conv_5)
#         batch_5 = BatchNormalization()(conv_5)
        
#         # Decoder block 2
#         up_2 = concatenate([UpSampling2D((2, 2))(batch_5), conv_3], axis=-1)
#         conv_6 = Conv2D(128, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(up_2)
#         conv_6 = Conv2D(128, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(conv_6)
#         batch_6 = BatchNormalization()(conv_6)
        
#         # Decoder block 3
#         up_3 = concatenate([UpSampling2D((2, 2))(batch_6), conv_2], axis=-1)
#         conv_7 = Conv2D(64, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(up_3)
#         conv_7 = Conv2D(64, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(conv_7)
#         batch_7 = BatchNormalization()(conv_7)
        
#         # Decoder block 4
#         up_4 = concatenate([UpSampling2D((2, 2))(batch_7), conv_1], axis=-1)
#         conv_8 = Conv2D(32, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(up_4)
#         conv_8 = Conv2D(32, (3, 3), padding="same", activation="relu",
#                         bias_initializer=initial_bias,
#                         kernel_initializer=initial_weight,
#                         kernel_regularizer=l2(1e-4))(conv_8)
#         batch_8 = BatchNormalization()(conv_8)
        
#         # Outputs model
#         out = Conv2D(N_CLASSES, (1, 1), padding='same', activation="relu")(batch_8)
#         out = Reshape((out.shape[1]*out.shape[2], -1))(out)
#         out = Activation("softmax")(out)
# #         print(out.shape)
        
#         model = Model(inputs=inputs, outputs=out)
        
        
#         return model

In [None]:
class Unet:
    @staticmethod
    def build(encoder, n_classes=N_CLASSES, image_ordering="channels_last", input_width=WIDTH, input_height=HEIGHT):
        bn = -1
        MERGE_AXIS = -1
        pool_size = 2
        pad = 1
        
        img_input, levels = encoder()
        f1, f2, f3, f4, f5 = levels
        
        if image_ordering=="channels_first":
            bn = 1
            MERGE_AXIS = 1
        
        o = f4
        o = (ZeroPadding2D((1, 1), data_format=image_ordering))(o)
        o = (Conv2D(512, (3, 3), padding='valid' , activation='relu' , data_format=image_ordering))(o)
        o = (BatchNormalization())(o)
        o = (ZeroPadding2D((1, 1), data_format=image_ordering))(o)
        o = (Conv2D(512, (3, 3), padding='valid' , activation='relu' , data_format=image_ordering))(o)
        o = (BatchNormalization())(o)
        
        o = (UpSampling2D((2, 2), data_format=image_ordering))(o)
        o = (concatenate([o, f3], axis=MERGE_AXIS))
        o = (ZeroPadding2D((1, 1), data_format=image_ordering))(o)
        o = (Conv2D(256, (3, 3), padding='valid', activation='relu' , data_format=image_ordering))(o)
        o = (BatchNormalization())(o)
        o = (ZeroPadding2D((1, 1), data_format=image_ordering))(o)
        o = (Conv2D(256, (3, 3), padding='valid', activation='relu' , data_format=image_ordering))(o)
        o = (BatchNormalization())(o)
        
        o = (UpSampling2D((2, 2), data_format=image_ordering))(o)
        o = (concatenate([o, f2], axis=MERGE_AXIS))
        o = (ZeroPadding2D((1, 1), data_format=image_ordering))(o)
        o = (Conv2D(256, (3, 3), padding='valid' , activation='relu' , data_format=image_ordering))(o)
        o = (BatchNormalization())(o)
        o = (ZeroPadding2D((1, 1), data_format=image_ordering))(o)
        o = (Conv2D(256, (3, 3), padding='valid' , activation='relu' , data_format=image_ordering))(o)
        o = (BatchNormalization())(o)
        
        o = (UpSampling2D((2, 2), data_format=image_ordering))(o)
        o = (concatenate([o, f1], axis=MERGE_AXIS))
        o = (ZeroPadding2D((1, 1), data_format=image_ordering))(o)
        o = (Conv2D(128, (3, 3), padding='valid', activation='relu', data_format=image_ordering))(o)
        o = (BatchNormalization())(o)
        o = (ZeroPadding2D((1, 1), data_format=image_ordering))(o)
        o = (Conv2D(128, (3, 3), padding='valid', activation='relu', data_format=image_ordering))(o)
        o = (BatchNormalization())(o)
        
        o = (UpSampling2D((2, 2), data_format=image_ordering))(o)
#         o = (concatenate([o, f1], axis=MERGE_AXIS))
#         o = (ZeroPadding2D((1, 1), data_format=image_ordering))(o)
#         o = (Conv2D(64, (3, 3), padding='valid', activation='relu', data_format=image_ordering))(o)
#         o = (BatchNormalization())(o)
        
        o = Conv2D(n_classes, (3, 3), padding='same', data_format=image_ordering)(o)
        print(o.shape)
        o = Reshape((o.shape[1]*o.shape[2], -1))(o)
        out = Activation("softmax")(o)
        
        model = Model(inputs=img_input, outputs=out)
        
        return model

In [None]:
unet_segment = Unet.build(base_model)

In [None]:
unet_segment.summary()

In [None]:
opt = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False)

In [None]:
weight_dir = "./weights"

if not os.path.exists(weight_dir):
    os.mkdir(weight_dir)
    
file_path = os.path.join(weight_dir, "best_weigth.hdf5")

In [None]:
checkpoint = ModelCheckpoint(filepath=file_path, mode="min", monitor="val_loss", save_best_only=True, verbose=1)

In [None]:
unet_segment.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])

In [None]:
H = unet_segment.fit_generator(data_gener.get_batch(),
                               validation_data=(X_val, y_val),
                               validation_steps=X_val.shape[0]//BATCH_SIZE,
                               epochs=200, steps_per_epoch=len(data_gener.image_train_paths)//BATCH_SIZE, 
                               verbose=1, callbacks=[checkpoint], initial_epoch=0)

## Prediction

In [None]:
from random import ranint

In [None]:
def prediction():
    idx = ranint(0, len(data_gener.image_test_paths))
    img_path = data_gener.image_test_paths[idx]
    image, seg = data_gener.load_data([img_path])
    p = unet_segment.predict(image)[0]
    p = p.argmax(axis=2)
    
    seg_img = np.zeros(())