In [1]:
# python
import os
import io
import copy
import json
import math
import torch
import torch.nn as nn
import pandas as pd
import timm

# ---------- Пути/чекпоинты ----------
ISIC_ROOT = os.path.join('data', 'ISIC')
WEIGHTS_DIR = os.path.join('data', 'model_weights')
RESULTS_DIR = os.path.join('data', 'results')
os.makedirs(RESULTS_DIR, exist_ok=True)

CKPTS = {
    'DeiT-T FP32 (KD)': os.path.join(WEIGHTS_DIR, 'student_best.pth'),
    'DeiT-T FP32 (KD+EMA)': os.path.join(WEIGHTS_DIR, 'student_best_ema.pth'),
    'DeiT-T INT8 (PTQ)': os.path.join(WEIGHTS_DIR, 'student_best_int8.pth'),
    'DeiT-T INT8 (KD+EMA+PTQ)': os.path.join(WEIGHTS_DIR, 'student_best_ema_int8.pth'),
    'DeiT-T QAT-lite FP32': os.path.join(WEIGHTS_DIR, 'student_best_qatlite.pth'),
    'DeiT-T QAT-lite INT8': os.path.join(WEIGHTS_DIR, 'student_best_qatlite_int8.pth'),
}

# ---------- Хелперы ----------
def state_dict_size_mb(state_dict: dict) -> float:
    buf = io.BytesIO()
    torch.save(state_dict, buf)
    return buf.tell() / (1024.0 * 1024.0)

def count_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters())

def count_flops(model: nn.Module, img_size=224):
    try:
        from ptflops import get_model_complexity_info  # type: ignore
    except Exception:
        return None
    m_eval = copy.deepcopy(model).eval().cpu()
    with torch.no_grad():
        macs, _ = get_model_complexity_info(
            m_eval, (3, img_size, img_size),
            as_strings=False, print_per_layer_stat=False
        )
    return int(macs * 2)  # FLOPs ~= 2 * MACs

def to_int8_dynamic(fp32_model: nn.Module) -> nn.Module:
    return torch.ao.quantization.quantize_dynamic(
        copy.deepcopy(fp32_model), {nn.Linear}, dtype=torch.qint8
    )

# ---------- Классы/число классов ----------
with open(os.path.join(ISIC_ROOT, 'labels.json'), 'r', encoding='utf-8') as f:
    label2idx = json.load(f)
num_classes = len(label2idx)

# ---------- Эталонные FP32-архитектуры (для Params/FLOPs) ----------
archs = {
    'DeiT-T': timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes),
    'ResNet-18': timm.create_model('resnet18', pretrained=False, num_classes=num_classes),
    'MobileNetV3-L': timm.create_model('mobilenetv3_large_100', pretrained=False, num_classes=num_classes),
    'ConvNeXt-T': timm.create_model('convnext_tiny', pretrained=False, num_classes=num_classes),
}

# Считаем Params/FLOPs один раз на FP32-архитектуре
arch_stats = {}
for name, m in archs.items():
    p = count_params(m)
    f = count_flops(m, img_size=224)  # может вернуть None, если нет ptflops
    arch_stats[name] = {
        'params_m': p / 1e6,
        'flops_g': (f / 1e9 if f is not None else None),
    }

# ---------- Модели-строки для таблицы ----------
rows_def = [
    ('DeiT-T FP32 (KD)',                 'DeiT-T',          'ckpt'),
    ('DeiT-T FP32 (KD+EMA)',             'DeiT-T',          'ckpt'),
    ('DeiT-T INT8 (PTQ)',                'DeiT-T',          'ckpt_int8'),
    ('DeiT-T INT8 (KD+EMA+PTQ)',         'DeiT-T',          'ckpt_int8'),
    ('DeiT-T QAT-lite FP32',             'DeiT-T',          'ckpt'),
    ('DeiT-T QAT-lite INT8',             'DeiT-T',          'ckpt_int8'),
    ('ResNet-18 FP32',                   'ResNet-18',       'fp32_live'),
    ('ResNet-18 INT8 (PTQ)',             'ResNet-18',       'int8_live'),
    ('MobileNetV3-L FP32',               'MobileNetV3-L',   'fp32_live'),
    ('MobileNetV3-L INT8 (PTQ)',         'MobileNetV3-L',   'int8_live'),
    ('ConvNeXt-T FP32',                   'ConvNeXt-T',     'fp32_live'),
    ('ConvNeXt-T INT8 (PTQ)',             'ConvNeXt-T',     'int8_live'),
]

# ---------- Подсчёт Size для каждой строки ----------
def size_mb_for_row(row_name: str, arch_key: str, how: str) -> float:
    # 1) Если есть соответствующий чекпоинт — считаем размер из него (как в ноутбуке)
    if how.startswith('ckpt'):
        ckpt_path = CKPTS.get(row_name, '')
        if os.path.isfile(ckpt_path):
            obj = torch.load(ckpt_path, map_location='cpu')
            sd = obj['model'] if (isinstance(obj, dict) and 'model' in obj) else obj
            return state_dict_size_mb(sd)
        # если нет файла — fallback к live-модели
        how = 'int8_live' if how == 'ckpt_int8' else 'fp32_live'

    # 2) Живые модели: собираем и берём size у state_dict()
    if how == 'fp32_live':
        if arch_key == 'DeiT-T':
            m = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes)
        elif arch_key == 'ResNet-18':
            m = timm.create_model('resnet18', pretrained=False, num_classes=num_classes)
        elif arch_key == 'MobileNetV3-L':
            m = timm.create_model('mobilenetv3_large_100', pretrained=False, num_classes=num_classes)
        elif arch_key == 'ConvNeXt-T':
            m = timm.create_model('convnext_tiny', pretrained=False, num_classes=num_classes)
        else:
            raise ValueError(f'Unknown arch_key: {arch_key}')
        return state_dict_size_mb(m.state_dict())

    if how == 'int8_live':
        if arch_key == 'DeiT-T':
            base = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=num_classes)
        elif arch_key == 'ResNet-18':
            base = timm.create_model('resnet18', pretrained=False, num_classes=num_classes)
        elif arch_key == 'MobileNetV3-L':
            base = timm.create_model('mobilenetv3_large_100', pretrained=False, num_classes=num_classes)
        elif arch_key == 'ConvNeXt-T':
            base = timm.create_model('convnext_tiny', pretrained=False, num_classes=num_classes)
        else:
            raise ValueError(f'Unknown arch_key: {arch_key}')
        qmodel = to_int8_dynamic(base)
        return state_dict_size_mb(qmodel.state_dict())

    raise ValueError(f'Unknown how={how}')

# Собираем строки
out_rows = []
for disp_name, arch_key, how in rows_def:
    s_mb = size_mb_for_row(disp_name, arch_key, how)
    stats = arch_stats[arch_key]
    out_rows.append({
        'Model': disp_name,
        'Params_M': round(stats['params_m'], 2),
        'FLOPs_G@224': (round(stats['flops_g'], 2) if stats['flops_g'] is not None else None),
        'Size_MB': round(s_mb, 2),
    })

# ---------- Сохраняем CSV ----------
csv_path = os.path.join(RESULTS_DIR, 'models_desc.csv')
pd.DataFrame(out_rows).to_csv(csv_path, index=False)
print(f'CSV -> {csv_path}')


  from .autonotebook import tqdm as notebook_tqdm
  device=storage.device,


CSV -> data\results\models_desc.csv
