In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from torch.quantization import FakeQuantize, QConfig

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os

In [15]:
def get_activation_function(activation: str):
    if activation == 'relu':
        return nn.ReLU(inplace=True)
    elif activation == 'hardtanh':
        return nn.Hardtanh(inplace=True)
    elif activation == 'relu6':
        return nn.ReLU6(inplace=True)
    else:
        raise ValueError("Unsupported activation: %s" % activation)

In [16]:
class LeNet5(nn.Module):
    def __init__(self, num_classes=10, activation='relu', quantize=False):
        super(LeNet5, self).__init__()
        self.quantize = quantize
        if quantize:
            self.quant = torch.quantization.QuantStub()
            self.dequant = torch.quantization.DeQuantStub()

        act_fn = get_activation_function(activation)

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5),
            act_fn,
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5),
            act_fn,
            nn.MaxPool2d(2)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(400, 120),
            act_fn
        )
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            act_fn
        )
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        if self.quantize:
            x = self.quant(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        if self.quantize:
            x = self.dequant(x)
        return x

In [17]:
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, activation='relu'):
        super(ResidualBlock, self).__init__()

        self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.act_fn = get_activation_function(activation)

        self.downsample = downsample

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act_fn(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.act_fn(out)
        return out

In [18]:
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10, activation='relu', initial_channels=16, quantize=False):
        super(ResNet, self).__init__()
        self.quantize = quantize
        if quantize:
            self.quant = torch.quantization.QuantStub()
            self.dequant = torch.quantization.DeQuantStub()

        self.in_channels = initial_channels

        self.conv = conv3x3(3, initial_channels)
        self.bn = nn.BatchNorm2d(initial_channels)

        self.act_fn = get_activation_function(activation)

        self.layer1 = self.make_layer(block, initial_channels, layers[0], activation=activation)
        self.layer2 = self.make_layer(block, initial_channels * 2, layers[1], stride=2, activation=activation)
        self.layer3 = self.make_layer(block, initial_channels * 4, layers[2], stride=2, activation=activation)
        if len(layers) == 4:
            self.layer4 = self.make_layer(block, initial_channels * 8, layers[3], stride=2, activation=activation)
        else:
            self.layer4 = None

        pool_size = 8 if initial_channels == 16 else 4
        fc_in_features = initial_channels * 4 if len(layers) == 3 else initial_channels * 8
        self.avg_pool = nn.AvgPool2d(pool_size)
        self.fc = nn.Linear(fc_in_features, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1, activation='relu'):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample, activation))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels, activation=activation))
        return nn.Sequential(*layers)

    def forward(self, x):
        if self.quantize:
            x = self.quant(x)

        out = self.conv(x)
        out = self.bn(out)
        out = self.act_fn(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        if self.layer4:
            out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        if self.quantize:
          x = self.dequant(x)
        return out

In [19]:
def ResNet20(num_classes=10, activation='relu', quantize=False):
    return ResNet(ResidualBlock, [3, 3, 3], num_classes, activation, initial_channels=16, quantize=quantize)

def ResNet18(num_classes=10, activation='relu', quantize=False):
    return ResNet(ResidualBlock, [2, 2, 2, 2], num_classes, activation, initial_channels=64, quantize=quantize)

**Обучение моделей**

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_EPOCHS = 100
BATCH_SIZE = 100
LR = 0.001
NUM_CLASSES = 10
QAT_EPOCHS = 100
QAT_LR = 0.00025

In [34]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor(),
    normalize])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize])

train_dataset = datasets.CIFAR10(root='data',
                                 train=True,
                                 transform=train_transform,
                                 download=True)

test_dataset = datasets.CIFAR10(root='data',
                                train=False,
                                transform=test_transform)


train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=8,
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=BATCH_SIZE,
                         num_workers=8,
                         shuffle=False)

Using downloaded and verified file: data/cifar-10-python.tar.gz
Extracting data/cifar-10-python.tar.gz to data


In [22]:
def evaluate(model, model_name):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Accuracy of {model_name} on test images: {accuracy:.2f}%')
        return accuracy


In [24]:

def train(model, criterion, optimizer, num_epochs=NUM_EPOCHS, writer=None, model_name="model"):
    total_step = len(train_loader)
    best_acc = 0.0
    patience = 6
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}")
                if writer:
                    writer.add_scalar(f"Loss/{model_name}", loss.item(), epoch * total_step + i)

        if (epoch + 1) % 5 == 0:
            acc = evaluate(model, model_name)
            if writer:
                writer.add_scalar(f"Accuracy/{model_name}", acc, epoch)

            if acc > best_acc:
                best_acc = acc
                patience_counter = 0
                torch.save(model.state_dict(), f"best_{model_name}.ckpt")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping triggered at epoch {epoch+1}")
                    break


In [25]:
def train_model(model_class, activation):
    model_name = f"{model_class.__name__}_{activation}"
    model = model_class(activation=activation).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)
    writer = SummaryWriter(f"runs/{model_name}")

    train(model, criterion, optimizer, writer=writer, model_name=model_name)
    torch.save(model.state_dict(), f"{model_name}.ckpt")
    writer.close()
    return model

In [27]:
def configure_qat(model):
    # Fake quantizer for activations
    activation_quant = FakeQuantize.with_args(
        quant_min=-8,
        quant_max=7,
        dtype=torch.qint8,
        reduce_range=False
    )

    # Fake quantizer for weights
    weight_quant = FakeQuantize.with_args(
        quant_min=-8,
        quant_max=7,
        dtype=torch.qint8,
        reduce_range=False
    )

    # Custom QConfig
    custom_qconfig = QConfig(
        activation=activation_quant,
        weight=weight_quant
    )
    model.qconfig = custom_qconfig

    torch.quantization.prepare_qat(model, inplace=True)

In [28]:
def qat_model(model_class, activation, ckpt_path):
    model_name = f"{model_class.__name__}_{activation}_quantized"
    model = model_class(activation=activation, quantize=True).to(device)
    model.load_state_dict(torch.load(ckpt_path))

    configure_qat(model)

    # fine-tune via QAT
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=QAT_LR)
    writer = SummaryWriter(f"runs/{model_name}")
    train(model, criterion, optimizer, num_epochs=QAT_EPOCHS, writer=writer, model_name=model_name)

    # Convert into quantized model
    model.eval()
    quantized_model = torch.quantization.convert(model, inplace=False)
    torch.save(quantized_model.state_dict(), f"{model_name}.ckpt")
    writer.close()
    return quantized_model

In [None]:
# All possible architectures and activations
models = [LeNet5, ResNet20, ResNet18]
activations = ['relu', 'hardtanh', 'relu6']

# Train all combinations (without quantization)
trained_models = {}
for model_class in models:
    for activation in activations:
        print(f"\nTraining {model_class.__name__} with {activation}")
        trained_models[(model_class.__name__, activation)] = train_model(model_class, activation)

In [69]:
# Train quantized versions

quantized_models = {}
for model_class in models:
    for activation in activations:
        ckpt_path = f"{model_class.__name__}_{activation}.ckpt"
        print(f"\nQuantizing {model_class.__name__} with {activation}")
        quantized_models[(model_class.__name__, activation)] = qat_model(model_class, activation, ckpt_path)


Quantizing LeNet5 with relu


  model.load_state_dict(torch.load(ckpt_path))


Epoch [1/100], Step [100/500], Loss: 2.3495
Epoch [1/100], Step [200/500], Loss: 2.2334
Epoch [1/100], Step [300/500], Loss: 1.7769
Epoch [1/100], Step [400/500], Loss: 2.2118
Epoch [1/100], Step [500/500], Loss: 1.7697
Epoch [2/100], Step [100/500], Loss: 1.8594
Epoch [2/100], Step [200/500], Loss: 1.5725
Epoch [2/100], Step [300/500], Loss: 1.7045
Epoch [2/100], Step [400/500], Loss: 1.7547
Epoch [2/100], Step [500/500], Loss: 1.8432
Epoch [3/100], Step [100/500], Loss: 1.9086
Epoch [3/100], Step [200/500], Loss: 1.7211
Epoch [3/100], Step [300/500], Loss: 1.9559
Epoch [3/100], Step [400/500], Loss: 1.7410
Epoch [3/100], Step [500/500], Loss: 1.7541
Epoch [4/100], Step [100/500], Loss: 1.8094
Epoch [4/100], Step [200/500], Loss: 2.0754
Epoch [4/100], Step [300/500], Loss: 1.7386
Epoch [4/100], Step [400/500], Loss: 1.8519
Epoch [4/100], Step [500/500], Loss: 1.7801
Epoch [5/100], Step [100/500], Loss: 1.6510
Epoch [5/100], Step [200/500], Loss: 1.6272
Epoch [5/100], Step [300/500], L

**Сравнение полученных моделей**

In [35]:
results = {}
for model_class in models:
    for activation in activations:
        key = (model_class.__name__, activation)
        orig_model = model_class(activation=activation).to(device)
        quant_model = model_class(activation=activation, quantize=True).to(device)
        configure_qat(quant_model)
        
        orig_model.load_state_dict(torch.load(f"best_{model_class.__name__}_{activation}.ckpt"))
        quant_model.load_state_dict(torch.load(f"best_{model_class.__name__}_{activation}_quantized.ckpt"), strict=True)

        orig_acc = evaluate(orig_model, f"{model_class.__name__} {activation}")
        quant_acc = evaluate(quant_model, f"Quantized {model_class.__name__} {activation}")

        results[key] = {
            "orig_acc": orig_acc,
            "quant_acc": quant_acc,
        }


for key, result in results.items():
    print(f"\nModel: {key[0]} with {key[1]}")
    print(f"Original Accuracy: {result['orig_acc']:.2f}%")
    print(f"Quantized Accuracy: {result['quant_acc']:.2f}%")

  orig_model.load_state_dict(torch.load(f"best_{model_class.__name__}_{activation}.ckpt"))
  quant_model.load_state_dict(torch.load(f"best_{model_class.__name__}_{activation}_quantized.ckpt"), strict=True)


Accuracy of LeNet5 relu on test images: 71.82%
Accuracy of Quantized LeNet5 relu on test images: 52.70%
Accuracy of LeNet5 hardtanh on test images: 65.54%
Accuracy of Quantized LeNet5 hardtanh on test images: 56.63%
Accuracy of LeNet5 relu6 on test images: 69.21%
Accuracy of Quantized LeNet5 relu6 on test images: 50.74%
Accuracy of ResNet20 relu on test images: 89.39%
Accuracy of Quantized ResNet20 relu on test images: 49.05%
Accuracy of ResNet20 hardtanh on test images: 85.53%
Accuracy of Quantized ResNet20 hardtanh on test images: 65.95%
Accuracy of ResNet20 relu6 on test images: 88.78%
Accuracy of Quantized ResNet20 relu6 on test images: 53.99%
Accuracy of ResNet18 relu on test images: 92.19%
Accuracy of Quantized ResNet18 relu on test images: 56.10%
Accuracy of ResNet18 hardtanh on test images: 88.78%
Accuracy of Quantized ResNet18 hardtanh on test images: 80.73%
Accuracy of ResNet18 relu6 on test images: 92.03%
Accuracy of Quantized ResNet18 relu6 on test images: 63.46%

Model: Le