In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from transformer.Vision_transformer import VisionTransformer, CustomDataset
import torchvision.transforms as transforms
from pprint import pprint
from torchsummary import summary
import json
import numpy as np
import os


from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QConfigMapping

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

device = torch.device("cpu")
print(f"Using device : {device}")
# We don't want to perform our quantization step on cuda GPU. It is not supported.
with open('transformer/config.json') as f:
    custom_config = json.load(f)
# Custom configurations for the VisionTransformer.
# Transformer can be customized with these configurations.
# Refer to documentation of the class VisionTransformer
# (`VisionTransformer.__doc__`, use pprint for cleaner display)
# for exact details of the customization.


Using device : cpu


In [4]:
# Load saved model
float_model = VisionTransformer(**custom_config).to(device=device)

# float_model = VisionTransformer(
#     img_size=32,
#     patch_size=8,
#     in_chans=3,
#     n_classes=10,
#     embed_dim=128,
#     depth=2,
#     n_heads=2,
#     mlp_ratio=4.,
#     p=0.3,
#     attn_p=0.3
# ).to(device=device)
checkpoint = torch.load("models/model.pth")
float_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
])   # Transform object to apply on the dataset.

# train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Loading/Downloading dataset. `download` can be `False` if the data is present in the root directory
# Else it will download the dataset to to the root location.

train_ds = CustomDataset(data=train_dataset)
test_ds = CustomDataset(data=test_dataset, device=device)
# Made custom dataset objects from the MNIST dataset.

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)
# DataLoaders for fast implementation of loading batch-wise data.



In [5]:
# transform = transforms.Compose([
#     transforms.ToTensor(),  # Convert PIL Image to tensor
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the image tensors
# ])

# # Load CIFAR-10 training dataset
# train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# # Load CIFAR-10 test dataset
# test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# # Classes in CIFAR-10 dataset
# classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


# train_ds = CustomDataset(data=train_dataset)
# test_ds = CustomDataset(data=test_dataset)
# # Made custom dataset objects from the MNIST dataset.

# train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2048, shuffle=True)
# test_loader = torch.utils.data.DataLoader(test_ds, batch_size=2048, shuffle=False)
# # DataLoaders for fast implementation of loading batch-wise data.


In [6]:
# # qconfig = get_default_qconfig("x86")

# max_bit_length = 4
# qconfig = torch.quantization.QConfig(
#     activation=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.quint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.quint8), 
#     weight=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.qint8)
# )
# qconfig_mapping = QConfigMapping().set_global(qconfig)

# def calibrate(model, data_loader):
#     model.eval()
#     with torch.no_grad():
#         for image, _ in data_loader:
#             image = image.to(device)
#             model(image)


# example_inputs = (next(iter(test_loader))[0]) # get an example input
# prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs=example_inputs ) 

# calibrate(prepared_model, train_loader)
# quantized_model = convert_fx(prepared_model)

In [6]:
# qconfig = get_default_qconfig("x86")

max_bit_length = 4
qconfig = torch.quantization.QConfig(
    activation=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.quint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.quint8), 
    weight=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.qint8)
)
qconfig_mapping = QConfigMapping().set_global(qconfig)

In [7]:
qconfig_mapping

QConfigMapping (
 global_qconfig
  QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, dtype=torch.quint8){}, quant_min=0, quant_max=15, dtype=torch.quint8){}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, dtype=torch.qint8){}, quant_min=0, quant_max=15, dtype=torch.qint8){})
 object_type_qconfigs
  OrderedDict()
 module_name_regex_qconfigs
  OrderedDict()
 module_name_qconfigs
  OrderedDict()
 module_name_object_type_order_qconfigs
  OrderedDict()
)

In [8]:
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, _ in data_loader:
            image = image.to(device)
            model(image)

In [9]:
example_inputs = (next(iter(test_loader))[0]) # get an example input
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs=example_inputs ) 

In [10]:
calibrate(prepared_model, train_loader)

In [11]:
quantized_model = convert_fx(prepared_model)

In [19]:
for name, param in quantized_model.named_parameters():
    print(name, param)

cls_token Parameter containing:
tensor([[[-0.0036,  0.0393,  0.0284,  0.0412,  0.0239, -0.0122,  0.0125,
           0.0209,  0.0172, -0.0591, -0.0249, -0.0518,  0.0008, -0.0724,
           0.0681,  0.0243, -0.2622,  0.6902,  0.0107, -0.0901, -0.1576,
          -0.0022,  0.0790, -0.0144,  0.0019, -0.0107,  0.0219, -0.0109,
           0.0872, -0.0216,  0.0366,  0.0085]]], requires_grad=True)
pos_embed Parameter containing:
tensor([[[-0.0036,  0.0393,  0.0284,  ..., -0.0216,  0.0366,  0.0085],
         [ 0.8553, -0.0934,  0.0290,  ...,  0.0155, -0.0228, -0.2379],
         [ 0.8468, -0.0868, -0.0270,  ...,  0.0681, -0.1890, -0.3481],
         ...,
         [-0.1675,  0.1178,  0.2676,  ...,  0.4275, -0.9559,  0.0071],
         [ 0.0372,  0.4384, -0.0244,  ..., -0.0046, -0.1864, -0.0974],
         [ 0.5809, -0.0199,  0.1729,  ...,  0.0371, -0.1984, -0.2354]]],
       requires_grad=True)
blocks.0.norm1.weight Parameter containing:
tensor([0.6856, 0.9604, 0.8827, 0.6854, 1.0185, 0.8176, 1.0376

In [13]:
def test(model):
    correct, total = 0, 0
    model.eval()
    # Setting the model in evaluation mode.
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # Loading batch images and ground truth onto device
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
    return f"Accuracy on test set: {(100 * correct / total):.2f}%"

In [14]:
test(float_model)

'Accuracy on test set: 95.30%'

In [15]:
test(quantized_model)

'Accuracy on test set: 70.79%'

In [28]:
for name, module in prepared_model.named_modules():
    if not hasattr(module, 'qconfig'):
        print(f"{name} is not quantized.")
    # else:
    #     print(f"{name} is not quantized")

 is not quantized.
activation_post_process_0 is not quantized.
activation_post_process_0.activation_post_process is not quantized.
patch_embed is not quantized.
activation_post_process_1 is not quantized.
activation_post_process_1.activation_post_process is not quantized.
activation_post_process_2 is not quantized.
activation_post_process_2.activation_post_process is not quantized.
activation_post_process_3 is not quantized.
activation_post_process_3.activation_post_process is not quantized.
activation_post_process_6 is not quantized.
activation_post_process_6.activation_post_process is not quantized.
activation_post_process_7 is not quantized.
activation_post_process_7.activation_post_process is not quantized.
activation_post_process_8 is not quantized.
activation_post_process_8.activation_post_process is not quantized.
blocks is not quantized.
blocks.0 is not quantized.
blocks.0.attn is not quantized.
blocks.0.mlp is not quantized.
blocks.1 is not quantized.
blocks.1.attn is not quan

In [23]:
float_model.state_dict()['pos_embed'] == quantized_model.state_dict()['pos_embed'] 

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])

In [13]:

# Function to extract scale and zero point values from the prepared model
def extract_quantization_params(model):
    quantization_params = {}
    for name, module in model.named_modules():
        if hasattr(module, 'activation_post_process'):
            observer = module.activation_post_process
            if isinstance(observer, torch.quantization.observer.MovingAverageMinMaxObserver):
                scale, zero_point = observer.calculate_qparams()
                quantization_params[name] = {'scale': scale.item(), 'zero_point': zero_point.item()}
    return quantization_params

# Extract quantization parameters
quantization_params = extract_quantization_params(prepared_model)

# Print the scale and zero point values for each observer
for name, params in quantization_params.items():
    print(f"Layer: {name}")
    print(f"Scale: {params['scale']}")
    print(f"Zero Point: {params['zero_point']}")

Layer: activation_post_process_0
Scale: 0.06666667014360428
Zero Point: 0
Layer: activation_post_process_1
Scale: 0.3135177195072174
Zero Point: 7
Layer: activation_post_process_2
Scale: 0.31371864676475525
Zero Point: 7
Layer: activation_post_process_3
Scale: 0.18223203718662262
Zero Point: 7
Layer: activation_post_process_6
Scale: 0.16426581144332886
Zero Point: 7
Layer: activation_post_process_7
Scale: 0.27967262268066406
Zero Point: 7
Layer: activation_post_process_8
Scale: 0.26970183849334717
Zero Point: 7
Layer: activation_post_process_9
Scale: 0.43550142645835876
Zero Point: 7
Layer: activation_post_process_10
Scale: 0.7412218451499939
Zero Point: 7
Layer: activation_post_process_13
Scale: 0.7402080297470093
Zero Point: 7
Layer: activation_post_process_15
Scale: 3.5287368297576904
Zero Point: 7
Layer: activation_post_process_17
Scale: 0.06555232405662537
Zero Point: 0
Layer: activation_post_process_18
Scale: 0.06545073539018631
Zero Point: 0
Layer: activation_post_process_19
Sca