In [42]:
import os
import io
import json
import math
import time
import copy
import random
import platform
import psutil
import warnings
from typing import Dict, Tuple, List, Optional

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms as T
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, f1_score, roc_auc_score,
    precision_recall_fscore_support, confusion_matrix
)

import timm

from EMA_for_weights import EMA
from wqat import LinearWQAT

In [43]:
# ======== Константы/пути ========
SEED = 42
ISIC_ROOT = os.path.join('data', 'ISIC')
TRAIN_CSV = os.path.join(ISIC_ROOT, 'split_train.csv')
VAL_CSV = os.path.join(ISIC_ROOT, 'split_val.csv')          # будет создан при первом запуске
TEST_CSV = os.path.join(ISIC_ROOT, 'split_test.csv')         # финальная оценка (без аугм., 1 проход)
LABELS_JSON = os.path.join(ISIC_ROOT, 'labels.json')

WEIGHTS_DIR = os.path.join('data', 'model_weights')
RESULTS_DIR = os.path.join('data', 'results')
os.makedirs(WEIGHTS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

CKPT_TEACHER = os.path.join(WEIGHTS_DIR, 'deit_s_best.pth')
CKPT_STUDENT = os.path.join(WEIGHTS_DIR, 'student_best.pth')
CKPT_STUDENT_EMA = os.path.join(WEIGHTS_DIR, 'student_best_ema.pth')
CKPT_STUDENT_INT8 = os.path.join(WEIGHTS_DIR, 'student_best_int8.pth')
CKPT_STUDENT_EMA_INT8 = os.path.join(WEIGHTS_DIR, 'student_best_ema_int8.pth')

# QAT-lite
CKPT_STUDENT_QATL_FP32 = os.path.join(WEIGHTS_DIR, 'student_best_qatlite.pth')
CKPT_STUDENT_QATL_INT8 = os.path.join(WEIGHTS_DIR, 'student_best_qatlite_int8.pth')

# HParams
IMG_SIZE = 224
BATCH_TRAIN = 32
BATCH_EVAL = 64
NUM_WORKERS = 0  # Windows
EPOCHS = 15
BASE_LR = 3e-4
WEIGHT_DECAY = 0.05
KD_T = 4.0
KD_ALPHA = 0.7

# QAT-lite hparams
QAT_EPOCHS = 5
QAT_LR = 3e-5
QAT_WD = 0.0
QAT_PER_CHANNEL = True

# CPU quant backend
torch.backends.quantized.engine = 'fbgemm'

In [44]:
# ======== Утилиты ========
def set_global_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.benchmark = False  # воспроизводимость


def ensure_val_split(train_csv: str, val_csv: str, val_size=0.1, seed=42):
    if os.path.isfile(val_csv):
        return
    df = pd.read_csv(train_csv)
    if not {'path', 'label'} <= set(df.columns):
        raise ValueError("Ожидаются столбцы 'path' и 'label' в split_train.csv")
    idx = np.arange(len(df))
    y = df['label'].values
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=seed)
    tr_idx, val_idx = next(sss.split(idx, y))
    df_val = df.iloc[val_idx].copy()
    df_val.to_csv(val_csv, index=False)
    print(f'Создан валид. сплит: {val_csv} (n={len(df_val)})')


def system_info() -> Dict[str, str]:
    cpu = platform.processor()
    if not cpu:
        try:
            import cpuinfo  # type: ignore
            cpu = cpuinfo.get_cpu_info().get('brand_raw', '')
        except Exception:
            cpu = 'unknown'
    return {
        'python': platform.python_version(),
        'pytorch': torch.__version__,
        'os': platform.platform(),
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'cpu': cpu,
        'quant_backend': getattr(torch.backends.quantized, 'engine', None),
        'inference_backend': 'torch-eager'
    }


def count_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters())


def count_flops(m: nn.Module, img_size=224) -> Optional[int]:
    try:
        from ptflops import get_model_complexity_info  # type: ignore
        m_eval = copy.deepcopy(m).eval().cpu()
        with torch.no_grad():
            macs, _ = get_model_complexity_info(
                m_eval, (3, img_size, img_size), as_strings=False, print_per_layer_stat=False
            )
        return int(macs * 2)
    except Exception:
        return None


def state_dict_size_mb(state_dict: Dict[str, torch.Tensor]) -> float:
    buf = io.BytesIO()
    torch.save(state_dict, buf)
    return buf.tell() / (1024.0 * 1024.0)


In [45]:
# ======== Данные ========
class ISICCsvDataset(Dataset):
    def __init__(self, csv_path: str, label2idx=None, tfm=None):
        df = pd.read_csv(csv_path)
        if 'path' not in df.columns or 'label' not in df.columns:
            raise ValueError("CSV должен содержать столбцы 'path' и 'label'")
        self.paths = df['path'].tolist()
        self.labels_str = df['label'].tolist()
        if label2idx is None:
            uniq = sorted(pd.unique(self.labels_str).tolist())
            label2idx = {s: i for i, s in enumerate(uniq)}
        self.label2idx = label2idx
        self.labels = [self.label2idx[s] for s in self.labels_str]
        self.tfm = tfm

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        y = self.labels[idx]
        with Image.open(p) as img:
            img = img.convert('RGB')
        if self.tfm:
            img = self.tfm(img)
        return img, y


def build_transforms(img_size=224):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    train_tfm = T.Compose([
        T.RandomResizedCrop(img_size, scale=(0.7, 1.0), ratio=(0.75, 1.33)),
        T.RandomHorizontalFlip(0.5),
        T.ColorJitter(0.2, 0.2, 0.2, 0.1),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    val_tfm = T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    return train_tfm, val_tfm


def make_weighted_sampler(labels_idx: List[int], num_classes: int):
    counts = np.bincount(labels_idx, minlength=num_classes).astype(np.float32)
    inv = 1.0 / np.maximum(counts, 1.0)
    weights = [inv[y] for y in labels_idx]
    return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True), counts


In [46]:
# ======== Модели ========
def build_teacher(num_classes: int) -> nn.Module:
    # DeiT-S
    m = timm.create_model('deit_small_patch16_224', pretrained=False, num_classes=num_classes)
    if os.path.isfile(CKPT_TEACHER):
        ckpt = torch.load(CKPT_TEACHER, map_location='cpu')
        sd = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
        m.load_state_dict(sd, strict=True)
    else:
        warnings.warn('CKPT_TEACHER не найден — используем timm pretrained head, num_classes адаптирован')
        m = timm.create_model('deit_small_patch16_224', pretrained=True, num_classes=num_classes)
    m.eval()
    for p in m.parameters():
        p.requires_grad_(False)
    return m


def build_student(num_classes: int, pretrained=True) -> nn.Module:
    return timm.create_model('deit_tiny_patch16_224', pretrained=pretrained, num_classes=num_classes)


def build_baselines(num_classes: int) -> Dict[str, nn.Module]:
    names = {
        'resnet18': 'resnet18',
        'mobilenetv3': 'mobilenetv3_large_100',
        'convnext_tiny': 'convnext_tiny',
    }
    out = {}
    for tag, name in names.items():
        try:
            out[f'{tag}_fp32'] = timm.create_model(name, pretrained=True, num_classes=num_classes)
        except Exception as e:
            print(f'Бейзлайн {name} пропущен: {e}')
    return out


def to_int8_dynamic(fp32_model: nn.Module) -> nn.Module:
    # Только nn.Linear для ViT/DeiT и большинства CNN голов
    return torch.ao.quantization.quantize_dynamic(copy.deepcopy(fp32_model), {nn.Linear}, dtype=torch.qint8)


In [47]:
# ======== Потери/оценка ========
def kd_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    soft = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T * T)
    hard = F.cross_entropy(student_logits, labels)
    return alpha * soft + (1.0 - alpha) * hard


@torch.no_grad()
def evaluate_loss_acc(model: nn.Module, loader: DataLoader, device) -> Tuple[float, float]:
    ce = nn.CrossEntropyLoss()
    model.eval().to(device)
    total, correct, total_loss = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits, y)
        total_loss += loss.item() * y.size(0)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return total_loss / max(1, total), correct / max(1, total)


@torch.no_grad()
def collect_logits_labels(model: nn.Module, loader: DataLoader, device=torch.device('cpu')):
    model.eval().to(device)
    lg, lb = [], []
    for x, y in loader:
        x = x.to(device)
        lg.append(model(x).cpu())
        lb.append(y.cpu())
    return torch.cat(lg, 0), torch.cat(lb, 0)


def compute_metrics(logits: torch.Tensor, labels: torch.Tensor, class_names: Optional[List[str]] = None) -> Dict:
    y_true = labels.numpy()
    y_prob = F.softmax(logits, dim=1).numpy()
    y_pred = y_prob.argmax(1)
    nc = logits.size(1)

    acc = accuracy_score(y_true, y_pred)
    macro_f1 = f1_score(y_true, y_pred, average='macro')
    try:
        if nc == 2:
            roc_macro = roc_auc_score(y_true, y_prob[:, 1])
        else:
            roc_macro = roc_auc_score(y_true, y_prob, multi_class='ovr', average='macro')
    except Exception:
        roc_macro = float('nan')

    prec_c, rec_c, f1_c, supp_c = precision_recall_fscore_support(
        y_true, y_pred, labels=np.arange(nc), average=None, zero_division=0
    )
    cm = confusion_matrix(y_true, y_pred, labels=np.arange(nc))
    per_class = []
    for i in range(nc):
        name = class_names[i] if (class_names and i < len(class_names)) else f'class_{i}'
        per_class.append({
            'class': name, 'precision': float(prec_c[i]), 'recall': float(rec_c[i]),
            'f1': float(f1_c[i]), 'support': int(supp_c[i])
        })
    return {
        'accuracy': float(acc),
        'macro_f1': float(macro_f1),
        'roc_auc_macro': float(roc_macro) if not math.isnan(roc_macro) else float('nan'),
        'per_class': per_class,
        'confusion_matrix': cm.astype(int).tolist(),
    }


In [48]:
# ======== Обучение KD + EMA ========
def train_kd_ema(student: nn.Module, teacher: nn.Module,
                 train_loader: DataLoader, val_loader: DataLoader,
                 device, epochs=15, base_lr=3e-4, weight_decay=0.05,
                 kd_T=4.0, kd_alpha=0.7,
                 ckpt_student=CKPT_STUDENT, ckpt_student_ema=CKPT_STUDENT_EMA) -> Tuple[str, str]:
    ema = EMA(student, decay=0.999)
    optimizer = torch.optim.AdamW(student.parameters(), lr=base_lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = torch.amp.GradScaler('cuda', enabled=(device.type == 'cuda'))

    best_acc, best_acc_ema = 0.0, 0.0

    for epoch in range(1, epochs + 1):
        student.train()
        run_loss, run_correct, seen = 0.0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                t_logits = teacher(x)
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
                s_logits = student(x)
                loss = kd_loss(s_logits, t_logits, y, T=kd_T, alpha=kd_alpha)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            ema.update(student)

            bs = y.size(0)
            seen += bs
            run_loss += loss.item() * bs
            run_correct += (s_logits.argmax(1) == y).sum().item()
        scheduler.step()

        tr_loss = run_loss / max(1, seen)
        tr_acc = run_correct / max(1, seen)
        v_loss, v_acc = evaluate_loss_acc(student, val_loader, device)
        v_loss_e, v_acc_e = evaluate_loss_acc(ema.ema_model, val_loader, device)
        print(f'Epoch {epoch:03d} | train: loss={tr_loss:.4f} acc={tr_acc:.4f} | '
              f'val: s=({v_loss:.4f},{v_acc:.4f}) ema=({v_loss_e:.4f},{v_acc_e:.4f})')

        if v_acc > best_acc:
            best_acc = v_acc
            torch.save({'model': student.state_dict()}, ckpt_student)
        if v_acc_e > best_acc_ema:
            best_acc_ema = v_acc_e
            torch.save({'model': ema.ema_model.state_dict()}, ckpt_student_ema)

    print(f'Best val acc: student={best_acc:.4f} | ema={best_acc_ema:.4f}')
    return ckpt_student, ckpt_student_ema


In [49]:
# ======== QAT-lite ========
def wrap_linears_with_wqat(module: nn.Module, per_channel=True) -> nn.Module:
    for name, child in list(module.named_children()):
        if isinstance(child, nn.Linear):
            setattr(module, name, LinearWQAT(child, per_channel=per_channel))
        else:
            wrap_linears_with_wqat(child, per_channel=per_channel)
    return module


def unwrap_wqat_to_linear(module: nn.Module) -> nn.Module:
    for name, child in list(module.named_children()):
        if isinstance(child, LinearWQAT):
            base = nn.Linear(child.in_features, child.out_features, bias=child.has_bias)
            with torch.no_grad():
                base.weight.copy_(child.weight)
                if child.has_bias:
                    base.bias.copy_(child.bias)
            setattr(module, name, base)
        else:
            unwrap_wqat_to_linear(child)
    return module


def train_qat_lite(student_base: nn.Module, teacher: nn.Module,
                   train_loader: DataLoader, val_loader: DataLoader,
                   device, epochs=5, lr=3e-5, wd=0.0, per_channel=True,
                   ckpt_qat_fp32=CKPT_STUDENT_QATL_FP32,
                   ckpt_qat_int8=CKPT_STUDENT_QATL_INT8) -> Tuple[str, str]:
    student_qat = wrap_linears_with_wqat(copy.deepcopy(student_base), per_channel=per_channel).to(device)
    opt = torch.optim.AdamW(student_qat.parameters(), lr=lr, weight_decay=wd)

    teacher.eval().to(device)
    for p in teacher.parameters():
        p.requires_grad_(False)

    print(f'QAT-lite: epochs={epochs}, lr={lr}, wd={wd}, per_channel={per_channel}')
    for ep in range(1, epochs + 1):
        student_qat.train()
        run_loss, seen = 0.0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                t_logits = teacher(x)
            opt.zero_grad(set_to_none=True)
            s_logits = student_qat(x)
            loss = kd_loss(s_logits, t_logits, y, T=KD_T, alpha=KD_ALPHA)
            loss.backward()
            opt.step()
            run_loss += loss.item() * y.size(0)
            seen += y.size(0)
        v_loss, v_acc = evaluate_loss_acc(student_qat, val_loader, device)
        print(f'QAT {ep:02d}/{epochs} | train_loss={run_loss/max(1,seen):.4f} | val_loss={v_loss:.4f} val_acc={v_acc:.4f}')

    # Unwrap to Linear and save FP32
    student_qat_fp32 = unwrap_wqat_to_linear(copy.deepcopy(student_qat).cpu())
    torch.save({'model': student_qat_fp32.state_dict()}, ckpt_qat_fp32)

    # INT8 (dynamic)
    student_qat_int8 = to_int8_dynamic(student_qat_fp32)
    torch.save({'model': student_qat_int8.state_dict()}, ckpt_qat_int8)
    return ckpt_qat_fp32, ckpt_qat_int8


In [50]:
# ======== CPU-бенчмарк ========
def cache_batches(loader: DataLoader, limit: Optional[int] = None):
    cached = []
    for i, (x, y) in enumerate(loader):
        cached.append((x.cpu(), y))
        if limit is not None and i + 1 >= limit:
            break
    return cached


@torch.no_grad()
def benchmark_cpu(model: nn.Module, cached_batches,
                  warmup=50, measure=100, reps=5) -> Dict[str, float]:
    proc = psutil.Process(os.getpid())
    model.eval().cpu()
    lat_ms = []
    total_imgs, total_time = 0, 0.0
    peak = proc.memory_info().rss

    if not cached_batches:
        raise RuntimeError('Нет батчей для бенчмарка')
    nb = len(cached_batches)

    # Warmup
    j = 0
    for _ in range(warmup):
        x, _ = cached_batches[j % nb]; j += 1
        _ = model(x)
        peak = max(peak, proc.memory_info().rss)

    # Measure
    j = 0
    for _ in range(reps):
        for _ in range(measure):
            x, _ = cached_batches[j % nb]; j += 1
            t0 = time.perf_counter()
            _ = model(x)
            dt = time.perf_counter() - t0
            lat_ms.append(dt * 1e3)
            total_imgs += x.size(0)
            total_time += dt
            peak = max(peak, proc.memory_info().rss)

    lat = np.asarray(lat_ms, dtype=np.float64)
    p50 = float(np.percentile(lat, 50))
    p90 = float(np.percentile(lat, 90))
    thr = float(total_imgs / total_time) if total_time > 0 else float('nan')
    peak_mb = peak / (1024.0 * 1024.0)
    size_mb = state_dict_size_mb(model.state_dict()) if hasattr(model, 'state_dict') else float('nan')
    return {'p50_ms': p50, 'p90_ms': p90, 'thr_img_s': thr, 'peak_ram_mb': peak_mb, 'size_mb': size_mb}



In [51]:
# python
import torch
import torch.nn as nn

class ExportOnlyLogits(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x):
        out = self.model(x)
        if isinstance(out, torch.Tensor):
            return out
        if isinstance(out, (list, tuple)):
            for v in out:
                if isinstance(v, torch.Tensor) and v.dim() == 2:
                    return v
        if isinstance(out, dict):
            for v in out.values():
                if isinstance(v, torch.Tensor) and v.dim() == 2:
                    return v
        raise RuntimeError('Экспорт: forward не вернул тензор с логитами.')

In [52]:
# ======== Повторы (k) для mean±std ========
def k_repeats_eval(model: nn.Module, base_csv: str, dataset_cls, label2idx, tfm, k=5, seed=42):
    df = pd.read_csv(base_csv)
    X = df['path'].values
    y = df['label'].values
    idx = np.arange(len(df))
    sss = StratifiedShuffleSplit(n_splits=k, test_size=0.5, random_state=seed)
    recs = []
    for _, val_idx in sss.split(idx, y):
        tmp = df.iloc[val_idx].copy()
        tmp_csv = os.path.join(ISIC_ROOT, '_tmp_val.csv')
        tmp.to_csv(tmp_csv, index=False)
        ds = dataset_cls(tmp_csv, label2idx=label2idx, tfm=tfm)
        ld = DataLoader(ds, batch_size=BATCH_EVAL, shuffle=False, num_workers=NUM_WORKERS)
        lg, lb = collect_logits_labels(model, ld, device=torch.device('cpu'))
        m = compute_metrics(lg, lb, None)
        recs.append(m)
        os.remove(tmp_csv)
    def _agg(key):
        vals = [r[key] for r in recs if isinstance(r[key], float) and not math.isnan(r[key])]
        return (float(np.mean(vals)), float(np.std(vals))) if vals else (float('nan'), float('nan'))
    return {'acc': _agg('accuracy'), 'macro_f1': _agg('macro_f1'), 'roc_auc_macro': _agg('roc_auc_macro')}


In [53]:
# ======== Главный сценарий ========
print('System:', system_info())
set_global_seed(SEED)
ensure_val_split(TRAIN_CSV, VAL_CSV, val_size=0.1, seed=SEED)


System: {'python': '3.11.0', 'pytorch': '2.7.1+cu118', 'os': 'Windows-10-10.0.26100-SP0', 'device': 'cuda', 'cpu': 'Intel64 Family 6 Model 151 Stepping 2, GenuineIntel', 'quant_backend': 'fbgemm', 'inference_backend': 'torch-eager'}


In [54]:
# labels
with open(LABELS_JSON, 'r', encoding='utf-8') as f:
    label2idx = json.load(f)
idx2label = {v: k for k, v in label2idx.items()}
num_classes = len(label2idx)

# data
train_tfm, val_tfm = build_transforms(IMG_SIZE)
train_df = pd.read_csv(TRAIN_CSV)
train_labels_idx = [label2idx[s] for s in train_df['label'].tolist()]
sampler, cls_counts = make_weighted_sampler(train_labels_idx, num_classes)

train_ds = ISICCsvDataset(TRAIN_CSV, label2idx=label2idx, tfm=train_tfm)
val_ds = ISICCsvDataset(VAL_CSV, label2idx=label2idx, tfm=val_tfm)
test_ds = ISICCsvDataset(TEST_CSV, label2idx=label2idx, tfm=val_tfm) if os.path.isfile(TEST_CSV) else None

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pin = device.type == 'cuda'
train_loader = DataLoader(train_ds, batch_size=BATCH_TRAIN, sampler=sampler,
                          num_workers=NUM_WORKERS, pin_memory=pin)
val_loader = DataLoader(val_ds, batch_size=BATCH_EVAL, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=pin)
test_loader = None
if test_ds is not None:
    test_loader = DataLoader(test_ds, batch_size=BATCH_EVAL, shuffle=False,
                             num_workers=NUM_WORKERS, pin_memory=pin)

print('Train class counts:', cls_counts.tolist())



Train class counts: [262.0, 411.0, 879.0, 92.0, 890.0, 5364.0, 114.0]


In [55]:
# teacher / student
teacher = build_teacher(num_classes).to(device).eval()
student = build_student(num_classes, pretrained=True).to(device)

# KD+EMA training
s_ckpt, s_ema_ckpt = train_kd_ema(
    student, teacher, train_loader, val_loader, device,
    epochs=EPOCHS, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY,
    kd_T=KD_T, kd_alpha=KD_ALPHA,
    ckpt_student=CKPT_STUDENT, ckpt_student_ema=CKPT_STUDENT_EMA
)

# Load best FP32 models
student_fp32 = build_student(num_classes, pretrained=False)
student_fp32.load_state_dict(torch.load(s_ckpt, map_location='cpu')['model'], strict=True)
student_ema_fp32 = build_student(num_classes, pretrained=False)
student_ema_fp32.load_state_dict(torch.load(s_ema_ckpt, map_location='cpu')['model'], strict=True)

# INT8 (PTQ dynamic)
student_int8 = to_int8_dynamic(student_fp32)
student_ema_int8 = to_int8_dynamic(student_ema_fp32)
torch.save({'model': student_int8.state_dict()}, CKPT_STUDENT_INT8)
torch.save({'model': student_ema_int8.state_dict()}, CKPT_STUDENT_EMA_INT8)

# QAT-lite (из EMA)
qatl_fp32_path, qatl_int8_path = train_qat_lite(
    student_base=student_ema_fp32, teacher=teacher,
    train_loader=train_loader, val_loader=val_loader, device=device,
    epochs=QAT_EPOCHS, lr=QAT_LR, wd=QAT_WD, per_channel=QAT_PER_CHANNEL,
    ckpt_qat_fp32=CKPT_STUDENT_QATL_FP32, ckpt_qat_int8=CKPT_STUDENT_QATL_INT8
)
student_qatl_fp32 = build_student(num_classes, pretrained=False)
student_qatl_fp32.load_state_dict(torch.load(qatl_fp32_path, map_location='cpu')['model'], strict=True)
student_qatl_int8 = build_student(num_classes, pretrained=False)
student_qatl_int8 = to_int8_dynamic(student_qatl_fp32)

# Бейзлайны (FP32 + INT8 dynamic)
baselines = build_baselines(num_classes)
baselines_int8 = {}
for name, m in baselines.items():
    try:
        baselines_int8[f'{name.replace("_fp32","")}_int8_ptq'] = to_int8_dynamic(m)
    except Exception:
        pass

# Метрики качества (val) для всех вариантов
eval_targets: Dict[str, nn.Module] = {
    'student_fp32': student_fp32,
    'student_ema_fp32': student_ema_fp32,
    'student_int8_ptq': student_int8,
    'student_ema_int8_ptq': student_ema_int8,
    'student_qatlite_fp32': student_qatl_fp32,
    'student_qatlite_int8': student_qatl_int8,
    **baselines,
    **baselines_int8
}

metrics_summary = {}
for tag, model in eval_targets.items():
    lg, lb = collect_logits_labels(model, val_loader, device=torch.device('cpu'))
    m = compute_metrics(lg, lb, [idx2label[i] for i in range(num_classes)])
    metrics_summary[tag] = m
    print(f'[VAL] {tag}: acc={m["accuracy"]:.4f} macro-F1={m["macro_f1"]:.4f} ROC-AUC={m["roc_auc_macro"]:.4f}')

# Финальный тест (один прогон, без аугм.)
if test_loader is not None:
    for tag in ['student_fp32', 'student_ema_fp32', 'student_int8_ptq', 'student_ema_int8_ptq',
                'student_qatlite_fp32', 'student_qatlite_int8']:
        if tag in eval_targets:
            lg, lb = collect_logits_labels(eval_targets[tag], test_loader, device=torch.device('cpu'))
            m = compute_metrics(lg, lb, [idx2label[i] for i in range(num_classes)])
            metrics_summary[f'{tag}_TEST'] = m
            print(f'[TEST] {tag}: acc={m["accuracy"]:.4f} macro-F1={m["macro_f1"]:.4f} ROC-AUC={m["roc_auc_macro"]:.4f}')


# CPU-бенчмарк (batch=1 и 8)
cached_val_b1 = cache_batches(DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0), limit=64)
cached_val_b8 = cache_batches(DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=0), limit=64)

bench_rows = []
for tag, model in eval_targets.items():
    for bs, cached in [(1, cached_val_b1), (8, cached_val_b8)]:
        res = benchmark_cpu(model, cached, warmup=50, measure=100, reps=5)
        row = {'model': tag, 'batch': bs, **res}
        bench_rows.append(row)
        print(f'[CPU] {tag:24s} b={bs} | p50={res["p50_ms"]:.2f}ms p90={res["p90_ms"]:.2f}ms '
              f'thr={res["thr_img_s"]:.1f} img/s RAM={res["peak_ram_mb"]:.1f}MB size={res["size_mb"]:.2f}MB')

# Params/FLOPs/Size
desc_rows = []
for tag, m in eval_targets.items():
    try:
        flops = count_flops(m, img_size=IMG_SIZE)
    except Exception:
        flops = None
    params = count_params(m)
    size_mb = state_dict_size_mb(m.state_dict()) if hasattr(m, 'state_dict') else float('nan')
    desc_rows.append({'model': tag, 'params': int(params), 'flops': (int(flops) if flops else None), 'size_mb': float(size_mb)})


Epoch 001 | train: loss=4.3309 acc=0.5741 | val: s=(0.7053,0.7431) ema=(1.4188,0.6122)
Epoch 002 | train: loss=2.3092 acc=0.7229 | val: s=(0.5824,0.7855) ema=(0.8248,0.7556)
Epoch 003 | train: loss=1.7262 acc=0.7721 | val: s=(1.1302,0.6833) ema=(0.6049,0.7893)
Epoch 004 | train: loss=1.3769 acc=0.8090 | val: s=(0.5919,0.7693) ema=(0.5291,0.8030)
Epoch 005 | train: loss=1.1160 acc=0.8356 | val: s=(0.7693,0.7207) ema=(0.4875,0.8155)
Epoch 006 | train: loss=0.8964 acc=0.8571 | val: s=(0.4566,0.8242) ema=(0.4499,0.8304)
Epoch 007 | train: loss=0.6969 acc=0.8798 | val: s=(0.4469,0.8192) ema=(0.4235,0.8392)
Epoch 008 | train: loss=0.6293 acc=0.8877 | val: s=(0.4365,0.8441) ema=(0.4001,0.8441)
Epoch 009 | train: loss=0.4860 acc=0.9090 | val: s=(0.3975,0.8416) ema=(0.3781,0.8541)
Epoch 010 | train: loss=0.4292 acc=0.9164 | val: s=(0.3546,0.8579) ema=(0.3660,0.8566)
Epoch 011 | train: loss=0.3748 acc=0.9264 | val: s=(0.3114,0.8803) ema=(0.3529,0.8641)
Epoch 012 | train: loss=0.3392 acc=0.9335 |

In [57]:
import os
import torch
import torch.nn as nn

@torch.no_grad()
def export_torchscript(model: nn.Module, img_size=224, out_path='model_ts.pt'):
    model = model.cpu().eval()
    wrapped = ExportOnlyLogits(model)
    dummy = torch.randn(1, 3, img_size, img_size, dtype=torch.float32)
    ts = torch.jit.trace(wrapped, (dummy,), strict=False)
    os.makedirs(os.path.dirname(out_path) or '.', exist_ok=True)
    ts.save(out_path)
    try:
        loaded = torch.jit.load(out_path, map_location='cpu')
        _ = loaded(dummy)
        print('TorchScript verify: OK (forward)')
    except Exception as e:
        print(f'TorchScript verify: пропущен ({e})')
    print(f'TorchScript -> {out_path}')

# ======== Финальные вызовы экспорта (только TorchScript) ========
export_dir = os.path.join(RESULTS_DIR, 'export')
os.makedirs(export_dir, exist_ok=True)

try:
    if 'student_ema_int8' in globals():
        export_torchscript(student_ema_int8, IMG_SIZE, os.path.join(export_dir, 'student_ema_int8_ts.pt'))
    elif 'student_int8' in globals():
        export_torchscript(student_int8, IMG_SIZE, os.path.join(export_dir, 'student_int8_ts.pt'))
    elif 'student_qatl_int8' in globals():
        export_torchscript(student_qatl_int8, IMG_SIZE, os.path.join(export_dir, 'student_qatl_int8_ts.pt'))
    else:
        print('Нет INT8‑модели в памяти для экспорта (ожидаются `student_ema_int8`/`student_int8`/`student_qatl_int8`).')
except Exception as e:
    print(f'Экспорт TorchScript пропущен: {e}')

  assert condition, message


TorchScript verify: OK (forward)
TorchScript -> data\results\export\student_ema_int8_ts.pt
