# ExecuTorchを用いたQuantization Aware Training実装

In [1]:
import os
from pathlib import Path
import copy
import time
import warnings
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch._export import capture_pre_autograd_graph
from torch.export import export, ExportedProgram
from torch.ao.quantization import move_exported_model_to_eval
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from torchvision.transforms import transforms
from torchvision.datasets import MNIST
import timm

from utils import AverageMeter, seed_everything

warnings.filterwarnings("ignore")
logging.getLogger().setLevel(logging.ERROR)
logging.getLogger('torch._export').setLevel(logging.ERROR)

In [2]:
seed_everything(42)
DATA_DIR = Path(os.path.expanduser("~")) / "Data"
device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")

In [3]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
])
train_dataset = MNIST(root=DATA_DIR, train=True, download=True, transform=transform)
test_dataset = MNIST(root=DATA_DIR, train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

In [4]:
model = timm.create_model('resnet18d', pretrained=True, in_chans=1, num_classes=10).to(device)
example_inputs = (torch.randn(2, 1, 28, 28).to(device),)
dynamic_shapes = {
    "x": {0: torch.export.Dim("batch", min=2, max=1024)},
}
pre_aten_model = capture_pre_autograd_graph(
    model,
    example_inputs,
    dynamic_shapes=dynamic_shapes,
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(pre_aten_model, quantizer)

In [5]:
def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified
    values of k.
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, device):
    if isinstance(model, ExportedProgram):
        move_exported_model_to_eval(model)
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    with torch.no_grad():
        for image, target in data_loader:
            image = image.to(device)
            target = target.to(device)
            output = model(image)
            loss = criterion(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))

    return top1, top5


def train_one_epoch(model, criterion, optimizer, data_loader, device):
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    avgloss = AverageMeter('Loss', '1.5f')
    for image, target in data_loader:
        start_time = time.time()
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        top1.update(acc1[0], image.size(0))
        top5.update(acc5[0], image.size(0))
        avgloss.update(loss, image.size(0))

    print(f'train set:  * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}')
    return

In [6]:
n_epoch = 10
num_observer_update_epochs = 8
num_batch_norm_update_epochs = 8
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(prepared_model.parameters(), lr=1e-3, momentum=0.9)

for epoch in range(n_epoch):
    train_one_epoch(
        model=prepared_model,
        criterion=criterion,
        optimizer=optimizer,
        data_loader=train_loader,
        device=device,
    )

    # Optionally disable observer/batchnorm stats after certain number of epochs
    if epoch >= num_observer_update_epochs:
        # Freeze quantizer parameters
        prepared_model.apply(torch.ao.quantization.disable_observer)
    if epoch >= num_batch_norm_update_epochs:
        # Freeze batch norm mean and variance estimates
        prepared_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    prepared_model_copy = copy.deepcopy(prepared_model)
    quantized_model = convert_pt2e(prepared_model_copy)
    top1, _ = evaluate(quantized_model, criterion, test_loader, device)
    print(f'Epoch {epoch+1:2d}: Evaluation accuracy : {top1.avg:.2f}')

train set:  * Acc@1 38.772 Acc@5 76.488
Epoch  1: Evaluation accuracy : 67.91
train set:  * Acc@1 80.428 Acc@5 96.277
Epoch  2: Evaluation accuracy : 89.21
train set:  * Acc@1 91.373 Acc@5 98.880
Epoch  3: Evaluation accuracy : 93.44
train set:  * Acc@1 94.445 Acc@5 99.450
Epoch  4: Evaluation accuracy : 95.01
train set:  * Acc@1 95.778 Acc@5 99.642
Epoch  5: Evaluation accuracy : 95.80
train set:  * Acc@1 96.555 Acc@5 99.710
Epoch  6: Evaluation accuracy : 96.35
train set:  * Acc@1 97.043 Acc@5 99.788
Epoch  7: Evaluation accuracy : 96.58
train set:  * Acc@1 97.468 Acc@5 99.837
Epoch  8: Evaluation accuracy : 97.04
train set:  * Acc@1 97.817 Acc@5 99.907
Epoch  9: Evaluation accuracy : 97.30
train set:  * Acc@1 98.045 Acc@5 99.905
Epoch 10: Evaluation accuracy : 97.41


In [7]:
quantized_model = convert_pt2e(prepared_model.to(torch.device("cpu")))
move_exported_model_to_eval(quantized_model)

example_inputs = (torch.randn(1, 1, 28, 28).to(torch.device('cpu')),)
core_aten_ep = export(quantized_model, example_inputs)
edge_m = to_edge(
    core_aten_ep,
    compile_config=EdgeCompileConfig(_check_ir_validity=False),
)

edge_m = edge_m.to_backend(XnnpackPartitioner())
exec_prog = edge_m.to_executorch()

exec_prog_path = f'models/resnet18_qat_ep{n_epoch}.pte'
with open(exec_prog_path, 'wb') as f:
    f.write(exec_prog.buffer)