Test for Unet

In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import random
import numpy as np
from skimage.io import imread, imshow
from skimage.transform import resize

Unet Architecture Based on Imagenet Trained Resnet50 Encoder Block

In [3]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    #next block
    x = Conv2D(num_filters, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x


In [4]:
def decoder_block(inputs, skip_features, num_filters):

    #first layer of each decoder block is a 2D-Conv transpose layer of shape 2,2
    x = Conv2DTranspose(num_filters, (2,2), strides=2, padding='same')(inputs)
    x = Concatenate()([x,skip_features])
    x = conv_block(x, num_filters)
    return x

In [5]:
def build_resnet50_unet(input_shape):
    #creating the input layer
    inputs =  Input(input_shape, name='input_1')

    #calling the pretrained resnet50 model
    resnet50 = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_tensor=inputs)
    #set the resnet model to inference mode (we can also train the layers of resnet except the BatchNorm without destroying the learned pattern)
    resnet50.trainable = False

    #summary of the resnet
    #resnet50.summary()

    #creating the encoder block with renset (branching out the skip connections for later use in decoder)
    s1 = resnet50.get_layer('input_1').output   #512
    s2 = resnet50.get_layer('conv1_relu').output    #256
    s3 = resnet50.get_layer('conv2_block3_out').output  #128
    s4 = resnet50.get_layer('conv3_block4_out').output  #64

    #bride connection between the encoder and decoder
    b1 = resnet50.get_layer('conv4_block4_out').output  #32
    
    #creating the decoder part
    d1 =  decoder_block(b1, s4, 512)        #64
    d2 =  decoder_block(d1, s3, 256)        #128
    d3 =  decoder_block(d2, s2, 128)        #256
    d4 =  decoder_block(d3, s1, 64)         #512,512, 1

    #output layer
    outputs = Conv2D(1, 1, padding='same', activation='sigmoid')(d4)

    model = model = tf.keras.Model(inputs=[inputs], outputs=[outputs])

    return model

    # check for the shapes
    print(s1.shape, s2.shape, s3.shape, s4.shape, b1.shape)
    print(d1.shape, d2.shape, d3.shape, d4.shape)

Preparing the training data

In [6]:
seed = 42
np.random.seed = seed

IMG_WIDTH = 256
IMG_HEIGHT = 256
IMG_CHANNELS = 3

TRAIN_PATH = 'data-science-bowl-2018/stage1_train/'
TEST_PATH = 'data-science-bowl-2018/stage1_test/'

train_ids = next(os.walk(TRAIN_PATH))[1]
test_ids = next(os.walk(TEST_PATH))[1]

X_train = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
Y_train = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool_)

print('Resizing training images and masks')
for n, id_ in tqdm(enumerate(train_ids), total=len(train_ids)):   
    path = TRAIN_PATH + id_
    img = imread(path + '/images/' + id_ + '.png')[:,:,:IMG_CHANNELS]  
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    X_train[n] = img  #Fill empty X_train with values from img
    mask = np.zeros((IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool_)
    for mask_file in next(os.walk(path + '/masks/'))[2]:
        mask_ = imread(path + '/masks/' + mask_file)
        mask_ = np.expand_dims(resize(mask_, (IMG_HEIGHT, IMG_WIDTH), mode='constant',  
                                      preserve_range=True), axis=-1)
        mask = np.maximum(mask, mask_)  
            
    Y_train[n] = mask   

# test images
X_test = np.zeros((len(test_ids), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
sizes_test = []
print('Resizing test images') 
for n, id_ in tqdm(enumerate(test_ids), total=len(test_ids)):
    path = TEST_PATH + id_
    img = imread(path + '/images/' + id_ + '.png')[:,:,:IMG_CHANNELS]
    sizes_test.append([img.shape[0], img.shape[1]])
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    X_test[n] = img

print('Done!')

image_x = random.randint(0, len(train_ids))
imshow(X_train[image_x])
plt.show()
imshow(np.squeeze(Y_train[image_x]))
plt.show()

Resizing training images and masks


 15%|█▍        | 100/670 [00:18<01:44,  5.43it/s]


KeyboardInterrupt: 

Training the model

In [8]:
# just to check 
if __name__ == '__main__':
    input_shape = (256, 256, 3)
    model = build_resnet50_unet(input_shape)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.IoU(num_classes=2, target_class_ids=[0])])
    model.summary()
    #results = model.fit(X_train, Y_train, validation_split=0.1, batch_size=16, epochs=25)


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 262, 262, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 128, 128, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                            

In [None]:
idx = random.randint(0, len(X_train))


preds_train = model.predict(X_train[:int(X_train.shape[0]*0.9)], verbose=1)
preds_val = model.predict(X_train[int(X_train.shape[0]*0.9):], verbose=1)
preds_test = model.predict(X_test, verbose=1)

 
preds_train_t = (preds_train > 0.5).astype(np.uint8)
preds_val_t = (preds_val > 0.5).astype(np.uint8)
preds_test_t = (preds_test > 0.5).astype(np.uint8)

# Perform a sanity check on some random validation samples
ix = random.randint(0, len(preds_val_t))
imshow(X_train[int(X_train.shape[0]*0.9):][ix])
imshow(np.squeeze(preds_val_t[ix]), alpha=0.1)
plt.show()