In [None]:
####################################################
# Variables in this cell shall be assigned by user #
####################################################

# MODEL - name of the neural network architecture to be used
#       - Segnet or Xunet
# OUTPUT_DIR  - folder, where weights of trained model shall be saved 
# DATASET_DIR - folder, where 'images' and 'masks' folder of the dataset may be found
# WEIGHTS     - file with path containing pre-trained weights of neural network
# BATCHSIZE   - depending on GPU and RAM
# USE_AUG     - if augmentation of data should be used, set to True

OUTPUT_DIR = ''
DATASET_DIR = ''
TRAIN_FROM_SCRATCH = True  # True or False
WEIGHTS = ''               # if TRAIN_FROM_CRATCH is False
MODEL = 'Segnet'           # Segnet or Xunet
BATCHSIZE = 4              # Default 4, integer
USE_AUG = False            # True or False

In [None]:
import os
import time
import glob
import cv2
import numpy as np
import random
import imageio
import tensorflow.keras
import tensorflow as tf

import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

from tensorflow.keras.utils import to_categorical
from keras.metrics import MeanIoU

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.utils import normalize
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, \
    Dropout, Lambda, LeakyReLU, Add, ZeroPadding2D

# Information about size of images, color channels and number of classes to be segmented
n_classes = 4
IMG_HEIGHT = 512
IMG_WIDTH = 512
IMG_CHANNELS = 1

In [None]:
def get_model():
    if MODEL is 'Segnet':
        return Segnet()
    elif MODEL is 'Xunet':
        return Unet_Xception_ResNetBlock()
    else:
        return Segnet()

def Segnet(nClasses=4, input_height=512, input_width=512):
    inputs = Input(shape=(input_height, input_width, 1))
    #Encoder
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Decoder
    up7 = UpSampling2D(size=(2, 2))(pool4)
    conv7 = Conv2D(512, (3, 3), activation='relu', padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)

    up8 = UpSampling2D(size=(2, 2))(conv7)
    conv8 = Conv2D(256, (3, 3), activation='relu', padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)

    up9 = UpSampling2D(size=(2, 2))(conv8)
    conv9 = Conv2D(128, (3, 3), activation='relu', padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)

    up10 = UpSampling2D(size=(2, 2))(conv9)
    conv10 = Conv2D(64, (3, 3), activation='relu', padding='same')(up10)
    conv10 = BatchNormalization()(conv10)
    conv10 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv10)
    conv10 = BatchNormalization()(conv10)

    outputs = Conv2D(nClasses, (1, 1), padding='same', activation='softmax')(conv10)

    model = Model(inputs, outputs)
    
    return model

from keras.applications.xception import Xception

def convolution_block(x, filters, size, strides=(1,1), padding='same', activation=True):
    x = Conv2D(filters, size, strides=strides, padding=padding)(x)
    x = BatchNormalization()(x)
    if activation == True:
        x = LeakyReLU(alpha=0.1)(x)
    return x

def residual_block(blockInput, num_filters=16):
    x = LeakyReLU(alpha=0.1)(blockInput)
    x = BatchNormalization()(x)
    blockInput = BatchNormalization()(blockInput)
    x = convolution_block(x, num_filters, (3,3) )
    x = convolution_block(x, num_filters, (3,3), activation=False)
    x = Add()([x, blockInput])
    return x


def Unet_Xception_ResNetBlock(nClasses=4, input_height=512, input_width=512):
    
    backbone = Xception(input_shape=(input_height, input_width, 1), weights=None, include_top=False)
    
    inputs = backbone.input

    conv4 = backbone.layers[121].output
    conv4 = LeakyReLU(alpha=0.1)(conv4)
    pool4 = MaxPooling2D((2, 2))(conv4)
    pool4 = Dropout(0.1)(pool4)
    
    convm = Conv2D(16*32, (3, 3), activation=None, padding="same")(pool4)
    convm = residual_block(convm, 16*32)
    convm = residual_block(convm, 16*32)
    convm = LeakyReLU(alpha=0.1)(convm)
    
    deconv4 = Conv2DTranspose(16*16, (3, 3), strides=(2, 2), padding="same")(convm)
    uconv4 = concatenate([deconv4, conv4])
    uconv4 = Dropout(0.1)(uconv4)
    
    uconv4 = Conv2D(16*16, (3, 3), activation=None, padding="same")(uconv4)
    uconv4 = residual_block(uconv4, 16 * 16)
    uconv4 = residual_block(uconv4, 16*16)
    uconv4 = LeakyReLU(alpha=0.1)(uconv4)
    
    deconv3 = Conv2DTranspose(16*8, (3, 3), strides=(2, 2), padding="same")(uconv4)
    conv3 = backbone.layers[31].output
    uconv3 = concatenate([deconv3, conv3])    
    uconv3 = Dropout(0.1)(uconv3)
    
    uconv3 = Conv2D(16*8, (3, 3), activation=None, padding="same")(uconv3)
    uconv3 = residual_block(uconv3, 16*8)
    uconv3 = residual_block(uconv3, 16*8)
    uconv3 = LeakyReLU(alpha=0.1)(uconv3)

    deconv2 = Conv2DTranspose(16*4, (3, 3), strides=(2, 2), padding="same")(uconv3)
    conv2 = backbone.layers[21].output
    conv2 = ZeroPadding2D(((1,0),(1,0)))(conv2)
    uconv2 = concatenate([deconv2, conv2])
        
    uconv2 = Dropout(0.1)(uconv2)
    uconv2 = Conv2D(16*4, (3, 3), activation=None, padding="same")(uconv2)
    uconv2 = residual_block(uconv2, 16*4)
    uconv2 = residual_block(uconv2, 16*4)
    uconv2 = LeakyReLU(alpha=0.1)(uconv2)
    
    deconv1 = Conv2DTranspose(16*2, (3, 3), strides=(2, 2), padding="same")(uconv2)
    conv1 = backbone.layers[11].output
    conv1 = ZeroPadding2D(((3,0),(3,0)))(conv1)
    uconv1 = concatenate([deconv1, conv1])
    
    uconv1 = Dropout(0.1)(uconv1)
    uconv1 = Conv2D(16*2, (3, 3), activation=None, padding="same")(uconv1)
    uconv1 = residual_block(uconv1, 16*2)
    uconv1 = residual_block(uconv1, 16*2)
    uconv1 = LeakyReLU(alpha=0.1)(uconv1)
    
    uconv0 = Conv2DTranspose(16*1, (3, 3), strides=(2, 2), padding="same")(uconv1)   
    uconv0 = Dropout(0.1)(uconv0)
    uconv0 = Conv2D(16*1, (3, 3), activation=None, padding="same")(uconv0)
    uconv0 = residual_block(uconv0, 16*1)
    uconv0 = residual_block(uconv0, 16*1)
    uconv0 = LeakyReLU(alpha=0.1)(uconv0)
    
    uconv0 = Dropout(0.1/2)(uconv0)

    
    outputs = Conv2D(nClasses, (1, 1), padding='same', activation='softmax')(uconv0)

    model = Model(inputs, outputs)

    return model

In [None]:
def img_generator(img_dir, label_dir, batch_size):
    list_images = os.listdir(img_dir)
    random.shuffle(list_images)  # Randomize the choice of batches
    ids_train_split = range(len(list_images))

    ##############################################################
    sometimes7 = lambda aug: iaa.Sometimes(0.7, aug)
    sometimes2 = lambda aug: iaa.Sometimes(0.2, aug)

    seq = iaa.Sequential([
        iaa.OneOf([
            sometimes2(iaa.CropAndPad(percent=(0, 0.2), pad_mode="constant", pad_cval=160))
        ]),

        sometimes7(iaa.Affine(rotate=(-180, 180), mode='constant', cval=160)),
        iaa.Fliplr(0.4),
        iaa.Flipud(0.4)
    ], random_order=True)
    ##############################################################


    while True:
        for start in range(0, len(ids_train_split), batch_size):
            x_batch = []
            y_batch = []
            end = min(start + batch_size, len(ids_train_split))
            ids_train_batch = ids_train_split[start:end]
            for id in ids_train_batch:
                img = cv2.imread(os.path.join(img_dir, list_images[id]), 0)
                mask = imageio.imread(os.path.join(label_dir, list_images[id].replace('jpg', 'png')))

                ##############################################################
                segmap = SegmentationMapsOnImage(mask, shape=img.shape)
                images_aug_i, segmaps_aug_i = seq(image=img, segmentation_maps=segmap)
                segmaps_aug_i = segmaps_aug_i.get_arr()
                ##############################################################

                x_batch.append(images_aug_i)
                y_batch.append(segmaps_aug_i)


            x_batch = np.array(x_batch, np.float32) / 255.
            y_batch = np.array(y_batch, np.float32)

            x_batch = np.expand_dims(x_batch, axis=3)
            x_batch = normalize(x_batch, axis=1)

            y_batch = np.expand_dims(y_batch, axis=3)
            y_batch = to_categorical(y_batch, num_classes=4)
    
            yield x_batch, y_batch
            
def img_generator_not_aug(img_dir, label_dir, batch_size):
    list_images = os.listdir(img_dir)
    random.shuffle(list_images)  # Randomize the choice of batches
    ids_train_split = range(len(list_images))


    while True:
        for start in range(0, len(ids_train_split), batch_size):
            x_batch = []
            y_batch = []
            end = min(start + batch_size, len(ids_train_split))
            ids_train_batch = ids_train_split[start:end]
            for id in ids_train_batch:
                img = cv2.imread(os.path.join(img_dir, list_images[id]), 0)
                mask = imageio.imread(os.path.join(label_dir, list_images[id].replace('jpg', 'png')))

                x_batch.append(img)
                y_batch.append(mask)

            x_batch = np.array(x_batch, np.float32) / 255.
            y_batch = np.array(y_batch, np.float32)

            x_batch = np.expand_dims(x_batch, axis=3)
            x_batch = normalize(x_batch, axis=1)

            y_batch = np.expand_dims(y_batch, axis=3)
            y_batch = to_categorical(y_batch, num_classes=4)
            
            yield x_batch, y_batch

In [None]:
# TRAINING OF THE NET
model = get_model()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=[tf.keras.metrics.MeanIoU(num_classes=4)])

if TRAIN_FROM_SCRATCH is False:
    model.load_weights(WEIGHTS)

batchsize = BATCHSIZE

# generators for yielding batches of TRAIN and TEST data
img_traindir = DATASET_DIR + 'images/train//'
seg_traindir = DATASET_DIR + 'masks/train//'

img_testdir = DATASET_DIR + 'images/test//'
seg_testdir = DATASET_DIR + 'masks/test//'

if USE_AUG:
    train_generator = img_generator(img_traindir, seg_traindir, batchsize)
    test_generator = img_generator(img_testdir, seg_testdir, batchsize)
else:
    train_generator = img_generator_not_aug(img_traindir, seg_traindir, batchsize)
    test_generator = img_generator_not_aug(img_testdir, seg_testdir, batchsize)

#TODO: modelname
callbacks = [
    ModelCheckpoint(OUTPUT_DIR + '_' + MODEL + '_.h5', verbose=1, save_best_only=True),
    EarlyStopping(patience=15, monitor='val_loss'),
]


In [1]:
numof_train_images = 0
numof_test_images = 0

for f in os.listdir(DATASET_DIR + 'images/train//'):
    if os.path.isfile(f) and f.endswith(".png"):
        numof_train_images += 1

for f in os.listdir(DATASET_DIR + 'images/test//'):
    if os.path.isfile(f) and f.endswith(".png"):
        numof_test_images += 1

history = model.fit(train_generator,
                    verbose=1,
                    epochs=500,
                    callbacks=callbacks,
                    validation_data=test_generator,
                    steps_per_epoch=numof_train_images//batchsize,
                    validation_steps=numof_test_images//batchsize,
                    shuffle=False)

IndentationError: ignored