# QAT, Prune, Extract

In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import math
import os

import training
from quant_layer import QuantConv2d, act_quantization

from resnet_quant import ResNet_Cifar
from vgg_quant import VGG_quant

'''
Parameters
'''
input_size = 32
mean = [0.491, 0.482, 0.447]
std = [0.247, 0.243, 0.262]

dataset = torchvision.datasets.CIFAR10
batch_size = 128
n_epochs_train = 96
n_epochs_retrain = 2
lr = 0.01
adjust_list = [20, 30]
optim = torch.optim.SGD
device = torch.device("cuda") 
criterion = torch.nn.CrossEntropyLoss().to(device)

prune_ratio = 0.8
w_bit = 4
x_bit = 4
x_alpha = 8.0
w_alpha_init = 3.0

print_freq = 100

## Model

In [2]:
model = ResNet_Cifar('resnet20_quant', x_bit, w_bit, x_alpha, w_alpha_init)
# model = VGG_quant('vgg16_quant', x_bit, w_bit, x_alpha, w_alpha_init)

model

ResNet_Cifar(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): QuantConv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (conv2): QuantConv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): QuantConv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (conv2

## Dataloaders

In [3]:
normalize = transforms.Normalize(mean=mean, std=std)

train_dataset = dataset(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(input_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

test_dataset = dataset(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


## Load Checkpoint

In [4]:
save_dir = f'checkpoints/{model.name}'
path = f"{save_dir}/model_best.pth.tar"

if os.path.exists(path):
    checkpoint = torch.load(path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
else:
    start_epoch = 0

## Quantization Aware Training

In [5]:
training.train_val_save(
    model, 
    trainloader, 
    testloader, 
    optim,
    lr,
    criterion,
    start_epoch,
    n_epochs_train, 
    adjust_list, 
    print_freq, 
    save_dir
)

Epoch: [95][0/391]	Loss 0.9652 (0.9652)	Prec 72.656% (72.656%)
Epoch: [95][100/391]	Loss 0.7631 (0.8891)	Prec 70.312% (68.595%)
Epoch: [95][200/391]	Loss 0.8231 (0.8848)	Prec 66.406% (68.711%)
Epoch: [95][300/391]	Loss 1.0208 (0.8849)	Prec 57.812% (68.638%)
Validation starts
Test: [0/79]	Loss 0.8056 (0.8056)	Prec 67.969% (67.969%)
 * Prec 68.250% 
best acc: 68.250000


## Pruning & Training

In [6]:
for layer in model.modules():
     if isinstance(layer, QuantConv2d):
        prune.l1_unstructured(layer, name='weight', amount=prune_ratio)

optimizer = optim(model.parameters(), lr=lr)
training.simple_train_test(
    model, 
    trainloader, 
    testloader, 
    optim,
    lr,
    criterion, 
    n_epochs_retrain
)

Epoch: 1 	Training Loss: 1.557315
Epoch: 2 	Training Loss: 1.317897

Test set: Accuracy: 5584/10000 (56%)



## Populate x with prehooks into layers

In [7]:
class SaveOutput:
    def __init__(self, layer):
        self.layer = layer
    def __call__(self, module, module_in):
        self.layer.prehooked = module_in[0]
        
for layer in model.modules():
    if isinstance(layer, QuantConv2d):
        layer.register_forward_pre_hook(SaveOutput(layer))             

dataiter = iter(testloader)
images, labels = dataiter.next()
images = images.to(device)
out = model(images)

## Calculate integer versions of w, x & y, store them as channel last numpy arrays

In [8]:
import numpy as np

for layer in model.modules():
    if isinstance(layer, QuantConv2d):
        
        # Weights
        weight_q = layer.weight_q
        w_alpha = layer.weight_quant.w_alpha
        layer.w_delta = w_alpha /(2**(w_bit-1)-1)
        layer.w_int = weight_q / layer.w_delta
        
        # Input
        act_quant_fn = act_quantization(x_bit)
        x_alpha = layer.x_alpha.item()
        layer.x_delta = x_alpha/(2**x_bit-1)
        x_q = act_quant_fn(layer.prehooked, x_alpha)
        layer.x_int   = x_q/layer.x_delta
        
        # Output
        co, ci, kh, kw = weight_q.shape
        layer.y_int = F.conv2d(
            layer.x_int, 
            torch.nn.parameter.Parameter(layer.w_int), 
            layer.bias, 
            layer.stride, layer.padding, layer.dilation, layer.groups)
        layer.y_f = layer.y_int * layer.w_delta * layer.x_delta
        

In [9]:
import numpy as np
import pickle

conv_i = 0
d = {}

for layer in model.modules():
    if isinstance(layer, QuantConv2d):
        layer.w_npcl = layer.w_int.cpu().detach().numpy().round().astype(np.int).transpose(2,3,1,0)
        layer.x_npcl = layer.x_int.cpu().detach().numpy().round().astype(np.int).transpose(0,2,3,1)
        layer.y_npcl = layer.y_int.cpu().detach().numpy().round().astype(np.int).transpose(0,2,3,1)
        
        conv_i += 1
        channel_out = layer.weight.shape[0]
        d[f'conv_{conv_i}'] = {
            'w': layer.w_npcl, 
            'x': layer.x_npcl, 
            'b': np.zeros((channel_out), dtype=np.int8),
            'y': layer.y_npcl
        }
    
with open (f'np/np_dict_{model.name}.pickle', 'wb') as f:
    pickle.dump(d, f)

In [10]:
for k, l in d.items():
    print(k,l['x'].shape)

conv_1 (128, 32, 32, 16)
conv_2 (128, 32, 32, 16)
conv_3 (128, 32, 32, 16)
conv_4 (128, 32, 32, 16)
conv_5 (128, 32, 32, 16)
conv_6 (128, 32, 32, 16)
conv_7 (128, 32, 32, 16)
conv_8 (128, 16, 16, 32)
conv_9 (128, 32, 32, 16)
conv_10 (128, 16, 16, 32)
conv_11 (128, 16, 16, 32)
conv_12 (128, 16, 16, 32)
conv_13 (128, 16, 16, 32)
conv_14 (128, 16, 16, 32)
conv_15 (128, 8, 8, 64)
conv_16 (128, 16, 16, 32)
conv_17 (128, 8, 8, 64)
conv_18 (128, 8, 8, 64)
conv_19 (128, 8, 8, 64)
conv_20 (128, 8, 8, 64)


In [11]:
print(d['conv_5']['w'].shape)
print(d['conv_5']['x'].shape)
print(d['conv_5']['y'].shape)

(3, 3, 16, 16)
(128, 32, 32, 16)
(128, 32, 32, 16)


In [12]:
# layer = list(model.modules())[2]

# conv_ref = torch.nn.Conv2d(in_channels=layer.in_channels, 
#                             out_channels=layer.out_channels, 
#                             kernel_size=layer.kernel_size, 
#                             stride=layer.stride, 
#                             padding=layer.padding, 
#                             dilation=layer.dilation, 
#                             groups=layer.groups, 
#                             bias=layer.bias)

# weight = layer.weight
# # mean = 0
# # std = 1
# mean = weight.data.mean()
# std = weight.data.std()
# conv_ref.weight = torch.nn.parameter.Parameter(weight.add(-mean).div(std))

# output_ref = conv_ref(layer.prehooked)

# difference = abs(output_ref - layer.y_f )
# print(difference.mean())