In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST


# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(191009)

class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel,self).__init__()
        self.conv1=nn.Conv2d(1,32,3,1,1)
        self.relu1=nn.ReLU()
        self.conv2=nn.Conv2d(32,64,3,1,1)
        self.relu2=nn.ReLU()
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
        self.dropout=nn.Dropout(0.5)
        self.flatten=nn.Flatten()
        self.fc1=nn.Linear(in_features=12544,out_features=128)
        self.relu3=nn.ReLU()
        self.fc2=nn.Linear(in_features=128,out_features=10)
        self.log_softmax=nn.LogSoftmax(dim=1)

    def forward(self,x):
        x=self.conv1(x)
        x=self.relu1(x)
        x=self.conv2(x)
        x=self.relu2(x)
        x=self.pool(x)
        x=self.dropout(x)
        x=self.flatten(x)
        x=self.fc1(x)
        x=self.relu3(x)
        x=self.fc2(x)
        output=self.log_softmax(x)
        return output


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def fmevaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end ="")
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5

def qmevaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            output = F.log_softmax(output, dim=1)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end ="")
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5

def load_model(model_file):
    model = BaseModel()
    state_dict = torch.load(model_file, weights_only=True)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [2]:
train_batch_size = 30
eval_batch_size = 50

def prepare_data_loaders():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)


    train_sampler = torch.utils.data.RandomSampler(train_dataset)
    test_sampler = torch.utils.data.SequentialSampler(test_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=train_batch_size,
        sampler=train_sampler)
    
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=eval_batch_size,
        sampler=test_sampler)

    return train_loader, test_loader

train_loader, test_loader = prepare_data_loaders()
criterion = nn.CrossEntropyLoss()

In [3]:
float_model_file = 'base_model.pth'
scripted_float_model_file = 'float_model_scripted.pth'
scripted_quantized_model_file = 'quantized_model_scripted.pth'

float_model = load_model(float_model_file).to('cpu')

  return self.fget.__get__(instance, owner)()


In [4]:
float_model.eval()
float_model_fused = torch.ao.quantization.fuse_modules(float_model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['fc1', 'relu3']])

In [5]:
float_model_file = 'base_model.pth'

num_eval_batches = 1000

print("Size of baseline model")
print_size_of_model(float_model)

top1, top5 = fmevaluate(float_model, criterion, test_loader, neval_batches=num_eval_batches)
print('\nEvaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(float_model), scripted_float_model_file)

Size of baseline model
Size (MB): 6.506268
........................................................................................................................................................................................................
Evaluation accuracy on 50000 images, 96.98


In [6]:
class QuantizedModel(nn.Module):
    def __init__(self, model):
        super(QuantizedModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.model = model
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

# Assuming per_channel_quantized_model is already loaded
per_channel_quantized_model = load_model(float_model_file)

per_channel_quantized_model.log_softmax = None
def modified_forward(self,x):
        x=self.conv1(x)
        x=self.relu1(x)
        x=self.conv2(x)
        x=self.relu2(x)
        x=self.pool(x)
        x=self.dropout(x)
        x=self.flatten(x)
        x=self.fc1(x)
        x=self.relu3(x)
        x=self.fc2(x)
        return x
per_channel_quantized_model.forward = modified_forward.__get__(per_channel_quantized_model, BaseModel)

In [7]:
# Fuse the modules in the original model
fused_model = torch.ao.quantization.fuse_modules(per_channel_quantized_model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['fc1', 'relu3']])

# Create a modified model with the fused original model
quantized_model = QuantizedModel(fused_model)

# Set the model to evaluation mode
quantized_model.eval()

# Set the quantization configuration
quantized_model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
# Prepare the model for quantization
torch.ao.quantization.prepare(quantized_model, inplace=True)

# Calibrate the model with representative data
for data, _ in train_loader:
    quantized_model(data)

# Convert the model to a quantized version
torch.ao.quantization.convert(quantized_model, inplace=True)

# Print the modified model architecture
print(quantized_model)



QuantizedModel(
  (quant): Quantize(scale=tensor([0.0255]), zero_point=tensor([17]), dtype=torch.quint8)
  (model): BaseModel(
    (conv1): QuantizedConvReLU2d(1, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.036530353128910065, zero_point=0, padding=(1, 1))
    (relu1): Identity()
    (conv2): QuantizedConvReLU2d(32, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.04108048230409622, zero_point=0, padding=(1, 1))
    (relu2): Identity()
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (dropout): QuantizedDropout(p=0.5, inplace=False)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (fc1): QuantizedLinearReLU(in_features=12544, out_features=128, scale=0.18267779052257538, zero_point=0, qscheme=torch.per_channel_affine)
    (relu3): Identity()
    (fc2): QuantizedLinear(in_features=128, out_features=10, scale=0.2601340711116791, zero_point=61, qscheme=torch.per_channel_affine)
    (log_softmax): None
  )
  (dequant): DeQuantize()
)


In [9]:
print("Size of quantized model")
print_size_of_model(quantized_model)
top1, top5 = qmevaluate(quantized_model, criterion, test_loader, neval_batches=num_eval_batches)
print('\nEvaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))

Size of quantized model
Size (MB): 1.637678
........................................................................................................................................................................................................
Evaluation accuracy on 50000 images, 96.96


In [10]:
torch.jit.save(torch.jit.script(quantized_model), scripted_quantized_model_file)

## Compare Model Sizes

In [11]:
print("Original Model Size :")
print_size_of_model(float_model)
print("Quantized Model Size :")
print_size_of_model(quantized_model)

Original Model Size :
Size (MB): 6.506268
Quantized Model Size :
Size (MB): 1.637678


## Compare Model Accuracy

In [12]:
print("Original Model Accuracy :")
top1, top5 = fmevaluate(float_model, criterion, test_loader, neval_batches=num_eval_batches)
print('\nEvaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
print("\n\n\n")
print("Quantized Model Accuracy :")
top1, top5 = qmevaluate(quantized_model, criterion, test_loader, neval_batches=num_eval_batches)
print('\nEvaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))

Original Model Accuracy :
........................................................................................................................................................................................................
Evaluation accuracy on 50000 images, 96.98




Quantized Model Accuracy :
........................................................................................................................................................................................................
Evaluation accuracy on 50000 images, 96.96


## Compare Model Speed

In [16]:
def measure_inference_time(model, input_data, num_iterations=100):
    # Set the model to evaluation mode
    model.eval()
    
    # Warm up the model (optional but recommended for more accurate timing)
    with torch.no_grad():
        for _ in range(10):
            _ = model(input_data)
    
    # Measure the inference time
    start_time = time.time()
    with torch.no_grad():
        for _ in range(num_iterations):
            _ = model(input_data)
    end_time = time.time()
    
    avg_inference_time = (end_time - start_time) / num_iterations
    return avg_inference_time

# Assuming float_model and modified_model are already defined and loaded
# Create some dummy input data
input_data = torch.randn(1, 1, 28, 28)  # Adjust the shape according to your model's input

# Measure inference time for the float model
float_model_inference_time = measure_inference_time(float_model, input_data)
print(f"Float model average inference time: {float_model_inference_time:.6f} seconds")

# Measure inference time for the quantized model
quantized_model_inference_time = measure_inference_time(quantized_model, input_data)
print(f"Quantized model average inference time: {quantized_model_inference_time:.6f} seconds")

Float model average inference time: 0.000971 seconds
Quantized model average inference time: 0.000443 seconds


#### We see that the Quantized Model offers a significant speed boost, consumes less memory, and doesn't compromise on accuracy. This follows the general case with quantization where we can maintain efficiency of a model while significantly reducing memory size requirements and inference time