In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
torch.manual_seed(0)

from torch.nn import Module

# import brevitas.nn as qnn
from brevitas.nn import QuantLinear, QuantHardTanh, QuantIdentity
from brevitas.nn import QuantSigmoid
from brevitas import config

from common import *

cuda:0


In [2]:
class QuantNet(Module):
    def __init__(self):
        super(QuantNet, self).__init__()

        self.fc1 = QuantLinear(784, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn1 = nn.BatchNorm1d(128)
        #self.bn1 = ShiftBatchNorm(1024)

        self.fc2 = QuantLinear(128, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn2 = nn.BatchNorm1d(1024)
        #self.bn2 = ShiftBatchNorm(1024)

        self.fc3 = QuantLinear(128, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn3 = nn.BatchNorm1d(128)
        #self.bn3 = ShiftBatchNorm(128)

        self.fc4 = QuantLinear(128, 10, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn4 = nn.BatchNorm1d(128)
        #self.bn4 = ShiftBatchNorm(128)
        
        self.quant_identity = QuantIdentity(act_quant=Int8ActPerTensorFloatScratch, return_quant_tensor = True)
        self.quant_act = QuantSigmoid(act_quant=Int8ActPerTensorFloatScratch, return_quant_tensor = True)
        self.dropout = nn.Dropout(p=0.0)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)

        x = self.quant_identity(x)
        # x = self.dropout(x)

        x = self.fc1(x)
        # x = self.bn1(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc2(x)
        # x = self.bn2(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc3(x)
        # x = self.bn3(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc4(x)
        # x = self.bn4(x)
        output = F.log_softmax(x, dim=1)
        return output

# Load Pre-trained Model Parameters

In [3]:
config.IGNORE_MISSING_KEYS = True
# pretrend weight path = 'mnist_qnn_mlp.pt'
model_state_dict = torch.load('mnist_qnn_mlp.pt')
net = QuantNet()
print(net.quant_act.quant_act_scale())
print(net.quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value)
net.load_state_dict(model_state_dict)
print(net.quant_act.quant_act_scale())
print(net.quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value)

tensor(0.0078)
Parameter containing:
tensor(1., requires_grad=True)
tensor(0.0089, grad_fn=<DivBackward0>)
Parameter containing:
tensor(2.8619, requires_grad=True)


In [4]:
from brevitas.core.scaling import ConstScaling

from brevitas.inject import ExtendedInjector
class CommonQuantizer(ExtendedInjector):
    quant_type = QuantType.INT
    bit_width_impl_type = BitWidthImplType.CONST
    scaling_impl_type = ScalingImplType.CONST
    restrict_scaling_type = RestrictValueType.FP
    zero_point_impl = ZeroZeroPoint
    float_to_int_impl_type = FloatToIntImplType.ROUND
    scaling_per_output_channel = False
    narrow_range = False
    signed = True

class ActInferenceQuant(CommonQuantizer, ActQuantSolver):
    bit_width = 8
    # scaling_impl = ConstScaling(2.8619)
    def scaling_impl(self):
        return ConstScaling(2.8619)
    

In [5]:
class QuantNetInference(Module):
    def __init__(self):
        super(QuantNetInference, self).__init__()

        self.fc1 = QuantLinear(784, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn1 = nn.BatchNorm1d(128)
        #self.bn1 = ShiftBatchNorm(1024)

        self.fc2 = QuantLinear(128, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn2 = nn.BatchNorm1d(1024)
        #self.bn2 = ShiftBatchNorm(1024)

        self.fc3 = QuantLinear(128, 128, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn3 = nn.BatchNorm1d(128)
        #self.bn3 = ShiftBatchNorm(128)

        self.fc4 = QuantLinear(128, 10, weight_quant=Int8WeightPerTensorFloatScratch, bias=Int32Bias)
        # self.bn4 = nn.BatchNorm1d(128)
        #self.bn4 = ShiftBatchNorm(128)
        
        self.quant_identity = QuantIdentity(act_quant=ActInferenceQuant, return_quant_tensor = True)
        self.quant_act = QuantSigmoid(act_quant=ActInferenceQuant, return_quant_tensor = True)
        self.dropout = nn.Dropout(p=0.0)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)

        x = self.quant_identity(x)
        # x = self.dropout(x)

        x = self.fc1(x)
        # x = self.bn1(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc2(x)
        # x = self.bn2(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc3(x)
        # x = self.bn3(x)
        x = self.quant_act(x)
        # x = self.dropout(x)

        x = self.fc4(x)
        # x = self.bn4(x)
        output = F.log_softmax(x, dim=1)
        return output

In [6]:
quant_model_inference = QuantNetInference()
print(quant_model_inference)

QuantNetInference(
  (fc1): QuantLinear(
    in_features=784, out_features=128, bias=False
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (output_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (weight_quant): WeightQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClampSte()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
        )
        (scaling_impl): StatsFromParameterScaling(
          (parameter_list_stats): _ParameterListStats(
            (first_tracked_param): _ViewParameterWrapper(
              (view_shape_impl): OverTensorView()
            )
            (stats): _Stats(
              (stats_impl): AbsMax()
            )
          )
          (stats_scaling_impl): _Stat

In [11]:
config.IGNORE_MISSING_KEYS = False
# pretrend weight path = 'mnist_qnn_mlp.pt'
model_state_dict = torch.load('mnist_qnn_mlp.pt')
print("before loading")
print(quant_model_inference.quant_act.quant_act_scale())
print(quant_model_inference.quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl(torch.empty(1)))

quant_model_inference.load_state_dict(model_state_dict)
print("after loading")
print(quant_model_inference.quant_act.quant_act_scale())
print(quant_model_inference.quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl(torch.empty(1)))

before loading
tensor(0.0224)
tensor(2.8619)
after loading
tensor(0.0224)
tensor(2.8619)


In [13]:
model_state_dict

OrderedDict([('fc1.weight',
              tensor([[ 0.1019,  0.0678,  0.0766,  ...,  0.0903,  0.0795,  0.0861],
                      [-0.0046,  0.0003, -0.0082,  ...,  0.0198,  0.0240,  0.0061],
                      [ 0.1158,  0.0782,  0.0536,  ...,  0.0736,  0.0933,  0.0541],
                      ...,
                      [ 0.0743,  0.1022,  0.0996,  ...,  0.1298,  0.0750,  0.0712],
                      [ 0.2434,  0.1943,  0.2404,  ...,  0.1941,  0.1990,  0.2293],
                      [ 0.1188,  0.1672,  0.1120,  ...,  0.1572,  0.1394,  0.1312]],
                     device='cuda:0')),
             ('fc2.weight',
              tensor([[-0.3910, -0.1140, -0.2444,  ..., -0.1273, -0.1737,  0.0166],
                      [ 0.0476, -0.8722,  0.6388,  ..., -0.4184,  0.1209, -0.1012],
                      [ 0.2996,  0.0529,  0.2631,  ...,  0.0447, -0.9385, -0.0182],
                      ...,
                      [-0.5761,  0.0806, -0.9951,  ..., -0.5898,  0.2716,  0.0159],
         