# Дистилляция [SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)

Будем дистилировать SegFormer для [задачи сегментации людей](https://www.kaggle.com/datasets/laurentmih/aisegmentcom-matting-human-datasets).

Cам датасет [сегментации](https://drive.google.com/file/d/1YOEDzZvhLb2DS1Yn7p7MSs41ou3ZBXUq/view?usp=sharing)

### Установим библиотеки

Установим рекомендованные библиотеки

In [None]:
!pip3 install transformers tensorboard pillow ipywidgets --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m64.4 MB/s[0m eta [36m0:00:00[0m
[?25h

С библиотеками ниже возникают трудности по этому установим их зафиксировав версии


In [None]:
!pip3 install  datasets==v2.11.0 --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/468.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 kB[0m [31m30.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/110.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/194.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.8/194.8 kB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0m
[?25h

Проверим наличие cuda драйвера в системе

In [None]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


Установим торч в соответствии с версией драйвера

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121/ --quiet

In [None]:
import torch

print(torch.cuda.is_available())
print(torch.__version__)

True
2.5.1+cu124


Для колаба

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/distill/

/content/drive/MyDrive/distill


### Работа с кодом

Не люблю смотреть на FutureWarnings

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
import os

import typing as tp
import torch

from copy import deepcopy
from datasets import load_metric
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.auto import tqdm

# utils у нас появились при скачивании вспомогательного кода. При желании можно в них провалиться-поизучать
from utils.data import init_dataloaders
from utils.model import evaluate_model
from utils.model import init_model_with_pretrain

from torch import nn
from transformers.models.segformer.modeling_segformer import SegformerLayer

In [None]:
teacher_path = 'runs/baseline_ckpt.pth'
save_dir = 'runs/distillation'

In [None]:
tb_writer = SummaryWriter(save_dir)

In [None]:
# маппинг названия классов и индексов
id2label = {
    0: "background",
    1: "human",
}
label2id = {v: k for k, v in id2label.items()}

Создадим лоадеры:

In [None]:
train_dataloader, valid_dataloader = init_dataloaders(
    root_dir=".",
    batch_size=16,
    num_workers=8,
)



Создадим модель учителя:

In [None]:
teacher_model = init_model_with_pretrain(label2id=label2id, id2label=id2label, pretrain_path=teacher_path)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


И сразу отвалидируем:

In [None]:
evaluate_model(teacher_model, valid_dataloader, id2label)



Mean_iou: 0.9859686841316206
Mean accuracy: 0.9929542821587765


{'mean_iou': 0.9859686841316206,
 'mean_accuracy': 0.9929542821587765,
 'overall_accuracy': 0.9929354231878622,
 'per_category_iou': array([0.98610262, 0.98583475]),
 'per_category_accuracy': array([0.99129329, 0.99461527])}

## Делаем ученика

Посмотрим, как выглядит модель учителя:

Также можем воспользоваться сайтом https://netron.app/

In [None]:
teacher_model

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

Нас интересует (block): он состоит из нескольких ModuleList. Нас интересуют первые четыре. Посмотрим на первый из них:

```
(0): ModuleList(
          (0): SegformerLayer(
            .... тут много понаписано
          )
          (1): SegformerLayer(
            .... и тут тоже много всего
        )
```

В каждом из четырёх ModuleList сидит по два `SegformerLayer`. Напишем функцию, которая оставит только один (последний) из них.

In [None]:
def create_small_network(model):
    # """ Оставляет только по одному SegformerLayer в каждом ModuleList"""
    modulelist_count = len(teacher_model.segformer.encoder.block)

    for i in range(modulelist_count):
        block = model.segformer.encoder.block[i]
        model.segformer.encoder.block[i] = nn.ModuleList([block[-1]])
    return model

def n_params(model):
    return sum(p.numel() for p in model.parameters())

In [None]:
student_model = create_small_network(deepcopy(teacher_model))

In [None]:
n_params(teacher_model) / n_params(student_model)

In [None]:
# визуализируем и убедимся, что действительно выкинуты нужные слои
student_model

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

## Train Loop

Напишем старый-добрый трейнлуп и добавим в него дистилляционные лоссы.

In [None]:
from dataclasses import dataclass

@dataclass
class TrainParams:
    n_epochs: int
    lr: float
    batch_size: int
    n_workers: int
    device: torch.device

    loss_weight: float
    last_layer_loss_weight: float
    intermediate_attn_layers_weights: tp.Tuple[float, float, float, float]
    intermediate_feat_layers_weights: tp.Tuple[float, float, float, float]

In [None]:
train_params = TrainParams(
    n_epochs=1,
    lr=6e-5,
    batch_size=16,
    n_workers=8,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    loss_weight=0.5,
    last_layer_loss_weight=0.5,
    intermediate_attn_layers_weights=(0.5, 0.5, 0.5, 0.5),
    intermediate_feat_layers_weights=(0.5, 0.5, 0.5, 0.5),
)

In [None]:
def train(
    teacher_model,
    student_model,
    train_params: TrainParams,
    student_teacher_attention_mapping,
):
    metric = load_metric('mean_iou')
    teacher_model.to(train_params.device)
    student_model.to(train_params.device)

    teacher_model.eval()

    train_dataloader, valid_dataloader = init_dataloaders(
        root_dir=".",
        batch_size=train_params.batch_size,
        num_workers=train_params.n_workers,
    )

    optimizer = torch.optim.AdamW(student_model.parameters(), lr=train_params.lr)
    step = 0
    for epoch in range(train_params.n_epochs):
        pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
        for idx, batch in pbar:
            student_model.train()
            # get the inputs;
            pixel_values = batch['pixel_values'].to(train_params.device)
            labels = batch['labels'].to(train_params.device)

            optimizer.zero_grad()

            # forward + backward + optimize
            student_outputs = student_model(
                pixel_values=pixel_values,
                labels=labels,
                output_attentions=True,
                output_hidden_states=True,
            )
            loss, student_logits = student_outputs.loss, student_outputs.logits

            # Чего это мы no_grad() при тренировке поставили?!
            with torch.no_grad():
                teacher_output = teacher_model(
                    pixel_values=pixel_values,
                    labels=labels,
                    output_attentions=True,
                    output_hidden_states=True,
                )


            last_layer_loss = calc_last_layer_loss(
                student_logits,
                teacher_output.logits,
                train_params.last_layer_loss_weight,
            )

            student_attentions, teacher_attentions = student_outputs.attentions, teacher_output.attentions
            student_hidden_states, teacher_hidden_states = student_outputs.hidden_states, teacher_output.hidden_states

            intermediate_layer_att_loss = calc_intermediate_layers_attn_loss(
                student_attentions,
                teacher_attentions,
                train_params.intermediate_attn_layers_weights,
                student_teacher_attention_mapping,
            )

            intermediate_layer_feat_loss = calc_intermediate_layers_feat_loss(
                student_hidden_states,
                teacher_hidden_states,
                train_params.intermediate_feat_layers_weights,
            )

            total_loss = loss* train_params.loss_weight + last_layer_loss
            if intermediate_layer_att_loss is not None:
                total_loss += intermediate_layer_att_loss

            if intermediate_layer_feat_loss is not None:
                total_loss += intermediate_layer_feat_loss

            step += 1

            total_loss.backward()
            optimizer.step()
            pbar.set_description(f'total loss: {total_loss.item():.3f}')

            for loss_value, loss_name in (
                (loss, 'loss'),
                (total_loss, 'total_loss'),
                (last_layer_loss, 'last_layer_loss'),
                (intermediate_layer_att_loss, 'intermediate_layer_att_loss'),
                (intermediate_layer_feat_loss, 'intermediate_layer_feat_loss'),
            ):
                if loss_value is None: # для выключенной дистилляции атеншенов
                    continue
                tb_writer.add_scalar(
                    tag=loss_name,
                    scalar_value=loss_value.item(),
                    global_step=step,
                )

        #после модификаций модели обязательно сохраняйте ее целиком, чтобы подгрузить ее в случае чего
        torch.save(
            {
                'model': student_model,
                'state_dict': student_model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            },
            f'{save_dir}/ckpt_{epoch}.pth',
        )

        eval_metrics = evaluate_model(student_model, valid_dataloader, id2label)

        for metric_key, metric_value in eval_metrics.items():
            if not isinstance(metric_value, float):
                continue
            tb_writer.add_scalar(
                tag=f'eval_{metric_key}',
                scalar_value=metric_value,
                global_step=epoch,
            )


### Лосс для дистилляции последних слоёв

Напишем функцию `calc_last_layer_loss` , которая считает лосс между последними слоями учителя и ученика.

In [None]:
mse_loss = nn.MSELoss()
kl_loss = nn.KLDivLoss()

def calc_last_layer_loss(student_logits, teacher_logits, weight):
    """Считаем лосс между выходами учителя и ученика"""
    return mse_loss(student_logits, teacher_logits) * weight

def calc_intermediate_layers_attn_loss(student_logits, teacher_logits, weights, student_teacher_attention_mapping):
    return None

def calc_intermediate_layers_feat_loss(student_feat, teacher_feat, weights):
    return None

### Включим-посмотрим, как учится

In [None]:
import gc
import torch
gc.collect()
torch.cuda.empty_cache()

In [None]:
train(
    teacher_model=teacher_model,
    student_model=deepcopy(student_model),
    train_params=train_params,
    student_teacher_attention_mapping={}, # заполним потом
)

### Лосс для дистилляции атеншн-мап

Из каждого сегформер-блока можно достать атеншн-мапы:

In [None]:
with torch.no_grad():
    teacher_attentions = teacher_model(pixel_values=torch.ones(1, 3, 512, 512).to(train_params.device), output_attentions=True).attentions
    student_attentions = student_model(pixel_values=torch.ones(1, 3, 512, 512).to(train_params.device), output_attentions=True).attentions

In [None]:
teacher_attentions[0].shape

torch.Size([1, 1, 16384, 256])

In [None]:
print("Два слоя учителя")
print(teacher_attentions[0].shape)
print(teacher_attentions[1].shape)

print("Против одного слоя ученика")
print(student_attentions[0].shape)

Два слоя учителя
torch.Size([1, 1, 16384, 256])
torch.Size([1, 1, 16384, 256])
Против одного слоя ученика
torch.Size([1, 1, 16384, 256])


In [None]:
student_attentions[0]

tensor([[[[0.0034, 0.0025, 0.0025,  ..., 0.0039, 0.0039, 0.0039],
          [0.0026, 0.0022, 0.0022,  ..., 0.0040, 0.0040, 0.0040],
          [0.0026, 0.0022, 0.0022,  ..., 0.0040, 0.0040, 0.0040],
          ...,
          [0.0038, 0.0023, 0.0023,  ..., 0.0039, 0.0039, 0.0039],
          [0.0038, 0.0023, 0.0023,  ..., 0.0039, 0.0039, 0.0039],
          [0.0038, 0.0023, 0.0023,  ..., 0.0039, 0.0039, 0.0039]]]],
       device='cuda:0')

In [None]:
assert len(teacher_attentions) == 8
assert len(student_attentions) == 4

Будем дистиллировать и их!
Но у учителя у нас их целых 8, а у ученика четыре. Поэтому нужно сделать соответствие: номер какой фичемапы у ученика
будем тянуть к какому номеру фичемапы учителя.

In [None]:
student_teacher_attention_mapping = {0: 1, 1: 3, 2: 5, 3: 7}

Теперь напишем лосс, который принимает на вход списки фичемап ученика и учителя и тянет одно к другому.

In [None]:
def calc_intermediate_layers_loss(student_attentions, teacher_attentions, weights, student_teacher_attention_mapping):
    intermediate_kl_loss = 0
    for i, (stud_attn_idx, teach_attn_idx) in enumerate(student_teacher_attention_mapping.items()):
        intermediate_kl_loss += weights[i] * kl_loss(
            input=torch.log(student_attentions[stud_attn_idx]),
            target=teacher_attentions[teach_attn_idx],
        )
    return intermediate_kl_loss

In [None]:
calc_intermediate_layers_attn_loss(student_attentions, teacher_attentions, (0.5, 0.5, 0.5, 0.5), student_teacher_attention_mapping)

tensor(-0.0420, device='cuda:0')

### Лосс для дистилляции промежуточных фиче-мап

Помимо внимания, у вас также есть карты признаков, которые можно стягивать.

In [None]:
def calc_intermediate_layers_feat_loss(student_feats, teacher_feats, weights):
    intermediate_mse_loss = 0.

    for i in range(len(student_feats)):
        intermediate_mse_loss += weights[i] * mse_loss(
            input=student_feats[i],
            target=teacher_feats[i],
        )
    return intermediate_mse_loss

### Теперь можем тренировать со стягиванием разных фич

In [None]:
train_params = TrainParams(
    n_epochs=1,
    lr=6e-5,
    batch_size=16,
    n_workers=8,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    loss_weight=0.7,
    last_layer_loss_weight=0.3,
    intermediate_attn_layers_weights=(0.5, 0.5, 0.5, 0.5),
    intermediate_feat_layers_weights=(0.5, 0.5, 0.5, 0.5),
)

In [None]:
import gc
import torch
gc.collect()
torch.cuda.empty_cache()

In [None]:
train(
    teacher_model=teacher_model,
    student_model=deepcopy(student_model),
    train_params=train_params,
    student_teacher_attention_mapping=student_teacher_attention_mapping,
)



  0%|          | 0/1129 [00:00<?, ?it/s]

