# Quantization Aware Training
* [dataset](https://www.kaggle.com/datasets/justin900429/3d-printer-defected-dataset)

In [None]:
# Use my own version of timm to support quantized option
!pip install git+https://github.com/Justin900429/pytorch-image-models.git
!pip install gdown==4.4.0
!gdown 1Fq0DkvzoB3wI6a8IgPeYplD01c-WmXvn -O tmp.zip && unzip -q tmp.zip && rm tmp.zip
!gdown 1-5w5uWGPVL43a2xkFHRI6WY5KPak_ZHQ

In [1]:
import os
import copy
import glob
import random

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.utils import make_grid
from torchvision import transforms
from torchvision.datasets import ImageFolder

import timm
from tqdm import tqdm

from PIL import Image
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm 

from sklearn.metrics import f1_score, accuracy_score

## Create training dataset

In [2]:
random.seed(0)

# Set up training dataset
train_no_defect = [
    file for file in glob.glob("no_defected/*.jpg") if "scratch_2" not in file]
train_yes_defect = [
    file for file in glob.glob("defected/*.jpg") if "no_bottom" not in file
]
train_yes_defect = random.choices(train_yes_defect, k=len(train_no_defect))

# Set up validation dataset
val_no_defect = [
    file for file in glob.glob("no_defected/*.jpg") if "scratch_2" in file]
val_yes_defect = [
    file for file in glob.glob("defected/*.jpg") if "no_bottom" in file]

## Create Dataset

In [3]:
class ListDataset(Dataset):
    def __init__(self, yes_defect, no_defect, transform=None):
        self.img_list = yes_defect + no_defect
        self.label = [1] * len(yes_defect) + [0] * len(no_defect)
        self.transform = transform

    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx])
        label = self.label[idx]
        img = self.transform(img)
        return img, label

In [4]:
def make_loader(yes_defect, no_defect, transform, batch_size,
                shuffle=True, num_workers=2, pin_memory=True,
                train=True):
    dataset = ListDataset(
        yes_defect=yes_defect, no_defect=no_defect, transform=transform)
    loader = DataLoader(
        dataset, batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        pin_memory=pin_memory)
    
    return loader

In [5]:
@torch.no_grad()
def evaluate(model, val_loader, device):
    model.eval()

    total_predict = []
    total_ground_truth = []
    for data, label in val_loader:
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        prediction = output.argmax(dim=-1)

        total_predict.extend(prediction.cpu().tolist())
        total_ground_truth.extend(label.cpu().tolist())

    return accuracy_score(total_ground_truth, total_predict), \
           f1_score(total_ground_truth, total_predict, average="macro")

In [6]:
def quantize_train(model, train_loader, val_loader, criterion, optimizer, args):
    best_f1 = 0
    model.apply(torch.ao.quantization.enable_observer)
    model.apply(torch.ao.quantization.enable_fake_quant)
    for epoch in range(args.epochs):
        train_progress_bar = tqdm(
            train_loader, desc=f"Epochs: {epoch + 1}/{args.epochs}")
        
        model.train()
        for data, label in train_progress_bar:
            data = data.to(args.device)
            label = label.to(args.device)

            # Send data into the model and compute the loss
            output = model(data)
            loss = criterion(output, label)

            # Update the model with back propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Obtain the quantization model for evaluation
        quantized_model = torch.ao.quantization.convert(
            copy.deepcopy(model).cpu().eval(), inplace=False)

        # Compute the accuracy ans save the best model
        eval_acc, eval_f1 = evaluate(
            quantized_model, val_loader, "cpu")
        print(f"Validation accuracy: {eval_acc:.8f} f1-score: {eval_f1:.8f}")
        if eval_f1 > best_f1:
            best_f1 = eval_f1
            torch.save(model.state_dict(), "best.pt")

In [7]:
def get_quantized_model_from_weight(model, weight="best.pt"):
    new_model = copy.deepcopy(model).cpu().eval()
    new_model.load_state_dict(torch.load(weight, map_location="cpu"))
    quantized_model = torch.quantization.convert(new_model, inplace=False)
    return quantized_model

## Create Model

In [8]:
class QuantizeTrainModel(nn.Module):
    def __init__(self, model_name="resnet34", pretrained=True, num_classes=2):
        super().__init__()
        
        # Model settings
        self.model_name = model_name
        self.pretrained=pretrained
        self.num_classes = num_classes
        
        # Floating point -> Integer for input
        self.quant = torch.ao.quantization.QuantStub()

        # Check out the doc: https://rwightman.github.io/pytorch-image-models/
        #  for different models
        self.model = timm.create_model(
            model_name, pretrained=pretrained, 
            block_args={"use_quantized": True})
        
        # Change the output linear layers to fit the output classes
        self.model.fc = nn.Linear(
            self.model.fc.weight.shape[1],
            num_classes
        )

        # Integer to Floating point for output
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        return self.dequant(x)

    def clone(self):
        clone = QuantizeTrainModel(self.model_name, self.pretrained, self.num_classes)
        clone.load_state_dict(self.state_dict())
        if self.is_cuda():
            clone.cuda()
        return clone
    
    def fused_module_inplace(self):
        """
        Fusing the model for resnet family only. Print out the model architecture
        to know how the logic works. Note that during the quantize fusion, 
        conv + bn + act or conv + bn should be bundled to together
        """
        self.train()
        
        for module_name, module in self.named_children():
            # This coding style should not be correct but the code can be
            #  more readable
            if "model" not in module_name:
                continue
            
            torch.ao.quantization.fuse_modules_qat(
                module, [["conv1", "bn1", "act1"]], inplace=True
            )
            for basic_block_name, basic_block in module.named_children():
                # Same as above reason :)
                if "layer" not in basic_block_name:
                    continue

                for sub_block_name, sub_block in basic_block.named_children():
                    torch.ao.quantization.fuse_modules_qat(
                        sub_block, 
                        [["conv1", "bn1", "act1"], ["conv2", "bn2", "act2"]],
                        inplace=True
                    )
                    for sub_sub_block_name, sub_sub_block in sub_block.named_children():
                        if sub_block_name == "downsample":
                            torch.ao.quantization.fuse_modules_qat(
                                sub_block, [["0", "1"]], inplace=True
                            )

In [9]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "tmp.pt")
    print("Size (MB):", os.path.getsize("tmp.pt") / 1e6)
    os.remove("tmp.pt")

## Quantized awared training

In [10]:
class args:
    # Training
    epochs = 30
    batch_size = 32
    lr = 3e-4
    weight_decay=1e-5
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Transform
    size = 400
    crop_size = 352
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

In [11]:
# Set up train loader and test loader
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((args.size, args.size)),
    transforms.CenterCrop((args.crop_size, args.crop_size)),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=args.mean, std=args.std)
])
val_transform = transforms.Compose([
    transforms.ToTensor(),                                    
    transforms.Resize((args.size, args.size)),
    transforms.CenterCrop((args.crop_size, args.crop_size)),
    transforms.Normalize(mean=args.mean, std=args.std)
])

train_loader = make_loader(
    yes_defect=train_yes_defect, no_defect=train_no_defect,
    batch_size=args.batch_size,
    transform=train_transform)
val_loader = make_loader(
    yes_defect=val_yes_defect, no_defect=val_no_defect,
    batch_size=args.batch_size,
    transform=val_transform, train=False)

In [12]:
train_loader = make_loader(
    yes_defect=train_yes_defect, no_defect=train_no_defect,
    batch_size=args.batch_size,
    transform=train_transform)
val_loader = make_loader(
    yes_defect=val_yes_defect, no_defect=val_no_defect,
    batch_size=args.batch_size,
    transform=val_transform, train=False)

In [None]:
model = QuantizeTrainModel().to(args.device)
model.load_state_dict(torch.load("cur.pt", map_location=args.device))
model.fused_module_inplace()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()

# Set up quantization config to fbgemm
# See more: https://github.com/pytorch/FBGEMM
model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")
# Shiftign to train model is required for quantization training
model.train()
torch.ao.quantization.prepare_qat(model, inplace=True)

In [14]:
quantized_model = get_quantized_model_from_weight(model)
print_size_of_model(quantized_model)

Size (MB): 21.538525


## Start Training

In [34]:
# Start training
quantize_train(model, train_loader, val_loader, criterion, optimizer, args)

Epochs: 1/30: 100%|██████████| 32/32 [00:11<00:00,  2.82it/s]


Validation accuracy: 0.97943445 f1-score: 0.97152686


Epochs: 2/30: 100%|██████████| 32/32 [00:11<00:00,  2.79it/s]


Validation accuracy: 0.99485861 f1-score: 0.99313303


Epochs: 3/30: 100%|██████████| 32/32 [00:11<00:00,  2.82it/s]


Validation accuracy: 0.99742931 f1-score: 0.99655463


Epochs: 4/30: 100%|██████████| 32/32 [00:11<00:00,  2.80it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 5/30: 100%|██████████| 32/32 [00:11<00:00,  2.68it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 6/30: 100%|██████████| 32/32 [00:11<00:00,  2.81it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 7/30: 100%|██████████| 32/32 [00:11<00:00,  2.84it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 8/30: 100%|██████████| 32/32 [00:11<00:00,  2.81it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 9/30: 100%|██████████| 32/32 [00:11<00:00,  2.81it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 10/30: 100%|██████████| 32/32 [00:11<00:00,  2.80it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 11/30: 100%|██████████| 32/32 [00:11<00:00,  2.78it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 12/30: 100%|██████████| 32/32 [00:11<00:00,  2.78it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 13/30: 100%|██████████| 32/32 [00:11<00:00,  2.79it/s]


Validation accuracy: 0.99742931 f1-score: 0.99655463


Epochs: 14/30: 100%|██████████| 32/32 [00:11<00:00,  2.81it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 15/30: 100%|██████████| 32/32 [00:11<00:00,  2.78it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 16/30: 100%|██████████| 32/32 [00:11<00:00,  2.79it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 17/30: 100%|██████████| 32/32 [00:11<00:00,  2.76it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 18/30: 100%|██████████| 32/32 [00:11<00:00,  2.81it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 19/30: 100%|██████████| 32/32 [00:11<00:00,  2.79it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 20/30: 100%|██████████| 32/32 [00:11<00:00,  2.80it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 21/30: 100%|██████████| 32/32 [00:11<00:00,  2.76it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 22/30: 100%|██████████| 32/32 [00:11<00:00,  2.79it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 23/30: 100%|██████████| 32/32 [00:11<00:00,  2.76it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 24/30: 100%|██████████| 32/32 [00:11<00:00,  2.78it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 25/30: 100%|██████████| 32/32 [00:11<00:00,  2.78it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 26/30: 100%|██████████| 32/32 [00:11<00:00,  2.79it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 27/30: 100%|██████████| 32/32 [00:11<00:00,  2.77it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 28/30: 100%|██████████| 32/32 [00:11<00:00,  2.78it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 29/30: 100%|██████████| 32/32 [00:11<00:00,  2.78it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


Epochs: 30/30: 100%|██████████| 32/32 [00:11<00:00,  2.79it/s]


Validation accuracy: 1.00000000 f1-score: 1.00000000


## Save quantized model

In [23]:
quantized_model = get_quantized_model_from_weight(model)
torch.save(quantized_model.state_dict(), "quantized.pt")