In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os 

In [2]:
_=torch.manual_seed(0) ## make torch deterministic 

In [3]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
## Loading MNIST dataset
mnist_train = datasets.MNIST(root='/data',train=True,download=True,transform=transform)
mnist_test = datasets.MNIST(root='/data',train=False,download=True,transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_train,batch_size=10,shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test,batch_size=10,shuffle=True)

In [4]:
class ClassNetQuantized(nn.Module):
    def __init__(self,  hidden_size_1=100,hidden_size_2=100):
        super(ClassNetQuantized,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.Linear1 = nn.Linear(28*28,hidden_size_1)
        self.Linear2 = nn.Linear(hidden_size_1,hidden_size_2)
        self.Linear3 = nn.Linear(hidden_size_2,10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.quant(x)
        x = self.relu(self.Linear1(x))
        x = self.relu(self.Linear2(x))
        x = self.Linear3(x)
        x = self.dequant(x)
        return x

In [5]:
device = "cuda"
net = ClassNetQuantized().to(device)

In [6]:
net.qconfig = torch.ao.quantization.default_qconfig
net.train()
net_quantized = torch.ao.quantization.prepare_qat(net)
net_quantized

ClassNetQuantized(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (Linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (Linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (Linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [7]:
def train(train_loader,model,epochs,total_iteration_limit=None) : 
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
    total_iterations = 0 
    
    for epoch in range(epochs) :
        model.train()
        loss_sum = 0
        num_iterations = 0
        data_iterator = tqdm(train_loader,desc=f"Epoch {epoch+1}")
        if  total_iteration_limit is not None:
            data_iterator.total = total_iteration_limit
        for data in data_iterator : 
            num_iterations += 1 
            total_iterations += 1
            x,y=data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            out = model(x.view(-1,28*28))
            loss = loss_fn(out,y)
            loss_sum += loss
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()
            if total_iteration_limit is not None and total_iterations >= total_iteration_limit:
                return


def print_size_model(model):
    torch.save(model.state_dict(),"temp_delme.p")
    print(f"size (KB) :" , os.path.getsize("temp_delme.p")/1e3)
    os.remove("temp_delme.p")
    
MODEL_FILE_NAME = 'classnet.pt'
if Path(MODEL_FILE_NAME).exists():
    net.load_state_dict(torch.load(MODEL_FILE_NAME))
    print(f"Model Loaded !")
else : 
    train(train_loader,net,epochs=2)
    torch.save(net.state_dict(),MODEL_FILE_NAME)

Model Loaded !


  net.load_state_dict(torch.load(MODEL_FILE_NAME))


In [8]:
train(train_loader,net_quantized,epochs=2)

Epoch 1: 100%|██████████| 6000/6000 [00:52<00:00, 113.88it/s, loss=tensor(0.2235, device='cuda:0', grad_fn=<DivBackward0>)]
Epoch 2: 100%|██████████| 6000/6000 [00:51<00:00, 116.17it/s, loss=tensor(0.1145, device='cuda:0', grad_fn=<DivBackward0>)]


In [9]:
print(f"Checking Stats collected during training")
net_quantized

Checking Stats collected during training


ClassNetQuantized(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (Linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.6958182454109192, max_val=0.4646841287612915)
    (activation_post_process): MinMaxObserver(min_val=-53.30250930786133, max_val=46.89039611816406)
  )
  (Linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5422224998474121, max_val=0.4541158676147461)
    (activation_post_process): MinMaxObserver(min_val=-55.56305694580078, max_val=32.61673355102539)
  )
  (Linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5476884841918945, max_val=0.22404539585113525)
    (activation_post_process): MinMaxObserver(min_val=-53.4817008972168, max_val=28.439607620239258)
  )
  (relu): ReLU()
  (dequant): DeQuant

In [10]:
net_quantized.eval()
net_quantized = torch.ao.quantization.convert(net_quantized)

In [11]:
print(f"Stats of # layers")
net_quantized

Stats of # layers


ClassNetQuantized(
  (quant): Quantize(scale=tensor([0.0256], device='cuda:0'), zero_point=tensor([17], device='cuda:0'), dtype=torch.quint8)
  (Linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.7889204621315002, zero_point=68, qscheme=torch.per_tensor_affine)
  (Linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.6943290829658508, zero_point=80, qscheme=torch.per_tensor_affine)
  (Linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.6450496912002563, zero_point=83, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [12]:
print(f"Weights after quantization")
print(torch.int_repr(net_quantized.Linear1.weight()))

Weights after quantization
tensor([[ 2,  5, -4,  ...,  6,  2,  2],
        [-6, -5, -4,  ..., -6, -3, -7],
        [ 1,  8, -1,  ...,  1,  5,  7],
        ...,
        [11, 11,  5,  ...,  4,  8,  1],
        [-3, -1,  5,  ...,  1,  1,  1],
        [ 3,  2,  0,  ...,  6, -3,  1]], device='cuda:0', dtype=torch.int8)


In [13]:
def test(model,device,total_iterations=None):
    correct=0
    total=0
    iterations=0
    model.eval()
    with torch.no_grad():
        for data in tqdm(test_loader,desc=f"Testing") : 
            x,y=data
            x = x.to(device)
            y = y.to(device)
            out = model(x.view(-1,28*28))
            for idx,i in enumerate(out):
                if torch.argmax(i) == y[idx]:
                    correct+=1
                total+=1
            iterations+=1
            if total_iterations is not None and iterations>= total_iterations : 
                break
    print(f"Accuracy : {round(correct/total,3)}")
        