In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# настройки
batch_size = 128
epochs = 10
finetune_epochs = 10
lr = 0.01
rank_ratio = 0.3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# данные
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
testloader  = DataLoader(testset,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# модели
class TeacherNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(8*8*128, 512)
        self.fc2 = nn.Linear(512, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 8*8*128)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(8*8*64, 256)
        self.fc2 = nn.Linear(256, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 8*8*64)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# utils
def train_epoch(model, loader, opt, criterion):
    model.train()
    total = 0.0
    for x,y in tqdm(loader, leave=False):
        x,y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        opt.step()
        total += loss.item()
    return total/len(loader)

def evaluate(model, loader):
    model.eval()
    ok, n = 0, 0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            out = model(x)
            ok += (out.argmax(1)==y).sum().item()
            n  += y.size(0)
    return ok/n

def distill_loss(s_out, t_out, y, T=3.0, alpha=0.7):
    ce = F.cross_entropy(s_out, y)
    kd = F.kl_div(F.log_softmax(s_out/T, dim=1), F.softmax(t_out/T, dim=1), reduction="batchmean")*(T*T)
    return alpha*kd + (1-alpha)*ce

def count_total_params(m):
    return sum(p.numel() for p in m.parameters())

def count_linear_params(m):
    s = 0
    for mod in m.modules():
        if isinstance(mod, nn.Linear):
            s += sum(p.numel() for p in mod.parameters())
    return s

# SVD-компрессия (device-safe)
def svd_compress_linear(layer: nn.Linear, rank_ratio: float):
    W = layer.weight.data
    bias = layer.bias.data.clone() if layer.bias is not None else None

    U, S, Vh = torch.linalg.svd(W, full_matrices=False)
    r = max(1, int(rank_ratio * min(W.shape[0], W.shape[1])))

    U_r = U[:, :r]                     # (out, r)
    S_r = torch.diag(S[:r])            # (r, r)
    Vh_r = Vh[:r, :]                   # (r, in)

    # первый: in->r (вес shape=(r,in) == Vh_r)
    first  = nn.Linear(W.shape[1], r, bias=False, dtype=W.dtype, device=W.device)
    first.weight.data = Vh_r           # (r, in)

    # второй: r->out (вес shape=(out,r) == U_r @ S_r)
    second = nn.Linear(r, W.shape[0], bias=True, dtype=W.dtype, device=W.device)
    second.weight.data = U_r @ S_r
    if bias is not None:
        second.bias.data = bias
    else:
        second.bias.data.zero_()

    new_seq = nn.Sequential(first, second).to(W.device)
    orig_params = W.numel() + (layer.bias.numel() if layer.bias is not None else 0)
    new_params  = first.weight.numel() + second.weight.numel() + (second.bias.numel())
    return new_seq, orig_params, new_params, r

def compress_model_linear_only(model: nn.Module, rank_ratio: float):
    stats = []
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            new_layer, orig_p, new_p, r = svd_compress_linear(module, rank_ratio)
            setattr(model, name, new_layer)
            stats.append((name, module.in_features, module.out_features, r, orig_p, new_p))
        else:
            stats.extend(compress_model_linear_only(module, rank_ratio))
    return stats

# обучение
teacher = TeacherNet().to(device)
opt_t = torch.optim.SGD(teacher.parameters(), lr=lr, momentum=0.9)
crit = nn.CrossEntropyLoss()

print("обучаем teacher")
for e in range(epochs):
    loss = train_epoch(teacher, trainloader, opt_t, crit)
    acc  = evaluate(teacher, testloader)
    print(f"epoch {e+1}/{epochs}, loss={loss:.4f}, acc={acc:.4f}")
teacher_acc    = evaluate(teacher, testloader)
teacher_params = count_total_params(teacher)

student = StudentNet().to(device)
opt_s = torch.optim.SGD(student.parameters(), lr=lr, momentum=0.9)

print("обучаем student с distillation")
for e in range(epochs):
    student.train(); teacher.eval()
    total = 0.0
    for x,y in tqdm(trainloader, leave=False):
        x,y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt_s.zero_grad()
        with torch.no_grad():
            t_out = teacher(x)
        s_out = student(x)
        loss  = distill_loss(s_out, t_out, y, T=3.0, alpha=0.7)
        loss.backward()
        opt_s.step()
        total += loss.item()
    acc = evaluate(student, testloader)
    print(f"epoch {e+1}/{epochs}, loss={total/len(trainloader):.4f}, acc={acc:.4f}")

student_acc_before      = evaluate(student, testloader)
student_params_before_T = count_total_params(student)
student_params_before_L = count_linear_params(student)

# компрессия линейных слоёв
print("сжимаем линейные слои svd с rank_ratio=", rank_ratio)
stats = compress_model_linear_only(student, rank_ratio=rank_ratio)

student_params_after_T = count_total_params(student)
student_params_after_L = count_linear_params(student)

# короткий fine-tune после SVD
opt_f = torch.optim.SGD(student.parameters(), lr=lr, momentum=0.9)
print("finetune student после svd")
for e in range(finetune_epochs):
    loss = train_epoch(student, trainloader, opt_f, crit)
    acc  = evaluate(student, testloader)
    print(f"finetune epoch {e+1}/{finetune_epochs}, loss={loss:.4f}, acc={acc:.4f}")

student_acc_after = evaluate(student, testloader)

# вывод по критерию задачи
print("\nподробные слои (только linear):")
for name, inf, outf, r, orig_p, new_p in stats:
    ratio = orig_p / new_p if new_p>0 else 0.0
    print(f"layer {name}: in={inf} out={outf} k={r} params {orig_p}->{new_p}, сжатие {ratio:.2f}x")

print("\nитог по линейным слоям:")
print(f"linear params до: {student_params_before_L}, после: {student_params_after_L}, сжатие: {student_params_before_L/student_params_after_L:.2f}x")

print("\nсравнение моделей:")
print(f"teacher params (все): {teacher_params}")
print(f"student params (все) до: {student_params_before_T}, после: {student_params_after_T}")
print(f"степень сжатия по всем параметрам student: {student_params_before_T/student_params_after_T:.2f}x")

print("\nточность:")
print(f"точность teacher: {teacher_acc:.4f}")
print(f"точность student до сжатия: {student_acc_before:.4f}")
print(f"точность student после сжатия+finetune: {student_acc_after:.4f}")

обучаем teacher




epoch 1/10, loss=1.6366, acc=0.5442




epoch 2/10, loss=1.1786, acc=0.6134




epoch 3/10, loss=0.9720, acc=0.6602




epoch 4/10, loss=0.8320, acc=0.6837




epoch 5/10, loss=0.7138, acc=0.7126




epoch 6/10, loss=0.5960, acc=0.7312




epoch 7/10, loss=0.4914, acc=0.7310




epoch 8/10, loss=0.3748, acc=0.7427




epoch 9/10, loss=0.2718, acc=0.7404




epoch 10/10, loss=0.1838, acc=0.7425
обучаем student с distillation




epoch 1/10, loss=4.6862, acc=0.5946




epoch 2/10, loss=2.1613, acc=0.6774




epoch 3/10, loss=1.4373, acc=0.7112




epoch 4/10, loss=1.0322, acc=0.7282




epoch 5/10, loss=0.7648, acc=0.7277




epoch 6/10, loss=0.5780, acc=0.7388




epoch 7/10, loss=0.4501, acc=0.7351




epoch 8/10, loss=0.3611, acc=0.7367




epoch 9/10, loss=0.3009, acc=0.7377




epoch 10/10, loss=0.2557, acc=0.7400
сжимаем линейные слои svd с rank_ratio= 0.3
finetune student после svd




finetune epoch 1/10, loss=0.9227, acc=0.6769




finetune epoch 2/10, loss=0.6916, acc=0.6978




finetune epoch 3/10, loss=0.5638, acc=0.7045




finetune epoch 4/10, loss=0.4764, acc=0.7008




finetune epoch 5/10, loss=0.3950, acc=0.7046




finetune epoch 6/10, loss=0.3300, acc=0.7005




finetune epoch 7/10, loss=0.2786, acc=0.7037




finetune epoch 8/10, loss=0.2521, acc=0.7028




finetune epoch 9/10, loss=0.2072, acc=0.6891




finetune epoch 10/10, loss=0.1827, acc=0.7011

подробные слои (только linear):
layer fc1: in=4096 out=256 k=76 params 1048832->331008, сжатие 3.17x
layer fc2: in=256 out=10 k=3 params 2570->808, сжатие 3.18x

итог по линейным слоям:
linear params до: 1051402, после: 331816, сжатие: 3.17x

сравнение моделей:
teacher params (все): 4275594
student params (все) до: 1070794, после: 351208
степень сжатия по всем параметрам student: 3.05x

точность:
точность teacher: 0.7425
точность student до сжатия: 0.7400
точность student после сжатия+finetune: 0.7011
