In [4]:
import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import timm
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from EMA_for_weights import EMA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = torch.load("../data/cifar10_cutmix.pt")
images, labels = data

num_classes = len(labels.unique()) if labels.ndim == 1 else labels.shape[1]

train_loader = DataLoader(
    TensorDataset(images, labels),
    batch_size=64,
    shuffle=True
)

teacher_model = timm.create_model("deit_small_patch16_224", pretrained=False, num_classes=10)
teacher_model.load_state_dict(torch.load("../data/model_weights/deit_s_cifar10_aug.pt"))
teacher_model.to(device)
teacher_model.eval()

student_model = timm.create_model("deit_tiny_patch16_224", pretrained=True, num_classes=10)
student_model.to(device)

student_model_quantized = torch.quantization.quantize_dynamic(
    student_model,
    {nn.Linear, nn.Conv2d},
    dtype=torch.qint8
)

class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=4.0):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_outputs, teacher_outputs, labels):
        if labels.ndim > 1:
            labels = torch.argmax(labels, dim=1)
        labels = labels.long()
        ce_loss = self.ce_loss(student_outputs, labels)

        student_logits = student_outputs / self.temperature
        teacher_logits = teacher_outputs / self.temperature
        distill_loss = self.kl_loss(
            F.log_softmax(student_logits, dim=1),
            F.softmax(teacher_logits, dim=1)
        ) * (self.temperature ** 2)

        loss = (1 - self.alpha) * ce_loss + self.alpha * distill_loss
        return loss

In [5]:
criterion = DistillationLoss(alpha=0.7, temperature=4.0)
optimizer = optim.AdamW(student_model.parameters(), lr=5e-5)
ema = EMA(student_model, decay=0.999)

epochs = 5
student_model.train()

for epoch in range(epochs):
    total_loss = 0
    total_correct = 0
    total_samples = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

    for x, y in loop:
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            teacher_outputs = teacher_model(x)

        optimizer.zero_grad()
        student_outputs = student_model(x)

        loss = criterion(student_outputs, teacher_outputs, y)
        loss.backward()
        optimizer.step()
        ema.update(student_model)

        total_loss += loss.item()
        preds = student_outputs.argmax(dim=1)
        if y.ndim > 1:
            y_true = torch.argmax(y, dim=1)
        else:
            y_true = y
        correct = (preds == y_true).sum().item()
        total_correct += correct
        total_samples += y.size(0)

        accuracy = 100 * total_correct / total_samples
        loop.set_postfix(loss=loss.item(), accuracy=f"{accuracy:.2f}%")

    avg_loss = total_loss / len(train_loader)
    avg_accuracy = 100 * total_correct / total_samples
    print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_accuracy:.2f}%")

Epoch 1/5: 100%|██████████| 782/782 [01:56<00:00,  6.74it/s, accuracy=76.39%, loss=2.21] 


[Epoch 1] Avg Loss: 2.1087, Avg Accuracy: 76.39%


Epoch 2/5: 100%|██████████| 782/782 [02:00<00:00,  6.49it/s, accuracy=88.84%, loss=0.349]


[Epoch 2] Avg Loss: 0.8599, Avg Accuracy: 88.84%


Epoch 3/5: 100%|██████████| 782/782 [01:56<00:00,  6.70it/s, accuracy=92.36%, loss=0.403]


[Epoch 3] Avg Loss: 0.5830, Avg Accuracy: 92.36%


Epoch 4/5: 100%|██████████| 782/782 [01:56<00:00,  6.69it/s, accuracy=94.55%, loss=0.725]


[Epoch 4] Avg Loss: 0.4367, Avg Accuracy: 94.55%


Epoch 5/5: 100%|██████████| 782/782 [02:02<00:00,  6.39it/s, accuracy=96.03%, loss=0.295]

[Epoch 5] Avg Loss: 0.3454, Avg Accuracy: 96.03%





In [6]:
torch.save(student_model.state_dict(), "../data/model_weights/deit_tiny_full_distilled.pt")
torch.save(ema.state_dict(), "../data/model_weights/deit_tiny_full_distilled_ema.pt")

student_model.eval().cpu()  # Переводим в режим оценки и на CPU
student_model_quantized = torch.quantization.quantize_dynamic(
    student_model,
    {nn.Linear, nn.Conv2d},
    dtype=torch.qint8
)

torch.save(student_model_quantized.state_dict(), "../data/model_weights/deit_tiny_8bit_distilled.pt")

with torch.no_grad():
    torch.save(teacher_model.cpu().state_dict(), "../data/model_weights/teacher_temp.pt")
    torch.save(student_model.state_dict(), "../data/model_weights/student_full_temp.pt")
    torch.save(student_model_quantized.state_dict(), "../data/model_weights/student_quant_temp.pt")

    teacher_size = os.path.getsize("../data/model_weights/teacher_temp.pt") / (1024 * 1024)
    student_full_size = os.path.getsize("../data/model_weights/student_full_temp.pt") / (1024 * 1024)
    student_quant_size = os.path.getsize("../data/model_weights/student_quant_temp.pt") / (1024 * 1024)

    print(f"Размер модели учителя: {teacher_size:.2f} МБ")
    print(f"Размер студента (полная точность): {student_full_size:.2f} МБ")
    print(f"Размер квантизированной модели студента: {student_quant_size:.2f} МБ")
    print(f"Сжатие: {teacher_size/student_quant_size:.2f}x")


Размер модели учителя: 82.72 МБ
Размер студента (полная точность): 21.14 МБ
Размер квантизированной модели студента: 5.98 МБ
Сжатие: 13.84x
