In [None]:
import torch
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

# PT2E (torch.ao)
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e

# AI Edge Torch
import ai_edge_torch as aet
from ai_edge_torch.quantize import pt2e_quantizer as aet_q
from ai_edge_torch.quantize import quant_config as aet_qc

m = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
example_inputs = (torch.randn(1,3,224,224),)

# 1) Capture to ExportedProgram (ATen graph)
ep = torch.export.export(m, example_inputs).module()  # 2.6+ API

# 2) Configure an AET PT2E quantizer (symmetric, per-channel)
qspec = aet_q.get_symmetric_quantization_config(is_per_channel=True)
quantizer = aet_q.PT2EQuantizer().set_global(qspec)

# 3) Prepare + calibrate
prepared = prepare_pt2e(ep, quantizer)
with torch.no_grad():
    for _ in range(32): prepared(torch.randn(1,3,224,224))

# 4) Convert (keep Q/DQ explicit for TFLite lowering)
quantized = convert_pt2e(prepared, fold_quantize=False)

print(quantized)

# 5) Convert to TFLite
edge_model = aet.convert(
    quantized,
    example_inputs,
    quant_config=aet_qc.QuantConfig(pt2e_quantizer=quantizer),
)
edge_model.export("mobilenetv2_int8.tflite")


<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  prepared = prepare_pt2e(ep, quantizer)
For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (

GraphModule(
  (features): Module(
    (0): Module(
      (0): Module()
    )
    (1): Module(
      (conv): Module(
        (0): Module(
          (0): Module()
        )
        (1): Module()
      )
    )
    (2): Module(
      (conv): Module(
        (0): Module(
          (0): Module()
        )
        (1): Module(
          (0): Module()
        )
        (2): Module()
      )
    )
    (3): Module(
      (conv): Module(
        (0): Module(
          (0): Module()
        )
        (1): Module(
          (0): Module()
        )
        (2): Module()
      )
    )
    (4): Module(
      (conv): Module(
        (0): Module(
          (0): Module()
        )
        (1): Module(
          (0): Module()
        )
        (2): Module()
      )
    )
    (5): Module(
      (conv): Module(
        (0): Module(
          (0): Module()
        )
        (1): Module(
          (0): Module()
        )
        (2): Module()
      )
    )
    (6): Module(
      (conv): Module(
        (0): 

W0000 00:00:1757303413.778856    6423 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1757303413.778882    6423 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
I0000 00:00:1757303413.779160    6423 reader.cc:83] Reading SavedModel from: /tmp/tmp6ohyr2g1
I0000 00:00:1757303413.783841    6423 reader.cc:52] Reading meta graph with tags { serve }
I0000 00:00:1757303413.783866    6423 reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmp6ohyr2g1
I0000 00:00:1757303413.820058    6423 loader.cc:236] Restoring SavedModel bundle.
I0000 00:00:1757303414.158779    6423 loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmp6ohyr2g1
I0000 00:00:1757303414.245415    6423 loader.cc:471] SavedModel load for tags { serve }; Status: success: OK. Took 466282 microseconds.
I0000 00:00:1757303415.540812    6423 flatbuffer_export.cc:4150] Estimated count of arithmetic ops: 608.445 M  ops, equivalently 304.223 M  MACs


In [42]:
import numpy as np
import torch
import torch.nn.functional as F

# 1) Prepare one test input (use real, normalized data if possible)
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    y_fp = m.eval()(x).cpu().numpy()

# 2) Run TFLite
import tensorflow as tf  # or tflite_runtime.interpreter
interpreter = tf.lite.Interpreter(model_path="mobilenetv2_int8.tflite")
interpreter.allocate_tensors()

inp = interpreter.get_input_details()[0]
out = interpreter.get_output_details()[0]

x_np = x.cpu().numpy().astype(np.float32)

# Handle quantized or float I/O automatically
def set_input(interpreter, detail, x_float):
    if np.issubdtype(detail["dtype"], np.floating):
        interpreter.set_tensor(detail["index"], x_float)
    else:
        scale, zero = detail["quantization"]
        x_q = np.round(x_float / scale + zero)
        qmin = np.iinfo(detail["dtype"]).min
        qmax = np.iinfo(detail["dtype"]).max
        x_q = np.clip(x_q, qmin, qmax).astype(detail["dtype"])
        interpreter.set_tensor(detail["index"], x_q)

def get_output(interpreter, detail):
    y = interpreter.get_tensor(detail["index"])
    if not np.issubdtype(detail["dtype"], np.floating):
        scale, zero = detail["quantization"]
        y = (y.astype(np.float32) - zero) * scale
    return y

set_input(interpreter, inp, x_np)
interpreter.invoke()
y_tfl = get_output(interpreter, out)

# 3) Metrics (PyTorch vs TFLite)
mse = np.mean((y_fp - y_tfl) ** 2)
cos = np.mean(np.sum(y_fp * y_tfl, axis=1) /
              (np.linalg.norm(y_fp, axis=1) * np.linalg.norm(y_tfl, axis=1) + 1e-12))
top1_pt  = y_fp.argmax(axis=1)
top1_tfl = y_tfl.argmax(axis=1)
agree = float((top1_pt == top1_tfl).mean())

print({"mse": mse, "cosine": cos, "top1_agree": agree})


{'mse': np.float32(0.011216138), 'cosine': np.float32(0.9931354), 'top1_agree': 1.0}
