# Semantic Segmentation Inference using TF Lite Runtime
In this example notebook, we describe how to use a pre-trained Semantic Segmentation model for inference using the TF Lite Runtime interface.
   - The user can choose the model (see section titled *Choosing a Pre-Compiled Model*)
   - The models used in this example were trained on either ***City Scapes*** or ***ADE 20K*** datasets because they are widely used dataset developed for training and benchmarking semantic segmentation AI models. 
   - We perform inference on a few sample images.
   - We also describe the input preprocessing and output postprocessing steps, demonstrate how to collect various benchmarking statistics and how to visualize the data.

## Choosing a Pre-Compiled Model
We provide a set of precompiled artifacts to use with this notebook that will appear as a drop-down list once the first code cell is executed.

<img src=docs/images/drop_down.PNG width="400">

## Semantic Segmentation

Semantic Segmentation is a popular computer vision algorithm used in many applications such as Free Space Detection and Lane Detection. The image below shows semantic segmentation results on few sample images.

<img src=docs/images/SEG.PNG width="700">

## Tensorflow Lite Runtime based Work flow
The diagram below describes the steps for Tensorflow Lite Runtime based workflow. 

Note:
- The user needs to compile models(sub-graph creation and quantization) on a PC to generate model artifacts.
    - For this notebook we use pre-compiled models artifacts
- The generated artifacts can then be used to run inference on the target.
- Users can run this notebook as-is, only action required is to select a model.

<img src=docs/images/tflrt_work_flow_2.png width="400">

In [None]:
import os
import cv2
import numpy as np
import ipywidgets as widgets
from scripts.utils import get_eval_configs
#grab a set of model configurations locally defined in a script
last_artifacts_id = selected_model_id.value if "selected_model_id" in locals() else None
prebuilt_configs, selected_model_id = get_eval_configs('segmentation','tflitert', num_quant_bits = 8, last_artifacts_id = last_artifacts_id)
display(selected_model_id)

In [None]:
print(f'Selected Model: {selected_model_id.label}')
config = prebuilt_configs[selected_model_id.value]
config['session'].set_param('model_id', selected_model_id.value)
config['session'].start()

## Define utility function to preprocess input images

Below, we define a utility function to preprocess images for the model. This function takes a path as input, loads the image and preprocesses the images as required by the model. The steps below are shown as a reference (no user action required):

 1. Load image
 2. Convert BGR image to RGB
 3. Scale image
 4. Apply per-channel pixel scaling and mean subtraction
 5. Convert RGB Image to BGR. 
 6. Convert the image to NCHW format


- The input arguments of this utility function is selected automatically by this notebook based on the model selected in the drop-down

In [None]:
def preprocess(image_path, size, mean, scale, layout, reverse_channels):
    # Step 1 - read image
    img = cv2.imread(image_path)
    
    # Step 2 - Flip from BGR to RGB
    img = img[:,:,::-1]
    
    # Step 3 -- resize to match model input dimensions 
    img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
     
    # Step 4 - subtract a mean and multiply a scale to match model's expected data distributions
    if mean is not None and scale is not None:   
        img = img.astype('float32')
        for mean, scale, ch in zip(mean, scale, range(img.shape[2])):
            img[:,:,ch] = ((img.astype('float32')[:,:,ch] - mean) * scale)
    # Step 5 - If needed, flip back to BGR
    if reverse_channels:
        img = img[:,:,::-1]
        
    # Step 6 -- Reorder tensor dimensions as NCHW (number, channel, height, width) or NHWC
    if layout == 'NCHW':
        img = np.expand_dims(np.transpose(img, (2,0,1)),axis=0)
    else:
        img = np.expand_dims(img,axis=0)
    
    return img

## Create the model using the stored artifacts

In [None]:
import tflite_runtime.interpreter as tflite

tflite_model_path = config['session'].get_param('model_file')
artifacts_dir = config['session'].get_param('artifacts_folder')
#setup the Tensorflow-Lite delegate to use TIDL
tidl_delegate = [tflite.load_delegate('libtidl_tfl_delegate.so', {'artifacts_folder': artifacts_dir})]
#create an interpreter object for this model, using the TIDL delegate
interpreter = tflite.Interpreter(model_path=tflite_model_path, experimental_delegates=tidl_delegate)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

## Run the model for inference

### Preprocessing and Inference
  - We perform inference on a set of images from the `/sample-images` directory. 
  - We use a loop to preprocess the selected images, and provide them as the input to the network.

### Postprocessing and Visualization
  - Once the inference results are available, we postpocess the results and visualize the inferred classes for each of the input images.
 - Semantic segmentation models return results as a list (i.e. `numpy.ndarray`) with one element to represent the class ID. 
 - We use the `seg_mask_overlay()` function to postprocess the results.
 - Then, in this notebook, we use *matplotlib* to plot the original images and the corresponding results.

In [None]:
from scripts.utils import get_preproc_props

# use results from the past inferences
images = [('sample-images/ADE_val_00001801.jpg', 221),
          ('sample-images/ADE_val_00000590.jpg', 222)]

size, mean, scale, layout, reverse_channels = get_preproc_props(config)
print(f'Image size: {size}')

In [None]:
import tqdm
import matplotlib.pyplot as plt
from PIL import Image
from scripts.utils import seg_mask_overlay

plt.figure(figsize=(20,10))

for num in tqdm.trange(len(images)):
    image_file, grid = images[num]
    img  = Image.open(image_file).convert('RGB')
    ax = plt.subplot(grid)
    #preprocess the image into a tensor that matches the model's input specifications
    img_in = preprocess(image_file , size, mean, scale, layout, reverse_channels)
    if not input_details[0]['dtype'] == np.float32:
        img_in = np.uint8(img_in)
    # Pass the input tensor to TFLite, run inference, and retrieve the output    
    interpreter.set_tensor(input_details[0]['index'], img_in)
    interpreter.invoke()
    res = [interpreter.get_tensor(output_detail['index']) for output_detail in output_details]

    # Postprocessing -- overlay a segmentation mask to show pixel classifications
    org_size = img.size
    img = seg_mask_overlay(res, img, layout).resize(org_size)
    ax.imshow(img)
    
plt.show()

## Plot Inference benchmarking statistics
 - During model execution several benchmarking statistics such as timestamps at different checkpoints, DDR bandwidth are collected and stored. 
 - The `get_TI_benchmark_data()` function can be used to collect these statistics. The statistics are collected as a dictionary of `annotations` and corresponding markers.
 - We provide the utility function plot_TI_benchmark_data to visualize these benchmark KPIs.

<div class="alert alert-block alert-info">
<b>Note:</b> The values represented by <i>Inferences Per Second</i> and <i>Inference Time Per Image</i> uses the total time taken by the inference except the time taken for copying inputs and outputs. In a performance oriented system, these operations can be bypassed by writing the data directly into shared memory and performing on-the-fly input / output normalization.
</div>


In [None]:
from scripts.utils import plot_TI_performance_data, plot_TI_DDRBW_data, get_benchmark_output, print_soc_info
# Pull TI performance measurements from the runtime
stats = interpreter.get_TI_benchmark_data()
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10,5))
plot_TI_performance_data(stats, axis=ax)
plt.show()
# Process stats to get total time (tt), processing time(st), ddr read time (rb), and ddr write time (wb) for one model inference
tt, st, rb, wb = get_benchmark_output(stats)

print_soc_info()
print(f'{selected_model_id.label} :')
print(f' Inferences Per Second    : {1000.0/tt :7.2f} fps')
print(f' Inference Time Per Image : {tt :7.2f} ms')
print(f' DDR usage Per Image      : {rb+ wb : 7.2f} MB')