In [1]:
import torch
class M(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.linear = torch.nn.Linear(5, 10)

   def forward(self, x):
      return self.linear(x)


example_inputs = (torch.randn(1, 5),)
m = M().eval()

# Step 1. program capture
# This is available for pytorch 2.6+, for more details on lower pytorch versions
# please check `Export the model with torch.export` section
m = torch.export.export(m, example_inputs).module()
# we get a model with aten ops


# Step 2. quantization
from torchao.quantization.pt2e.quantize_pt2e import (
  prepare_pt2e,
  convert_pt2e,
)

# install executorch: `pip install executorch`
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
  get_symmetric_quantization_config,
  XNNPACKQuantizer,
)
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = prepare_pt2e(m, quantizer)

# calibration omitted

m = convert_pt2e(m)
# we have a model with aten ops doing integer computations when possible

Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu129 for torchao version 0.14.0         Please see GitHub issue #2919 for more info


In [17]:
import torch
from torch import nn
import copy

backend = "fbgemm"  # running on a x86 CPU. Use "qnnpack" if running on ARM.

model = nn.Sequential(
     nn.Conv2d(2,64,3),
     nn.ReLU(),
     nn.Conv2d(64, 128, 3),
     nn.ReLU()
)

## EAGER MODE
m = copy.deepcopy(model)
m.eval()
"""Fuse
- Inplace fusion replaces the first module in the sequence with the fused module, and the rest with identity modules
"""
# torch.quantization.fuse_modules(m, ['0','1'], inplace=True) # fuse first Conv-ReLU pair
# torch.quantization.fuse_modules(m, ['2','3'], inplace=True) # fuse second Conv-ReLU pair

"""Insert stubs"""
m = nn.Sequential(torchao.quantization.QuantStub(), 
                  *m, 
                  torchao.quantization.DeQuantStub())

"""Prepare"""
m.qconfig = torchao.quantization.get_default_qconfig(backend)
torchao.quantization.prepare(m, inplace=True)

"""Calibrate
- This example uses random data for convenience. Use representative (validation) data instead.
"""
with torch.inference_mode():
  for _ in range(10):
    x = torch.rand(1,2, 28, 28)
    m(x)
    
"""Convert"""
torchao.quantization.convert(m, inplace=True)

"""Check"""
print(m[1].weight().element_size()) # 1 byte instead of 4 bytes for FP32

AttributeError: module 'torchao.quantization' has no attribute 'QuantStub'

In [14]:
m[1].weight

<bound method Conv2d.weight of QuantizedConv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.018997007980942726, zero_point=71)>

In [None]:
import copy
import torch
device = "cpu"
dtype = torch.float32

class ToyLinearModel(torch.nn.Module):
    def __init__(self, m=64, n=32, k=64):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, k, bias=True)
        self.linear2 = torch.nn.Linear(k, n, bias=True)

    def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
        return (
            torch.randn(
                batch_size, self.linear1.in_features, dtype=dtype, device=device
            ),
        )

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

m = ToyLinearModel(4, 4, 4).eval().to(dtype).to(device)
m = torch.compile(m, mode="max-autotune")

from torchao.quantization.granularity import PerAxis, PerTensor
from torchao.quantization.observer import AffineQuantizedMinMaxObserver
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain

# per tensor input activation asymmetric quantization
act_obs = AffineQuantizedMinMaxObserver(
    MappingType.SYMMETRIC,
    torch.int8,
    granularity=PerTensor(),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
    zero_point_domain=ZeroPointDomain.NONE
)

# per channel weight asymmetric quantization
weight_obs = AffineQuantizedMinMaxObserver(
    MappingType.SYMMETRIC,
    torch.int8,
    granularity=PerTensor(),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
    zero_point_domain=ZeroPointDomain.NONE
)

import torch.nn.functional as F

class ObservedLinear(torch.nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
        bias: bool = True,
        device=None,
        dtype=None,
    ):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.act_obs = act_obs
        self.weight_obs = weight_obs

    def forward(self, input: torch.Tensor):
        observed_input = self.act_obs(input)
        observed_weight = self.weight_obs(self.weight)
        return F.linear(observed_input, observed_weight, self.bias)

    @classmethod
    def from_float(cls, float_linear, act_obs, weight_obs):
        observed_linear = cls(
            float_linear.in_features,
            float_linear.out_features,
            act_obs,
            weight_obs,
            True,
            device=float_linear.weight.device,
            dtype=float_linear.weight.dtype,
        )
        observed_linear.weight = float_linear.weight
        observed_linear.bias = float_linear.bias
        return observed_linear


from torchao.quantization.quant_api import (
    _replace_with_custom_fn_if_matches_filter,
)

def insert_observers_(model, act_obs, weight_obs):
    _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)

    def replacement_fn(m):
        copied_act_obs = copy.deepcopy(act_obs)
        copied_weight_obs = copy.deepcopy(weight_obs)
        return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs)

    _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)

insert_observers_(m, act_obs, weight_obs)

for _ in range(10):
    example_inputs = m.example_inputs(dtype=dtype, device=device)
    m(*example_inputs)

from torchao.dtypes import to_affine_quantized_intx_static

class QuantizedLinear(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
        weight: torch.Tensor,
        bias: torch.Tensor,
        target_dtype: torch.dtype,
    ):
        super().__init__()
        self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
        weight_scale, weight_zero_point = weight_obs.calculate_qparams()
        assert weight.dim() == 2
        block_size = (weight.shape[0], weight.shape[1])
        self.target_dtype = target_dtype
        self.bias = bias
        self.qweight = to_affine_quantized_intx_static(
            weight, weight_scale, weight_zero_point, block_size, self.target_dtype, zero_point_domain=ZeroPointDomain.NONE
        )

    def forward(self, input: torch.Tensor):
        block_size = input.shape
        qinput = to_affine_quantized_intx_static(
            input,
            self.act_scale,
            self.act_zero_point,
            block_size,
            self.target_dtype,
            zero_point_domain=ZeroPointDomain.NONE,
        )
        return F.linear(qinput, self.qweight, self.bias)

    def forward_int8(self, input: torch.Tensor):
        block_size = input.shape
        qinput = to_affine_quantized_intx_static(
            input,
            self.act_scale,
            self.act_zero_point,
            block_size,
            self.target_dtype,
            zero_point_domain=ZeroPointDomain.NONE,
        )
        int8_i, scale_i, _ = qinput.tensor_impl.get_plain()
        int8_w, scale_w, _ = self.qweight.tensor_impl.get_plain()
        print(F.linear(int8_i.to(torch.int32), int8_w.to(torch.int32)))
        return F.linear(int8_i.to(torch.int32), int8_w.to(torch.int32)).to(torch.float32) * (scale_i*scale_w) + self.bias

    @classmethod
    def from_observed(cls, observed_linear, target_dtype):
        quantized_linear = cls(
            observed_linear.in_features,
            observed_linear.out_features,
            observed_linear.act_obs,
            observed_linear.weight_obs,
            observed_linear.weight,
            observed_linear.bias,
            target_dtype,
        )
        return quantized_linear


from dataclasses import dataclass

from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.transform_module import (
    register_quantize_module_handler,
)

@dataclass
class StaticQuantConfig(AOBaseConfig):
    target_dtype: torch.dtype

@register_quantize_module_handler(StaticQuantConfig)
def _apply_static_quant(
    module: torch.nn.Module,
    config: StaticQuantConfig,
):
    """
    Define a transformation associated with `StaticQuantConfig`.
    This is called by `quantize_`, not by the user directly.
    """
    return QuantizedLinear.from_observed(module, config.target_dtype)

# filter function to identify which modules to swap
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)

# perform static quantization
quantize_(m, StaticQuantConfig(torch.int8), is_observed_linear)

Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu129 for torchao version 0.14.0         Please see GitHub issue #2919 for more info


In [40]:
x = torch.randn(1, 100, 4, dtype=dtype, device=device)
y = m.linear1.forward_int8(x)

tensor([[[-13782,   4494,   8964,  12220],
         [ 17631,   6600,  -1550,  -9179],
         [ -3563,  -2746,  -2081,   -352],
         [ -6578,    680,    790,   3260],
         [ -2544,  -1523,  -5267,  -2687],
         [     3,  -3224, -12286,  -6949],
         [-14733, -11205,   6339,  13776],
         [ -7560,  -1663,    130,  -1248],
         [ -8867,  11408,  -6545,  -2031],
         [  3933,  11864, -11642, -14015],
         [ 11648,  -2793,  -6540, -11410],
         [ -8015,  -3694,   -529,  -2376],
         [  1789,    606, -11456, -10634],
         [ -6238,  -7694, -16127,  -7283],
         [  5782,   9062,  -1450,  -5505],
         [  7691,    180,   5236,   4404],
         [ 12443,   3018,  -6990,  -9822],
         [-11288,  -4266,  -6744,   -902],
         [-11572,   3514,   6744,   9116],
         [  1184,  -4354,   2481,   2398],
         [ -8030,    977,   2732,   2441],
         [-10714,  -3620,  12765,  17869],
         [  7956, -13467,  -6927,  -6749],
         [ 

In [41]:
z = m.linear1.forward(x)
torch.allclose(y, z)

True

In [4]:
x = torch.tensor([-10, -0.0253, -0.0127, -0.0064, 0.0, 0.0064, 0.0127, 0.0253, 0.0380, 0.0506, 1.0, 10])
to_affine_quantized_intx_static(
    x,
    m.linear1.act_scale,
    m.linear1.act_zero_point,
    x.shape,
    m.linear1.target_dtype,
    zero_point_domain=ZeroPointDomain.NONE,
).tensor_impl.get_plain()

(tensor([-128,   -1,   -1,    0,    0,    0,    1,    1,    2,    2,   44,  127],
        dtype=torch.int8),
 tensor([0.0228]),
 None)

In [5]:
torch.round(x / m.linear1.act_scale)

tensor([-439.,   -1.,   -1.,   -0.,    0.,    0.,    1.,    1.,    2.,    2.,
          44.,  439.])

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="4"

In [55]:
import copy
import torch
device = "cpu"
dtype = torch.float32
dtype_q = torch.int8

class ToyModel(torch.nn.Module):
    def __init__(self, ci=64, ch=32, co=64, k=3):
        super().__init__()
        self.conv = torch.nn.Conv1d(ci, ch, k, padding=(k-1)//2, bias=True)
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(ch, co, bias=True)

    def set_int8_mode(self):
        self.conv.set_int8_mode()
        self.linear.set_int8_mode()

    def set_quantized_mode(self):
        self.conv.set_quantized_mode()
        self.linear.set_quantized_mode()

    def forward(self, x):
        """x: [B, Ci, T]"""
        x = self.conv(x)
        x = self.relu(x)
        x = x.transpose(1, 2)
        x = self.linear(x)
        return x

model = ToyModel(4, 4, 4).eval().to(dtype).to(device)
m = copy.deepcopy(model)
# m = torch.compile(m, mode="max-autotune")

from torchao.quantization.granularity import PerTensor
from torchao.quantization.observer import AffineQuantizedMinMaxObserver
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.dtypes import to_affine_quantized_intx_static

# per tensor input activation asymmetric quantization
act_obs = AffineQuantizedMinMaxObserver(
    MappingType.SYMMETRIC,
    dtype_q,
    granularity=PerTensor(),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
    zero_point_domain=ZeroPointDomain.NONE
)

# per channel weight asymmetric quantization
weight_obs = AffineQuantizedMinMaxObserver(
    MappingType.SYMMETRIC,
    dtype_q,
    granularity=PerTensor(),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
    zero_point_domain=ZeroPointDomain.NONE
)

import torch.nn.functional as F

class QModule(torch.nn.Module):
    def __init__(
        self,
        module: torch.nn.Module,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
    ):
        super().__init__()
        self.module = module
        self.act_obs = act_obs
        self.weight_obs = weight_obs
        self.mode = "observe"
        self.act_scale = 0.0
        self.act_zero_point = None
        self.target_dtype = dtype_q

    def set_quantized_mode(self):
        if self.mode == "int8":
            self.mode = "quantized"
            return
        elif self.mode == "quantized":
            return
        self.mode = "quantized"
        self.act_scale, self.act_zero_point = self.act_obs.calculate_qparams()
        weight_scale, weight_zero_point = self.weight_obs.calculate_qparams()
        self.weight = to_affine_quantized_intx_static(
            self.module.weight,
            weight_scale,
            weight_zero_point,
            self.module.weight.shape,
            self.target_dtype,
            zero_point_domain=ZeroPointDomain.NONE
        )
        self.bias = torch.nn.Parameter(self.module.bias.data.clone())
        # delattr(self.module, "weight")
        # delattr(self.module, "bias")

    def set_int8_mode(self):
        assert self.mode == "quantized"
        self.mode = "int8"

    def forward(self, x: torch.Tensor):
        if self.mode == "int8":
            qinput = to_affine_quantized_intx_static(
                x,
                self.act_scale,
                self.act_zero_point,
                x.shape,
                self.target_dtype,
                zero_point_domain=ZeroPointDomain.NONE,
            )
            int8_i, scale_i, _ = qinput.tensor_impl.get_plain()
            int8_w, scale_w, _ = self.weight.tensor_impl.get_plain()
            return self._forward_int8(int8_i.to(torch.int32), int8_w.to(torch.int32), scale_i * scale_w)
        if self.mode == "observe":
            x = self.act_obs(x)
            w = self.weight_obs(self.module.weight)
            b = self.module.bias
        elif self.mode == "quantized":
            x = to_affine_quantized_intx_static(
                x,
                self.act_scale,
                self.act_zero_point,
                x.shape,
                self.target_dtype,
                zero_point_domain=ZeroPointDomain.NONE,
            )
            int8_i, scale_i, _ = x.tensor_impl.get_plain()
            int8_w, scale_w, _ = self.weight.tensor_impl.get_plain()
            x = int8_i.to(torch.float32) * scale_i
            w = int8_w.to(torch.float32) * scale_w
            b = self.bias
        return self.forward_module(x, w, b)

    def forward_module(self, x, w):
        raise NotImplementedError()


class QLinear(QModule):
    def forward_module(self, x, w, b):
        return F.linear(x, w, b)

    def _forward_int8(self, x: torch.Tensor, w: torch.Tensor, scale: float):
        return F.linear(x, w).to(torch.float32) * scale + self.bias


class QConv1d(QModule):
    def forward_module(self, x, w, b):
        m = self.module
        return F.conv1d(x, w, b, stride=m.stride, padding=m.padding, groups=m.groups, dilation=m.dilation)

    def _forward_int8(self, x: torch.Tensor, w: torch.Tensor, scale: float):
        m = self.module
        return F.conv1d(
            x, w, stride=m.stride, padding=m.padding, groups=m.groups, dilation=m.dilation
        ).to(torch.float32) * scale + self.bias.unsqueeze(1)


from torchao.quantization.quant_api import (
    _replace_with_custom_fn_if_matches_filter,
)

def insert_observers_(model, act_obs, weight_obs):
    _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
    _is_conv = lambda m, fqn: isinstance(m, torch.nn.Conv1d)

    def replacement_fn_linear(m):
        copied_act_obs = copy.deepcopy(act_obs)
        copied_weight_obs = copy.deepcopy(weight_obs)
        return QLinear(m, copied_act_obs, copied_weight_obs)

    def replacement_fn_conv(m):
        copied_act_obs = copy.deepcopy(act_obs)
        copied_weight_obs = copy.deepcopy(weight_obs)
        return QConv1d(m, copied_act_obs, copied_weight_obs)

    _replace_with_custom_fn_if_matches_filter(model, replacement_fn_linear, _is_linear)
    _replace_with_custom_fn_if_matches_filter(model, replacement_fn_conv, _is_conv)

insert_observers_(m, act_obs, weight_obs)

for _ in range(10):
    example_inputs = torch.randn(10, 4, 11, dtype=dtype, device=device)
    m(example_inputs)


from dataclasses import dataclass

from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.transform_module import (
    register_quantize_module_handler,
)

@dataclass
class StaticQuantConfig(AOBaseConfig):
    target_dtype: torch.dtype

@register_quantize_module_handler(StaticQuantConfig)
def _apply_static_quant(
    module: torch.nn.Module,
    config: StaticQuantConfig,
):
    """
    Define a transformation associated with `StaticQuantConfig`.
    This is called by `quantize_`, not by the user directly.
    """
    return module.set_quantized_mode()

# filter function to identify which modules to swap
is_observed = lambda m, fqn: isinstance(m, QModule)

# perform static quantization
quantize_(m, StaticQuantConfig(dtype_q), is_observed)

In [56]:
x = torch.randn(1, 4, 10, dtype=dtype, device=device)
y_orig = model(x)
m.set_quantized_mode()
y = m(x)
m.set_int8_mode()
z = m.forward(x)
print(y_orig, y, z)

tensor([[[-0.5158, -0.5071, -0.3130,  0.5694],
         [-0.4772, -0.0242, -0.2216,  0.5604],
         [-0.4678,  0.0796, -0.1654,  0.2828],
         [-0.5277, -0.0567, -0.1809,  0.3696],
         [-0.4450, -0.2705, -0.2657,  0.4772],
         [-0.4696, -0.3269, -0.2753,  0.4750],
         [-0.6127, -0.1233, -0.1858,  0.5636],
         [-0.5405,  0.0497, -0.1543,  0.3666],
         [-0.5064,  0.0536, -0.1598,  0.3175],
         [-0.4504, -0.1618, -0.2346,  0.3581]]], grad_fn=<AddBackward0>) tensor([[[-0.5223, -0.4976, -0.3084,  0.5709],
         [-0.4771, -0.0275, -0.2213,  0.5595],
         [-0.4680,  0.0768, -0.1654,  0.2805],
         [-0.5270, -0.0602, -0.1829,  0.3697],
         [-0.4441, -0.2729, -0.2650,  0.4781],
         [-0.4725, -0.3253, -0.2732,  0.4767],
         [-0.6136, -0.1215, -0.1863,  0.5613],
         [-0.5404,  0.0496, -0.1551,  0.3646],
         [-0.5069,  0.0556, -0.1601,  0.3194],
         [-0.4508, -0.1663, -0.2352,  0.3616]]], grad_fn=<AddBackward0>) tensor([

In [29]:
F.conv1d(torch.randn(1, 4, 100, dtype=torch.int32), torch.randn(4,4,3, dtype=torch.int32), padding=1)

NotImplementedError: "normal_kernel_cpu" not implemented for 'Int'