In [None]:
#!pip install ai-edge-torch-nightly
#!pip install torchao
#!pip install ai-edge-model-explorer
#!pip install ai-edge-litert

In [None]:
import os
import shutil
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import random
from tqdm import tqdm

# Set the random seed for reproducibility
seed = 42
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

from ai_edge_litert.interpreter import Interpreter
import tensorflow as tf

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'squeezenet1_1', pretrained=True)
model.eval()

In [None]:
import ai_edge_torch
import numpy
import torch
import torchvision

In [None]:
sample_inputs = (torch.randn(1, 3, 224, 224),)
torch_output = model(*sample_inputs)

In [None]:
edge_model = ai_edge_torch.convert(model.eval(), sample_inputs)

In [None]:
edge_output = edge_model(*sample_inputs)

In [None]:
if (numpy.allclose(
    torch_output.detach().numpy(),
    edge_output,
    atol=1e-5,
    rtol=1e-5,
)):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")


In [None]:
edge_model.export('./squeezenet.tflite')

In [None]:
# Initialize the TFLite interpreter
interpreter = Interpreter('./squeezenet.tflite')
#interpreter = tf.lite.Interpreter(model_path='quantized_model.tflite')
# Allocate the tensors
interpreter.allocate_tensors()

In [None]:
# Get input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Get the details of all tensors in the model
tensor_details = interpreter.get_tensor_details()

# Print details of each tensor
for tensor in tensor_details:
    print(f"Tensor Name: {tensor['name']}, Type: {tensor['dtype']}, Quantization: {tensor['quantization']}")

# Initialize counters
quantized_count = 0
non_quantized_count = 0

# Iterate through tensor details
for tensor in tensor_details:
    tensor_type = tensor['dtype']
    # Check if the tensor is quantized
    if tensor_type == tf.int8:
        quantized_count += tf.reduce_prod(tensor['shape']).numpy()  # Count the number of quantized parameters
    else:
        non_quantized_count += tf.reduce_prod(tensor['shape']).numpy()  # Count the number of non-quantized parameters

# Print the results
print(f'Total quantized parameters: {quantized_count}')
print(f'Total non-quantized parameters: {non_quantized_count}')

In [None]:
import model_explorer
model_explorer.visualize('squeezenet.tflite')

In [None]:
https://medium.com/axinc-ai/quantization-with-ai-edge-torch-1efe17b93cd7