In [None]:
#Mount notebook to google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd drive/My Drive/keras-u-net-master

In [None]:
from __future__ import print_function
#Import proper libraries
import os
import numpy as np

from skimage.io import imsave, imread
from PIL import Image
data_path = 'data/'
#Set Image width and height
image_rows = 512
image_cols = 512

#Function to create the training data
def create_train_data():
    train_data_path = os.path.join(data_path, 'train/Image PP/')
    train_data_Label_path = os.path.join(data_path, 'train/Label/')
    images = os.listdir(train_data_path)
    total = len(images)

    imgs = np.ndarray((total, image_rows, image_cols), dtype=np.uint8)
    imgs_mask = np.ndarray((total, image_rows, image_cols), dtype=np.uint8)

    i = 0
    print('-'*30)
    print('Creating training images...')
    print('-'*30)
    for image_name in images:
        img = imread(os.path.join(train_data_path, image_name), as_gray=True)
        img_mask = imread(os.path.join(train_data_Label_path, image_name), as_gray=True)

        img = np.array([img])
        img_mask = np.array([img_mask])

        imgs[i] = img
        imgs_mask[i] = img_mask

        if i % 50 == 0:
            print('Done: {0}/{1} images'.format(i, total))
        i += 1
    print('Loading done.')

    np.save('imgs_train.npy', imgs)
    np.save('imgs_mask_train.npy', imgs_mask)
    print('Saving to .npy files done.')

#Function to laod the training data
def load_train_data():
    imgs_train = np.load('imgs_train.npy')
    imgs_mask_train = np.load('imgs_mask_train.npy')
    return imgs_train, imgs_mask_train

#Function to create the testing data
def create_test_data():
    train_data_path = os.path.join(data_path, 'test/Image PP')
    images = os.listdir(train_data_path)
    total = len(images)

    imgs = np.ndarray((total, image_rows, image_cols), dtype=np.uint8)
    imgs_id = np.ndarray((total, ), dtype=np.int32)

    i = 0
    print('-'*30)
    print('Creating test images...')
    print('-'*30)
    for image_name in images:
        img_id = int(image_name.split('.')[0])
        img = imread(os.path.join(train_data_path, image_name), as_gray=True)

        img = np.array([img])

        imgs[i] = img
        imgs_id[i] = img_id

        if i % 10 == 0:
            print('Done: {0}/{1} images'.format(i, total))
        i += 1
    print('Loading done.')

    np.save('imgs_test.npy', imgs)
    np.save('imgs_id_test.npy', imgs_id)
    print('Saving to .npy files done.')

#Function to load the testing data
def load_test_data():
    imgs_test = np.load('imgs_test.npy')
    imgs_id = np.load('imgs_id_test.npy')
    return imgs_test, imgs_id

if __name__ == '__main__':
    create_train_data()
    create_test_data()

In [None]:
from __future__ import print_function
#Import proper Libraries
import os
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from keras.models import Sequential,Model  
from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras import backend as K
from keras import losses,metrics
import tensorflow as tf

from data_preparation import load_train_data, load_test_data

K.set_image_data_format('channels_last')  # TF dimension ordering in this code
#Set Image width and height
img_rows = 256
img_cols = 256
n_label=2
smooth = 1.

#Function to calculate the Dice Coefficient
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

#Function to calculate the Dice Coefficient Loss
def dice_coef_loss(y_true, y_pred):
    return dice_coef(y_true, y_pred)

#Function to define the SegNet Model
def SegNet(pretrained_weights=None):
    model = Sequential()  
    #encoder  
    model.add(Conv2D(64,(3,3),strides=(1,1),input_shape=(img_rows,img_cols,1),padding='same',activation='relu',data_format='channels_last'))
    model.add(BatchNormalization())  
    model.add(Conv2D(64,(3,3),strides=(1,1),padding='same',activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2,2)))  
    #(128,128)  
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2,2)))
    #(64,64)  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2, 2)))  
    #(32,32)  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2, 2)))  
    #(16,16)  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2, 2)))  
    #(8,8)  
    #decoder  
    model.add(UpSampling2D(size=(2,2)))  
    #(16,16)  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(UpSampling2D(size=(2, 2)))  
    #(32,32)  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(UpSampling2D(size=(2, 2)))  
    #(64,64)  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(UpSampling2D(size=(2, 2)))  
    #(128,128)  
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(UpSampling2D(size=(2, 2)))  
    #(256,256)  
    model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(img_rows, img_cols,3), padding='same', activation='relu',data_format='channels_last'))
    model.add(BatchNormalization())  
    model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(1, (1, 1), strides=(1, 1), padding='same'))  
    #model.add(Reshape((img_rows,img_cols,n_label)))
    #model.add(Permute((2,1)))  
    model.add(Activation('sigmoid'))
    #Model is compiled using Adam, Binary Cross Entropy and Dice Coefficient, accuracy and MAE metrics
    model.compile(optimizer='Adam',loss='binary_crossentropy',metrics=[dice_coef,'accuracy',tf.keras.metrics.MeanIoU(num_classes=2),metrics.mae])  
    model.summary()  
    
    if (pretrained_weights):
        model.load_weights(pretrained_weights)

    return model

#Function to resize the images for predictions
def preprocess(imgs):
    imgs_p = np.ndarray((imgs.shape[0], img_rows, img_cols), dtype=np.uint8)
    for i in range(imgs.shape[0]):
        imgs_p[i] = resize(imgs[i], (img_cols, img_rows), preserve_range=True)

    imgs_p = imgs_p[..., np.newaxis]
    return imgs_p

#Main function to train the model and predict the masks. 
def train_and_predict():
    print('Loading and preprocessing train data...')
    imgs_train, imgs_mask_train = load_train_data()

    imgs_train = preprocess(imgs_train)
    imgs_mask_train = preprocess(imgs_mask_train)

    imgs_train = imgs_train.astype('float32')
    mean = np.mean(imgs_train)  # mean for data centering
    std = np.std(imgs_train)  # std for data normalization

    imgs_train -= mean
    imgs_train /= std

    imgs_mask_train = imgs_mask_train.astype('float32')
    imgs_mask_train /= 255.  # scale masks to [0, 1]

    #model = SegNet('unet_weights_150_Img_SegNet.h5') #If you are training again, uncomment  to load the pre-training model
    model = SegNet()
    #Model checkpoints are saved and validation loss is monitored
    model_checkpoint = ModelCheckpoint('unet_weights_150_Img_SegNet.h5', monitor='val_loss', save_best_only=True)
    print('Fitting model...')
    hist=model.fit(imgs_train, imgs_mask_train, batch_size=16, nb_epoch=1000, verbose=1, shuffle=True,
              validation_split=0.2,
              callbacks=[model_checkpoint])
    print(model.summary())

    imgs_test, imgs_id_test = load_test_data()
    imgs_test = preprocess(imgs_test)
    imgs_test = imgs_test.astype('float32')
    mean=np.mean(imgs_test)
    std=np.std(imgs_test)
    imgs_test -= mean
    imgs_test /= std

    model.load_weights('unet_weights_150_Img_SegNet.h5')

    print('Predicting masks on test data...')

    imgs_mask_test = model.predict(imgs_test, verbose=1)
    np.save('imgs_mask_test.npy', imgs_mask_test)
    pred_dir = 'preds_SegNet'
    
    #After predictions, the mask images are saved in the preds_SegNet folder
    if not os.path.exists(pred_dir):
      os.mkdir(pred_dir)

    for image, image_id in zip(imgs_mask_test, imgs_id_test):
      image = (image[:, :, 0] * 255.).astype(np.uint8)
      imsave(os.path.join(pred_dir, str(image_id) + '_pred.png'), image)
    
    #History of the model is then plotted to show the various curves
    import matplotlib.pyplot as plt
    import pickle
    model.load_weights('unet_weights_150_Img_SegNet.h5')
    l_s=plt.plot(hist.history['loss'], color='b')
    vl_s=plt.plot(hist.history['val_loss'], color='r')
    plt.title('Loss Curve')
    pickle.dump(l_s, open('Loss_Seg.fig.pickle', 'wb')) # This is for Python 3 - py2 may need `file` instead of `open`
    pickle.dump(vl_s, open('Val_Loss_Seg.fig.pickle', 'wb')) # This is for Python 3 - py2 may need `file` instead of `open`
    plt.show()

    d_s=plt.plot(hist.history['dice_coef'], color='b')
    vd_s=plt.plot(hist.history['val_dice_coef'], color='r')
    plt.title('Dice Coefficient Curve')
    pickle.dump(d_s, open('Dice_Seg.fig.pickle', 'wb')) # This is for Python 3 - py2 may need `file` instead of `open`
    pickle.dump(vd_s, open('Val_Dice_Seg.fig.pickle', 'wb')) # This is for Python 3 - py2 may need `file` instead of `open`
    plt.show()

    a_s=plt.plot(hist.history['accuracy'], color='b')
    va_s=plt.plot(hist.history['val_accuracy'], color='r')
    plt.title('Accuracy Curve')
    pickle.dump(a_s, open('Acc_Seg.fig.pickle', 'wb')) # This is for Python 3 - py2 may need `file` instead of `open`
    pickle.dump(va_s, open('Val_Acc_Seg.fig.pickle', 'wb')) # This is for Python 3 - py2 may need `file` instead of `open`
    plt.show()

if __name__ == '__main__':
    train_and_predict()