## Check env

In [2]:
import platform
machine = platform.machine()
print(f"This is a {machine} machine")

backend = ""
if machine == 'AMD64':
    # backend = 'x86'
    backend = 'fbgemm'
elif machine == 'arm64':
    backend = 'qnnpack'
print(f"Backend is {backend}")

This is a AMD64 machine
Backend is fbgemm


## Libraries & Config

In [3]:
from dataclasses import dataclass
import torch as T
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Subset
import torch.ao.quantization as Q
from torchvision import datasets, transforms
import math
import os
import numpy as np

T.backends.quantized.engine = backend

In [4]:
@dataclass
class VisionConfig:
    num_hidden_layers: int = 12 # number of hidden layers in the encoder as in the paper
    num_channels: int = 3
    embed_dim: int = 512  # patch_size * patch_size * num_channels
    image_size: int = 32
    patch_size: int = 4
    num_attention_heads: int = 8  # embed_dim // 64
    hidden_size: int = 512  # embed_dim
    intermediate_size: int = 144  # 4 * hidden_size
    layer_norm_eps: float = 1e-6
    attention_dropout: float = 0.0
    
model_config = VisionConfig()

@dataclass
class DatasetConfig:
    batch_size: int = 1
    subset_size: int = 100
    num_workers: int = 4
    
dataset_config = DatasetConfig()

## Load Dataset

In [5]:
# Load the dataset with float32 and int8
transform_fp32 = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(
    #         mean=[0.485, 0.456, 0.406],
    #         std=[0.229, 0.224, 0.225]
    # )
])
def pil_to_tensor(img):
    return T.from_numpy(np.array(img)).permute(2, 0, 1)
transform_int8 = transforms.Compose([
    transforms.Lambda(pil_to_tensor),    
])

# Shape: (B, C, H, W)
full_dataset_fp32 = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_fp32,
)
full_dataset_int8 = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_int8,
)

# Subset for calibration
indices = np.random.choice(
    a=len(full_dataset_fp32),
    size=dataset_config.subset_size,
    replace=False
)
calibration_dataset_fp32 = Subset(full_dataset_fp32, indices)
calibration_loader_fp32 = T.utils.data.DataLoader(
    dataset=calibration_dataset_fp32,
    batch_size=dataset_config.batch_size,
    shuffle=False,
    num_workers=dataset_config.num_workers,
)

calibration_dataset_int8 = Subset(full_dataset_int8, indices)
calibration_loader_int8 = T.utils.data.DataLoader(
    dataset=calibration_dataset_int8,
    batch_size=dataset_config.batch_size,
    shuffle=False,
)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
print(f"Number of data points: {len(calibration_dataset_fp32)}\nNumber of calibration batches: {len(calibration_loader_fp32)}")

Number of data points: 100
Number of calibration batches: 100


In [7]:
# float32 input for check the model
input_tensor_fp32 = next(iter(calibration_loader_fp32))[0]
print(f"Float32 input shape: {input_tensor_fp32.shape}\n{input_tensor_fp32[0, 0, 0]}")

# int8 input for quantization
input_tensor_int8 = next(iter(calibration_loader_int8))[0]
print(f"Int8 input shape: {input_tensor_int8.shape}\n{input_tensor_int8[0, 0, 0]}")

Float32 input shape: torch.Size([1, 3, 32, 32])
tensor([0.6353, 0.6353, 0.6549, 0.6745, 0.6980, 0.7137, 0.7216, 0.7216, 0.7059,
        0.7176, 0.6784, 0.5412, 0.4902, 0.5569, 0.6157, 0.5490, 0.4157, 0.3882,
        0.3843, 0.4314, 0.4667, 0.4314, 0.3412, 0.3529, 0.4118, 0.4824, 0.6078,
        0.6667, 0.6824, 0.6863, 0.6902, 0.6588])
Int8 input shape: torch.Size([1, 3, 32, 32])
tensor([162, 162, 167, 172, 178, 182, 184, 184, 180, 183, 173, 138, 125, 142,
        157, 140, 106,  99,  98, 110, 119, 110,  87,  90, 105, 123, 155, 170,
        174, 175, 176, 168], dtype=torch.uint8)


## Test

In [5]:
# define a floating point model where some layers could be statically quantized
class M(T.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = Q.QuantStub()
        self.conv = T.nn.Conv2d(1, 1, 1)
        self.relu = T.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = Q.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()


M(
  (quant): QuantStub()
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [None]:
# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = Q.get_default_qconfig('fbgemm')
model_fp32.qconfig = Q.get_default_qconfig(backend)

## 運算融合
# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = Q.fuse_modules(model_fp32, [['conv', 'relu']])

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = Q.prepare(model_fp32_fused)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = T.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = Q.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

In [12]:
model_int8

M(
  (quant): Quantize(scale=tensor([0.0233]), zero_point=tensor([107]), dtype=torch.quint8)
  (conv): QuantizedConvReLU2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=0.007877849042415619, zero_point=0)
  (relu): Identity()
  (dequant): DeQuantize()
)

## Embeddings

In [8]:
class VisionEmbeddings(nn.Module):
  def __init__(self, config: VisionConfig):
    super().__init__()
    self.config = config

    self.num_channels = config.num_channels  # 3 for RGB
    self.embed_dim = config.embed_dim  # 512
    self.image_size = config.image_size  # 32
    self.patch_size = config.patch_size  # 4

    self.patch_embedding = nn.Conv2d(
      in_channels=self.num_channels,
      out_channels=self.embed_dim,
      kernel_size=self.patch_size,
      stride=self.patch_size,
      padding=0,
    )

    self.num_patches = (self.image_size // self.patch_size) ** 2  # （32/4）^2 = 64
    self.num_positions = self.num_patches
    self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
    self.register_buffer(
      "position_ids",
      T.arange(self.num_positions).expand((1, -1)),
      persistent=False,
    )
    
    self.quant = Q.QuantStub()
    self.dequant = Q.DeQuantStub()

  def forward(self, pixel_values: T.FloatTensor) -> T.Tensor:
    # B, C, H, W = pixel_values.shape
    pixel_values = self.quant(pixel_values)
    patch_embeds = self.patch_embedding(pixel_values)
    patch_embeds = self.dequant(patch_embeds)
    
    embeddings = patch_embeds.flatten(start_dim=2, end_dim=-1)
    embeddings = embeddings.transpose(1, 2)
    embeddings = embeddings + self.position_embedding(self.position_ids)
    return embeddings

In [9]:
embd_fp32 = VisionEmbeddings(model_config).eval()
print(f"Shape: {embd_fp32(input_tensor_fp32).shape}")

Shape: torch.Size([1, 64, 512])


### Quantize the VisionEmbeddings module

In [10]:
def embd_calibrate(model, data_loader):
    model.eval()
    with T.no_grad():
        for img, _ in data_loader:
            model(img)

In [11]:
def quantize_model(model: nn.Module, model_name: str, calibrate_fn: callable) -> nn.Module:
    model.qconfig = Q.get_default_qconfig(backend)
    if model_name == 'embd':
        model.position_embedding.qconfig = Q.float_qparams_weight_only_qconfig

    # Prepare the model for static quantization. This inserts observers in
    # the model that will observe activation tensors during calibration.
    model = Q.prepare(model)
    calibrate_fn(model, calibration_loader_fp32)
    model = Q.convert(model)
    return model

embd_int8 = quantize_model(
    model=embd_fp32,
    model_name='embd',
    calibrate_fn=embd_calibrate
)



In [12]:
embd_int8

VisionEmbeddings(
  (patch_embedding): QuantizedConv2d(3, 512, kernel_size=(4, 4), stride=(4, 4), scale=0.02523711696267128, zero_point=62)
  (position_embedding): QuantizedEmbedding(num_embeddings=64, embedding_dim=512, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
  (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)

### Save Input Data

In [13]:
attn_fp32 = None
def save_data(
    data_loader: T.utils.data.DataLoader,
    model: nn.Module,
    model_name: str = "embd",
    dir_path: str = "result/input",
) -> None:
    os.makedirs(dir_path, exist_ok=True)
    scale = model.quant.scale
    zero_point = model.quant.zero_point
    if scale is not None:
        scale = scale.detach().numpy()
        index =  math.ceil(math.log2(0.5/scale.item()))
        scale_file_path = os.path.join(dir_path, "scale.npy")
        index_file_path = os.path.join(dir_path, "index.npy")
        np.save(scale_file_path, scale)
        np.save(index_file_path, index)
    if zero_point is not None:
        zero_point = zero_point.detach().numpy()
        zero_point_file_path = os.path.join(dir_path, "zero_point.npy")
        np.save(zero_point_file_path, zero_point)
    datas_int8 = []
    if model_name == "embd":
        for img_fp32, _ in data_loader:
            data_fp32 = model(img_fp32)
            data_int8 = (T.round(data_fp32 / scale) + zero_point).to(T.int8)
            datas_int8.append(data_int8.numpy())
    elif model_name == "attn":
        for img_fp32, _ in data_loader:
            data_fp32 = attn_fp32(embd_int8(img_fp32)).detach()
            data_int8 = (T.round(data_fp32 / scale) + zero_point).to(T.int8)
            datas_int8.append(data_int8.numpy())
    datas_int8_np = np.stack(datas_int8, axis=0)
    print(f"Shape: {datas_int8_np.shape}")
    if model_name == "embd":
        data_file_path = os.path.join(dir_path, "input.npy")
    elif model_name == "attn":
        data_file_path = os.path.join(dir_path, "qk.npy")
    np.save(data_file_path, datas_int8_np)
    
    print(f"Saved input data to {dir_path}/")

In [14]:
save_data(
    data_loader=calibration_loader_fp32,
    model=embd_int8,
    model_name="embd",
    dir_path="result/input",
)

Shape: (100, 1, 64, 512)
Saved input data to result/input/


In [15]:
# Check input type and shape
input_data = np.load("result/input/input.npy")
print(f"Type: {type(input_data[0, 0, 0, 0])}, Shpae: {input_data.shape}\n{input_data[:, 0, 0, 0]}")

Type: <class 'numpy.int8'>, Shpae: (100, 1, 64, 512)
[ 113  106   71  -85 -121   84   81  113 -127  100   84 -127  126   81
   97  -98 -111  103  122 -118  100  103 -108  -92 -124   78  122   87
   90 -127  110 -121 -108  116 -108  119  116 -121   81 -118 -127  -89
   97 -127  100  106  106 -121  -95  119 -118  126  116  -98 -111  126
  122   74  -85   74 -118 -127 -127 -124  122 -102  122   90  -95 -121
  106  119   84   87 -102 -105   81   94   97   81   97 -121  100   94
  -95  -85   78  110  -98  -98  -95 -127 -124   90  122  122   81  -92
 -121   94]


## MSA/Attention

### Model

In [None]:
class Attention(nn.Module):
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.dropout = config.attention_dropout

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
        
        self.quant = Q.QuantStub()
        self.dequant = Q.DeQuantStub()

    def forward(self, hidden_states):
        # the hidden states are the embeddings of the patches, so (batch_size, num_patches, embed_dim)
        B, T, E = hidden_states.shape
        hidden_states = self.quant(hidden_states)  # int8
        q_states = self.q_proj(hidden_states)
        k_states = self.k_proj(hidden_states)
        v_states = self.v_proj(hidden_states)
        
        np.save("result/output/q.npy", q_states.detach().numpy())
        np.save("result/output/k.npy", k_states.detach().numpy())
        np.save("result/output/v.npy", v_states.detach().numpy())
        
        q_states = q_states.view(B, T, self.num_heads, E // self.num_heads).transpose(1, 2)
        k_states = k_states.view(B, T, self.num_heads, E // self.num_heads).transpose(1, 2)
        v_states = v_states.view(B, T, self.num_heads, E // self.num_heads).transpose(1, 2)
        
        # int8 quantization
        q_states = self.quant(q_states)
        k_states = self.quant(k_states)
        v_states = self.quant(v_states)

        attn_weights = (q_states @ k_states.transpose(-2, -1)) * (1.0 / math.sqrt(k_states.size(-1)))
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dequant(attn_weights)  # float32
        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
        
        attn_outs = attn_weights @ v_states
        
        attn_outs = self.quant(attn_outs)  # int8
        attn_outs = attn_outs.transpose(1, 2)
        attn_outs = attn_outs.reshape(B, T, E).contiguous()
        attn_outs = self.out_proj(attn_outs)
        attn_outs = self.dequant(attn_outs)
        return attn_outs

In [18]:
class AttentionQK(nn.Module):
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
                
        self.quant = Q.QuantStub()
        self.dequant = Q.DeQuantStub()

    def forward(self, hidden_states):
        # the hidden states are the embeddings of the patches, so (batch_size, num_patches, embed_dim)
        B, T, E = hidden_states.shape
        hidden_states = self.quant(hidden_states)  # int8
        q_states = self.q_proj(hidden_states)
        k_states = self.k_proj(hidden_states)
        v_states = self.v_proj(hidden_states)
        
        os.makedirs("result/output/Q", exist_ok=True)
        os.makedirs("result/output/K", exist_ok=True)
        os.makedirs("result/output/V", exist_ok=True)
        np.save("result/output/Q/q_fp32.npy", q_states.detach().numpy())
        np.save("result/output/K/k_fp32.npy", k_states.detach().numpy())
        np.save("result/output/V/v_fp32.npy", v_states.detach().numpy())

        q_states = q_states.view(B, T, self.num_heads, E // self.num_heads).transpose(1, 2)
        k_states = k_states.view(B, T, self.num_heads, E // self.num_heads).transpose(1, 2)
        v_states = v_states.view(B, T, self.num_heads, E // self.num_heads).transpose(1, 2)
        
        # int8 quantization
        q_states = self.quant(q_states)
        k_states = self.quant(k_states)
        v_states = self.quant(v_states)

        QK = (q_states @ k_states.transpose(-2, -1)) * (1.0 / math.sqrt(k_states.size(-1)))
        QK = self.quant(QK)
        
        return QK

In [19]:
attn_fp32 = AttentionQK(model_config).eval()
QK = attn_fp32(embd_int8(input_tensor_fp32))
print(f"Shape: {QK.shape}")

Shape: torch.Size([1, 8, 64, 64])


### Quantize the attention module

In [20]:
def attn_calibrate(model, data_loader):
    model.eval()
    with T.no_grad():
        for img, _ in data_loader:
            model(embd_int8(img))

In [21]:
attn_int8 = quantize_model(attn_fp32, 'attn', attn_calibrate)
attn_int8



AttentionQK(
  (k_proj): QuantizedLinear(in_features=512, out_features=512, scale=0.03887570649385452, zero_point=62, qscheme=torch.per_channel_affine)
  (v_proj): QuantizedLinear(in_features=512, out_features=512, scale=0.03652804717421532, zero_point=66, qscheme=torch.per_channel_affine)
  (q_proj): QuantizedLinear(in_features=512, out_features=512, scale=0.03661142289638519, zero_point=62, qscheme=torch.per_channel_affine)
  (quant): Quantize(scale=tensor([0.0617]), zero_point=tensor([64]), dtype=torch.quint8)
  (dequant): DeQuantize()
)

### Save Weight Data

In [22]:
def save_weight(
    model: nn.Module,
    model_name: str = "attn",
    dir_path: str = "result/weights",
):
    os.makedirs(dir_path, exist_ok=True)
    def extract_weight_bias(qlinear: T.nn.quantized.Linear) -> tuple[T.Tensor, T.Tensor]:
        """Extract packed weight, bias, scale, zero_point from quantized Linear layer"""
        weight, bias = qlinear.weight(), qlinear.bias()
        # scale = qlinear.scale
        # zero_point = qlinear.zero_point
        return weight, bias

    for layer_name, param in model.state_dict().items():
        if isinstance(param, T.Tensor):
            weight_name = layer_name.split('.')[0]
            if "scale" in layer_name:
                scale_file_path = os.path.join(dir_path, f"{weight_name}_scale.npy")
                index_file_path = os.path.join(dir_path, f"{weight_name}_index.npy")
                index =  math.ceil(math.log2(0.5/param.item()))
                np.save(scale_file_path, param.detach().numpy())
                np.save(index_file_path, index)
                print(f"✅Saved {weight_name} scale")
            elif "zero_point" in layer_name:
                file_path = os.path.join(dir_path, f"{weight_name}_zero_point.npy")
                np.save(file_path, param.detach().numpy())
                print(f"✅Saved {weight_name} zero_point")
        elif isinstance(param, tuple):
            weight_file_path = os.path.join(dir_path, f"{weight_name}_weight.npy")
            bias_file_path = os.path.join(dir_path, f"{weight_name}_bias_fp32.npy")
            layer = getattr(model, weight_name)
            weight, bias = extract_weight_bias(layer)
            np.save(weight_file_path, weight.detach().int_repr().numpy())
            np.save(bias_file_path, bias.detach().numpy())
            print(f"✅Saved {weight_name} weight and bias")
        else:
            print(f"⚠️Skip {layer_name} (Not Tensor, type: {type(param)})")
        print("----------------------------------")

    print(f"{model_name} weights have been saved!")

In [23]:
save_weight(
    model=attn_int8,
    model_name="attn",
    dir_path="result/weights",
)

✅Saved k_proj scale
----------------------------------
✅Saved k_proj zero_point
----------------------------------
⚠️Skip k_proj._packed_params.dtype (Not Tensor, type: <class 'torch.dtype'>)
----------------------------------
✅Saved k_proj weight and bias
----------------------------------
✅Saved v_proj scale
----------------------------------
✅Saved v_proj zero_point
----------------------------------
⚠️Skip v_proj._packed_params.dtype (Not Tensor, type: <class 'torch.dtype'>)
----------------------------------
✅Saved v_proj weight and bias
----------------------------------
✅Saved q_proj scale
----------------------------------
✅Saved q_proj zero_point
----------------------------------
⚠️Skip q_proj._packed_params.dtype (Not Tensor, type: <class 'torch.dtype'>)
----------------------------------
✅Saved q_proj weight and bias
----------------------------------
✅Saved quant scale
----------------------------------
✅Saved quant zero_point
----------------------------------
attn weight

### Handle Bias

In [26]:
def bias_to_int8(
    dir_path: str = "result/weights",
    weight_name: str = "k_proj",
):
    bias_path = os.path.join(dir_path, f"{weight_name}_bias_fp32.npy")
    scale = np.load(os.path.join(dir_path, f"{weight_name}_scale.npy"))
    zero_point = np.load(os.path.join(dir_path, f"{weight_name}_zero_point.npy"))
    bias_fp32 = np.load(bias_path)
    bias_int8 = (T.round(T.tensor(bias_fp32) / scale) + zero_point).to(T.int8)
    bias_int8_file_path = bias_path.replace("bias", "bias_int8")
    np.save(bias_int8_file_path, bias_int8)
    print(f"Saved int8 bias to {bias_int8_file_path}")

In [27]:
bias_to_int8(weight_name="k_proj")
bias_to_int8(weight_name="v_proj")
bias_to_int8(weight_name="q_proj")

Saved int8 bias to result/weights\k_proj_bias_int8_fp32.npy
Saved int8 bias to result/weights\v_proj_bias_int8_fp32.npy
Saved int8 bias to result/weights\q_proj_bias_int8_fp32.npy


In [124]:
# Check weight type and shape
weight_data = np.load("result/weights/k_proj_weight.npy")
print(f"Type: {type(weight_data[0, 0])}, Shape: {weight_data.shape}\n{weight_data[0, 0]}")

bias_data = np.load("result/weights/k_proj_bias_int8.npy")
print(f"Type: {type(bias_data[0])}, Shape: {bias_data.shape}\n{bias_data[0]}")

Type: <class 'numpy.int8'>, Shape: (512, 512)
-82
Type: <class 'numpy.int8'>, Shape: (512,)
64


### Handle Output Data

In [31]:
def output_to_int8(
    dir_path: str = "result/output/Q",
    output_name: str = "k",
):
    output_path = os.path.join(dir_path, f"{output_name}_fp32.npy")
    output = np.load(output_path)
    scale = np.load(f"result/weights/{output_name}_proj_scale.npy")
    zero_point = np.load(f"result/weights/{output_name}_proj_zero_point.npy")
    output_int8 = (T.round(T.tensor(output) / scale) + zero_point).to(T.int8)
    output_int8_file_path = os.path.join(dir_path, f"{output_name}_int8.npy")
    np.save(output_int8_file_path, output_int8)
    print(f"Saved int8 bias to {output_int8_file_path}")

In [33]:
output_to_int8(
    dir_path="result/output/Q",
    output_name="q",
)
output_to_int8(
    dir_path="result/output/K",
    output_name="k",
)
output_to_int8(
    dir_path="result/output/V",
    output_name="v",
)

Saved int8 bias to result/output/Q\q_int8.npy
Saved int8 bias to result/output/K\k_int8.npy
Saved int8 bias to result/output/V\v_int8.npy


### Save Output Data

In [28]:
save_data(
    data_loader=calibration_loader_fp32,
    model=attn_int8,
    model_name="attn",
    dir_path="result/output/QK",
)

Shape: (100, 1, 8, 64, 64)
Saved input data to result/output/QK/


## MLP

In [None]:
class MLP(nn.Module):
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.config = config
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.quant = Q.QuantStub()
        self.dequant = Q.DeQuantStub()

    def forward(self, hidden_states: T.Tensor) -> T.Tensor:
        hidden_states = self.quant(hidden_states)
        hidden_states = self.fc1(hidden_states)
        hidden_states = nn.functional.gelu(hidden_states, approximate="tanh")
        hidden_states = self.fc2(hidden_states)
        hidden_states = self.dequant(hidden_states)
        return hidden_states

mlp = MLP(model_config)
print(f"MLP: {mlp}\nShape: {mlp(embd_fp32(input_tensor_fp32[:1])).shape}")

MLP: MLP(
  (fc1): Linear(in_features=48, out_features=144, bias=True)
  (fc2): Linear(in_features=144, out_features=48, bias=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
Shape: torch.Size([1, 64, 48])


## Encoder

In [19]:
class EncoderLayer(nn.Module):
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = Attention(config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = MLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        
        # quantization
        self.quant = T.ao.quantization.QuantStub()
        self.dequant = T.ao.quantization.DeQuantStub()


    def forward(self, hidden_states):
        hidden_states = self.quant(hidden_states)
        residual = hidden_states
        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(hidden_states)
        hidden_states = residual + hidden_states
        hidden_states = self.dequant(hidden_states)

        hidden_states = self.quant(hidden_states)
        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        hidden_states = self.dequant(hidden_states)
        return hidden_states

encoder_layer_32 = EncoderLayer(VisionConfig(hidden_size=768, intermediate_size=3072))
encoder_layer_32(T.randn(1, 196, 768)).shape

torch.Size([1, 196, 768])

In [24]:
encoder_layer_32.qconfig = T.ao.quantization.get_default_qconfig(backend)

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
encoder_layer_32_prepared = T.ao.quantization.prepare(encoder_layer_32)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = T.randn(1, 196, 768)
encoder_layer_32_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
encoder_layer_int8 = T.ao.quantization.convert(encoder_layer_32_prepared)
encoder_layer_int8

EncoderLayer(
  (self_attn): Attention(
    (k_proj): QuantizedLinear(in_features=768, out_features=768, scale=0.017260996624827385, zero_point=130, qscheme=torch.per_tensor_affine)
    (v_proj): QuantizedLinear(in_features=768, out_features=768, scale=0.01868443191051483, zero_point=128, qscheme=torch.per_tensor_affine)
    (q_proj): QuantizedLinear(in_features=768, out_features=768, scale=0.017791934311389923, zero_point=125, qscheme=torch.per_tensor_affine)
    (out_proj): QuantizedLinear(in_features=768, out_features=768, scale=0.0009516198770143092, zero_point=124, qscheme=torch.per_tensor_affine)
    (quant): Quantize(scale=tensor([0.0284]), zero_point=tensor([127]), dtype=torch.quint8)
    (dequant): DeQuantize()
  )
  (layer_norm1): QuantizedLayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (mlp): MLP(
    (fc1): QuantizedLinear(in_features=768, out_features=3072, scale=0.017856968566775322, zero_point=128, qscheme=torch.per_tensor_affine)
    (fc2): QuantizedLinear(in_fe