In [1]:
import os
import torch
import torch.nn as nn 
from tqdm import tqdm
from pathlib import Path
import torchvision.datasets as datasets
from torchvision.transforms import transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Load MNIST datasets

In [3]:
transform = transforms.Compose([transforms.ToTensor(),
                    transforms.Normalize((0.1307),(0.3081)),
                    ])

mnist_trainset  = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader    = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

mnist_testset   = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader     = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Create Model

In [4]:
class SimpleNet(nn.Module):

  def __init__(self,hidden_size_1=100,hidden_size_2=100):
    super(SimpleNet,self).__init__()
    self.linear1  = nn.Linear(28*28,hidden_size_1)
    self.linear2  = nn.Linear(hidden_size_1,hidden_size_2)
    self.linear3  = nn.Linear(hidden_size_2,10)
    self.relu     = nn.ReLU()

  def forward(self,img):
    x = img.reshape(-1,28*28)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    return x

In [5]:
from torchsummary import summary

model = SimpleNet().to(device)
print(model)
summary(model, (1, 28, 28))

SimpleNet(
  (linear1): Linear(in_features=784, out_features=100, bias=True)
  (linear2): Linear(in_features=100, out_features=100, bias=True)
  (linear3): Linear(in_features=100, out_features=10, bias=True)
  (relu): ReLU()
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 100]          78,500
              ReLU-2                  [-1, 100]               0
            Linear-3                  [-1, 100]          10,100
              ReLU-4                  [-1, 100]               0
            Linear-5                   [-1, 10]           1,010
Total params: 89,610
Trainable params: 89,610
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.34
Estimated Total Size (MB): 0.35
----------------------------------------------------------------


## Model Training

In [6]:
def train(train_loader,model,epochs=5,total_iteration_limit=None):

  cross_entro_loss  = nn.CrossEntropyLoss()
  optimizer         = torch.optim.Adam(model.parameters(),lr=0.001) #Passing the parameters to an optimizer.
  total_iteration   = 0

  for epoch in range(epochs):
    model.train()

    loss_sum        = 0
    num_iterations  = 0

    data_iterator   = tqdm(train_loader, desc=f'Epoch  {epoch+1}')

    if total_iteration_limit is not None:
      data_iterator.total = total_iteration_limit

    for data in data_iterator:
      num_iterations  += 1
      total_iteration += 1
      image,label     = data
      image           = image.to(device)
      label           = label.to(device)
      optimizer.zero_grad()
      output          = model(image.view(-1,28*28))
      loss            = cross_entro_loss(output,label)
      loss_sum       += loss.item()
      avg_loss        = loss_sum / num_iterations
      data_iterator.set_postfix(loss=avg_loss)          # update progress bar with loss
      loss.backward()                                   # backward pass to calculate gradients
      optimizer.step()                                  # Updates the model parameters using the computed gradients to minimize the loss.

      if total_iteration_limit is not None and total_iteration >= total_iteration_limit:
        return

In [7]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    model.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, model, epochs=2)
    # Save the model to disk
    torch.save(model.state_dict(), MODEL_FILENAME)

Loaded model from disk


In [8]:
print_size_of_model(model)

Size (KB): 361.062


- **optimizer.zero_grad()**

    - For every mini-batch during the training phase, we typically want to explicitly set the gradients to zero before starting to do backpropagation (i.e., updating the Weights and biases) because PyTorch accumulates the gradients on subsequent backward passes.

    - This accumulating behavior is convenient while training RNNs or when we want to compute the gradient of the loss summed over multiple mini-batches. So, the default action has been set to accumulate (i.e. sum) the gradients on every loss.backward() call.
    
    - when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly. Otherwise, the gradient would be a combination of the old gradient, which you have already used to update your model parameters and the newly-computed gradient.
    It would therefore point in some other direction than the intended direction towards the minimum (or maximum, in case of maximization objectives).

- **data_iterator.set_postfix()**


    allows you to display additional information (in this case, the loss) next to the progress bar.



In [9]:
for key in model.state_dict():
  print(f"{key}: {model.state_dict().get(key).mean()}")

linear1.weight: 0.007502851076424122
linear1.bias: -0.03206382691860199
linear2.weight: -0.027867062017321587
linear2.bias: -0.1109653189778328
linear3.weight: -0.10847935825586319
linear3.bias: -0.021669870242476463


## Test function

In [10]:
def test(model:nn.Module,total_iteration:int=None):
  correct     = 0
  total       = 0
  iterations  = 0
  model.eval()

  with torch.no_grad():                           # When evaluating the model on a validation set or performing inference on new data, you do not need gradient computations.
    for data in tqdm(test_loader,desc="Testing"):
      image, label = data
      image   = image.to(device)
      label   = label.to(device)
      output  = model(image.reshape(-1,784))

      for index, i in enumerate(output):
        if torch.argmax(i) == label[index]:
          correct += 1
        total +=1

      iterations += 1
      if total_iteration is not None and iterations >= total_iteration:
        break
    print(f'Accuracy: {round(correct/total, 3)}')

## Print weights and size of the model before quantization

In [11]:
print("Weight before quantization")
print(model.linear1.weight)
print(model.linear1.weight.dtype)
print_size_of_model(model)

Weight before quantization
Parameter containing:
tensor([[ 0.0393, -0.0006,  0.0137,  ...,  0.0013, -0.0181,  0.0328],
        [-0.0207,  0.0032,  0.0330,  ..., -0.0303, -0.0180,  0.0186],
        [ 0.0638,  0.1237,  0.0910,  ...,  0.0883,  0.1152,  0.1136],
        ...,
        [ 0.0175, -0.0349, -0.0273,  ..., -0.0308, -0.0263, -0.0293],
        [ 0.0780,  0.0357,  0.0261,  ...,  0.0798,  0.0573,  0.0307],
        [ 0.0030,  0.0069,  0.0268,  ...,  0.0027, -0.0172,  0.0402]],
       device='cuda:0', requires_grad=True)
torch.float32
Size (KB): 361.062


In [12]:
print(f'Accuracy of the model before quantization: ')
test(model)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 1000/1000 [00:02<00:00, 382.65it/s]

Accuracy: 0.969





## Insert Min-Max Observation in the Model

In [28]:
class QuantizedVerySimpleNet(nn.Module):

  def __init__(self,hidden_size_1=100,hidden_size_2=100):
    super(QuantizedVerySimpleNet,self).__init__()
    self.quantize   = torch.quantization.QuantStub()
    self.linear1    = nn.Linear(28*28,hidden_size_1)
    self.linear2    = nn.Linear(hidden_size_1,hidden_size_2)
    self.linear3    = nn.Linear(hidden_size_2,10)
    self.relu       = nn.ReLU()
    self.dequantize = torch.quantization.DeQuantStub()

  def forward(self,img):
    x = img.reshape(-1,(28*28))
    x = self.quantize(x)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    x = self.dequantize(x)
    return x

quantize_model = QuantizedVerySimpleNet().to(device)
print(quantize_model)
summary(quantize_model, (1, 28, 28))

QuantizedVerySimpleNet(
  (quantize): QuantStub()
  (linear1): Linear(in_features=784, out_features=100, bias=True)
  (linear2): Linear(in_features=100, out_features=100, bias=True)
  (linear3): Linear(in_features=100, out_features=10, bias=True)
  (relu): ReLU()
  (dequantize): DeQuantStub()
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         QuantStub-1                  [-1, 784]               0
            Linear-2                  [-1, 100]          78,500
              ReLU-3                  [-1, 100]               0
            Linear-4                  [-1, 100]          10,100
              ReLU-5                  [-1, 100]               0
            Linear-6                   [-1, 10]           1,010
       DeQuantStub-7                   [-1, 10]               0
Total params: 89,610
Trainable params: 89,610
Non-trainable params: 0
---------------------------------------------------------

In [29]:
 # copy the weight from unquntized model
quantize_model.load_state_dict(model.state_dict())
quantize_model.eval()

QuantizedVerySimpleNet(
  (quantize): QuantStub()
  (linear1): Linear(in_features=784, out_features=100, bias=True)
  (linear2): Linear(in_features=100, out_features=100, bias=True)
  (linear3): Linear(in_features=100, out_features=10, bias=True)
  (relu): ReLU()
  (dequantize): DeQuantStub()
)

In [30]:
quantize_model.qconfig  = torch.ao.quantization.default_qconfig         # set default quntize
quantize_model          = torch.ao.quantization.prepare(quantize_model) # Insert observers
quantize_model

QuantizedVerySimpleNet(
  (quantize): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequantize): DeQuantStub()
)

## Calibrate the model using the test set

In [31]:
test(quantize_model)

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 256.33it/s]

Accuracy: 0.969





In [33]:
print('Testing the model after quantization')
test(quantize_model)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:04<00:00, 235.26it/s]

Accuracy: 0.969





In [32]:
print(f'Check statistics of the various layers')
quantize_model

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quantize): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-86.30210876464844, max_val=53.823486328125)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-83.1438980102539, max_val=57.77360534667969)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-125.01261901855469, max_val=38.11832046508789)
  )
  (relu): ReLU()
  (dequantize): DeQuantStub()
)

## Quantize the model using the statistics collected

In [35]:
quantize_model_c = torch.ao.quantization.convert(quantize_model)

print(f'Check statistics of the various layers')
quantize_model_c

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quantize): Quantize(scale=tensor([0.0256], device='cuda:0'), zero_point=tensor([17], device='cuda:0'), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=1.10335111618042, zero_point=78, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=1.1095867156982422, zero_point=75, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=1.2844955921173096, zero_point=97, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequantize): DeQuantize()
)

In [36]:
quantize_model

QuantizedVerySimpleNet(
  (quantize): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-86.30210876464844, max_val=53.823486328125)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-83.1438980102539, max_val=57.77360534667969)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-125.01261901855469, max_val=38.11832046508789)
  )
  (relu): ReLU()
  (dequantize): DeQuantStub()
)

## Print weights of the model after quantization

In [20]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(quantize_model.linear1.weight()))

Weights after quantization
tensor([[ 4,  0,  1,  ...,  0, -2,  3],
        [-2,  0,  3,  ..., -3, -2,  2],
        [ 6, 12,  9,  ...,  9, 11, 11],
        ...,
        [ 2, -3, -3,  ..., -3, -3, -3],
        [ 8,  3,  3,  ...,  8,  6,  3],
        [ 0,  1,  3,  ...,  0, -2,  4]], device='cuda:0', dtype=torch.int8)


## Compare the dequantized weights and the original weights

In [21]:
print('Original weights: ')
print(model.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(quantize_model.linear1.weight()))
print('')

Original weights: 
Parameter containing:
tensor([[ 0.0393, -0.0006,  0.0137,  ...,  0.0013, -0.0181,  0.0328],
        [-0.0207,  0.0032,  0.0330,  ..., -0.0303, -0.0180,  0.0186],
        [ 0.0638,  0.1237,  0.0910,  ...,  0.0883,  0.1152,  0.1136],
        ...,
        [ 0.0175, -0.0349, -0.0273,  ..., -0.0308, -0.0263, -0.0293],
        [ 0.0780,  0.0357,  0.0261,  ...,  0.0798,  0.0573,  0.0307],
        [ 0.0030,  0.0069,  0.0268,  ...,  0.0027, -0.0172,  0.0402]],
       device='cuda:0', requires_grad=True)

Dequantized weights: 
tensor([[ 0.0409,  0.0000,  0.0102,  ...,  0.0000, -0.0204,  0.0307],
        [-0.0204,  0.0000,  0.0307,  ..., -0.0307, -0.0204,  0.0204],
        [ 0.0613,  0.1226,  0.0920,  ...,  0.0920,  0.1124,  0.1124],
        ...,
        [ 0.0204, -0.0307, -0.0307,  ..., -0.0307, -0.0307, -0.0307],
        [ 0.0817,  0.0307,  0.0307,  ...,  0.0817,  0.0613,  0.0307],
        [ 0.0000,  0.0102,  0.0307,  ...,  0.0000, -0.0204,  0.0409]],
       device='cuda:0')


In [22]:
print_size_of_model(quantize_model)

Size (KB): 95.458


In [23]:
model_tensor = [value.data for key,value in model.named_parameters()]

In [24]:
linear1_weight = quantize_model.linear1.weight().dequantize().detach().cpu()
model_tensor[0].detach().cuda()

tensor([[ 0.0393, -0.0006,  0.0137,  ...,  0.0013, -0.0181,  0.0328],
        [-0.0207,  0.0032,  0.0330,  ..., -0.0303, -0.0180,  0.0186],
        [ 0.0638,  0.1237,  0.0910,  ...,  0.0883,  0.1152,  0.1136],
        ...,
        [ 0.0175, -0.0349, -0.0273,  ..., -0.0308, -0.0263, -0.0293],
        [ 0.0780,  0.0357,  0.0261,  ...,  0.0798,  0.0573,  0.0307],
        [ 0.0030,  0.0069,  0.0268,  ...,  0.0027, -0.0172,  0.0402]],
       device='cuda:0')

## Print size and accuracy of the quantized model

In [38]:
print('Size of the model after quantization')
print_size_of_model(quantize_model)

Size of the model after quantization
Size (KB): 364.834


In [37]:
print('Testing the model after quantization')
test(quantize_model)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:04<00:00, 226.00it/s]

Accuracy: 0.969



