In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from transformer.Vision_transformer import VisionTransformer, CustomDataset
import torchvision.transforms as transforms
from pprint import pprint
from torchsummary import summary
import json
import numpy as np
import os


from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QConfigMapping

In [None]:

device = torch.device("cpu")
print(f"Using device : {device}")
# We don't want to perform our quantization step on cuda GPU. It is not supported.
with open('transformer/config.json') as f:
    custom_config = json.load(f)
# Custom configurations for the VisionTransformer.
# Transformer can be customized with these configurations.
# Refer to documentation of the class VisionTransformer
# (`VisionTransformer.__doc__`, use pprint for cleaner display)
# for exact details of the customization.


In [None]:
# Load saved model
float_model = VisionTransformer(**custom_config).to(device=device)

# float_model = VisionTransformer(
#     img_size=32,
#     patch_size=8,
#     in_chans=3,
#     n_classes=10,
#     embed_dim=128,
#     depth=2,
#     n_heads=2,
#     mlp_ratio=4.,
#     p=0.3,
#     attn_p=0.3
# ).to(device=device)
checkpoint = torch.load("fashionMNIST.pth")
float_model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])   # Transform object to apply on the dataset.

# train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# Loading/Downloading dataset. `download` can be `False` if the data is present in the root directory
# Else it will download the dataset to to the root location.

train_ds = CustomDataset(data=train_dataset)
test_ds = CustomDataset(data=test_dataset, device=device)
# Made custom dataset objects from the MNIST dataset.

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)
# DataLoaders for fast implementation of loading batch-wise data.



In [None]:
# transform = transforms.Compose([
#     transforms.ToTensor(),  # Convert PIL Image to tensor
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the image tensors
# ])

# # Load CIFAR-10 training dataset
# train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# # Load CIFAR-10 test dataset
# test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# # Classes in CIFAR-10 dataset
# classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


# train_ds = CustomDataset(data=train_dataset)
# test_ds = CustomDataset(data=test_dataset)
# # Made custom dataset objects from the MNIST dataset.

# train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2048, shuffle=True)
# test_loader = torch.utils.data.DataLoader(test_ds, batch_size=2048, shuffle=False)
# # DataLoaders for fast implementation of loading batch-wise data.


In [None]:
# qconfig = get_default_qconfig("x86")

max_bit_length = 4
qconfig = torch.quantization.QConfig(
    activation=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.quint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.quint8), 
    weight=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.qint8)
)
qconfig_mapping = QConfigMapping().set_global(qconfig)

In [None]:
qconfig_mapping

In [None]:
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, _ in data_loader:
            image = image.to(device)
            model(image)

In [None]:
example_inputs = (next(iter(test_loader))[0]) # get an example input
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs=example_inputs ) 

In [None]:
calibrate(prepared_model, train_loader)

In [None]:
quantized_model = convert_fx(prepared_model)

In [None]:
quantized_model

In [None]:
def test(model):
    correct, total = 0, 0
    model.eval()
    # Setting the model in evaluation mode.
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # Loading batch images and ground truth onto device
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
    return f"Accuracy on test set: {(100 * correct / total):.2f}%"

In [None]:
test(float_model)

In [None]:
test(quantized_model)