In [1]:
### Utilities
import os
import json
import random 
import argparse
import warnings
import numpy as np
from tqdm import tqdm
import copy

### Torch
import torch
from torch import nn
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import MultiStepLR

### Quantization
from torch.quantization.qconfig import QConfig
from torch.quantization.observer import MinMaxObserver, HistogramObserver
from torch.quantization.observer import default_observer

### Custom
from src.model import model
from src.train import train, validate
from src.utils import Accuracy, AverageMeter, get_model_size

### Configuration

In [2]:
with open("cfg/cfg_64_200.json") as configurations:
    cfg, cfg_CIFAR, cfg_dataloader_train, cfg_dataloader_test, cfg_train = json.load(configurations).values()

### Data

In [3]:
trainset = CIFAR10(transform=transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, 4),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=True)
trainloader = DataLoader(trainset, **cfg_dataloader_train)

testset = CIFAR10(transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=False)
testloader = DataLoader(testset, **cfg_dataloader_test)

### Model

In [4]:
is_cuda = torch.cuda.is_available()
from_checkpoint = os.path.exists(cfg["checkpoint_path"])
cfg["device"] = torch.device("cuda") if is_cuda \
            else torch.device("cpu")


ResNet20 = model().to(cfg["device"])
optimResNet20 = AdamW(ResNet20.parameters(), lr=1e-2)
schedResNet20 = MultiStepLR(optimResNet20, last_epoch=-1,
                            milestones=[100, 150], gamma=0.1)

if from_checkpoint:
    checkpoint = torch.load(cfg["checkpoint_path"], map_location=cfg["device"])
    last_epoch = checkpoint["epoch"] + 1
    best_acc = checkpoint["best_acc"]
    ResNet20.load_state_dict(checkpoint["state_dict"])
    optimResNet20.load_state_dict(checkpoint["optimizer"])
    schedResNet20.load_state_dict(checkpoint["scheduler"])
    
CELoss = nn.CrossEntropyLoss(reduction="sum")
Acc = Accuracy(reduction="sum")

In [5]:
print("Load Checkpoint")
ce, acc = validate(testloader, ResNet20, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

Load Checkpoint


100%|██████████| 10/10 [00:37<00:00,  3.78s/it]


Cross Entropy: 0.374, Accuracy: 0.926


In [34]:
qconfig_int2 = QConfig(
    activation=MinMaxObserver.with_args(
        dtype=torch.quint8,
        quant_min=0, 
        quant_max=1
    ),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min=0, 
        quant_max=1
    )
)
    
ResNet20PTQint2 = torch.quantization.QuantWrapper(
    copy.deepcopy(ResNet20)
)

ResNet20PTQint2.qconfig = qconfig_int2
torch.quantization.prepare(ResNet20PTQint2, inplace=True)

print("Compute Statistics")
ResNet20PTQint2.eval()
ce, acc = validate(testloader, ResNet20PTQint2, CELoss, Acc, cfg["device"], verbose=True)
_ = torch.quantization.convert(ResNet20PTQint2, inplace=True)
ResNet20PTQint2.eval()
print("Evaluate")
ce, acc = validate(testloader, ResNet20PTQint2, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

Compute Statistics


100%|██████████| 10/10 [00:37<00:00,  3.72s/it]


Evaluate


100%|██████████| 10/10 [00:23<00:00,  2.32s/it]

Cross Entropy: 2.303, Accuracy: 0.100





In [46]:
qconfig_int4 = torch.quantization.qconfig.QConfig(
    activation=MinMaxObserver.with_args(
        dtype=torch.quint8,
        quant_min=0, 
        quant_max=15
    ),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min=-8, 
        quant_max=7
    )
)

ResNet20PTQint4 = torch.quantization.QuantWrapper(
    copy.deepcopy(ResNet20)
)

ResNet20PTQint4.qconfig = qconfig_int4
torch.quantization.prepare(ResNet20PTQint4, inplace=True)

print("Compute Statistics")
ResNet20PTQint4.eval()
ce, acc = validate(testloader, ResNet20PTQint4, CELoss, Acc, cfg["device"], verbose=True)
_ = torch.quantization.convert(ResNet20PTQint4, inplace=True)
ResNet20PTQint4.eval()
print("Evaluate")
ce, acc = validate(testloader, ResNet20PTQint4, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

Compute Statistics


100%|██████████| 10/10 [00:37<00:00,  3.74s/it]


Evaluate


100%|██████████| 10/10 [00:22<00:00,  2.27s/it]

Cross Entropy: 13.108, Accuracy: 0.168





In [45]:
qconfig_int8 = torch.quantization.qconfig.QConfig(
    activation=MinMaxObserver.with_args(
        dtype=torch.quint8,
        quant_min=0, 
        quant_max=255
    ),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min=-128, 
        quant_max=127,
    )
)

ResNet20PTQint8 = torch.quantization.QuantWrapper(
    copy.deepcopy(ResNet20)
)

ResNet20PTQint8.qconfig = qconfig_int8
torch.quantization.prepare(ResNet20PTQint8, inplace=True)

print("Compute Statistics")
ResNet20PTQint8.eval()
ce, acc = validate(testloader, ResNet20PTQint8, CELoss, Acc, cfg["device"], verbose=True)
_ = torch.quantization.convert(ResNet20PTQint8, inplace=True)
ResNet20PTQint8.eval()
print("Evaluate")
ce, acc = validate(testloader, ResNet20PTQint8, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

Compute Statistics


100%|██████████| 10/10 [00:36<00:00,  3.64s/it]


Evaluate


100%|██████████| 10/10 [00:23<00:00,  2.38s/it]

Cross Entropy: 0.390, Accuracy: 0.922





#### Compare size and speed execution

In [70]:
print("Normal Model Size: {:.3f} MB\nQuantized Model Size: {:.3f} MB".format(
    get_model_size(ResNet20), get_model_size(ResNet20PTQint8)))

Normal Model Size: 1.124 MB
Quantized Model Size: 0.012 MB


Qunatized to int8 Model

In [71]:
%%timeit
_ = validate(testloader, ResNet20PTQint8, CELoss, Acc, torch.device("cpu"))

22.9 s ± 283 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Normal Model

In [72]:
%%timeit
_ = validate(testloader, ResNet20, CELoss, Acc, torch.device("cpu"))

35.8 s ± 480 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
