# **Module imports**

In [None]:
# importing necessary modules
import numpy as np
import math

import os
import sys
import time
import numpy as np
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.init as init


# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(191009)

<torch._C.Generator at 0x7f0ab8329f30>

# **Floating point quantization function experiment**

In [None]:
#FP8 floating point quantization
x=2.22545454
sign=np.sign(x) #sign of number
m=4 #mantissa
e=3 #exponent
b=2**(e-1) #bias
c=(2-2**(-m))*2**(2**e-b-1) #maximum representable range
p=np.round(np.log2(abs(x)))-m 
if p<1-b-m:
  p=1-b-m
s=2**p
x_q=sign*s*np.round(x/s)
x_q=np.clip(x_q,-c,c)
print('quantised number:',x_q)


quantised number: 2.25


In [None]:
p=np.clip(np.round(np.log2(abs(x))),1-b,2**e-b-1)
s=2**(p-m)
x_q=sign*s*np.round(x/s)
print('quantised number:',x_q)

quantised number: 2.25


In [None]:
from math import frexp
def round_mantissa(x):
    sign=np.sign(x)
    significand, exponent = frexp(x)
    scale = 2.0 ** 4
    newsignificand = round(significand * scale) / scale
    return sign*newsignificand*2**exponent

In [None]:
round_mantissa(2.22545)

2.25

# **Suplementary Functions**

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            image,target=image.to(device),target.to(device)
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')



term_width = 80
TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()
def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return 

# **Cifar10 dataloader**

In [None]:
# Data
data_path = '/content/drive/MyDrive/Extra'
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    data_path, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    data_path, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/drive/MyDrive/Extra/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /content/drive/MyDrive/Extra/cifar-10-python.tar.gz to /content/drive/MyDrive/Extra
Files already downloaded and verified


# **Train/Test function**

In [None]:
def train(epoch,model):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        # print(train_loss,total,correct)
        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    print('Loss:',(train_loss/(batch_idx+1)), 'Acc: ',100.*correct/total,'correct',correct,'total',total)


def test(epoch,model):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    print('Loss:',(test_loss/(batch_idx+1)),'Acc:',100.*correct/total,'Correct',correct,'total',total)

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('/content/drive/MyDrive/Extra/checkpoint'):
            os.mkdir('/content/drive/MyDrive/Extra/checkpoint')
        torch.save(state, '/content/drive/MyDrive/Extra/checkpoint/ckpt.pth')
        best_acc = acc


# **Resnet18 model**

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

# **Building normal  resnet model**

In [None]:
# Resnet18 on floating point number
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

print('==> Building model..')
net = ResNet18()
net = net.to(device)


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)



for epoch in range(start_epoch, start_epoch+10):
    train(epoch,net)
    # test(epoch)
    scheduler.step()



==> Building model..

Epoch: 0
Loss: 1.8474685081740474 Acc:  33.476 correct 16738 total 50000

Epoch: 1
Loss: 1.327060457538156 Acc:  51.552 correct 25776 total 50000

Epoch: 2
Loss: 1.0478629082669992 Acc:  62.806 correct 31403 total 50000

Epoch: 3
Loss: 0.8672097018917503 Acc:  69.482 correct 34741 total 50000

Epoch: 4
Loss: 0.7232816175883993 Acc:  74.768 correct 37384 total 50000

Epoch: 5
Loss: 0.643893344551706 Acc:  77.806 correct 38903 total 50000

Epoch: 6
Loss: 0.5835074847159178 Acc:  79.76 correct 39880 total 50000

Epoch: 7
Loss: 0.5497747966090737 Acc:  81.3 correct 40650 total 50000

Epoch: 8
Loss: 0.5227898905420547 Acc:  82.084 correct 41042 total 50000

Epoch: 9
Loss: 0.49972571451645675 Acc:  82.906 correct 41453 total 50000


In [None]:
# Evaluation of model
num_eval_batches = 1000

train_batch_size = 30
eval_batch_size = 50

saved_model_dir = '/content/drive/MyDrive/Extra/'
scripted_float_model_file = 'resnet18.pth'
print("Size of baseline model")
print_size_of_model(net)

top1, top5 = evaluate(net, criterion, testloader, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(net), saved_model_dir + scripted_float_model_file)

# **Quantizer Function**

In [None]:
class Quantizer(nn.Module):
    def __init__(self, m, e):
        super().__init__()
        self.m = m
        self.e = e
        # self.range_tracker = range_tracker
        # self.register_buffer('b', None)      
        # self.register_buffer('c', None)  
        # self.register_buffer('p', None) 

    # def update_params(self):
    #     raise NotImplementedError

    # Quantize
    def forward(self, input):
        sign=torch.sign(input)
        b= 2**(self.e-1) #bias
        c=(2-2**(-self.m))*2**(2**self.e-b-1)  #maximum representable range
        p= torch.round(torch.log2(abs(input)))-self.m 
        p=torch.clamp(p,1-b-self.m,p.max())
        s=2**p
        output = torch.clamp(sign*s*torch.round(input/s),-c,c)
        return output




class QuantizedResNet18(nn.Module):
    def __init__(self, model):
        super(QuantizedResNet18, self).__init__()
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = Quantizer(4,3)
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        # self.dequant = torch.quantization.DeQuantStub()
        # FP32 model
        self.model = model
    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.model(x)
        x = self.quant(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        # x = self.dequant(x)
        return x

# **Building quantized resnet model**

In [None]:
#Quantization aware traing of resnet18
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
print('==> Building quantized model..')
net_quant= ResNet18()
quant_model=QuantizedResNet18(model=net_quant)
net_quant=net_quant.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net_quant.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)



for epoch in range(start_epoch, start_epoch+10):
    train(epoch,net_quant)
    # test(epoch)
    scheduler.step()


==> Building quantized model..

Epoch: 0
Loss: 1.9011738364348936 Acc:  31.992 correct 15996 total 50000

Epoch: 1
Loss: 1.4244765113381779 Acc:  47.92 correct 23960 total 50000

Epoch: 2
Loss: 1.1736782999599682 Acc:  58.148 correct 29074 total 50000

Epoch: 3
Loss: 0.9914183830056349 Acc:  64.872 correct 32436 total 50000

Epoch: 4
Loss: 0.8502378007944893 Acc:  70.092 correct 35046 total 50000

Epoch: 5
Loss: 0.7231217928402259 Acc:  74.82 correct 37410 total 50000

Epoch: 6
Loss: 0.6359348762828065 Acc:  77.994 correct 38997 total 50000

Epoch: 7
Loss: 0.5901471950361491 Acc:  79.556 correct 39778 total 50000

Epoch: 8
Loss: 0.5553777747599365 Acc:  81.026 correct 40513 total 50000

Epoch: 9
Loss: 0.5331915782388214 Acc:  81.714 correct 40857 total 50000


In [None]:
num_eval_batches = 1000

train_batch_size = 30
eval_batch_size = 50

saved_model_dir = '/content/drive/MyDrive/Extra/'
scripted_float_model_file = 'quant_resnet18.pth'
print("Size of baseline model")
print_size_of_model(quant_model)

top1, top5 = evaluate(quant_model, criterion, testloader, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(quant_model), saved_model_dir + scripted_float_model_file)

# **Toy expeiment**

# **Pertensor/Per channel quantization**