In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

In [7]:
# Loading the MNIST Dataset
# Make torch deterministic
_ = torch.manual_seed(0)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # Format (mean, ) every value is a channel

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Dataloader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size = 10, shuffle = True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size = 10, shuffle = True)

# Define the device
device = "cpu"

In [8]:
# Model Definition

class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1 = 100, hidden_size_2=100):
        super(VerySimpleNet, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
net = VerySimpleNet().to(device)

In [10]:
# Training the loop

def train(train_loader, net, epochs=5, total_iterations_limit=None):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    
    # Set Total iterations to zero
    total_iterations = 0
    
    # Start the for loop for epochs
    for epoch in range(epochs):
        loss_sum = 0
        num_iterations = 0

        # Set net to train mode
        net.train()
        # Set loss and num_iterations to zero
        loss, num_iterations = 0, 0
        # Have a data iterator using tqdm
        data_iter = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        # if total iterations limit is not none:
        if total_iterations_limit is not None:
            # Set total_iterations to total_iterations_limit
            total_iterations = total_iterations_limit
        # Start the for loop for iterator
        for x, y in data_iter:
            
            # Increment iterations and total_iterations
            total_iterations += 1
            
            # Move data and labels to device
            x = x.to(device)
            y= y.to(device)
            # Zero the optimizer gradients
            optimizer.zero_grad()
            # Forward pass
            output = net(x)
            # Loss calculation
            loss = ce_loss(output, y)
            # Perform backpropagation calculation
            loss.backward()
            # Perform update step
            optimizer.step()
            
            # metrics stuff:
            loss_sum += loss
            avg_loss = loss_sum / num_iterations
            data_iter.set_postfix(loss = avg_loss)
            
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

TEMP_MODEL_FILENAME = 'temp_simplenet_ptq.pt'
MODEL_FILENAME = "simplenet_ptq.pt"

def print_size_of_model(model):
    # Save the model?
    torch.save(model, TEMP_MODEL_FILENAME)
    # Print the size in KB using os.path.getsize
    print(f"Model size in KB {os.path.getsize(TEMP_MODEL_FILENAME)/1e3}")
    # Remove the model with os.remove
    os.remove(TEMP_MODEL_FILENAME)

# If the path exists, 
if Path(MODEL_FILENAME).exists():
    # Load the state dictionaries
    net.load_state_dict(torch.load(MODEL_FILENAME))
else:
    train(train_loader, net, epochs=1)
    # Save the model once the training is done
    torch.save(net.state_dict(), MODEL_FILENAME)
            
            

Epoch 1: 100%|██████████| 6000/6000 [00:15<00:00, 382.47it/s, loss=tensor(inf, grad_fn=<DivBackward0>)]


In [83]:
import time
from tqdm import tqdm

# Defining the testing loop with average running time calculation
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0
    iterations = 0
    times = []  # List to store batch times

    model.eval()

    # With torch no gradient
    with torch.no_grad():
        for x, y in tqdm(test_loader, desc="Testing"):
            # Start timing for this batch
            start_time = time.time()
            
            # Move input and labels to device
            x, y = x.to(device), y.to(device)
            
            # Get model output
            output = model(x)
            
            # Calculate accuracy for the batch
            for idx, i in enumerate(output):
                answer = torch.argmax(i)
                if answer == y[idx]:
                    correct += 1
                total += 1
            
            # Stop timing for this batch and store it
            batch_time = time.time() - start_time
            times.append(batch_time)

            # Increment iteration count
            iterations += 1
            if total_iterations == iterations:
                break
    
    # Compute overall accuracy
    accuracy = correct / total
    avg_time_per_batch = sum(times) / len(times)
    
    print(f"Accuracy: {round(accuracy, 3)}")
    print(f"Average Time per Batch: {round(avg_time_per_batch, 6)} seconds")

# Run the test on original model
print("Testing Original Model:")
test(net, 1000)


Testing Original Model:


Testing: 100%|█████████▉| 999/1000 [00:01<00:00, 751.74it/s]

Accuracy: 0.954
Average Time per Batch: 0.000211 seconds





In [23]:
# Printing out weights of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[ 0.0138,  0.0333, -0.0153,  ...,  0.0360,  0.0178,  0.0162],
        [-0.0154, -0.0106, -0.0061,  ..., -0.0159, -0.0016, -0.0256],
        [-0.0080,  0.0270, -0.0211,  ..., -0.0082,  0.0133,  0.0202],
        ...,
        [ 0.0246,  0.0283, -0.0064,  ..., -0.0117,  0.0075, -0.0294],
        [-0.0184, -0.0103,  0.0217,  ...,  0.0020,  0.0021,  0.0016],
        [ 0.0084,  0.0031, -0.0110,  ...,  0.0254, -0.0239, -0.0038]],
       requires_grad=True)
torch.float32


In [24]:
# Print size of model before quantization
print_size_of_model(net)

Model size in KB 361.839


In [29]:
print(f'Accuracy of the model before quantization')
test(net)

Accuracy of the model before quantization


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 719.71it/s]

Accuracy: 0.954





In [31]:
# Inserting min-max observers in the model
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedVerySimpleNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
        
    # Forward
    def forward(self, x):
        x = x.view((-1, 28*28))
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x
        
        # Flatten out imagae
        # Quantize inputs
        # Forward till end
        # Dequantize
        # return

In [69]:
# Define net and put it todevide
import torch.ao.quantization


static_quantized_net = QuantizedVerySimpleNet().to(device)

# Copy weights by loading state dict from other model
static_quantized_net.load_state_dict(net.state_dict())

# Set to eval mode
static_quantized_net.eval()

# Get Quantization config and run quantization stuff
# torch.ao.quantization.default_qconfig
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MinMaxObserver
import functools
config = QConfig(
    activation = functools.partial(MinMaxObserver, quant_min=0, quant_max=127),
    weight = functools.partial(MinMaxObserver, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
)

static_quantized_net.qconfig = config
static_quantized_net = torch.ao.quantization.prepare(static_quantized_net) # Preparing observers
torch.ao.quantization.default_qconfig

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})

In [71]:
torch.ao.quantization.qconfig.default_per_channel_qconfig

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})

In [72]:
print(static_quantized_net)

QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)


In [73]:
## Calibrate the model using the test set 
test(static_quantized_net)

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 749.40it/s]

Accuracy: 0.954





In [74]:
print(static_quantized_net)

QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-46.70039367675781, max_val=28.700359344482422)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-27.921293258666992, max_val=27.60137939453125)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-32.47340393066406, max_val=22.86193084716797)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)


In [75]:
static_quantized_net = torch.ao.quantization.convert(static_quantized_net)

In [76]:
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(static_quantized_net.linear1.weight()))
print('')

Original weights: 
Parameter containing:
tensor([[ 0.0138,  0.0333, -0.0153,  ...,  0.0360,  0.0178,  0.0162],
        [-0.0154, -0.0106, -0.0061,  ..., -0.0159, -0.0016, -0.0256],
        [-0.0080,  0.0270, -0.0211,  ..., -0.0082,  0.0133,  0.0202],
        ...,
        [ 0.0246,  0.0283, -0.0064,  ..., -0.0117,  0.0075, -0.0294],
        [-0.0184, -0.0103,  0.0217,  ...,  0.0020,  0.0021,  0.0016],
        [ 0.0084,  0.0031, -0.0110,  ...,  0.0254, -0.0239, -0.0038]],
       requires_grad=True)

Dequantized weights: 
tensor([[ 0.0132,  0.0352, -0.0132,  ...,  0.0352,  0.0176,  0.0176],
        [-0.0132, -0.0088, -0.0044,  ..., -0.0176,  0.0000, -0.0264],
        [-0.0088,  0.0264, -0.0220,  ..., -0.0088,  0.0132,  0.0220],
        ...,
        [ 0.0264,  0.0264, -0.0044,  ..., -0.0132,  0.0088, -0.0308],
        [-0.0176, -0.0088,  0.0220,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0088,  0.0044, -0.0132,  ...,  0.0264, -0.0220, -0.0044]])



In [77]:
print('Size of the model after quantization')
print_size_of_model(static_quantized_net)

Size of the model after quantization
Model size in KB 95.613


In [78]:
print('Testing the model after quantization')
test(static_quantized_net)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 759.20it/s]

Accuracy: 0.953





In [80]:
# Dynamic Post-training Quantization

from torch.quantization import default_dynamic_qconfig
model_fp32 = VerySimpleNet()

# Method # 1 to do it
qconfig_spec = {
    # Apply int8 to all layers
    torch.nn.Linear: torch.quantization.default_dynamic_qconfig,
    # For layer2, we use float16 quantization for weights
    # "linear2": torch.quantization.float16_dynamic_qconfig
}

net.eval()

dynamic_quantized_model = torch.quantization.quantize_dynamic(
    net,
    qconfig_spec= qconfig_spec,
    inplace=False
)

# # Method #2 to do it
# quantized_model = torch.quantization.quantize_dynamic(
#     model_fp32,
#     qconfig_spec={torch.nn.Linear},
#     dtype=torch.qint8,
#     inplace=False
# )
print(dynamic_quantized_model.linear1)
print(dynamic_quantized_model.linear2)
print(dynamic_quantized_model.linear1.weight())

DynamicQuantizedLinear(in_features=784, out_features=100, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
DynamicQuantizedLinear(in_features=100, out_features=100, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
tensor([[ 0.0132,  0.0352, -0.0132,  ...,  0.0352,  0.0176,  0.0176],
        [-0.0132, -0.0088, -0.0044,  ..., -0.0176,  0.0000, -0.0264],
        [-0.0088,  0.0264, -0.0220,  ..., -0.0088,  0.0132,  0.0220],
        ...,
        [ 0.0264,  0.0264, -0.0044,  ..., -0.0132,  0.0088, -0.0308],
        [-0.0176, -0.0088,  0.0220,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0088,  0.0044, -0.0132,  ...,  0.0264, -0.0220, -0.0044]],
       size=(100, 784), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.004404082428663969,
       zero_point=0)


In [81]:
# On the side, mapping can be used as such:
import torch
import torch.nn.quantized.dynamic as nnqd

# Assume you have a custom linear layer
class MyLinearLayer(torch.nn.Linear):
    pass

# Create a model with MyLinearLayer
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = MyLinearLayer(28*28, 100)
        
    def forward(self, x):
        return self.fc(x)
    
# Define the mapping to replace MyLinearLayer with dynamically quantized Linear
mapping = {
    MyLinearLayer: nnqd.Linear  # Map MyLinearLayer to torch.nn.quantized.dynamic.Linear
}
    
    
# Apply dynamic quantization
model = MyModel()
model_quantized = torch.quantization.quantize_dynamic(
    model, qconfig_spec={torch.nn.Linear}, mapping=mapping, inplace=False
)

print(model_quantized)

MyModel(
  (fc): MyLinearLayer(in_features=784, out_features=100, bias=True)
)


In [85]:
# Run the test on original model
print("Testing Original Model:")
test(net, 1000)

# Test dynamic quantized model
print("Testing Dynamic Quantized Model:")
test(dynamic_quantized_model, 1000)

# Test static quantized model
print("\nTesting Static Quantized Model:")
test(static_quantized_net, 1000)

Testing Original Model:


Testing: 100%|█████████▉| 999/1000 [00:01<00:00, 757.89it/s]


Accuracy: 0.954
Average Time per Batch: 0.000215 seconds
Testing Dynamic Quantized Model:


Testing: 100%|█████████▉| 999/1000 [00:01<00:00, 775.16it/s]


Accuracy: 0.953
Average Time per Batch: 0.00024 seconds

Testing Static Quantized Model:


Testing: 100%|█████████▉| 999/1000 [00:01<00:00, 784.83it/s]

Accuracy: 0.953
Average Time per Batch: 0.000247 seconds



