In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from Vision_transformer import VisionTransformer, VisionTransformerForPTQ, CustomDataset
import torchvision.transforms as transforms
from pprint import pprint
from torchsummary import summary
import json
import numpy as np
import os

In [None]:
torch.ao.quantization.quantize_dynamic

In [None]:

device = torch.device("cpu")
# We don't want to perform our quantization step on cuda GPU. It is not supported.
with open('config.json') as f:
    custom_config = json.load(f)
# Custom configurations for the VisionTransformer.
# Transformer can be customized with these configurations.
# Refer to documentation of the class VisionTransformer
# (`VisionTransformer.__doc__`, use pprint for cleaner display)
# for exact details of the customization.


In [None]:
# Load saved model
MNIST_ViT = VisionTransformer(**custom_config).to(device=device)
checkpoint = torch.load("model.pth")
MNIST_ViT.load_state_dict(checkpoint['model_state_dict'])

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])   # Transform object to apply on the dataset.

# train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Loading/Downloading dataset. `download` can be `False` if the data is present in the root directory
# Else it will download the dataset to to the root location.

test_ds = CustomDataset(data=test_dataset, device=device)
# Made custom dataset objects from the MNIST dataset.

test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)
# DataLoaders for fast implementation of loading batch-wise data.



In [None]:
def test(model : VisionTransformer):
    correct, total = 0, 0
    model.eval()
    # Setting the model in evaluation mode.
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # Loading batch images and ground truth onto device
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
    return f"Accuracy on test set: {(100 * correct / total):.2f}%"
            


In [None]:
test(MNIST_ViT)

In [None]:
# Weights matrix of the model before quantization
print('Weights before quantization')
print(MNIST_ViT.head.weight)
print(MNIST_ViT.head.weight.dtype)

In [None]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

print('Size of the model before quantization')
print_size_of_model(MNIST_ViT)

In [None]:
print(f'Accuracy of the model before quantization: ')
test(MNIST_ViT)

In [None]:
# Loading weights to the object that we have to quantize
net_quantized = VisionTransformerForPTQ(**custom_config).to(device=device)
net_quantized.load_state_dict(checkpoint['model_state_dict'])

In [None]:
net_quantized.eval()
max_bit_length = 4
# net_quantized.qconfig = torch.ao.quantization.default_qconfig

net_quantized.qconfig = torch.quantization.QConfig(
    activation=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.quint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.quint8), 
    weight=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.qint8)
)

# net_quantized.qconfig = torch.ao.quantization.QConfig(
#     activation=torch.ao.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.ao.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.quint8), quant_min =-2**(max_bit_length-1) ,quant_max=2**(max_bit_length-1)-1, dtype=torch.quint8), 
#     weight=torch.ao.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.ao.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.quint8), quant_min =-2**(max_bit_length-1) ,quant_max=2**(max_bit_length-1)-1, dtype=torch.quint8)
# )

In [None]:
net_quantized.qconfig

In [None]:
# torch.ao.quantization.QConfig(
#     activation=torch.ao.quantization.fake_quantize.FakeQuantize.with_args(quant_min =-2**(max_bit_length-1) ,quant_max=2**(max_bit_length-1)-1, dtype=torch.qint8), 
#     weight=torch.ao.quantization.fake_quantize.FakeQuantize.with_args(quant_min =-2**(max_bit_length-1) ,quant_max=2**(max_bit_length-1)-1, dtype=torch.qint8)
# )

In [None]:
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

In [None]:
test(net_quantized)

In [None]:
print(f'Check statistics of the various layers')
net_quantized

In [None]:
net_quantized.to(device)

In [None]:
net_quantized = torch.quantization.convert(net_quantized)

In [None]:
print(f'Check statistics of the various layers')
net_quantized

In [None]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(net_quantized.head)

In [None]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

In [None]:
print('Testing the model after quantization')
test(net_quantized)

In [None]:
torch.__version__

In [None]:
for i, (name, param) in enumerate(MNIST_ViT.named_parameters()):
    print(i, name,param.shape)

In [None]:
for i, (name, param) in enumerate(net_quantized.named_parameters()):
    print(i, name,param.shape)