In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from Vision_transformer import CustomDataset
from Vision_transformer import VisionTransformer
import torchvision.transforms as transforms
from pprint import pprint
from torchsummary import summary
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

device = torch.device("cpu")
# We don't want to perform our quantization step on cuda GPU. It is not supported.
with open('config.json') as f:
    custom_config = json.load(f)
# Custom configurations for the VisionTransformer.
# Transformer can be customized with these configurations.
# Refer to documentation of the class VisionTransformer
# (`VisionTransformer.__doc__`, use pprint for cleaner display)
# for exact details of the customization.


In [3]:
# Load saved model
MNIST_ViT = VisionTransformer(**custom_config).to(device=device)
checkpoint = torch.load("model.pth")
MNIST_ViT.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [5]:
MNIST_ViT

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 32, kernel_size=(4, 4), stride=(4, 4))
  )
  (pos_drop): Dropout(p=0.2, inplace=False)
  (blocks): ModuleList(
    (0-1): 2 x Block(
      (norm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=32, out_features=96, bias=True)
        (attn_drop): Dropout(p=0.2, inplace=False)
        (proj): Linear(in_features=32, out_features=32, bias=True)
        (proj_drop): Dropout(p=0.2, inplace=False)
      )
      (norm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=32, out_features=12, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=12, out_features=32, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=32, out_features=10, bias=True)
)

In [4]:
# # Load saved model
# MNIST_ViT_quant = VisionTransformerForPTQ(**custom_config).to(device=device)
# checkpoint = torch.load("model.pth")
# MNIST_ViT_quant.load_state_dict(checkpoint['model_state_dict'])

In [5]:
# inp = torch.rand((1, 1, 28, 28)).to(device)
# MNIST_ViT_quant(inp)

In [6]:
# qconfig = torch.ao.quantization.get_default_qconfig('x86')
max_bit_length = 4
# net_quantized.qconfig = torch.ao.quantization.default_qconfig

qconfig = torch.quantization.QConfig(
    activation=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.quint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.quint8), 
    weight=torch.quantization.fake_quantize.FakeQuantize.with_args(observer = torch.quantization.observer.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), quant_min = 0 ,quant_max=2**(max_bit_length)-1, dtype=torch.qint8)
)

In [7]:
# MNIST_ViT_quant_fused = torch.ao.quantization.fuse_modules(MNIST_ViT_quant, [['linear', 'gelu'], ])


In [8]:
torch.ao.quantization.quantize_dynamic

<function torch.ao.quantization.quantize.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)>

In [9]:
# create a quantized model instance


model_int8 = torch.ao.quantization.quantize_dynamic(
    model = MNIST_ViT,
    qconfig_spec = {qconfig}
)



# model_int8 = torch.ao.quantization.quantize_dynamic(
#     MNIST_ViT,  # the original model
#     qconfig,
#     {nn.Linear, nn.Conv2d, nn.LayerNorm, nn.GELU, nn.Parameter},  # a set of layers to dynamically quantize
#     dtype=torch.qint8)  # the target dtype for quantized weights

# run the model
input_fp32 = torch.randn(1, 1, 28, 28)
res = model_int8(input_fp32)

In [10]:
# class CustomDataset(Dataset):
#     """Puts incoming MNIST dataset into an object 
#         which can be loaded onto cuda gpu.
#     Parameters
#     ----------
#     data : torchvision.datasets.mnist.MNIST

#     Attributes
#     ----------
#     X : torch.Tensor
#         Shape `(n_samples, n_channels, img_height, img_width)`
#     """
#     def __init__(self, data, device = device):
#         self.X = torch.cat([torch.unsqueeze(data[i][0], dim=0) for i in range(len(data))], dim=0).to(device)
#         self.Y = torch.tensor([data[i][1] for i in range(len(data))]).to(device)
    
#     def __len__(self):
#         """Length method.
#         Parameters
#         ----------
#         None
#         Returns
#         ----------
#         int
#             n_samples

#         """
#         return self.X.shape[0]
    
#     def __getitem__(self, idx):
#         """Indexing call.
#         Parameters:
#         idx : int
#             index of element to be returned.
        
#         Returns : 
#         torch.Tensor
#             Shape `(n_channels, img_height, img_width)`
#         torch.Tensor
#             Shape `(class_idx)`
#         """
#         return self.X[idx], self.Y[idx]


In [11]:
transform = transforms.Compose([
    transforms.ToTensor(),
])   # Transform object to apply on the dataset.

# train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Loading/Downloading dataset. `download` can be `False` if the data is present in the root directory
# Else it will download the dataset to to the root location.


In [12]:
test_ds = CustomDataset(data=test_dataset)
# Made custom dataset objects from the MNIST dataset.

test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)
# DataLoaders for fast implementation of loading batch-wise data.


In [13]:
def test(model : VisionTransformer):
    correct, total = 0, 0
    model.eval()
    # Setting the model in evaluation mode.
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # Loading batch images and ground truth onto device
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
    return f"Accuracy on test set: {(100 * correct / total):.2f}%"
            
test(MNIST_ViT)

'Accuracy on test set: 95.49%'

In [14]:
model_int8.eval()
# Setting the model in evaluation mode.
correct, total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        # Loading batch images and ground truth onto device
        outputs = model_int8(images)
        # Calculating logits.
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        # Updated number of correct predictions and total predictions.

print(f"Accuracy on test set: {(100 * correct / total):.2f}%")

Accuracy on test set: 95.49%


In [15]:
torch.quantization.convert(model_int8, inplace=True)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 32, kernel_size=(4, 4), stride=(4, 4))
  )
  (pos_drop): Dropout(p=0.2, inplace=False)
  (blocks): ModuleList(
    (0-1): 2 x Block(
      (norm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=32, out_features=96, bias=True)
        (attn_drop): Dropout(p=0.2, inplace=False)
        (proj): Linear(in_features=32, out_features=32, bias=True)
        (proj_drop): Dropout(p=0.2, inplace=False)
      )
      (norm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=32, out_features=12, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=12, out_features=32, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=32, out_features=10, bias=True)
)

In [16]:
test(model_int8)

'Accuracy on test set: 95.49%'

In [19]:
for i, (name, param) in enumerate(MNIST_ViT.named_parameters()):
    print(i, name,param)

0 cls_token Parameter containing:
tensor([[[-0.0122, -0.1364, -0.0094, -0.0141,  0.0223,  0.0042,  0.0245,
          -0.0475, -0.0308, -0.0016,  0.0057,  0.0205,  0.1222,  0.0767,
          -0.0336,  0.0040, -0.0123, -0.0162, -0.4628,  0.0262,  0.0148,
          -0.0054, -0.1923,  0.0093,  0.0606, -0.0081,  0.0104, -0.0111,
           0.0501,  0.1046, -0.0188,  0.0386]]], requires_grad=True)
1 pos_embed Parameter containing:
tensor([[[-0.0122, -0.1364, -0.0094,  ...,  0.1046, -0.0188,  0.0386],
         [-0.1002,  0.2293,  0.1266,  ..., -0.2201,  0.0879,  0.2167],
         [ 0.1388,  0.2504,  0.1720,  ..., -0.0359,  0.2639,  0.1905],
         ...,
         [-0.3200,  0.4005,  0.3438,  ...,  0.2797, -0.2134,  0.1789],
         [ 0.1282,  0.2640,  0.3549,  ...,  0.2232,  0.1646,  0.0010],
         [-0.6103,  0.2312, -0.1634,  ...,  0.1081, -0.1291,  0.5354]]],
       requires_grad=True)
2 patch_embed.proj.weight Parameter containing:
tensor([[[[ 9.8508e-02, -1.6454e-01, -1.0587e-01,  5.7

In [18]:
for i, (name, param) in enumerate(model_int8.named_parameters()):
    print(i, name,param)

0 cls_token Parameter containing:
tensor([[[-0.0122, -0.1364, -0.0094, -0.0141,  0.0223,  0.0042,  0.0245,
          -0.0475, -0.0308, -0.0016,  0.0057,  0.0205,  0.1222,  0.0767,
          -0.0336,  0.0040, -0.0123, -0.0162, -0.4628,  0.0262,  0.0148,
          -0.0054, -0.1923,  0.0093,  0.0606, -0.0081,  0.0104, -0.0111,
           0.0501,  0.1046, -0.0188,  0.0386]]], requires_grad=True)
1 pos_embed Parameter containing:
tensor([[[-0.0122, -0.1364, -0.0094,  ...,  0.1046, -0.0188,  0.0386],
         [-0.1002,  0.2293,  0.1266,  ..., -0.2201,  0.0879,  0.2167],
         [ 0.1388,  0.2504,  0.1720,  ..., -0.0359,  0.2639,  0.1905],
         ...,
         [-0.3200,  0.4005,  0.3438,  ...,  0.2797, -0.2134,  0.1789],
         [ 0.1282,  0.2640,  0.3549,  ...,  0.2232,  0.1646,  0.0010],
         [-0.6103,  0.2312, -0.1634,  ...,  0.1081, -0.1291,  0.5354]]],
       requires_grad=True)
2 patch_embed.proj.weight Parameter containing:
tensor([[[[ 9.8508e-02, -1.6454e-01, -1.0587e-01,  5.7

In [39]:
torch.ao.quantization.fake_quantize.FakeQuantize(quant_min = 0, quant_max=2**7-1)

FakeQuantize(
  fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=127, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
  (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)

In [93]:
torch.quantization.QuantStub(torch.ao.quantization.fake_quantize.FakeQuantize(quant_min = 0, quant_max=2**8-1))

QuantStub(
  (qconfig): FakeQuantize(
    fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  )
)

In [45]:
torch.quantization.QuantStub(torch.ao.quantization.observer.MinMaxObserver())

QuantStub(
  (qconfig): MinMaxObserver(min_val=inf, max_val=-inf)
)

In [47]:
torch.ao.quantization.default_qconfig

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})

In [87]:
max_bit_length = 8
custom_config = torch.ao.quantization.QConfig(
    activation=torch.ao.quantization.fake_quantize.FakeQuantize.with_args(quant_max=2**max_bit_length-1, dtype=torch.quint8),  # Use HistogramObserver for activations
    weight=torch.ao.quantization.fake_quantize.FakeQuantize.with_args(quant_max=2**max_bit_length-1, dtype=torch.quint8)  # Keep the default observer for weights (can be changed too)
)

In [79]:
2**max_bit_length-1

255

In [77]:
torch.ao.quantization.default_observer

functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}

In [90]:
custom_config

QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, quant_max=255, dtype=torch.quint8){}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, quant_max=255, dtype=torch.quint8){})

In [86]:
custom_config

QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, quant_max=1237940039285380274899124223, dtype=torch.quint8){}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, quant_max=1237940039285380274899124223, dtype=torch.quint8){})