In [1]:
import timm
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1) Модели
teacher = timm.create_model("deit_small_patch16_224", num_classes=10, pretrained=False)
ckpt = torch.load("../data/model_weights/deit_small_cifar10.pth", map_location='cpu')
teacher.load_state_dict(ckpt["model_state_dict"])
teacher.eval().to(device)

student = timm.create_model("deit_tiny_patch16_224", num_classes=10, pretrained=False)
student.train().to(device)

  from .autonotebook import tqdm as notebook_tqdm


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)


In [11]:
from ffcv.fields.rgb_image import RandomResizedCropRGBImageDecoder
import os
import timm
import torch
import torch.nn.functional as F
from torch.optim import Adam
from ffcv.loader import Loader, OrderOption
from ffcv.fields.decoders import SimpleRGBImageDecoder, NDArrayDecoder
from ffcv.transforms import (
    RandomResizedCrop, RandomHorizontalFlip,
    ToTensor, ToTorchImage, Convert, ToDevice,Cutout
)

# 2) Даталоадер из .beton
data_path = './data/cifar10_cutmix.beton'
assert os.path.exists(data_path), f"{data_path} not found"

# 1) Image pipeline: только FFCV‑де­кодер + FFCV‑трансформы
image_pipeline = [
    RandomResizedCropRGBImageDecoder((224, 224)),
    RandomHorizontalFlip(),
    Cutout(16),                # нарезает «дырки» — эквивалент augment
    ToTensor(),              # делает torch.tensor и нормализует [0,255]→[0,255]
    ToTorchImage(),          # переставляет оси [H,W,C]→[C,H,W]
    ToDevice(device, non_blocking=True),
]

# 2) Label pipeline: soft‑labels
label_pipeline = [
    NDArrayDecoder(),        # читает float32[10] soft‑labels
    ToTensor(),              # делает torch.tensor
    ToDevice(device, non_blocking=True),
]

train_loader = Loader(
    data_path,
    batch_size=128,
    num_workers=4,
    order=OrderOption.RANDOM,
    drop_last=True,
    pipelines={
        'image': image_pipeline,
        'label': label_pipeline
    },
)

In [12]:
# 3) Loss с учётом soft-лейблов и teacher logits
def distillation_loss(student_logits, teacher_logits, soft_labels, alpha=0.5, T=2.0):
    # KL между student и teacher (температурная дистилляция)
    kl_st = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T * T)

    # Soft‑CE между student и CutMix‑метками
    ce_soft = F.kl_div(
        F.log_softmax(student_logits, dim=1),
        soft_labels,
        reduction='batchmean'
    )

    return alpha * kl_st + (1. - alpha) * ce_soft


In [15]:
# 4) Оптимизатор и тренировочный цикл
optimizer = Adam(student.parameters(), lr=3e-4)

for epoch in range(1, 11):
    total_loss = 0.0
    for batch in train_loader:
        imgs = batch[0]
        soft_labels = batch[1]
        #imgs = batch['image']
        #soft_labels = batch['label']    # float32 shape [B,10]

        with torch.no_grad():
            t_logits = teacher(imgs)

        s_logits = student(imgs)
        loss = distillation_loss(s_logits, t_logits, soft_labels, alpha=0.7, T=2.0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg = total_loss / len(train_loader)
    print(f"Epoch {epoch:2d} — avg loss: {avg:.4f}")

RuntimeError: Input type (unsigned char) and bias type (float) should be the same