In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import time
import copy
import numpy as np
from tqdm import tqdm
from torchvision import datasets, transforms
from torchvision.models import resnet18

random_seed = 0
num_classes = 10  # CIFAR10 datasets has 10 image classes

# Define model file paths
model_dir = "./data"
model_fn = "resnet18_cifar10.pt"
quantized_model_fn = "resnet18_quantized_cifar10.pt"
model_file_path = os.path.join(model_dir, model_fn)
quantized_model_file_path = os.path.join(model_dir, quantized_model_fn)

# Set device
device = "cpu:0"

In [2]:
# Get resnet model to train and test
model = resnet18(num_classes=num_classes, weights=None)

In [None]:
def set_random_seeds(random_seed=0):
    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)


def prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256):
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ]
    )

    train_set = torchvision.datasets.CIFAR10(
        root="data", train=True, download=True, transform=train_transform
    )
    # We will use test set for validation and test in this project.
    # Do not use test set for validation in practice!
    test_set = torchvision.datasets.CIFAR10(
        root="data", train=False, download=True, transform=test_transform
    )

    train_sampler = torch.utils.data.RandomSampler(train_set)
    test_sampler = torch.utils.data.SequentialSampler(test_set)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=train_batch_size,
        sampler=train_sampler,
        num_workers=num_workers,
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=eval_batch_size,
        sampler=test_sampler,
        num_workers=num_workers,
    )

    return train_loader, test_loader


def evaluate_model(model, test_loader, device, criterion=None):
    running_loss = 0
    running_corrects = 0

    model.eval()
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)

    return eval_loss, eval_accuracy


def train_model(
    model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200
):
    # The training configurations were not carefully selected.
    # Settings below is just to represent the normal training process.
    criterion = nn.CrossEntropyLoss()
    model.to(device)

    # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
    optimizer = optim.SGD(
        model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4
    )
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[100, 150], gamma=0.1, last_epoch=-1
    )
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

    # Evaluation
    eval_loss, eval_accuracy = evaluate_model(
        model=model, test_loader=test_loader, device=device, criterion=criterion
    )
    print(
        "Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
            -1, eval_loss, eval_accuracy
        )
    )

    for epoch in range(num_epochs):
        # Training
        model.train()

        running_loss = 0
        running_corrects = 0

        data_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for inputs, labels in data_iterator:
            # zero the parameter gradients
            optimizer.zero_grad()

            inputs = inputs.to(device)
            labels = labels.to(device)

            # forward + backward + optimize
            outputs = model(inputs)  # |batch_size, num_classes|
            _, preds = torch.max(outputs, dim=1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # statistics
            data_iterator.set_postfix(loss=loss.item())
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        # Evaluation
        model.eval()
        eval_loss, eval_accuracy = evaluate_model(
            model=model, test_loader=test_loader, device=device, criterion=criterion
        )

        # Set learning rate scheduler
        scheduler.step()

        print(
            "Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
                epoch, train_loss, train_accuracy, eval_loss, eval_accuracy
            )
        )

    return model


def save_model(model, model_dir, model_filename):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_filepath = os.path.join(model_dir, model_filename)
    torch.save(model.state_dict(), model_filepath)


# Set random seed
set_random_seeds(random_seed=0)

# Get train dataset loader and test dataset loader
train_loader, test_loader = prepare_dataloader(
    num_workers=1, train_batch_size=4, eval_batch_size=8
)

# Get resnet model to train and test
model = resnet18(num_classes=num_classes, weights=None)

# Train model
model = train_model(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    learning_rate=1e-1,
    num_epochs=10,
)

# Save model
save_model(model=model, model_dir=model_dir, model_filename=model_fn)

In [None]:
def load_model(model, model_filepath, device):
    model.load_state_dict(torch.load(model_filepath, map_location=device))

    return model


def model_equivalence(
    model_1,
    model_2,
    device,
    rtol=1e-05,
    atol=1e-08,
    num_tests=100,
    input_size=(1, 3, 32, 32),
):
    model_1.to(device)
    model_2.to(device)

    for _ in range(num_tests):
        x = torch.rand(size=input_size).to(device)
        y1 = model_1(x).detach().cpu().numpy()
        y2 = model_2(x).detach().cpu().numpy()
        if np.allclose(a=y1, b=y2, rtol=rtol, atol=atol, equal_nan=False) == False:
            print("Model equivalence test sample failed: ")
            print(y1, y2)
            return False

    return True


# Load pretrained model
model = load_model(model, model_file_path, device)

# Move the model to CPU since static quantization does not support CUDA currently.
model.to("cpu:0")
fused_model = copy.deepcopy(model)

# Turn on eval() mode to fuse model before training
model.eval()
fused_model.eval()
fused_model = torch.quantization.fuse_modules(
    fused_model, [["conv1", "bn1", "relu"]], inplace=True
)
for module_name, module in fused_model.named_children():
    if "layer" in module_name:
        for basic_block_name, basic_block in module.named_children():
            torch.quantization.fuse_modules(
                basic_block, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]], inplace=True
            )
            for sub_block_name, sub_block in basic_block.named_children():
                if sub_block_name == "downsample":
                    torch.quantization.fuse_modules(
                        sub_block, [["0", "1"]], inplace=True
                    )

# Check if the results are same
assert model_equivalence(
    model_1=model,
    model_2=fused_model,
    device="cpu:0",
    rtol=1e-03,
    atol=1e-06,
    num_tests=100,
    input_size=(1, 3, 32, 32),
), "Fused model is not equivalent to the original model!"

In [None]:
class QuantizedResNet18(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedResNet18, self).__init__()
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = torch.quantization.QuantStub()
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        self.dequant = torch.quantization.DeQuantStub()
        # FP32 model
        self.model_fp32 = model_fp32

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.model_fp32(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x


# Prepare the model for quantization aware training. This inserts observers in
# the model that will observe activation tensors during calibration.
quantized_model = QuantizedResNet18(model_fp32=fused_model)
# Using un-fused model will fail.
# Because there is no quantized layer implementation for a single batch normalization layer.
# quantized_model = QuantizedResNet18(model_fp32=model)

# Select quantization schemes from
# https://pytorch.org/docs/stable/quantization-support.html
print(torch.backends.quantized.supported_engines)
torch.backends.quantized.engine = "qnnpack"
quantization_config = torch.ao.quantization.default_qconfig
# Custom quantization configurations
# quantization_config = torch.quantization.default_qconfig
# quantization_config = torch.quantization.QConfig(activation=torch.quantization.MinMaxObserver.with_args(dtype=torch.quint8), weight=torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
quantized_model.qconfig = quantization_config

# Print quantization configurations
print(quantized_model.qconfig)

In [None]:
def save_torchscript_model(model, model_dir, model_filename):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_filepath = os.path.join(model_dir, model_filename)
    torch.jit.save(torch.jit.script(model), model_filepath)


def load_torchscript_model(model_filepath, device):
    model = torch.jit.load(model_filepath, map_location=device)

    return model


def create_model(num_classes=10):
    # The number of channels in ResNet18 is divisible by 8.
    # This is required for fast GEMM integer matrix multiplication.
    # model = torchvision.models.resnet18(pretrained=False)
    model = resnet18(num_classes=num_classes, pretrained=False)

    # We would use the pretrained ResNet18 as a feature extractor.
    # for param in model.parameters():
    #     param.requires_grad = False

    # Modify the last FC layer
    # num_features = model.fc.in_features
    # model.fc = nn.Linear(num_features, 10)

    return model


# https://pytorch.org/docs/stable/_modules/torch/quantization/quantize.html#prepare_qat
torch.quantization.prepare_qat(quantized_model, inplace=True)

# # Use training data for calibration.
print("Training QAT Model...")
quantized_model.train()

train_model(
    model=quantized_model,
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    learning_rate=1e-3,
    num_epochs=10,
)

quantized_model.to("cpu:0")
# Using high-level static quantization wrapper
# The above steps, including torch.quantization.prepare, calibrate_model, and torch.quantization.convert, are also equivalent to
# quantized_model = torch.quantization.quantize_qat(model=quantized_model, run_fn=train_model, run_args=[train_loader, test_loader, cuda_device], mapping=None, inplace=False)
# ```
# def quantize_qat(model, run_fn, run_args, inplace=False):
#     r"""Do quantization aware training and output a quantized model

#     Args:
#         model: input model
#         run_fn: a function for evaluating the prepared model, can be a
#                 function that simply runs the prepared model or a training
#                 loop
#         run_args: positional arguments for `run_fn`

#     Return:
#         Quantized model.
#     """
#     torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
#     if not inplace:
#         model = copy.deepcopy(model)
#     model.train()
#     prepare_qat(model, inplace=True)
#     run_fn(model, *run_args)
#     convert(model, inplace=True)
#     return model
# ```

quantized_model = torch.quantization.convert(quantized_model, inplace=True)
quantized_model.eval()

# Print quantized model.
print(quantized_model)

# Save quantized model.
save_torchscript_model(
    model=quantized_model, model_dir=model_dir, model_filename=quantized_model_fn
)

# Load quantized model.
quantized_jit_model = load_torchscript_model(
    model_filepath=quantized_model_file_path, device="cpu:0"
)

_, fp32_eval_accuracy = evaluate_model(
    model=model, test_loader=test_loader, device="cpu:0", criterion=None
)
_, int8_eval_accuracy = evaluate_model(
    model=quantized_jit_model, test_loader=test_loader, device="cpu:0", criterion=None
)