In [12]:
from EMA_for_weights import EMA
import os
import json
import random
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms as T
import timm
from tqdm.auto import tqdm
import pandas as pd

# --- Repro/paths/hparams ---
SEED = 42
ISIC_ROOT = os.path.join('data', 'ISIC')
TRAIN_CSV = os.path.join(ISIC_ROOT, 'split_train.csv')
VAL_CSV = os.path.join(ISIC_ROOT, 'split_test.csv')
LABELS_JSON = os.path.join(ISIC_ROOT, 'labels.json')

CKPT_TEACHER = os.path.join('data', 'model_weights', 'deit_s_best.pth')
CKPT_STUDENT = os.path.join('data', 'model_weights', 'student_best.pth')
CKPT_STUDENT_EMA = CKPT_STUDENT.replace('.pth', '_ema.pth')
os.makedirs(os.path.dirname(CKPT_STUDENT), exist_ok=True)

IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 15
BASE_LR = 3e-4
WEIGHT_DECAY = 0.05
NUM_WORKERS = 0  # Windows/Jupyter
KD_T = 4.0
KD_ALPHA = 0.7

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class ISICCsvDataset(Dataset):
    def __init__(self, csv_path: str, label2idx=None, tfm=None):
        df = pd.read_csv(csv_path)
        if 'path' not in df.columns or 'label' not in df.columns:
            raise ValueError("CSV должен содержать столбцы 'path' и 'label'")
        self.paths = df['path'].tolist()
        self.labels = df['label'].tolist()
        if label2idx is None:
            uniq = sorted(pd.unique(self.labels).tolist())
            label2idx = {lbl: i for i, lbl in enumerate(uniq)}
        self.label2idx = label2idx
        self.tfm = tfm

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        y_str = self.labels[idx]
        y = self.label2idx[y_str]
        with Image.open(p) as img:
            img = img.convert('RGB')
        if self.tfm:
            img = self.tfm(img)
        return img, y

def build_transforms(img_size=224):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    train_tfm = T.Compose([
        T.RandomResizedCrop(img_size, scale=(0.7, 1.0), ratio=(0.75, 1.33)),
        T.RandomHorizontalFlip(p=0.5),
        T.ColorJitter(0.2, 0.2, 0.2, 0.1),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    val_tfm = T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    return train_tfm, val_tfm

def make_weighted_sampler(labels_idx, num_classes):
    counts = np.bincount(labels_idx, minlength=num_classes).astype(np.float32)
    inv = 1.0 / np.maximum(counts, 1.0)
    weights = [inv[y] for y in labels_idx]
    return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True), counts

@torch.no_grad()
def evaluate(model, loader, device, criterion):
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        total_loss += loss.item() * y.size(0)
        preds = logits.argmax(1)
        total_correct += (preds == y).sum().item()
        total += y.size(0)
    return total_loss / max(total, 1), total_correct / max(total, 1)

def kd_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    soft = nn.KLDivLoss(reduction="batchmean")(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1)
    ) * (T * T)
    hard = F.cross_entropy(student_logits, labels)
    return alpha * soft + (1 - alpha) * hard

# --- Data/teacher/student/EMA ---
set_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

with open(LABELS_JSON, 'r', encoding='utf-8') as f:
    label2idx = json.load(f)
idx2label = {v: k for k, v in label2idx.items()}
num_classes = len(label2idx)

train_tfm, val_tfm = build_transforms(IMG_SIZE)
train_df = pd.read_csv(TRAIN_CSV)
train_labels_idx = [label2idx[l] for l in train_df['label'].tolist()]
sampler, class_counts = make_weighted_sampler(train_labels_idx, num_classes)

train_ds = ISICCsvDataset(TRAIN_CSV, label2idx=label2idx, tfm=train_tfm)
val_ds = ISICCsvDataset(VAL_CSV, label2idx=label2idx, tfm=val_tfm)

pin = device.type == 'cuda'
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler,
                          num_workers=NUM_WORKERS, pin_memory=pin, persistent_workers=False)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=pin, persistent_workers=False)

print('Распределение классов (train):', class_counts.tolist())

Device: cuda
Распределение классов (train): [262.0, 411.0, 879.0, 92.0, 890.0, 5364.0, 114.0]


In [13]:
teacher = timm.create_model('deit_small_patch16_224', pretrained=False, num_classes=num_classes)
ckpt_t = torch.load(CKPT_TEACHER, map_location='cpu')
teacher.load_state_dict(ckpt_t['model'], strict=True)
teacher.to(device).eval()
for p in teacher.parameters():
    p.requires_grad_(False)

student = timm.create_model('deit_tiny_patch16_224', pretrained=True, num_classes=num_classes).to(device)
ema = EMA(student, decay=0.999)

criterion_ce = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(student.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda', enabled=(device.type == 'cuda'))

best_acc = 0.0
best_acc_ema = 0.0

# --- Train KD + EMA ---
for epoch in range(1, EPOCHS + 1):
    student.train()
    running_loss, running_correct, seen = 0.0, 0, 0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{EPOCHS}', leave=False)
    for x, y in pbar:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

        with torch.no_grad():
            t_logits = teacher(x)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            s_logits = student(x)
            loss = kd_loss(s_logits, t_logits, y, T=KD_T, alpha=KD_ALPHA)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # EMA update
        ema.update(student)

        # stats
        bs = y.size(0)
        running_loss += loss.item() * bs
        seen += bs
        preds = s_logits.argmax(1)
        running_correct += (preds == y).sum().item()
        batch_acc = (preds == y).float().mean().item()

        pbar.set_postfix(batch_loss=f'{loss.item():.4f}',
                         batch_acc=f'{batch_acc:.4f}',
                         lr=f"{optimizer.param_groups[0]['lr']:.2e}")

    scheduler.step()
    train_loss = running_loss / max(seen, 1)
    train_acc = running_correct / max(seen, 1)

    val_loss, val_acc = evaluate(student, val_loader, device, criterion_ce)
    val_loss_ema, val_acc_ema = evaluate(ema.ema_model, val_loader, device, criterion_ce)

    print(f'Epoch {epoch:03d} | '
          f'train_loss={train_loss:.4f}  train_acc={train_acc:.4f}  '
          f'val_loss={val_loss:.4f}  val_acc={val_acc:.4f}  '
          f'val_loss_ema={val_loss_ema:.4f}  val_acc_ema={val_acc_ema:.4f}  '
          f'lr={scheduler.get_last_lr()[0]:.2e}')

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({'model': student.state_dict(),
                    'label2idx': label2idx,
                    'epoch': epoch,
                    'val_acc': val_acc}, CKPT_STUDENT)
        print(f'Save best student -> {CKPT_STUDENT}')

    if val_acc_ema > best_acc_ema:
        best_acc_ema = val_acc_ema
        torch.save({'model': ema.ema_model.state_dict(),
                    'label2idx': label2idx,
                    'epoch': epoch,
                    'val_acc': val_acc_ema}, CKPT_STUDENT_EMA)
        print(f'Save best student EMA -> {CKPT_STUDENT_EMA}')

print(f'Best val_acc (student): {best_acc:.4f}')
print(f'Best val_acc (student EMA): {best_acc_ema:.4f}')

                                                                                                               

Epoch 001 | train_loss=4.5493  train_acc=0.5549  val_loss=0.9843  val_acc=0.7144  val_loss_ema=1.6258  val_acc_ema=0.4413  lr=2.97e-04
Save best student -> data\model_weights\student_best.pth
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                               

Epoch 002 | train_loss=2.3536  train_acc=0.7159  val_loss=0.7749  val_acc=0.7424  val_loss_ema=1.0361  val_acc_ema=0.6540  lr=2.87e-04
Save best student -> data\model_weights\student_best.pth
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                               

Epoch 003 | train_loss=1.7870  train_acc=0.7681  val_loss=0.9308  val_acc=0.7054  val_loss_ema=0.7617  val_acc_ema=0.7189  lr=2.71e-04
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                               

Epoch 004 | train_loss=1.3483  train_acc=0.8100  val_loss=1.0153  val_acc=0.7004  val_loss_ema=0.6464  val_acc_ema=0.7609  lr=2.50e-04
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                               

Epoch 005 | train_loss=1.1082  train_acc=0.8298  val_loss=0.7065  val_acc=0.7818  val_loss_ema=0.6133  val_acc_ema=0.7718  lr=2.25e-04
Save best student -> data\model_weights\student_best.pth
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                               

Epoch 006 | train_loss=0.8252  train_acc=0.8676  val_loss=0.6541  val_acc=0.7868  val_loss_ema=0.5915  val_acc_ema=0.7863  lr=1.96e-04
Save best student -> data\model_weights\student_best.pth
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                               

Epoch 007 | train_loss=0.7681  train_acc=0.8702  val_loss=0.6470  val_acc=0.7888  val_loss_ema=0.5792  val_acc_ema=0.7948  lr=1.66e-04
Save best student -> data\model_weights\student_best.pth
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                               

Epoch 008 | train_loss=0.6103  train_acc=0.8895  val_loss=0.6202  val_acc=0.7973  val_loss_ema=0.5735  val_acc_ema=0.7978  lr=1.34e-04
Save best student -> data\model_weights\student_best.pth
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                               

Epoch 009 | train_loss=0.4925  train_acc=0.9033  val_loss=0.6828  val_acc=0.7773  val_loss_ema=0.5707  val_acc_ema=0.8033  lr=1.04e-04
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                                

Epoch 010 | train_loss=0.4323  train_acc=0.9177  val_loss=0.5504  val_acc=0.8273  val_loss_ema=0.5663  val_acc_ema=0.8078  lr=7.50e-05
Save best student -> data\model_weights\student_best.pth
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                                

Epoch 011 | train_loss=0.3914  train_acc=0.9230  val_loss=0.5468  val_acc=0.8233  val_loss_ema=0.5573  val_acc_ema=0.8138  lr=4.96e-05
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                                

Epoch 012 | train_loss=0.3449  train_acc=0.9331  val_loss=0.5685  val_acc=0.8173  val_loss_ema=0.5507  val_acc_ema=0.8173  lr=2.86e-05
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                                

Epoch 013 | train_loss=0.3279  train_acc=0.9392  val_loss=0.5678  val_acc=0.8228  val_loss_ema=0.5483  val_acc_ema=0.8223  lr=1.30e-05
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                                

Epoch 014 | train_loss=0.3138  train_acc=0.9403  val_loss=0.5394  val_acc=0.8362  val_loss_ema=0.5462  val_acc_ema=0.8253  lr=3.28e-06
Save best student -> data\model_weights\student_best.pth
Save best student EMA -> data\model_weights\student_best_ema.pth


                                                                                                                

Epoch 015 | train_loss=0.3152  train_acc=0.9395  val_loss=0.5419  val_acc=0.8333  val_loss_ema=0.5445  val_acc_ema=0.8308  lr=0.00e+00
Save best student EMA -> data\model_weights\student_best_ema.pth
Best val_acc (student): 0.8362
Best val_acc (student EMA): 0.8308


## Quantization (PTQ) and comparison

In [14]:
# python
import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearWQAT(nn.Module):
    def __init__(self, base_linear: nn.Linear, per_channel: bool = True):
        super().__init__()
        self.in_features = base_linear.in_features
        self.out_features = base_linear.out_features
        self.has_bias = base_linear.bias is not None

        self.weight = nn.Parameter(base_linear.weight.detach().clone())
        self.bias = nn.Parameter(base_linear.bias.detach().clone()) if self.has_bias else None

        from torch.ao.quantization import FakeQuantize, default_per_channel_weight_observer, default_weight_observer
        if per_channel:
            self.weight_fake = FakeQuantize(
                observer=default_per_channel_weight_observer,
                quant_min=-128, quant_max=127,
                dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0
            )
        else:
            self.weight_fake = FakeQuantize(
                observer=default_weight_observer,
                quant_min=-128, quant_max=127,
                dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
            )

    def forward(self, x):
        w_q = self.weight_fake(self.weight)
        return F.linear(x, w_q, self.bias)

In [15]:
import io
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

torch.backends.quantized.engine = 'fbgemm'  # Windows/x86

CKPT_STUDENT = os.path.join('data', 'model_weights', 'student_best.pth')
CKPT_STUDENT_INT8 = CKPT_STUDENT.replace('.pth', '_int8.pth')
CKPT_STUDENT_EMA = CKPT_STUDENT.replace('.pth', '_ema.pth')
CKPT_STUDENT_EMA_INT8 = CKPT_STUDENT_EMA.replace('.pth', '_int8.pth')

@torch.no_grad()
def eval_logits(model, loader, device):
    model.eval().to(device)
    total, correct, total_loss = 0, 0, 0.0
    all_logits, all_labels = [], []
    ce = nn.CrossEntropyLoss()
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        loss = ce(logits, y)
        total_loss += loss.item() * y.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)
        all_logits.append(logits.cpu())
        all_labels.append(y.cpu())
    return torch.cat(all_logits, 0), torch.cat(all_labels, 0), total_loss / max(total, 1), correct / max(total, 1)

def state_dict_size_bytes(state_dict):
    buf = io.BytesIO()
    torch.save(state_dict, buf)
    return buf.tell()

# 1) Load FP32 student and EMA-student
student_fp32 = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes)
ckpt_s = torch.load(CKPT_STUDENT, map_location='cpu')
student_fp32.load_state_dict(ckpt_s['model'], strict=True)

student_ema_fp32 = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes)
ckpt_se = torch.load(CKPT_STUDENT_EMA, map_location='cpu')
student_ema_fp32.load_state_dict(ckpt_se['model'], strict=True)

# 2) Evaluate FP32 on CPU
logits_s_fp32, labels_cpu, s_fp32_loss, s_fp32_acc = eval_logits(student_fp32, val_loader, device=torch.device('cpu'))
logits_se_fp32, _, se_fp32_loss, se_fp32_acc = eval_logits(student_ema_fp32, val_loader, device=torch.device('cpu'))

# 3) Dynamic INT8 quantization and eval
student_int8 = torch.ao.quantization.quantize_dynamic(copy.deepcopy(student_fp32), {nn.Linear}, dtype=torch.qint8)
student_ema_int8 = torch.ao.quantization.quantize_dynamic(copy.deepcopy(student_ema_fp32), {nn.Linear}, dtype=torch.qint8)

logits_s_int8, _, s_int8_loss, s_int8_acc = eval_logits(student_int8, val_loader, device=torch.device('cpu'))
logits_se_int8, _, se_int8_loss, se_int8_acc = eval_logits(student_ema_int8, val_loader, device=torch.device('cpu'))

# 4) Outputs and sizes comparison
with torch.no_grad():
    mae_s = (logits_s_fp32 - logits_s_int8.float()).abs().mean().item()
    top1_s = (logits_s_fp32.argmax(1) == logits_s_int8.argmax(1)).float().mean().item()
    kl_s = F.kl_div(F.log_softmax(logits_s_int8.float(), dim=1), F.softmax(logits_s_fp32, dim=1), reduction='batchmean').item()

    mae_se = (logits_se_fp32 - logits_se_int8.float()).abs().mean().item()
    top1_se = (logits_se_fp32.argmax(1) == logits_se_int8.argmax(1)).float().mean().item()
    kl_se = F.kl_div(F.log_softmax(logits_se_int8.float(), dim=1), F.softmax(logits_se_fp32, dim=1), reduction='batchmean').item()

sd_s_fp32 = student_fp32.state_dict()
sd_s_int8 = student_int8.state_dict()
sd_se_fp32 = student_ema_fp32.state_dict()
sd_se_int8 = student_ema_int8.state_dict()

size_s_fp32 = state_dict_size_bytes(sd_s_fp32)
size_s_int8 = state_dict_size_bytes(sd_s_int8)
size_se_fp32 = state_dict_size_bytes(sd_se_fp32)
size_se_int8 = state_dict_size_bytes(sd_se_int8)

# 5) Save INT8 checkpoints
torch.save({'model': sd_s_int8, 'label2idx': label2idx, 'from': CKPT_STUDENT}, CKPT_STUDENT_INT8)
torch.save({'model': sd_se_int8, 'label2idx': label2idx, 'from': CKPT_STUDENT_EMA}, CKPT_STUDENT_EMA_INT8)

# 6) Print
print('=== Сравнение FP32 vs INT8 (CPU) — Student ===')
print(f'val_acc FP32: {s_fp32_acc:.4f} | INT8: {s_int8_acc:.4f} | Δ: {s_int8_acc - s_fp32_acc:+.4f}')
print(f'val_loss FP32: {s_fp32_loss:.4f} | INT8: {s_int8_loss:.4f} | Δ: {s_int8_loss - s_fp32_loss:+.4f}')
print(f'Top-1 совпадение: {top1_s:.4f} | MAE logits: {mae_s:.6f} | KL(fp32||int8): {kl_s:.6f}')
print(f'Размер FP32: {size_s_fp32/1024/1024:.2f} MB | INT8: {size_s_int8/1024/1024:.2f} MB '
      f'({(size_s_int8/size_s_fp32)*100:.1f}% от FP32)')
print(f'INT8 сохранён: {CKPT_STUDENT_INT8}')

print('\n=== Сравнение FP32 vs INT8 (CPU) — Student EMA ===')
print(f'val_acc FP32: {se_fp32_acc:.4f} | INT8: {se_int8_acc:.4f} | Δ: {se_int8_acc - se_fp32_acc:+.4f}')
print(f'val_loss FP32: {se_fp32_loss:.4f} | INT8: {se_int8_loss:.4f} | Δ: {se_int8_loss - se_fp32_loss:+.4f}')
print(f'Top-1 совпадение: {top1_se:.4f} | MAE logits: {mae_se:.6f} | KL(fp32||int8): {kl_se:.6f}')
print(f'Размер FP32: {size_se_fp32/1024/1024:.2f} MB | INT8: {size_se_int8/1024/1024:.2f} MB '
      f'({(size_se_int8/size_se_fp32)*100:.1f}% от FP32)')
print(f'INT8 сохранён: {CKPT_STUDENT_EMA_INT8}')

=== Сравнение FP32 vs INT8 (CPU) — Student ===
val_acc FP32: 0.8362 | INT8: 0.8337 | Δ: -0.0025
val_loss FP32: 0.5394 | INT8: 0.5437 | Δ: +0.0043
Top-1 совпадение: 0.9870 | MAE logits: 0.065218 | KL(fp32||int8): 0.001641
Размер FP32: 21.13 MB | INT8: 5.97 MB (28.3% от FP32)
INT8 сохранён: data\model_weights\student_best_int8.pth

=== Сравнение FP32 vs INT8 (CPU) — Student EMA ===
val_acc FP32: 0.8308 | INT8: 0.8288 | Δ: -0.0020
val_loss FP32: 0.5445 | INT8: 0.5456 | Δ: +0.0011
Top-1 совпадение: 0.9870 | MAE logits: 0.065542 | KL(fp32||int8): 0.002174
Размер FP32: 21.13 MB | INT8: 5.97 MB (28.3% от FP32)
INT8 сохранён: data\model_weights\student_best_ema_int8.pth


## QAT-lite (fake-quant weights of nn.Linear) and PTQ vs QAT-lite comparison

In [16]:
import copy
import torch
import torch.nn as nn
import timm

# Пути для QAT-lite
CKPT_STUDENT_QAT = CKPT_STUDENT.replace('.pth', '_qatlite.pth')
CKPT_STUDENT_QAT_INT8 = CKPT_STUDENT_QAT.replace('.pth', '_int8.pth')

# Гиперпараметры QAT-lite
QAT_EPOCHS = 5  # 3–10
QAT_LR = 3e-5  # малый LR
QAT_WD = 0.0
QAT_PER_CHANNEL = True  # per-channel fake-quant


def wrap_linears_with_wqat(module: nn.Module, per_channel=True):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.Linear):
            setattr(module, name, LinearWQAT(child, per_channel=per_channel))
        else:
            wrap_linears_with_wqat(child, per_channel=per_channel)
    return module


def unwrap_wqat_to_linear(module: nn.Module):
    for name, child in list(module.named_children()):
        if isinstance(child, LinearWQAT):
            lin = nn.Linear(child.in_features, child.out_features, bias=child.has_bias)
            with torch.no_grad():
                lin.weight.copy_(child.weight.data)
                if child.has_bias:
                    lin.bias.copy_(child.bias.data)
            setattr(module, name, lin)
        else:
            unwrap_wqat_to_linear(child)
    return module


@torch.no_grad()
def eval_acc_loss(model, loader, device):
    model.eval().to(device)
    ce = nn.CrossEntropyLoss()
    total, correct, total_loss = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        loss = ce(logits, y)
        total_loss += loss.item() * y.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)
    return total_loss / max(total, 1), correct / max(total, 1)


# 0) База для QAT-lite: берём EMA-студента если есть
student_base = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes)
if os.path.isfile(CKPT_STUDENT_EMA):
    ckpt_base = torch.load(CKPT_STUDENT_EMA, map_path='cpu')
else:
    ckpt_base = torch.load(CKPT_STUDENT, map_path='cpu')  # fallback
    if 'model' not in ckpt_base:
        ckpt_base = torch.load(CKPT_STUDENT, map_location='cpu')
student_base.load_state_dict(ckpt_base['model'], strict=True)

# Базовая PTQ (динамическая) для сравнения
ptq_base = torch.ao.quantization.quantize_dynamic(copy.deepcopy(student_base), {nn.Linear}, dtype=torch.qint8)
ptq_loss, ptq_acc = eval_acc_loss(ptq_base, val_loader, device=torch.device('cpu'))

# 1) Готовим и обучаем QAT-lite
student_qat = copy.deepcopy(student_base)
student_qat = wrap_linears_with_wqat(student_qat, per_channel=QAT_PER_CHANNEL)
student_qat.to(device)

teacher.eval().to(device)
for p in teacher.parameters():
    p.requires_grad_(False)

optimizer_qat = torch.optim.AdamW(student_qat.parameters(), lr=QAT_LR, weight_decay=QAT_WD)

print(f'QAT-lite start: epochs={QAT_EPOCHS}, lr={QAT_LR}, wd={QAT_WD}')
for epoch in range(1, QAT_EPOCHS + 1):
    student_qat.train()
    run_loss, seen = 0.0, 0
    for x, y in tqdm(train_loader, desc=f'QAT-lite {epoch}/{QAT_EPOCHS}', leave=False):
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        with torch.no_grad():
            t_logits = teacher(x)
        optimizer_qat.zero_grad(set_to_none=True)
        # без AMP для стабильности fake-quant
        s_logits = student_qat(x)
        loss = kd_loss(s_logits, t_logits, y, T=KD_T, alpha=KD_ALPHA)
        loss.backward()
        optimizer_qat.step()

        bs = y.size(0)
        run_loss += loss.item() * bs
        seen += bs
    val_loss_qat, val_acc_qat = eval_acc_loss(student_qat, val_loader, device)
    print(
        f'Epoch {epoch:02d} | train_loss={run_loss / max(seen, 1):.4f}  val_loss={val_loss_qat:.4f}  val_acc={val_acc_qat:.4f}')

# 2) Разворачиваем в Linear и сохраняем FP32 чекпойнт
student_qat_fp32 = copy.deepcopy(student_qat).cpu()
student_qat_fp32 = unwrap_wqat_to_linear(student_qat_fp32)
torch.save({'model': student_qat_fp32.state_dict(), 'label2idx': label2idx}, CKPT_STUDENT_QAT)
print(f'Save QAT-lite FP32 -> {CKPT_STUDENT_QAT}')

# 3) INT8 после QAT-lite и оценка
student_qat_int8 = torch.ao.quantization.quantize_dynamic(copy.deepcopy(student_qat_fp32), {nn.Linear},
                                                          dtype=torch.qint8)
qat_int8_loss, qat_int8_acc = eval_acc_loss(student_qat_int8, val_loader, device=torch.device('cpu'))
torch.save({'model': student_qat_int8.state_dict(), 'label2idx': label2idx, 'from': CKPT_STUDENT_QAT},
           CKPT_STUDENT_QAT_INT8)
print(f'Save QAT-lite INT8 -> {CKPT_STUDENT_QAT_INT8}')

# 4) Финальное сравнение INT8: PTQ(EMA) vs QAT-lite(EMA)
print('\n=== CPU INT8: PTQ(EMA) vs QAT-lite(EMA) ===')
print(f'PTQ  INT8 -> val_loss: {ptq_loss:.4f}  val_acc: {ptq_acc:.4f}')
print(f'QATL INT8 -> val_loss: {qat_int8_loss:.4f}  val_acc: {qat_int8_acc:.4f}')
print(f'Δacc (QAT-lite - PTQ): {qat_int8_acc - ptq_acc:+.4f}')

QAT-lite start: epochs=5, lr=3e-05, wd=0.0


                                                               

Epoch 01 | train_loss=0.3321  val_loss=0.5259  val_acc=0.8342


                                                               

Epoch 02 | train_loss=0.3282  val_loss=0.5572  val_acc=0.8208


                                                               

Epoch 03 | train_loss=0.3249  val_loss=0.5293  val_acc=0.8342


                                                               

Epoch 04 | train_loss=0.3198  val_loss=0.5771  val_acc=0.8138


                                                               

Epoch 05 | train_loss=0.3078  val_loss=0.5415  val_acc=0.8253
Save QAT-lite FP32 -> data\model_weights\student_best_qatlite.pth
Save QAT-lite INT8 -> data\model_weights\student_best_qatlite_int8.pth

=== CPU INT8: PTQ(EMA) vs QAT-lite(EMA) ===
PTQ  INT8 -> val_loss: 0.5456  val_acc: 0.8288
QATL INT8 -> val_loss: 0.5446  val_acc: 0.8238
Δacc (QAT-lite - PTQ): -0.0050


##  CPU-бенчмарк

In [17]:
# python
# =========================
# Блок 1. CPU-бенчмарк
# =========================
import os
import time
import math
import psutil
import numpy as np
import torch
import torch.nn as nn
import timm

torch.backends.quantized.engine = 'fbgemm'  # x86/Windows

def cache_batches(loader, limit=None):
    """Загружает несколько батчей в RAM для стабильных замеров (исключая I/O/аугментации во время замеров)."""
    cached = []
    for i, (x, y) in enumerate(loader):
        cached.append((x.cpu(), y.cpu()))
        if limit is not None and (i + 1) >= limit:
            break
    return cached

@torch.no_grad()
def benchmark_cpu(model: torch.nn.Module,
                  cached_batches,
                  warmup_steps=50,
                  measure_steps=100,
                  reps=5):
    """Возвращает словарь: p50/p90 latency (мс), throughput (img/s), peak_ram_mb."""
    proc = psutil.Process(os.getpid())
    model.eval().to('cpu')
    lat_ms = []
    total_imgs = 0
    total_time = 0.0
    peak_rss = proc.memory_info().rss

    # Фиксируем список батчей для циклического прохода
    if not cached_batches:
        raise RuntimeError('Нет закешированных батчей для бенчмарка.')
    nb = len(cached_batches)

    # Прогрев
    j = 0
    for _ in range(warmup_steps):
        x, _ = cached_batches[j % nb]
        j += 1
        _ = model(x)  # forward на CPU
        peak_rss = max(peak_rss, proc.memory_info().rss)

    # Замеры
    j = 0
    for _ in range(reps):
        run_imgs = 0
        run_time = 0.0
        for _ in range(measure_steps):
            x, _ = cached_batches[j % nb]
            j += 1
            t0 = time.perf_counter()
            _ = model(x)
            dt = time.perf_counter() - t0
            lat_ms.append(dt * 1e3)
            run_imgs += x.size(0)
            run_time += dt
            peak_rss = max(peak_rss, proc.memory_info().rss)
        total_imgs += run_imgs
        total_time += run_time

    lat_ms = np.asarray(lat_ms, dtype=np.float64)
    p50 = float(np.percentile(lat_ms, 50))
    p90 = float(np.percentile(lat_ms, 90))
    thr = float(total_imgs / total_time) if total_time > 0 else float('nan')
    peak_mb = peak_rss / (1024.0 * 1024.0)
    return {'latency_p50_ms': p50, 'latency_p90_ms': p90, 'throughput_img_s': thr, 'peak_ram_mb': peak_mb}

def build_student_fp32(num_classes: int, ckpt_path: str):
    m = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes)
    sd = torch.load(ckpt_path, map_location='cpu')['model']
    m.load_state_dict(sd, strict=True)
    return m

def build_student_ema_fp32(num_classes: int, ckpt_path: str):
    m = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes)
    sd = torch.load(ckpt_path, map_location='cpu')['model']
    m.load_state_dict(sd, strict=True)
    return m

def to_int8_dynamic(fp32_model: nn.Module):
    return torch.ao.quantization.quantize_dynamic(fp32_model, {nn.Linear}, dtype=torch.qint8)

# Подхватим уже определённые объекты и пути (есть в ноутбуке выше),
# а при их отсутствии — создадим/определим.
try:
    num_classes
except NameError:
    # На всякий случай — читаем labels.json
    import json
    with open(os.path.join('data', 'ISIC', 'labels.json'), 'r', encoding='utf-8') as f:
        label2idx = json.load(f)
    num_classes = len(label2idx)

CKPT_STUDENT = os.path.join('data', 'model_weights', 'student_best.pth')
CKPT_STUDENT_EMA = CKPT_STUDENT.replace('.pth', '_ema.pth')
CKPT_STUDENT_QAT = CKPT_STUDENT.replace('.pth', '_qatlite.pth')

# Собираем модели для сравнения
models_for_bench = {}

# FP32 student
try:
    models_for_bench['student_fp32'] = student_fp32  # из предыдущего блока, если есть
except NameError:
    if os.path.isfile(CKPT_STUDENT):
        models_for_bench['student_fp32'] = build_student_fp32(num_classes, CKPT_STUDENT)

# FP32 student EMA
try:
    models_for_bench['student_ema_fp32'] = student_ema_fp32
except NameError:
    if os.path.isfile(CKPT_STUDENT_EMA):
        models_for_bench['student_ema_fp32'] = build_student_ema_fp32(num_classes, CKPT_STUDENT_EMA)

# INT8 (PTQ) — строим на лету из FP32
if 'student_fp32' in models_for_bench:
    models_for_bench['student_int8_ptq'] = to_int8_dynamic(timm.create_model(
        'deit_tiny_patch16_224', pretrained=False, num_classes=num_classes
    ).load_state_dict(models_for_bench['student_fp32'].state_dict()) or models_for_bench['student_fp32'])

if 'student_ema_fp32' in models_for_bench:
    models_for_bench['student_ema_int8_ptq'] = to_int8_dynamic(timm.create_model(
        'deit_tiny_patch16_224', pretrained=False, num_classes=num_classes
    ).load_state_dict(models_for_bench['student_ema_fp32'].state_dict()) or models_for_bench['student_ema_fp32'])

# QAT-lite FP32 (если есть чекпойнт) + его INT8
if os.path.isfile(CKPT_STUDENT_QAT):
    m_qat = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes)
    m_qat.load_state_dict(torch.load(CKPT_STUDENT_QAT, map_location='cpu')['model'], strict=True)
    models_for_bench['student_qatlite_fp32'] = m_qat
    models_for_bench['student_qatlite_int8'] = to_int8_dynamic(copy.deepcopy(m_qat))

# Готовим батчи для стабильных замеров (например, 64 батча из val_loader)
try:
    val_loader
except NameError:
    raise RuntimeError('val_loader не найден. Выполните блок подготовки данных выше.')
cached_val = cache_batches(val_loader, limit=64)

print('=== CPU benchmark (batch из val_loader) ===')
for name, model in models_for_bench.items():
    m = model  # уже собран
    res = benchmark_cpu(m, cached_val, warmup_steps=50, measure_steps=100, reps=5)
    print(f'{name:>22s} | p50={res["latency_p50_ms"]:.2f} ms | p90={res["latency_p90_ms"]:.2f} ms | '
          f'thr={res["throughput_img_s"]:.1f} img/s | peak RAM={res["peak_ram_mb"]:.1f} MB')


AttributeError: '_IncompatibleKeys' object has no attribute 'eval'

##  Метрики качества

In [None]:
# =========================
# Блок 2. Метрики качества
# =========================
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_recall_fscore_support,
    confusion_matrix,
    roc_auc_score,
)
import torch.nn.functional as F

@torch.no_grad()
def collect_logits_labels(model: nn.Module, loader, device=torch.device('cpu')):
    model.eval().to(device)
    logits_list, labels_list = [], []
    for x, y in loader:
        x = x.to(device)
        logits = model(x)
        logits_list.append(logits.cpu())
        labels_list.append(y.cpu())
    logits = torch.cat(logits_list, dim=0)
    labels = torch.cat(labels_list, dim=0)
    return logits, labels

def compute_classification_metrics(logits: torch.Tensor,
                                   labels: torch.Tensor,
                                   class_names=None):
    y_true = labels.numpy()
    y_pred = logits.argmax(dim=1).numpy()
    num_classes = logits.size(1)

    acc = accuracy_score(y_true, y_pred)
    macro_f1 = f1_score(y_true, y_pred, average='macro')

    # per-class precision/recall/f1/support
    prec_c, rec_c, f1_c, supp_c = precision_recall_fscore_support(
        y_true, y_pred, labels=np.arange(num_classes), average=None, zero_division=0
    )

    # ROC-AUC (macro)
    y_prob = F.softmax(logits, dim=1).numpy()
    try:
        if num_classes == 2:
            roc_auc_macro = roc_auc_score(y_true, y_prob[:, 1])
        else:
            roc_auc_macro = roc_auc_score(y_true, y_prob, multi_class='ovr', average='macro')
    except ValueError:
        roc_auc_macro = float('nan')  # если в y_true отсутствует какой-то класс

    cm = confusion_matrix(y_true, y_pred, labels=np.arange(num_classes))

    per_class = []
    for i in range(num_classes):
        name = class_names[i] if (class_names is not None and i < len(class_names)) else f'class_{i}'
        per_class.append({
            'class': name,
            'precision': float(prec_c[i]),
            'recall': float(rec_c[i]),
            'f1': float(f1_c[i]),
            'support': int(supp_c[i]),
        })

    return {
        'accuracy': float(acc),
        'macro_f1': float(macro_f1),
        'roc_auc_macro': float(roc_auc_macro) if not math.isnan(roc_auc_macro) else float('nan'),
        'per_class': per_class,
        'confusion_matrix': cm.astype(int).tolist(),
    }

def print_metrics(tag: str, m: dict, max_classes=20):
    print(f'[{tag}] accuracy={m["accuracy"]:.4f} | macro-F1={m["macro_f1"]:.4f} | ROC-AUC(macro)={m["roc_auc_macro"]:.4f}')
    print('per-class (precision/recall/f1/support):')
    for i, row in enumerate(m['per_class'][:max_classes]):
        print(f'  {row["class"]:<16s} P={row["precision"]:.3f} R={row["recall"]:.3f} F1={row["f1"]:.3f} n={row["support"]}')
    if len(m['per_class']) > max_classes:
        print(f'  ... {len(m["per_class"]) - max_classes} классов скрыто')
    cm = np.array(m['confusion_matrix'], dtype=int)
    print(f'confusion matrix: shape={cm.shape}, sum={cm.sum()}')

# Имена классов (если есть)
try:
    idx2label
except NameError:
    try:
        with open(os.path.join('data', 'ISIC', 'labels.json'), 'r', encoding='utf-8') as f:
            label2idx = json.load(f)
        idx2label = {v: k for k, v in label2idx.items()}
    except Exception:
        idx2label = None
class_names = [idx2label[i] for i in range(len(idx2label))] if idx2label else None

# Примеры расчёта метрик на CPU для FP32/INT8
metrics_targets = {}
if 'student_fp32' in models_for_bench:
    lg, lb = collect_logits_labels(models_for_bench['student_fp32'], val_loader, device=torch.device('cpu'))
    metrics_targets['student_fp32'] = compute_classification_metrics(lg, lb, class_names)

if 'student_int8_ptq' in models_for_bench:
    lg, lb = collect_logits_labels(models_for_bench['student_int8_ptq'], val_loader, device=torch.device('cpu'))
    metrics_targets['student_int8_ptq'] = compute_classification_metrics(lg, lb, class_names)

if 'student_ema_fp32' in models_for_bench:
    lg, lb = collect_logits_labels(models_for_bench['student_ema_fp32'], val_loader, device=torch.device('cpu'))
    metrics_targets['student_ema_fp32'] = compute_classification_metrics(lg, lb, class_names)

if 'student_ema_int8_ptq' in models_for_bench:
    lg, lb = collect_logits_labels(models_for_bench['student_ema_int8_ptq'], val_loader, device=torch.device('cpu'))
    metrics_targets['student_ema_int8_ptq'] = compute_classification_metrics(lg, lb, class_names)

if 'student_qatlite_fp32' in models_for_bench:
    lg, lb = collect_logits_labels(models_for_bench['student_qatlite_fp32'], val_loader, device=torch.device('cpu'))
    metrics_targets['student_qatlite_fp32'] = compute_classification_metrics(lg, lb, class_names)

if 'student_qatlite_int8' in models_for_bench:
    lg, lb = collect_logits_labels(models_for_bench['student_qatlite_int8'], val_loader, device=torch.device('cpu'))
    metrics_targets['student_qatlite_int8'] = compute_classification_metrics(lg, lb, class_names)

print('\n=== Качество (val, CPU) ===')
for tag, m in metrics_targets.items():
    print_metrics(tag, m)