In [None]:
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

This Colab shows how to take a PyTorch model, convert using AI Edge Torch and then quantize with AI Edge Quantizer. More details of conversion of PyTorch models to LiteRT is at https://ai.google.dev/edge/litert/models/convert_pytorch

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-ai-edge/ai-edge-quantizer/blob/main/colabs/torch_convert_and_quantize.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google-ai-edge/ai-edge-quantizer/blob/main/colabs/torch_convert_and_quantize.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

In [None]:
# When running in google colab the pre-installed versions of some packages
# might be incompatible with AI edge libraries.
!pip uninstall -y tensorflow jax jaxlib
!pip install ai-edge-torch-nightly
!pip install ai-edge-quantizer-nightly
!pip install ai-edge-model-explorer

In [None]:
import ai_edge_torch
import model_explorer
import numpy as np
import torch
import torchvision

from ai_edge_quantizer import quantizer
from ai_edge_quantizer import recipe

In [None]:
#@title Parameter to visualize LiteRT model
visualize_model = True

In [None]:
resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()
sample_inputs = (torch.randn(1, 3, 224, 224),)
torch_output = resnet18(*sample_inputs)

# Conversion
edge_model = ai_edge_torch.convert(resnet18, sample_inputs)

# Inference
edge_output = edge_model(*sample_inputs)

# Validation
if np.allclose(torch_output.detach().numpy(), edge_output, atol=1e-5):
    print("Inference result with Pytorch and LiteRT was within tolerance")
else:
    print("Something wrong with Pytorch --> LiteRT")

# Serialization
edge_model.export('model/resnet.tflite')

# Model Explorer Visualization
if visualize_model:
  model_explorer.visualize('model/resnet.tflite')

# Quantization (API will quantize and save a flatbuffer as *.tflite)
qt = quantizer.Quantizer('model/resnet.tflite', recipe.dynamic_wi8_afp32())
quant_result = qt.quantize().export_model("model/resnet_quantized.tflite")

## Compare size of flatbuffers

In [None]:
!ls -lh model/*.tflite