**Portrait Segmentation Using Prisma-Unet**

Set up the GPU runtime

In [0]:
 # Check GPU
!nvidia-smi

In [0]:
# Mount G-drive
from google.colab import drive
drive.mount('/content/drive')

**Imports**

In [0]:
# Import libraries
import os
import tensorflow as tf
import keras
from keras.models import Model
from keras.layers import Dense, Input,Flatten, concatenate,Reshape, Conv2D, MaxPooling2D, Lambda,Activation,Conv2DTranspose, SeparableConv2D
from keras.layers import UpSampling2D, Conv2DTranspose, BatchNormalization, Dropout, DepthwiseConv2D, Add
from keras.callbacks import TensorBoard, ModelCheckpoint, Callback, ReduceLROnPlateau
from keras.regularizers import l1
from keras.optimizers import SGD, Adam
import keras.backend as K
from keras.utils import plot_model
from keras.callbacks import TensorBoard, ModelCheckpoint, Callback
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage.filters import gaussian_filter
from random import randint
from keras.models import load_model
from keras.preprocessing.image import ImageDataGenerator
from PIL import Image
import matplotlib.pyplot as plt
from random import randint
%matplotlib inline

**Load dataset**

Load the datset for training the model from directory.

Ensure the images are in **RGB** format and masks (**ALPHA**) have pixel values **0 or 255**.

In [0]:
IMGDS="/content/portrait256/images";
MSKDS="/content/portrait256/masks";

# Total number of images
num_images=len(os.listdir(IMGDS+"/img"))

Copy pretrained model to local runtime disk. Save the checkpoints to your google drive (safe).

In [0]:
# Configure save paths and batch size
CHECKPOINT="/content/drive/My Drive/portrait256/prisma-net-{epoch:02d}-{val_loss:.2f}.hdf5"
LOGS='./logs'
BATCH_SIZE=64

**Data Generator**

Create a data generator to load images and masks together at runtime. 
Use same seed for performing run-time augmentation for images and masks. Here we use  80/20 tran-val split.

**Note:** The keras 'flow_from_directory' expects a specific directory structure for loading datasets. Your parent data-set directory should contain two sub-directories 'images' and 'masks'. Now, each of these directories should have a sub-directory(say 'img' and 'msk) for storing images or masks.

In [0]:
# Data generator for training and validation

data_gen_args = dict(rescale=1./255,
                     width_shift_range=0.2,
                     height_shift_range=0.2,
                     zoom_range=0.2,
                     horizontal_flip=True,
                     validation_split=0.2
                    )

image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

# Provide the same seed and keyword arguments to the fit and flow methods
seed = 1
batch_sz=BATCH_SIZE

# Train-val split (80-20)
num_train=int(num_images*0.8)
num_val=int(num_images*0.2) 


train_image_generator = image_datagen.flow_from_directory(
    IMGDS,
    batch_size=batch_sz,
    shuffle=True,
    subset='training',
    color_mode="rgb",
    class_mode=None,
    seed=seed)

train_mask_generator = mask_datagen.flow_from_directory(
    MSKDS,
    batch_size=batch_sz,
    shuffle=True,
    subset='training',
    color_mode="grayscale",
    class_mode=None,
    seed=seed)


val_image_generator = image_datagen.flow_from_directory(
    IMGDS, 
batch_size = batch_sz,
shuffle=True,
subset='validation',
color_mode="rgb",
class_mode=None,
seed=seed)

val_mask_generator = mask_datagen.flow_from_directory(
     MSKDS,
batch_size = batch_sz,
shuffle=True,
subset='validation',
color_mode="grayscale",
class_mode=None,
seed=seed)
                   
# combine generators into one which yields image and masks
train_generator = zip(train_image_generator, train_mask_generator)
val_generator = zip(val_image_generator, val_mask_generator)

**Model Architecture**

The prisma-net basically uses a **U-Net** encoder-decoder structure. However, the architecture incorporates a few significant **changes**. Firstly, we replace the concatenation of features after upsampling with **element-wise addition**. Further, instead of normal Conv+ReLu block, we use a **residual block with depth-wise separable convolutions**. Finally, to improve the accuracy the **decoder** part contains **more blocks** than encoder.

In [0]:
def residual_block(x, nfilters):

  y = Activation("relu")(x)
  y= SeparableConv2D(filters=nfilters, kernel_size=3, padding="same")(y)
  y = Activation("relu")(y)
  y= SeparableConv2D(filters=nfilters, kernel_size=3, padding="same")(y)

  z = Add()([x, y])

  return z

In [0]:
def prisma_unet(finetuene=False, alpha=1):

    input = Input(shape=(256,256,3))
    
    # Encoder part
    x = Conv2D(filters=8, kernel_size=3,padding = 'same' )(input)

    res1= residual_block(x, nfilters=8)
    x = Conv2D(filters=32, kernel_size=3, strides=2, padding = 'same' )(res1)

    res2= residual_block(x, nfilters=32)
    x = Conv2D(filters=64, kernel_size=3, strides=2, padding = 'same' )(res2)

    res3= residual_block(x, nfilters=64)
    x = Conv2D(filters=128, kernel_size=3, strides=2, padding = 'same' )(res3)


    x= residual_block(x, nfilters=128)
    res4= residual_block(x, nfilters=128)
    x = Conv2D(filters=128, kernel_size=3, strides=2, padding = 'same' )(res4)

    x= residual_block(x, nfilters=128)
    x= residual_block(x, nfilters=128)
    x= residual_block(x, nfilters=128)
    x= residual_block(x, nfilters=128)
    x= residual_block(x, nfilters=128)
    x= residual_block(x, nfilters=128)
   

    # Decoder part
    x=Conv2DTranspose(filters=128, kernel_size=3, strides=2, padding = "same")(x)
    x = Add()([x, res4 ])

    x= residual_block(x, nfilters=128)
    x= residual_block(x, nfilters=128)
    x= residual_block(x, nfilters=128)

    x = Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding = 'same' )(x)
    x = Add()([x, res3 ])
   
    x= residual_block(x, nfilters=64)
    x= residual_block(x, nfilters=64)
    x= residual_block(x, nfilters=64)

    x = Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding = 'same' )(x)
    x = Add()([x, res2 ])
   
    x= residual_block(x, nfilters=32)
    x= residual_block(x, nfilters=32)
    x= residual_block(x, nfilters=32)

    x = Conv2DTranspose(filters=8, kernel_size=3, strides=2, padding = 'same' )(x)
    x = Add()([x, res1 ])
  
    x= residual_block(x, nfilters=8)
    x= residual_block(x, nfilters=8)
    x= residual_block(x, nfilters=8)
  
    x = Conv2DTranspose(1, (1,1), padding='same')(x)
    x = Activation('sigmoid', name="op")(x) 

    model = Model(inputs=input, outputs=x)
    model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-3),metrics=['accuracy'])

    return model

In [0]:
# Get prisma network
model = prisma_unet()

# Model summary
model.summary()

# Layer specifications
for i, layer in enumerate(model.layers):
    print(i, layer.output.name, layer.output.shape)

# Plot model architecture
plot_model(model, to_file='prisma-net.png')

# Save checkpoints
checkpoint = ModelCheckpoint(CHECKPOINT, monitor='val_loss', verbose=1, save_weights_only=False , save_best_only=True, mode='min')

# Callbacks 
reduce_lr = ReduceLROnPlateau(factor=0.5, patience=3, min_lr=0.000001, verbose=1)
tensorboard = TensorBoard(log_dir=LOGS, histogram_freq=0,
                          write_graph=True, write_images=True)

callbacks_list = [checkpoint, tensorboard, reduce_lr]

**Train**

Initially train the model for **300 epochs** using **supervisely person** dataset. Finally, train the model on **portrait datasets**, using the result of the previous step as initial values for **weights**.

Use keras callbacks for **tensorboard** visulaization and **learning rate decay** as shown below. You can resume your training from a previous session by loading the entire **pretrained model** (weights  & optimzer state) as a hdf5 file.

In [0]:
# Load pretrained model (if any)
model=load_model('/content/drive/My Drive/portrait256/prisma-net-07-0.09.hdf5')

In [0]:
# Train the model
model.fit_generator(
    train_generator,
    epochs=300,
    steps_per_epoch=num_train/batch_sz,
    validation_data=val_generator, 
    validation_steps=num_val/batch_sz,
    use_multiprocessing=True,
    workers=4,
    callbacks=callbacks_list)

**Test**

Test the model on a new portrait image and plot the results.

In [0]:
# Load a test image
im=Image.open('/content/baby.jpg')

# Load the model
model=load_model('/content/drive/My Drive/portrait256/prisma-net-15-0.08.hdf5')

In [0]:
# Inference
im=im.resize((256,256),Image.ANTIALIAS)
img=np.float32(np.array(im)/255.0)
plt.imshow(img[:,:,0:3])
img=img[:,:,0:3]

# Reshape input and threshold output
out=model.predict(img.reshape(1,256,256,3))
out=np.float32((out>0.5))

In [0]:
# Output mask
plt.imshow(np.squeeze(out.reshape((256,256))))

**Export Model**

Export the model to **tflite** format for **real-time** inference on a **smart-phone**.

In [0]:
# Flatten output and save model
output = model.output
newout=Reshape((65536,))(output)
new_model=Model(model.input,newout)

new_model.save('prisma-net.h5')

# For Float32 Model
converter = tf.lite.TFLiteConverter.from_keras_model_file('/content/prisma-net.h5')
tflite_model = converter.convert()
open("prisma-net.tflite", "wb").write(tflite_model)

**Post-training Quantization**

We can **reduce the model size and latency** by performing post training quantization. Fixed precison conversion (**UINT8**) allows us to reduce the model size significantly by quantizing the model weights.We can run this model on the mobile **CPU**. The **FP16** (experimental) conversion allows us to reduce the model size by half and the corresponding model can be run directly on mobile **GPU**.

In [0]:
# For UINT8 Quantization

converter = tf.lite.TFLiteConverter.from_keras_model_file('/content/prisma-net.h5')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_model = converter.convert()
open("prisma-net_uint8.tflite", "wb").write(tflite_model)


In [0]:
# For Float16 Quantization (Experimental)

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model_file('/content/prisma-net.h5')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]
tflite_model = converter.convert()
open("prisma-net_fp16.tflite", "wb").write(tflite_model)

**Plot sample output**

Load the test data as a batch using a numpy array. 

Crop the image using the output mask and plot the result.

In [0]:
# Load test images and model
model=load_model('/content/prisma-net.h5',compile=False)
test_imgs=np.load('/content/kids.npy')
test_imgs= np.float32(np.array(test_imgs)/255.0)

In [0]:
# Perform batch prediction
out=model.predict(test_imgs)
out=np.float32((out>0.5))
out=out.reshape((4,256,256,1))

In [0]:
# Plot the output using matplotlib
fig=plt.figure(figsize=(16, 16))
columns = 4
rows = 2

for i in range(1, columns+1):
    img = test_imgs[i-1].squeeze()
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
plt.show()

fig=plt.figure(figsize=(16, 16))
columns = 4
rows = 2

for i in range(1, columns+1):
    img = out[i-1].squeeze()/255.0
    fig.add_subplot(rows, columns, 4+i)
    plt.imshow(out[i-1]*test_imgs[i-1])

plt.show()