In [11]:
import os
import json
import pandas as pd
import random
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms as T
import timm

from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import csv

SEED = 42
ISIC_ROOT = os.path.join('data', 'ISIC')
TRAIN_CSV = os.path.join(ISIC_ROOT, 'split_train.csv')
VAL_CSV = os.path.join(ISIC_ROOT, 'split_test.csv')
LABELS_JSON = os.path.join(ISIC_ROOT, 'labels.json')
CKPT_PATH = os.path.join('data/model_weights', 'deit_s_best.pth')

LOG_DIR = os.path.join(ISIC_ROOT, 'runs', 'deit_s')
CSV_LOG = os.path.join(ISIC_ROOT, 'train_log.csv')
os.makedirs(LOG_DIR, exist_ok=True)

IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 20
BASE_LR = 5e-4
WEIGHT_DECAY = 0.05
NUM_WORKERS = 0

In [12]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class ISICCsvDataset(Dataset):
    def __init__(self, csv_path: str, label2idx=None, tfm=None):
        df = pd.read_csv(csv_path)
        if 'path' not in df.columns or 'label' not in df.columns:
            raise ValueError("CSV должен содержать столбцы 'path' и 'label'")
        self.paths = df['path'].tolist()
        self.labels = df['label'].tolist()
        if label2idx is None:
            uniq = sorted(pd.unique(self.labels).tolist())
            label2idx = {lbl: i for i, lbl in enumerate(uniq)}
        self.label2idx = label2idx
        self.tfm = tfm

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        y_str = self.labels[idx]
        y = self.label2idx[y_str]
        with Image.open(p) as img:
            img = img.convert('RGB')
        if self.tfm:
            img = self.tfm(img)
        return img, y


def build_transforms(img_size=224):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    train_tfm = T.Compose([
        T.RandomResizedCrop(img_size, scale=(0.7, 1.0), ratio=(0.75, 1.33)),
        T.RandomHorizontalFlip(p=0.5),
        T.ColorJitter(0.2, 0.2, 0.2, 0.1),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    val_tfm = T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    return train_tfm, val_tfm


def make_weighted_sampler(labels_idx, num_classes):
    counts = np.bincount(labels_idx, minlength=num_classes).astype(np.float32)
    inv = 1.0 / np.maximum(counts, 1.0)
    weights = [inv[y] for y in labels_idx]
    return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True), counts


@torch.no_grad()
def evaluate(model, loader, device, criterion):
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        total_loss += loss.item() * y.size(0)
        preds = logits.argmax(1)
        total_correct += (preds == y).sum().item()
        total += y.size(0)
    return total_loss / max(total, 1), total_correct / max(total, 1)


set_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

train_tfm, val_tfm = build_transforms(IMG_SIZE)

tmp_train = pd.read_csv(TRAIN_CSV)
uniq_labels = sorted(pd.unique(tmp_train['label']).tolist())
label2idx = {lbl: i for i, lbl in enumerate(uniq_labels)}
with open(LABELS_JSON, 'w', encoding='utf-8') as f:
    json.dump(label2idx, f, ensure_ascii=False, indent=2)

train_ds = ISICCsvDataset(TRAIN_CSV, label2idx=label2idx, tfm=train_tfm)
val_ds = ISICCsvDataset(VAL_CSV, label2idx=label2idx, tfm=val_tfm)
num_classes = len(label2idx)
print(f'Классов: {num_classes} -> {uniq_labels}')

train_labels_idx = [label2idx[l] for l in tmp_train['label'].tolist()]
sampler, class_counts = make_weighted_sampler(train_labels_idx, num_classes)
print('Распределение классов (train):', class_counts.tolist())

Device: cuda
Классов: 7 -> ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
Распределение классов (train): [262.0, 411.0, 879.0, 92.0, 890.0, 5364.0, 114.0]


In [13]:
pin = device.type == 'cuda'
train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, sampler=sampler,
    num_workers=NUM_WORKERS, pin_memory=pin, persistent_workers=False
)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=pin, persistent_workers=False
)

model = timm.create_model('deit_small_patch16_224', pretrained=True, num_classes=num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda', enabled=(device.type == 'cuda'))

writer = SummaryWriter(LOG_DIR)
best_acc = 0.0
global_step = 0

if not os.path.exists(CSV_LOG):
    with open(CSV_LOG, 'w', newline='', encoding='utf-8') as f:
        w = csv.writer(f)
        w.writerow(['epoch', 'train_loss', 'val_loss', 'val_acc', 'lr'])


In [14]:
best_acc = 0.0
global_step = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0
    running_correct = 0
    seen = 0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{EPOCHS}', leave=False)
    for x, y in pbar:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            logits = model(x)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # accumulate loss
        bs = y.size(0)
        running_loss += loss.item() * bs
        seen += bs

        # batch accuracy
        preds = logits.argmax(1)
        correct = (preds == y).sum().item()
        running_correct += correct
        batch_acc = correct / bs

        # log batch metrics
        writer.add_scalar('train/batch_loss', loss.item(), global_step)
        writer.add_scalar('train/batch_acc', batch_acc, global_step)
        writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], global_step)
        pbar.set_postfix(batch_loss=f'{loss.item():.4f}',
                         batch_acc=f'{batch_acc:.4f}',
                         lr=f"{optimizer.param_groups[0]['lr']:.2e}")
        global_step += 1

    scheduler.step()
    train_loss = running_loss / max(seen, 1)
    train_acc = running_correct / max(seen, 1)

    # validation
    val_loss, val_acc = evaluate(model, val_loader, device, criterion)

    # log epoch metrics
    writer.add_scalar('train/epoch_loss', train_loss, epoch)
    writer.add_scalar('train/epoch_acc', train_acc, epoch)
    writer.add_scalar('val/loss', val_loss, epoch)
    writer.add_scalar('val/acc', val_acc, epoch)

    print(f'Epoch {epoch:03d} | '
          f'train_loss={train_loss:.4f}  train_acc={train_acc:.4f}  '
          f'val_loss={val_loss:.4f}  val_acc={val_acc:.4f}  '
          f'lr={scheduler.get_last_lr()[0]:.2e}')

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({
            'model': model.state_dict(),
            'label2idx': label2idx,
            'epoch': epoch,
            'val_acc': val_acc
        }, CKPT_PATH)
        print(f'Save best weights -> {CKPT_PATH}')

print(f'Best val_acc: {best_acc:.4f}')
writer.close()

                                                                                                               

Epoch 001 | train_loss=1.4073  train_acc=0.4472  val_loss=1.0675  val_acc=0.5632  lr=4.97e-04
Save best weights -> data/model_weights\deit_s_best.pth


                                                                                                               

Epoch 002 | train_loss=1.0272  train_acc=0.6090  val_loss=0.7592  val_acc=0.7284  lr=4.88e-04
Save best weights -> data/model_weights\deit_s_best.pth


                                                                                                               

Epoch 003 | train_loss=0.8149  train_acc=0.6935  val_loss=0.8974  val_acc=0.6500  lr=4.73e-04


                                                                                                               

Epoch 004 | train_loss=0.7501  train_acc=0.7169  val_loss=0.7428  val_acc=0.7424  lr=4.52e-04
Save best weights -> data/model_weights\deit_s_best.pth


                                                                                                               

Epoch 005 | train_loss=0.6669  train_acc=0.7469  val_loss=0.9142  val_acc=0.6191  lr=4.27e-04


                                                                                                               

Epoch 006 | train_loss=0.5994  train_acc=0.7733  val_loss=0.6421  val_acc=0.7439  lr=3.97e-04
Save best weights -> data/model_weights\deit_s_best.pth


                                                                                                               

Epoch 007 | train_loss=0.5643  train_acc=0.7820  val_loss=0.8439  val_acc=0.6520  lr=3.63e-04


                                                                                                               

Epoch 008 | train_loss=0.4956  train_acc=0.8142  val_loss=0.9573  val_acc=0.6405  lr=3.27e-04


                                                                                                               

Epoch 009 | train_loss=0.4381  train_acc=0.8339  val_loss=0.6943  val_acc=0.7269  lr=2.89e-04


                                                                                                                

Epoch 010 | train_loss=0.3965  train_acc=0.8474  val_loss=0.8342  val_acc=0.6715  lr=2.50e-04


                                                                                                                

Epoch 011 | train_loss=0.3626  train_acc=0.8640  val_loss=0.7037  val_acc=0.7504  lr=2.11e-04
Save best weights -> data/model_weights\deit_s_best.pth


                                                                                                                

Epoch 012 | train_loss=0.3046  train_acc=0.8868  val_loss=0.8006  val_acc=0.7149  lr=1.73e-04


                                                                                                                

Epoch 013 | train_loss=0.2497  train_acc=0.9031  val_loss=0.6754  val_acc=0.7678  lr=1.37e-04
Save best weights -> data/model_weights\deit_s_best.pth


                                                                                                                

Epoch 014 | train_loss=0.2040  train_acc=0.9237  val_loss=0.6764  val_acc=0.7693  lr=1.03e-04
Save best weights -> data/model_weights\deit_s_best.pth


                                                                                                                

Epoch 015 | train_loss=0.1792  train_acc=0.9321  val_loss=0.6349  val_acc=0.7888  lr=7.32e-05
Save best weights -> data/model_weights\deit_s_best.pth


                                                                                                                

Epoch 016 | train_loss=0.1459  train_acc=0.9457  val_loss=0.6102  val_acc=0.8168  lr=4.77e-05
Save best weights -> data/model_weights\deit_s_best.pth


                                                                                                                

Epoch 017 | train_loss=0.1228  train_acc=0.9546  val_loss=0.6519  val_acc=0.8058  lr=2.72e-05


                                                                                                                

Epoch 018 | train_loss=0.1068  train_acc=0.9612  val_loss=0.6532  val_acc=0.7968  lr=1.22e-05


                                                                                                                

Epoch 019 | train_loss=0.0866  train_acc=0.9674  val_loss=0.6504  val_acc=0.8008  lr=3.08e-06


                                                                                                                

Epoch 020 | train_loss=0.0838  train_acc=0.9703  val_loss=0.6575  val_acc=0.7973  lr=0.00e+00
Best val_acc: 0.8168
