In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
torch.manual_seed(0)

from torch.nn import Module

# import brevitas.nn as qnn
from brevitas.nn import QuantLinear, QuantHardTanh, QuantIdentity
from brevitas.nn import QuantSigmoid
from brevitas import config

from common import *

cuda:0


In [3]:
class QuantNet(Module):
    def __init__(self):
        super(QuantNet, self).__init__()

        self.fc1 = QuantLinear(784, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn1 = nn.BatchNorm1d(128)
        #self.bn1 = ShiftBatchNorm(1024)

        self.fc2 = QuantLinear(128, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn2 = nn.BatchNorm1d(1024)
        #self.bn2 = ShiftBatchNorm(1024)

        self.fc3 = QuantLinear(128, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn3 = nn.BatchNorm1d(128)
        #self.bn3 = ShiftBatchNorm(128)

        self.fc4 = QuantLinear(128, 10, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn4 = nn.BatchNorm1d(128)
        #self.bn4 = ShiftBatchNorm(128)
        
        self.quant_identity = QuantIdentity(act_quant=Int8ActPerTensorFloatScratch, return_quant_tensor = True)
        self.quant_act = QuantSigmoid(act_quant=Int8ActPerTensorFloatScratch, return_quant_tensor = True)
        self.dropout = nn.Dropout(p=0.0)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)

        x = self.quant_identity(x)
        # x = self.dropout(x)

        x = self.fc1(x)
        # x = self.bn1(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc2(x)
        # x = self.bn2(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc3(x)
        # x = self.bn3(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc4(x)
        # x = self.bn4(x)
        output = F.log_softmax(x, dim=1)
        return output

# Load Pre-trained Model Parameters

In [4]:
config.IGNORE_MISSING_KEYS = True
# pretrend weight path = 'mnist_qnn_mlp.pt'
model_state_dict = torch.load('mnist_qnn_mlp.pt')
net = QuantNet()
print(net.quant_act.quant_act_scale())
print(net.quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value)
net.load_state_dict(model_state_dict)
print(net.quant_act.quant_act_scale())
print(net.quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value)

tensor(0.0078)
Parameter containing:
tensor(1., requires_grad=True)
tensor(0.0089, grad_fn=<DivBackward0>)
Parameter containing:
tensor(2.8619, requires_grad=True)


In [5]:
from brevitas.core.scaling import ConstScaling

from brevitas.inject import ExtendedInjector
class CommonQuantizer(ExtendedInjector):
    quant_type = QuantType.INT
    bit_width_impl_type = BitWidthImplType.CONST
    scaling_impl_type = ScalingImplType.CONST
    restrict_scaling_type = RestrictValueType.FP
    zero_point_impl = ZeroZeroPoint
    float_to_int_impl_type = FloatToIntImplType.ROUND
    scaling_per_output_channel = False
    narrow_range = False
    signed = True

class ActInferenceQuant(CommonQuantizer, ActQuantSolver):
    bit_width = 8
    scaling_impl = ConstScaling(2.8619)

In [6]:
class QuantNetInference(Module):
    def __init__(self):
        super(QuantNetInference, self).__init__()

        self.fc1 = QuantLinear(784, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn1 = nn.BatchNorm1d(128)
        #self.bn1 = ShiftBatchNorm(1024)

        self.fc2 = QuantLinear(128, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn2 = nn.BatchNorm1d(1024)
        #self.bn2 = ShiftBatchNorm(1024)

        self.fc3 = QuantLinear(128, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn3 = nn.BatchNorm1d(128)
        #self.bn3 = ShiftBatchNorm(128)

        self.fc4 = QuantLinear(128, 10, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn4 = nn.BatchNorm1d(128)
        #self.bn4 = ShiftBatchNorm(128)
        
        self.quant_identity = QuantIdentity(act_quant=ActInferenceQuant, return_quant_tensor = True)
        self.quant_act = QuantSigmoid(act_quant=ActInferenceQuant, return_quant_tensor = True)
        self.dropout = nn.Dropout(p=0.0)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)

        x = self.quant_identity(x)
        # x = self.dropout(x)

        x = self.fc1(x)
        # x = self.bn1(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc2(x)
        # x = self.bn2(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc3(x)
        # x = self.bn3(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc4(x)
        # x = self.bn4(x)
        output = F.log_softmax(x, dim=1)
        return output

In [7]:
quant_model_inference = QuantNetInference()
print(quant_model_inference)

QuantNetInference(
  (fc1): QuantLinear(
    in_features=784, out_features=128, bias=False
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (output_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (weight_quant): WeightQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClampSte()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
        )
        (scaling_impl): StatsFromParameterScaling(
          (parameter_list_stats): _ParameterListStats(
            (first_tracked_param): _ViewParameterWrapper(
              (view_shape_impl): OverTensorView()
            )
            (stats): _Stats(
              (stats_impl): AbsMax()
            )
          )
          (stats_scaling_impl): _Stat

In [8]:
config.IGNORE_MISSING_KEYS = False
# pretrend weight path = 'mnist_qnn_mlp.pt'
model_state_dict = torch.load('mnist_qnn_mlp.pt')
print("before loading")
print(quant_model_inference.quant_act.quant_act_scale())
print(quant_model_inference.quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl(torch.empty(1)))

quant_model_inference.load_state_dict(model_state_dict)
print("after loading")
print(quant_model_inference.quant_act.quant_act_scale())
print(quant_model_inference.quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl(torch.empty(1)))

before loading
tensor(0.0224)
tensor(2.8619)
after loading
tensor(0.0224)
tensor(2.8619)


In [9]:
# random torch tensor
x = torch.randn(1, 1, 28, 28)


In [10]:
print("training model: ", net.quant_identity(x))

training model:  QuantTensor(value=tensor([[[[-1.5204, -0.0894,  2.0123, -0.9838, -0.4248,  0.3577,  0.0671,
            1.2968,  1.0062,  0.0224,  0.8944,  0.3577,  1.1403,  1.1627,
            0.3130,  1.9676,  0.4025, -0.2236, -0.6931,  0.8496,  1.0062,
           -0.4695,  0.0000, -0.4248,  0.3354, -1.4757, -1.9005, -0.4472],
          [-2.2806,  0.9838,  1.7664, -1.0956,  0.7378,  0.9838,  0.0671,
            0.3801, -0.4025, -0.1342, -0.2236,  0.7826, -0.4472,  0.3130,
           -0.3130,  2.1017, -1.4310, -1.1179,  2.0347,  0.3354,  0.1565,
            0.5366,  0.0671,  0.3801,  2.2359, -0.4695, -1.6322,  0.8720],
          [ 1.2968, -0.2459, -0.9838,  1.4757, -0.3801, -0.9838,  0.1789,
           -0.5590, -0.8944, -0.2683,  0.5143, -0.3577,  0.6484, -0.2683,
            1.3415,  1.2074, -0.2907, -0.1342, -1.2074,  0.8944, -0.4248,
            1.6769, -0.7826, -1.7440,  0.1342, -0.9167,  1.2297,  0.6037],
          [ 0.9167,  0.4472,  0.9391,  1.3863,  0.4472, -1.4310, -0.4472,


  return super().rename(names)


In [11]:
print("inference model: ", quant_model_inference.quant_identity(x))

inference model:  QuantTensor(value=tensor([[[[-1.5204, -0.0894,  2.0123, -0.9838, -0.4248,  0.3577,  0.0671,
            1.2968,  1.0061,  0.0224,  0.8943,  0.3577,  1.1403,  1.1626,
            0.3130,  1.9676,  0.4025, -0.2236, -0.6931,  0.8496,  1.0061,
           -0.4695,  0.0000, -0.4248,  0.3354, -1.4757, -1.9005, -0.4472],
          [-2.2806,  0.9838,  1.7663, -1.0956,  0.7378,  0.9838,  0.0671,
            0.3801, -0.4025, -0.1342, -0.2236,  0.7826, -0.4472,  0.3130,
           -0.3130,  2.1017, -1.4310, -1.1179,  2.0346,  0.3354,  0.1565,
            0.5366,  0.0671,  0.3801,  2.2359, -0.4695, -1.6322,  0.8720],
          [ 1.2968, -0.2459, -0.9838,  1.4757, -0.3801, -0.9838,  0.1789,
           -0.5590, -0.8943, -0.2683,  0.5142, -0.3577,  0.6484, -0.2683,
            1.3415,  1.2074, -0.2907, -0.1342, -1.2074,  0.8943, -0.4248,
            1.6769, -0.7826, -1.7440,  0.1342, -0.9167,  1.2297,  0.6037],
          [ 0.9167,  0.4472,  0.9391,  1.3862,  0.4472, -1.4310, -0.4472,

In [15]:
torch.sum(net.quant_identity(x).value - quant_model_inference.quant_identity(x).value)

tensor(0.0009, grad_fn=<SumBackward0>)