In [1]:
### Utilities
import os
import json
import numpy as np
from time import time
from tqdm import tqdm
import copy

### Torch
import torch
from torch import nn
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import MultiStepLR

### Quantization
from torch.quantization.qconfig import QConfig
from torch.quantization.observer import MinMaxObserver

### Custom
from src.model import model
from src.train import validate
from src.utils import Accuracy, get_model_size
from src.qmodel import quantize_model, compile_module

### Configuration

In [2]:
with open("cfg/cfg_64_200.json") as configurations:
    cfg, cfg_CIFAR, cfg_dataloader_train, cfg_dataloader_test, cfg_train = json.load(configurations).values()

### Data

In [4]:
trainset = CIFAR10(transform=transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, 4),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=True)
trainloader = DataLoader(trainset, **cfg_dataloader_train)

testset = CIFAR10(transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=False)
testloader = DataLoader(testset, **cfg_dataloader_test)

n_q, idxes = 10, torch.randperm(len(trainset))
qloader = DataLoader([trainset[idxes[i]] for i in range(n_q)], **cfg_dataloader_train)

### Model

In [5]:
is_cuda = torch.cuda.is_available()
from_checkpoint = os.path.exists(cfg["checkpoint_path"])
cfg["device"] = torch.device("cuda") if is_cuda \
            else torch.device("cpu")


ResNet20 = model().to(cfg["device"])
optimResNet20 = AdamW(ResNet20.parameters(), lr=1e-2)
schedResNet20 = MultiStepLR(optimResNet20, last_epoch=-1,
                            milestones=[100, 150], gamma=0.1)

if from_checkpoint:
    checkpoint = torch.load(cfg["checkpoint_path"], map_location=cfg["device"])
    last_epoch = checkpoint["epoch"] + 1
    best_acc = checkpoint["best_acc"]
    ResNet20.load_state_dict(checkpoint["state_dict"])
    optimResNet20.load_state_dict(checkpoint["optimizer"])
    schedResNet20.load_state_dict(checkpoint["scheduler"])
    
CELoss = nn.CrossEntropyLoss(reduction="sum")
Acc = Accuracy(reduction="sum")

In [6]:
print("Load Checkpoint")
ce, acc = validate(testloader, ResNet20, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

Load Checkpoint


100%|██████████| 10/10 [00:35<00:00,  3.55s/it]

Cross Entropy: 0.374, Accuracy: 0.926





In [7]:
qconfig_int2 = QConfig(
    activation=MinMaxObserver.with_args(
        dtype=torch.quint8,
        quant_min=0,
        quant_max=3

    ),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min=-2,
        quant_max=1
    )
)
    
ResNet20PTQint2 = torch.quantization.QuantWrapper(
    copy.deepcopy(ResNet20)
)

ResNet20PTQint2.qconfig = qconfig_int2
torch.quantization.prepare(ResNet20PTQint2, inplace=True)

print("Compute Statistics")
ce, acc = validate(qloader, ResNet20PTQint2, CELoss, Acc, cfg["device"], verbose=True)
_ = torch.quantization.convert(ResNet20PTQint2, inplace=True)

print("Evaluate")
ce, acc = validate(testloader, ResNet20PTQint2, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

torch.save(ResNet20PTQint2.state_dict(), cfg["checkpoint_path"]+"_int2")

Compute Statistics


100%|██████████| 1/1 [00:01<00:00,  1.13s/it]


Evaluate


100%|██████████| 10/10 [00:22<00:00,  2.25s/it]

Cross Entropy: 21.469, Accuracy: 0.085





In [8]:
qconfig_int4 = torch.quantization.qconfig.QConfig(
    activation=MinMaxObserver.with_args(
        dtype=torch.quint8,
        quant_min=0,
        quant_max=15
    ),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min =-8,
        quant_max =7,
    )
)

ResNet20PTQint4 = torch.quantization.QuantWrapper(
    copy.deepcopy(ResNet20)
)

ResNet20PTQint4.qconfig = qconfig_int4
torch.quantization.prepare(ResNet20PTQint4, inplace=True)

print("Compute Statistics")
ce, acc = validate(qloader, ResNet20PTQint4, CELoss, Acc, cfg["device"], verbose=True)
_ = torch.quantization.convert(ResNet20PTQint4, inplace=True)
print("Evaluate")
ce, acc = validate(testloader, ResNet20PTQint4, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

torch.save(ResNet20PTQint4.state_dict(), cfg["checkpoint_path"]+"_int4")

Compute Statistics


100%|██████████| 1/1 [00:00<00:00,  1.07it/s]


Evaluate


100%|██████████| 10/10 [00:22<00:00,  2.27s/it]

Cross Entropy: 8.334, Accuracy: 0.247





In [9]:
qconfig_int8 = torch.quantization.qconfig.QConfig(
    activation=MinMaxObserver.with_args(
        dtype=torch.quint8
    ),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8
    )
)

ResNet20PTQint8 = torch.quantization.QuantWrapper(
    copy.deepcopy(ResNet20)
)

ResNet20PTQint8.qconfig = qconfig_int8
torch.quantization.prepare(ResNet20PTQint8, inplace=True)

print("Compute Statistics")
ResNet20PTQint8.eval()
ce, acc = validate(qloader, ResNet20PTQint8, CELoss, Acc, cfg["device"], verbose=True)
_ = torch.quantization.convert(ResNet20PTQint8, inplace=True)
ResNet20PTQint8.eval()
print("Evaluate")
ce, acc = validate(testloader, ResNet20PTQint8, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

torch.save(ResNet20PTQint8.state_dict(), cfg["checkpoint_path"]+"_int8")

Compute Statistics


100%|██████████| 1/1 [00:00<00:00,  1.17it/s]


Evaluate


100%|██████████| 10/10 [00:23<00:00,  2.33s/it]

Cross Entropy: 0.386, Accuracy: 0.922





In [10]:
QuantizedNN = quantize_model(ResNet20)

ce, acc = validate(qloader, QuantizedNN, CELoss, Acc, cfg["device"], verbose=True)
QuantizedNN.apply(compile_module)
print("Quantized Model Size: {:.3f} MB".format(get_model_size(QuantizedNN)))
ce, acc = validate(testloader, QuantizedNN, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

100%|██████████| 1/1 [00:00<00:00,  1.12it/s]


Quantized Model Size: 0.281 MB


100%|██████████| 10/10 [00:16<00:00,  1.62s/it]

Cross Entropy: 2.215, Accuracy: 0.726





#### Compare size and speed execution

In [12]:
print("Normal Model Size: {:.3f} MB".format(get_model_size(ResNet20)))
time_norm = []
for i in (pbar := tqdm(range(5))):
    start_time = time()
    validate(testloader, ResNet20, CELoss, Acc, torch.device("cpu"))
    time_norm.append(time() - start_time)
    pbar.set_description('Normal Model Time {:.3f}s \u00B1 {:.3f}s'.format(np.mean(time_norm), np.std(time_norm)))

print("Quantized Model Size: {:.3f} MB".format(get_model_size(ResNet20PTQint8)))
time_quant = []
for i in (pbar := tqdm(range(5))):
    start_time = time()
    validate(testloader, ResNet20PTQint8, CELoss, Acc, torch.device("cpu"))
    time_quant.append(time() - start_time)
    pbar.set_description('Qunatized Model Time {:.3f}s \u00B1 {:.3f}s'.format(np.mean(time_quant), np.std(time_quant)))

print("Hands-on Quantized Model Size: {:.3f} MB".format(get_model_size(QuantizedNN)))
time_quant = []
for i in (pbar := tqdm(range(5))):
    start_time = time()
    validate(testloader, QuantizedNN, CELoss, Acc, torch.device("cpu"))
    time_quant.append(time() - start_time)
    pbar.set_description('Qunatized Model Time {:.3f}s \u00B1 {:.3f}s'.format(np.mean(time_quant), np.std(time_quant)))

Normal Model Size: 1.124 MB


Normal Model Time 34.947s ± 0.397s: 100%|██████████| 5/5 [02:54<00:00, 34.95s/it]


Quantized Model Size: 0.012 MB


Qunatized Model Time 21.990s ± 0.057s: 100%|██████████| 5/5 [01:49<00:00, 21.99s/it]


Hands-on Quantized Model Size: 0.281 MB


Qunatized Model Time 15.654s ± 0.077s: 100%|██████████| 5/5 [01:18<00:00, 15.66s/it]
