In [2]:
import torch
import torch.nn as nn
from brevitas.nn import QuantConv2d, QuantLinear, QuantReLU
from brevitas.quant import Int8WeightPerTensorFloat, Int8ActPerTensorFloat, Uint8ActPerTensorFloat
from brevitas.core.quant import QuantType
from brevitas.core.scaling import ScalingImplType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.bit_width import BitWidthImplType

In [3]:
FIRST_LAYER_BIT_WIDTH = 8
LAST_LAYER_BIT_WIDTH = 8
INTERNAL_BIT_WIDTH = 4

In [None]:
class CommonIntWeightPerTensorQuant(Int8WeightPerTensorFloat):
    """
    Common per-tensor weight quantizer with bit-width set to None so that it's forced to be
    specified by each layer.
    """
    scaling_min_val = 2e-16
    bit_width = None


class CommonIntWeightPerChannelQuant(CommonIntWeightPerTensorQuant):
    """
    Common per-channel weight quantizer with bit-width set to None so that it's forced to be
    specified by each layer.
    """
    scaling_per_output_channel = True


class CommonIntActQuant(Int8ActPerTensorFloat):
    """
    Common signed act quantizer with bit-width set to None so that it's forced to be specified by
    each layer.
    """
    scaling_min_val = 2e-16
    bit_width = None
    restrict_scaling_type = RestrictValueType.LOG_FP


class CommonUintActQuant(Uint8ActPerTensorFloat):
    """
    Common unsigned act quantizer with bit-width set to None so that it's forced to be specified by
    each layer.
    """
    scaling_min_val = 2e-16
    bit_width = None
    restrict_scaling_type = RestrictValueType.LOG_FP

# Convolutional Block

In [5]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        weight_bit_width,
        act_bit_width,
        stride=1,
        padding=None,
        groups=1,
        bn_eps=1e-5,
        activation_scaling_per_channel=False,
        bias=False
    ):
        super(ConvBlock, self).__init__()
        if padding == None:
            padding = (kernel_size - 1) // 2
        self.conv = QuantConv2d(
            in_channels = in_channels,
            out_channels = out_channels,
            kernel_size = kernel_size,
            stride = stride,
            padding = padding,
            groups = groups,
            bias = bias,
            weight_bit_width = weight_bit_width, # 4 bit quantization
            weight_quant = CommonIntWeightPerChannelQuant, # quantization per output channel
            weight_scaling_per_output_channel=True, # each output will have its own scaling factor
            weight_scaling_impl_type=ScalingImplType.STATS, # scaling is based on statistics of weight value --> max absolute value
            weight_scaling_stats_op='abs_max',
            weight_narrow_range=True, # range narrowed
            weight_scaling_min_val=2e-16 # min value for scaling factor
        )
        self.bn = nn.BatchNorm2d(num_features = out_channels, eps = bn_eps)
        self.activation = QuantReLU(
            bit_width=act_bit_width,
            max_val=6,
            quant_type=QuantType.INT,
            scaling_impl_type=ScalingImplType.PARAMETER,
            restrict_scaling_type=RestrictValueType.LOG_FP,
            scaling_per_channel=activation_scaling_per_channel,
            per_channel_broadcastable_shape=(1, out_channels, 1, 1),
            return_quant_tensor=True
        )
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x
        

# Inverted Residual Block

In [None]:
class InvertedResidual(nn.Module):
    def __init__(
        self,
        inp,
        oup,
        stride,
        expand_ratio
    ):
        super(InvertedResidual, self).__init__()

# MobileNet

# Model Instantiator

In [None]:
def get_mobilenet_v2():
    # output channels
    channels = [[32], [16], [24, 24], [32, 32, 32], [64, 64, 64, 64, 96, 96, 96], [160, 160, 160], [320]]

# Average Pooling