In [3]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os 

In [None]:
_=torch.manual_seed(0) ## make torch deterministic 

In [5]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
## Loading MNIST dataset
mnist_train = datasets.MNIST(root='/data',train=True,download=True,transform=transform)
mnist_test = datasets.MNIST(root='/data',train=False,download=True,transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_train,batch_size=10,shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test,batch_size=10,shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:14<00:00, 688kB/s] 


Extracting /data\MNIST\raw\train-images-idx3-ubyte.gz to /data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 202kB/s]


Extracting /data\MNIST\raw\train-labels-idx1-ubyte.gz to /data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:02<00:00, 690kB/s]


Extracting /data\MNIST\raw\t10k-images-idx3-ubyte.gz to /data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]

Extracting /data\MNIST\raw\t10k-labels-idx1-ubyte.gz to /data\MNIST\raw






In [7]:
device = "cuda"
class ClassNet(nn.Module) : 
    def __init__(self,  hidden_size_1=100,hidden_size_2=100):
        super().__init__()
        self.Linear1 = nn.Linear(28*28,hidden_size_1)
        self.Linear2 = nn.Linear(hidden_size_1,hidden_size_2)
        self.Linear3 = nn.Linear(hidden_size_2,10)
        self.relu = nn.ReLU()
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.relu(self.Linear1(x))
        x = self.relu(self.Linear2(x))
        x = self.Linear3(x)
        
        return x

In [8]:
net = ClassNet().to(device)

In [9]:
def train(train_loader,model,epochs,total_iteration_limit=None) : 
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
    total_iterations = 0 
    
    for epoch in range(epochs) :
        model.train()
        loss_sum = 0
        num_iterations = 0
        data_iterator = tqdm(train_loader,desc=f"Epoch {epoch+1}")
        if  total_iteration_limit is not None:
            data_iterator.total = total_iteration_limit
        for data in data_iterator : 
            num_iterations += 1 
            total_iterations += 1
            x,y=data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            out = model(x.view(-1,28*28))
            loss = loss_fn(out,y)
            loss_sum += loss
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()
            if total_iteration_limit is not None and total_iterations >= total_iteration_limit:
                return


def print_size_model(model):
    torch.save(model.state_dict(),"temp_delme.p")
    print(f"size (KB) :" , os.path.getsize("temp_delme.p")/1e3)
    os.remove("temp_delme.p")
    
MODEL_FILE_NAME = 'classnet.pt'
if Path(MODEL_FILE_NAME).exists():
    net.load_state_dict(torch.load(MODEL_FILE_NAME))
    print(f"Model Loaded !")
else : 
    train(train_loader,net,epochs=2)
    torch.save(net.state_dict(),MODEL_FILE_NAME)

Epoch 1: 100%|██████████| 6000/6000 [00:46<00:00, 128.19it/s, loss=tensor(0.2231, device='cuda:0', grad_fn=<DivBackward0>)]
Epoch 2: 100%|██████████| 6000/6000 [00:42<00:00, 141.77it/s, loss=tensor(0.1129, device='cuda:0', grad_fn=<DivBackward0>)]


In [10]:
def test(model,total_iterations=None):
    correct=0
    total=0
    iterations=0
    model.eval()
    with torch.no_grad():
        for data in tqdm(test_loader,desc=f"Testing") : 
            x,y=data
            x = x.to(device)
            y = y.to(device)
            out = model(x.view(-1,28*28))
            for idx,i in enumerate(out):
                if torch.argmax(i) == y[idx]:
                    correct+=1
                total+=1
            iterations+=1
            if total_iterations is not None and iterations>= total_iterations : 
                break
    print(f"Accuracy : {round(correct/total,3)}")
        

In [11]:
test(net)

Testing: 100%|██████████| 1000/1000 [00:05<00:00, 177.71it/s]

Accuracy : 0.963





In [13]:
print(f"Weights before quantization")
print(net.Linear1.weight)
print(net.Linear1.weight.dtype) ## torch.float32

Weights before quantization
Parameter containing:
tensor([[-0.0029,  0.0165, -0.0321,  ...,  0.0193,  0.0011, -0.0006],
        [ 0.0196,  0.0244,  0.0290,  ...,  0.0191,  0.0334,  0.0095],
        [ 0.0075,  0.0425, -0.0056,  ...,  0.0073,  0.0288,  0.0357],
        ...,
        [ 0.0838,  0.0876,  0.0529,  ...,  0.0476,  0.0668,  0.0299],
        [ 0.0105,  0.0187,  0.0507,  ...,  0.0310,  0.0310,  0.0305],
        [ 0.0149,  0.0096, -0.0046,  ...,  0.0318, -0.0175,  0.0027]],
       device='cuda:0', requires_grad=True)
torch.float32


In [14]:
print(f"Size of the model before quantization")
print_size_model(net)

Size of the model before quantization
size (KB) : 361.062


In [15]:
class ClassNetQuantized(nn.Module):
    def __init__(self,  hidden_size_1=100,hidden_size_2=100):
        super(ClassNetQuantized,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.Linear1 = nn.Linear(28*28,hidden_size_1)
        self.Linear2 = nn.Linear(hidden_size_1,hidden_size_2)
        self.Linear3 = nn.Linear(hidden_size_2,10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.quant(x)
        x = self.relu(self.Linear1(x))
        x = self.relu(self.Linear2(x))
        x = self.Linear3(x)
        x = self.dequant(x)
        return x

In [36]:
net_quantized = ClassNetQuantized().to(device)

In [37]:
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()
net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized)
net_quantized

ClassNetQuantized(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (Linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (Linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (Linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [20]:
print(f"Statistics gathered by the observers after testing the model")
net_quantized

Statistics gathered by the observers after testing the model


ClassNetQuantized(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (Linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-65.63235473632812, max_val=35.30819320678711)
  )
  (Linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-38.02914810180664, max_val=34.694541931152344)
  )
  (Linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-40.19949722290039, max_val=24.341419219970703)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [21]:
### Quantize the model using the collected statistics

net_quantized = torch.ao.quantization.convert(net_quantized)

In [22]:
print(f"Statistics of the various layers after quantization")
net_quantized

Statistics of the various layers after quantization


ClassNetQuantized(
  (quant): Quantize(scale=tensor([0.0256], device='cuda:0'), zero_point=tensor([17], device='cuda:0'), dtype=torch.quint8)
  (Linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.794807493686676, zero_point=83, qscheme=torch.per_tensor_affine)
  (Linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.5726274847984314, zero_point=66, qscheme=torch.per_tensor_affine)
  (Linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.5081961750984192, zero_point=79, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [26]:
print(f"first linear layer")
print(torch.int_repr(net_quantized.Linear1.weight()))

first linear layer
tensor([[ 0,  3, -5,  ...,  3,  0,  0],
        [ 3,  4,  5,  ...,  3,  5,  1],
        [ 1,  7, -1,  ...,  1,  5,  6],
        ...,
        [13, 14,  8,  ...,  7, 11,  5],
        [ 2,  3,  8,  ...,  5,  5,  5],
        [ 2,  2, -1,  ...,  5, -3,  0]], device='cuda:0', dtype=torch.int8)


In [32]:
print(f"Original Weights")
print(net.Linear1.weight)
print(f"")
print(f"Dequantized Weights")
print(torch.dequantize(net_quantized.Linear1.weight()))

Original Weights
Parameter containing:
tensor([[-0.0029,  0.0165, -0.0321,  ...,  0.0193,  0.0011, -0.0006],
        [ 0.0196,  0.0244,  0.0290,  ...,  0.0191,  0.0334,  0.0095],
        [ 0.0075,  0.0425, -0.0056,  ...,  0.0073,  0.0288,  0.0357],
        ...,
        [ 0.0838,  0.0876,  0.0529,  ...,  0.0476,  0.0668,  0.0299],
        [ 0.0105,  0.0187,  0.0507,  ...,  0.0310,  0.0310,  0.0305],
        [ 0.0149,  0.0096, -0.0046,  ...,  0.0318, -0.0175,  0.0027]],
       device='cuda:0', requires_grad=True)

Dequantized Weights
tensor([[ 0.0000,  0.0191, -0.0318,  ...,  0.0191,  0.0000,  0.0000],
        [ 0.0191,  0.0254,  0.0318,  ...,  0.0191,  0.0318,  0.0064],
        [ 0.0064,  0.0445, -0.0064,  ...,  0.0064,  0.0318,  0.0381],
        ...,
        [ 0.0826,  0.0889,  0.0508,  ...,  0.0445,  0.0699,  0.0318],
        [ 0.0127,  0.0191,  0.0508,  ...,  0.0318,  0.0318,  0.0318],
        [ 0.0127,  0.0127, -0.0064,  ...,  0.0318, -0.0191,  0.0000]],
       device='cuda:0')


In [34]:
print_size_model(net)
print_size_model(net_quantized)

size (KB) : 361.062
size (KB) : 95.458


In [38]:
## Accuracies 
test(net)
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 387.63it/s]


Accuracy : 0.963


Testing: 100%|██████████| 1000/1000 [00:02<00:00, 340.00it/s]

Accuracy : 0.963



