In [2]:
import argparse
import os
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
   

import torchvision
import torchvision.transforms as transforms

from models import *



def act_quantization(b):

    def uniform_quant(x, b=3):
        xdiv = x.mul(2 ** b - 1)
        xhard = xdiv.round().div(2 ** b - 1)  ## b = k = quantization bit size
        return xhard

    class uq(torch.autograd.Function):   # uq inherit the backbone of this torch Function

        def forward(ctx, input, alpha):
            input_d = input/alpha
            input_c = input_d.clamp(max=1)  # Mingu edited for Alexnet ; clipping
            input_q = uniform_quant(input_c, b)
            ctx.save_for_backward(input, input_q)
            input_q_out = input_q.mul(alpha)
            return input_q_out

    return uq().apply




def weight_quantization(b):

    def uniform_quant(x, b):
        xdiv = x.mul((2 ** b - 1))
        xhard = xdiv.round().div(2 ** b - 1)  
        return xhard

    class uq(torch.autograd.Function):

        def forward(ctx, input, alpha):
            input_d = input/alpha                          # weights are first divided by alpha                       
            input_c = input_d.clamp(min=-1, max=1)       # then clipped to [-1,1]
            sign = input_c.sign()                        # return the sign of the number e.g. 0.8 -> +1   -0.6 -> -1  
            input_abs = input_c.abs()
            input_q = uniform_quant(input_abs, b).mul(sign)  # if sign is negative: 0.67 -> - 0.67; if sign is positive: 0.67 -> + 0.67
            ctx.save_for_backward(input, input_q)
            input_q_out = input_q.mul(alpha)               # rescale to the original range
            return input_q_out

    return uq().apply



class weight_quantize_fn(nn.Module):
    def __init__(self, w_bit):
        super(weight_quantize_fn, self).__init__()
        self.w_bit = w_bit-1    # weight has a sign bit, which should be subtracted 
        self.weight_q = weight_quantization(b=self.w_bit)
        self.wgt_alpha = 0.0     

    def forward(self, weight):
        weight_q = self.weight_q(weight, self.wgt_alpha)
        
        return weight_q
    
    

In [13]:
w_alpha = 4.0  # cliping value
w_bits = 16    # quantized bits, the bigger the bits are, the lower the gap between psum_ref(higher accuracy)

x_alpha = 4.0  # clipping value
x_bits = 16

n_batch = 10
n_ch_in = 4
n_ch_out = 8

k_size = 3

## (batch size, input channel size, height, width
x =torch.abs(torch.randn(n_batch,n_ch_in,12,12))  # Note X is unsigned number; n_ch_in: input channel is arbitary
print("x size is:", x.size())

x size is: torch.Size([10, 4, 12, 12])


In [14]:
act_quant = act_quantization(b= x_bits)  # define activation quantization function
x_quant = act_quant(x, x_alpha)   ## quantized real number (dirty)
x_delta = x_alpha/(2**x_bits-1)   # resolution
x_int   = x_quant/x_delta         ## quantized integer number (clean)

#print("X_quant:", x_quant)
#print("resolution:", x_delta)
print("X with integer:", x_int) # why this value is bigger than alpha ?

X with integer: tensor([[[[3.9339e+04, 1.7948e+04, 9.1600e+03,  ..., 3.9390e+03,
           7.9170e+03, 2.2770e+03],
          [1.4648e+04, 3.5550e+03, 2.0690e+04,  ..., 1.3698e+04,
           2.3736e+04, 1.7148e+04],
          [2.0469e+04, 3.0629e+04, 7.0190e+03,  ..., 2.5180e+04,
           2.3105e+04, 1.8448e+04],
          ...,
          [1.7123e+04, 1.5164e+04, 1.0400e+03,  ..., 3.0580e+03,
           1.0159e+04, 5.8410e+03],
          [1.5765e+04, 1.7407e+04, 4.2180e+03,  ..., 8.8030e+03,
           1.4459e+04, 6.1470e+03],
          [1.2250e+03, 9.0100e+02, 4.1850e+03,  ..., 6.0380e+03,
           2.2350e+03, 1.7600e+03]],

         [[4.1800e+03, 3.7800e+02, 2.0440e+03,  ..., 3.1460e+03,
           1.8730e+03, 3.4325e+04],
          [5.5740e+03, 1.7575e+04, 1.5406e+04,  ..., 1.6530e+03,
           2.2086e+04, 6.1120e+03],
          [6.6780e+03, 3.1460e+03, 1.3193e+04,  ..., 5.0340e+03,
           5.2770e+03, 1.1012e+04],
          ...,
          [5.1350e+03, 1.1625e+04, 6.9060e+

In [15]:

### Initialize conv module
conv = nn.Conv2d(n_ch_in, n_ch_out, kernel_size=k_size, padding=1, bias=False)
conv.weight  = torch.nn.Parameter(torch.randn(n_ch_out,n_ch_in,k_size,k_size))
##########################


weight_quant = weight_quantize_fn(w_bit= w_bits)  ## define quant function
weight_quant.wgt_alpha = torch.tensor(w_alpha)
w_quant      = weight_quant(conv.weight)
w_delta      = w_alpha/(2**(w_bits-1)-1)
w_int        = w_quant/w_delta

print("W with integer:", w_int)

W with integer: tensor([[[[ 1.1490e+03,  1.4540e+03, -9.3150e+03],
          [ 4.1170e+03, -6.4010e+03, -9.5460e+03],
          [ 4.6700e+02,  2.0972e+04, -2.2750e+03]],

         [[ 7.9550e+03, -1.4832e+04, -3.4870e+03],
          [-1.7376e+04,  7.3940e+03,  3.1790e+03],
          [ 3.9040e+03, -8.4920e+03, -1.5880e+03]],

         [[ 5.1200e+03, -1.4087e+04,  1.0800e+04],
          [ 1.1202e+04, -1.7535e+04,  6.7800e+03],
          [-5.9730e+03, -5.3700e+03, -1.5040e+03]],

         [[ 1.7400e+03, -3.6430e+03, -8.7300e+02],
          [ 1.8028e+04,  1.8210e+03,  3.5940e+03],
          [-1.9410e+04, -1.1909e+04,  9.3900e+02]]],


        [[[ 6.6100e+02, -5.7700e+02,  4.0500e+02],
          [ 1.4080e+03,  2.8220e+03,  1.4910e+03],
          [-2.0400e+02, -6.1620e+03, -1.2160e+03]],

         [[ 7.3890e+03, -2.9520e+03,  9.9250e+03],
          [-3.8660e+03, -2.4800e+03,  2.4319e+04],
          [-6.8390e+03,  7.8590e+03, -1.3855e+04]],

         [[ 5.8730e+03,  6.6080e+03,  1.0071e+04],
 

In [16]:
###### This cell is the same as red box in the slide #######

conv2 = nn.Conv2d(n_ch_in, n_ch_out, kernel_size=k_size, padding=1, bias=False)
conv2.weight = torch.nn.Parameter(w_int)    # plug in the quantized integer weights

psum_int = conv2(x_int)  # X-int * W-int
print("psum int:", psum_int)

psum int: tensor([[[[-5.8229e+08, -6.3274e+08, -5.3181e+08,  ...,  4.9634e+07,
            2.3709e+08,  1.8783e+08],
          [-3.2294e+08,  3.2993e+08, -7.7268e+08,  ...,  7.1436e+07,
            2.6376e+08, -9.0951e+08],
          [-6.2128e+08, -4.1305e+08, -1.8223e+08,  ..., -7.3012e+08,
           -1.5842e+09,  5.3108e+08],
          ...,
          [-4.9216e+08, -7.8586e+08,  5.6054e+07,  ..., -1.5759e+09,
           -6.6603e+08, -9.9985e+08],
          [-7.7521e+08, -2.3644e+08, -1.6069e+09,  ..., -6.1591e+08,
           -5.1096e+08, -1.9414e+08],
          [-4.0294e+08,  2.7181e+08,  6.9042e+08,  ..., -2.5426e+07,
            1.9030e+08, -5.7255e+06]],

         [[-4.1943e+08,  1.9977e+08,  6.2721e+08,  ..., -5.5880e+08,
            8.8233e+08, -4.0781e+07],
          [ 2.0058e+08,  1.0980e+08,  2.6020e+08,  ...,  5.2806e+08,
           -3.5484e+08,  2.0122e+08],
          [ 2.8230e+08,  4.3490e+08,  5.6692e+08,  ...,  8.2950e+08,
           -5.9904e+08,  3.5986e+08],
          

In [17]:
psum_recovered = psum_int*w_delta*x_delta
print("psum recovered:", psum_recovered)

psum recovered: tensor([[[[-4.3386e+00, -4.7145e+00, -3.9624e+00,  ...,  3.6982e-01,
            1.7665e+00,  1.3995e+00],
          [-2.4062e+00,  2.4583e+00, -5.7572e+00,  ...,  5.3226e-01,
            1.9652e+00, -6.7767e+00],
          [-4.6291e+00, -3.0776e+00, -1.3578e+00,  ..., -5.4401e+00,
           -1.1803e+01,  3.9570e+00],
          ...,
          [-3.6670e+00, -5.8554e+00,  4.1766e-01,  ..., -1.1742e+01,
           -4.9625e+00, -7.4498e+00],
          [-5.7761e+00, -1.7617e+00, -1.1973e+01,  ..., -4.5891e+00,
           -3.8071e+00, -1.4466e+00],
          [-3.0023e+00,  2.0253e+00,  5.1442e+00,  ..., -1.8945e-01,
            1.4179e+00, -4.2661e-02]],

         [[-3.1251e+00,  1.4884e+00,  4.6733e+00,  ..., -4.1636e+00,
            6.5742e+00, -3.0386e-01],
          [ 1.4945e+00,  8.1815e-01,  1.9387e+00,  ...,  3.9345e+00,
           -2.6439e+00,  1.4993e+00],
          [ 2.1034e+00,  3.2404e+00,  4.2241e+00,  ...,  6.1805e+00,
           -4.4634e+00,  2.6813e+00],
    

In [18]:
#### floating point system result
psum_ref = conv(x)
#################################

print("gap", psum_ref -psum_recovered)

gap tensor([[[[-1.0109e-04, -2.2888e-05,  2.9802e-05,  ...,  2.8515e-04,
            1.4544e-04,  2.6393e-04],
          [-1.9097e-04, -2.3913e-04, -2.1458e-04,  ..., -1.0270e-04,
            2.7764e-04,  2.1410e-04],
          [-6.1035e-05, -1.1945e-04, -4.9233e-05,  ...,  7.2002e-05,
            1.4973e-04,  1.0252e-04],
          ...,
          [-5.2214e-05, -9.5367e-07,  5.0604e-04,  ...,  4.3869e-05,
           -6.1989e-05,  2.9421e-04],
          [ 3.0041e-05, -6.9737e-05,  2.4986e-04,  ...,  1.6642e-04,
            2.9755e-04,  1.8132e-04],
          [ 2.0719e-04,  3.3617e-04,  2.7466e-04,  ...,  2.4782e-04,
            5.7220e-06,  2.9951e-04]],

         [[ 1.1301e-04,  3.0160e-05, -2.0027e-04,  ..., -2.2697e-04,
            1.2016e-04,  9.8467e-05],
          [ 2.0564e-04,  5.2452e-06,  3.2413e-04,  ...,  8.8215e-06,
            5.2214e-05, -1.0216e-04],
          [ 5.4598e-05,  1.7643e-05, -2.7657e-04,  ..., -3.1948e-05,
           -5.3883e-05,  1.2684e-04],
          ...,
 