In [1]:
# Adapted from
# https://github.com/zzzxxxttt/pytorch_DoReFaNet/blob/master/utils/quant_dorefa.py and
# https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/dorefa.py#L25

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class SwitchBatchNorm2d(nn.Module):
    """Adapted from https://github.com/JiahuiYu/slimmable_networks
    """
    def __init__(self, num_features, bit_list):
        super(SwitchBatchNorm2d, self).__init__()
        self.bit_list = bit_list
        self.bn_dict = nn.ModuleDict()
        for i in self.bit_list:
            self.bn_dict[str(i)] = nn.BatchNorm2d(num_features)

        self.abit = self.bit_list[-1]
        self.wbit = self.bit_list[-1]
        if self.abit != self.wbit:
            raise ValueError('Currenty only support same activation and weight bit width!')

    def forward(self, x):
        x = self.bn_dict[str(self.abit)](x)
        return x


def batchnorm2d_fn(bit_list):
    class SwitchBatchNorm2d_(SwitchBatchNorm2d):
        def __init__(self, num_features, bit_list=bit_list):
            super(SwitchBatchNorm2d_, self).__init__(num_features=num_features, bit_list=bit_list)

    return SwitchBatchNorm2d_


class SwitchBatchNorm1d(nn.Module):
    """Adapted from https://github.com/JiahuiYu/slimmable_networks
    """
    def __init__(self, num_features, bit_list):
        super(SwitchBatchNorm1d, self).__init__()
        self.bit_list = bit_list
        self.bn_dict = nn.ModuleDict()
        for i in self.bit_list:
            self.bn_dict[str(i)] = nn.BatchNorm1d(num_features)

        self.abit = self.bit_list[-1]
        self.wbit = self.bit_list[-1]
        if self.abit != self.wbit:
            raise ValueError('Currenty only support same activation and weight bit width!')

    def forward(self, x):
        x = self.bn_dict[str(self.abit)](x)
        return x


def batchnorm1d_fn(bit_list):
    class SwitchBatchNorm1d_(SwitchBatchNorm1d):
        def __init__(self, num_features, bit_list=bit_list):
            super(SwitchBatchNorm1d_, self).__init__(num_features=num_features, bit_list=bit_list)

    return SwitchBatchNorm1d_


class qfn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, k):
        n = float(2**k - 1)
        out = torch.round(input * n) / n
        return out

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input, None


class weight_quantize_fn(nn.Module):
    def __init__(self, bit_list):
        super(weight_quantize_fn, self).__init__()
        self.bit_list = bit_list
        self.wbit = self.bit_list[-1]
        assert self.wbit <= 8 or self.wbit == 32

    def forward(self, x):
        if self.wbit == 32:
            E = torch.mean(torch.abs(x)).detach()
            weight = torch.tanh(x)
            weight = weight / torch.max(torch.abs(weight))
            weight_q = weight * E
        else:
            E = torch.mean(torch.abs(x)).detach()
            weight = torch.tanh(x)
            weight = weight / 2 / torch.max(torch.abs(weight)) + 0.5
            weight_q = 2 * qfn.apply(weight, self.wbit) - 1
            weight_q = weight_q * E
        return weight_q


class activation_quantize_fn(nn.Module):
    def __init__(self, bit_list):
        super(activation_quantize_fn, self).__init__()
        self.bit_list = bit_list
        self.abit = self.bit_list[-1]
        assert self.abit <= 8 or self.abit == 32

    def forward(self, x):
        if self.abit == 32:
            activation_q = x
        else:
            activation_q = qfn.apply(x, self.abit)
        return activation_q


class Conv2d_Q(nn.Conv2d):
    def __init__(self, *kargs, **kwargs):
        super(Conv2d_Q, self).__init__(*kargs, **kwargs)


def conv2d_quantize_fn(bit_list):
    class Conv2d_Q_(Conv2d_Q):
        def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                     bias=True):
            super(Conv2d_Q_, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
                                            bias)
            self.bit_list = bit_list
            self.w_bit = self.bit_list[-1]
            self.quantize_fn = weight_quantize_fn(self.bit_list)

        def forward(self, input, order=None):
            weight_q = self.quantize_fn(self.weight)
            return F.conv2d(input, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups)

    return Conv2d_Q_


class Linear_Q(nn.Linear):
    def __init__(self, *kargs, **kwargs):
        super(Linear_Q, self).__init__(*kargs, **kwargs)


def linear_quantize_fn(bit_list):
    class Linear_Q_(Linear_Q):
        def __init__(self, in_features, out_features, bias=True):
            super(Linear_Q_, self).__init__(in_features, out_features, bias)
            self.bit_list = bit_list
            self.w_bit = self.bit_list[-1]
            self.quantize_fn = weight_quantize_fn(self.bit_list)

        def forward(self, input):
            weight_q = self.quantize_fn(self.weight)
            return F.linear(input, weight_q, self.bias)

    return Linear_Q_


batchnorm_fn = batchnorm2d_fn




In [2]:
# initialize conv2d_q layer with random weights
conv2d = conv2d_quantize_fn([4])
conv_org = conv2d(in_channels=16, out_channels=8*2, kernel_size=3, stride=1, padding=1, bias=False, groups=2)

# initialize random input tensor
input = torch.rand(1, 16, 32, 32)

# forward pass
output = conv_org(input)
print(output.shape)
print(conv_org.weight.shape)

torch.Size([1, 16, 32, 32])
torch.Size([16, 8, 3, 3])


In [6]:
# low rank convolution
rank = 8
conv_R = conv2d(in_channels=16, out_channels=rank, kernel_size=3, stride=1, padding=1, bias=False)
conv_L = conv2d(in_channels=rank, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False)

# forward pass
output = conv_R(input)
output = conv_L(output)
print(output.shape)

torch.Size([1, 16, 32, 32])


In [4]:
class Conv2d_Q_LR(nn.Module):
    def __init__(self, *kargs, **kwargs):
        super(Conv2d_Q_LR, self).__init__(*kargs, **kwargs)

def conv2d_lr_quantize_fn(bit_list):
    class Conv2d_Q_LR_(Conv2d_Q_LR):
        def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                    bias=True, rank=8):
            super(Conv2d_Q_LR_, self).__init__()
            self.stride = stride
            self.rank = in_channels // 2
            self.groups = 4
            # if in_channels == 64:
            #     self.groups = 1
            # elif in_channels == 32:
            #     self.groups = 2
            # elif in_channels == 16:
            #     self.groups = 4
            conv2d = conv2d_quantize_fn(bit_list)
            if stride == 1:
                self.conv_R = conv2d(in_channels=in_channels, 
                                    out_channels=self.rank*self.groups, 
                                    kernel_size=kernel_size, 
                                    stride=stride, 
                                    padding=padding, 
                                    bias=False, 
                                    groups=groups)
                self.conv_L = conv2d(in_channels=self.rank*self.groups, 
                                    out_channels=out_channels, 
                                    kernel_size=kernel_size, 
                                    stride=stride, 
                                    padding=padding, 
                                    bias=False)
            else:
                self.conv_ds = conv2d(in_channels=in_channels, 
                                        out_channels=out_channels, 
                                        kernel_size=kernel_size, 
                                        stride=stride, 
                                        padding=padding, 
                                        bias=False)

        def forward(self, x):
            if self.stride == 1:
                x = self.conv_R(x)
                x = self.conv_L(x)
            else:
                x = self.conv_ds(x)
            return x
        
    return Conv2d_Q_LR_

In [3]:
conv_lr = conv2d_lr_quantize_fn([4])
conv0 = conv_lr(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False)

output = conv0(input)
print(output.shape)

NameError: name 'conv2d_lr_quantize_fn' is not defined

In [13]:
conv0

Conv2d_Q_LR_(
  (conv_R): Conv2d_Q_(
    16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    (quantize_fn): weight_quantize_fn()
  )
  (conv_L): Conv2d_Q_(
    8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    (quantize_fn): weight_quantize_fn()
  )
)

In [5]:
# low rank convolution
rank = 8
conv_R = conv2d(in_channels=16, out_channels=rank, kernel_size=3, stride=1, padding=1, bias=False, groups=2)
conv_L = conv2d(in_channels=rank, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False)

# forward pass
output = conv_R(input)
output = conv_L(output)
print(output.shape)

torch.Size([1, 16, 32, 32])


[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.


In [30]:
rank = 8
groups = 4
conv_R = conv2d(in_channels=16, out_channels=rank*groups, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
print(conv_R.weight.shape)

# output = conv_R(input)
# print(output.shape)

torch.Size([32, 4, 3, 3])


In [27]:
# check the number of parameters in conv_R
print(sum(p.numel() for p in conv_R.parameters() if p.requires_grad))

1152


In [23]:
conv_R.weight.shape

torch.Size([16, 8, 3, 3])

In [24]:
rank = 8
conv_R = conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False, groups=1)

output = conv_R(input)
print(output.shape)

torch.Size([1, 16, 32, 32])




PreActResNet(
  (conv0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layers): ModuleList(
    (0-2): 3 x PreActBasicBlockQ(
      (bn0): SwitchBatchNorm2d_(
        (bn_dict): ModuleDict(
          (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (act0): Activate(
        (acti): ReLU(inplace=True)
        (quan): activation_quantize_fn()
      )
      (conv0): Conv2d_Q_LR_(
        (conv_R): Conv2d_Q_(
          16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False
          (quantize_fn): weight_quantize_fn()
        )
        (conv_L): Conv2d_Q_(
          32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (quantize_fn): weight_quantize_fn()
        )
      )
      (bn1): SwitchBatchNorm2d_(
        (bn_dict): ModuleDict(
          (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
 

In [5]:
import models 

model = models.__dict__['resnet20q'](bit_list=[4], rank=2, groups=1, num_classes=10)
# get the numnber of parameters in the model
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

TypeError: __init__() got multiple values for argument 'num_classes'