In [1]:
import argparse
import os
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
   

import torchvision
import torchvision.transforms as transforms

from models import *



def act_quantization(b):

    def uniform_quant(x, b=3):
        xdiv = x.mul(2 ** b - 1)
        xhard = xdiv.round().div(2 ** b - 1)
        return xhard

    class uq(torch.autograd.Function):   # here single underscore means this class is for internal use

        def forward(ctx, input, alpha):
            input_d = input/alpha
            input_c = input_d.clamp(max=1)  # Mingu edited for Alexnet
            input_q = uniform_quant(input_c, b)
            ctx.save_for_backward(input, input_q)
            input_q_out = input_q.mul(alpha)
            return input_q_out

    return uq().apply




def weight_quantization(b):

    def uniform_quant(x, b):
        xdiv = x.mul((2 ** b - 1))
        xhard = xdiv.round().div(2 ** b - 1)  
        return xhard

    class uq(torch.autograd.Function):

        def forward(ctx, input, alpha):
            input_d = input/alpha                          # weights are first divided by alpha                       
            input_c = input_d.clamp(min=-1, max=1)       # then clipped to [-1,1]
            sign = input_c.sign()
            input_abs = input_c.abs()
            input_q = uniform_quant(input_abs, b).mul(sign)
            ctx.save_for_backward(input, input_q)
            input_q_out = input_q.mul(alpha)               # rescale to the original range
            return input_q_out

    return uq().apply



class weight_quantize_fn(nn.Module):
    def __init__(self, w_bit):
        super(weight_quantize_fn, self).__init__()
        self.w_bit = w_bit-1
        self.weight_q = weight_quantization(b=self.w_bit)
        self.wgt_alpha = 0.0

    def forward(self, weight):
        weight_q = self.weight_q(weight, self.wgt_alpha)
        
        return weight_q
    
    

In [2]:

w_alpha = 4.0  # cliping value
w_bits = 8

x_alpha = 4.0  # clipping value
x_bits = 8

n_batch = 10
n_ch_in = 4
n_ch_out = 8

k_size = 3


x =torch.abs(torch.randn(n_batch,n_ch_in,12,12))  # Note X is unsigned number
act_quant = act_quantization(b= x_bits)  # define activation quantization function
x_quant = act_quant(x, x_alpha)   
x_delta = x_alpha/(2**x_bits-1)   # resolution
x_int   = x_quant/x_delta

#print("X_quant:", x_quant)
#print("resolution:", x_delta)
print("X with integer:", x_int) # why this value is bigger than alpha ?

X with integer: tensor([[[[ 70.0000,  50.0000,  67.0000,  ...,   7.0000,  43.0000, 101.0000],
          [ 38.0000,   2.0000,  81.0000,  ...,  69.0000,  40.0000,  35.0000],
          [ 27.0000,  55.0000,  60.0000,  ...,  66.0000,  69.0000,  41.0000],
          ...,
          [  5.0000,  74.0000,  56.0000,  ...,  62.0000,  32.0000,   2.0000],
          [ 26.0000,  25.0000, 130.0000,  ...,   1.0000,   1.0000,  69.0000],
          [ 21.0000,  13.0000,  11.0000,  ...,  22.0000, 125.0000,  50.0000]],

         [[124.0000,  10.0000,  48.0000,  ..., 105.0000,  26.0000,  55.0000],
          [ 57.0000, 118.0000,  20.0000,  ...,  77.0000,  62.0000,  33.0000],
          [117.0000,  58.0000,  13.0000,  ...,  29.0000,  35.0000, 101.0000],
          ...,
          [ 64.0000, 123.0000, 163.0000,  ...,  53.0000,  38.0000,  43.0000],
          [ 46.0000,  50.0000,  60.0000,  ...,  34.0000,  32.0000,  37.0000],
          [ 20.0000,  61.0000,  60.0000,  ...,  32.0000,  68.0000,  70.0000]],

         [[ 49

In [3]:

### Initialize conv module
conv = nn.Conv2d(n_ch_in, n_ch_out, kernel_size=k_size, padding=1, bias=False)
conv.weight  = torch.nn.Parameter(torch.randn(n_ch_out,n_ch_in,k_size,k_size))
##########################


weight_quant = weight_quantize_fn(w_bit= w_bits)  ## define quant function
weight_quant.wgt_alpha = torch.tensor(w_alpha)
w_quant      = weight_quant(conv.weight)
w_delta      = w_alpha/(2**(w_bits-1)-1)
w_int        = w_quant/w_delta

print("W with integer:", w_int)

W with integer: tensor([[[[-48., -30.,  -2.],
          [-18.,  23.,  48.],
          [ -5.,   6.,  37.]],

         [[-12.,  -6., -54.],
          [ -5.,  23.,  86.],
          [-39.,   7.,  56.]],

         [[ 36., -18.,  14.],
          [-28.,  39., -26.],
          [-74., -11.,  -5.]],

         [[ 45.,  22.,  26.],
          [ 19.,  26.,  23.],
          [ -3.,  24.,   2.]]],


        [[[-40.,  -8., -38.],
          [ 13.,   2.,  22.],
          [-93.,  45.,  -7.]],

         [[-59.,  27.,  32.],
          [-27., -48.,  22.],
          [  6.,  47.,  18.]],

         [[-65.,  -1.,  14.],
          [ -9., -43.,  42.],
          [ 31.,  30., -15.]],

         [[-39.,   4.,  21.],
          [  2., -47.,  13.],
          [ 35.,  11.,  83.]]],


        [[[ 25., -31., -12.],
          [-46.,  19., -43.],
          [-10.,  54., -18.]],

         [[ -8., -36.,  24.],
          [  4.,  71.,  -8.],
          [ 29.,  -4., -26.]],

         [[ 19., -41.,  -4.],
          [-32.,  15.,  -2.],


In [4]:
###### This cell is the same as red box in the slide #######

conv2 = nn.Conv2d(n_ch_in, n_ch_out, kernel_size=k_size, padding=1, bias=False)
conv2.weight = torch.nn.Parameter(w_int)

psum_int = conv2(x_int)
print("psum int:", psum_int)

psum int: tensor([[[[ 2.0785e+04,  1.4239e+04,  4.6160e+03,  ...,  5.1150e+03,
            2.8030e+03, -3.0940e+03],
          [ 2.2846e+04,  5.2200e+03,  5.2180e+03,  ...,  1.0964e+04,
           -5.1400e+02, -1.3337e+04],
          [ 1.2256e+04,  5.9640e+03,  7.7000e+03,  ...,  4.1050e+03,
            1.2905e+04, -4.4300e+03],
          ...,
          [ 1.9132e+04,  3.2302e+04,  1.9908e+04,  ...,  7.9330e+03,
           -3.6200e+02,  4.5200e+02],
          [ 9.9170e+03,  1.0081e+04,  5.7300e+03,  ...,  1.4829e+04,
            6.6230e+03, -6.0600e+02],
          [ 7.0360e+03,  1.0105e+04,  9.3180e+03,  ...,  2.5265e+04,
            1.9750e+04,  2.4360e+03]],

         [[ 9.5320e+03, -2.4100e+02,  1.1360e+04,  ...,  1.4610e+03,
            4.5050e+03, -2.5800e+02],
          [ 2.0151e+04, -1.0712e+04, -1.3960e+04,  ..., -5.0200e+02,
            1.1120e+03, -4.0510e+03],
          [ 7.9530e+03, -1.5708e+04, -1.0224e+04,  ...,  1.1145e+04,
           -1.3439e+04, -1.7527e+04],
          

In [5]:
psum_recovered = psum_int*w_delta*x_delta
print("psum recovered:", psum_recovered)

psum recovered: tensor([[[[ 1.0269e+01,  7.0349e+00,  2.2806e+00,  ...,  2.5271e+00,
            1.3848e+00, -1.5286e+00],
          [ 1.1287e+01,  2.5790e+00,  2.5780e+00,  ...,  5.4168e+00,
           -2.5394e-01, -6.5892e+00],
          [ 6.0551e+00,  2.9465e+00,  3.8042e+00,  ...,  2.0281e+00,
            6.3758e+00, -2.1887e+00],
          ...,
          [ 9.4523e+00,  1.5959e+01,  9.8357e+00,  ...,  3.9193e+00,
           -1.7885e-01,  2.2331e-01],
          [ 4.8996e+00,  4.9806e+00,  2.8309e+00,  ...,  7.3264e+00,
            3.2721e+00, -2.9940e-01],
          [ 3.4762e+00,  4.9924e+00,  4.6036e+00,  ...,  1.2482e+01,
            9.7576e+00,  1.2035e+00]],

         [[ 4.7093e+00, -1.1907e-01,  5.6125e+00,  ...,  7.2182e-01,
            2.2257e+00, -1.2747e-01],
          [ 9.9557e+00, -5.2923e+00, -6.8970e+00,  ..., -2.4802e-01,
            5.4939e-01, -2.0014e+00],
          [ 3.9292e+00, -7.7606e+00, -5.0512e+00,  ...,  5.5063e+00,
           -6.6396e+00, -8.6593e+00],
    

In [6]:
#### floating point system result
psum_ref = conv(x)
#################################

print("gap", psum_ref -psum_recovered)

gap tensor([[[[ 7.4270e-02, -7.1104e-02, -4.7727e-02,  ..., -3.4266e-02,
           -1.6858e-02, -5.5637e-02],
          [ 4.5961e-02,  1.8716e-02, -5.3405e-02,  ...,  9.0661e-03,
           -3.8610e-02, -2.6696e-02],
          [ 1.6518e-02,  3.2850e-02,  2.6267e-02,  ...,  2.3345e-02,
           -2.1091e-02,  2.4374e-02],
          ...,
          [ 8.7528e-02,  7.3021e-02,  6.7367e-02,  ...,  4.2431e-03,
            7.6391e-02, -1.2729e-02],
          [ 5.9956e-02,  5.4739e-02,  3.0748e-02,  ...,  1.4976e-02,
            3.1790e-02, -2.0675e-02],
          [ 7.7489e-02,  7.3452e-02,  2.9769e-02,  ..., -9.4147e-03,
            5.0355e-02,  3.4674e-02]],

         [[ 6.1777e-02, -5.3184e-02,  4.6600e-02,  ...,  5.5040e-02,
            3.0536e-02,  2.3327e-02],
          [ 7.0719e-02,  7.1895e-02,  1.1579e-01,  ...,  2.0328e-02,
           -7.6136e-03,  2.8992e-02],
          [ 4.4076e-02,  1.6612e-02,  4.0874e-02,  ...,  6.9376e-02,
            7.5018e-02,  6.3675e-02],
          ...,
 