In [None]:
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-ai-edge/ai-edge-quantizer/blob/main/colabs/getting_started.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google-ai-edge/ai-edge-quantizer/blob/main/colabs/getting_started.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

In [None]:
!pip install ai-edge-quantizer-nightly
!pip install ai-edge-model-explorer
!pip install ai-edge-litert-nightly

## Install

In [None]:
import logging
import numpy as np

import matplotlib.pylab as plt
import pathlib
import random
import json

import numpy as np
import model_explorer

from ai_edge_litert.interpreter import Interpreter
import tensorflow as tf

from ai_edge_quantizer import quantizer
from ai_edge_quantizer import recipe
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.utils import tfl_flatbuffer_utils

## Create and train MNIST in Keras

In [None]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images.astype(np.float32) / 255.0
train_images = train_images.reshape([-1, 28, 28, 1])
test_images = test_images.astype(np.float32) / 255.0
test_images = test_images.reshape([-1, 28, 28, 1])

num_classes = 10
hidden_dim = 32
model = tf.keras.Sequential()

model.add(
    tf.keras.layers.Conv2D(
        hidden_dim//4,
        3,
        activation="relu",
        padding="same",
        input_shape=(28, 28, 1),
        use_bias=True,
    )
)
model.add(tf.keras.layers.AveragePooling2D(pool_size=2))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(hidden_dim, activation="relu", use_bias=True))
model.add(
    tf.keras.layers.Dense(num_classes, use_bias=False, activation="softmax")
)

# Train the digit classification model.
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(
                  from_logits=True),
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=5,
  validation_data=(test_images, test_labels)
)

## Helper functions

In [None]:
def run_model(model_path, test_image_indices):
  global test_images

  # Initialize the interpreter.
  interpreter = Interpreter(model_path=str(model_path))
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  print(f"input details: {input_details}")
  output_details = interpreter.get_output_details()[0]

  predictions = np.zeros((len(test_image_indices),), dtype=int)
  for i, test_image_index in enumerate(test_image_indices):
    test_image = test_images[test_image_index]

    # Check if the input type is quantized, then rescale input data to int8.
    if input_details['dtype'] == np.int8:
      input_scale, input_zero_point = input_details["quantization"]
      test_image = test_image / input_scale + input_zero_point

    test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])
    interpreter.set_tensor(input_details["index"], test_image)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details["index"])[0]

    predictions[i] = output.argmax()

  return predictions

def test_model(model_path, test_image_index, model_type):
  global test_labels

  predictions = run_model(model_path, [test_image_index])

  plt.imshow(test_images[test_image_index])
  template = model_type + " Model \n True:{true}, Predicted:{predict}"
  _ = plt.title(template.format(true= str(test_labels[test_image_index]), predict=str(predictions[0])))
  plt.grid(False)

def evaluate_model(model_path, model_type):
  global test_images
  global test_labels

  test_image_indices = range(test_images.shape[0])
  predictions = run_model(model_path, test_image_indices)

  accuracy = (np.sum(test_labels== predictions) * 100) / len(test_images)

  print('%s model accuracy is %.4f%% (Number of test samples=%d)' % (
      model_type, accuracy, len(test_images)))

## Convert to flatbuffer and visualize float model

In [6]:
#@title Parameter to visualize LiteRT model
visualize_model = True

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
litert_model = converter.convert()

model_path = "mnist_model.tflite"
with open(model_path, "wb") as f:
  f.write(litert_model)

if visualize_model:
  model_explorer.visualize(model_path)

## Create a LiteRT model with dynamic quantization with AI Edge Quantizer

In [None]:
dynamic_quant_mnist_model_path = "mnist_model_quantized.tflite"

qt = quantizer.Quantizer(model_path, recipe.dynamic_wi8_afp32())
quant_result = qt.quantize().export_model(dynamic_quant_mnist_model_path)

if visualize_model:
  model_explorer.visualize(dynamic_quant_mnist_model_path)

## Sanity check of float model on one image

In [None]:
# Change this to test a different image.
test_image_index = 1

# Test the float model
test_model(model_path, test_image_index, model_type="Float")

## Sanity check of LiteRT model with dynamic quantization

In [None]:
test_model(dynamic_quant_mnist_model_path, test_image_index, model_type="Dynamic_wi8_afp32")

## Evaluate the models on all images

In [None]:
# Evaluate the float model
evaluate_model(model_path, model_type="Float")

# Evaluate the LiteRT model with dynamic quantization
evaluate_model(dynamic_quant_mnist_model_path, model_type="Dynamic_wi8_afp32")

## Compare size of flatbuffers

In [None]:
!ls -lh *.tflite