In [19]:
### Utilities
import os
import json
import numpy as np
from time import time
from tqdm import tqdm
import copy

### Torch
import torch
from torch import nn
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import MultiStepLR

### Quantization
from torch.quantization.qconfig import QConfig
from torch.quantization.observer import MinMaxObserver

### Custom
from src.model import ResNet20
from src.train import validate
from src.utils import Accuracy, get_model_size
from src.qmodel import *

### Configuration

In [20]:
with open("cfg/new_64_200.json") as configurations:
    cfg, cfg_CIFAR, cfg_dataloader_train, cfg_dataloader_test, cfg_train = json.load(configurations).values()

### Data

In [21]:
trainset = CIFAR10(transform=transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, 4),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=True)
trainloader = DataLoader(trainset, **cfg_dataloader_train)

testset = CIFAR10(transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=False)
testloader = DataLoader(testset, **cfg_dataloader_test)

n_q, idxes = 10, torch.randperm(len(trainset))
qloader = DataLoader([trainset[idxes[i]] for i in range(n_q)], **cfg_dataloader_train)

In [22]:
is_cuda = torch.cuda.is_available()
from_checkpoint = os.path.exists(cfg["checkpoint_path"])
cfg["device"] = torch.device("cuda") if is_cuda \
            else torch.device("cpu")


In [23]:
ResNet = ResNet20().to(cfg["device"])
CELoss = nn.CrossEntropyLoss(reduction="sum")
Acc = Accuracy(reduction="sum")

### Model

In [24]:

optimResNet = AdamW(ResNet.parameters(), lr=1e-2)
schedResNet = MultiStepLR(optimResNet, last_epoch=-1,
                            milestones=[100, 150], gamma=0.1)

if from_checkpoint:
    checkpoint = torch.load(cfg["checkpoint_path"], map_location=cfg["device"])
    last_epoch = checkpoint["epoch"] + 1
    best_acc = checkpoint["best_acc"]
    ResNet.load_state_dict(checkpoint["state_dict"])
    optimResNet.load_state_dict(checkpoint["optimizer"])
    schedResNet.load_state_dict(checkpoint["scheduler"])
    
CELoss = nn.CrossEntropyLoss(reduction="sum")
Acc = Accuracy(reduction="sum")

In [25]:
print("Load Checkpoint")
ce, acc = validate(qloader, ResNet, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

Load Checkpoint


100%|██████████| 1/1 [00:01<00:00,  1.10s/it]

Cross Entropy: 0.087, Accuracy: 0.900





In [26]:
qconfig_int2 = QConfig(
    activation=MinMaxObserver.with_args(
        dtype=torch.quint8,
        quant_min=0,
        quant_max=3
    ),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min=-2,
        quant_max=1
    )
)
    
ResNetPTQint2 = torch.quantization.QuantWrapper(
    copy.deepcopy(ResNet)
)

ResNetPTQint2.qconfig = qconfig_int2
torch.quantization.prepare(ResNetPTQint2, inplace=True)

print("Compute Statistics")
ce, acc = validate(qloader, ResNetPTQint2, CELoss, Acc, cfg["device"], verbose=True)
_ = torch.quantization.convert(ResNetPTQint2, inplace=True)

print("Evaluate")
ce, acc = validate(testloader, ResNetPTQint2, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

torch.save(ResNetPTQint2.state_dict(), cfg["checkpoint_path"]+"_int2")

Compute Statistics


100%|██████████| 1/1 [00:00<00:00,  1.30it/s]


Evaluate


100%|██████████| 10/10 [00:23<00:00,  2.40s/it]

Cross Entropy: 22.003, Accuracy: 0.097





In [27]:
qconfig_int4 = torch.quantization.qconfig.QConfig(
    activation=MinMaxObserver.with_args(
        dtype=torch.quint8,
        quant_min=0,
        quant_max=15
    ),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min =-8,
        quant_max =7,
    )
)

ResNetPTQint4 = torch.quantization.QuantWrapper(
    copy.deepcopy(ResNet)
)

ResNetPTQint4.qconfig = qconfig_int4
torch.quantization.prepare(ResNetPTQint4, inplace=True)

print("Compute Statistics")
ce, acc = validate(qloader, ResNetPTQint4, CELoss, Acc, cfg["device"], verbose=True)
_ = torch.quantization.convert(ResNetPTQint4, inplace=True)
print("Evaluate")
ce, acc = validate(testloader, ResNetPTQint4, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

torch.save(ResNetPTQint4.state_dict(), cfg["checkpoint_path"]+"_int4")

Compute Statistics


100%|██████████| 1/1 [00:00<00:00,  1.33it/s]


Evaluate


100%|██████████| 10/10 [00:24<00:00,  2.44s/it]

Cross Entropy: 8.808, Accuracy: 0.266





In [28]:
qconfig_int8 = torch.quantization.qconfig.QConfig(
    activation=MinMaxObserver.with_args(
        dtype=torch.quint8
    ),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8
    )
)

ResNetPTQint8 = torch.quantization.QuantWrapper(
    copy.deepcopy(ResNet)
)

ResNetPTQint8.qconfig = qconfig_int8
torch.quantization.prepare(ResNetPTQint8, inplace=True)

print("Compute Statistics")
ResNetPTQint8.eval()
ce, acc = validate(qloader, ResNetPTQint8, CELoss, Acc, cfg["device"], verbose=True)
_ = torch.quantization.convert(ResNetPTQint8, inplace=True)
ResNetPTQint8.eval()
print("Evaluate")
ce, acc = validate(testloader, ResNetPTQint8, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

torch.save(ResNetPTQint8.state_dict(), cfg["checkpoint_path"]+"_int8")

Compute Statistics


100%|██████████| 1/1 [00:02<00:00,  2.09s/it]


Evaluate


100%|██████████| 10/10 [00:29<00:00,  2.91s/it]

Cross Entropy: 0.416, Accuracy: 0.880





In [29]:
QuantizedNN = quantize_merge_model(ResNet)
ce, acc = validate(qloader, QuantizedNN, CELoss, Acc, cfg["device"], verbose=True)
QuantizedNN.apply(compile_module)
print("Quantized Model Size: {:.3f} MB".format(get_model_size(QuantizedNN)))
ce, acc = validate(testloader, QuantizedNN, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

100%|██████████| 1/1 [00:01<00:00,  1.51s/it]


Quantized Model Size: 0.370 MB


100%|██████████| 10/10 [00:17<00:00,  1.72s/it]

Cross Entropy: 0.779, Accuracy: 0.830





#### Compare size and speed execution

In [30]:
print("Normal Model Size: {:.3f} MB".format(get_model_size(ResNet)))
time_norm = []
for i in (pbar := tqdm(range(5))):
    start_time = time()
    validate(testloader, ResNet, CELoss, Acc, torch.device("cpu"))
    time_norm.append(time() - start_time)
    pbar.set_description('Normal Model Time {:.3f}s \u00B1 {:.3f}s'.format(np.mean(time_norm), np.std(time_norm)))

print("Quantized Model Size: {:.3f} MB".format(get_model_size(ResNetPTQint8)))
time_quant = []
for i in (pbar := tqdm(range(5))):
    start_time = time()
    validate(testloader, ResNetPTQint8, CELoss, Acc, torch.device("cpu"))
    time_quant.append(time() - start_time)
    pbar.set_description('Qunatized Model Time {:.3f}s \u00B1 {:.3f}s'.format(np.mean(time_quant), np.std(time_quant)))

print("Hands-on Quantized Model Size: {:.3f} MB".format(get_model_size(QuantizedNN)))
time_quant = []
for i in (pbar := tqdm(range(5))):
    start_time = time()
    validate(testloader, QuantizedNN, CELoss, Acc, torch.device("cpu"))
    time_quant.append(time() - start_time)
    pbar.set_description('Qunatized Model Time {:.3f}s \u00B1 {:.3f}s'.format(np.mean(time_quant), np.std(time_quant)))

Normal Model Size: 1.478 MB


Normal Model Time 43.272s ± 5.276s: 100%|██████████| 5/5 [03:36<00:00, 43.27s/it]


Quantized Model Size: 0.015 MB


Qunatized Model Time 22.778s ± 0.090s: 100%|██████████| 5/5 [01:53<00:00, 22.78s/it]


Hands-on Quantized Model Size: 0.370 MB


Qunatized Model Time 16.079s ± 0.023s: 100%|██████████| 5/5 [01:20<00:00, 16.08s/it]
