In [None]:
import albumentations as A
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import timm
import os
import random
from collections import defaultdict
from glob import glob
from tqdm import tqdm
import pandas as pd

In [None]:
base_path = "/kaggle/input/datasets/iowiqo/lab1vegs/imgs/"

In [None]:
# –≤–∞–∂–Ω–æ - –∑–∞—Ñ–∏–∫—Å–∏—Ä–æ–≤–∞—Ç—å –≤—Å–µ —Å–∏–¥—ã
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
class_to_idx = {
  "–ê–ø–µ–ª—å—Å–∏–Ω—ã": 0,
  "–ë–∞–Ω–∞–Ω—ã": 1,
  "–ì—Ä—É—à–∏": 2,
  "–ö–∞–±–∞—á–∫–∏": 3,
  "–ö–∞–ø—É—Å—Ç–∞": 4,
  "–ö–∞—Ä—Ç–æ—Ñ–µ–ª—å": 5,
  "–ö–∏–≤–∏": 6,
  "–õ–∏–º–æ–Ω": 7,
  "–õ—É–∫": 8,
  "–ú–∞–Ω–¥–∞—Ä–∏–Ω—ã": 9,
  "–ú–æ—Ä–∫–æ–≤—å": 10,
  "–û–≥—É—Ä—Ü—ã": 11,
  "–¢–æ–º–∞—Ç—ã": 12,
  "–Ø–±–ª–æ–∫–∏ –∑–µ–ª—ë–Ω—ã–µ": 13,
  "–Ø–±–ª–æ–∫–∏ –∫—Ä–∞—Å–Ω—ã–µ": 14
}

In [None]:
class EMA:
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        self.register()

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, model=None):
        if model is None:
            model = self.model
        for name, param in model.named_parameters():
            if param.requires_grad:
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}
        
    @property
    def ema(self):
        return self.model

In [None]:
class MyDataset(Dataset):
    def __init__(self, images_filepaths, name2label,
                 base_transform=None,
                 strong_transform=None,
                 weak_classes=None):
        """
        images_filepaths: —Å–ø–∏—Å–æ–∫ –ø—É—Ç–µ–π –∏–ª–∏ –∫–æ—Ä—Ç–µ–∂–µ–π (path, label)
        name2label: —Å–ª–æ–≤–∞—Ä—å –∏–º—è –∫–ª–∞—Å—Å–∞ -> –∏–Ω–¥–µ–∫—Å (–¥–ª—è —Å—Ç–∞—Ä—ã—Ö –¥–∞–Ω–Ω—ã—Ö, –µ—Å–ª–∏ paths –±–µ–∑ –º–µ—Ç–æ–∫)
        –ï—Å–ª–∏ images_filepaths —Å–æ–¥–µ—Ä–∂–∏—Ç –∫–æ—Ä—Ç–µ–∂–∏ (path, label), —Ç–æ name2label –∏–≥–Ω–æ—Ä–∏—Ä—É–µ—Ç—Å—è.
        """
        self.images_filepaths = images_filepaths
        self.name2label = name2label
        self.base_transform = base_transform
        self.strong_transform = strong_transform
        self.weak_classes = weak_classes or []

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        # –ï—Å–ª–∏ images_filepaths —Ö—Ä–∞–Ω–∏—Ç –∫–æ—Ä—Ç–µ–∂–∏ (path, label)
        if isinstance(self.images_filepaths[idx], (tuple, list)):
            image_filepath, label = self.images_filepaths[idx]
        else:
            # –°—Ç–∞—Ä—ã–π –≤–∞—Ä–∏–∞–Ω—Ç: —Ç–æ–ª—å–∫–æ –ø—É—Ç—å, –º–µ—Ç–∫–∞ –∏–∑ name2label
            image_filepath = self.images_filepaths[idx]
            label = self.name2label[os.path.normpath(image_filepath).split(os.sep)[-3]]

        image = cv2.imdecode(np.fromfile(image_filepath, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # –í—ã–±–æ—Ä –ø–æ–¥—Ö–æ–¥—è—â–µ–≥–æ pipeline
        if label in self.weak_classes:
            transform = self.strong_transform
        else:
            transform = self.base_transform

        if transform is not None:
            image = transform(image=image)['image']

        return image, label


def train_test_split_from_directory(root_path, folder2class, train_size=0.8):
    train, test = [], []
    class_to_images = defaultdict(list)

    for class_name in os.listdir(root_path):
        class_path = os.path.join(root_path, class_name)
        if not os.path.isdir(class_path):
            continue

        for subclass_name in os.listdir(class_path):
            subclass_path = os.path.join(class_path, subclass_name)
            if not os.path.isdir(subclass_path):
                continue

            # –ì–ª–æ–±–∞–ª—å–Ω—ã–π –∫–ª–∞—Å—Å = –∏–º—è –ø–µ—Ä–≤–æ–π –ø–∞–ø–∫–∏ (class_name)
            class_name = folder2class.get(class_name, class_name)

            images = glob(os.path.join(subclass_path, '*.jpg')) + \
                     glob(os.path.join(subclass_path, '*.png')) + \
                     glob(os.path.join(subclass_path, '*.jpeg'))

            class_to_images[class_name].extend(images)

    # –†–∞–≤–Ω–æ–º–µ—Ä–Ω–æ–µ —Ä–∞–∑–±–∏–µ–Ω–∏–µ –ø–æ –∫–∞–∂–¥–æ–º—É –∫–ª–∞—Å—Å—É
    for cls_name, images in class_to_images.items():
        random.shuffle(images)
        split_idx = int(train_size * len(images))
        train.extend(images[:split_idx])
        test.extend(images[split_idx:])

    return train, test


In [None]:
from sklearn.model_selection import StratifiedKFold
from collections import Counter

# –°–æ–±–∏—Ä–∞–µ–º –≤—Å–µ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è –∏–∑ –ø–∞–ø–∫–∏ train
all_images = []
all_labels = []

dataset_path = base_path + 'train'
for class_name in os.listdir(dataset_path):
    class_path = os.path.join(dataset_path, class_name)
    if not os.path.isdir(class_path):
        continue
    # –ï—Å–ª–∏ –µ—Å—Ç—å –ø–æ–¥–ø–∞–ø–∫–∏ (–Ω–∞–ø—Ä–∏–º–µ—Ä, —Ä–∞–∑–Ω—ã–µ –ø–∞—Ä—Ç–∏–∏)
    for subclass_name in os.listdir(class_path):
        subclass_path = os.path.join(class_path, subclass_name)
        if not os.path.isdir(subclass_path):
            continue
        class_idx = class_to_idx.get(class_name)
        if class_idx is None:
            print(f"–ü—Ä–µ–¥—É–ø—Ä–µ–∂–¥–µ–Ω–∏–µ: –Ω–µ–∏–∑–≤–µ—Å—Ç–Ω—ã–π –∫–ª–∞—Å—Å {class_name}")
            continue
        images = glob(os.path.join(subclass_path, '*.jpg')) + \
                 glob(os.path.join(subclass_path, '*.png')) + \
                 glob(os.path.join(subclass_path, '*.jpeg'))
        for img_path in images:
            all_images.append(img_path)
            all_labels.append(class_idx)

print(f"–í—Å–µ–≥–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π: {len(all_images)}")
print(f"–†–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–∏–µ –∫–ª–∞—Å—Å–æ–≤: {dict(Counter(all_labels))}")

In [None]:
# –í–µ—Å–∞ –¥–ª—è –∫–ª–∞—Å—Å–æ–≤ —Å —É—á—ë—Ç–æ–º –¥–∏—Å–±–∞–ª–∞–Ω—Å–∞ –∏ —É—Å–∏–ª–µ–Ω–∏–µ–º —Å–ª–∞–±—ã—Ö –∫–ª–∞—Å—Å–æ–≤
num_classes = len(class_to_idx)
class_weights = torch.zeros(num_classes)

train_counts_dict = dict(Counter(all_labels))
max_count = max(train_counts_dict.values()) if train_counts_dict else 1

# –ë–∞–∑–æ–≤—ã–µ –≤–µ—Å–∞ ‚Äì –æ–±—Ä–∞—Ç–Ω–∞—è —á–∞—Å—Ç–æ—Ç–∞
for i in range(num_classes):
    count = train_counts_dict.get(i, 0)
    if count > 0:
        class_weights[i] = max_count / count
    else:
        class_weights[i] = 0.0

print("Raw class weights (inverse frequency):", class_weights.cpu().numpy())

# –ù–æ—Ä–º–∏—Ä–æ–≤–∫–∞, —á—Ç–æ–±—ã —Å—É–º–º–∞ —Ä–∞–≤–Ω—è–ª–∞—Å—å num_classes
class_weights = class_weights / class_weights.sum() * num_classes

# –°–ø–∏—Å–æ–∫ —Å–ª–∞–±—ã—Ö –∫–ª–∞—Å—Å–æ–≤ (—Å–∫–æ—Ä—Ä–µ–∫—Ç–∏—Ä—É–π—Ç–µ –ø–æ–¥ –≤–∞—à–∏ –Ω–∞–±–ª—é–¥–µ–Ω–∏—è)
WEAK_CLASSES = [0, 2, 6, 8, 14]  # –ê–ø–µ–ª—å—Å–∏–Ω—ã, –ì—Ä—É—à–∏, –ö–∏–≤–∏, –õ—É–∫, –Ø–±–ª–æ–∫–∏ –∫—Ä–∞—Å–Ω—ã–µ
BOOST_FACTOR = 1.2  # –≤–æ —Å–∫–æ–ª—å–∫–æ —Ä–∞–∑ —É–≤–µ–ª–∏—á–∏—Ç—å –≤–µ—Å –¥–ª—è —ç—Ç–∏—Ö –∫–ª–∞—Å—Å–æ–≤

for i in WEAK_CLASSES:
    class_weights[i] *= BOOST_FACTOR

# –ü–æ–≤—Ç–æ—Ä–Ω–∞—è –Ω–æ—Ä–º–∏—Ä–æ–≤–∫–∞ –ø–æ—Å–ª–µ —É—Å–∏–ª–µ–Ω–∏—è
class_weights = class_weights / class_weights.sum() * num_classes

print("Class weights after boosting weak classes:", class_weights.cpu().numpy())

# –ü–µ—Ä–µ–Ω–µ—Å–∏—Ç–µ —ç—Ç—É —Å—Ç—Ä–æ–∫—É –ü–û–ó–ñ–ï, –ø–æ—Å–ª–µ –æ–ø—Ä–µ–¥–µ–ª–µ–Ω–∏—è device
# class_weights = class_weights.to(device)

# –í–µ—Ä–Ω—ë–º—Å—è –∫ –æ–±—É—á–µ–Ω–∏—é

### –ê—É–≥–º–µ–Ω—Ç–∞—Ü–∏–∏

In [None]:
def create_tta_batch(x: torch.Tensor):
    # x: shape (B, C, H, W)
    tta_versions = []

    tta_versions.append(x)
    tta_versions.append(torch.flip(x, dims=[3]))
    tta_versions.append(torch.flip(x, dims=[2]))
    tta_versions.append(torch.flip(x, dims=[2, 3]))
    tta_versions.append(torch.rot90(x, k=1, dims=[2, 3]))
    tta_versions.append(torch.rot90(x, k=2, dims=[2, 3]))
    tta_versions.append(torch.rot90(x, k=3, dims=[2, 3]))

    tta_batch = torch.cat(tta_versions, dim=0) # shape: (B*n_tta, C, H, W)
    return tta_batch, len(tta_versions)

def calc_tta_logits(model, X_batch):
    tta_batch, n_tta = create_tta_batch(X_batch) # shape: (B*n_tta, C, H, W)
    logits_all = model(tta_batch)  # shape: (B*n_tta, num_classes)

    B = X_batch.size(0)

    logits_all = logits_all.view(n_tta, B, -1) # shape: (n_tta, B, num_classes)
    logits = logits_all.mean(dim=0)  # shape: (B, num_classes)
    return logits


In [None]:
def mixup_data(x, y, alpha=0.2):
    """
    –°–æ–∑–¥–∞—ë—Ç —Å–º–µ—à–∞–Ω–Ω—ã–µ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è –∏ –º–µ—Ç–∫–∏.
    x: –±–∞—Ç—á –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π (B, C, H, W)
    y: –±–∞—Ç—á –º–µ—Ç–æ–∫ (B,) –∏–ª–∏ one-hot
    alpha: –ø–∞—Ä–∞–º–µ—Ç—Ä –±–µ—Ç–∞-—Ä–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–∏—è
    –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç:
        mixed_x: —Å–º–µ—à–∞–Ω–Ω—ã–µ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è
        y_a, y_b: –∏—Å—Ö–æ–¥–Ω—ã–µ –º–µ—Ç–∫–∏ –¥–ª—è –ø–µ—Ä–≤–æ–≥–æ –∏ –≤—Ç–æ—Ä–æ–≥–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π
        lam: –∫–æ—ç—Ñ—Ñ–∏—Ü–∏–µ–Ω—Ç —Å–º–µ—à–∏–≤–∞–Ω–∏—è
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """
    –ö—Ä–∏—Ç–µ—Ä–∏–π –¥–ª—è MixUp: –≤–∑–≤–µ—à–µ–Ω–Ω–∞—è —Å—É–º–º–∞ –ø–æ—Ç–µ—Ä—å –¥–ª—è –¥–≤—É—Ö –Ω–∞–±–æ—Ä–æ–≤ –º–µ—Ç–æ–∫.
    criterion: —Ñ—É–Ω–∫—Ü–∏—è –ø–æ—Ç–µ—Ä—å, –ø–æ–¥–¥–µ—Ä–∂–∏–≤–∞—é—â–∞—è —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω—ã–π –≤—ã–∑–æ–≤ (pred, target)
    pred: –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è –º–æ–¥–µ–ª–∏
    y_a, y_b: –∏—Å—Ö–æ–¥–Ω—ã–µ –º–µ—Ç–∫–∏
    lam: –∫–æ—ç—Ñ—Ñ–∏—Ü–∏–µ–Ω—Ç —Å–º–µ—à–∏–≤–∞–Ω–∏—è
    """
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
# ---------- –ù–û–í–´–ô –ë–õ–û–ö: –Ω–µ—Å–∫–æ–ª—å–∫–æ pipeline'–æ–≤ –∞—É–≥–º–µ–Ω—Ç–∞—Ü–∏–π ----------
# –ò–Ω–¥–µ–∫—Å—ã –ø—Ä–æ–±–ª–µ–º–Ω—ã—Ö –∫–ª–∞—Å—Å–æ–≤ (–∏–∑ class_to_idx)
WEAK_CLASSES = [0, 2, 6, 8, 14]       # –ê–ø–µ–ª—å—Å–∏–Ω—ã, –ì—Ä—É—à–∏, –ö–∏–≤–∏, –õ—É–∫, –Ø–±–ª–æ–∫–∏ –∫—Ä–∞—Å–Ω—ã–µ

IMAGE_W, IMAGE_H = 320, 320

# –ë–∞–∑–æ–≤—ã–π pipeline (–¥–ª—è –≤—Å–µ—Ö, –∫—Ä–æ–º–µ –æ—Å–æ–±—ã—Ö —Å–ª—É—á–∞–µ–≤)
base_train_transforms = A.Compose([
    A.Resize(IMAGE_W, IMAGE_H),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=30, p=0.5),
    A.GridDropout(ratio=0.2, p=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    A.ToTensorV2(),
])

# –£—Å–∏–ª–µ–Ω–Ω—ã–π pipeline –¥–ª—è —Å–ª–∞–±—ã—Ö –∫–ª–∞—Å—Å–æ–≤ (–±–æ–ª—å—à–µ –≥–µ–æ–º–µ—Ç—Ä–∏–∏, Cutout, —à—É–º)
strong_transforms = A.Compose([
    A.Resize(IMAGE_W, IMAGE_H),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=45, p=0.7),
    A.GridDropout(ratio=0.25, p=0.25),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    A.ToTensorV2(),
])

# –î–ª—è –≤–∞–ª–∏–¥–∞—Ü–∏–∏ –æ—Å—Ç–∞–≤–ª—è–µ–º –±–µ–∑ –∏–∑–º–µ–Ω–µ–Ω–∏–π (—Ç–æ–ª—å–∫–æ Resize + Normalize)
val_transforms = A.Compose([
    A.Resize(IMAGE_W, IMAGE_H),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    A.ToTensorV2(),
])

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, loss_fn, device, desc="Val"):
    model.eval()

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    pbar = tqdm(dataloader, desc=desc, leave=False)
    for X_batch, y_batch in pbar:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        logits = calc_tta_logits(model, X_batch)
        loss = loss_fn(logits, y_batch)

        batch_size = y_batch.size(0)
        total_loss += loss.item() * batch_size

        y_pred = logits.argmax(dim=1)
        total_correct += (y_pred == y_batch).sum().item()
        total_samples += batch_size

        avg_loss = total_loss / max(total_samples, 1)
        acc = total_correct / max(total_samples, 1)
        pbar.set_postfix(loss=f"{avg_loss:.4f}", acc=f"{acc:.4f}")

    avg_loss = total_loss / max(total_samples, 1)
    accuracy = total_correct / max(total_samples, 1)
    return accuracy, avg_loss

In [None]:
def train(model, loss_fn, optimizer, scheduler, train_loader, val_loader, device, 
          ema=None, writer=None, n_epoch=10, save_path="best_model.pth"):
    num_iter = 0
    best_val_acc = 0.0
    patience = 5
    patience_counter = 0
    
    # –ü–∞—Ä–∞–º–µ—Ç—Ä—ã MixUp
    mixup_alpha = 0.1
    mixup_prob = 0.5

    for epoch in range(1, n_epoch + 1):
        model.train()

        total_loss = 0.0
        clean_correct = 0
        clean_samples = 0
        num_batches = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{n_epoch}", leave=True)

        for X_batch, y_batch in pbar:
            X_batch = X_batch.to(device, non_blocking=True)
            y_batch = y_batch.to(device, non_blocking=True)

            # –°–ª—É—á–∞–π–Ω–æ —Ä–µ—à–∞–µ–º, –ø—Ä–∏–º–µ–Ω—è—Ç—å –ª–∏ MixUp
            use_mixup = random.random() < mixup_prob

            if use_mixup:
                X_batch, y_a, y_b, lam = mixup_data(X_batch, y_batch, alpha=mixup_alpha)
                logits = model(X_batch)
                loss = mixup_criterion(loss_fn, logits, y_a, y_b, lam)
                # –î–ª—è –±–∞—Ç—á–µ–π —Å MixUp –Ω–µ –æ–±–Ω–æ–≤–ª—è–µ–º accuracy
            else:
                logits = model(X_batch)
                loss = loss_fn(logits, y_batch)
                
                # –°—á–∏—Ç–∞–µ–º accuracy —Ç–æ–ª—å–∫–æ –Ω–∞ —á–∏—Å—Ç—ã—Ö –±–∞—Ç—á–∞—Ö
                y_pred = logits.argmax(dim=1)
                clean_correct += (y_pred == y_batch).sum().item()
                clean_samples += y_batch.size(0)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            
            if ema is not None:
                ema.update(model)

            # –û–±–Ω–æ–≤–ª—è–µ–º –º–µ—Ç—Ä–∏–∫–∏
            batch_size = X_batch.size(0)
            total_loss += loss.item() * batch_size
            num_batches += 1
            
            avg_loss = total_loss / (num_batches * train_loader.batch_size)
            clean_acc = clean_correct / max(clean_samples, 1)
            
            pbar.set_postfix(
                train_loss=f"{avg_loss:.4f}", 
                train_acc=f"{clean_acc:.4f}",
                mixup=f"{use_mixup}"
            )

            # –õ–æ–≥–∏—Ä–æ–≤–∞–Ω–∏–µ
            num_iter += 1
            if writer is not None:
                writer.add_scalar("Loss/train", loss.item(), num_iter)
                if not use_mixup:
                    writer.add_scalar("Accuracy/train_clean", (y_pred == y_batch).float().mean().item(), num_iter)

        # –í–∞–ª–∏–¥–∞—Ü–∏—è (–∏—Å–ø–æ–ª—å–∑—É–µ–º EMA –º–æ–¥–µ–ª—å, –µ—Å–ª–∏ –æ–Ω–∞ –µ—Å—Ç—å)
        val_model = ema.ema if ema is not None else model
        val_acc, val_loss = evaluate(val_model, val_loader, loss_fn, device, desc=f"Val {epoch}/{n_epoch}")
        
        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –ª—É—á—à—É—é –º–æ–¥–µ–ª—å
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(val_model.state_dict(), save_path)
            print(f"üî• New best model saved! val_acc = {val_acc:.4f}")
        else:
            patience_counter += 1

        # –û–±–Ω–æ–≤–ª—è–µ–º scheduler
        if scheduler is not None:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
            current_lr = scheduler.get_last_lr()[0]
        else:
            current_lr = optimizer.param_groups[0]['lr']

        if writer is not None:
            writer.add_scalar("Loss/val", val_loss, num_iter)
            writer.add_scalar("Accuracy/val", val_acc, num_iter)
            writer.add_scalar("LR", current_lr, epoch)

        # Early stopping
        if patience_counter >= patience:
            print(f"‚õî Early stopping triggered. Best val_acc = {best_val_acc:.4f}")
            break

        print(f"Epoch {epoch}/{n_epoch}: val_loss={val_loss:.4f}  val_acc={val_acc:.4f}  lr={current_lr:.6f}")

    return model

In [None]:
# ============================================
# –ù–ê–ß–ê–õ–û K-FOLD –û–ë–£–ß–ï–ù–ò–Ø
# ============================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# –ü–µ—Ä–µ–Ω–æ—Å–∏–º –≤–µ—Å–∞ –∫–ª–∞—Å—Å–æ–≤ –Ω–∞ —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ
class_weights = class_weights.to(device)

# –ü–∞—Ä–∞–º–µ—Ç—Ä—ã K-Fold
n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=SEED)
EPOCHS_PER_FOLD = 30  # –º–æ–∂–Ω–æ —É–≤–µ–ª–∏—á–∏—Ç—å –¥–æ 50, –µ—Å–ª–∏ –ø–æ–∑–≤–æ–ª—è–µ—Ç –≤—Ä–µ–º—è
BATCH_SIZE = 32  # 

# –°–ø–∏—Å–æ–∫ –¥–ª—è —Ö—Ä–∞–Ω–µ–Ω–∏—è –ø—É—Ç–µ–π –∫ –ª—É—á—à–∏–º –º–æ–¥–µ–ª—è–º
best_model_paths = []

# –¶–∏–∫–ª –ø–æ —Ñ–æ–ª–¥–∞–º
for fold, (train_idx, val_idx) in enumerate(skf.split(all_images, all_labels)):
    print(f"\n{'='*60}")
    print(f"Fold {fold+1}/{n_folds}")
    print(f"{'='*60}")
    
    # –§–æ—Ä–º–∏—Ä—É–µ–º –¥–∞–Ω–Ω—ã–µ –¥–ª—è —Ç–µ–∫—É—â–µ–≥–æ —Ñ–æ–ª–¥–∞
    train_data = [(all_images[i], all_labels[i]) for i in train_idx]
    val_data = [(all_images[i], all_labels[i]) for i in val_idx]
    
    print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
    
    # –°–æ–∑–¥–∞—ë–º –¥–∞—Ç–∞—Å–µ—Ç—ã
    train_dataset_fold = MyDataset(
        images_filepaths=train_data,
        name2label=None,
        base_transform=base_train_transforms,
        strong_transform=strong_transforms,
        weak_classes=WEAK_CLASSES
    )
    
    val_dataset_fold = MyDataset(
        images_filepaths=val_data,
        name2label=None,
        base_transform=val_transforms,
        strong_transform=val_transforms
    )
    
    # –î–∞—Ç–∞–ª–æ–∞–¥–µ—Ä—ã
    train_loader_fold = DataLoader(
        train_dataset_fold,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True
    )
    
    val_loader_fold = DataLoader(
        val_dataset_fold,
        batch_size=BATCH_SIZE // 2,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True
    )
    
    # –°–æ–∑–¥–∞—ë–º –º–æ–¥–µ–ª—å –¥–ª—è —Ñ–æ–ª–¥–∞
    model_fold = timm.create_model('tf_efficientnetv2_s', pretrained=True, num_classes=15, drop_rate=0.3)
    model_fold.to(device)
    
    # Loss, –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä, scheduler, EMA
    loss_fn_fold = torch.nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    
    optimizer_fold = torch.optim.AdamW(model_fold.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler_fold = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_fold, T_0=5, T_mult=1)
    
    ema_fold = EMA(model_fold, decay=0.999)
    
    # –ü—É—Ç—å –¥–ª—è —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏—è –ª—É—á—à–µ–π –º–æ–¥–µ–ª–∏ —Ç–µ–∫—É—â–µ–≥–æ —Ñ–æ–ª–¥–∞
    fold_save_path = f"best_model_fold{fold}.pth"
    best_model_paths.append(fold_save_path)
    
    # –û–±—É—á–µ–Ω–∏–µ
    train(
        model=model_fold,
        loss_fn=loss_fn_fold,
        optimizer=optimizer_fold,
        scheduler=scheduler_fold,
        train_loader=train_loader_fold,
        val_loader=val_loader_fold,
        device=device,
        ema=ema_fold,
        writer=None,  # –º–æ–∂–Ω–æ —Å–æ–∑–¥–∞—Ç—å SummaryWriter –¥–ª—è –∫–∞–∂–¥–æ–≥–æ —Ñ–æ–ª–¥–∞, –µ—Å–ª–∏ –Ω—É–∂–Ω–æ
        n_epoch=EPOCHS_PER_FOLD,
        save_path=fold_save_path
    )
    
    # –û—á–∏—Å—Ç–∫–∞ –ø–∞–º—è—Ç–∏
    del model_fold, train_loader_fold, val_loader_fold, train_dataset_fold, val_dataset_fold
    torch.cuda.empty_cache()
    
print("\n‚úÖ –í—Å–µ —Ñ–æ–ª–¥—ã –æ–±—É—á–µ–Ω—ã! –õ—É—á—à–∏–µ –º–æ–¥–µ–ª–∏ —Å–æ—Ö—Ä–∞–Ω–µ–Ω—ã.")
print("–ü—É—Ç–∏ –∫ –º–æ–¥–µ–ª—è–º:", best_model_paths)

# ============================================
# –ö–û–ù–ï–¶ K-FOLD –û–ë–£–ß–ï–ù–ò–Ø
# ============================================

# –ü–æ—Å–º–æ—Ç—Ä–∏–º –Ω–∞ –∏—Ç–æ–≥–æ–≤—ã–µ –º–µ—Ç—Ä–∏–∫–∏

In [None]:
import numpy as np
import torch
from sklearn.metrics import classification_report, confusion_matrix

@torch.no_grad()
def sklearn_report(model, dataloader, device, idx2class=None, digits=4):
    model.eval()

    y_true, y_pred = [], []

    for X_batch, y_batch in dataloader:
        X_batch = X_batch.to(device, non_blocking=True)

        logits = calc_tta_logits(model, X_batch)
        preds = logits.argmax(dim=1).cpu().numpy()

        y_pred.append(preds)
        y_true.append(y_batch.numpy())

    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)

    # names for report
    if idx2class is None:
        target_names = None
        labels = None
    else:
        labels = sorted(idx2class.keys())
        target_names = [idx2class[i] for i in labels]

    rep = classification_report(
        y_true, y_pred,
        labels=labels,
        target_names=target_names,
        digits=digits,
        zero_division=0
    )
    print(rep)

# –û—Ñ–æ—Ä–º–ª—è–µ–º –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ

In [None]:
# ============================================
# –§–ò–ù–ê–õ–¨–ù–û–ï –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ï –° –ê–ù–°–ê–ú–ë–õ–ï–ú –ú–û–î–ï–õ–ï–ô
# ============================================
test_images_dir = base_path + "test_images"
submission_path = base_path + "sample_submission.csv"
output_path = "/kaggle/working/submission.csv"

print(f'test_images_dir({test_images_dir})')
print(f'submission_path({submission_path})')

submission = pd.read_csv(submission_path)
image_ids = submission["image_id"].tolist()

# –°–ø–∏—Å–æ–∫ –¥–ª—è —Ö—Ä–∞–Ω–µ–Ω–∏—è –ª–æ–≥–∏—Ç–æ–≤ –∫–∞–∂–¥–æ–π –º–æ–¥–µ–ª–∏
all_fold_logits = []

for fold, model_path in enumerate(best_model_paths):
    print(f"\nüì¶ –ó–∞–≥—Ä—É–∑–∫–∞ –º–æ–¥–µ–ª–∏ —Ñ–æ–ª–¥–∞ {fold} –∏–∑ {model_path}")
    
    # –°–æ–∑–¥–∞—ë–º –º–æ–¥–µ–ª—å —Å —Ç–æ–π –∂–µ –∞—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä–æ–π
    model_fold = timm.create_model('tf_efficientnetv2_s', pretrained=False, num_classes=15)
    model_fold.load_state_dict(torch.load(model_path, map_location='cpu'))
    model_fold.eval().to(device)
    
    fold_logits = []
    test_batch_size = 32  # —É–º–µ–Ω—å—à–∞–µ–º –¥–ª—è TTA
    
    with torch.no_grad():
        for start_idx in tqdm(range(0, len(image_ids), test_batch_size), desc=f"Predicting fold {fold}"):
            batch_ids = image_ids[start_idx:start_idx + test_batch_size]
            images = []
            
            for image_id in batch_ids:
                image_path = os.path.join(test_images_dir, image_id)
                image = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                image = val_transforms(image=image)["image"]
                images.append(image)
            
            X_batch = torch.stack(images).to(device, non_blocking=True)
            logits = calc_tta_logits(model_fold, X_batch)  # –ø—Ä–∏–º–µ–Ω—è–µ–º TTA
            fold_logits.append(logits.cpu())
    
    fold_logits = torch.cat(fold_logits, dim=0)
    all_fold_logits.append(fold_logits)
    
    # –û—á–∏—Å—Ç–∫–∞ –ø–∞–º—è—Ç–∏
    del model_fold
    torch.cuda.empty_cache()

print("\nüîÑ –£—Å—Ä–µ–¥–Ω–µ–Ω–∏–µ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π –ø–æ –≤—Å–µ–º —Ñ–æ–ª–¥–∞–º...")
mean_logits = torch.mean(torch.stack(all_fold_logits), dim=0)
pred_labels = mean_logits.argmax(dim=1).tolist()

submission["label"] = pred_labels
submission.to_csv(output_path, index=False)

print(f"‚úÖ –°–∞–±–º–∏—Ç —Å–æ—Ö—Ä–∞–Ω—ë–Ω –≤ {output_path}")
print("\n–ü–µ—Ä–≤—ã–µ 5 —Å—Ç—Ä–æ–∫ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π:")
submission.head()