In [1]:
!gdown 1WGaxwrpoBzvS3njzm-rBp80cCqwapyZw # веса модели

Downloading...
From: https://drive.google.com/uc?id=1WGaxwrpoBzvS3njzm-rBp80cCqwapyZw
To: /content/ViT_2.pt
100% 343M/343M [00:07<00:00, 47.5MB/s]


In [None]:
!pip install transformers datasets

In [3]:
import os

import torch
from torch import nn
import torch.nn.utils.prune as prune

from torch.utils.data import Dataset, DataLoader
from torch.ao.quantization.qconfig import default_dynamic_qconfig, float_qparams_weight_only_qconfig

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

from transformers import ViTImageProcessor, ViTForImageClassification
from datasets import load_dataset

from tqdm.notebook import tqdm

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

'cpu'

Модель процессора:

In [7]:
!lscpu | grep 'Model name'

Model name:                      Intel(R) Xeon(R) CPU @ 2.20GHz


# Подготовка:

In [5]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = nn.Linear(768, 2)

Downloading (…)rocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [6]:
model.load_state_dict(torch.load('ViT_2.pt', map_location=torch.device('cpu')))
model.to(device);

In [8]:
ds = load_dataset('cats_vs_dogs')

Downloading builder script:   0%|          | 0.00/3.33k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.94k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/8.06k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/825M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/23410 [00:00<?, ? examples/s]

In [9]:
indexes = list(range((len(ds['train']))))
train, test = train_test_split(indexes, test_size=0.05, random_state=0)

In [10]:
class CustomDataset(Dataset):
    def __init__(self, ids, dataset):
        self.ids = ids
        self.ds = dataset

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
      image = self.ds['train'][index]['image']
      label = self.ds['train'][index]['labels']

      image = processor(
          image.convert("RGB"),
          return_tensors='pt'
          )

      image['pixel_values'] = image['pixel_values'].squeeze(0)

      return image, label

In [11]:
val_dataset = CustomDataset(
    ids=test,
    dataset=ds
)

In [12]:
val_loader = DataLoader(val_dataset, batch_size=50, shuffle=False, num_workers=2)

# До применения методов:

In [29]:
def get_model_size(model) -> None:
    model_filepath = 'tmp.pth'
    torch.save(model.state_dict(), model_filepath)
    print(
        'Model size (MB): {0:.5f}'.format(
            os.path.getsize(model_filepath) / 1024 ** 2,
        ),
    )

    os.remove(model_filepath)

get_model_size(model)

Model size (MB): 327.36292


In [23]:
val_loss = []
val_targets = []
val_preds = []

with torch.no_grad():
    for i, (batch, targets) in enumerate(tqdm(val_loader)):

        batch = batch.to(device)
        targets = targets.to(device)

        outputs = model(**batch)
        logits = outputs.logits
        val_targets.extend(targets.cpu().numpy())
        val_preds.extend(logits.argmax(axis=1).cpu().numpy())

print('F1:', f1_score(val_targets, val_preds, average='macro'))

  0%|          | 0/24 [00:00<?, ?it/s]

F1: 1.0


In [24]:
%%timeit
outputs = model(**batch)

16.8 s ± 317 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [45]:
model

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

# Динамическая квантизация:

In [97]:
qconfig = {
    nn.Linear: default_dynamic_qconfig,
}

model_int8 = torch.quantization.quantize_dynamic(model, qconfig)

In [98]:
get_model_size(model_int8)

Model size (MB): 84.41269


In [27]:
val_loss = []
val_targets = []
val_preds = []

with torch.no_grad():
  for i, (batch, targets) in enumerate(tqdm(val_loader)):

      batch = batch.to(device)
      targets = targets.to(device)

      outputs = model_int8(**batch)
      logits = outputs.logits
      val_targets.extend(targets.cpu().numpy())
      val_preds.extend(logits.argmax(axis=1).cpu().numpy())

print('F1:', f1_score(val_targets, val_preds, average='macro'))

  0%|          | 0/24 [00:00<?, ?it/s]

F1: 1.0


In [96]:
%%timeit
outputs = model_int8(**batch)

12.6 s ± 1.16 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
model_int8

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (key): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (value): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): DynamicQuantizedLinear(in_features=768, out_fea

# Прунинг:

Неструктурированный:

In [40]:
def prune_module(module: torch.nn.Module) -> None:
    if list(module.children()):
        return None
    named_parameters = [
        parameter_name for parameter_name, _ in module.named_parameters()
    ]
    for parameter_name in named_parameters:
        prune.l1_unstructured(module, name=parameter_name, amount=0.1)
        prune.remove(module, name=parameter_name)


model.apply(prune_module);

Размер после прунинга:

In [41]:
get_model_size(model)

Model size (MB): 327.36292


In [42]:
val_loss = []
val_targets = []
val_preds = []

with torch.no_grad():
  for i, (batch, targets) in enumerate(tqdm(val_loader)):

      batch = batch.to(device)
      targets = targets.to(device)

      outputs = model(**batch)
      logits = outputs.logits
      val_targets.extend(targets.cpu().numpy())
      val_preds.extend(logits.argmax(axis=1).cpu().numpy())

print('F1:', f1_score(val_targets, val_preds, average='macro'))

  0%|          | 0/24 [00:00<?, ?it/s]

F1: 0.4908695652173913


In [153]:
%%timeit
outputs = model(**batch)

15.2 s ± 462 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Структурированный:

In [38]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = nn.Linear(768, 2)

In [39]:
model.load_state_dict(torch.load('ViT_2.pt', map_location=torch.device('cpu')))
model.to(device);

In [34]:
for name, module in model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.ln_structured(module, name='weight', amount=0.1, n=1, dim=0)
        prune.remove(module, name='weight')

    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.ln_structured(module, name='weight', amount=0.1, n=1, dim=0)
        prune.remove(module, name='weight')

Размер после прунинга:

In [35]:
get_model_size(model)

Model size (MB): 327.36292


In [36]:
val_loss = []
val_targets = []
val_preds = []

with torch.no_grad():
  for i, (batch, targets) in enumerate(tqdm(val_loader)):

      batch = batch.to(device)
      targets = targets.to(device)

      outputs = model(**batch)
      logits = outputs.logits
      val_targets.extend(targets.cpu().numpy())
      val_preds.extend(logits.argmax(axis=1).cpu().numpy())

print('F1:', f1_score(val_targets, val_preds, average='macro'))

  0%|          | 0/24 [00:00<?, ?it/s]

F1: 0.4935121107266436


In [37]:
%%timeit
outputs = model(**batch)

14.5 s ± 290 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
