**SlimNet: Real-time Portrait Segmentation on High Resolution Images**

Slim-net is a light weight CNN for performing **real-time portrait segmentation** on mobile devices, using high resolution images. We were able to achieve **99% training accuracy** on the aisegment portrait dataset and run the model(**1.5MB**) on a mid-range android smartphone at **20 fps** on deployment. Using the high resolution input images, we were able to preserve **fine details** and **avoid sharp edges** on segmentation masks, during inference .The architecture is heavily inspired from the mediapipe **hair-segmentation** model for android and the tflite model runs on any **android** device without additional API's.

**Environment and Datset**

Choose a **GPU runtime** in colab, for training the network.

In [None]:
!nvidia-smi

Install **tensorflow 1.15** version, using pip command.

In [None]:
# Install TensorFlow 2.0 (GPU)
!pip install tensorflow-gpu==1.15

Extract the **dataset** for training, from google-drive.

In [None]:
!unzip /content/drive/My\ Drive/slim512/aiseg_small.zip

**Packages and Libraries**

In [1]:
# Import the packages
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, Callback, ReduceLROnPlateau
from tensorflow.keras.layers import Input, Conv2D, Add, Multiply, Reshape , MaxPool2D, Conv2DTranspose, PReLU, concatenate, Lambda
from tensorflow.keras.layers import Dropout, BatchNormalization, concatenate, Activation, AveragePooling2D, UpSampling2D
from tensorflow.keras.layers import Flatten, Dense, GlobalAveragePooling2D, BatchNormalization, DepthwiseConv2D, SeparableConv2D
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, Callback, ReduceLROnPlateau
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, Callback
import matplotlib.pyplot as plt
import tensorflow as tf
import os, cv2, imageio
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance
from IPython.display import clear_output
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

**Data-loader and Augmentations**

Use the **custom data-loader** to perform  **augmentaion** on-the-fly and feed the network with **batch of images**. Here we use the augmentaion like **brightness, saturation, contrast, cropping, flipping** etc for training the model. If your masks are not in **raw format**, then you need to convert them into sparse labels(color indexed) for training with **SparseCategoricalCrossentropy** loss (i.e 0 for bg and 1 for fg). If you want to provide labels using **one-hot** representation, please use **CategoricalCrossentropy** loss.

In [None]:
import tensorflow as tf
import random

class DataLoader(object):
    """A TensorFlow Dataset API based loader for semantic segmentation problems."""

    def __init__(self, image_paths, mask_paths, image_size, channels=[3, 3], crop_percent=None, palette=None, seed=None):
        """
        Initializes the data loader object
        Args:
            image_paths: List of paths of train images.
            mask_paths: List of paths of train masks (segmentation masks)
            image_size: Tuple of (Height, Width), the final height 
                        of the loaded images.
            channels: List of ints, first element is number of channels in images,
                      second is the number of channels in the mask image (needed to
                      correctly read the images into tensorflow.)
            crop_percent: Float in the range 0-1, defining percentage of image 
                          to randomly crop.
            palette: A list of RGB pixel values in the mask. If specified, the mask
                     will be one hot encoded along the channel dimension.
            seed: An int, if not specified, chosen randomly. Used as the seed for 
                  the RNG in the data pipeline.
        """
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.palette = palette
        self.image_size = image_size
        if crop_percent is not None:
            if 0.0 < crop_percent <= 1.0:
                self.crop_percent = tf.constant(crop_percent, tf.float32)
            elif 0 < crop_percent <= 100:
                self.crop_percent = tf.constant(crop_percent / 100., tf.float32)
            else:
                raise ValueError("Invalid value entered for crop size. Please use an \
                                  integer between 0 and 100, or a float between 0 and 1.0")
        else:
            self.crop_percent = None
        self.channels = channels
        if seed is None:
            self.seed = random.randint(0, 1000)
        else:
            self.seed = seed

    def _corrupt_brightness(self, image, mask):
        """
        Radnomly applies a random brightness change.
        """
        cond_brightness = tf.cast(tf.random.uniform(
            [], maxval=2, dtype=tf.int32), tf.bool)
        image = tf.cond(cond_brightness, lambda: tf.image.random_brightness(
            image, 0.1), lambda: tf.identity(image))
        return image, mask


    def _corrupt_contrast(self, image, mask):
        """
        Randomly applies a random contrast change.
        """
        cond_contrast = tf.cast(tf.random.uniform(
            [], maxval=2, dtype=tf.int32), tf.bool)
        image = tf.cond(cond_contrast, lambda: tf.image.random_contrast(
            image, 0.1, 0.8), lambda: tf.identity(image))
        return image, mask


    def _corrupt_saturation(self, image, mask):
        """
        Randomly applies a random saturation change.
        """
        cond_saturation = tf.cast(tf.random.uniform(
            [], maxval=2, dtype=tf.int32), tf.bool)
        image = tf.cond(cond_saturation, lambda: tf.image.random_saturation(
            image, 0.1, 0.8), lambda: tf.identity(image))
        return image, mask


    def _crop_random(self, image, mask):
        """
        Randomly crops image and mask in accord.
        """
        cond_crop_image = tf.cast(tf.random.uniform(
            [], maxval=2, dtype=tf.int32, seed=self.seed), tf.bool)
        cond_crop_mask = tf.cast(tf.random.uniform(
            [], maxval=2, dtype=tf.int32, seed=self.seed), tf.bool)

        shape = tf.cast(tf.shape(image), tf.float32)
        h = tf.cast(shape[0] * self.crop_percent, tf.int32)
        w = tf.cast(shape[1] * self.crop_percent, tf.int32)

        image = tf.cond(cond_crop_image, lambda: tf.image.random_crop(
            image, [h, w, self.channels[0]], seed=self.seed), lambda: tf.identity(image))
        mask = tf.cond(cond_crop_mask, lambda: tf.image.random_crop(
            mask, [h, w, self.channels[1]], seed=self.seed), lambda: tf.identity(mask))

        return image, mask


    def _flip_left_right(self, image, mask):
        """
        Randomly flips image and mask left or right in accord.
        """
        image = tf.image.random_flip_left_right(image, seed=self.seed)
        mask = tf.image.random_flip_left_right(mask, seed=self.seed)

        return image, mask


    def _resize_data(self, image, mask):
        """
        Resizes images to specified size and normalizes the image: [0...1]
        """
        image = tf.image.resize(image, [self.image_size, self.image_size]) /255.0
        mask = tf.image.resize(mask, [self.image_size, self.image_size], method='nearest')//tf.reduce_max(mask) # masks should be binary with 0 representing background
        
        return image, mask


    def _parse_data(self, image_paths, mask_paths):
        """
        Reads image and mask files depending on
        specified extension.
        """
        image_content = tf.io.read_file(image_paths)
        mask_content = tf.io.read_file(mask_paths)

        images = tf.image.decode_jpeg(image_content, channels=self.channels[0])
        masks = tf.image.decode_jpeg(mask_content, channels=self.channels[1])

        return images, masks


    def _one_hot_encode(self, image, mask):
        """
        Converts mask to a one-hot encoding specified by the semantic map.
        """
        one_hot_map = []
        for colour in self.palette:
            class_map = tf.reduce_all(tf.equal(mask, colour), axis=-1)
            one_hot_map.append(class_map)
        one_hot_map = tf.stack(one_hot_map, axis=-1)
        one_hot_map = tf.cast(one_hot_map, tf.float32)
        
        return image, one_hot_map

    def data_batch(self, batch_size, augment, shuffle=False, one_hot_encode=False):
        """
        Reads data, normalizes it, shuffles it, then batches it, returns a
        the next element in dataset op and the dataset initializer op.
        Inputs:
            batch_size: Number of images/masks in each batch returned.
            augment: Boolean, whether to augment data or not.
            shuffle: Boolean, whether to shuffle data in buffer or not.
            one_hot_encode: Boolean, whether to one hot encode the mask image or not.
                            Encoding will done according to the palette specified when
                            initializing the object.
        Returns:
            data: A tf dataset object.
        """

        # Create dataset out of the 2 files:
        data = tf.data.Dataset.from_tensor_slices((self.image_paths, self.mask_paths))

        # Parse images and labels
        data = data.map(self._parse_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)

        # If augmentation is to be applied
        if augment:
            data = data.map(self._corrupt_brightness,
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)

            data = data.map(self._corrupt_contrast,
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)

            data = data.map(self._corrupt_saturation,
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)

            if self.crop_percent is not None:
                data = data.map(self._crop_random, 
                                num_parallel_calls=tf.data.experimental.AUTOTUNE)

            data = data.map(self._flip_left_right,
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)

        # Resize to smaller dims for speed
        data = data.map(self._resize_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)

        # One hot encode the mask
        if one_hot_encode:
            if self.palette is None:
                raise ValueError('No Palette for one-hot encoding specified in the data loader! \
                                  please specify one when initializing the loader.')
            data = data.map(self._one_hot_encode, num_parallel_calls=tf.data.experimental.AUTOTUNE)

        if shuffle:
            # Shuffle, repeat, batch and prefetch
            data = data.shuffle(1000).repeat().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
        else:
            # Batch and prefetch
            data = data.repeat().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

        return data

Configure the **data loader, image paths and logging** options.

In [None]:
import tensorflow as tf
import os

IMAGE_DIR_PATH = '/content/imgs'
MASK_DIR_PATH = '/content/msks'

image_paths = [os.path.join(IMAGE_DIR_PATH, x) for x in sorted(os.listdir(IMAGE_DIR_PATH)) if x.endswith('.jpg')]
mask_paths = [os.path.join(MASK_DIR_PATH, x) for x in sorted(os.listdir(MASK_DIR_PATH)) if x.endswith('.png')]

train_image_paths, val_image_paths, train_mask_paths, val_mask_paths = train_test_split(image_paths, mask_paths, test_size = 0.2, random_state = 0)

CHECKPOINT="/content/drive/My Drive/slim512/ckpt/slim-net-{epoch:02d}-{val_loss:.2f}.hdf5"
LOGS='./logs'

num_train=len(train_image_paths)
num_val=len(val_image_paths)
batch_sz=64
epochs=100


# Initialize the dataloader object
train_dataset = DataLoader(image_paths=train_image_paths,
                     mask_paths=train_mask_paths,
                     image_size=512,
                     crop_percent=0.8,
                     channels=[3, 1],
                     seed=47)
val_dataset = DataLoader(image_paths=val_image_paths,
                     mask_paths=val_mask_paths,
                     image_size=512,
                     crop_percent=0.8,
                     channels=[3, 1],
                     seed=47)

# Parse the images and masks, and return the data in batches, augmented optionally.
train_dataset = train_dataset.data_batch(batch_size=batch_sz,
                             augment=True, 
                             shuffle=True)
val_dataset = val_dataset.data_batch(batch_size=batch_sz,
                             augment=False, 
                             shuffle=True)

**Model Architecture**

The following is a brief  **summary** of the **architectural features** of the model:-

1. The model is based on **encoder-decoder** architecture and uses **PReLU** activation throught the network. It hepls us to achieve **faster convergence** and **improved accuracy**.

2. The inputs are initially **downsampled** from a size of 512 to 128 (i,e 1/4'th). This helps us to **reduce** the overall **computation** costs; while preseving the details.

3. It uses **skip connections** between the encoder and decoder blocks (like unet) and helps us to extract **fine details** and improves **gradient flow** across layers. 

4. Further, it uses **bottleneck** layers (like resnet) with **depthwise** convolutions for **faster inference**.

5. Also, it uses **dilated** convolution(like deeplab) and helps us to maintain **larger receptive field** with **same computation and memory costs**, while also **preserving resolution**.

6. Finally, the features are **upsampled** to full resolution(512) with the help of **transposed convolutions**.



In [None]:

def bilinear_resize(x, rsize):
  return tf.image.resize_bilinear(x, [rsize,rsize], align_corners=True)

def encode_bottleneck(x, proj_ch, out_ch, strides=1, dilation=1,separable=True, depthwise=True, preluop=False, pool=False):

  x = PReLU(shared_axes=[1, 2])(x)
  y = Conv2D(filters=proj_ch, kernel_size=strides, strides=strides, padding='same')(x)
  y = PReLU(shared_axes=[1, 2])(y)
  
  if separable==True:
          
      if depthwise==True:
          y = SeparableConv2D(filters=proj_ch, kernel_size=3, strides=1, padding='same')(y)
          y = PReLU(shared_axes=[1, 2])(y)
          y = DepthwiseConv2D(kernel_size=3, padding='same')(y)
      else:
        y= SeparableConv2D(filters=proj_ch, kernel_size=5, strides=1, padding='same')(y)
  else:  
      y = Conv2D(filters=out_ch, kernel_size=3, dilation_rate= dilation ,strides=1, padding='same')(y)

  y = PReLU(shared_axes=[1, 2])(y)
  y = Conv2D(filters=out_ch, kernel_size=1, strides=1, padding='same')(y)

  if pool == True:
      m = MaxPool2D((2, 2), padding='same')(x)
      if m.shape[-1] != 128:
          x = Conv2D(filters=out_ch, kernel_size=1, strides=1, padding='same')(m)
      else:
          x = m
      z = Add()([x, y])
      return z, m 

  z = Add()([x, y])
  
  if preluop==True:
    return z, x

  return z


def decode_bottleneck(x,res1, res2 ,proj_ch1, out_ch, proj_ch2, strides=1, rsize=32,pconv=True):

  x = PReLU(shared_axes=[1, 2])(x)
  y = Conv2D(filters=proj_ch1, kernel_size=strides, strides=strides, padding='same')(x)
  y = PReLU(shared_axes=[1, 2])(y)
  
  y = Conv2DTranspose(filters=8, kernel_size=3, strides=2, padding = 'same' )(y)

  y = PReLU(shared_axes=[1, 2])(y)
  y = Conv2D(filters=out_ch, kernel_size=1, strides=1, padding='same')(y)

  
  x = Conv2D(filters=out_ch, kernel_size=1, strides=1, padding='same')(x)
  r = Add()([x,res1])
  x = Lambda(lambda r: bilinear_resize(r, rsize))(r)

  z = Add()([x, y])
  z = PReLU(shared_axes=[1, 2])(z)
  if pconv == True:
    z = concatenate([z,res2])
  b = Conv2D(filters=proj_ch2, kernel_size=strides, strides=strides, padding='same')(z)
  b = PReLU(shared_axes=[1, 2])(b)
  b = Conv2D(filters=proj_ch2, kernel_size=3, strides=1, padding='same')(b)
  b = PReLU(shared_axes=[1, 2])(b)
  b = Conv2D(filters=out_ch, kernel_size=1, strides=1, padding='same')(b)
  
  if pconv == True:
    z = Conv2D(filters=out_ch, kernel_size=1, strides=1, padding='same')(z)

  c = Add()([z, b])

  return c


**Note:** Here, we need to ensure that the **bilinear resize** has the option **align_corners=True** for proper upsampling of image and for avoiding the shifting problem (TF 1.15 Inference). Also, for **reducing** the number of **parameters**, set the **PReLU** option **shared_axes=[1,2]** to  share the parameters along axes 1 and 2.

In [None]:
# Define the network using the basic bottleneck layers
def slim_net():

  # Initial spatial phase [Reduces input by a factor 1/4]
  input = Input(shape=(512,512,3), name='ip')
  x = Conv2D(filters=8, kernel_size=2, strides=2, padding='valid')(input)
  x = PReLU(shared_axes=[1, 2])(x)
  x = Conv2D(filters=32, kernel_size=2, strides=2, padding='valid')(x)

  b1, r1 = encode_bottleneck(x, proj_ch=16, out_ch=64, strides=2, separable=True, depthwise=True, pool=True)
  b2, p1 = encode_bottleneck(b1, proj_ch=16, out_ch=64, strides=1, separable=True, depthwise=True, preluop=True, pool=False)
  b3 = encode_bottleneck(b2, proj_ch=16, out_ch=64, strides=1, separable=True, depthwise=True, pool=False)
  b4, r2 = encode_bottleneck(b3, proj_ch=32, out_ch=128, strides=2, separable=True, depthwise=True, pool=True)

  b5, p2 = encode_bottleneck(b4, proj_ch=16, out_ch=128, strides=1, separable=True, depthwise=True, preluop=True,pool=False)
  b6 = encode_bottleneck(b5, proj_ch=16, out_ch=128, strides=1, separable=True, depthwise=True, pool=False)
  b7 = encode_bottleneck(b6, proj_ch=16, out_ch=128, strides=1, separable=True, depthwise=True, pool=False)
  b8 = encode_bottleneck(b7, proj_ch=16, out_ch=128, strides=1, separable=True, depthwise=True, pool=False)

  b9, r3 = encode_bottleneck(b8, proj_ch=16, out_ch=128, strides=2, separable=True, depthwise=True, pool=True)
  b10 = encode_bottleneck(b9, proj_ch=8, out_ch=128, strides=1, separable=True, depthwise=True, pool=False)
  b11 = encode_bottleneck(b10, proj_ch=8, out_ch=128, strides=1, dilation=2, separable=False, depthwise=False, pool=False) # dil -2
  b12 = encode_bottleneck(b11, proj_ch=8, out_ch=128, strides=1, separable=True, depthwise=False, pool=False)
  b13 = encode_bottleneck(b12, proj_ch=8, out_ch=128, strides=1, dilation=4, separable=False, depthwise=False, pool=False) # dil -4
  b14 = encode_bottleneck(b13, proj_ch=8, out_ch=128, strides=1, separable=True, depthwise=True, pool=False)
  b15 = encode_bottleneck(b14, proj_ch=8, out_ch=128, strides=1, dilation=8, separable=False, depthwise=False, pool=False) # dil -8 
  b16 = encode_bottleneck(b15, proj_ch=8, out_ch=128, strides=1, separable=True, depthwise=True, pool=False)
  b17 = encode_bottleneck(b16, proj_ch=8, out_ch=128, strides=1, dilation=2, separable=False, depthwise=False, pool=False) # dil -2
  b18 = encode_bottleneck(b17, proj_ch=8, out_ch=128, strides=1, separable=True, depthwise=False, pool=False)
  b19 = encode_bottleneck(b18, proj_ch=8, out_ch=128, strides=1, dilation=4, separable=False, depthwise=False, pool=False) # dil -4
  b20 = encode_bottleneck(b19, proj_ch=8, out_ch=128, strides=1, separable=True, depthwise=True, pool=False)
  b21 = encode_bottleneck(b20, proj_ch=8, out_ch=128, strides=1, dilation=8, separable=False, depthwise=False, pool=False) # dil -8 

  b22 = encode_bottleneck(b21, proj_ch=4, out_ch=128, strides=1, separable=False, depthwise=False, pool=False) # dil -1

 
  d1 = decode_bottleneck(b22,res1=r3, res2=p2 ,proj_ch1=8, proj_ch2=8, out_ch=128, strides=1, rsize=32, pconv=True)
  d2 = decode_bottleneck(d1,res1=r2, res2=p1 ,proj_ch1=8, proj_ch2=4 , out_ch=64, strides=1, rsize=64,pconv=True)
  d3 = decode_bottleneck(d2,res1=r1, res2=None ,proj_ch1=4, proj_ch2=4, out_ch=32, strides=1, rsize=128,pconv=False)

  pout1 = PReLU(shared_axes=[1, 2])(d3)
  cout1 = Conv2DTranspose(filters=8, kernel_size=2, strides=2, padding = 'same' )(pout1) # output size: 256
  pout2 = PReLU(shared_axes=[1, 2])(cout1)
  cout2 = Conv2DTranspose(filters=2, kernel_size=2, strides=2, padding = 'same' )(pout2) # output size: 512

  model = Model(inputs=input, outputs=cout2)
  model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy']) # Ensure you have sparse labels

  return model

# Initialize the model and plot summary
model = slim_net()
model.summary()

# Plot model architecture
plot_model(model, to_file='slim-net.png')


**Training and Callbacks**

Configure the **callbacks** for learning rate decay, logging and checkpoint.

In [None]:
# Save checkpoints
checkpoint = ModelCheckpoint(CHECKPOINT, monitor='val_loss', verbose=1, save_weights_only=False , save_best_only=True, mode='min')

# Callbacks 
reduce_lr = ReduceLROnPlateau(factor=0.5, patience=10, min_lr=0.000001, verbose=1)
tensorboard = TensorBoard(log_dir=LOGS, histogram_freq=0,
                          write_graph=True, write_images=True)

callbacks_list = [checkpoint, tensorboard, reduce_lr]

Perform **training and validation** on the model, using the dataset and save the results.

In [None]:
# Train the model
model_history = model.fit(train_dataset, epochs=epochs,
                          steps_per_epoch=num_train//batch_sz,
                          validation_steps=num_val//batch_sz,
                          validation_data=val_dataset,
                          callbacks=callbacks_list)

**Testing**

Load the model and **test images**

In [None]:
# Load the final model for inference
model=load_model('/content/drive/My Drive/slim512/ckpt/slim-net-157-0.02.hdf5',compile=False)

# Prepare the batch of images as lists
TEST_DIR='/content/test/'
images=os.listdir(TEST_DIR)
inputs=[]
for img in images:
    im=Image.open(TEST_DIR+img)
    im=im.resize((512,512), Image.ANTIALIAS)
    inputs.append(np.array(im))

Perform **inferene** on input image batch

In [None]:
# Perform batch prediction and obtain masks
batch=np.float32(inputs)/255.0
result=model.predict(batch)
argmax=np.argmax(result, axis=3)
outputs=list(argmax[...,np.newaxis]*batch)

Plot the **input and output** images using matplotlib

In [None]:
# Combine the input & output list of images
plot_lists=inputs+outputs

# Plot all the images using matplotlib 
fig=plt.figure(figsize=(16, 8))
columns = 4
rows = 2

# Show all four inputs and corresponding outputs
for i in range(1, columns*rows+1):
    img = plot_lists[i-1].squeeze()
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
plt.show()

**Exporting**

Load the trained **model chekpoint** for export.

In [None]:
import tensorflow as tf

def bilinear_resize(x, rsize):
  return tf.image.resize_bilinear(x, [rsize,rsize], align_corners=True)

def slice_foreground(x):
    return tf.strided_slice(x, [0,0, 0, 1], [1,512, 512, 2], [1, 1, 1, 1])

model=load_model('/content/drive/My Drive/slim512/ckpt/slim-net-157-0.02.hdf5',compile=False)
model.summary()

Add new layers and **modify** the network.

In [None]:
sm=tf.keras.layers.Softmax()(model.output) # softmax
str_slice=Lambda(slice_foreground, name="strided_slice")(sm) # strided slice
newout=Reshape((262144,))(str_slice) # reshape
reshape_model=Model(model.input,newout)

Save the final **keras model** for deployment.

In [None]:
reshape_model.summary()
reshape_model.save('/content/slim_reshape_v2.h5')

Convert the final keras model to **tflite** format.

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model_file('/content/slim_reshape_v2.h5')
tflite_model = converter.convert()
open("slim_reshape_v2.tflite", "wb").write(tflite_model)

**Verification**

Load the exported slim-net **tflite model** for portait segmentation.

In [None]:
!wget -O slim_reshape_v2.tflite https://github.com/anilsathyan7/Portrait-Segmentation/blob/master/models/slim_seg_512/slim_reshape%20v2.tflite?raw=true

Download a sample **portrait image** for testing the model.

In [None]:
import cv2
import numpy as np
from skimage import io
import tensorflow as tf
from matplotlib import pyplot as plt

image=io.imread('https://images.newindianexpress.com/uploads/user/imagelibrary/2020/6/9/w600X390/Javed_Akhtar_PTI.jpg')
image=cv2.resize(image,(512,512))
plt.imshow(image)

Perform inference on test images, using **tflite interpreter**.

In [15]:
def run_tflite_model(tflite_file, test_image):

  # Initialize the interpreter
  interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
  interpreter.allocate_tensors()

  # Get input and output details
  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  # Preprocess the input image
  test_image = test_image/255.0
  test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])

  # Run the interpreter and get the output
  interpreter.set_tensor(input_details["index"], test_image)
  interpreter.invoke()
  output = interpreter.get_tensor(output_details["index"])[0]

  # Compute mask from segmentaion output
  mask = np.reshape(output, (512,512))>0.5

  return mask

**Crop** the input using output mask and plot the results.

In [None]:
import cv2
import numpy as np
from skimage import io
import tensorflow as tf
from matplotlib import pyplot as plt

mask=run_tflite_model(tflite_file='/content/slim_reshape_v2.tflite',test_image=image)
crop_float=image*mask[...,np.newaxis]
plt.imshow(crop_float/255.0)
plt.title('Float Model Output')
plt.show()