In [146]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from torch.quantization import FakeQuantize, QConfig

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os

In [147]:
def get_activation_function(activation: str):
    if activation == 'relu':
        return nn.ReLU(inplace=True)
    elif activation == 'hardtanh':
        return nn.Hardtanh(inplace=True)
    elif activation == 'relu6':
        return nn.ReLU6(inplace=True)
    else:
        raise ValueError("Unsupported activation: %s" % activation)

In [148]:
import copy


class LeNet5(nn.Module):
    def __init__(self, num_classes=10, activation="relu", quantize=False):
        super(LeNet5, self).__init__()
        self.quantize = quantize
        if quantize:
            self.quant = torch.quantization.QuantStub()
            self.dequant = torch.quantization.DeQuantStub()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5),
            get_activation_function(activation),
            nn.MaxPool2d(2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5),
            get_activation_function(activation),
            nn.MaxPool2d(2),
        )
        self.fc1 = nn.Sequential(
            nn.Linear(400, 120), get_activation_function(activation)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84), get_activation_function(activation)
        )
        self.fc3 = nn.Linear(84, num_classes)
    
    def forward(self, x):
        if self.quantize:
            x = self.quant(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        if self.quantize:
            x = self.dequant(x)
        return x

In [149]:
def conv3x3(in_channels: int, out_channels: int, stride: int = 1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
    )


def conv1x1(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d:
    return nn.Conv2d(
        in_channels, out_channels, kernel_size=1, stride=stride, bias=False
    )


class ResidualBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, stride=1, downsample=None, activation="relu"
    ):
        super(ResidualBlock, self).__init__()

        self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.act_fn1 = get_activation_function(activation)

        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = downsample
        self.skip_add = nn.quantized.FloatFunctional()

        # Remember to use two independent ReLU for layer fusion.
        self.act_fn2 = get_activation_function(activation)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act_fn1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        # Use FloatFunctional for addition for quantization compatibility
        out = self.skip_add.add(residual, out)
        out = self.act_fn2(out)

        return out

In [150]:
import copy


class ResNet(nn.Module):
    def __init__(
        self,
        block,
        layers,
        num_classes=10,
        activation="relu",
        initial_channels=16,
        quantize=False,
    ):
        super(ResNet, self).__init__()
        self.quantize = quantize
        if quantize:
            self.quant = torch.quantization.QuantStub()
            self.dequant = torch.quantization.DeQuantStub()

        self.in_channels = initial_channels

        self.conv = conv3x3(3, initial_channels)
        self.bn = nn.BatchNorm2d(initial_channels)
        self.act_fn = get_activation_function(activation)

        self.layer1 = self.make_layer(
            block, initial_channels, layers[0], activation=activation
        )
        self.layer2 = self.make_layer(
            block, initial_channels * 2, layers[1], stride=2, activation=activation
        )
        self.layer3 = self.make_layer(
            block, initial_channels * 4, layers[2], stride=2, activation=activation
        )
        if len(layers) == 4:
            self.layer4 = self.make_layer(
                block, initial_channels * 8, layers[3], stride=2, activation=activation
            )
        else:
            self.layer4 = None

        pool_size = 8 if initial_channels == 16 else 4
        fc_in_features = (
            initial_channels * 4 if len(layers) == 3 else initial_channels * 8
        )
        self.avg_pool = nn.AvgPool2d(pool_size)
        self.fc = nn.Linear(fc_in_features, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1, activation="relu"):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                conv1x1(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels),
            )

        layers = []
        layers.append(
            block(self.in_channels, out_channels, stride, downsample, activation)
        )
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels, activation=activation))
        return nn.Sequential(*layers)

    def create_fused_model(self):
        fused_model = copy.deepcopy(self)
        fused_model.train()
        fused_model = torch.quantization.fuse_modules(
            fused_model, [["conv", "bn", "act_fn"]], inplace=True
        )

        for module_name, module in fused_model.named_children():
            if "layer" in module_name:
                for basic_block_name, basic_block in module.named_children():
                    torch.quantization.fuse_modules(
                        basic_block,
                        [["conv1", "bn1", "act_fn1"], ["conv2", "bn2"], ["skip_add", "act_fn2"]],
                        inplace=True,
                    )
                    for sub_block_name, sub_block in basic_block.named_children():
                        if sub_block_name == "downsample":
                            torch.quantization.fuse_modules(
                                sub_block, [["0", "1"]], inplace=True
                            )
        
        return fused_model
    
    def forward(self, x):
        if self.quantize:
            x = self.quant(x)

        out = self.conv(x)
        out = self.bn(out)
        out = self.act_fn(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        if self.layer4:
            out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        if self.quantize:
            x = self.dequant(x)
        return out

In [151]:
def ResNet20(num_classes=10, activation='relu', quantize=False):
    return ResNet(ResidualBlock, [3, 3, 3], num_classes, activation, initial_channels=16, quantize=quantize)

def ResNet18(num_classes=10, activation='relu', quantize=False):
    return ResNet(ResidualBlock, [2, 2, 2, 2], num_classes, activation, initial_channels=64, quantize=quantize)

**Обучение моделей**

In [None]:
cuda_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Quantize-converted models can be run only on CPU
cpu_device = torch.device("cpu")

NUM_EPOCHS = 120
BATCH_SIZE = 128
LR = 1e-3
NUM_CLASSES = 10
QAT_EPOCHS = 30
QAT_LR = 1e-3

In [153]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor(),
    normalize])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize])

train_dataset = datasets.CIFAR10(root='data',
                                 train=True,
                                 transform=train_transform,
                                 download=True)

test_dataset = datasets.CIFAR10(root='data',
                                train=False,
                                transform=test_transform)


train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=8,
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=BATCH_SIZE,
                         num_workers=8,
                         shuffle=False)

Files already downloaded and verified


In [166]:
def evaluate_model(model, device=cuda_device, criterion=None):

    model.eval()
    model.to(device)

    running_loss = 0
    running_corrects = 0

    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)

    return eval_loss, eval_accuracy


In [None]:
def train_model(
    model,
    learning_rate,
    num_epochs=NUM_EPOCHS,
    writer=None,
    model_name="model",
    device=cuda_device,
):
    patience = 15
    cur_idle = 0
    criterion = nn.CrossEntropyLoss()

    model.to(device)
    optimizer = optim.SGD(
        model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4
    )

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[70, 100, 130], gamma=0.1, last_epoch=-1
    )

    eval_loss, eval_accuracy = evaluate_model(
        model=model, device=device, criterion=criterion
    )

    print(
        "[{}] Epoch: {:03d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
            model_name, 0, eval_loss, eval_accuracy
        )
    )
    best_eval_acc = 0
    for epoch in range(num_epochs):
        model.train()

        running_loss = 0
        running_corrects = 0

        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        eval_loss, eval_accuracy = evaluate_model(
            model=model, device=device, criterion=criterion
        )

        scheduler.step()
        if writer:
            writer.add_scalar(f"TrainLoss/{model_name}", train_loss, epoch)
            writer.add_scalar(f"TrainAcc/{model_name}", train_accuracy, epoch)
            writer.add_scalar(f"EvalLoss/{model_name}", eval_loss, epoch)
            writer.add_scalar(f"EvalAcc/{model_name}", eval_accuracy, epoch)

        print(
            "[{}] Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
                model_name,
                epoch + 1,
                train_loss,
                train_accuracy,
                eval_loss,
                eval_accuracy,
            )
        )
        
        cur_idle += 1
        if eval_accuracy > best_eval_acc:
            best_eval_acc = eval_accuracy
            torch.save(model.state_dict(), f"checkpoint/best_{model_name}.ckpt")
            cur_idle = 0
        elif cur_idle >= patience:
            print("Early stopping was triggered!")
            break

In [156]:
def train_orig_model(model_class, activation):
    model_name = f"{model_class.__name__}_{activation}"
    model = model_class(activation=activation).to(cuda_device)
    writer = SummaryWriter(f"runs/{model_name}")

    train_model(
        model,
        learning_rate=LR,
        num_epochs=NUM_EPOCHS,
        writer=writer,
        model_name=model_name,
        device=cuda_device,
    )

    model.eval()
    torch.save(model.state_dict(), f"checkpoint/{model_name}.ckpt")
    writer.close()
    return model

In [None]:
def configure_qat(model, activation_bitwidth=4, weight_bitwidth=4):
    # Fake quantizer for activations
    fq_activation = FakeQuantize.with_args(
        observer=torch.quantization.MovingAverageMinMaxObserver.with_args(
            quant_min=0,
            quant_max=2**activation_bitwidth - 1,
            dtype=torch.quint8,
            qscheme=torch.per_tensor_affine,
            reduce_range=False,
        )
    )

    # Fake quantizer for weights
    fq_weights = FakeQuantize.with_args(
        observer=torch.quantization.MovingAveragePerChannelMinMaxObserver.with_args(
            quant_min=-(2**weight_bitwidth) // 2,
            quant_max=(2**weight_bitwidth) // 2 - 1,
            dtype=torch.qint8,
            qscheme=torch.per_channel_symmetric,
            reduce_range=False,
            ch_axis=0,
        )
    )

    # We don't want non-activation layers to quantize it's output
    # (for example, Conv2d)
    # Because it will harm subsequent Activation performance
    # (other solution will be to fuse Activation with previous layer
    # but pytorch can't fuse with anything, except for nn.ReLU)
    weight_only_qconfig = QConfig(
        activation=torch.quantization.NoopObserver.with_args(dtype=torch.float32),
        weight=fq_weights,
    )
    
    activation_qconfig = QConfig(activation=fq_activation, weight=fq_weights)
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Hardtanh, nn.ReLU6, nn.ReLU, torch.quantization.QuantStub, torch.quantization.DeQuantStub)):
            module.qconfig = activation_qconfig
        else:
            module.qconfig = weight_only_qconfig

    torch.quantization.prepare_qat(model, inplace=True)

In [169]:
def train_quantized_model(model_class, activation, ckpt_path):
    device=cuda_device
    
    model_name = f"{model_class.__name__}_{activation}_quantized"
    model = model_class(activation=activation, quantize=True).to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device))

    model.train()
    print(f"[{model_name}] initial state:", model)
    
    configure_qat(model)
    
    print(f"[{model_name}] after configure_qat:", model)
    
    # fine-tune via QAT
    writer = SummaryWriter(f"runs/{model_name}")
    train_model(model, learning_rate=QAT_LR, num_epochs=QAT_EPOCHS, writer=writer, model_name=model_name, device=device)

    # Convert into quantized model
    model.eval()
    torch.save(model.state_dict(), f"checkpoint/{model_name}.ckpt")
    writer.close()
    return model

In [159]:
# All possible architectures and activations
models = [LeNet5, ResNet20, ResNet18]
activations = ['relu', 'hardtanh', 'relu6']

In [160]:
# Train all combinations (without quantization)
trained_models = {}
for model_class in models:
    for activation in activations:
        print(f"\nTraining {model_class.__name__} with {activation}")
        trained_models[(model_class.__name__, activation)] = train_orig_model(model_class, activation)


Training LeNet5 with relu


[LeNet5_relu] Epoch: 000 Eval Loss: 2.304 Eval Acc: 0.095
[LeNet5_relu] Epoch: 001 Train Loss: 2.298 Train Acc: 0.123 Eval Loss: 2.287 Eval Acc: 0.180
[LeNet5_relu] Epoch: 002 Train Loss: 2.217 Train Acc: 0.203 Eval Loss: 2.069 Eval Acc: 0.253
[LeNet5_relu] Epoch: 003 Train Loss: 2.008 Train Acc: 0.262 Eval Loss: 1.924 Eval Acc: 0.306
[LeNet5_relu] Epoch: 004 Train Loss: 1.908 Train Acc: 0.302 Eval Loss: 1.795 Eval Acc: 0.353
[LeNet5_relu] Epoch: 005 Train Loss: 1.812 Train Acc: 0.334 Eval Loss: 1.705 Eval Acc: 0.377
[LeNet5_relu] Epoch: 006 Train Loss: 1.745 Train Acc: 0.353 Eval Loss: 1.650 Eval Acc: 0.397
[LeNet5_relu] Epoch: 007 Train Loss: 1.703 Train Acc: 0.372 Eval Loss: 1.592 Eval Acc: 0.416
[LeNet5_relu] Epoch: 008 Train Loss: 1.666 Train Acc: 0.382 Eval Loss: 1.557 Eval Acc: 0.434
[LeNet5_relu] Epoch: 009 Train Loss: 1.634 Train Acc: 0.398 Eval Loss: 1.537 Eval Acc: 0.441
[LeNet5_relu] Epoch: 010 Train Loss: 1.607 Train Acc: 0.405 Eval Loss: 1.488 Eval Acc: 0.459
[LeNet5_relu

KeyboardInterrupt: 

In [170]:
# Train quantized versions

quantized_models = {}
for model_class in models:
    for activation in activations:
        ckpt_path = f"checkpoint/{model_class.__name__}_{activation}.ckpt"
        print(f"\nQuantizing {model_class.__name__} with {activation}")
        quantized_models[(model_class.__name__, activation)] = train_quantized_model(model_class, activation, ckpt_path)


Quantizing LeNet5 with relu
[LeNet5_relu_quantized] initial state: LeNet5(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (conv1): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU(inplace=True)
  )
  (fc2): Sequential(
    (0): Linear(in_features=120, out_features=84, bias=True)
    (1): ReLU(inplace=True)
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
[LeNet5_relu_quantized] after configure_qat: LeNet5(
  (quant): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint

  model.load_state_dict(torch.load(ckpt_path, map_location=device))


[LeNet5_relu_quantized] Epoch: 000 Eval Loss: 1.124 Eval Acc: 0.610
[LeNet5_relu_quantized] Epoch: 001 Train Loss: 1.143 Train Acc: 0.595 Eval Loss: 1.027 Eval Acc: 0.642
[LeNet5_relu_quantized] Epoch: 002 Train Loss: 1.150 Train Acc: 0.593 Eval Loss: 1.041 Eval Acc: 0.633
[LeNet5_relu_quantized] Epoch: 003 Train Loss: 1.144 Train Acc: 0.594 Eval Loss: 1.055 Eval Acc: 0.630
[LeNet5_relu_quantized] Epoch: 004 Train Loss: 1.148 Train Acc: 0.594 Eval Loss: 1.051 Eval Acc: 0.634
[LeNet5_relu_quantized] Epoch: 005 Train Loss: 1.144 Train Acc: 0.595 Eval Loss: 1.039 Eval Acc: 0.633
[LeNet5_relu_quantized] Epoch: 006 Train Loss: 1.147 Train Acc: 0.595 Eval Loss: 1.084 Eval Acc: 0.622
[LeNet5_relu_quantized] Epoch: 007 Train Loss: 1.137 Train Acc: 0.597 Eval Loss: 1.028 Eval Acc: 0.644
[LeNet5_relu_quantized] Epoch: 008 Train Loss: 1.135 Train Acc: 0.599 Eval Loss: 1.021 Eval Acc: 0.639
[LeNet5_relu_quantized] Epoch: 009 Train Loss: 1.141 Train Acc: 0.597 Eval Loss: 1.057 Eval Acc: 0.632
[LeNe

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoint/ResNet20_hardtanh.ckpt'

**Сравнение полученных моделей**

In [174]:
results = {}
for model_class in models:
    for activation in activations:
        if model_class != LeNet5 and (model_class == ResNet18 or activation != "relu"):
            continue

        key = (model_class.__name__, activation)
        orig_model = model_class(activation=activation).to(cuda_device)
        quant_model = model_class(activation=activation, quantize=True).to(cuda_device)
        configure_qat(quant_model)
        
        orig_model.eval()
        quant_model.eval()
        orig_model.load_state_dict(torch.load(f"checkpoint/best_{model_class.__name__}_{activation}.ckpt"))
        quant_model.load_state_dict(torch.load(f"checkpoint/best_{model_class.__name__}_{activation}_quantized.ckpt"), strict=True)

        print(quant_model)
        _, orig_acc = evaluate_model(orig_model)
        _, quant_acc = evaluate_model(quant_model)

        results[key] = {
            "orig_acc": orig_acc,
            "quant_acc": quant_acc,
        }


for key, result in results.items():
    print(f"\nModel: {key[0]} with {key[1]}")
    print(f"Original Accuracy: {result['orig_acc']:.4f}%")
    print(f"Quantized Accuracy: {result['quant_acc']:.4f}%")

LeNet5(
  (quant): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=15, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.3172], device='cuda:0'), zero_point=tensor([7], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=-2.1179039478302, max_val=2.640000104904175)
    )
  )
  (dequant): DeQuantStub()
  (conv1): Sequential(
    (0): Conv2d(
      3, 6, kernel_size=(5, 5), stride=(1, 1)
      (weight_fake_quant): FakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-8, quant_max=7, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([0.0826, 0.0611, 0.0373, 0.0597, 0.1061, 0.0572], device='cuda:0')

  orig_model.load_state_dict(torch.load(f"checkpoint/best_{model_class.__name__}_{activation}.ckpt"))
  quant_model.load_state_dict(torch.load(f"checkpoint/best_{model_class.__name__}_{activation}_quantized.ckpt"), strict=True)


LeNet5(
  (quant): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=15, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.3172], device='cuda:0'), zero_point=tensor([7], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=-2.1179039478302, max_val=2.640000104904175)
    )
  )
  (dequant): DeQuantStub()
  (conv1): Sequential(
    (0): Conv2d(
      3, 6, kernel_size=(5, 5), stride=(1, 1)
      (weight_fake_quant): FakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-8, quant_max=7, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([0.0643, 0.0567, 0.0619, 0.0634, 0.0559, 0.0742], device='cuda:0')