# Прунинг [SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer) (50 баллов)

Будем прунить [SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer) для [задачи сегментации людей](https://www.kaggle.com/datasets/laurentmih/aisegmentcom-matting-human-datasets).

## Скачаем вспомогательный код и чекпоинт бейзлайна (не то же, что в первой домашке)

Скачайте архив, распакуйте, распакованные файлы выложите рядом с этим ноутбуком: https://drive.google.com/file/d/1kpWDZeYSWUM8o4TvnfP2yFg4ZDqrm2tv/view?usp=drive_link

### Скачаем датасет (Если остался с 1ой домшки можно переиспользовать)

Датасет находится по ссылке https://drive.google.com/file/d/1YOEDzZvhLb2DS1Yn7p7MSs41ou3ZBXUq/view?usp=sharing

Нужно его скачать и распаковать в папке, в которой находится ноутбук

### Установим библиотеки

Эти из прошлой домашки:

In [18]:
from torch.utils.data import dataloader
!pip install torch transformers datasets tensorboard pillow

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


А эти новые:

In [19]:
!pip install torch_pruning

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [20]:
import os

import typing
import torch

from copy import deepcopy
from evaluate import load as load_metric # alternative for load_metric
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.auto import tqdm

# utils у нас появились при скачивании вспомогательного кода. При желании можно в них провалиться-поизучать
from utils.data import init_dataloaders
from utils.model import evaluate_model
from utils.model import init_model_with_pretrain

from torch import nn
from transformers.models.segformer.modeling_segformer import SegformerLayer, SegformerEfficientSelfAttention, SegformerDecodeHead, SegformerMixFFN

import torch_pruning as tp

In [21]:
baseline_path = 'runs/baseline_ckpt.pth'
distilled_ckpt = 'runs/distillation/ckpt_2.pth'

In [22]:
# маппинг названия классов и индексов
id2label = {
    0: "background",
    1: "human",
}
label2id = {v: k for k, v in id2label.items()}

Создадим лоадеры:

In [23]:
train_dataloader, valid_dataloader = init_dataloaders(
    root_dir=".",
    batch_size=16,
    num_workers=32,
)



Создадим baseline модель:

In [24]:
baseline_model = init_model_with_pretrain(label2id=label2id, id2label=id2label, pretrain_path=baseline_path).cuda()

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model = torch.load(pretrain_path, map_location="cpu")["model"]


И сразу отвалидируем:

In [25]:
evaluate_model(baseline_model, valid_dataloader, id2label)

Mean_iou: 0.9859656551860897
Mean accuracy: 0.9929527605498385


{'mean_iou': 0.9859656551860897,
 'mean_accuracy': 0.9929527605498385,
 'overall_accuracy': 0.9929338872037976,
 'per_category_iou': array([0.9860996 , 0.98583171]),
 'per_category_accuracy': array([0.9912905 , 0.99461502])}

Создадим модель после дистилляции (можно использовать модель,полученную в первой домашке):

In [26]:
distilled_model = init_model_with_pretrain(label2id=label2id, id2label=id2label, pretrain_path=distilled_ckpt).cuda()

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Проверим точность:

In [27]:
evaluate_model(distilled_model, valid_dataloader, id2label)

Mean_iou: 0.9779576921804504
Mean accuracy: 0.9888835711945456


{'mean_iou': 0.9779576921804504,
 'mean_accuracy': 0.9888835711945456,
 'overall_accuracy': 0.9888568953962515,
 'per_category_iou': array([0.97815124, 0.97776415]),
 'per_category_accuracy': array([0.98653412, 0.99123302])}

Оценим вычислительную сложность и количество параметров моделей:

In [28]:
input_example = torch.rand(1,3,512,512, device="cuda")

In [29]:
ops, params = tp.utils.count_ops_and_params(baseline_model, input_example)
print(f"Baseline model complexity: {ops/1e6} MMAC, {params/1e6} M params")

Baseline model complexity: 6761.228288 MMAC, 3.714658 M params


In [30]:
ops, params = tp.utils.count_ops_and_params(distilled_model, input_example)
print(f"Distilled model complexity: {ops/1e6} MMAC, {params/1e6} M params")

Distilled model complexity: 5841.819136 MMAC, 2.29821 M params


Проверим, что модель после дистилляции имеет по одному SegformerLayer в block-е:

In [31]:
distilled_model

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

## Magnitude pruning (10 баллов)

Выполните one-shot прунинг модели по L2 норме весов в uniform режиме. Помните, что последний слой желательно не прунить. Поставьте pruning_ratio=0.5

In [32]:
def prune_model_l2(model):
    # check sizes of attentions heads before prunning
    for module in model.modules():
        if isinstance(module, SegformerEfficientSelfAttention):
            actual_channels = module.query.out_features
            print(f"Basic: out_features={actual_channels}, num_heads={ module.num_attention_heads}, head_dim={module.attention_head_size} (must be = {actual_channels / module.num_attention_heads})")
    
    # set importance
    imp = tp.importance.GroupNormImportance(p=2) # use L2 norm
    
    # ignore last layer
    ignored_layers = []
    for m in model.modules():
        if isinstance(m, nn.Conv2d) and m.out_channels == 2:
            ignored_layers.append(m)
            
    # set pruner params
    pruner = tp.pruner.MetaPruner(
        model,
        input_example,
        importance=imp,
        pruning_ratio=0.5,
        ignored_layers=ignored_layers,
        round_to=8,
    )
    
    # prune
    pruner.step()
    
    return model

pruned_model = prune_model_l2(deepcopy(distilled_model))

Basic: out_features=32, num_heads=1, head_dim=32 (must be = 32.0)
Basic: out_features=64, num_heads=2, head_dim=32 (must be = 32.0)
Basic: out_features=160, num_heads=5, head_dim=32 (must be = 32.0)
Basic: out_features=256, num_heads=8, head_dim=32 (must be = 32.0)


In [33]:
pruned_model

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 16, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((80,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(80, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
     

In [34]:
# Проверим, запускается ли наша запруненная сеть
pruned_model(input_example)

RuntimeError: shape '[1, 16384, 1, 32]' is invalid for input of size 262144

## Почините модельку (20 баллов)

In [150]:
# Проанализируйте лог ошибки, и поймите почему модель перестала запускаться после прунинга
# Подсказка, это связано со слоем внимания и размером голов

def fix_attention_layer(pruned_model):
    for module in pruned_model.modules():
        if isinstance(module, SegformerEfficientSelfAttention):
            actual_channels = module.query.out_features
            print(f"actual_channels = {actual_channels}")
            print(f"module.num_attention_heads {module.num_attention_heads}")
            print(f"module.attention_head_size was {module.attention_head_size}")
            module.attention_head_size = actual_channels // module.num_attention_heads
            print(f"module.attention_head_size now {module.attention_head_size}")
            module.all_head_size = module.num_attention_heads * module.attention_head_size
            print(f"module.all_head_size = {module.all_head_size}")
            # assert actual_channels % module.num_attention_heads == 0, f"actual_channels must be divisible by num_heads"
    return pruned_model

pruned_model_fixed = fix_attention_layer(deepcopy(pruned_model))

actual_channels = 16
module.num_attention_heads 1
module.attention_head_size was 32
module.attention_head_size now 16
module.all_head_size = 16
actual_channels = 32
module.num_attention_heads 2
module.attention_head_size was 32
module.attention_head_size now 16
module.all_head_size = 32
actual_channels = 80
module.num_attention_heads 5
module.attention_head_size was 32
module.attention_head_size now 16
module.all_head_size = 80
actual_channels = 128
module.num_attention_heads 8
module.attention_head_size was 32
module.attention_head_size now 16
module.all_head_size = 128


In [36]:
pruned_model_fixed

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 16, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((80,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(80, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
     

In [37]:
# Убедитесь, что модель запускается после фикса
pruned_model_fixed(input_example)

SemanticSegmenterOutput(loss=None, logits=tensor([[[[ 1.1423,  0.9427,  1.1825,  ...,  0.8781,  0.7674,  0.9538],
          [ 0.6324,  0.7941,  1.2474,  ...,  0.8479,  0.8587,  0.7896],
          [ 0.6281,  0.9249,  0.8642,  ...,  0.7581,  1.0199,  0.9772],
          ...,
          [ 0.4184,  0.6597,  0.6155,  ...,  0.6422,  0.8185,  0.8279],
          [ 0.2564,  0.4469,  0.4740,  ...,  0.6793,  0.7476,  0.8134],
          [ 0.4957,  0.5007,  0.5791,  ...,  0.6984,  0.6923,  0.7908]],

         [[-1.2877, -1.1341, -1.3187,  ..., -0.9192, -0.8073, -0.9824],
          [-0.9034, -1.0037, -1.3637,  ..., -0.8841, -0.9800, -0.8035],
          [-0.8044, -1.0815, -0.9879,  ..., -0.7882, -1.0589, -1.0020],
          ...,
          [-0.4889, -0.6870, -0.6548,  ..., -0.6741, -0.8868, -0.8847],
          [-0.4652, -0.5836, -0.5401,  ..., -0.6897, -0.7704, -0.8392],
          [-0.5292, -0.5369, -0.6092,  ..., -0.7096, -0.7169, -0.7926]]]],
       device='cuda:0', grad_fn=<ConvolutionBackward0>), hi

In [38]:
# Оценим вычислительную сложность получившейся модели
ops, params = tp.utils.count_ops_and_params(pruned_model_fixed, input_example)
print(f"Distilled model complexity (After magnitude pruning): {ops/1e6} MMAC, {params/1e6} M params")

Distilled model complexity (After magnitude pruning): 1497.484032 MMAC, 0.583858 M params


In [39]:
# Попробуем уменьшать модель еще сильнее, запрунив головы в attention.
# Функционал torch pruning это не поддерживает, однако это доступно в transformers
# Для выбора наименее полезных голов можно воспользоваться L2 нормой весов. 
# Мы же тут выкинем все, кроме нулевой.

pruned_model_fixed.segformer.encoder.block[1][0].attention.prune_heads([1])
pruned_model_fixed.segformer.encoder.block[2][0].attention.prune_heads([1,2,3,4])
pruned_model_fixed.segformer.encoder.block[3][0].attention.prune_heads([1,2,3,4,5,6,7])

In [40]:
# Снова оценим вычислительную сложность
ops, params = tp.utils.count_ops_and_params(pruned_model_fixed, input_example)
print(f"Distilled model complexity (After magnitude pruning): {ops/1e6} MMAC, {params/1e6} M params")

Distilled model complexity (After magnitude pruning): 1465.239744 MMAC, 0.50341 M params


##  Дообучение запруненной модели (5 баллов)

In [41]:
# перенесите свой трейновый пайплайн из предыдущей домашки в отдельный файл и воспользуйтесь им
from speed_up_nn.hw_02.utils.train import train, TrainParams

In [42]:
train_params = TrainParams(
    n_epochs=10,
    lr=1e-4,
    batch_size=16,
    n_workers=32,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    temperature=10.0,
    loss_weight=0.5,
    last_layer_loss_weight=0.5,
    intermediate_layers_weights=(0.5, 0.5, 0.5, 0.5),
    intermediate_attn_layers_weights=(0.5, 0.5, 0.5, 0.5),
    intermediate_feat_layers_weights=(0.5, 0.5, 0.5, 0.5),
    warmup_steps=1
)

In [43]:
save_dir = 'runs/magnitude_equal_pruning'
tb_writer = SummaryWriter(save_dir)
student_teacher_attention_mapping = {
    0: 0,
    1: 2,
    2: 4,
    3: 6,
}

In [44]:
train(
    teacher_model=baseline_model,
    student_model=deepcopy(pruned_model_fixed),
    train_params=train_params,
    student_teacher_attention_mapping=student_teacher_attention_mapping,
    tb_writer=tb_writer,
    save_dir=save_dir,
)

total loss: 0.440: 100%|██████████| 1129/1129 [05:33<00:00,  3.39it/s]


Mean_iou: 0.8043854644353918
Mean accuracy: 0.8919619601564333


total loss: 0.237: 100%|██████████| 1129/1129 [05:35<00:00,  3.37it/s]


Mean_iou: 0.8582472163625408
Mean accuracy: 0.9242506007229528


total loss: 0.174: 100%|██████████| 1129/1129 [05:27<00:00,  3.44it/s]


Mean_iou: 0.8865743980982412
Mean accuracy: 0.9398130361226941


total loss: 0.126: 100%|██████████| 1129/1129 [05:27<00:00,  3.45it/s]


Mean_iou: 0.9025849495382567
Mean accuracy: 0.9486952646177247


total loss: 0.090: 100%|██████████| 1129/1129 [05:25<00:00,  3.47it/s]


Mean_iou: 0.9144972774387226
Mean accuracy: 0.9554812447561024


total loss: 0.084: 100%|██████████| 1129/1129 [05:23<00:00,  3.49it/s]


Mean_iou: 0.9234864218574639
Mean accuracy: 0.960328962222267


total loss: 0.091: 100%|██████████| 1129/1129 [05:23<00:00,  3.49it/s]


Mean_iou: 0.921074263617037
Mean accuracy: 0.9588204783649343


total loss: 0.094: 100%|██████████| 1129/1129 [05:23<00:00,  3.49it/s]


Mean_iou: 0.9211735711302158
Mean accuracy: 0.958836230959981


total loss: 0.093: 100%|██████████| 1129/1129 [05:23<00:00,  3.49it/s]


Mean_iou: 0.9288159882186029
Mean accuracy: 0.9630465695728686


total loss: 0.115: 100%|██████████| 1129/1129 [05:23<00:00,  3.49it/s]


Mean_iou: 0.9319221630441362
Mean accuracy: 0.9648285064397568


# Taylor pruning (15 баллов)

Далее требуется выполнить прунинг по Taylor критерию важности, и сравнить точности полученных моделей после тюнинга. Уровень прунинга и структуру  (uniform) оставьте такой же, как для L2.

In [153]:
distilled_model = init_model_with_pretrain(label2id=label2id, id2label=id2label, pretrain_path=distilled_ckpt).cuda()

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [152]:
def prune_model_taylor(model):
    imp = tp.importance.TaylorImportance()
    
    ignored_layers = []
    for m in model.modules():
        if isinstance(m, torch.nn.Conv2d) and m.out_channels == 2:
            ignored_layers.append(m)

    pruner = tp.pruner.MetaPruner(
        model,
        input_example,
        importance=imp,
        pruning_ratio=0.5,
        global_pruning=True,
        round_to=8,
        ignored_layers=ignored_layers,
    )

    return pruner

pruner = prune_model_taylor(distilled_model)

Для прунинга по Тейлору необходимо накопить градиенты на весах, они используются для оценки важности каналов

In [142]:
def calibrate_model(model, dataloader, device='cuda', num_batches=32):
    model.train()
    model.to(device)
    
    # clear prev grads
    for param in model.parameters():
        if param.grad is not None:
            param.grad.zero_()
    
    # Run the data through the model and accumulate gradients to obtain the most important weights
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for idx, batch in pbar:
        if idx >= num_batches:
            break
        inputs = batch['pixel_values'].to(device)
        targets = batch['labels'].to(device)
        
        outputs = model(inputs)
        outputs_resized = torch.nn.functional.interpolate(outputs.logits, size=(512, 512), mode="bilinear", align_corners=False)
        loss = torch.nn.functional.cross_entropy(outputs_resized, targets)
        loss.backward()

    return model

In [143]:
# Обратите внимание, у вас применение прунинга и его создание разнесены по функциям.
def apply_taylor_pruning(pruner):
    pruner.step()
    return pruner.model

In [144]:
calibrated_model = calibrate_model(distilled_model, train_dataloader)
prunned_model_via_taylor = apply_taylor_pruning(pruner)

  3%|▎         | 32/1129 [00:11<06:43,  2.72it/s] 


In [145]:
# Проверим, запускается ли наша запруненная сеть
prunned_model_via_taylor(input_example);

RuntimeError: shape '[1, 16384, 1, 32]' is invalid for input of size 262144

In [146]:
prunned_model_via_taylor

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 16, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 112, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((112,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(112, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
  

Попробуйте тот же фикс, как для прунинга по L2

In [149]:
prunned_model_via_taylor_fixed = fix_attention_layer(prunned_model_via_taylor)

actual_channels = 16
module.attention_head_size was 16
module.attention_head_size now 16
module.num_attention_heads 1
module.all_head_size = 16
actual_channels = 32
module.attention_head_size was 16
module.attention_head_size now 16
module.num_attention_heads 2
module.all_head_size = 32
actual_channels = 80
module.attention_head_size was 16
module.attention_head_size now 16
module.num_attention_heads 5
module.all_head_size = 80
actual_channels = 128
module.attention_head_size was 16
module.attention_head_size now 16
module.num_attention_heads 8
module.all_head_size = 128


In [148]:
# Убедитесь, что модель запускается после фикса
prunned_model_via_taylor_fixed(input_example)

RuntimeError: shape '[1, 112, -1]' is invalid for input of size 22528

In [None]:
# Оценим сложность полученной модели
ops, params = tp.utils.count_ops_and_params(prunned_model_via_taylor_fixed, input_example)
print(f"Distilled model complexity (After taylor pruning): {ops/1e6} MMAC, {params/1e6} M params")

Выполним дообучение, и сравним точности

In [None]:
train_params = TrainParams(
    n_epochs=10,
    lr=1e-4,
    batch_size=16,
    n_workers=32,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    temperature=10.0,
    loss_weight=0.5,
    last_layer_loss_weight=0.5,
    intermediate_layers_weights=(0.5, 0.5, 0.5, 0.5),
    intermediate_attn_layers_weights=(0.5, 0.5, 0.5, 0.5),
    intermediate_feat_layers_weights=(0.5, 0.5, 0.5, 0.5),
    warmup_steps=1
)

In [None]:
save_dir = 'runs/taylor_equal_pruning'
tb_writer = SummaryWriter(save_dir)

In [None]:
train(
    teacher_model=baseline_model,
    student_model=deepcopy(prunned_model_via_taylor_fixed),
    train_params=train_params,
    student_teacher_attention_mapping=student_teacher_attention_mapping,
    tb_writer=tb_writer,
    save_dir=save_dir,
)