In [None]:
from math import e
from re import A
import re
from typing import Dict, List, Tuple, Union, Optional
from sympy import per

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.utils import get_same_padding, resize, val2list, val2tuple, merge_tensor
from models.nn.norm import build_norm
from models.nn.act import build_act, Quanhswish
from models.nn.quant_lsq import QuanConv, PActFn, PACT, SymmetricQuantFunction
from models.nn.lsq import LsqQuantizer4input, LsqQuantizer4weight
from models.nn.ops import ConvLayer, DSConv, MBConv, EfficientViTBlock, OpSequential, ResidualBlock, IdentityLayer, LiteMSA

In [None]:
class Conv(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, per_channel=True,
                 dilation=1, groups=1, bias=True, input_bitdepth=8, weight_bitdepth=8, output_bitdepth=8):
        super(Conv, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            groups, bias)
        self.fixBN = True
        self.per_channel = per_channel
        # scale
        self.input_scale = torch.ones(1, requires_grad=False) # 2^N
        # reciprocal format (c / 2^N)
        self.weight_scale = torch.ones(1, requires_grad=False) if self.per_channel \
                       else torch.ones(self.weight.shape[0], requires_grad=False)
        self.output_scale = torch.ones(1, requires_grad=False)
        # bit-width
        self.input_bitdepth = input_bitdepth
        self.weight_bitdepth = weight_bitdepth
        self.output_bitdepth = output_bitdepth
    def forward(self, input): # the value of input is int8 but the type is float32
        # weight 
        w = self.weight / self.weight_scale # stored as Sw*W_int in fp format
        # bias
        b = self.bias / (self.weight_scale * self.input_scale) if self.bias is not None else None
        # perform convolution in fp format to simulate integer convolution
        output = F.conv2d(input, w, b,
                          self.stride, self.padding, self.dilation, self.groups)
        # dyadic scale
        dyadic_scale = self.input_scale * self.weight_scale / self.output_scale
        # final output
        output = output * dyadic_scale 
        return output

In [None]:
class QConvLayer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        dilation=1,
        padding=None,
        groups=1,
        use_bias=False,
        dropout_rate=0,
        norm="bn2d",
        act_func="relu",
        per_channel=True,
        quan_a='acy',
        quan_w='lsq',
        nbit_w=8,
        nbit_a=8,
        res = False,
        psum_quan=False,
        cg=32
    ):
        super(QConvLayer, self).__init__()

        padding = padding if padding else get_same_padding(kernel_size)
        padding *= dilation
        self.res = res
        self.dropout = nn.Dropout2d(dropout_rate, inplace=False) if dropout_rate > 0 else None
        self.quan_a_name = quan_a
        if self.quan_a_name == 'acy':
            self.pact = PACT(nbit_a)
        self.conv = QuanConv(
            in_channels,
            out_channels,
            kernel_size=(kernel_size, kernel_size),
            quan_name_w=quan_w,
            quan_name_a=self.quan_a_name,
            nbit_w=nbit_w,
            nbit_a=nbit_a,
            per_channel=per_channel,
            stride=(stride, stride),
            padding=padding,
            dilation=(dilation, dilation),
            groups=groups,
            bias=use_bias,
            norm=norm,
            psum_quan=psum_quan,
            cg=cg,
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.quan_a_name == 'acy' and self.res == False:
            x, scale_a = self.pact(x)
            self.conv.scale_a = scale_a.detach()
        elif self.quan_a_name == 'lsq' and self.res == False:
            x, scale_a = self.conv.lsq_a(x)
            self.conv.scale_a = scale_a.detach()

        x = self.conv(x)
        
        return x

In [None]:
model_train = QConvLayer(3,
                         16,
                         kernel_size=3,
                         stride=2,
                         padding=1,
                         norm='bn2d',
                         act_func='relu',
                         per_channel=True,
                         quan_a='acy',
                         quan_w='lsq',
                         nbit_w=8,
                         nbit_a=8,
                         res=False,
                         psum_quan=False,
                         cg=32)
