In [None]:
# Imports
import os
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")
%matplotlib inline

from tqdm import tqdm_notebook, tnrange
from itertools import chain
from skimage.io import imread, imshow, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from sklearn.model_selection import train_test_split

import tensorflow as tf

from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout
from keras.layers.core import Lambda, RepeatVector, Reshape
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D, GlobalMaxPool2D
from keras.layers.merge import concatenate, add
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from PIL import Image

In [None]:
def segment_eyes_mask(array):
    img_array_red=array[:,:,0]
    img_array_green=array[:,:,1]
    img_array_blue=array[:,:,2]
    array = img_array_green
    for x, row in enumerate(img_array_green):
        for y, val in enumerate(row):
            if val==255 and img_array_red[x,y]==0 and img_array_blue[x,y]==0:
                array[x,y]=1
            else:
                array[x,y]=0
    return array

In [None]:
patients = next(os.walk("../input/head-segmentation-masks"))[1]
patients.remove("labels")
patients.remove("real_photos")
im_per_patient =30
X = np.zeros((len(patients)*im_per_patient, 256, 256,1), dtype=np.float32)
y = np.zeros((len(patients)*im_per_patient, 256, 256,1), dtype=np.float32)
image_number=0
for p, patient_id_ in tqdm_notebook(enumerate(patients), total=len(patients)):
    patient_image_ids = next(os.walk("../input/head-segmentation-masks/"+patient_id_))[2]
    print("length: ", len(patient_image_ids))
    for c, img_id_ in tqdm_notebook(enumerate(patient_image_ids[:im_per_patient]), total=im_per_patient):
        print(patient_id_+"/"+img_id_, image_number)
        img = load_img("../input/head-segmentation-masks/"+patient_id_+"/"+img_id_, grayscale=True, target_size=(256,256,1))
        x_img = img_to_array(img)
        delim = patient_id_.find("_")
        name = patient_id_[:delim]
        if name!="male06":
            mask = load_img("../input/head-segmentation-masks/labels/"+name+"/"+img_id_, target_size=(256,256,3))
            mask = img_to_array(mask)
            mask = segment_eyes_mask(mask)
            mask = resize(mask, (256,256,1))
            X[image_number]=x_img
            y[image_number]=mask
            image_number+=1

In [None]:
fig, (ax1, ax2) = plt.subplots(1,2,figsize=(20,15))
ax1.imshow(X[5])
ax2.imshow(y[5])

In [None]:
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=20)
index=2

In [None]:
def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)

    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    return x
  
def get_unet(input_img, n_filters = 16, dropout = 0.1, batchnorm = True):
    c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    p1 = MaxPooling2D((2, 2))(c1)
    p1 = Dropout(dropout)(p1)
    
    c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    p2 = MaxPooling2D((2, 2))(c2)
    p2 = Dropout(dropout)(p2)
    
    c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    p3 = MaxPooling2D((2, 2))(c3)
    p3 = Dropout(dropout)(p3)
    
    c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    p4 = MaxPooling2D((2, 2))(c4)
    p4 = Dropout(dropout)(p4)
    
    c5 = conv2d_block(p4, n_filters * 16, kernel_size = 3, batchnorm = batchnorm)
    p5 = MaxPooling2D((2, 2))(c4)
    p5 = Dropout(dropout)(p4)
    
    c6 = conv2d_block(p5, n_filters = n_filters * 32, kernel_size = 3, batchnorm = batchnorm)
    
    u6 = Conv2DTranspose(n_filters * 16, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u6 = concatenate([u6, c5])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters * 16, kernel_size = 3, batchnorm = batchnorm)
    
    u7 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u7 = concatenate([u7, c4])
    u7 = Dropout(dropout)(u6)
    c7 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    
    u8 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)
    u8 = concatenate([u8, c3])
    u8 = Dropout(dropout)(u7)
    c8 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    
    u9 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u9 = concatenate([u9, c2])
    u9 = Dropout(dropout)(u8)
    c9 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    
    u10 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u10 = concatenate([u10, c1])
    u10 = Dropout(dropout)(u10)
    c10 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c10)
    model = Model(inputs=[input_img], outputs=[outputs])
    return model

In [None]:
input_img = Input((256, 256,1), name='img')
model = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True)
model.compile(optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy"])
model.summary()

In [None]:
callbacks = [
    EarlyStopping(patience=10, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1),
    ModelCheckpoint('segmentation_256.h5', verbose=1, save_best_only=True, save_weights_only=True)
]
results = model.fit(X_train, y_train, batch_size=32, epochs=40, callbacks=callbacks,\
                    validation_data=(X_valid, y_valid))