In [4]:
from tensorflow.keras.applications import Xception, DenseNet201, MobileNetV2, DenseNet169, DenseNet121, ResNet50, MobileNetV2
from tensorflow.keras.layers import Softmax, ReLU, GlobalAveragePooling2D, Dense, UpSampling2D, Input, Activation, Concatenate
from tensorflow.keras.layers import Conv2D, GlobalAveragePooling1D, Flatten, MaxPooling2D, Add
from tensorflow.keras.layers import Dropout, BatchNormalization, Input
from tensorflow.keras.models import Sequential, save_model, load_model, Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from glob import glob
from skimage.io import imread, imsave
from skimage.transform import resize
from sklearn.utils import shuffle
import numpy as np
import matplotlib.pyplot as plt
import PIL
from tensorflow.math import multiply

from tensorflow.keras import backend as K
def IoULoss(targets, inputs, smooth=1e-6):
    #flatten label and prediction tensors
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)
    intersection = K.sum(multiply(targets, inputs))
    total = K.sum(targets) + K.sum(inputs)
    union = total - intersection
    
    IoU = (intersection + smooth) / (union + smooth)
    return 1 - IoU
def IoUMetrics(targets, inputs, smooth=1e-6):
    #flatten label and prediction tensors
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)
    intersection = K.sum(multiply(targets, inputs))
    total = K.sum(targets) + K.sum(inputs)
    union = total - intersection
    
    IoU = (intersection + smooth) / (union + smooth)
    return IoU
def IoUSingle(gt, pred):
    return np.logical_and(gt, pred).sum() / np.logical_or(gt, pred).sum()

def DiceLoss(targets, inputs, smooth=1e-6):
    
    #flatten label and prediction tensors
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)
    
    intersection = K.sum(multiply(targets, inputs))
    dice = (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    return 1 - dice
  
basic_shape = (512, 512, 3)
basic_out_shape = (512, 512, 1)
def train_model(train_img_dir='./train_data/'):
    #b_model = DenseNet121(include_top=False, weights='imagenet', input_shape=(300, 300, 3))
    #print(b_model.summary())
    #print(b_model.output_shape)
    in_layer = Input(shape=basic_shape)

    conv1 = Conv2D(20, kernel_size=(3, 3), padding='same')(in_layer)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)

    conv2 = MaxPooling2D()(conv1)
    conv2 = Conv2D(30, kernel_size=(5, 5), padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)

    conv3 = MaxPooling2D()(conv2)
    conv3 = Conv2D(40, kernel_size=(5, 5), padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)

    conv4 = MaxPooling2D()(conv3)
    conv4 = Conv2D(50, kernel_size=(5, 5), padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)

    conv5 = MaxPooling2D()(conv4)
    conv5 = Conv2D(60, kernel_size=(3, 3), padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)

    conv6 = MaxPooling2D()(conv5)
    conv6 = Conv2D(80, kernel_size=(3, 3), padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)

    conv7 = MaxPooling2D()(conv6)
    conv7 = Conv2D(110, kernel_size=(3, 3), padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)

    convb = UpSampling2D()(conv7)
    convb = Concatenate()([convb, conv6])
    convb = Conv2D(80, kernel_size=(3, 3), padding='same')(convb)
    convb = BatchNormalization()(convb)
    convb = Activation('relu')(convb)

    convb = UpSampling2D()(convb)
    convb = Concatenate()([convb, conv5])
    convb = Conv2D(60, kernel_size=(3, 3), padding='same')(convb)
    convb = BatchNormalization()(convb)
    convb = Activation('relu')(convb)

    convb = UpSampling2D()(convb)
    convb = Concatenate()([convb, conv4])
    convb = Conv2D(50, kernel_size=(3, 3), padding='same')(convb)
    convb = BatchNormalization()(convb)
    convb = Activation('relu')(convb)

    convb = UpSampling2D()(convb)
    convb = Concatenate()([convb, conv3])
    convb = Conv2D(40, kernel_size=(5, 5), padding='same')(convb)
    convb = BatchNormalization()(convb)
    convb = Activation('relu')(convb)

    convb = UpSampling2D()(convb)
    convb = Concatenate()([convb, conv2])
    convb = Conv2D(30, kernel_size=(5, 5), padding='same')(convb)
    convb = BatchNormalization()(convb)
    convb = Activation('relu')(convb)

    convb = UpSampling2D()(convb)
    convb = Concatenate()([convb, conv1])
    convb = Conv2D(20, kernel_size=(3, 3), padding='same')(convb)
    convb = BatchNormalization()(convb)
    convb = Activation('relu')(convb)

    convb = Conv2D(1, kernel_size=(3, 3), padding='same')(convb)
    convf = Activation('sigmoid')(convb)

    model = Model(inputs=in_layer, outputs=convf)

    #model = load_model('./gdrive/My Drive/segm.hdf5', compile=False)
    model.compile(loss=IoULoss, optimizer=Adam(), metrics=[IoUMetrics])

    img_names_train = sorted(glob(train_img_dir + 'birds/' + '*.jpg'))
    img_names_gt = sorted(glob(train_img_dir+ 'gt/' + '*.png'))

    train_split = 762                                                           #size of all set is 8382 which is 11 * 762
    epochs = 30

    img_names_train, img_names_gt = shuffle(img_names_train, img_names_gt, random_state=1337)

    for j in range(0, len(img_names_train), train_split):

        print('=========================' + str(j) + '=========================')
        img_train = np.zeros((train_split, ) + basic_shape, dtype='uint8')
        img_gt = np.zeros((train_split, ) + basic_out_shape, dtype='float32')

        for i in range(train_split):
            tmp = imread(img_names_train[j + i])
            if len(tmp.shape) != 3:
                tmp = np.dstack((tmp, tmp, tmp))

            img_train[i, :tmp.shape[0], :tmp.shape[1], :] += tmp
            tmp = imread(img_names_gt[j + i], as_gray=True)
            img_gt[i, :tmp.shape[0], :tmp.shape[1], 0] += tmp
            
        img_gt /= 255.0

        model.fit(img_train, img_gt, batch_size=8, epochs=epochs)
        model.save('segm.hdf5')
        print('=========================End of' + str(j) + '=========================')

    return model

def predict(model, img_path='./1classt/'):
    img_names_test = sorted(glob(img_path + 'albatross_pred/' + '*.jpg'))
    img_names_gt = sorted(glob(img_path + 'albatross_gt/' + '*.png'))

    img_test = np.zeros((len(img_names_test), ) + basic_shape, dtype='uint8')
    img_gt = np.zeros((len(img_names_gt), ) + basic_out_shape, dtype='float32')
    sizes = np.zeros((len(img_names_gt), 2))

    for i in range(0, len(img_names_test)):
        tmp = imread(img_names_test[i])
        if len(tmp.shape) != 3 or tmp.shape[2] != 3:
            tmp = np.dstack((tmp, tmp, tmp))
        sizes[i][0] = tmp.shape[0]
        sizes[i][1] = tmp.shape[1]
        img_test[i, :tmp.shape[0], :tmp.shape[1], :] += tmp
        tmp = imread(img_names_gt[i], as_gray=True)
        img_gt[i, :tmp.shape[0], :tmp.shape[1], 0] += tmp
    img_gt /= 255.0
    images = (model.predict(img_test) > 0.5).astype('uint8') * 255
    total_iou = 0
    for i in range(len(img_names_gt)):
      fig, ax = plt.subplots(1, 4)
      iou = IoUSingle(img_gt[i, :int(sizes[i][0]), :int(sizes[i][1]), 0], images[i, :int(sizes[i][0]), :int(sizes[i][1]), 0])
      total_iou += iou
      ax[0].imshow(img_test[i, :int(sizes[i][0]), :int(sizes[i][1]), :])
      ax[1].imshow(images[i, :int(sizes[i][0]), :int(sizes[i][1]), 0], cmap='gray')
      ax[2].imshow(img_gt[i, :int(sizes[i][0]), :int(sizes[i][1]), 0], cmap='gray')
      ax[3].imshow(np.logical_and(images[i, :int(sizes[i][0]), :int(sizes[i][1]), 0], img_gt[i, :int(sizes[i][0]), :int(sizes[i][1]), 0]))
      #imsave('./out_data/' + img_names_gt[i].rsplit(sep='/', maxsplit=1)[-1], images[i, :int(sizes[i][0]), :int(sizes[i][1]), 0])
    print('Mean IoU =', total_iou / len(img_names_gt))
  

In [None]:
model = train_model()
model.save('segm.hdf5')