In [4]:
# python
# eval_cpu_models.py
import os
import json
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import timm

from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

# ----- Paths / constants -----
ISIC_ROOT = os.path.join('data', 'ISIC')
VAL_CSV = os.path.join(ISIC_ROOT, 'split_test.csv')
LABELS_JSON = os.path.join(ISIC_ROOT, 'labels.json')

CKPT_STUDENT_FP32 = os.path.join('data', 'model_weights', 'student_best.pth')
CKPT_STUDENT_EMA_FP32 = os.path.join('data', 'model_weights', 'student_best_ema.pth')
CKPT_STUDENT_INT8 = os.path.join('data', 'model_weights', 'student_best_int8.pth')
CKPT_STUDENT_EMA_INT8 = os.path.join('data', 'model_weights', 'student_best_ema_int8.pth')

# QAT-lite (если есть)
CKPT_STUDENT_QATL_FP32 = os.path.join('data', 'model_weights', 'student_best_qatlite.pth')
CKPT_STUDENT_QATL_INT8 = os.path.join('data', 'model_weights', 'student_best_qatlite_int8.pth')

IMG_SIZE = 224
BATCH_SIZE = 64
NUM_WORKERS = 0  # Windows
DEVICE = torch.device('cpu')

torch.backends.quantized.engine = 'fbgemm'


In [5]:
# ----- Dataset / transforms -----
class ISICCsvDataset(Dataset):
    def __init__(self, csv_path: str, label2idx=None, tfm=None):
        df = pd.read_csv(csv_path)
        if 'path' not in df.columns or 'label' not in df.columns:
            raise ValueError("CSV должен содержать столбцы 'path' и 'label'")
        self.paths = df['path'].tolist()
        self.labels_str = df['label'].tolist()
        if label2idx is None:
            uniq = sorted(pd.unique(self.labels_str).tolist())
            label2idx = {lbl: i for i, lbl in enumerate(uniq)}
        self.label2idx = label2idx
        self.labels = [self.label2idx[s] for s in self.labels_str]
        self.tfm = tfm

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        y = self.labels[idx]
        with Image.open(p) as img:
            img = img.convert('RGB')
        if self.tfm:
            img = self.tfm(img)
        return img, y


def build_transforms(img_size=224):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    val_tfm = T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    return val_tfm


# ----- Metrics -----
@torch.no_grad()
def eval_metrics(model: nn.Module, loader: DataLoader, num_classes: int):
    ce = nn.CrossEntropyLoss()
    model.eval()
    total_loss, total = 0.0, 0
    all_logits, all_labels = [], []

    for x, y in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        logits = model(x)
        loss = ce(logits, y)
        bs = y.size(0)
        total_loss += loss.item() * bs
        total += bs
        all_logits.append(logits.detach().cpu())
        all_labels.append(y.detach().cpu())

    if total == 0:
        return {'loss': None, 'acc': None, 'f1_macro': None,
                'roc_auc': None, 'roc_auc_macro': None, 'roc_auc_micro': None}

    logits = torch.cat(all_logits, dim=0).float()
    labels = torch.cat(all_labels, dim=0).numpy()
    probs = torch.softmax(logits, dim=1).numpy()
    preds = probs.argmax(axis=1)

    metrics = {
        'loss': total_loss / max(total, 1),
        'acc': accuracy_score(labels, preds),
        'f1_macro': f1_score(labels, preds, average='macro'),
        'roc_auc': None,
        'roc_auc_macro': None,
        'roc_auc_micro': None,
    }

    try:
        if num_classes == 2:
            # использовать вероятность класса 1
            metrics['roc_auc'] = roc_auc_score(labels, probs[:, 1])
        elif num_classes > 2:
            metrics['roc_auc_macro'] = roc_auc_score(labels, probs, multi_class='ovr', average='macro')
            metrics['roc_auc_micro'] = roc_auc_score(labels, probs, multi_class='ovr', average='micro')
    except Exception:
        pass

    return metrics


def print_metrics(tag: str, m: dict):
    parts = [f'{tag}:',
             f"acc={m.get('acc'):.4f}" if m.get('acc') is not None else "acc=NA",
             f"f1_macro={m.get('f1_macro'):.4f}" if m.get('f1_macro') is not None else "f1_macro=NA",
             f"loss={m.get('loss'):.4f}" if m.get('loss') is not None else "loss=NA"]
    if m.get('roc_auc') is not None:
        parts.append(f"roc_auc={m['roc_auc']:.4f}")
    if m.get('roc_auc_macro') is not None:
        parts.append(f"roc_auc_macro={m['roc_auc_macro']:.4f}")
    if m.get('roc_auc_micro') is not None:
        parts.append(f"roc_auc_micro={m['roc_auc_micro']:.4f}")
    print('  '.join(parts))


# ----- Model loaders -----
def build_deit_tiny(num_classes: int):
    return timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes)


def load_fp32_model(ckpt_path: str, num_classes: int):
    model = build_deit_tiny(num_classes)
    ckpt = torch.load(ckpt_path, map_location='cpu')
    sd = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
    model.load_state_dict(sd, strict=True)
    model.to(DEVICE)
    return model


def load_int8_model(ckpt_path: str, num_classes: int):
    # Создать динамически квантованную архитектуру и загрузить её state_dict
    float_model = build_deit_tiny(num_classes)
    model_int8 = torch.ao.quantization.quantize_dynamic(float_model, {nn.Linear}, dtype=torch.qint8)
    ckpt = torch.load(ckpt_path, map_location='cpu')
    sd = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
    model_int8.load_state_dict(sd, strict=True)
    model_int8.to(DEVICE)
    return model_int8


def maybe_eval(tag: str, builder, val_loader, num_classes: int):
    try:
        model = builder()
        m = eval_metrics(model, val_loader, num_classes)
        print_metrics(tag, m)
    except FileNotFoundError:
        pass
    except Exception as e:
        print(f'{tag}: ошибка — {e}')


In [6]:
# labels / classes
with open(LABELS_JSON, 'r', encoding='utf-8') as f:
    label2idx = json.load(f)
num_classes = len(label2idx)

# val loader
val_tfm = build_transforms(IMG_SIZE)
val_ds = ISICCsvDataset(VAL_CSV, label2idx=label2idx, tfm=val_tfm)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=False, persistent_workers=False
)

print('=== Оценка CPU‑моделей (accuracy, macro‑F1, ROC‑AUC) ===')

# Student FP32
if os.path.isfile(CKPT_STUDENT_FP32):
    maybe_eval('Student FP32',
               lambda: load_fp32_model(CKPT_STUDENT_FP32, num_classes),
               val_loader, num_classes)

# Student EMA FP32
if os.path.isfile(CKPT_STUDENT_EMA_FP32):
    maybe_eval('Student EMA FP32',
               lambda: load_fp32_model(CKPT_STUDENT_EMA_FP32, num_classes),
               val_loader, num_classes)

# Student INT8 (PTQ)
if os.path.isfile(CKPT_STUDENT_INT8):
    maybe_eval('Student INT8 (PTQ)',
               lambda: load_int8_model(CKPT_STUDENT_INT8, num_classes),
               val_loader, num_classes)

# Student EMA INT8 (PTQ)
if os.path.isfile(CKPT_STUDENT_EMA_INT8):
    maybe_eval('Student EMA INT8 (PTQ)',
               lambda: load_int8_model(CKPT_STUDENT_EMA_INT8, num_classes),
               val_loader, num_classes)

# QAT-lite FP32 (если есть)
if os.path.isfile(CKPT_STUDENT_QATL_FP32):
    maybe_eval('Student QAT‑lite FP32',
               lambda: load_fp32_model(CKPT_STUDENT_QATL_FP32, num_classes),
               val_loader, num_classes)

# QAT-lite INT8 (если есть)
if os.path.isfile(CKPT_STUDENT_QATL_INT8):
    maybe_eval('Student QAT‑lite INT8',
               lambda: load_int8_model(CKPT_STUDENT_QATL_INT8, num_classes),
               val_loader, num_classes)

=== Оценка CPU‑моделей (accuracy, macro‑F1, ROC‑AUC) ===
Student FP32:  acc=0.8362  f1_macro=0.7797  loss=0.5394  roc_auc_macro=0.9626  roc_auc_micro=0.9789
Student EMA FP32:  acc=0.8308  f1_macro=0.7786  loss=0.5445  roc_auc_macro=0.9616  roc_auc_micro=0.9781


  device=storage.device,


Student INT8 (PTQ):  acc=0.8352  f1_macro=0.7764  loss=0.5437  roc_auc_macro=0.9619  roc_auc_micro=0.9786
Student EMA INT8 (PTQ):  acc=0.8308  f1_macro=0.7762  loss=0.5451  roc_auc_macro=0.9617  roc_auc_micro=0.9781
Student QAT‑lite FP32:  acc=0.8253  f1_macro=0.7747  loss=0.5418  roc_auc_macro=0.9621  roc_auc_micro=0.9790
Student QAT‑lite INT8:  acc=0.8243  f1_macro=0.7714  loss=0.5464  roc_auc_macro=0.9618  roc_auc_micro=0.9787
