In [69]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from torch.quantization import FakeQuantize, QConfig

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os

In [70]:
def get_activation_function(activation: str, quantize=False):
    if activation == 'relu6':
        # ReLU6 is already quantizable
        return nn.ReLU6(inplace=True)
    if activation == 'relu':
        act = nn.ReLU(inplace=True)
    elif activation == 'hardtanh':
        act = nn.Hardtanh(inplace=True)
    else:
        raise ValueError("Unsupported activation: %s" % activation)
    
    if not quantize:
        return act
    else:
        # Put QuantStub after activation, to force quantization of it's output
        # (By default, Pytorch framework supposes that user fuses ReLU with previous layers
        # but you can't fuse with Hardtanh anyway)
        return nn.Sequential(act, torch.quantization.QuantStub())


In [71]:
import copy


class LeNet5(nn.Module):
    def __init__(self, num_classes=10, activation="relu", quantize=False):
        super(LeNet5, self).__init__()
        self.quantize = quantize
        if quantize:
            self.quant = torch.quantization.QuantStub()
            self.dequant = torch.quantization.DeQuantStub()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5),
            # Never quantize first activation func
            get_activation_function(activation, quantize=False),
            nn.MaxPool2d(2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5),
            get_activation_function(activation, quantize=quantize),
            nn.MaxPool2d(2),
        )
        self.fc1 = nn.Sequential(
            nn.Linear(400, 120), get_activation_function(activation, quantize=quantize)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84), get_activation_function(activation, quantize=quantize)
        )
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        if self.quantize:
            x = self.quant(x)
            
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)

        if self.quantize:
            x = self.dequant(x)
        x = self.fc3(x)
        return x

In [72]:
def conv3x3(in_channels: int, out_channels: int, stride: int = 1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
    )


def conv1x1(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d:
    return nn.Conv2d(
        in_channels, out_channels, kernel_size=1, stride=stride, bias=False
    )


class ResidualBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, stride=1, downsample=None, activation="relu", quantize=False
    ):
        super(ResidualBlock, self).__init__()

        self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.act_fn1 = get_activation_function(activation, quantize=quantize)

        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = downsample
        self.skip_add = nn.quantized.FloatFunctional()

        # Remember to use two independent ReLU for layer fusion.
        self.act_fn2 = get_activation_function(activation, quantize=quantize)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act_fn1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        # Use FloatFunctional for addition for quantization compatibility
        out = self.skip_add.add(residual, out)
        out = self.act_fn2(out)

        return out

In [73]:
import copy


class ResNet(nn.Module):
    def __init__(
        self,
        block,
        layers,
        num_classes=10,
        activation="relu",
        initial_channels=16,
        quantize=False,
    ):
        super(ResNet, self).__init__()
        self.quantize = quantize
        if quantize:
            self.quant = torch.quantization.QuantStub()
            self.dequant = torch.quantization.DeQuantStub()

        self.in_channels = initial_channels
        self.initial_layer = nn.Sequential(
            conv3x3(3, initial_channels),
            nn.BatchNorm2d(initial_channels),
            # Never quantize first activation func
            get_activation_function(activation, quantize=False)
        )

        self.layer1 = self.make_layer(
            block, initial_channels, layers[0], activation=activation
        )
        self.layer2 = self.make_layer(
            block, initial_channels * 2, layers[1], stride=2, activation=activation, 
        )
        self.layer3 = self.make_layer(
            block, initial_channels * 4, layers[2], stride=2, activation=activation
        )
        if len(layers) == 4:
            self.layer4 = self.make_layer(
                block, initial_channels * 8, layers[3], stride=2, activation=activation
            )
        else:
            self.layer4 = None

        pool_size = 8 if initial_channels == 16 else 4
        fc_in_features = (
            initial_channels * 4 if len(layers) == 3 else initial_channels * 8
        )
        self.avg_pool = nn.AvgPool2d(pool_size)
        self.fc = nn.Linear(fc_in_features, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1, activation="relu"):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                conv1x1(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels),
            )

        layers = []
        layers.append(
            block(self.in_channels, out_channels, stride, downsample, activation, quantize=self.quantize)
        )
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels, activation=activation, quantize=self.quantize))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.initial_layer(x)
        # Start quantizing AFTER first layer
        if self.quantize:
            out = self.quant(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        if self.layer4:
            out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        
        if self.quantize:
            out = self.dequant(out)
        out = self.fc(out)
        return out

In [74]:
def ResNet20(num_classes=10, activation='relu', quantize=False):
    return ResNet(ResidualBlock, [3, 3, 3], num_classes, activation, initial_channels=16, quantize=quantize)

def ResNet18(num_classes=10, activation='relu', quantize=False):
    return ResNet(ResidualBlock, [2, 2, 2, 2], num_classes, activation, initial_channels=64, quantize=quantize)

**Обучение моделей**

In [None]:
cuda_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Quantize-converted models can be run only on CPU
cpu_device = torch.device("cpu")

NUM_EPOCHS = 120
BATCH_SIZE = 128
LR = 1e-3
NUM_CLASSES = 10
QAT_EPOCHS = 50
QAT_LR = 1e-3

In [76]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor(),
    normalize])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize])

train_dataset = datasets.CIFAR10(root='data',
                                 train=True,
                                 transform=train_transform,
                                 download=True)

test_dataset = datasets.CIFAR10(root='data',
                                train=False,
                                transform=test_transform)


train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=8,
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=BATCH_SIZE,
                         num_workers=8,
                         shuffle=False)

Files already downloaded and verified


In [77]:
def evaluate_model(model, device=cuda_device, criterion=None):

    model.eval()
    # We don't want Observers to change quantization params
    model.apply(torch.ao.quantization.disable_observer)
    model.to(device)

    running_loss = 0
    running_corrects = 0

    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)

    model.apply(torch.ao.quantization.enable_observer)
    return eval_loss, eval_accuracy


In [None]:
def train_model(
    model,
    learning_rate,
    num_epochs=NUM_EPOCHS,
    writer=None,
    model_name="model",
    device=cuda_device,
):
    patience = 15
    cur_idle = 0
    criterion = nn.CrossEntropyLoss()

    model.to(device)
    optimizer = optim.SGD(
        model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4
    )

    if "quantized" in model_name:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[30, 40], gamma=0.25, last_epoch=-1
        )
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[70, 100, 130], gamma=0.1, last_epoch=-1
        )

    best_eval_acc = 0
    for epoch in range(num_epochs):
        model.train()

        running_loss = 0
        running_corrects = 0

        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        eval_loss, eval_accuracy = evaluate_model(
            model=model, device=device, criterion=criterion
        )

        scheduler.step()
        if writer:
            writer.add_scalar(f"TrainLoss/{model_name}", train_loss, epoch)
            writer.add_scalar(f"TrainAcc/{model_name}", train_accuracy, epoch)
            writer.add_scalar(f"EvalLoss/{model_name}", eval_loss, epoch)
            writer.add_scalar(f"EvalAcc/{model_name}", eval_accuracy, epoch)

        print(
            "[{}] Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
                model_name,
                epoch + 1,
                train_loss,
                train_accuracy,
                eval_loss,
                eval_accuracy,
            )
        )

        cur_idle += 1
        if eval_accuracy > best_eval_acc:
            best_eval_acc = eval_accuracy
            torch.save(model.state_dict(), f"checkpoint/best_{model_name}.ckpt")
            cur_idle = 0
        elif cur_idle >= patience:
            print("Early stopping was triggered!")
            break

In [79]:
def train_orig_model(model_class, activation):
    model_name = f"{model_class.__name__}_{activation}"
    model = model_class(activation=activation).to(cuda_device)
    writer = SummaryWriter(f"runs/{model_name}")

    train_model(
        model,
        learning_rate=LR,
        num_epochs=NUM_EPOCHS,
        writer=writer,
        model_name=model_name,
        device=cuda_device,
    )

    model.eval()
    torch.save(model.state_dict(), f"checkpoint/{model_name}.ckpt")
    writer.close()
    return model

In [80]:
def configure_qat(model, activation_bitwidth=4, weight_bitwidth=4):
    # Fake quantizer for activations
    fq_activation = FakeQuantize.with_args(
        observer=torch.quantization.MovingAverageMinMaxObserver.with_args(
            quant_min=0,
            quant_max=2**activation_bitwidth - 1,
            dtype=torch.quint8,
            qscheme=torch.per_tensor_affine,
            reduce_range=False,
        )
    )

    # Fake quantizer for weights
    fq_weights = FakeQuantize.with_args(
        observer=torch.quantization.MovingAveragePerChannelMinMaxObserver.with_args(
            quant_min=-(2**weight_bitwidth) // 2,
            quant_max=(2**weight_bitwidth) // 2 - 1,
            dtype=torch.qint8,
            qscheme=torch.per_channel_symmetric,
            reduce_range=False,
            ch_axis=0,
        )
    )

    # We don't want non-activation layers to quantize it's output
    # (for example, Conv2d)
    # Because it will harm subsequent Activation performance
    # (other solution will be to fuse Activation with previous layer
    # but pytorch can't fuse with anything, except for nn.ReLU)
    weight_only_qconfig = QConfig(
        activation=torch.quantization.NoopObserver.with_args(dtype=torch.float32),
        weight=fq_weights,
    )
    
    activation_qconfig = QConfig(activation=fq_activation, weight=fq_weights)
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Hardtanh, nn.ReLU6, nn.ReLU, torch.quantization.QuantStub, torch.quantization.DeQuantStub)):
            module.qconfig = activation_qconfig
        else:
            module.qconfig = weight_only_qconfig

     # Avoid quantizing first and last layers
    if isinstance(model, LeNet5):
        # First Layer
        model.conv1.qconfig = None
        for name, module in model.conv1.named_modules():
            module.qconfig = None
        model.fc3.qconfig = None # Last Linear

    elif isinstance(model, ResNet):
        model.initial_layer.qconfig = None
        for name, module in model.initial_layer.named_modules():
            module.qconfig = None
        model.fc.qconfig = None # Last Linear

    torch.quantization.prepare_qat(model, inplace=True)

In [None]:
def train_quantized_model(model_class, activation, ckpt_path):
    device=cuda_device
    
    model_name = f"{model_class.__name__}_{activation}_quantized"
    model = model_class(activation=activation, quantize=True).to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device))

    model.train()
    configure_qat(model)
    
    print(f"[{model_name}] after configure_qat:", model)
    
    # fine-tune via QAT
    writer = SummaryWriter(f"runs/{model_name}")
    train_model(model, learning_rate=QAT_LR, num_epochs=QAT_EPOCHS, writer=writer, model_name=model_name, device=device)

    # Save quantized model
    model.eval()
    torch.save(model.state_dict(), f"checkpoint/{model_name}.ckpt")
    writer.close()
    return model

In [82]:
# All possible architectures and activations
models = [LeNet5, ResNet20, ResNet18]
activations = ['relu', 'hardtanh', 'relu6']

In [83]:
# Train all combinations (without quantization)
trained_models = {}
for model_class in models:
    for activation in activations:
        print(f"\nTraining {model_class.__name__} with {activation}")
        trained_models[(model_class.__name__, activation)] = train_orig_model(model_class, activation)


Training LeNet5 with relu


[LeNet5_relu] Epoch: 001 Train Loss: 2.301 Train Acc: 0.110 Eval Loss: 2.295 Eval Acc: 0.124
[LeNet5_relu] Epoch: 002 Train Loss: 2.270 Train Acc: 0.153 Eval Loss: 2.206 Eval Acc: 0.193


KeyboardInterrupt: 

In [84]:
# Train quantized versions

quantized_models = {}
for model_class in models:
    for activation in activations:
        ckpt_path = f"checkpoint/{model_class.__name__}_{activation}.ckpt"
        print(f"\nQuantizing {model_class.__name__} with {activation}")
        quantized_models[(model_class.__name__, activation)] = train_quantized_model(model_class, activation, ckpt_path)


Quantizing LeNet5 with relu
[LeNet5_relu_quantized] initial state: LeNet5(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (conv1): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): Sequential(
      (0): ReLU(inplace=True)
      (1): QuantStub()
    )
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): Sequential(
      (0): ReLU(inplace=True)
      (1): QuantStub()
    )
  )
  (fc2): Sequential(
    (0): Linear(in_features=120, out_features=84, bias=True)
    (1): Sequential(
      (0): ReLU(inplace=True)
      (1): QuantStub()
    )
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
[LeNet5_relu_quantized] after conf

  model.load_state_dict(torch.load(ckpt_path, map_location=device))


[LeNet5_relu_quantized] Epoch: 001 Train Loss: 1.176 Train Acc: 0.584 Eval Loss: 1.060 Eval Acc: 0.622

Quantizing LeNet5 with hardtanh
[LeNet5_hardtanh_quantized] initial state: LeNet5(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (conv1): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): Hardtanh(min_val=-1.0, max_val=1.0, inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): Sequential(
      (0): Hardtanh(min_val=-1.0, max_val=1.0, inplace=True)
      (1): QuantStub()
    )
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): Sequential(
      (0): Hardtanh(min_val=-1.0, max_val=1.0, inplace=True)
      (1): QuantStub()
    )
  )
  (fc2): Sequential(
    (0): Linear(in_features=120, out_

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "initial_layer.0.weight", "initial_layer.1.weight", "initial_layer.1.bias", "initial_layer.1.running_mean", "initial_layer.1.running_var". 
	Unexpected key(s) in state_dict: "conv.weight", "bn.weight", "bn.bias", "bn.running_mean", "bn.running_var", "bn.num_batches_tracked". 

**Сравнение полученных моделей**

In [None]:
results = {}
for model_class in models:
    for activation in activations:
        if model_class != LeNet5 and (model_class == ResNet18 or activation != "relu"):
            continue

        key = (model_class.__name__, activation)
        orig_model = model_class(activation=activation).to(cuda_device)
        quant_model = model_class(activation=activation, quantize=True).to(cuda_device)
        configure_qat(quant_model)
        
        orig_model.eval()
        quant_model.eval()
        orig_model.load_state_dict(torch.load(f"checkpoint/best_{model_class.__name__}_{activation}.ckpt"))
        quant_model.load_state_dict(torch.load(f"checkpoint/best_{model_class.__name__}_{activation}_quantized.ckpt"), strict=True)

        print(quant_model)
        _, orig_acc = evaluate_model(orig_model)
        _, quant_acc = evaluate_model(quant_model)

        results[key] = {
            "orig_acc": orig_acc,
            "quant_acc": quant_acc,
        }


for key, result in results.items():
    print(f"\nModel: {key[0]} with {key[1]}")
    print(f"Original Accuracy: {result['orig_acc']:.4f}%")
    print(f"Quantized Accuracy: {result['quant_acc']:.4f}%")

LeNet5(
  (quant): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=15, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.3172], device='cuda:0'), zero_point=tensor([7], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=-2.1179039478302, max_val=2.640000104904175)
    )
  )
  (dequant): DeQuantStub()
  (conv1): Sequential(
    (0): Conv2d(
      3, 6, kernel_size=(5, 5), stride=(1, 1)
      (weight_fake_quant): FakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-8, quant_max=7, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([0.0826, 0.0611, 0.0373, 0.0597, 0.1061, 0.0572], device='cuda:0')

  orig_model.load_state_dict(torch.load(f"checkpoint/best_{model_class.__name__}_{activation}.ckpt"))
  quant_model.load_state_dict(torch.load(f"checkpoint/best_{model_class.__name__}_{activation}_quantized.ckpt"), strict=True)


LeNet5(
  (quant): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=15, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.3172], device='cuda:0'), zero_point=tensor([7], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=-2.1179039478302, max_val=2.640000104904175)
    )
  )
  (dequant): DeQuantStub()
  (conv1): Sequential(
    (0): Conv2d(
      3, 6, kernel_size=(5, 5), stride=(1, 1)
      (weight_fake_quant): FakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-8, quant_max=7, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([0.0643, 0.0567, 0.0619, 0.0634, 0.0559, 0.0742], device='cuda:0')