In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from Vision_transformer import CustomDataset
from Vision_transformer import VisionTransformer
import torchvision.transforms as transforms
from pprint import pprint
from torchsummary import summary
import json

In [None]:

device = torch.device("cpu")
# We don't want to perform our quantization step on cuda GPU. It is not supported.
with open('config.json') as f:
    custom_config = json.load(f)
# Custom configurations for the VisionTransformer.
# Transformer can be customized with these configurations.
# Refer to documentation of the class VisionTransformer
# (`VisionTransformer.__doc__`, use pprint for cleaner display)
# for exact details of the customization.


In [None]:
# Load saved model
MNIST_ViT = VisionTransformer(**custom_config).to(device=device)
checkpoint = torch.load("model.pth")
MNIST_ViT.load_state_dict(checkpoint['model_state_dict'])

In [None]:
MNIST_ViT

In [None]:
# # Load saved model
# MNIST_ViT_quant = VisionTransformerForPTQ(**custom_config).to(device=device)
# checkpoint = torch.load("model.pth")
# MNIST_ViT_quant.load_state_dict(checkpoint['model_state_dict'])

In [None]:
# inp = torch.rand((1, 1, 28, 28)).to(device)
# MNIST_ViT_quant(inp)

In [None]:
# qconfig = torch.ao.quantization.get_default_qconfig('x86')
max_bit_length = 4
# net_quantized.qconfig = torch.ao.quantization.default_qconfig

qconfig = torch.quantization.QConfig(
    activation=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.quint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.quint8), 
    weight=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.qint8)
)

In [None]:
# MNIST_ViT_quant_fused = torch.ao.quantization.fuse_modules(MNIST_ViT_quant, [['linear', 'gelu'], ])


In [None]:
torch.ao.quantization.quantize_dynamic

In [None]:
# create a quantized model instance


model_int8 = torch.ao.quantization.quantize_dynamic(
    model = MNIST_ViT,
    qconfig_spec = {qconfig}
)



# model_int8 = torch.ao.quantization.quantize_dynamic(
#     MNIST_ViT,  # the original model
#     qconfig,
#     {nn.Linear, nn.Conv2d, nn.LayerNorm, nn.GELU, nn.Parameter},  # a set of layers to dynamically quantize
#     dtype=torch.qint8)  # the target dtype for quantized weights

# run the model
input_fp32 = torch.randn(1, 1, 28, 28)
res = model_int8(input_fp32)

In [None]:
# class CustomDataset(Dataset):
#     """Puts incoming MNIST dataset into an object 
#         which can be loaded onto cuda gpu.
#     Parameters
#     ----------
#     data : torchvision.datasets.mnist.MNIST

#     Attributes
#     ----------
#     X : torch.Tensor
#         Shape `(n_samples, n_channels, img_height, img_width)`
#     """
#     def __init__(self, data, device = device):
#         self.X = torch.cat([torch.unsqueeze(data[i][0], dim=0) for i in range(len(data))], dim=0).to(device)
#         self.Y = torch.tensor([data[i][1] for i in range(len(data))]).to(device)
    
#     def __len__(self):
#         """Length method.
#         Parameters
#         ----------
#         None
#         Returns
#         ----------
#         int
#             n_samples

#         """
#         return self.X.shape[0]
    
#     def __getitem__(self, idx):
#         """Indexing call.
#         Parameters:
#         idx : int
#             index of element to be returned.
        
#         Returns : 
#         torch.Tensor
#             Shape `(n_channels, img_height, img_width)`
#         torch.Tensor
#             Shape `(class_idx)`
#         """
#         return self.X[idx], self.Y[idx]


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])   # Transform object to apply on the dataset.

# train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Loading/Downloading dataset. `download` can be `False` if the data is present in the root directory
# Else it will download the dataset to to the root location.


In [None]:
test_ds = CustomDataset(data=test_dataset)
# Made custom dataset objects from the MNIST dataset.

test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)
# DataLoaders for fast implementation of loading batch-wise data.


In [None]:
def test(model : VisionTransformer):
    correct, total = 0, 0
    model.eval()
    # Setting the model in evaluation mode.
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # Loading batch images and ground truth onto device
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
    return f"Accuracy on test set: {(100 * correct / total):.2f}%"
            
test(MNIST_ViT)

In [None]:
model_int8.eval()
# Setting the model in evaluation mode.
correct, total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        # Loading batch images and ground truth onto device
        outputs = model_int8(images)
        # Calculating logits.
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        # Updated number of correct predictions and total predictions.

print(f"Accuracy on test set: {(100 * correct / total):.2f}%")

In [None]:
torch.quantization.convert(model_int8, inplace=True)

In [None]:
test(model_int8)

In [None]:
for i, (name, param) in enumerate(MNIST_ViT.named_parameters()):
    print(i, name,param)

In [None]:
for i, (name, param) in enumerate(model_int8.named_parameters()):
    print(i, name,param)

In [None]:
torch.ao.quantization.fake_quantize.FakeQuantize(quant_min = 0, quant_max=2**7-1)

In [None]:
torch.quantization.QuantStub(torch.ao.quantization.fake_quantize.FakeQuantize(quant_min = 0, quant_max=2**8-1))

In [None]:
torch.quantization.QuantStub(torch.ao.quantization.observer.MinMaxObserver())

In [None]:
torch.ao.quantization.default_qconfig

In [None]:
max_bit_length = 8
custom_config = torch.ao.quantization.QConfig(
    activation=torch.ao.quantization.fake_quantize.FakeQuantize.with_args(quant_max=2**max_bit_length-1, dtype=torch.quint8),  # Use HistogramObserver for activations
    weight=torch.ao.quantization.fake_quantize.FakeQuantize.with_args(quant_max=2**max_bit_length-1, dtype=torch.quint8)  # Keep the default observer for weights (can be changed too)
)

In [None]:
2**max_bit_length-1

In [None]:
torch.ao.quantization.default_observer

In [None]:
custom_config

In [None]:
custom_config