In [1]:
import time
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.nn.utils import prune
from torchsummary import summary

In [2]:
BATCH_SIZE = 128

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load data

In [3]:
train_transforms = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465],
                std=[0.2023, 0.1994, 0.2010])
])

val_transforms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465],
                std=[0.2023, 0.1994, 0.2010])
])

In [4]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transforms)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


# Init model

In [5]:
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.block1 = self._make_conv_block(in_ch=3,   out_ch=64)
        self.block2 = self._make_conv_block(in_ch=64,  out_ch=128)
        self.block3 = self._make_conv_block(in_ch=128, out_ch=256)
        self.block4 = self._make_conv_block(in_ch=256, out_ch=512)

        self.classifier = nn.Sequential(
            nn.Linear(512 * 2 * 2, 2048),  # из 512x2x2 -> 2048
            nn.ReLU(inplace=True),
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes)
        )

    def _make_conv_block(self, in_ch, out_ch):
        block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(kernel_size=2)
        )
        return block

    def forward(self, x):
        x = self.block1(x) 
        x = self.block2(x)  
        x = self.block3(x)  
        x = self.block4(x) 

        x = x.view(x.size(0), -1)  
        x = self.classifier(x)
        return x

model = CNN(num_classes=10)
model.load_state_dict(torch.load('./model.pt', weights_only=True))

_ = summary(model, input_size=(BATCH_SIZE, 3, 32, 32), device=DEVICE, depth=4)

Layer (type:depth-idx)                   Param #
├─Sequential: 1-1                        --
|    └─Conv2d: 2-1                       1,728
|    └─BatchNorm2d: 2-2                  128
|    └─ReLU: 2-3                         --
|    └─Conv2d: 2-4                       36,864
|    └─BatchNorm2d: 2-5                  128
|    └─ReLU: 2-6                         --
|    └─MaxPool2d: 2-7                    --
├─Sequential: 1-2                        --
|    └─Conv2d: 2-8                       73,728
|    └─BatchNorm2d: 2-9                  256
|    └─ReLU: 2-10                        --
|    └─Conv2d: 2-11                      147,456
|    └─BatchNorm2d: 2-12                 256
|    └─ReLU: 2-13                        --
|    └─MaxPool2d: 2-14                   --
├─Sequential: 1-3                        --
|    └─Conv2d: 2-15                      294,912
|    └─BatchNorm2d: 2-16                 512
|    └─ReLU: 2-17                        --
|    └─Conv2d: 2-18                      589,

# Optimization

## Quantization

In [6]:
quantized_model = torch.quantization.quantize_dynamic(model, {nn.Conv2d, nn.Linear}, dtype=torch.qint8)

## Pruning

In [7]:
def apply_pruning(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.ln_structured(module, name='weight', amount=amount, n=1, dim=0)
        elif isinstance(module, nn.Linear):
            prune.ln_structured(module, name='weight', amount=amount, n=1, dim=1)

def remove_pruning_masks(model):
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.remove(module, 'weight')

pruned_model = CNN(num_classes=10)
pruned_model.load_state_dict(torch.load('./model.pt', weights_only=True))
apply_pruning(pruned_model)
remove_pruning_masks(pruned_model)

# Eval

## Quantization

In [8]:
param_size = 0
for param in quantized_model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in quantized_model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('Model size: {:.3f}MB'.format(size_all_mb))

Model size: 17.895MB


In [9]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

params = count_parameters(quantized_model)
print("Params:", params)

Params: 4687296


In [10]:
inp = torch.randn(1, 3, 32, 32)

num_samples = 100
start_time = time.time()
for _ in range(num_samples):
    output = quantized_model(inp)
end_time = time.time()

infer_time = ((end_time - start_time) / num_samples) * 1000
print(f'Avg inference time: {infer_time:.4f} ms')

Avg inference time: 1.9107 ms


In [11]:
correct = 0
total = 0
with torch.no_grad():
    for X_, y_ in val_loader:
        outputs = quantized_model(X_)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y_).sum().item()
        total += y_.size(0)

accuracy = 100.0 * correct / total
print("Accuracy:", accuracy)

Accuracy: 87.98


## Pruning

### Размеры остались прежними, потому что torch не пересоздает архитектуру 

In [12]:
param_size = 0
for param in pruned_model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in pruned_model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('Model size: {:.3f}MB'.format(size_all_mb))

Model size: 41.946MB


In [13]:
params = count_parameters(pruned_model)
print("Params:", params)

Params: 10992074


In [14]:
inp = torch.randn(1, 3, 32, 32)

num_samples = 100
start_time = time.time()
for _ in range(num_samples):
    output = pruned_model(inp)
end_time = time.time()

infer_time = ((end_time - start_time) / num_samples) * 1000
print(f'CPU Avg inference time: {infer_time:.4f} ms')

CPU Avg inference time: 2.3608 ms


In [15]:
num_samples = 100

pruned_model.to(DEVICE)
start_time = time.time()
for _ in range(num_samples):
    output = pruned_model(inp.to(DEVICE))
end_time = time.time()

infer_time = ((end_time - start_time) / num_samples) * 1000
print(f'GPU Avg inference time: {infer_time:.4f} ms')

GPU Avg inference time: 1.1596 ms


In [16]:
correct = 0
total = 0
with torch.no_grad():
    for X_, y_ in val_loader:
        X_, y_ = X_.to(DEVICE), y_.to(DEVICE)
        outputs = pruned_model(X_)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y_).sum().item()
        total += y_.size(0)

accuracy = 100.0 * correct / total
print("Accuracy:", accuracy)

Accuracy: 65.66
