In [6]:
import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import timm
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import copy
from experiments.EMA_for_weights import EMA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = torch.load("../data/cifar10_cutmix.pt")
images, labels = data

train_loader = DataLoader(
    TensorDataset(images, labels),
    batch_size=64,
    shuffle=True
)



In [7]:
teacher_model = timm.create_model("deit_small_patch16_224", pretrained=False, num_classes=10)
teacher_model.load_state_dict(torch.load("../data/model_weights/deit_s_cifar10_aug.pt"))
teacher_model.to(device)
teacher_model.eval()

class LowRankLinear(nn.Module):
    def __init__(self, in_features, out_features, rank=None, bias=True):
        super().__init__()
        if rank is None:
            rank = min(in_features, out_features) // 4
        self.rank = rank
        self.in_features = in_features
        self.out_features = out_features
        self.U = nn.Parameter(torch.Tensor(in_features, rank))
        self.V = nn.Parameter(torch.Tensor(rank, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.U, a=np.sqrt(5))
        nn.init.kaiming_uniform_(self.V, a=np.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.U)
            bound = 1 / np.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        output = input @ self.U @ self.V
        if self.bias is not None:
            output += self.bias
        return output

    @staticmethod
    def from_linear(linear_layer, rank=None):
        W = linear_layer.weight.data.t()
        in_f, out_f = W.shape
        if rank is None:
            rank = min(in_f, out_f) // 4
        rank = min(rank, in_f, out_f)
        low_rank = LowRankLinear(in_f, out_f, rank, bias=(linear_layer.bias is not None))
        U_full, S, V_full = torch.svd(W)
        low_rank.U.data = U_full[:, :rank] * S[:rank].unsqueeze(0)
        low_rank.V.data = V_full.t()[:rank, :]
        if linear_layer.bias is not None:
            low_rank.bias.data = linear_layer.bias.data.clone()
        return low_rank

In [8]:
def create_low_rank_model(model, rank_factor=4):
    new_model = copy.deepcopy(model)
    skip = ['qkv', 'norm', 'cls_token', 'pos_embed']
    replaced = skipped = 0
    modules = dict(new_model.named_modules())
    for name, module in modules.items():
        if not isinstance(module, nn.Linear):
            continue
        if any(k in name for k in skip):
            skipped += 1
            continue
        rank = min(module.in_features, module.out_features) // rank_factor
        if rank < 2:
            skipped += 1
            continue
        low_rank = LowRankLinear.from_linear(module, rank=rank)
        # Встраиваем обратно
        parent_name, attr = (name.rsplit('.', 1) + [''])[:2]
        parent = modules[parent_name] if parent_name else new_model
        setattr(parent, attr, low_rank)
        replaced += 1
    print(f"Replaced {replaced} layers, skipped {skipped}")
    return new_model

base_model = timm.create_model("deit_tiny_patch16_224", pretrained=True, num_classes=10)
student_model = create_low_rank_model(base_model, rank_factor=4)
student_model.to(device)

ema = EMA(student_model, decay=0.999)

class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=4.0):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce = nn.CrossEntropyLoss()
        self.kl = nn.KLDivLoss(reduction='batchmean')
    def forward(self, student_out, teacher_out, labels):
        if labels.ndim > 1:
            labels = labels.argmax(dim=1)
        ce_loss = self.ce(student_out, labels)
        s_logits = student_out / self.temperature
        t_logits = teacher_out / self.temperature
        distill = self.kl(F.log_softmax(s_logits,1), F.softmax(t_logits,1)) * (self.temperature**2)
        return (1-self.alpha)*ce_loss + self.alpha*distill

Replaced 37 layers, skipped 12


In [9]:
criterion = DistillationLoss(alpha=0.7, temperature=4.0)
optimizer = optim.AdamW(student_model.parameters(), lr=5e-5)
epochs = 5

student_model.train()
for epoch in range(1, epochs+1):
    total_loss = total_corr = total_samples = 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
        x = F.interpolate(x, size=(224,224), mode='bilinear', align_corners=False).to(device)
        y = y.to(device)
        with torch.no_grad():
            t_out = teacher_model(x)
        optimizer.zero_grad()
        s_out = student_model(x)
        loss = criterion(s_out, t_out, y)
        loss.backward()
        optimizer.step()
        ema.update(student_model)
        total_loss += loss.item()
        preds = s_out.argmax(dim=1)
        true = y.argmax(1) if y.ndim>1 else y
        total_corr += (preds == true).sum().item()
        total_samples += y.size(0)
    print(f"[Epoch {epoch}] Loss: {total_loss/len(train_loader):.4f}, Acc: {100*total_corr/total_samples:.2f}%")

Epoch 1/5: 100%|██████████| 782/782 [01:50<00:00,  7.09it/s]


[Epoch 1] Loss: 5.0986, Acc: 42.55%


Epoch 2/5: 100%|██████████| 782/782 [01:50<00:00,  7.06it/s]


[Epoch 2] Loss: 3.4735, Acc: 60.41%


Epoch 3/5: 100%|██████████| 782/782 [01:50<00:00,  7.05it/s]


[Epoch 3] Loss: 2.7883, Acc: 67.26%


Epoch 4/5: 100%|██████████| 782/782 [01:50<00:00,  7.07it/s]


[Epoch 4] Loss: 2.3646, Acc: 71.76%


Epoch 5/5: 100%|██████████| 782/782 [01:50<00:00,  7.07it/s]

[Epoch 5] Loss: 2.0297, Acc: 75.60%





In [10]:
torch.save(student_model.state_dict(), "../data/model_weights/deit_tiny_low_rank.pt")
torch.save(ema.state_dict(), "../data/model_weights/deit_tiny_low_rank_ema.pt")

with torch.no_grad():
    torch.save(teacher_model.cpu().state_dict(), "../data/model_weights/teacher_temp.pt")
    torch.save(base_model.cpu().state_dict(), "../data/model_weights/student_base_temp.pt")
    torch.save(student_model.cpu().state_dict(), "../data/model_weights/student_low_rank_temp.pt")

    teacher_size = os.path.getsize("../data/model_weights/teacher_temp.pt") / (1024 * 1024)
    student_base_size = os.path.getsize("../data/model_weights/student_base_temp.pt") / (1024 * 1024)
    student_low_rank_size = os.path.getsize("../data/model_weights/student_low_rank_temp.pt") / (1024 * 1024)

    print(f"Размер модели учителя: {teacher_size:.2f} МБ")
    print(f"Размер базовой модели студента: {student_base_size:.2f} МБ")
    print(f"Размер низкоранговой модели студента: {student_low_rank_size:.2f} МБ")
    print(f"Сжатие относительно учителя: {teacher_size / student_low_rank_size:.2f}x")
    print(f"Сжатие относительно базовой модели: {student_base_size / student_low_rank_size:.2f}x")


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"Количество параметров в учителе: {count_parameters(teacher_model):,}")
print(f"Количество параметров в базовом студенте: {count_parameters(base_model):,}")
print(f"Количество параметров в низкоранговом студенте: {count_parameters(student_model):,}")

Размер модели учителя: 82.72 МБ
Размер базовой модели студента: 21.14 МБ
Размер низкоранговой модели студента: 11.03 МБ
Сжатие относительно учителя: 7.50x
Сжатие относительно базовой модели: 1.92x
Количество параметров в учителе: 21,669,514
Количество параметров в базовом студенте: 5,526,346
Количество параметров в низкоранговом студенте: 2,872,552
