In [2]:
import os, copy, gc
import numpy as np
import pandas as pd
from PIL import Image
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision.transforms import InterpolationMode

import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import f1_score
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR
from torch.amp import autocast, GradScaler


In [3]:
class ImageDataset(Dataset):
    def __init__(self, img_dir, csv_path, transform=None, has_labels=True):
        self.df = pd.read_csv(csv_path)
        self.img_dir = img_dir
        self.transform = transform
        self.has_labels = has_labels
        self.labels_dict = { 
            'Ink scenery': 0, 'comic': 1, 'cyberpunk': 2, 'futuristic UI': 3, 'lowpoly': 4, 'oil painting': 5, 
            'pixel': 6, 'realistic': 7, 'steampunk': 8, 'water color': 9, 'UNK': 10
        }

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        uuid = row['uuid']
        img_path = os.path.join(self.img_dir, uuid + '.png')
        image = np.array(Image.open(img_path).convert("RGB"))
        if self.transform:
            image = self.transform(image=image)['image']
        if self.has_labels:
            style = row['style']
            label = self.labels_dict[style] if style in self.labels_dict else 10
            return image, label
        else:
            return image, uuid


In [4]:
means = [0.485, 0.456, 0.406]
stds  = [0.229, 0.224, 0.225]
transform_train = A.Compose([
    A.RandomResizedCrop(size=(288, 288), scale=(0.85, 1.0), ratio=(0.9, 1.11), p=1.0),
    A.HorizontalFlip(p=0.5),
    A.ImageCompression(compression_type='jpeg', quality_range=(70, 100), p=0.3),
    A.RandomBrightnessContrast(0.15, 0.15, p=0.3),
    A.Normalize(mean=means, std=stds),
    ToTensorV2()
])
transform_eval = A.Compose([
    A.Resize(288, 288),
    A.Normalize(mean=means, std=stds),
    ToTensorV2()
])


In [5]:
DIR_TRAIN = '../train'
DIR_VALID = '../valid'
CSV_TRAIN = '../train.csv'
CSV_VALID = '../valid.csv'
BATCH_SIZE = 16

train_dataset = ImageDataset(DIR_TRAIN, CSV_TRAIN, transform=transform_train, has_labels=True)
valid_dataset = ImageDataset(DIR_VALID, CSV_VALID, transform=transform_eval, has_labels=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

# Class weight (for FocalLoss)
class_counts = [875, 869, 753, 684, 557, 473, 376, 290, 193, 106]
counts = torch.tensor(class_counts + [1], dtype=torch.float32)
weights = 1.0 / counts
weights = weights / weights.sum()
weights[-1] = 0.0



In [6]:
def get_model():
    model = models.efficientnet_b4(weights='IMAGENET1K_V1')
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, 11)
    return model

model = get_model()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)
print(model)


cuda:0
EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
            (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(48, 12, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(12, 48, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNo

In [7]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=1.0, reduction='mean', ignore_index=10):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.ignore_index = ignore_index
    def forward(self, inputs, targets):
        mask = (targets != self.ignore_index)
        if mask.sum() == 0:
            return torch.tensor(0.0, device=inputs.device, requires_grad=True)
        inputs, targets = inputs[mask], targets[mask]
        ce_loss = F.cross_entropy(inputs, targets, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce_loss)
        loss = (1 - pt) ** self.gamma * ce_loss
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


In [8]:
weights = weights.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
# criterion = FocalLoss(alpha=class_weights.to(device), gamma=1.0, ignore_index=10)
criterion = FocalLoss(alpha=weights,gamma=1.0,ignore_index=10)
scaler = GradScaler()
scheduler = OneCycleLR(optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=15, pct_start=0.3)


In [9]:
num_epochs = 15
best_f1 = -1
counter = 0
patience = 4
min_delta = 1e-4
best_thr = 0.5
best_model_wts = copy.deepcopy(model.state_dict())
for epoch in range(num_epochs):
    # TRAIN
    model.train()
    for images, labels in tqdm(train_loader, desc=f'Train {epoch+1}'):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        with autocast(device_type='cuda'):
            out = model(images)
            loss = criterion(out, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
    # VALID + best-threshold
    model.eval()
    all_confs, all_preds0, all_labels = [], [], []
    with torch.no_grad():
        for images, labels in tqdm(valid_loader, desc=f'Valid {epoch+1}'):
            images, labels = images.to(device), labels.to(device)
            with autocast(device_type='cuda'):
                logits = model(images)
            probs = F.softmax(logits, 1)
            confs0, preds0 = probs.max(1)
            all_confs.extend(confs0.cpu().numpy())
            all_preds0.extend(preds0.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    all_confs = np.array(all_confs); all_preds0 = np.array(all_preds0); all_labels = np.array(all_labels)
    thr_list = np.linspace(0.3, 0.8, 51)
    best_f1_ep, best_thr_ep = -1, 0.5
    for t in thr_list:
        preds = np.where(all_confs < t, 10, all_preds0)
        f1 = f1_score(all_labels, preds, average='macro', labels=list(range(11)))
        if f1 > best_f1_ep: best_f1_ep, best_thr_ep = f1, t
    print(f"Epoch {epoch+1} — Best valid F1={best_f1_ep:.4f} (thr={best_thr_ep:.2f})")
    # Early stopping
    if best_f1_ep > best_f1 + min_delta:
        best_f1, best_thr = best_f1_ep, best_thr_ep
        best_model_wts = copy.deepcopy(model.state_dict())
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stop at {epoch+1} — Best F1={best_f1:.4f} (thr={best_thr:.2f})")
            break
# Restore best
model.load_state_dict(best_model_wts)


Train 1: 100%|██████████| 324/324 [01:50<00:00,  2.92it/s]
Valid 1: 100%|██████████| 82/82 [00:19<00:00,  4.31it/s]


Epoch 1 — Best valid F1=0.1162 (thr=0.42)


Train 2: 100%|██████████| 324/324 [01:32<00:00,  3.51it/s]
Valid 2: 100%|██████████| 82/82 [00:15<00:00,  5.46it/s]


Epoch 2 — Best valid F1=0.6016 (thr=0.30)


Train 3: 100%|██████████| 324/324 [01:32<00:00,  3.51it/s]
Valid 3: 100%|██████████| 82/82 [00:13<00:00,  5.90it/s]


Epoch 3 — Best valid F1=0.6631 (thr=0.31)


Train 4: 100%|██████████| 324/324 [01:31<00:00,  3.54it/s]
Valid 4: 100%|██████████| 82/82 [00:14<00:00,  5.69it/s]


Epoch 4 — Best valid F1=0.7544 (thr=0.30)


Train 5: 100%|██████████| 324/324 [01:31<00:00,  3.56it/s]
Valid 5: 100%|██████████| 82/82 [00:14<00:00,  5.82it/s]


Epoch 5 — Best valid F1=0.7601 (thr=0.46)


Train 6: 100%|██████████| 324/324 [01:32<00:00,  3.50it/s]
Valid 6: 100%|██████████| 82/82 [00:15<00:00,  5.46it/s]


Epoch 6 — Best valid F1=0.7778 (thr=0.41)


Train 7: 100%|██████████| 324/324 [01:33<00:00,  3.48it/s]
Valid 7: 100%|██████████| 82/82 [00:15<00:00,  5.33it/s]


Epoch 7 — Best valid F1=0.7852 (thr=0.30)


Train 8: 100%|██████████| 324/324 [01:32<00:00,  3.50it/s]
Valid 8: 100%|██████████| 82/82 [00:14<00:00,  5.76it/s]


Epoch 8 — Best valid F1=0.7860 (thr=0.38)


Train 9: 100%|██████████| 324/324 [01:31<00:00,  3.54it/s]
Valid 9: 100%|██████████| 82/82 [00:14<00:00,  5.74it/s]


Epoch 9 — Best valid F1=0.8019 (thr=0.48)


Train 10: 100%|██████████| 324/324 [01:31<00:00,  3.53it/s]
Valid 10: 100%|██████████| 82/82 [00:14<00:00,  5.49it/s]


Epoch 10 — Best valid F1=0.8002 (thr=0.52)


Train 11: 100%|██████████| 324/324 [01:32<00:00,  3.50it/s]
Valid 11: 100%|██████████| 82/82 [00:14<00:00,  5.55it/s]


Epoch 11 — Best valid F1=0.8090 (thr=0.48)


Train 12: 100%|██████████| 324/324 [01:34<00:00,  3.43it/s]
Valid 12: 100%|██████████| 82/82 [00:13<00:00,  5.88it/s]


Epoch 12 — Best valid F1=0.8116 (thr=0.36)


Train 13: 100%|██████████| 324/324 [01:39<00:00,  3.25it/s]
Valid 13: 100%|██████████| 82/82 [00:15<00:00,  5.43it/s]


Epoch 13 — Best valid F1=0.8148 (thr=0.47)


Train 14: 100%|██████████| 324/324 [01:31<00:00,  3.56it/s]
Valid 14: 100%|██████████| 82/82 [00:15<00:00,  5.35it/s]


Epoch 14 — Best valid F1=0.8133 (thr=0.59)


Train 15: 100%|██████████| 324/324 [01:30<00:00,  3.57it/s]
Valid 15: 100%|██████████| 82/82 [00:13<00:00,  5.90it/s]


Epoch 15 — Best valid F1=0.8109 (thr=0.55)


<All keys matched successfully>

In [10]:
torch.save(model.state_dict(), "../go/b4_letsgo.pth")

In [16]:
my_dict = {0:'Ink scenery',
           1:'comic',
           2:'cyberpunk',
           3:'futuristic UI',
           4:'lowpoly',
           5:'oil painting',
           6:'pixel',
           7:'realistic',
           8:'steampunk',
           9:'water color',
           10:'UNK'}
test_dataset = ImageDataset('../test', '../sub_dir/submission01.csv', transform=transform_eval, has_labels=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
uuid_list, pred_list = [], []
model.eval()
with torch.no_grad():
    for images, uuids in tqdm(test_loader, desc="Test"):
        images = images.to(device)
        with autocast(device_type='cuda'):
            logits = model(images)
        probs = F.softmax(logits, 1)
        confs, preds0 = probs.max(1)
        preds = torch.where(confs < 0.6, torch.full_like(preds0, 10), preds0)
        uuid_list.extend(uuids)
        pred_list.extend([my_dict[i.item()] for i in preds])
df_sub = pd.DataFrame({'uuid': uuid_list, 'style': pred_list})
df_sub.to_csv('../sub_dir/submission_1111.csv', index=False)


Test: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]
