In [4]:
# https://xilinx.github.io/brevitas/getting_started.html

from torch.nn import Module
import torch.nn.functional as F

import brevitas.nn as qnn
from brevitas.quant import Int32Bias

import os
import onnx
import torch
import numpy as np
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, Dataset
from brevitas.nn import QuantLinear, QuantReLU
import torch.nn as nn
from sklearn.metrics import accuracy_score
from tqdm import tqdm, trange
import ipynbname


class QuantWeightActBiasLeNet(Module):
    def __init__(self):
        super(QuantWeightActBiasLeNet, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=4, return_quant_tensor=True)
        self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4, bias_quant=Int32Bias)
        self.relu1 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.conv2 = qnn.QuantConv2d(6, 16, 5, bias=True, weight_bit_width=4, bias_quant=Int32Bias)
        self.relu2 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.fc1   = qnn.QuantLinear(16*5*5, 120, bias=True, weight_bit_width=4, bias_quant=Int32Bias)
        self.relu3 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.fc2   = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=4, bias_quant=Int32Bias)
        self.relu4 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.fc3   = qnn.QuantLinear(84, 10, bias=True, weight_bit_width=4, bias_quant=Int32Bias)

    def forward(self, x):
        out = self.quant_inp(x)
        out = self.relu1(self.conv1(out))
        out = F.max_pool2d(out, 2)
        out = self.relu2(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.reshape(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.relu4(self.fc2(out))
        out = self.fc3(out)
        return out

quant_weight_act_bias_lenet = QuantWeightActBiasLeNet()

# ... training ...

####  Convert to ONNX model

In [6]:
from brevitas.export import export_onnx_qcdq
import torch

# Weight-activation-bias model
export_onnx_qcdq(quant_weight_act_bias_lenet, torch.randn(1, 3, 32, 32), export_path='./4b_weight_act_bias_lenet.onnx')




ir_version: 7
producer_name: "pytorch"
producer_version: "1.13.1"
graph {
  node {
    input: "x.1"
    input: "/quant_inp/act_quant/export_handler/Constant_output_0"
    input: "/quant_inp/act_quant/export_handler/Constant_1_output_0"
    output: "/quant_inp/act_quant/export_handler/QuantizeLinear_output_0"
    name: "/quant_inp/act_quant/export_handler/QuantizeLinear"
    op_type: "QuantizeLinear"
  }
  node {
    input: "/quant_inp/act_quant/export_handler/QuantizeLinear_output_0"
    input: "/quant_inp/act_quant/export_handler/Constant_2_output_0"
    input: "/quant_inp/act_quant/export_handler/Constant_3_output_0"
    output: "/quant_inp/act_quant/export_handler/Clip_output_0"
    name: "/quant_inp/act_quant/export_handler/Clip"
    op_type: "Clip"
  }
  node {
    input: "/quant_inp/act_quant/export_handler/Clip_output_0"
    input: "/quant_inp/act_quant/export_handler/Constant_output_0"
    input: "/quant_inp/act_quant/export_handler/Constant_1_output_0"
    output: "/quant_inp/