In [None]:
import time
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from modules.MNIST import MNIST
from modules.trainer import train_and_test, test, train

In [None]:
(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = MNIST("../MNIST_DATASET")

train_images = torch.tensor(train_images, dtype=torch.float32)
val_images = torch.tensor(val_images, dtype=torch.float32)
test_images = torch.tensor(test_images, dtype=torch.float32)

train_labels = torch.tensor(train_labels, dtype=torch.long)
val_labels = torch.tensor(val_labels, dtype=torch.long)
test_labels = torch.tensor(test_labels, dtype=torch.long)

# Add channel dimension (1 for grayscale images)
train_images = train_images.unsqueeze(1)  # Shape: [N, 1, 28, 28]
val_images = val_images.unsqueeze(1)      # Shape: [N, 1, 28, 28]
test_images = test_images.unsqueeze(1)    # Shape: [N, 1, 28, 28]

# Dataloaders
train_dataset = TensorDataset(train_images, train_labels)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)

test_dataset = TensorDataset(test_images, test_labels)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=True)

val_dataset = TensorDataset(val_images, val_labels)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=True)

In [None]:
# Define the model
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, stride=2, padding=1)
    self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
    self.conv3 = nn.Conv2d(64, 64, 3, 1)
    self.conv4 = nn.Conv2d(64, 32, 5, 1)
    self.fc1 = nn.Linear(32, 10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = F.relu(self.conv3(x))
    x = F.relu(self.conv4(x))
    x = x.view(-1, 32*1*1)
    x = self.fc1(x)

    return x
  
def get_model_size(model: nn.Module, data_width=32):
  num_elements = 0
  for param in model.parameters():
    num_elements += param.numel()
  return num_elements * data_width

Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB 
GiB = 1024 * MiB 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 25

train_and_test(model, train_loader, test_loader, optimizer, device, epochs)

fp32_model_accuracy = test(model, device, test_loader)

model_backup = deepcopy(model)

## K-Means Quantization

Quantizer for 1 matrix

In [None]:
from collections import namedtuple
from fast_pytorch_kmeans import KMeans

Codebook = namedtuple('Codebook', ['centroids', 'labels'])

def k_means_quantize(fp32_tensor: torch.Tensor, bitwidth=4, codebook=None):
  if codebook is None:
    n_clusters = 2 ** bitwidth
    kmeans = KMeans(n_clusters=n_clusters, mode='euclidean', verbose=False)
    labels = kmeans.fit_predict(fp32_tensor.view(-1, 1)).to(torch.long)
    centroids = kmeans.centroids.to(torch.float).view(-1)
    codebook = Codebook(centroids, labels)

  quantized_tensor = codebook.centroids[codebook.labels].view(fp32_tensor.shape)
  fp32_tensor.set_(quantized_tensor.view_as(fp32_tensor))
  return codebook

Full model quantizer

In [None]:
from torch.nn import parameter
class KMeansQuantizer:
    def __init__(self, model : nn.Module, bitwidth=4):
        self.codebook = KMeansQuantizer.quantize(model, bitwidth)

    @torch.no_grad()
    def apply(self, model, update_centroids):
        for name, param in model.named_parameters():
            if name in self.codebook:
                if update_centroids:
                    update_codebook(param, codebook=self.codebook[name]) # Defined below
                self.codebook[name] = k_means_quantize(param, codebook=self.codebook[name])

    @staticmethod
    @torch.no_grad()
    def quantize(model: nn.Module, bitwidth=4):
        codebook = dict()
        if isinstance(bitwidth, dict):
            for name, param in model.named_parameters():
                if name in bitwidth:
                    codebook[name] = k_means_quantize(param, bitwidth=bitwidth[name])
        else:
            for name, param in model.named_parameters():
                if param.dim() > 1:
                    codebook[name] = k_means_quantize(param, bitwidth=bitwidth)
        return codebook

In [None]:
# Observing accuracy loss upon quantizing
quantizers = dict()

for bitwidth in [8, 4, 2]:
    model = deepcopy(model_backup)
    print(f'k-means quantizing model into {bitwidth} bits')
    quantizer = KMeansQuantizer(model, bitwidth)
    quantized_model_size = get_model_size(model, bitwidth)
    print(f"    {bitwidth}-bit k-means quantized model has size={quantized_model_size/MiB:.2f} MiB")
    quantized_model_accuracy = test(model, device, test_loader)
    print(f"    {bitwidth}-bit k-means quantized model has accuracy={quantized_model_accuracy:.2f}%")
    quantizers[bitwidth] = quantizer

In [None]:
# Quantization aware training...
# updated centroid = mean of weights in same cluster

def update_codebook(fp32_tensor: torch.Tensor, codebook: Codebook):
  n_clusters = codebook.centroids.numel()
  fp32_tensor = fp32_tensor.view(-1)
  for k in range(n_clusters):
    codebook.centroids[k] = fp32_tensor[codebook.labels == k].mean()

In [None]:
# Will only train if the accuracy diff is more than a certain threshold..
accuracy_drop_threshold = 0.5
quantizers_before_finetune = deepcopy(quantizers)
quantizers_after_finetune = quantizers

for bitwidth in [8, 4, 2]:
  model = deepcopy(model_backup)
  quantizer = quantizers[bitwidth]
  print(f'k-means quantizing model into {bitwidth} bits')
  quantizer.apply(model, update_centroids=False)
  quantized_model_size = get_model_size(model, bitwidth)
  print(f"    {bitwidth}-bit k-means quantized model has size={quantized_model_size/MiB:.2f} MiB")
  quantized_model_accuracy = test(model, device, test_loader)
  print(f"    {bitwidth}-bit k-means quantized model has accuracy={quantized_model_accuracy:.2f}% before quantization-aware training ")
  accuracy_drop = fp32_model_accuracy - quantized_model_accuracy
  
  if accuracy_drop > accuracy_drop_threshold:
    print(f"        Quantization-aware training due to accuracy drop={accuracy_drop:.2f}% is larger than threshold={accuracy_drop_threshold:.2f}%")
    num_finetune_epochs = 5
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    best_accuracy = 0
    epoch = num_finetune_epochs
    while accuracy_drop > accuracy_drop_threshold and epoch > 0:
      train(model, device, train_loader, optimizer, epoch,
            callbacks=[lambda: quantizer.apply(model, update_centroids=True)])
      model_accuracy = test(model, device, test_loader)
      best_accuracy = max(model_accuracy, best_accuracy)
      print(f'        Epoch {num_finetune_epochs-epoch} Accuracy {model_accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')
      accuracy_drop = fp32_model_accuracy - best_accuracy
      epoch -= 1
  
  else:
    print(f"        No need for quantization-aware training since accuracy drop={accuracy_drop:.2f}% is smaller than threshold={accuracy_drop_threshold:.2f}%")

## Linear Quantization

In [None]:
QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])

def calcScaleZeroPoint(min_val, max_val, num_bits=8):
  # Calc Scale and zero point of next
  qmin = 0.
  qmax = 2.**num_bits - 1.

  scale = (max_val - min_val) / (qmax - qmin)

  initial_zero_point = qmin - min_val / scale

  zero_point = 0
  if initial_zero_point < qmin:
      zero_point = qmin
  elif initial_zero_point > qmax:
      zero_point = qmax
  else:
      zero_point = initial_zero_point

  zero_point = int(zero_point)

  return scale, zero_point

def quantize_tensor(x, num_bits=8, min_val=None, max_val=None):
    if not min_val and not max_val:
      min_val, max_val = x.min(), x.max()

    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    q_x = q_x.round().byte()

    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)

def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

In [None]:
def quantizeLayer(x, layer, stat, scale_x, zp_x):
  # for both conv and linear layers

  # cache old values
  W = layer.weight.data
  B = layer.bias.data

  # quantise weights, activations are already quantised
  w = quantize_tensor(layer.weight.data)
  b = quantize_tensor(layer.bias.data)

  layer.weight.data = w.tensor.float()
  layer.bias.data = b.tensor.float()

  # This is Quantisation Artihmetic
  scale_w = w.scale
  zp_w = w.zero_point
  scale_b = b.scale
  zp_b = b.zero_point

  scale_next, zero_point_next = calcScaleZeroPoint(min_val=stat['min'], max_val=stat['max'])

  # Preparing input by shifting
  X = x.float() - zp_x
  layer.weight.data = scale_x * scale_w*(layer.weight.data - zp_w)
  layer.bias.data = scale_b*(layer.bias.data + zp_b)

  # All int computation
  x = (layer(X)/ scale_next) + zero_point_next

  # Perform relu too
  x = F.relu(x)

  # Reset weights for next forward pass
  layer.weight.data = W
  layer.bias.data = B

  return x, scale_next, zero_point_next


We will now get max and min stats for quantising activations of network by running the network with around 1000 examples and getting the average min and max activation values before and after each layer.

In [None]:
# Get Min and max of x tensor, and stores it
def updateStats(x, stats, key):
  max_val, _ = torch.max(x, dim=1)
  min_val, _ = torch.min(x, dim=1)


  if key not in stats:
    stats[key] = {"max": max_val.sum(), "min": min_val.sum(), "total": 1}
  else:
    stats[key]['max'] += max_val.sum().item()
    stats[key]['min'] += min_val.sum().item()
    stats[key]['total'] += 1

  return stats

# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):

  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')

  x = F.relu(model.conv1(x))

  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv2')

  x = F.relu(model.conv2(x))

  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv3')

  x = F.relu(model.conv3(x))

  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv4')

  x = F.relu(model.conv4(x))

  x = x.view(-1, 32*1*1)

  stats = updateStats(x, stats, 'fc1')

  x = model.fc1(x)

  return stats

# Entry function to get stats of all functions.
def gatherStats(model, test_loader):
    device = 'cuda'

    model.eval()
    test_loss = 0
    correct = 0
    stats = {}
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            stats = gatherActivationStats(model, data, stats)

    final_stats = {}
    for key, value in stats.items():
      final_stats[key] = { "max" : value["max"] / value["total"], "min" : value["min"] / value["total"] }
    return final_stats

## Forward Pass for Quantised Inference

In [None]:
def quantForward(model, x, stats):

  # Quantise before inputting into incoming layers
  x = quantize_tensor(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'])

  x, scale_next, zero_point_next = quantizeLayer(x.tensor, model.conv1, stats['conv2'], x.scale, x.zero_point)

  x, scale_next, zero_point_next = quantizeLayer(x, model.conv2, stats['conv3'], scale_next, zero_point_next)
  x, scale_next, zero_point_next = quantizeLayer(x, model.conv3, stats['conv4'], scale_next, zero_point_next)
  x, scale_next, zero_point_next = quantizeLayer(x, model.conv4, stats['fc1'], scale_next, zero_point_next)

  x = x.view(-1, 32*1*1)

  # Back to dequant for final layer
  x = dequantize_tensor(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))

  x = model.fc1(x)

  return F.log_softmax(x, dim=1)

# Testing Function for Quantisation

In [None]:
def testQuant(model, test_loader, quant=False, stats=None):
    device = 'cuda'

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            if quant:
              output = quantForward(model, data, stats)
            else:
              output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

# Get Accuracy of Quantised Model

In [None]:
q_model = deepcopy(model_backup)
stats = gatherStats(q_model, test_loader)
print(stats)
testQuant(q_model, test_loader, quant=True, stats=stats)