## Downloading model

In [147]:
from torchvision.models import vgg16
from torchvision.models import VGG16_Weights

model = vgg16(weights=VGG16_Weights.DEFAULT)

In [173]:
import torch
def calc_weights(model):
    result = 0
    for p in model.features:
        if hasattr(p, "weight"):
            result += p.weight.numel()
            result += p.bias.numel()
    for p in model.classifier:
        if hasattr(p, "weight"):
            result += p.weight.numel()
            result += p.bias.numel()
    return result
def calc_nonzero_weights(model):
    result = 0
    for p in model.features:
        if hasattr(p, "weight"):
            result += torch.sum(p.weight != 0.)
            result += torch.sum(p.bias != 0.)
    for p in model.classifier:
        if hasattr(p, "weight"):
            result += torch.sum(p.weight != 0.)
            result += torch.sum(p.bias != 0.)
    return result

print(calc_weights(model), "elements;", calc_nonzero_weights(model).item(), "non-zero elements")

138357544 elements; 138357544 non-zero elements


In [152]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import max as torch_max

dataset = ImageFolder("ImageNet-Mini/images", transform=VGG16_Weights.DEFAULT.transforms())
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

def evaluate(model):
    correct = 0 
    for test_imgs, test_labels in dataloader:
        output = model(test_imgs)
        predicted = torch_max(output, 1)[1]
        correct += (predicted == test_labels).sum()
    return float(correct) / (len(dataloader))

In [153]:
%%time

print("accuracy:", evaluate(model))

accuracy: 0.6790721386693856
CPU times: user 23min 17s, sys: 1.29 s, total: 23min 18s
Wall time: 24min 41s


In [154]:
from torch import save
import os

def get_size_on_disk(model):
    torch.save(model.state_dict(), "file")
    sz = os.path.getsize("file")
    os.remove("file")
    return sz

print(get_size_on_disk(model), "bytes")

553438751 bytes


## Неструктурированный прунинг для полносвязных и сверточных слоев


In [155]:
import torch.nn.utils.prune as prune
import torch.nn as nn
from copy import deepcopy

class PrunedModel(nn.Module):
    def __init__(self, model):
        super(PrunedModel, self).__init__()
        self.model = deepcopy(model)
    
    def forward(self,X):
        return self.model(X)
    
    def prune(self, rate):
        # Используем l1_unstructured вместо нашего подхода
        # unstructured говорит о том, что нет ограничений на удаляемые веса
        # l1 говорит о том, что нужно смотреть на модуль веса
        for i, p in enumerate(self.model.features):
            if isinstance(p, nn.Linear) or isinstance(p, nn.Conv2d):
                self.model.features[i] = prune.l1_unstructured(p, 'weight', amount=rate)
        for i, p in enumerate(self.model.classifier):
            if isinstance(p, nn.Linear) or isinstance(p, nn.Conv2d):
                self.model.classifier[i] = prune.l1_unstructured(p, 'weight', amount=rate)
        
p_model = PrunedModel(model)
p_model.prune(0.5)

In [177]:
print(calc_weights(p_model.model), "elements;", calc_nonzero_weights(p_model.model).item(), "non-zero elements")

138357544 elements; 69185480 non-zero elements


In [179]:
%%time

print("accuracy:", evaluate(p_model))

accuracy: 0.5962273770073923
CPU times: user 28min 49s, sys: 9min 41s, total: 38min 31s
Wall time: 39min 58s


In [180]:
print(get_size_on_disk(p_model), "bytes")

1106819839 bytes


It works! We have twice less non-zero parameters, and significantly less accuracy (it is expected because we do quite aggressive pruning). However, the computation time also doubled, and the number of bytes on disk too. This is because we also store the mask, not only the weights. Let us remove the mask and compute the metrics again.

In [183]:
from torch.nn.utils.prune import remove
def remove_reparametrisation(model):
    for i, p in enumerate(model.model.features):
        if isinstance(p, nn.Linear) or isinstance(p, nn.Conv2d):
            model.model.features[i] = remove(p, 'weight')
    for i, p in enumerate(model.model.classifier):
        if isinstance(p, nn.Linear) or isinstance(p, nn.Conv2d):
            model.model.classifier[i] = remove(p, 'weight')
            
remove_reparametrisation(p_model)

In [184]:
print(calc_weights(p_model.model), "elements;", calc_nonzero_weights(p_model.model).item(), "non-zero elements")

138357544 elements; 69185480 non-zero elements


In [186]:
%%time

print("accuracy:", evaluate(p_model))

accuracy: 0.5926586795819526
CPU times: user 23min 18s, sys: 727 ms, total: 23min 19s
Wall time: 23min 19s


In [185]:
print(get_size_on_disk(p_model), "bytes")

553439263 bytes


Супер! Теперь время и место на диске не увеличились, идем дальше.

## Динамическая квантизация для полносвязных слоев

(Сверточные не квантизуются из коробки, в чате разрешили не квантизовать)

In [187]:
pq_model = torch.quantization.quantize_dynamic(
    p_model, {nn.Linear}, dtype=torch.qint8
)

In [190]:
import torch
def q_calc_weights(model):
    result = 0
    for module in model.children():
        if type(module) == torch.ao.nn.quantized.dynamic.modules.linear.Linear:
            result += module.weight().numel() + module.bias().numel()
        else:
            result += q_calc_weights(module)
    return result
def q_calc_nonzero_weights(model):
    result = 0
    for module in model.children():
        if type(module) == torch.ao.nn.quantized.dynamic.modules.linear.Linear:
            result += torch.sum(module.weight() != 0.) + torch.sum(module.bias() != 0.)
        else:
            result += q_calc_nonzero_weights(module)
    return result

print(q_calc_weights(pq_model), "elements;", q_calc_nonzero_weights(pq_model).item(), "non-zero elements")

123642856 elements; 61826024 non-zero elements


In [195]:
%%time

print("accuracy:", evaluate(pq_model))

accuracy: 0.6311496303849095
CPU times: user 22min 8s, sys: 596 ms, total: 22min 9s
Wall time: 22min 9s


In [192]:
print(get_size_on_disk(pq_model), "bytes")

182540539 bytes


Вроде все тоже работает - остались нули от прунинга, размер на диске уменьшился где-то на 30% (мы квантизовали только линейные слои, так что это нормально), accuracy после прунинга подросло, а время немного понизилось