# Hand digit classifier 
---
## Incremental network quantization

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from skimage import io

### Importing MNIST dataset

In [2]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

train_data = torchvision.datasets.MNIST(root='../', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root="../", train=False, download=True, transform=transform)

train_set = [train_data[i] for i in range(50000)]
validation_set = [train_data[i] for i in range(50000, 60000)]

### Setting up data loaders

In [3]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=64, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)

### Definig a VGG-7 inspired architecture model
---
Featuring 4 convolutional and 3 fully connected layers

In [4]:
class VGG7(nn.Module):
    def __init__(self):
        super(VGG7, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, padding="same", stride=1, bias=False)
        self.conv2 = nn.Conv2d(64, 64, 3, padding="same", stride=1, bias=False)
        self.conv3 = nn.Conv2d(64, 128, 3, padding="same", stride=1, bias=False)
        self.conv4 = nn.Conv2d(128, 128, 3, padding="same", stride=1, bias=False)
        
        self.fc1 = nn.Linear(7*7*128, 512, bias=False)
        self.fc2 = nn.Linear(512, 256, bias=False)
        self.fc3 = nn.Linear(256, 10, bias=False)

    def forward(self, x): 
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, stride=2)

        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, stride=2)  

        x = x.view(-1, 7*7*128)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)  

        x = F.log_softmax(x, dim=1)  

        return x

### Utility functions

In [None]:
def evaluateModel(net, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net.to(device)
    net.eval()

    correct = 0
    loss = 0
    confusion_matrix = np.zeros((10, 10))

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = net(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            loss += F.nll_loss(output, target, reduction="sum").item()
            confusion_matrix[pred.view(-1), target] += 1

    print("[OK] Model evaluation complete [OK]")
    print("Average loss: {:.5f}".format(loss/len(test_loader.dataset)))
    print("Test data accuracy: {:.2f}%".format(100.*(correct/len(test_loader.dataset))))

    fig, ax = plt.subplots()
    ax.imshow(confusion_matrix/len(train_loader.dataset))
    ax.set_xticks(np.arange(10))
    ax.set_yticks(np.arange(10))
    ax.set_xlabel("Truth")
    ax.set_ylabel("Predictions")
    ax.set_title("Confusion matrix")
    fig.set_size_inches(4, 4)

### Shift Quantization operations

In [5]:
def getBounderyExponents(W, b):
    s = torch.max(torch.abs(W)).item()
    n1 = np.floor(np.log2(4*(s/3)))
    n2 = n1 + 1 - (2**(b - 1))/2
    return n1, n2

def getQuantizationMask(W, percentage, T):
    w = W.view(-1)
    t = T.view(-1)
    idx = t == 1

    numberOfWeights = w.size(dim=0)
    numberOfQWeights = int(percentage*numberOfWeights - t[idx].size(dim=0))

    t_aux = torch.Tensor(np.ones_like(T)).view(-1)
    w = w*(t_aux - t)
    w = torch.abs(w)
    sorted_w, indices_w = w.sort()
    t[indices_w[-numberOfQWeights:]] = 1
    
    return t.view(T.size())

def quantizeWeights(W, T, n1, n2):
    T_aux = torch.Tensor(np.ones_like(T))
    eps = 1e-6
    W1 = W*(T_aux - T)
    idx = W == 0
    W.data[idx] = eps

    closestExp = torch.floor(torch.log2(torch.abs(W*4/3)))
    Q = W1 + torch.sign(W)*(2**closestExp)*T

    idx = closestExp*T < n2
    Q[idx] = 0
    idx = ((closestExp > n1)*T).bool()
    Q[idx] = 2**n1

    return closestExp, Q

def quantize_conv_layer(W, T, percentage, number_of_bits):
    n = T.size(dim=0)
    m = T.size(dim=1)

    n1, n2 = getBounderyExponents(W, number_of_bits)

    for i in range(n):
        for j in range(m):
            T[i, j, :, :] = getQuantizationMask(W[i, j, :, :], percentage, T[i, j, :, :])
            _, W.data[i, j, :, :] = quantizeWeights(W[i, j, :, :], T[i, j, :, :], n1, n2)

In [None]:
W = torch.tensor(np.array([
    [0.01, 0.02, -0.2, 0.04, 0.33],
    [0.17, -0.42, -0.33, 0.02, -0.05], 
    [0.02, 0.83, -0.03, 0.03, 0.06],
    [-0.9, 0.07, 0.11, 0.87, -0.36], 
    [-0.73, 0.41, 0.42, 0.39, 0.47]]))
bit_length = 4
n1, n2 = getBounderyExponents(W, bit_length)
print(n1, n2)
T = torch.Tensor(np.zeros_like(W))
T = getQuantizationMask(W, 0.5, T)
print(T)
_, W = quantizeWeights(W, T, n1, n2)

W = torch.Tensor(np.array([
    [0.11, 0.04, -0.7, 0.19, -0.25],
    [0.15, -0.5, -0.25, -0.09, -0.02],
    [-0.02, 1, -0.06, 0.21, 0.15],
    [-1, 0.27, -0.09, 1, -0.25],
    [-0.5, 0.5, 0.5, 0.5, 0.5]
]))

T = getQuantizationMask(W, 0.75, T)
_, W = quantizeWeights(W, T, n1, n2)
print(T)
print(W)

### Device initialization for training

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
net = torch.load("../Baseline/baseline.pth")
net.to(device)

optimizer = optim.SGD(net.parameters(), lr = 1e-2, weight_decay=0)

quantization_percentages = [0.5, 0.75, 0.875]
quantization_precision = 8

epochs = 5
logs_interval = 100
iteration = 0
train_loss = []

Tconv1 = torch.zeros_like(net.conv1.weight)
Tconv2 = torch.zeros_like(net.conv2.weight)
Tconv3 = torch.zeros_like(net.conv3.weight)
Tconv4 = torch.zeros_like(net.conv4.weight)

net.train()

for q_stage, percentage in enumerate(quantization_percentages):
    # quantize layers
    quantize_conv_layer(net.conv1.weight, Tconv1, percentage, quantization_precision)
    quantize_conv_layer(net.conv2.weight, Tconv2, percentage, quantization_precision)
    quantize_conv_layer(net.conv3.weight, Tconv3, percentage, quantization_precision)
    quantize_conv_layer(net.conv4.weight, Tconv4, percentage, quantization_precision)
    
    # correct remaining weights by training
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = net(data)
            loss = F.nll_loss(output, target)
            loss.backward()

            # setam gradienti la 0
            Tconv1_aux = torch.ones_like(Tconv1)
            Tconv2_aux = torch.ones_like(Tconv2)
            Tconv3_aux = torch.ones_like(Tconv3)
            Tconv4_aux = torch.ones_like(Tconv4)
            net.conv1.weight.grad = net.conv1.weight.grad*(Tconv1_aux - Tconv1)
            net.conv2.weight.grad = net.conv2.weight.grad*(Tconv2_aux - Tconv2)
            net.conv3.weight.grad = net.conv3.weight.grad*(Tconv3_aux - Tconv3)
            net.conv4.weight.grad = net.conv4.weight.grad*(Tconv4_aux - Tconv4)

            optimizer.step()

            iteration = iteration + 1
            if iteration % logs_interval == 0:
                print('Quantization step: {}/{}, Train epoch:{}, batch index:{}, loss:{}'.format(
                    q_stage + 1, len(quantization_percentages),
                    epoch, batch_idx, loss.item()/logs_interval))
                train_loss.append(loss.item())    

quantize_conv_layer(net.conv1.weight, Tconv1, 1, quantization_precision)
quantize_conv_layer(net.conv2.weight, Tconv2, 1, quantization_precision)
quantize_conv_layer(net.conv3.weight, Tconv3, 1, quantization_precision)
quantize_conv_layer(net.conv4.weight, Tconv4, 1, quantization_precision) 

torch.save(net, "INQ.pth")

In [None]:
evaluateModel(net, test_loader)