In [29]:
!pip install optuna
!pip install optuna-integration 'pytorch_lightning'



In [30]:
import os
import random
import numpy as np
import pandas as pd
import optuna
from optuna.integration import PyTorchLightningPruningCallback  # optional, not required

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from datasets import load_from_disk
from PIL import Image

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cpu


In [31]:
RANDOM_SEED = 42
BATCH_SIZE = 32
EPOCHS = 25
TUNE_EPOCHS = 5      # for Bayesian optimisation
IMG_SIZE = 224
PATCH_SIZE = 16
NUM_CLASSES = 2 

TUNE_EPOCHS = 5

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

Right now, your functions use global state (model, optimizer, LAMBDA_ATTR, train_loader, val_loader, etc.).
We can exploit that and just re-assign the globals inside the Optuna objective to avoid refactoring everything.

Idea: for each trial we:
* create a new model with its own drop
* create a new optimizer (lr, weight_decay)
* set LAMBDA_ATTR
* train for TUNE_EPOCHS
* return best validation accuracy seen

In [32]:
TRAIN_VAL_PATH = "processed_bird_data"
TEST_PATH = "processed_bird_test_data"

print("Loading train/val dataset from:", TRAIN_VAL_PATH)
full_ds = load_from_disk(TRAIN_VAL_PATH)
train_hf = full_ds["train"]
val_hf = full_ds["validation"]

print("Train size:", len(train_hf))
print("Val size:", len(val_hf))

print("\nLoading test dataset from:", TEST_PATH)
test_hf = load_from_disk(TEST_PATH)
print("Test size:", len(test_hf))

# attributes
ATTR_PATH = "data/attributes.npy"
attributes = np.load(ATTR_PATH)
NUM_CLASSES = attributes.shape[0]
NUM_ATTR = attributes.shape[1]

print("\nAttributes shape:", attributes.shape)
print("NUM_CLASSES:", NUM_CLASSES, "| NUM_ATTR:", NUM_ATTR)

Loading train/val dataset from: processed_bird_data
Train size: 3337
Val size: 589

Loading test dataset from: processed_bird_test_data
Test size: 4000

Attributes shape: (200, 312)
NUM_CLASSES: 200 | NUM_ATTR: 312


In [33]:
from torchvision.transforms import InterpolationMode

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=25, interpolation=InterpolationMode.BILINEAR),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

In [34]:
class BirdTrainDataset(Dataset):
    def __init__(self, hf_dataset, attributes, transform=None):
        self.ds = hf_dataset
        self.attributes = attributes.astype("float32")
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        img = item["image"]
        if isinstance(img, Image.Image):
            img = img.convert("RGB")
        label = int(item["label"]) 

        if self.transform is not None:
            img = self.transform(img)

        attr_vec = self.attributes[label]
        attr_vec = torch.from_numpy(attr_vec)

        return img, attr_vec, label


class BirdTestDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.ds = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        img = item["image"]
        if isinstance(img, Image.Image):
            img = img.convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        img_id = int(item["id"])
        return img, img_id


train_dataset = BirdTrainDataset(train_hf, attributes, transform=train_transform)
val_dataset   = BirdTrainDataset(val_hf,   attributes, transform=eval_transform)
test_dataset  = BirdTestDataset(test_hf,   transform=eval_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [35]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size//patch_size
        self.num_patches = self.grid_size**2

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x [B,3,H,W]
        x = self.proj(x) #[B,embed_dim,H/P,W/P]
        x = x.flatten(2) #[B,embed_dim,num_patches]
        x = x.transpose(1, 2) #[B,num_patches,embed_dim]
        return x


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=192, num_heads=3, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=drop, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)

        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        #x [B,N,D]
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


class SimpleViTWithAttributes(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 num_classes=200, num_attr=312, embed_dim=192, depth=6,
                 num_heads=3, mlp_ratio=4.0, drop=0.1):
        super().__init__()

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(drop)

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, drop)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        self.head_class = nn.Linear(embed_dim, num_classes)
        # gÅ‚owa atrybutowa
        self.head_attr  = nn.Linear(embed_dim, num_attr)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.head_class.weight, std=0.02)
        nn.init.trunc_normal_(self.head_attr.weight, std=0.02)

    def forward(self, x):
        # x:[B,3,224,224]
        B = x.size(0)
        x = self.patch_embed(x) #[B,N,D]

        cls_tokens = self.cls_token.expand(B,-1,-1) #[B,1,D]
        x = torch.cat((cls_tokens, x), dim=1) #[B,1+N,D]
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls = x[:, 0] #[B,D]

        logits = self.head_class(cls)
        attr_pred = self.head_attr(cls)
        return logits, attr_pred
    
model = SimpleViTWithAttributes(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=4.0,
    drop=0.1
).to(DEVICE)

model

SimpleViTWithAttributes(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x TransformerEncoderBlock(
      (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
      )
      (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=192, out_features=768, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=768, out_features=192, bias=True)
        (4): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  (head_class): Linear(in_features=192, out_features=200, bias=True)
  (head_attr): Linear(in_features=192, out_features=312, b

In [None]:
criterion_class = nn.CrossEntropyLoss()
criterion_attr  = nn.MSELoss()
LAMBDA_ATTR = 0.05

In [37]:
def train_one_epoch(epoch_idx):
    model.train()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    for batch_idx, (imgs, attr_targets, labels) in enumerate(train_loader):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        attr_targets = attr_targets.to(DEVICE)

        optimizer.zero_grad()

        logits, attr_pred = model(imgs)

        loss_cls = criterion_class(logits, labels)
        loss_attr = criterion_attr(attr_pred, attr_targets)
        loss = loss_cls + LAMBDA_ATTR * loss_attr

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        total_cls  += loss_cls.item() * imgs.size(0)
        total_attr += loss_attr.item() * imgs.size(0)

        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        samples += imgs.size(0)

        if batch_idx % 20 == 0:
            print(f"[Epoch {epoch_idx}] Batch {batch_idx}/{len(train_loader)} "
                  f"loss={loss.item():.4f}")

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    return avg_loss, avg_cls, avg_attr, acc


def evaluate():
    model.eval()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    with torch.no_grad():
        for imgs, attr_targets, labels in val_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            attr_targets = attr_targets.to(DEVICE)

            logits, attr_pred = model(imgs)

            loss_cls = criterion_class(logits, labels)
            loss_attr = criterion_attr(attr_pred, attr_targets)
            loss = loss_cls + LAMBDA_ATTR * loss_attr

            total_loss += loss.item() * imgs.size(0)
            total_cls  += loss_cls.item() * imgs.size(0)
            total_attr += loss_attr.item() * imgs.size(0)

            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            samples += imgs.size(0)

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    return avg_loss, avg_cls, avg_attr, acc

In [38]:
def reset_seed(seed=RANDOM_SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def create_model_for_trial(drop: float):
    # rebuild your ViT with this trial's dropout
    model = SimpleViTWithAttributes(
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        num_classes=NUM_CLASSES,
        num_attr=NUM_ATTR,
        embed_dim=192,   # fixed for now
        depth=6,
        num_heads=3,
        mlp_ratio=4.0,
        drop=drop
    ).to(DEVICE)
    return model


def objective(trial: optuna.Trial):
    global model, optimizer, scheduler, LAMBDA_ATTR

    # 1. Sample hyperparameters
    lr = trial.suggest_float("lr", 1e-5, 5e-4, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    drop = trial.suggest_float("drop", 0.0, 0.3)
    lambda_attr = trial.suggest_float("lambda_attr", 0.01, 0.2, log=True)

    # 2. Reset random seeds for reproducibility
    reset_seed()

    # 3. Create model & optimizer for this trial
    model = create_model_for_trial(drop)

    LAMBDA_ATTR = lambda_attr

    optimizer = optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )

    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=TUNE_EPOCHS
    )

    best_val_acc = 0.0

    # 4. Train for a small number of epochs
    for epoch in range(1, TUNE_EPOCHS + 1):
        train_loss, train_cls, train_attr, train_acc = train_one_epoch(epoch)
        val_loss, val_cls, val_attr, val_acc = evaluate()

        scheduler.step()

        # Report to Optuna (for pruning)
        trial.report(val_acc, step=epoch)

        # Optional: prune bad trials early
        if trial.should_prune():
            raise optuna.TrialPruned()

        if val_acc > best_val_acc:
            best_val_acc = val_acc

    # We want to maximize validation accuracy
    return best_val_acc


In [39]:
study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=RANDOM_SEED),
    pruner=optuna.pruners.MedianPruner(n_warmup_steps=2),
)

study.optimize(objective, n_trials=25)

print("Best value:", study.best_trial.value)
print("Best params:", study.best_trial.params)

[I 2025-12-10 18:20:42,674] A new study created in memory with name: no-name-9b0d2cc0-70d0-4698-8d9d-f880548e8f7e


[Epoch 1] Batch 0/105 loss=5.4260
[Epoch 1] Batch 20/105 loss=5.3527
[Epoch 1] Batch 40/105 loss=5.2618
[Epoch 1] Batch 60/105 loss=5.2006
[Epoch 1] Batch 80/105 loss=5.3158
[Epoch 1] Batch 100/105 loss=5.1076
[Epoch 2] Batch 0/105 loss=5.1222
[Epoch 2] Batch 20/105 loss=5.2400
[Epoch 2] Batch 40/105 loss=5.1746
[Epoch 2] Batch 60/105 loss=5.2009
[Epoch 2] Batch 80/105 loss=5.1619
[Epoch 2] Batch 100/105 loss=5.1624
[Epoch 3] Batch 0/105 loss=5.1895
[Epoch 3] Batch 20/105 loss=5.0314
[Epoch 3] Batch 40/105 loss=5.0652
[Epoch 3] Batch 60/105 loss=5.0949
[Epoch 3] Batch 80/105 loss=5.2163
[Epoch 3] Batch 100/105 loss=5.0958
[Epoch 4] Batch 0/105 loss=5.0459
[Epoch 4] Batch 20/105 loss=5.2918
[Epoch 4] Batch 40/105 loss=5.0723
[Epoch 4] Batch 60/105 loss=4.9639
[Epoch 4] Batch 80/105 loss=5.0111
[Epoch 4] Batch 100/105 loss=5.0958
[Epoch 5] Batch 0/105 loss=4.9436
[Epoch 5] Batch 20/105 loss=5.0018
[Epoch 5] Batch 40/105 loss=5.0140
[Epoch 5] Batch 60/105 loss=5.1614
[Epoch 5] Batch 80/10

[I 2025-12-10 19:02:49,787] Trial 0 finished with value: 0.013582342954159592 and parameters: {'lr': 4.3284502212938785e-05, 'weight_decay': 0.006351221010640699, 'drop': 0.21959818254342153, 'lambda_attr': 0.060099747183803134}. Best is trial 0 with value: 0.013582342954159592.


[Epoch 1] Batch 0/105 loss=5.4528
[Epoch 1] Batch 20/105 loss=5.3027
[Epoch 1] Batch 40/105 loss=5.2589
[Epoch 1] Batch 60/105 loss=5.2371
[Epoch 1] Batch 80/105 loss=5.3567
[Epoch 1] Batch 100/105 loss=5.1825
[Epoch 2] Batch 0/105 loss=5.1942
[Epoch 2] Batch 20/105 loss=5.2541
[Epoch 2] Batch 40/105 loss=5.2165
[Epoch 2] Batch 60/105 loss=5.2089
[Epoch 2] Batch 80/105 loss=5.1774
[Epoch 2] Batch 100/105 loss=5.2399
[Epoch 3] Batch 0/105 loss=5.1988
[Epoch 3] Batch 20/105 loss=5.0773
[Epoch 3] Batch 40/105 loss=5.1182
[Epoch 3] Batch 60/105 loss=5.1042
[Epoch 3] Batch 80/105 loss=5.2409
[Epoch 3] Batch 100/105 loss=5.0990
[Epoch 4] Batch 0/105 loss=5.0704
[Epoch 4] Batch 20/105 loss=5.3023
[Epoch 4] Batch 40/105 loss=5.0692
[Epoch 4] Batch 60/105 loss=5.0104
[Epoch 4] Batch 80/105 loss=5.0952
[Epoch 4] Batch 100/105 loss=5.1612
[Epoch 5] Batch 0/105 loss=4.9891
[Epoch 5] Batch 20/105 loss=5.0519
[Epoch 5] Batch 40/105 loss=5.0897
[Epoch 5] Batch 60/105 loss=5.2013
[Epoch 5] Batch 80/10

[I 2025-12-10 19:29:32,060] Trial 1 finished with value: 0.011884550084889643 and parameters: {'lr': 1.8410729205738674e-05, 'weight_decay': 4.207053950287936e-06, 'drop': 0.017425083650459836, 'lambda_attr': 0.13394334706750485}. Best is trial 0 with value: 0.013582342954159592.


[Epoch 1] Batch 0/105 loss=5.4594
[Epoch 1] Batch 20/105 loss=5.3370
[Epoch 1] Batch 40/105 loss=5.2075
[Epoch 1] Batch 60/105 loss=5.1840
[Epoch 1] Batch 80/105 loss=5.3240
[Epoch 1] Batch 100/105 loss=5.0135
[Epoch 2] Batch 0/105 loss=5.0544
[Epoch 2] Batch 20/105 loss=5.1223
[Epoch 2] Batch 40/105 loss=5.0559
[Epoch 2] Batch 60/105 loss=5.1437
[Epoch 2] Batch 80/105 loss=5.0775
[Epoch 2] Batch 100/105 loss=5.0911
[Epoch 3] Batch 0/105 loss=5.1149
[Epoch 3] Batch 20/105 loss=4.8674
[Epoch 3] Batch 40/105 loss=4.9322
[Epoch 3] Batch 60/105 loss=4.9254
[Epoch 3] Batch 80/105 loss=5.1680
[Epoch 3] Batch 100/105 loss=4.9827
[Epoch 4] Batch 0/105 loss=4.8563
[Epoch 4] Batch 20/105 loss=5.1926
[Epoch 4] Batch 40/105 loss=4.9588
[Epoch 4] Batch 60/105 loss=4.8563
[Epoch 4] Batch 80/105 loss=4.7585
[Epoch 4] Batch 100/105 loss=4.8862
[Epoch 5] Batch 0/105 loss=4.7830
[Epoch 5] Batch 20/105 loss=4.8867
[Epoch 5] Batch 40/105 loss=4.7933
[Epoch 5] Batch 60/105 loss=4.9405
[Epoch 5] Batch 80/10

[I 2025-12-10 19:54:41,427] Trial 2 finished with value: 0.01867572156196944 and parameters: {'lr': 0.00010502105436744271, 'weight_decay': 0.0006796578090758161, 'drop': 0.006175348288740734, 'lambda_attr': 0.1827602783178572}. Best is trial 2 with value: 0.01867572156196944.


[Epoch 1] Batch 0/105 loss=5.4327
[Epoch 1] Batch 20/105 loss=5.2992
[Epoch 1] Batch 40/105 loss=5.1957
[Epoch 1] Batch 60/105 loss=5.1298
[Epoch 1] Batch 80/105 loss=5.2571
[Epoch 1] Batch 100/105 loss=4.9224
[Epoch 2] Batch 0/105 loss=4.9923
[Epoch 2] Batch 20/105 loss=5.0779
[Epoch 2] Batch 40/105 loss=4.9025
[Epoch 2] Batch 60/105 loss=5.1107
[Epoch 2] Batch 80/105 loss=5.0122
[Epoch 2] Batch 100/105 loss=5.0659
[Epoch 3] Batch 0/105 loss=4.9803
[Epoch 3] Batch 20/105 loss=4.7116
[Epoch 3] Batch 40/105 loss=4.8328
[Epoch 3] Batch 60/105 loss=4.8801
[Epoch 3] Batch 80/105 loss=5.0854
[Epoch 3] Batch 100/105 loss=4.9809
[Epoch 4] Batch 0/105 loss=4.7405
[Epoch 4] Batch 20/105 loss=5.1703
[Epoch 4] Batch 40/105 loss=4.8663
[Epoch 4] Batch 60/105 loss=4.7475
[Epoch 4] Batch 80/105 loss=4.5950
[Epoch 4] Batch 100/105 loss=4.7174
[Epoch 5] Batch 0/105 loss=4.6956
[Epoch 5] Batch 20/105 loss=4.8556
[Epoch 5] Batch 40/105 loss=4.6713
[Epoch 5] Batch 60/105 loss=4.7500
[Epoch 5] Batch 80/10

[I 2025-12-10 20:21:06,869] Trial 3 finished with value: 0.022071307300509338 and parameters: {'lr': 0.00025959425503112657, 'weight_decay': 7.068974950624607e-06, 'drop': 0.05454749016213018, 'lambda_attr': 0.017322667470546258}. Best is trial 3 with value: 0.022071307300509338.


[Epoch 1] Batch 0/105 loss=5.4259
[Epoch 1] Batch 20/105 loss=5.3146
[Epoch 1] Batch 40/105 loss=5.2339
[Epoch 1] Batch 60/105 loss=5.2286
[Epoch 1] Batch 80/105 loss=5.3364
[Epoch 1] Batch 100/105 loss=5.1218
[Epoch 2] Batch 0/105 loss=5.1438
[Epoch 2] Batch 20/105 loss=5.2286
[Epoch 2] Batch 40/105 loss=5.1965
[Epoch 2] Batch 60/105 loss=5.1832
[Epoch 2] Batch 80/105 loss=5.1625
[Epoch 2] Batch 100/105 loss=5.1759
[Epoch 3] Batch 0/105 loss=5.1950
[Epoch 3] Batch 20/105 loss=5.0218
[Epoch 3] Batch 40/105 loss=5.0694
[Epoch 3] Batch 60/105 loss=5.0837
[Epoch 3] Batch 80/105 loss=5.2361
[Epoch 3] Batch 100/105 loss=5.0759
[Epoch 4] Batch 0/105 loss=5.0502
[Epoch 4] Batch 20/105 loss=5.2984
[Epoch 4] Batch 40/105 loss=5.0698
[Epoch 4] Batch 60/105 loss=4.9803
[Epoch 4] Batch 80/105 loss=5.0357
[Epoch 4] Batch 100/105 loss=5.1196
[Epoch 5] Batch 0/105 loss=4.9590
[Epoch 5] Batch 20/105 loss=5.0145
[Epoch 5] Batch 40/105 loss=5.0245
[Epoch 5] Batch 60/105 loss=5.1638
[Epoch 5] Batch 80/10

[I 2025-12-10 20:46:45,064] Trial 4 finished with value: 0.015280135823429542 and parameters: {'lr': 3.2877474139911175e-05, 'weight_decay': 0.0001256104370001356, 'drop': 0.12958350559263473, 'lambda_attr': 0.023927528765580634}. Best is trial 3 with value: 0.022071307300509338.


[Epoch 1] Batch 0/105 loss=5.4372
[Epoch 1] Batch 20/105 loss=5.3238
[Epoch 1] Batch 40/105 loss=5.2118
[Epoch 1] Batch 60/105 loss=5.1567
[Epoch 1] Batch 80/105 loss=5.3034
[Epoch 1] Batch 100/105 loss=5.0070
[Epoch 2] Batch 0/105 loss=5.0464
[Epoch 2] Batch 20/105 loss=5.1368
[Epoch 2] Batch 40/105 loss=5.0743
[Epoch 2] Batch 60/105 loss=5.1290
[Epoch 2] Batch 80/105 loss=5.0894
[Epoch 2] Batch 100/105 loss=5.0651
[Epoch 3] Batch 0/105 loss=5.1162
[Epoch 3] Batch 20/105 loss=4.8837
[Epoch 3] Batch 40/105 loss=4.9450
[Epoch 3] Batch 60/105 loss=4.9538
[Epoch 3] Batch 80/105 loss=5.1505
[Epoch 3] Batch 100/105 loss=5.0012
[Epoch 4] Batch 0/105 loss=4.8735
[Epoch 4] Batch 20/105 loss=5.2152
[Epoch 4] Batch 40/105 loss=4.9798
[Epoch 4] Batch 60/105 loss=4.8718
[Epoch 4] Batch 80/105 loss=4.7645
[Epoch 4] Batch 100/105 loss=4.8681
[Epoch 5] Batch 0/105 loss=4.7906
[Epoch 5] Batch 20/105 loss=4.8934
[Epoch 5] Batch 40/105 loss=4.8142
[Epoch 5] Batch 60/105 loss=4.9449
[Epoch 5] Batch 80/10

[I 2025-12-10 21:20:38,295] Trial 5 finished with value: 0.01867572156196944 and parameters: {'lr': 0.00010952662748632558, 'weight_decay': 3.6138942712165278e-06, 'drop': 0.08764339456056544, 'lambda_attr': 0.029967309097101588}. Best is trial 3 with value: 0.022071307300509338.


[Epoch 1] Batch 0/105 loss=5.4367
[Epoch 1] Batch 20/105 loss=5.3136
[Epoch 1] Batch 40/105 loss=5.2447
[Epoch 1] Batch 60/105 loss=5.1892
[Epoch 1] Batch 80/105 loss=5.3288
[Epoch 1] Batch 100/105 loss=5.0653
[Epoch 2] Batch 0/105 loss=5.0910
[Epoch 2] Batch 20/105 loss=5.1663
[Epoch 2] Batch 40/105 loss=5.1218
[Epoch 2] Batch 60/105 loss=5.1667
[Epoch 2] Batch 80/105 loss=5.0981
[Epoch 2] Batch 100/105 loss=5.1443


[I 2025-12-10 21:37:45,626] Trial 6 pruned. 


[Epoch 1] Batch 0/105 loss=5.4364
[Epoch 1] Batch 20/105 loss=5.3443
[Epoch 1] Batch 40/105 loss=5.2219
[Epoch 1] Batch 60/105 loss=5.1506
[Epoch 1] Batch 80/105 loss=5.3181
[Epoch 1] Batch 100/105 loss=5.0083
[Epoch 2] Batch 0/105 loss=5.0518
[Epoch 2] Batch 20/105 loss=5.1598
[Epoch 2] Batch 40/105 loss=5.0937
[Epoch 2] Batch 60/105 loss=5.1389
[Epoch 2] Batch 80/105 loss=5.0732
[Epoch 2] Batch 100/105 loss=5.0752
[Epoch 3] Batch 0/105 loss=5.1591
[Epoch 3] Batch 20/105 loss=4.9120
[Epoch 3] Batch 40/105 loss=4.9500
[Epoch 3] Batch 60/105 loss=4.9765
[Epoch 3] Batch 80/105 loss=5.1779
[Epoch 3] Batch 100/105 loss=5.0238
[Epoch 4] Batch 0/105 loss=4.9036
[Epoch 4] Batch 20/105 loss=5.2212
[Epoch 4] Batch 40/105 loss=4.9930
[Epoch 4] Batch 60/105 loss=4.8724
[Epoch 4] Batch 80/105 loss=4.8049
[Epoch 4] Batch 100/105 loss=4.9475
[Epoch 5] Batch 0/105 loss=4.8150
[Epoch 5] Batch 20/105 loss=4.9128
[Epoch 5] Batch 40/105 loss=4.8472
[Epoch 5] Batch 60/105 loss=5.0278
[Epoch 5] Batch 80/10

[I 2025-12-10 22:17:26,514] Trial 7 finished with value: 0.01867572156196944 and parameters: {'lr': 0.0001015066704592858, 'weight_decay': 1.5339162591163623e-06, 'drop': 0.1822634555704315, 'lambda_attr': 0.016666983286066417}. Best is trial 3 with value: 0.022071307300509338.


[Epoch 1] Batch 0/105 loss=5.4390
[Epoch 1] Batch 20/105 loss=5.3386
[Epoch 1] Batch 40/105 loss=5.3229
[Epoch 1] Batch 60/105 loss=5.2921
[Epoch 1] Batch 80/105 loss=5.3105
[Epoch 1] Batch 100/105 loss=5.2810
[Epoch 2] Batch 0/105 loss=5.2753
[Epoch 2] Batch 20/105 loss=5.2989
[Epoch 2] Batch 40/105 loss=5.2712
[Epoch 2] Batch 60/105 loss=5.2650
[Epoch 2] Batch 80/105 loss=5.2711
[Epoch 2] Batch 100/105 loss=5.2845


[I 2025-12-10 22:32:03,458] Trial 8 pruned. 


[Epoch 1] Batch 0/105 loss=5.4158
[Epoch 1] Batch 20/105 loss=5.3589
[Epoch 1] Batch 40/105 loss=5.2610
[Epoch 1] Batch 60/105 loss=5.2153
[Epoch 1] Batch 80/105 loss=5.3240
[Epoch 1] Batch 100/105 loss=5.1332
[Epoch 2] Batch 0/105 loss=5.1446
[Epoch 2] Batch 20/105 loss=5.2611
[Epoch 2] Batch 40/105 loss=5.1914
[Epoch 2] Batch 60/105 loss=5.1965
[Epoch 2] Batch 80/105 loss=5.1737
[Epoch 2] Batch 100/105 loss=5.1942


[I 2025-12-10 22:47:35,305] Trial 9 pruned. 


[Epoch 1] Batch 0/105 loss=5.4318
[Epoch 1] Batch 20/105 loss=5.3263
[Epoch 1] Batch 40/105 loss=5.2086
[Epoch 1] Batch 60/105 loss=5.0601
[Epoch 1] Batch 80/105 loss=5.2930
[Epoch 1] Batch 100/105 loss=4.8588
[Epoch 2] Batch 0/105 loss=4.9598
[Epoch 2] Batch 20/105 loss=5.0336
[Epoch 2] Batch 40/105 loss=4.7960
[Epoch 2] Batch 60/105 loss=5.0809
[Epoch 2] Batch 80/105 loss=4.9756
[Epoch 2] Batch 100/105 loss=5.0964
[Epoch 3] Batch 0/105 loss=4.9586
[Epoch 3] Batch 20/105 loss=4.6231
[Epoch 3] Batch 40/105 loss=4.7724
[Epoch 3] Batch 60/105 loss=4.8924
[Epoch 3] Batch 80/105 loss=5.0989
[Epoch 3] Batch 100/105 loss=4.9115
[Epoch 4] Batch 0/105 loss=4.6868
[Epoch 4] Batch 20/105 loss=5.1476
[Epoch 4] Batch 40/105 loss=4.8164
[Epoch 4] Batch 60/105 loss=4.7146
[Epoch 4] Batch 80/105 loss=4.5273
[Epoch 4] Batch 100/105 loss=4.6886
[Epoch 5] Batch 0/105 loss=4.6335
[Epoch 5] Batch 20/105 loss=4.8666
[Epoch 5] Batch 40/105 loss=4.6510
[Epoch 5] Batch 60/105 loss=4.6420
[Epoch 5] Batch 80/10

[I 2025-12-10 23:24:39,567] Trial 10 finished with value: 0.02037351443123939 and parameters: {'lr': 0.00040342112732688975, 'weight_decay': 2.85043206278715e-05, 'drop': 0.1106186318163292, 'lambda_attr': 0.010552829926879392}. Best is trial 3 with value: 0.022071307300509338.


[Epoch 1] Batch 0/105 loss=5.4301
[Epoch 1] Batch 20/105 loss=5.3295
[Epoch 1] Batch 40/105 loss=5.2124
[Epoch 1] Batch 60/105 loss=5.0611
[Epoch 1] Batch 80/105 loss=5.2928
[Epoch 1] Batch 100/105 loss=4.8667
[Epoch 2] Batch 0/105 loss=4.9661
[Epoch 2] Batch 20/105 loss=5.0341
[Epoch 2] Batch 40/105 loss=4.8159
[Epoch 2] Batch 60/105 loss=5.0669
[Epoch 2] Batch 80/105 loss=5.0208
[Epoch 2] Batch 100/105 loss=5.0565
[Epoch 3] Batch 0/105 loss=5.0246
[Epoch 3] Batch 20/105 loss=4.6948
[Epoch 3] Batch 40/105 loss=4.7720
[Epoch 3] Batch 60/105 loss=4.9463
[Epoch 3] Batch 80/105 loss=5.0897
[Epoch 3] Batch 100/105 loss=4.9374
[Epoch 4] Batch 0/105 loss=4.6616
[Epoch 4] Batch 20/105 loss=5.1644
[Epoch 4] Batch 40/105 loss=4.8731
[Epoch 4] Batch 60/105 loss=4.7837
[Epoch 4] Batch 80/105 loss=4.5331
[Epoch 4] Batch 100/105 loss=4.7550
[Epoch 5] Batch 0/105 loss=4.6825
[Epoch 5] Batch 20/105 loss=4.8361
[Epoch 5] Batch 40/105 loss=4.6429
[Epoch 5] Batch 60/105 loss=4.6562
[Epoch 5] Batch 80/10

[I 2025-12-11 00:02:47,839] Trial 11 finished with value: 0.023769100169779286 and parameters: {'lr': 0.0004273669312062059, 'weight_decay': 2.3739423955292538e-05, 'drop': 0.10912464381815776, 'lambda_attr': 0.010368427651183114}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4338
[Epoch 1] Batch 20/105 loss=5.3244
[Epoch 1] Batch 40/105 loss=5.2346
[Epoch 1] Batch 60/105 loss=5.0826
[Epoch 1] Batch 80/105 loss=5.2832
[Epoch 1] Batch 100/105 loss=4.8608
[Epoch 2] Batch 0/105 loss=4.9830
[Epoch 2] Batch 20/105 loss=5.0437
[Epoch 2] Batch 40/105 loss=4.8224
[Epoch 2] Batch 60/105 loss=5.1209
[Epoch 2] Batch 80/105 loss=5.0114
[Epoch 2] Batch 100/105 loss=5.0302


[I 2025-12-11 00:18:09,985] Trial 12 pruned. 


[Epoch 1] Batch 0/105 loss=5.4460
[Epoch 1] Batch 20/105 loss=5.3148
[Epoch 1] Batch 40/105 loss=5.1792
[Epoch 1] Batch 60/105 loss=5.1083
[Epoch 1] Batch 80/105 loss=5.2954
[Epoch 1] Batch 100/105 loss=4.9376
[Epoch 2] Batch 0/105 loss=4.9889
[Epoch 2] Batch 20/105 loss=5.0862
[Epoch 2] Batch 40/105 loss=4.9767
[Epoch 2] Batch 60/105 loss=5.1022
[Epoch 2] Batch 80/105 loss=4.9954
[Epoch 2] Batch 100/105 loss=5.0377
[Epoch 3] Batch 0/105 loss=5.0318
[Epoch 3] Batch 20/105 loss=4.7897
[Epoch 3] Batch 40/105 loss=4.8416
[Epoch 3] Batch 60/105 loss=4.8580
[Epoch 3] Batch 80/105 loss=5.1047
[Epoch 3] Batch 100/105 loss=5.0053
[Epoch 4] Batch 0/105 loss=4.7910
[Epoch 4] Batch 20/105 loss=5.1592
[Epoch 4] Batch 40/105 loss=4.8833
[Epoch 4] Batch 60/105 loss=4.8082
[Epoch 4] Batch 80/105 loss=4.6330
[Epoch 4] Batch 100/105 loss=4.7527
[Epoch 5] Batch 0/105 loss=4.7232
[Epoch 5] Batch 20/105 loss=4.8327
[Epoch 5] Batch 40/105 loss=4.7208
[Epoch 5] Batch 60/105 loss=4.8357
[Epoch 5] Batch 80/10

[I 2025-12-11 00:54:36,681] Trial 13 finished with value: 0.02037351443123939 and parameters: {'lr': 0.00022711179190406156, 'weight_decay': 2.11950469770016e-05, 'drop': 0.15888285342194056, 'lambda_attr': 0.017501117294230414}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4343
[Epoch 1] Batch 20/105 loss=5.2967
[Epoch 1] Batch 40/105 loss=5.2041
[Epoch 1] Batch 60/105 loss=5.1125
[Epoch 1] Batch 80/105 loss=5.2896
[Epoch 1] Batch 100/105 loss=4.9234
[Epoch 2] Batch 0/105 loss=4.9925
[Epoch 2] Batch 20/105 loss=5.0985
[Epoch 2] Batch 40/105 loss=4.9041
[Epoch 2] Batch 60/105 loss=5.1129
[Epoch 2] Batch 80/105 loss=5.0043
[Epoch 2] Batch 100/105 loss=5.0408
[Epoch 3] Batch 0/105 loss=4.9897
[Epoch 3] Batch 20/105 loss=4.7166
[Epoch 3] Batch 40/105 loss=4.8201
[Epoch 3] Batch 60/105 loss=4.8861
[Epoch 3] Batch 80/105 loss=5.1219
[Epoch 3] Batch 100/105 loss=4.9784
[Epoch 4] Batch 0/105 loss=4.7497
[Epoch 4] Batch 20/105 loss=5.1778
[Epoch 4] Batch 40/105 loss=4.8654
[Epoch 4] Batch 60/105 loss=4.7864
[Epoch 4] Batch 80/105 loss=4.5880
[Epoch 4] Batch 100/105 loss=4.7568
[Epoch 5] Batch 0/105 loss=4.6777
[Epoch 5] Batch 20/105 loss=4.8417
[Epoch 5] Batch 40/105 loss=4.7246
[Epoch 5] Batch 60/105 loss=4.7243
[Epoch 5] Batch 80/10

[I 2025-12-11 01:28:23,850] Trial 14 finished with value: 0.02037351443123939 and parameters: {'lr': 0.0002453634175888639, 'weight_decay': 0.00012799717396110157, 'drop': 0.05912714493419006, 'lambda_attr': 0.016393724345057703}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4323
[Epoch 1] Batch 20/105 loss=5.2966
[Epoch 1] Batch 40/105 loss=5.1802
[Epoch 1] Batch 60/105 loss=5.1193
[Epoch 1] Batch 80/105 loss=5.2922
[Epoch 1] Batch 100/105 loss=4.9288
[Epoch 2] Batch 0/105 loss=4.9944
[Epoch 2] Batch 20/105 loss=5.0926
[Epoch 2] Batch 40/105 loss=4.9647
[Epoch 2] Batch 60/105 loss=5.1091
[Epoch 2] Batch 80/105 loss=4.9966
[Epoch 2] Batch 100/105 loss=5.0612
[Epoch 3] Batch 0/105 loss=5.0281
[Epoch 3] Batch 20/105 loss=4.7696
[Epoch 3] Batch 40/105 loss=4.8439
[Epoch 3] Batch 60/105 loss=4.8593
[Epoch 3] Batch 80/105 loss=5.0858
[Epoch 3] Batch 100/105 loss=4.9652
[Epoch 4] Batch 0/105 loss=4.7892
[Epoch 4] Batch 20/105 loss=5.2129
[Epoch 4] Batch 40/105 loss=4.8618
[Epoch 4] Batch 60/105 loss=4.7597
[Epoch 4] Batch 80/105 loss=4.6000
[Epoch 4] Batch 100/105 loss=4.7935
[Epoch 5] Batch 0/105 loss=4.7176
[Epoch 5] Batch 20/105 loss=4.8492
[Epoch 5] Batch 40/105 loss=4.7197
[Epoch 5] Batch 60/105 loss=4.7280
[Epoch 5] Batch 80/10

[I 2025-12-11 02:07:53,690] Trial 15 finished with value: 0.01867572156196944 and parameters: {'lr': 0.00021768474487790407, 'weight_decay': 9.188227716652774e-06, 'drop': 0.10076176448993387, 'lambda_attr': 0.012952946089363222}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4430
[Epoch 1] Batch 20/105 loss=5.3846
[Epoch 1] Batch 40/105 loss=5.2248
[Epoch 1] Batch 60/105 loss=5.0419
[Epoch 1] Batch 80/105 loss=5.2933
[Epoch 1] Batch 100/105 loss=4.8484
[Epoch 2] Batch 0/105 loss=4.9504
[Epoch 2] Batch 20/105 loss=5.0183
[Epoch 2] Batch 40/105 loss=4.7340
[Epoch 2] Batch 60/105 loss=5.1283
[Epoch 2] Batch 80/105 loss=5.0015
[Epoch 2] Batch 100/105 loss=5.0624
[Epoch 3] Batch 0/105 loss=4.9560
[Epoch 3] Batch 20/105 loss=4.6832
[Epoch 3] Batch 40/105 loss=4.8000
[Epoch 3] Batch 60/105 loss=4.9305
[Epoch 3] Batch 80/105 loss=5.1266
[Epoch 3] Batch 100/105 loss=4.8951
[Epoch 4] Batch 0/105 loss=4.6537
[Epoch 4] Batch 20/105 loss=5.1423
[Epoch 4] Batch 40/105 loss=4.8147
[Epoch 4] Batch 60/105 loss=4.7087
[Epoch 4] Batch 80/105 loss=4.5099
[Epoch 4] Batch 100/105 loss=4.6726
[Epoch 5] Batch 0/105 loss=4.6519
[Epoch 5] Batch 20/105 loss=4.8519
[Epoch 5] Batch 40/105 loss=4.6161
[Epoch 5] Batch 60/105 loss=4.6340
[Epoch 5] Batch 80/10

[I 2025-12-11 02:46:47,199] Trial 16 finished with value: 0.02037351443123939 and parameters: {'lr': 0.00048183492480022197, 'weight_decay': 4.9146810539426464e-05, 'drop': 0.03742862534893712, 'lambda_attr': 0.026060376348876303}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4382
[Epoch 1] Batch 20/105 loss=5.3208
[Epoch 1] Batch 40/105 loss=5.1920
[Epoch 1] Batch 60/105 loss=5.1450
[Epoch 1] Batch 80/105 loss=5.3074
[Epoch 1] Batch 100/105 loss=4.9787
[Epoch 2] Batch 0/105 loss=5.0201
[Epoch 2] Batch 20/105 loss=5.1095
[Epoch 2] Batch 40/105 loss=5.0336
[Epoch 2] Batch 60/105 loss=5.1102
[Epoch 2] Batch 80/105 loss=5.0630
[Epoch 2] Batch 100/105 loss=5.0691
[Epoch 3] Batch 0/105 loss=5.0668
[Epoch 3] Batch 20/105 loss=4.8229
[Epoch 3] Batch 40/105 loss=4.9086
[Epoch 3] Batch 60/105 loss=4.8980
[Epoch 3] Batch 80/105 loss=5.1468
[Epoch 3] Batch 100/105 loss=4.9954
[Epoch 4] Batch 0/105 loss=4.8484
[Epoch 4] Batch 20/105 loss=5.2001
[Epoch 4] Batch 40/105 loss=4.9467
[Epoch 4] Batch 60/105 loss=4.8495
[Epoch 4] Batch 80/105 loss=4.7031
[Epoch 4] Batch 100/105 loss=4.8059
[Epoch 5] Batch 0/105 loss=4.7558
[Epoch 5] Batch 20/105 loss=4.8805
[Epoch 5] Batch 40/105 loss=4.7796
[Epoch 5] Batch 60/105 loss=4.9080
[Epoch 5] Batch 80/10

[I 2025-12-11 03:22:08,229] Trial 17 finished with value: 0.01867572156196944 and parameters: {'lr': 0.00015409834605716456, 'weight_decay': 0.000336775626774932, 'drop': 0.14592206217745002, 'lambda_attr': 0.06663296217608104}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4405
[Epoch 1] Batch 20/105 loss=5.3188
[Epoch 1] Batch 40/105 loss=5.1800
[Epoch 1] Batch 60/105 loss=5.1090
[Epoch 1] Batch 80/105 loss=5.2826
[Epoch 1] Batch 100/105 loss=4.9172
[Epoch 2] Batch 0/105 loss=4.9676
[Epoch 2] Batch 20/105 loss=5.0590
[Epoch 2] Batch 40/105 loss=4.9346
[Epoch 2] Batch 60/105 loss=5.0846
[Epoch 2] Batch 80/105 loss=4.9687
[Epoch 2] Batch 100/105 loss=4.9900
[Epoch 3] Batch 0/105 loss=5.0093
[Epoch 3] Batch 20/105 loss=4.7442
[Epoch 3] Batch 40/105 loss=4.8527
[Epoch 3] Batch 60/105 loss=4.8875
[Epoch 3] Batch 80/105 loss=5.1489
[Epoch 3] Batch 100/105 loss=4.9736
[Epoch 4] Batch 0/105 loss=4.7774
[Epoch 4] Batch 20/105 loss=5.1488
[Epoch 4] Batch 40/105 loss=4.8524
[Epoch 4] Batch 60/105 loss=4.7743
[Epoch 4] Batch 80/105 loss=4.5636
[Epoch 4] Batch 100/105 loss=4.6971
[Epoch 5] Batch 0/105 loss=4.7123
[Epoch 5] Batch 20/105 loss=4.8214
[Epoch 5] Batch 40/105 loss=4.6813
[Epoch 5] Batch 60/105 loss=4.7763
[Epoch 5] Batch 80/10

[I 2025-12-11 04:00:46,760] Trial 18 finished with value: 0.022071307300509338 and parameters: {'lr': 0.00029791204863096585, 'weight_decay': 9.2516614469223e-06, 'drop': 0.2531940305133584, 'lambda_attr': 0.02020396268393445}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4432
[Epoch 1] Batch 20/105 loss=5.3096
[Epoch 1] Batch 40/105 loss=5.1970
[Epoch 1] Batch 60/105 loss=5.1428
[Epoch 1] Batch 80/105 loss=5.2937
[Epoch 1] Batch 100/105 loss=4.9872
[Epoch 2] Batch 0/105 loss=5.0290
[Epoch 2] Batch 20/105 loss=5.1182
[Epoch 2] Batch 40/105 loss=5.0339
[Epoch 2] Batch 60/105 loss=5.1145
[Epoch 2] Batch 80/105 loss=5.0630
[Epoch 2] Batch 100/105 loss=5.0545


[I 2025-12-11 04:17:13,322] Trial 19 pruned. 


[Epoch 1] Batch 0/105 loss=5.4254
[Epoch 1] Batch 20/105 loss=5.2943
[Epoch 1] Batch 40/105 loss=5.1883
[Epoch 1] Batch 60/105 loss=5.0816
[Epoch 1] Batch 80/105 loss=5.2998
[Epoch 1] Batch 100/105 loss=4.8817
[Epoch 2] Batch 0/105 loss=4.9821
[Epoch 2] Batch 20/105 loss=5.0730
[Epoch 2] Batch 40/105 loss=4.8858
[Epoch 2] Batch 60/105 loss=5.1068
[Epoch 2] Batch 80/105 loss=4.9910
[Epoch 2] Batch 100/105 loss=5.0984
[Epoch 3] Batch 0/105 loss=5.0215
[Epoch 3] Batch 20/105 loss=4.6901
[Epoch 3] Batch 40/105 loss=4.8266
[Epoch 3] Batch 60/105 loss=4.8878
[Epoch 3] Batch 80/105 loss=5.1098
[Epoch 3] Batch 100/105 loss=4.9715
[Epoch 4] Batch 0/105 loss=4.7235
[Epoch 4] Batch 20/105 loss=5.1828
[Epoch 4] Batch 40/105 loss=4.8140
[Epoch 4] Batch 60/105 loss=4.7038
[Epoch 4] Batch 80/105 loss=4.5772
[Epoch 4] Batch 100/105 loss=4.7396
[Epoch 5] Batch 0/105 loss=4.6841
[Epoch 5] Batch 20/105 loss=4.8780
[Epoch 5] Batch 40/105 loss=4.7003
[Epoch 5] Batch 60/105 loss=4.6569
[Epoch 5] Batch 80/10

[I 2025-12-11 04:57:40,227] Trial 20 finished with value: 0.02037351443123939 and parameters: {'lr': 0.000302777629761442, 'weight_decay': 8.274974937412168e-06, 'drop': 0.12456198490899, 'lambda_attr': 0.012191839002861509}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4383
[Epoch 1] Batch 20/105 loss=5.2932
[Epoch 1] Batch 40/105 loss=5.2024
[Epoch 1] Batch 60/105 loss=5.0780
[Epoch 1] Batch 80/105 loss=5.2679
[Epoch 1] Batch 100/105 loss=4.8934
[Epoch 2] Batch 0/105 loss=4.9608
[Epoch 2] Batch 20/105 loss=4.9982
[Epoch 2] Batch 40/105 loss=4.8999
[Epoch 2] Batch 60/105 loss=5.0837
[Epoch 2] Batch 80/105 loss=4.9368
[Epoch 2] Batch 100/105 loss=5.0076
[Epoch 3] Batch 0/105 loss=4.9914
[Epoch 3] Batch 20/105 loss=4.7266
[Epoch 3] Batch 40/105 loss=4.8305
[Epoch 3] Batch 60/105 loss=4.8549
[Epoch 3] Batch 80/105 loss=5.1406
[Epoch 3] Batch 100/105 loss=4.9430
[Epoch 4] Batch 0/105 loss=4.7407
[Epoch 4] Batch 20/105 loss=5.1223
[Epoch 4] Batch 40/105 loss=4.8575
[Epoch 4] Batch 60/105 loss=4.7849
[Epoch 4] Batch 80/105 loss=4.5368
[Epoch 4] Batch 100/105 loss=4.7202
[Epoch 5] Batch 0/105 loss=4.6655
[Epoch 5] Batch 20/105 loss=4.8540
[Epoch 5] Batch 40/105 loss=4.6655
[Epoch 5] Batch 60/105 loss=4.7143
[Epoch 5] Batch 80/10

[I 2025-12-11 05:31:13,692] Trial 21 finished with value: 0.02037351443123939 and parameters: {'lr': 0.00035546031788461787, 'weight_decay': 8.756805280546209e-06, 'drop': 0.28325069170998785, 'lambda_attr': 0.02364721757248912}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4250
[Epoch 1] Batch 20/105 loss=5.3050
[Epoch 1] Batch 40/105 loss=5.1933
[Epoch 1] Batch 60/105 loss=5.1411
[Epoch 1] Batch 80/105 loss=5.2800
[Epoch 1] Batch 100/105 loss=4.9655
[Epoch 2] Batch 0/105 loss=5.0061
[Epoch 2] Batch 20/105 loss=5.1062
[Epoch 2] Batch 40/105 loss=5.0190
[Epoch 2] Batch 60/105 loss=5.1073
[Epoch 2] Batch 80/105 loss=5.0247
[Epoch 2] Batch 100/105 loss=5.0295
[Epoch 3] Batch 0/105 loss=5.0580
[Epoch 3] Batch 20/105 loss=4.8290
[Epoch 3] Batch 40/105 loss=4.9135
[Epoch 3] Batch 60/105 loss=4.9167
[Epoch 3] Batch 80/105 loss=5.1531
[Epoch 3] Batch 100/105 loss=5.0081
[Epoch 4] Batch 0/105 loss=4.8434
[Epoch 4] Batch 20/105 loss=5.1893
[Epoch 4] Batch 40/105 loss=4.9142
[Epoch 4] Batch 60/105 loss=4.8358
[Epoch 4] Batch 80/105 loss=4.6633
[Epoch 4] Batch 100/105 loss=4.8176
[Epoch 5] Batch 0/105 loss=4.7503
[Epoch 5] Batch 20/105 loss=4.8525
[Epoch 5] Batch 40/105 loss=4.7862
[Epoch 5] Batch 60/105 loss=4.8735
[Epoch 5] Batch 80/10

[I 2025-12-11 06:05:43,180] Trial 22 finished with value: 0.01867572156196944 and parameters: {'lr': 0.00018675839994556782, 'weight_decay': 9.65561126905972e-06, 'drop': 0.21858131828053876, 'lambda_attr': 0.020131623578867264}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4457
[Epoch 1] Batch 20/105 loss=5.3135
[Epoch 1] Batch 40/105 loss=5.1845
[Epoch 1] Batch 60/105 loss=5.1059
[Epoch 1] Batch 80/105 loss=5.2724
[Epoch 1] Batch 100/105 loss=4.9158
[Epoch 2] Batch 0/105 loss=4.9597
[Epoch 2] Batch 20/105 loss=5.0515
[Epoch 2] Batch 40/105 loss=4.9289
[Epoch 2] Batch 60/105 loss=5.0932
[Epoch 2] Batch 80/105 loss=4.9823
[Epoch 2] Batch 100/105 loss=4.9983
[Epoch 3] Batch 0/105 loss=4.9885
[Epoch 3] Batch 20/105 loss=4.7443
[Epoch 3] Batch 40/105 loss=4.8461
[Epoch 3] Batch 60/105 loss=4.8727
[Epoch 3] Batch 80/105 loss=5.1413
[Epoch 3] Batch 100/105 loss=4.9635
[Epoch 4] Batch 0/105 loss=4.7503
[Epoch 4] Batch 20/105 loss=5.1217
[Epoch 4] Batch 40/105 loss=4.8487
[Epoch 4] Batch 60/105 loss=4.7717
[Epoch 4] Batch 80/105 loss=4.5665
[Epoch 4] Batch 100/105 loss=4.7145
[Epoch 5] Batch 0/105 loss=4.7232
[Epoch 5] Batch 20/105 loss=4.8542
[Epoch 5] Batch 40/105 loss=4.6776
[Epoch 5] Batch 60/105 loss=4.7830
[Epoch 5] Batch 80/10

[I 2025-12-11 06:43:18,154] Trial 23 finished with value: 0.022071307300509338 and parameters: {'lr': 0.00030777398303933327, 'weight_decay': 1.5558205341022404e-06, 'drop': 0.25716261354321895, 'lambda_attr': 0.03179408669133864}. Best is trial 11 with value: 0.023769100169779286.


[Epoch 1] Batch 0/105 loss=5.4390
[Epoch 1] Batch 20/105 loss=5.3236
[Epoch 1] Batch 40/105 loss=5.1836
[Epoch 1] Batch 60/105 loss=5.0769
[Epoch 1] Batch 80/105 loss=5.2655
[Epoch 1] Batch 100/105 loss=4.9200
[Epoch 2] Batch 0/105 loss=4.9859
[Epoch 2] Batch 20/105 loss=5.0699
[Epoch 2] Batch 40/105 loss=4.9367
[Epoch 2] Batch 60/105 loss=5.0837
[Epoch 2] Batch 80/105 loss=4.9637
[Epoch 2] Batch 100/105 loss=4.9914


[I 2025-12-11 06:57:08,023] Trial 24 pruned. 


Best value: 0.023769100169779286
Best params: {'lr': 0.0004273669312062059, 'weight_decay': 2.3739423955292538e-05, 'drop': 0.10912464381815776, 'lambda_attr': 0.010368427651183114}


{'lr': 4.3284502212938785e-05, 'weight_decay': 0.006351221010640699, 'drop': 0.21959818254342153, 'lambda_attr': 0.060099747183803134}

In [40]:
best_params = study.best_trial.params
best_lr           = best_params["lr"]
best_weight_decay = best_params["weight_decay"]
best_drop         = best_params["drop"]
best_lambda_attr  = best_params["lambda_attr"]

reset_seed()

model = SimpleViTWithAttributes(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=4.0,
    drop=best_drop,
).to(DEVICE)

LAMBDA_ATTR = best_lambda_attr

optimizer = optim.AdamW(model.parameters(), lr=best_lr, weight_decay=best_weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

best_val_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    print(f"\n[FINAL TRAIN] Epoch {epoch}/{EPOCHS}")
    train_one_epoch(epoch)
    _, _, _, val_acc = evaluate()
    scheduler.step()

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "vit_best_model_optimized.pth")
        print("Saved new best optimized model")



[FINAL TRAIN] Epoch 1/25
[Epoch 1] Batch 0/105 loss=5.4301
[Epoch 1] Batch 20/105 loss=5.3295
[Epoch 1] Batch 40/105 loss=5.2124
[Epoch 1] Batch 60/105 loss=5.0611
[Epoch 1] Batch 80/105 loss=5.2928
[Epoch 1] Batch 100/105 loss=4.8667
Saved new best optimized model

[FINAL TRAIN] Epoch 2/25
[Epoch 2] Batch 0/105 loss=4.9661
[Epoch 2] Batch 20/105 loss=5.0323
[Epoch 2] Batch 40/105 loss=4.7495
[Epoch 2] Batch 60/105 loss=5.0976
[Epoch 2] Batch 80/105 loss=4.9452
[Epoch 2] Batch 100/105 loss=5.0395

[FINAL TRAIN] Epoch 3/25
[Epoch 3] Batch 0/105 loss=4.9850
[Epoch 3] Batch 20/105 loss=4.6469
[Epoch 3] Batch 40/105 loss=4.7907
[Epoch 3] Batch 60/105 loss=4.9544
[Epoch 3] Batch 80/105 loss=5.1524
[Epoch 3] Batch 100/105 loss=4.9507

[FINAL TRAIN] Epoch 4/25
[Epoch 4] Batch 0/105 loss=4.6722
[Epoch 4] Batch 20/105 loss=5.3330
[Epoch 4] Batch 40/105 loss=4.8650
[Epoch 4] Batch 60/105 loss=4.7634
[Epoch 4] Batch 80/105 loss=4.5779
[Epoch 4] Batch 100/105 loss=4.8708
Saved new best optimized 

In [41]:
print("Number of finished trials:", len(study.trials))
print("Best trial value (val_acc):", study.best_trial.value)
print("Best hyperparameters:")
for k, v in study.best_trial.params.items():
    print(f"  {k}: {v}")


Number of finished trials: 25
Best trial value (val_acc): 0.023769100169779286
Best hyperparameters:
  lr: 0.0004273669312062059
  weight_decay: 2.3739423955292538e-05
  drop: 0.10912464381815776
  lambda_attr: 0.010368427651183114


In [42]:
best_vit = SimpleViTWithAttributes(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=4.0,
    drop=0.1
).to(DEVICE)

best_vit.load_state_dict(torch.load("vit_best_model.pth", map_location=DEVICE))
best_vit.eval()

all_ids = []
all_preds = []

with torch.no_grad():
    for imgs, img_ids in test_loader:
        imgs = imgs.to(DEVICE)
        logits, attr_pred = best_vit(imgs)
        preds = logits.argmax(dim=1).cpu().numpy()

        all_ids.extend(img_ids.numpy().tolist())
        all_preds.extend(preds.tolist())

kaggle_labels = [p + 1 for p in all_preds]

submission_vit = pd.DataFrame({
    "id": all_ids,
    "label": kaggle_labels
}).sort_values("id")

print(submission_vit.head())

submission_vit.to_csv("vit_submission.csv", index=False)
print("\nSaved vit_submission.csv")

   id  label
0   1     33
1   2     55
2   3     28
3   4     34
4   5     16

Saved vit_submission.csv
