Demonstrate selective quantization capabilities of AI Edge Quantizer.


In [None]:
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-ai-edge/ai-edge-quantizer/blob/main/colabs/selective_quantization_isnet.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google-ai-edge/ai-edge-quantizer/blob/main/colabs/selective_quantization_isnet.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

In [None]:
# When running in google colab the pre-installed versions of some packages
# might be incompatible with AI edge libraries.
!pip uninstall -y tensorflow jax jaxlib
!pip install ai-edge-litert-nightly
!pip install ai-edge-model-explorer
!pip install ai-edge-quantizer-nightly
!pip install ai-edge-torch-nightly
!pip install pillow requests matplotlib

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import skimage
import tensorflow as tf
import ai_edge_quantizer
import model_explorer

from ai_edge_litert.interpreter import Interpreter

In [None]:
# @title Preprocess/postprocess utilities (unrelated to quantization) { display-mode: "form" }

MODEL_INPUT_HW = (1024, 1024)

def make_channels_first(image):
  image = tf.transpose(image, [2, 0, 1])
  image = np.expand_dims(image, axis=0)
  return image

def preprocess_image(file_path):
  image = skimage.io.imread(file_path)
  image = tf.image.resize(image, MODEL_INPUT_HW).numpy().astype(np.float32)
  image = image / 255.0
  return make_channels_first(image)

def preprocess_image_ai_edge_torch(test_image_path):
  image = Image.open(test_image_path)
  test_image = np.array(image.resize(MODEL_INPUT_HW, Image.Resampling.BILINEAR))
  test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
  return test_image

def run_segmentation(image, model_path):
  """Get segmentation mask of the image."""
  interpreter = Interpreter(model_path=model_path)
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  interpreter.set_tensor(input_details['index'], image)
  interpreter.invoke()

  output_details = interpreter.get_output_details()
  output_index = 0
  outputs = []
  for detail in output_details:
    outputs.append(interpreter.get_tensor(detail['index']))
  mask = tf.squeeze(outputs[output_index])
  # Min-max normalization.
  tf_min = np.min(mask)
  tf_max = np.max(mask)
  mask = (mask - tf_min) / (tf_max - tf_min)
  # Scale [0, 1] -> [0, 255].
  mask = (mask * 255)
  return mask


def draw_segmentation(image, float_mask, quant_mask, info):
  _, ax = plt.subplots(1, 3, figsize=(15, 10))

  ax[0].imshow(np.array(image))
  ax[1].imshow(np.array(float_mask), cmap='gray')
  ax[2].imshow(np.array(quant_mask), cmap='gray')

  ax[1].set_title('Image')
  ax[1].set_title('Float Mask')
  ax[2].set_title('Quant Mask: {}'.format(info))

  plt.show()

def save_model(model_content, save_path):
  with open(save_path, 'wb') as f:
    f.write(model_content)



In [None]:
!curl -H 'Accept: application/vnd.github.v3.raw'  -O   -L https://api.github.com/repos/google-ai-edge/ai-edge-quantizer/contents/colabs/test_data/input_image.jpg

IMAGE_PATH = 'input_image.jpg'

image = Image.open(IMAGE_PATH)
test_image = preprocess_image_ai_edge_torch(IMAGE_PATH)

# Getting LiteRT model From Pytorch.

Our first step is to convert a PyTorch model to a float LiteRT model (which will be the input to AI Edge Quantizer).

In [None]:
%cd /content
!rm -rf DIS sample_data

!git clone https://github.com/xuebinqin/DIS.git
%cd DIS/IS-Net/

!curl -o ./model.tar.gz -L https://www.kaggle.com/api/v1/models/paulruiz/dis/pyTorch/8-17-22/1/download
!tar -xvf 'model.tar.gz'

In [None]:
import torch
from models import ISNetDIS

pytorch_model_filename = 'isnet-general-use.pth'
pt_model = ISNetDIS()
pt_model.load_state_dict(
    torch.load(pytorch_model_filename, map_location=torch.device('cpu'))
)

import torch
from torch import nn
from torchvision.transforms.functional import normalize


class ImageSegmentationModelWrapper(nn.Module):

  RESCALING_FACTOR = 255.0
  MEAN = 0.5
  STD = 1.0

  def __init__(self, pt_model):
    super().__init__()
    self.model = pt_model

  def forward(self, image: torch.Tensor):
    # BHWC -> BCHW.
    image = image.permute(0, 3, 1, 2)

    # Rescale [0, 255] -> [0, 1].
    image = image / self.RESCALING_FACTOR

    # Normalize.
    image = (image - self.MEAN) / self.STD

    # Get result.
    result = self.model(image)[0][0]

    # BHWC -> BCHW.
    result = result.permute(0, 2, 3, 1)

    return result


wrapped_pt_model = ImageSegmentationModelWrapper(pt_model).eval()

In [None]:
# @title Convert torch model to LiteRT using AI Edge Torch

import ai_edge_torch

sample_args = (torch.rand((1, *MODEL_INPUT_HW, 3)),)
edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)
edge_model.export('model/isnet_float.tflite')

# AI Edge Quantizer

To use the `Quantizer`, we need to provide
* the float .tflite model.
* quantization recipe (i.e., apply quantization algorithm X on Operator Y with configuration Z).






### Quantizing model with dynamic quantization


The following example will showcase how to get a model with dynamic quantization with AI Edge Quantizer.

In [None]:
from ai_edge_quantizer import recipe

quantizer = ai_edge_quantizer.Quantizer(float_model='model/isnet_float.tflite')
quantizer.load_quantization_recipe(recipe=recipe.dynamic_wi8_afp32())


quantization_result = quantizer.quantize()
quantization_result.export_model('model/isnet_dynamic_wi8_afp32.tflite')

`quantization_result` has two components

> Add blockquote


* quantized LiteRT model (in bytearray) and
* the corresponding quantization recipe

Let's take a look at what in this recipe

In [None]:
quantization_result.recipe

Here the recipe means: apply the naive min/max uniform algorithm (`min_max_uniform_quantize`) for all ops supported by the AI Edge Quantizer (indicated by `*`) under layers satisfying regex `.*` (i.e., all layers). We want the weights of these ops to be quantized as int8, symmetric, channel_wise, and we want to execute the ops in `Integer` mode.


Now let try running both the float model and the newly quantized model and see how they compare.

In [None]:
quantized_mask = run_segmentation(test_image, 'model/isnet_dynamic_wi8_afp32.tflite')
float_mask = run_segmentation(test_image, 'model/isnet_float.tflite')
draw_segmentation(image, float_mask, quantized_mask, 'Dynamic_wi8_afp32')

# Debug through Model Explorer (visualization)

Now we know that Float execution give us better quality result, with a larger model size. Dynamic quantization gives a smaller model size but the quality can be worse.

Let's try to understand where dynamic quantization is introducing precision loss to see if we can do better.

The following code will generate a tensor-by-tensor comparison result between the dynamic quantized model and original float model.



In [None]:
#@title Parameter to visualize LiteRT model
visualize_model = True

*Note*: `quantizer.validate` is a memory intensive operation. This might result in the out of memory errors in the cell below.

In [None]:
comparison_result = quantizer.validate(
    test_data={'serving_default': [{'args_0': test_image}]},
    error_metrics='median_diff_ratio',
    use_xnnpack=False,
    num_threads=1,
).save('', 'dynamic')

In [None]:
if visualize_model:
  model_explorer.visualize_from_config(
      model_explorer.config()
      .add_model_from_path('model/isnet_dynamic_wi8_afp32.tflite')
      .add_node_data_from_path('dynamic_comparison_result_me_input.json')
  )

Using Model Explorer, we find that the errors come from the last few layers ('RSU6_stage2d', 'RSU7_stage1d', 'Conv2d_side1'). Lets try not quantize them.

# Selective Dynamic Quantization

Here we'll override the original `dynamic_wi8_afp32` recipe to skip the three scopes that produce inaccurate results. Notice that for each scope, the newly added rule always take precedence.

In [None]:
scopes = ['RSU6', 'RSU7', 'Conv2d_side1']
for scope in scopes:
  quantizer.update_quantization_recipe(
      regex=scope,
      operation_name='CONV_2D',
      algorithm_key='no_quantize',
  )
quantizer.get_quantization_recipe()

In [None]:
quantizer.quantize().export_model('model/isnet_selective_dynamic_wi8_afp32.tflite')
quantized_mask = run_segmentation(
    test_image, 'model/isnet_selective_dynamic_wi8_afp32.tflite'
)
draw_segmentation(image, float_mask, quantized_mask, 'Selective Dynamic')

In [None]:
!ls -lh model/*.tflite