# Lab4: Quantize DeiT

### Setup

In [1]:
# install the newest version of torch, torchvision, and timm
#!pip3 uninstall --yes torch torchaudio torchvision torchtext torchdata timm
#!pip3 install torch torchaudio torchvision torchtext torchdata timm

In [2]:
import numpy as np
import torch
from torch import nn
import os
from tqdm.auto import tqdm
import math
import time

from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.utils.data import DataLoader

from torch.export import export, ExportedProgram
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
import copy

torch.cuda.is_available()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda:0


In [3]:
def data_loader_to_list(data_loader, length=128):
    new_data_loader = []
    if length < 0:
        return list(data_loader)
    else:
        for i, data in enumerate(data_loader):
            if i >= length:
                break
            new_data_loader.append(data)

    return new_data_loader

def build_dataset_CIFAR100(is_train, data_path):
    transform = build_transform(is_train)
    dataset = datasets.CIFAR100(data_path, train=is_train, transform=transform, download=True)
    nb_classes = 100
    return dataset, nb_classes

def build_transform(is_train):
    input_size = 224
    eval_crop_ratio = 1.0

    resize_im = input_size > 32
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=input_size,
            is_training=True,
            color_jitter=0.3,
            auto_augment='rand-m9-mstd0.5-inc1',
            interpolation='bicubic',
            re_prob=0.0,
            re_mode='pixel',
            re_count=1,
        )
        if not resize_im:
            # replace RandomResizedCropAndInterpolation with
            # RandomCrop
            transform.transforms[0] = transforms.RandomCrop(
                input_size, padding=4)
        return transform

    t = []
    if resize_im:
        size = int(input_size / eval_crop_ratio)
        t.append(
            transforms.Resize(size, interpolation=3),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(input_size))

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)

def prepare_data(batch_size):
    train_set, nb_classes = build_dataset_CIFAR100(is_train=True, data_path='./data')
    test_set, _ = build_dataset_CIFAR100(is_train=False, data_path='./data')

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=True)
    return train_loader, test_loader, nb_classes


In [4]:
def calibrate(model: nn.Module, data_loader) -> None:
    calibration_data = data_loader_to_list(data_loader, math.ceil(128/data_loader.batch_size)) # calibrate 128 images
    for image, _ in calibration_data:
        model(image)
    return

def train_one_epoch(model, criterion, optimizer, data_loader, device):
    cnt = 0
    for image, target in tqdm(data_loader):
        cnt += 1
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [5]:
def get_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    model_size = os.path.getsize("temp.p")/1e6
    os.remove('temp.p')
    return model_size

def evaluate_model(model, data_loader, device):
    model.to(device)
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(data_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the model on the test images: {accuracy}%')
    return accuracy

def getMiniTestDataset():
    # Create a test_loader with batch size = 1
    _, test_loader, _ = prepare_data(batch_size=1)

    # Prepare to collect 10 images per class
    class_images = [[] for _ in range(100)]

    # Iterate through the data
    for (image, label) in test_loader:
        if len(class_images[label]) < 5:
            class_images[label].append((image, label))
        if all(len(images) == 5 for images in class_images):
            break  # Stop once we have 10 images per class

    # flatten class_images
    mini_test_dataset = []
    for images in class_images:
        mini_test_dataset.extend(images)
    return mini_test_dataset

# TA Uses the following code to evaluate your score
def lab4_cifar100_evaluation(quantized_model_path='deits_quantized.pth'):
    # Prepare data
    mini_test_dataset = getMiniTestDataset()

    # Load quantized model
    quantized_ep = torch.export.load(quantized_model_path)
    quantized_model = quantized_ep.module()

    # Evaluate model
    start_time = time.time()
    acc = evaluate_model(quantized_model, mini_test_dataset, device="cpu")
    exec_time = time.time() - start_time
    model_size = get_size_of_model(quantized_model)

    print(f"Model Size: {model_size:.2f} MB")
    print(f"Accuracy: {acc:.2f}%")
    print(f"Execution Time: {exec_time:.2f} s")

    score = 0
    if model_size <= 30: score += 10
    if model_size <= 27: score += 2 * math.floor(27-model_size)
    if acc >= 86:
      score += 10 + 2 * math.floor(acc-86)
    print(f'Model Score: {score:.2f}')
    return score

## Part1: Simple Quantization Pipeline (0%)

Below is a naive pipeline of quantizing DeiT-S. You may need to modify the pipeline or build your own later on.

[**use_reference_representation=False** in **convert_pt2e()** represents fake quant (matmul using fp32).](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model)

However when the variable is set to True, the execution speed becomes extremely slow.

In this lab, it is just fine to set **use_reference_representation=False**.

In [6]:
# Replace functions XNNPackQuantizer uses for annotation.

def get_module_names(name):
    names = name.split(".")
    return [".".join(names[i:]) for i in reversed(range(len(names)))]

def parse_string(name):
    if name.startswith("L"): return name[10:]
    split_getattr = name.split(")")
    ig_left = split_getattr[0].split("L['self'].")[-1].split(",")[0]
    ig_right = split_getattr[0].split(", '")[-1][:-1]
    return ig_left + "." + str(ig_right) + split_getattr[1]

def is_name_in_ignore_list(name, IGNORE_LIST):
    return name in IGNORE_LIST

def name_not_in_ignore_list(n, IGNORE_LIST) -> bool:
    nn_module_stack = n.meta.get("nn_module_stack", {})
    names = [n for n, klass in nn_module_stack.values()]
    if len(names) == 0:
        return True

    names = get_module_names(parse_string(names[-1]))
    set1 = set(names)
    set2 = set(IGNORE_LIST)
    # if len(set1.intersection(set2)) == 0:
    #     print("DEBUG: ", names)
    return len(set1.intersection(set2)) == 0

def get_module_name_filter(module_name: str, IGNORE_LIST):
    def module_name_filter(n) -> bool:
        nn_module_stack = n.meta.get("nn_module_stack", {})
        names = [n for n, klass in nn_module_stack.values()]
        if len(names) == 0:
            return False

        names = get_module_names(parse_string(names[-1]))
        return (module_name in names) and name_not_in_ignore_list(n, IGNORE_LIST)
    return module_name_filter


def get_module_type_filter(tp, IGNORE_LIST):
    def module_type_filter(n) -> bool:
        nn_module_stack = n.meta.get("nn_module_stack", {})
        types = [t for _, t in nn_module_stack.values()]
        return (tp in types) and name_not_in_ignore_list(n, IGNORE_LIST)

    return module_type_filter


def get_not_module_type_or_name_filter(
    tp_list, module_name_list, IGNORE_LIST
):
    module_type_filters = [get_module_type_filter(tp) for tp in tp_list]
    module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]

    def not_module_type_or_name_filter(n) -> bool:
        return not any(f(n) for f in module_type_filters + module_name_list_filters) and name_not_in_ignore_list(n, IGNORE_LIST)

    return not_module_type_or_name_filter

class PartialXNNPACKQuantizer(XNNPACKQuantizer): # skips quantizing layers inside the ignore_list
    def __init__(self, ignore_list):
        super().__init__()
        self.ignore_list = ignore_list

    def _annotate_for_static_quantization_config(
        self, model: torch.fx.GraphModule
    ) -> torch.fx.GraphModule:
        print("annotating for static quantization")
        module_name_list = list(self.module_name_config.keys())
        for module_name, config in self.module_name_config.items():
            self._annotate_all_static_patterns(
                model, config, get_module_name_filter(module_name, self.ignore_list)
            )

        tp_list = list(self.module_type_config.keys())
        for module_type, config in self.module_type_config.items():
            self._annotate_all_static_patterns(
                model, config, get_module_type_filter(module_type, self.ignore_list)
            )

        self._annotate_all_static_patterns(
            model,
            self.global_config,
            get_not_module_type_or_name_filter(tp_list, module_name_list, self.ignore_list),
        )
        return model

    def _annotate_for_dynamic_quantization_config(
        self, model: torch.fx.GraphModule
    ) -> torch.fx.GraphModule:
        print("annotating for dynamic quantization")
        module_name_list = list(self.module_name_config.keys())
        for module_name, config in self.module_name_config.items():
            self._annotate_all_dynamic_patterns(
                model, config, get_module_name_filter(module_name, self.ignore_list)
            )

        tp_list = list(self.module_type_config.keys())
        for module_type, config in self.module_type_config.items():
            self._annotate_all_dynamic_patterns(
                model, config, get_module_type_filter(module_type, self.ignore_list)
            )

        self._annotate_all_dynamic_patterns(
            model,
            self.global_config,
            get_not_module_type_or_name_filter(tp_list, module_name_list, self.ignore_list),
        )
        return model

# quantizer = XNNPACKQuantizer()
# quantizer = PartialXNNPACKQuantizer(ignore_list=["head"]) # replace XNNPACKQuantizer()

In [7]:
from torch.quantization import get_default_qconfig, prepare, convert

def partial_quantize_ptq_model(quantizer, model: nn.Module, device, data_loader, per_channel=False) -> None:
    _dummy_input_data = (next(iter(data_loader))[0].to(device),)
    model = model.to(device)
    model.eval()
    model = capture_pre_autograd_graph(model, _dummy_input_data)

    quantization_config = get_symmetric_quantization_config(is_per_channel=per_channel, is_qat=False)
    quantizer.set_global(quantization_config)
    # prepare_pt2e folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model.
    model = prepare_pt2e(model, quantizer)

    # model(*_dummy_input_data)
    #get 128 input data for calibration
    for i, (image, _) in enumerate(data_loader):
        if i >= 128:
            break
        image = image.to(device)
        model(image)

    model = convert_pt2e(model, use_reference_representation=False)

    return model

In [8]:
def prepare_qat_model(model: nn.Module, quantizer, device) -> None:
    ############### YOUR CODE STARTS HERE ###############

    # Step 1. program capture
    example_inputs = (torch.randn(1, 3, 224, 224).to(device),)
    model = capture_pre_autograd_graph(model, example_inputs)
    # Step 2. set quantizatizer
    # quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
    # Step 3. prepare qat pt2e
    model = prepare_qat_pt2e(model, quantizer)

    ############### YOUR CODE ENDS HERE #################

    return model

In [9]:
# ignore_list = ["blocks.9.mlp.fc1", "blocks.2.attn.qkv", "blocks.10.mlp.fc1"]
ignore_list = ["cls.token", "blocks.2.attn.qkv"]
quantizer = PartialXNNPACKQuantizer(ignore_list=ignore_list) # replace XNNPACKQuantizer()


In [10]:
# batch_size = 1
# model = torch.load('0.9099_deit3_small_patch16_224.pth', map_location='cpu')
# model = model.to(device)
# train_loader, test_loader, nb_classes = prepare_data(batch_size)

# simple_test_loader = data_loader_to_list(test_loader)

# mini_test_dataset = getMiniTestDataset()
# # evaluate before quantization
# print('Before quantization:')
# print('Device:', device)
# acc = evaluate_model(model, mini_test_dataset, device)
# # acc = evaluate_model(model, test_loader, device) # acc: 90.99%
# print('Size (MB) before quantization:', get_size_of_model(model))
# print(f'Accuracy of the model on the test images: {acc}%') # 92.8%
# # evaluate_model(model, test_loader, device) # acc: 90.99%

# # quantize model
# print('Quantizing model...')
# model.cpu()
# quantized_model = partial_quantize_ptq_model(quantizer, model, device, train_loader, per_channel=False)
# torch.ao.quantization.move_exported_model_to_eval(quantized_model)

# print('Exporting model...')
# quantized_model_path = "deits_quantized_1.pth"

# quantized_model.cpu()
# cpu_example_inputs = (torch.randn([1, 3, 224, 224]), ) # batch_size should equal to 1 on inference.
# quantized_ep = torch.export.export(quantized_model, cpu_example_inputs)
# torch.export.save(quantized_ep, quantized_model_path)

# print('Evaluating model...')
# lab4_cifar100_evaluation(quantized_model_path) # 84.4%

In [11]:

batch_size = 1
model = torch.load('0.9099_deit3_small_patch16_224.pth', map_location='cpu')
model = model.to(device)
train_loader, test_loader, nb_classes = prepare_data(batch_size)

############### YOUR CODE STARTS HERE ###############
num_epochs = 3
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
############### YOUR CODE ENDS HERE #################

num_observer_update_epochs = 4
num_batch_norm_update_epochs = 3
num_epochs_between_evals = 2

criterion = nn.CrossEntropyLoss()

prepared_model  = prepare_qat_model(model, quantizer, device)
prepared_model = prepared_model.to(device)

# QAT takes time and one needs to train over a few epochs.
for epoch in range(num_epochs):
    train_one_epoch(prepared_model, criterion, optimizer, train_loader, device)

    # Optionally disable observer/batchnorm stats after certain number of epochs
    if epoch >= num_observer_update_epochs:
        print("Disabling observer for subseq epochs, epoch = ", epoch)
        prepared_model.apply(torch.ao.quantization.disable_observer)
    if epoch >= num_batch_norm_update_epochs:
        print("Freezing BN for subseq epochs, epoch = ", epoch)
        for n in prepared_model.graph.nodes:
            # Args: input, weight, bias, running_mean, running_var, training, momentum, eps
            # set the `training` flag to False here to freeze BN stats
            if n.target in [
                torch.ops.aten._native_batch_norm_legit.default,
                torch.ops.aten.cudnn_batch_norm.default,
            ]:
                new_args = list(n.args)
                new_args[5] = False
                n.args = new_args
        prepared_model.recompile()

    # Check the quantized accuracy every N epochs

    if (epoch + 1) % num_epochs_between_evals == 0:
        prepared_model_copy = copy.deepcopy(prepared_model)
        quantized_model = convert_pt2e(prepared_model_copy)
        acc = evaluate_model(quantized_model, test_loader,device)

print('Exporting model...')
quantized_model_path = "deits_quantized_1.pth"

quantized_model.cpu()
cpu_example_inputs = (torch.randn([1, 3, 224, 224]), ) # batch_size should equal to 1 on inference.
quantized_ep = torch.export.export(quantized_model, cpu_example_inputs)
torch.export.save(quantized_ep, quantized_model_path)

print('Evaluating model...')
lab4_cifar100_evaluation(quantized_model_path) # 84.4%

Files already downloaded and verified
Files already downloaded and verified


TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., size=(1, 196, 384), grad_fn=<TransposeBackward0>), Parameter(FakeTensor(..., device='cuda:0', size=(1, 196, 384), requires_grad=True))), **{}):
Unhandled FakeTensor Device Propagation for aten.add.Tensor, found two different devices cpu, cuda:0

from user code:
   File "/home/aa35037123/miniconda3/envs/lab4/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 704, in forward
    x = self.forward_features(x)
  File "/home/aa35037123/miniconda3/envs/lab4/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 682, in forward_features
    x = self._pos_embed(x)
  File "/home/aa35037123/miniconda3/envs/lab4/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 619, in _pos_embed
    x = x + pos_embed

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
