In [16]:
import os
import time
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
from PIL import Image
from sklearn.metrics import accuracy_score

In [17]:
path_to_model = "weights/my_model"

extractor = AutoFeatureExtractor.from_pretrained(path_to_model)
vit_model = AutoModelForImageClassification.from_pretrained(path_to_model)



In [18]:
def model_use(model, img):
    model.eval()
    with torch.no_grad():
        logits = model(torch.quantize_per_tensor(img['pixel_values'], 0.1, 10, torch.quint8)).logits
        logits = model(img['pixel_values']).logits
        logits = logits.int_repr().to(torch.float32)

    predicted_label = logits.argmax(-1).item()

    return model.module.config.id2label[predicted_label]

In [19]:
path = "data/"
images_list = os.listdir(path)

In [20]:
def size_measurement(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()

    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / (1024 ** 2)
    print('model size: {:.3f}MB'.format(size_all_mb))

In [21]:
# Найдем исходный размер модели.
size_measurement(vit_model)

model size: 327.302MB


In [22]:
# Функция для просмотра всех слоев модели
def get_children(model: torch.nn.Module):
    children = list(model.children())
    flatt_children = []
    if children == []:
        return model
    else:
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children

In [23]:
get_children(vit_model)

[Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)),
 Dropout(p=0.0, inplace=False),
 Linear(in_features=768, out_features=768, bias=True),
 Linear(in_features=768, out_features=768, bias=True),
 Linear(in_features=768, out_features=768, bias=True),
 Dropout(p=0.0, inplace=False),
 Linear(in_features=768, out_features=768, bias=True),
 Dropout(p=0.0, inplace=False),
 Linear(in_features=768, out_features=3072, bias=True),
 GELUActivation(),
 Linear(in_features=3072, out_features=768, bias=True),
 Dropout(p=0.0, inplace=False),
 LayerNorm((768,), eps=1e-12, elementwise_affine=True),
 LayerNorm((768,), eps=1e-12, elementwise_affine=True),
 Linear(in_features=768, out_features=768, bias=True),
 Linear(in_features=768, out_features=768, bias=True),
 Linear(in_features=768, out_features=768, bias=True),
 Dropout(p=0.0, inplace=False),
 Linear(in_features=768, out_features=768, bias=True),
 Dropout(p=0.0, inplace=False),
 Linear(in_features=768, out_features=3072, bias=True),
 GELUActivati

In [24]:
vit_model

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [25]:
# Статическая квантизация модели

backend = "fbgemm"
vit_model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
quantized_model = torch.quantization.prepare(vit_model, inplace=False)
quantized_model = torch.quantization.convert(quantized_model, inplace=False)
model_static_quantized = torch.quantization.QuantWrapper(quantized_model)

In [26]:
start_time = time.time()

# Собака 1, кошка 0.
target_list = []
predict_list = []

for element in images_list:

    image = Image.open(path + element, mode='r', formats=None)

    inputs = extractor(image, return_tensors="pt")
    predict = model_use(model_static_quantized, inputs)
    target = element[:element.find(".")]

    if target == "dog":
        label = 1
    else:
        label = 0

    target_list.append(label) 

    if predict == "dogs":
        pr = 1
    else:
        pr = 0

    predict_list.append(pr)

end_time = time.time()

acc = accuracy_score(target_list, predict_list)
print("Точность квантизированной модели = ", acc)
print("Время обработки изображений квантизированной модели = ", end_time-start_time, " секунд")
print("Скорость обработки изображений у квантизированной модели составила  ", len(images_list)/(end_time-start_time), " картинок в секунду")

Точность квантизированной модели =  0.5
Время обработки изображений квантизированной модели =  12.914300203323364  секунд
Скорость обработки изображений у квантизированной модели составила   12.389366630862865  картинок в секунду


In [28]:
# Итоговый размер модели

size_measurement(model_static_quantized)

model size: 0.727MB
