In [1]:
import os
import time
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
import torch.nn as nn
from PIL import Image
from sklearn.metrics import accuracy_score
import torch.nn.utils.prune as prune
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path_to_model = "weights/my_model"

extractor = AutoFeatureExtractor.from_pretrained(path_to_model)
vit_model = AutoModelForImageClassification.from_pretrained(path_to_model)



In [3]:
def model_use(model, img):
    with torch.no_grad():
        logits = model(**img).logits

    predicted_label = logits.argmax(-1).item()

    return model.config.id2label[predicted_label]

In [4]:
path = "data/"
images_list = os.listdir(path)

In [5]:
def size_measurement(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()

    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / (1024 ** 2)
    print('model size: {:.3f}MB'.format(size_all_mb))

In [6]:
# num of parameters
vit_model.num_parameters()

85800194

In [7]:
# Найдем исходный размер модели.
size_measurement(vit_model)

model size: 327.302MB


In [8]:
parameters_to_prune = []

vit_model_copy = copy.deepcopy(vit_model)

# prune several layers
parameters_to_prune.append((vit_model_copy.vit.embeddings.patch_embeddings.projection, 'weight'))
parameters_to_prune.append((vit_model_copy.vit.layernorm, 'weight'))
parameters_to_prune.append((vit_model_copy.classifier, 'weight'))


In [9]:
print(parameters_to_prune)

[(Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)), 'weight'), (LayerNorm((768,), eps=1e-12, elementwise_affine=True), 'weight'), (Linear(in_features=768, out_features=2, bias=True), 'weight')]


In [10]:
print(
    "Sparsity in vit_model_copy.vit.embeddings.patch_embeddings.projection.weight: {:.2f}%".format(
        100. * float(torch.sum(vit_model_copy.vit.embeddings.patch_embeddings.projection.weight == 0))
        / float(vit_model_copy.vit.embeddings.patch_embeddings.projection.weight.nelement())
    )
)



print(
    "Sparsity in vit_model_copy.vit.layernorm.weight: {:.2f}%".format(
        100. * float(torch.sum(vit_model_copy.vit.layernorm.weight == 0))
        / float(vit_model_copy.vit.layernorm.weight.nelement())
    )
)

print(
    "Sparsity in vit_model_copy.classifier.weight: {:.2f}%".format(
        100. * float(torch.sum(vit_model_copy.classifier.weight == 0))
        / float(vit_model_copy.classifier.weight.nelement())
    )
)


print(  
    "Global sparsity: {:.2f}%".format(  
        100. * float(torch.sum(vit_model_copy.vit.embeddings.patch_embeddings.projection.weight == 0)  
            +  torch.sum(vit_model_copy.vit.layernorm.weight == 0)  
            + torch.sum(vit_model_copy.classifier.weight == 0)  
 
        )  
        / float(vit_model_copy.vit.embeddings.patch_embeddings.projection.weight.nelement()  
            + vit_model_copy.vit.layernorm.weight.nelement()  
            + vit_model_copy.classifier.weight.nelement()  
 
        )  
    )  
) 

Sparsity in vit_model_copy.vit.embeddings.patch_embeddings.projection.weight: 0.00%
Sparsity in vit_model_copy.vit.layernorm.weight: 0.00%
Sparsity in vit_model_copy.classifier.weight: 0.00%
Global sparsity: 0.00%


In [11]:
print(vit_model.prune_heads)

<bound method PreTrainedModel.prune_heads of ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermedia

In [12]:
pruning_method = torch.nn.utils.prune.L1Unstructured

prune.global_unstructured(  
    parameters_to_prune,  
    pruning_method=pruning_method,  
    amount=0.5)    

In [13]:
print(
    "Sparsity in vit_model_copy.vit.embeddings.patch_embeddings.projection.weight: {:.2f}%".format(
        100. * float(torch.sum(vit_model_copy.vit.embeddings.patch_embeddings.projection.weight == 0))
        / float(vit_model_copy.vit.embeddings.patch_embeddings.projection.weight.nelement())
    )
)

print(
    "Sparsity in vit_model_copy.vit.layernorm.weight: {:.2f}%".format(
        100. * float(torch.sum(vit_model_copy.vit.layernorm.weight == 0))
        / float(vit_model_copy.vit.layernorm.weight.nelement())
    )
)

print(
    "Sparsity in vit_model_copy.classifier.weight: {:.2f}%".format(
        100. * float(torch.sum(vit_model_copy.classifier.weight == 0))
        / float(vit_model_copy.classifier.weight.nelement())
    )
)


print(  
    "Global sparsity: {:.2f}%".format(  
        100. * float(torch.sum(vit_model_copy.vit.embeddings.patch_embeddings.projection.weight == 0)  
            +  torch.sum(vit_model_copy.vit.layernorm.weight == 0)  
            + torch.sum(vit_model_copy.classifier.weight == 0)  
 
        )  
        / float(vit_model_copy.vit.embeddings.patch_embeddings.projection.weight.nelement()  
            + vit_model_copy.vit.layernorm.weight.nelement()  
            + vit_model_copy.classifier.weight.nelement()  
 
        )  
    )  
)

Sparsity in vit_model_copy.vit.embeddings.patch_embeddings.projection.weight: 50.01%
Sparsity in vit_model_copy.vit.layernorm.weight: 0.39%
Sparsity in vit_model_copy.classifier.weight: 70.31%
Global sparsity: 50.00%


In [14]:
# num of parameters
vit_model_copy.num_parameters()

85800194

In [15]:
# Найдем размер модели после глобального неструктурированного прунинга.
size_measurement(vit_model_copy)

model size: 329.561MB


In [16]:
start_time = time.time()

# Собака 1, кошка 0.
target_list = []
predict_list = []

for element in images_list:

    image = Image.open(path + element, mode='r', formats=None)

    inputs = extractor(image, return_tensors="pt")
    predict = model_use(vit_model_copy, inputs)
    target = element[:element.find(".")]

    if target == "dog":
        label = 1
    else:
        label = 0

    target_list.append(label)

    if predict == "dogs":
        pr = 1
    else:
        pr = 0

    predict_list.append(pr)

end_time = time.time()

acc = accuracy_score(target_list, predict_list)
print("Точность модели после прунинга= ", acc)
print("Время обработки изображений модели после прунинга = ", end_time-start_time, " секунд")
print("Скорость обработки изображений у модели после прунинга составила  ", len(images_list)/(end_time-start_time), " картинок в секунду")

Точность модели после прунинга=  0.95
Время обработки изображений модели после прунинга =  879.3829545974731  секунд
Скорость обработки изображений у модели после прунинга составила   0.18194575999399268  картинок в секунду
