In [15]:
!pip install optuna
!pip install optuna-integration 'pytorch_lightning'



In [16]:
import os
import random
import numpy as np
import pandas as pd
import optuna
from optuna.integration import PyTorchLightningPruningCallback  # optional, not required

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from datasets import load_from_disk
from PIL import Image

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cpu


In [17]:
RANDOM_SEED = 42
BATCH_SIZE = 32
EPOCHS = 25
TUNE_EPOCHS = 5      # for Bayesian optimisation
IMG_SIZE = 224
PATCH_SIZE = 16
NUM_CLASSES = 2 

TUNE_EPOCHS = 5

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

Right now, your functions use global state (model, optimizer, LAMBDA_ATTR, train_loader, val_loader, etc.).
We can exploit that and just re-assign the globals inside the Optuna objective to avoid refactoring everything.

Idea: for each trial we:
* create a new model with its own drop
* create a new optimizer (lr, weight_decay)
* set LAMBDA_ATTR
* train for TUNE_EPOCHS
* return best validation accuracy seen

In [22]:
TRAIN_VAL_PATH = "processed_bird_data"
TEST_PATH = "processed_bird_test_data"

print("Loading train/val dataset from:", TRAIN_VAL_PATH)
full_ds = load_from_disk(TRAIN_VAL_PATH)
train_hf = full_ds["train"]
val_hf = full_ds["validation"]

print("Train size:", len(train_hf))
print("Val size:", len(val_hf))

print("\nLoading test dataset from:", TEST_PATH)
test_hf = load_from_disk(TEST_PATH)
print("Test size:", len(test_hf))

# attributes
ATTR_PATH = "data/attributes.npy"
attributes = np.load(ATTR_PATH)
NUM_CLASSES = attributes.shape[0]
NUM_ATTR = attributes.shape[1]

print("\nAttributes shape:", attributes.shape)
print("NUM_CLASSES:", NUM_CLASSES, "| NUM_ATTR:", NUM_ATTR)

Loading train/val dataset from: processed_bird_data


FileNotFoundError: Directory processed_bird_data not found

In [None]:
from torchvision.transforms import InterpolationMode

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=25, interpolation=InterpolationMode.BILINEAR),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

In [None]:
class BirdTrainDataset(Dataset):
    def __init__(self, hf_dataset, attributes, transform=None):
        self.ds = hf_dataset
        self.attributes = attributes.astype("float32")
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        img = item["image"]
        if isinstance(img, Image.Image):
            img = img.convert("RGB")
        label = int(item["label"]) 

        if self.transform is not None:
            img = self.transform(img)

        attr_vec = self.attributes[label]
        attr_vec = torch.from_numpy(attr_vec)

        return img, attr_vec, label


class BirdTestDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.ds = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        img = item["image"]
        if isinstance(img, Image.Image):
            img = img.convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        img_id = int(item["id"])
        return img, img_id


train_dataset = BirdTrainDataset(train_hf, attributes, transform=train_transform)
val_dataset   = BirdTrainDataset(val_hf,   attributes, transform=eval_transform)
test_dataset  = BirdTestDataset(test_hf,   transform=eval_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size//patch_size
        self.num_patches = self.grid_size**2

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x [B,3,H,W]
        x = self.proj(x) #[B,embed_dim,H/P,W/P]
        x = x.flatten(2) #[B,embed_dim,num_patches]
        x = x.transpose(1, 2) #[B,num_patches,embed_dim]
        return x


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=192, num_heads=3, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=drop, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)

        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        #x [B,N,D]
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


class SimpleViTWithAttributes(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 num_classes=200, num_attr=312, embed_dim=192, depth=6,
                 num_heads=3, mlp_ratio=4.0, drop=0.1):
        super().__init__()

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(drop)

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, drop)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        self.head_class = nn.Linear(embed_dim, num_classes)
        # głowa atrybutowa
        self.head_attr  = nn.Linear(embed_dim, num_attr)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.head_class.weight, std=0.02)
        nn.init.trunc_normal_(self.head_attr.weight, std=0.02)

    def forward(self, x):
        # x:[B,3,224,224]
        B = x.size(0)
        x = self.patch_embed(x) #[B,N,D]

        cls_tokens = self.cls_token.expand(B,-1,-1) #[B,1,D]
        x = torch.cat((cls_tokens, x), dim=1) #[B,1+N,D]
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls = x[:, 0] #[B,D]

        logits = self.head_class(cls)
        attr_pred = self.head_attr(cls)
        return logits, attr_pred
    
model = SimpleViTWithAttributes(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=4.0,
    drop=0.1
).to(DEVICE)

model

SimpleViTWithAttributes(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x TransformerEncoderBlock(
      (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
      )
      (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=192, out_features=768, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=768, out_features=192, bias=True)
        (4): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  (head_class): Linear(in_features=192, out_features=200, bias=True)
  (head_attr): Linear(in_features=192, out_features=312, b

In [None]:
criterion_class = nn.CrossEntropyLoss()
criterion_attr  = nn.MSELoss()
LAMBDA_ATTR = 0.05

In [None]:
# ===== Option B: Self-Supervised MAE ("Nuclear Option") =====
# This uses the same IMG_SIZE, PATCH_SIZE, DEVICE, and train_loader already defined above.

class MAEViT(nn.Module):
    """
    Masked Autoencoder with ViT backbone.
    - Encoder: patch embedding + TransformerEncoder
    - Decoder: tiny transformer to reconstruct masked patches
    """
    def __init__(
        self,
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        in_chans=3,
        embed_dim=192,
        depth=6,
        num_heads=3,
        decoder_embed_dim=128,
        decoder_depth=4,
        decoder_num_heads=4,
        mlp_ratio=4.0,
        mask_ratio=0.75,
    ):
        super().__init__()

        # --- Encoder ---
        self.mask_ratio = mask_ratio
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # --- Decoder ---
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, decoder_embed_dim)
        )

        decoder_layer = nn.TransformerEncoderLayer(
            d_model=decoder_embed_dim,
            nhead=decoder_num_heads,
            dim_feedforward=int(decoder_embed_dim * mlp_ratio),
            batch_first=True,
        )
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=decoder_depth)

        patch_dim = patch_size * patch_size * in_chans
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_dim)

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.num_patches = num_patches
        self.norm_pix_loss = True

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
        nn.init.trunc_normal_(self.mask_token, std=0.02)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    # ------- patchify / unpatchify helpers -------

    def patchify(self, imgs):
        """
        imgs: [B, 3, H, W] -> [B, L, patch_dim]
        """
        p = self.patch_size
        B, C, H, W = imgs.shape
        assert H == self.img_size and W == self.img_size

        h = H // p
        w = W // p
        x = imgs.reshape(B, C, h, p, w, p)
        x = x.permute(0, 2, 4, 3, 5, 1)      # [B, h, w, p, p, C]
        x = x.reshape(B, h * w, p * p * C)   # [B, L, patch_dim]
        return x

    def unpatchify(self, x):
        """
        x: [B, L, patch_dim] -> [B, 3, H, W]
        """
        p = self.patch_size
        B, L, patch_dim = x.shape
        C = self.in_chans
        h = w = int(L ** 0.5)
        assert h * w == L

        x = x.reshape(B, h, w, p, p, C)
        x = x.permute(0, 5, 1, 3, 2, 4)
        imgs = x.reshape(B, C, h * p, w * p)
        return imgs

    # ------- random masking -------

    def random_masking(self, x, mask_ratio):
        """
        x: [B, L, D]
        Returns:
          x_masked, mask, ids_restore
        mask: 1 = masked, 0 = visible
        """
        B, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(B, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
        )

        mask = torch.ones(B, L, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    # ------- encoder / decoder forward -------

    def forward_encoder(self, imgs):
        x = self.patch_embed(imgs)              # [B, L, D]
        x = x + self.pos_embed                  # [B, L, D]
        x_masked, mask, ids_restore = self.random_masking(x, self.mask_ratio)
        x_encoded = self.encoder(x_masked)      # [B, L_visible, D]
        return x_encoded, mask, ids_restore

    def forward_decoder(self, x_encoded, ids_restore):
        B, L_vis, D = x_encoded.shape
        x = self.decoder_embed(x_encoded)       # [B, L_vis, D_dec]

        L = self.num_patches
        L_mask = L - L_vis
        mask_tokens = self.mask_token.repeat(B, L_mask, 1)

        x_ = torch.cat([x, mask_tokens], dim=1)       # [B, L, D_dec]
        index = ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2])
        x_full = torch.gather(x_, dim=1, index=index) # [B, L, D_dec]

        x_full = x_full + self.decoder_pos_embed
        x_full = self.decoder(x_full)                 # [B, L, D_dec]

        pred = self.decoder_pred(x_full)              # [B, L, patch_dim]
        return pred

    def loss(self, imgs, pred, mask):
        """
        imgs: [B, 3, H, W]
        pred: [B, L, patch_dim]
        mask: [B, L]
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1e-6).sqrt()

        loss_per_patch = (pred - target) ** 2
        loss_per_patch = loss_per_patch.mean(dim=-1)      # [B, L]
        loss = (loss_per_patch * mask).sum() / mask.sum()
        return loss

    def forward(self, imgs):
        x_encoded, mask, ids_restore = self.forward_encoder(imgs)
        pred = self.forward_decoder(x_encoded, ids_restore)
        loss = self.loss(imgs, pred, mask)
        return loss, pred, mask


In [None]:
def pretrain_mae(mae_model, loader, epochs=20, lr=1e-4, weight_decay=0.05):
    """
    Self-supervised MAE pretraining on the train_loader.
    We ignore attributes and labels here.
    Tracks reconstruction loss and patch-level MSE as a proxy metric.
    """
    optimizer = optim.AdamW(mae_model.parameters(), lr=lr, weight_decay=weight_decay)
    mae_model.train()

    for epoch in range(1, epochs + 1):
        total_loss = 0.0
        total_mse = 0.0
        samples = 0

        for imgs, _, _ in loader:   # ignore attr_vec, label
            imgs = imgs.to(DEVICE)

            optimizer.zero_grad()
            loss, pred, mask = mae_model(imgs)
            loss.backward()
            optimizer.step()

            bs = imgs.size(0)
            total_loss += loss.item() * bs
            
            # Compute patch-level MSE as reconstruction quality metric
            # (lower is better, inverse of "accuracy")
            patch_mse = ((pred - mae_model.patchify(imgs)) ** 2).mean().item()
            total_mse += patch_mse * bs
            
            samples += bs

        avg_loss = total_loss / samples
        avg_mse = total_mse / samples
        
        # MSE is inverted: lower MSE = better reconstruction = "higher accuracy"
        # Compute a proxy accuracy: 1 / (1 + MSE) to show improvement trend
        proxy_acc = 1.0 / (1.0 + avg_mse)
        
        print(f"[MAE] Epoch {epoch}/{epochs} - loss: {avg_loss:.4f} | Recon MSE: {avg_mse:.4f} | Proxy Acc: {proxy_acc:.4f}")

    return mae_model

# ---- actually run MAE pretraining (once) ----
mae = MAEViT(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_chans=3,
    embed_dim=192,
    depth=6,
    num_heads=3,
    decoder_embed_dim=128,
    decoder_depth=4,
    decoder_num_heads=4,
    mask_ratio=0.75,   # 75% masked = nuclear
).to(DEVICE)

mae = pretrain_mae(mae, train_loader, epochs=20, lr=1e-4, weight_decay=0.05)
torch.save(mae.state_dict(), "mae_pretrained_nuclear.pth")
print("Saved MAE nuclear weights to mae_pretrained_nuclear.pth")

[MAE] Epoch 1/20 - loss: 0.9152 | Recon MSE: 0.4048 | Proxy Acc: 0.7118
[MAE] Epoch 2/20 - loss: 0.7871 | Recon MSE: 0.3894 | Proxy Acc: 0.7197
[MAE] Epoch 2/20 - loss: 0.7871 | Recon MSE: 0.3894 | Proxy Acc: 0.7197
[MAE] Epoch 3/20 - loss: 0.7655 | Recon MSE: 0.4013 | Proxy Acc: 0.7136
[MAE] Epoch 3/20 - loss: 0.7655 | Recon MSE: 0.4013 | Proxy Acc: 0.7136
[MAE] Epoch 4/20 - loss: 0.7570 | Recon MSE: 0.4044 | Proxy Acc: 0.7120
[MAE] Epoch 4/20 - loss: 0.7570 | Recon MSE: 0.4044 | Proxy Acc: 0.7120
[MAE] Epoch 5/20 - loss: 0.7527 | Recon MSE: 0.4035 | Proxy Acc: 0.7125
[MAE] Epoch 5/20 - loss: 0.7527 | Recon MSE: 0.4035 | Proxy Acc: 0.7125
[MAE] Epoch 6/20 - loss: 0.7387 | Recon MSE: 0.3947 | Proxy Acc: 0.7170
[MAE] Epoch 6/20 - loss: 0.7387 | Recon MSE: 0.3947 | Proxy Acc: 0.7170
[MAE] Epoch 7/20 - loss: 0.7381 | Recon MSE: 0.3937 | Proxy Acc: 0.7175
[MAE] Epoch 7/20 - loss: 0.7381 | Recon MSE: 0.3937 | Proxy Acc: 0.7175
[MAE] Epoch 8/20 - loss: 0.7352 | Recon MSE: 0.3951 | Proxy Acc:

In [19]:
def train_one_epoch(epoch_idx):
    model.train()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    for batch_idx, (imgs, attr_targets, labels) in enumerate(train_loader):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        attr_targets = attr_targets.to(DEVICE)

        optimizer.zero_grad()

        logits, attr_pred = model(imgs)

        loss_cls = criterion_class(logits, labels)
        loss_attr = criterion_attr(attr_pred, attr_targets)
        loss = loss_cls + LAMBDA_ATTR * loss_attr

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        total_cls  += loss_cls.item() * imgs.size(0)
        total_attr += loss_attr.item() * imgs.size(0)

        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        samples += imgs.size(0)

        if batch_idx % 20 == 0:
            print(f"[Epoch {epoch_idx}] Batch {batch_idx}/{len(train_loader)} "
                  f"loss={loss.item():.4f}")

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    return avg_loss, avg_cls, avg_attr, acc


def evaluate():
    model.eval()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    with torch.no_grad():
        for imgs, attr_targets, labels in val_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            attr_targets = attr_targets.to(DEVICE)

            logits, attr_pred = model(imgs)

            loss_cls = criterion_class(logits, labels)
            loss_attr = criterion_attr(attr_pred, attr_targets)
            loss = loss_cls + LAMBDA_ATTR * loss_attr

            total_loss += loss.item() * imgs.size(0)
            total_cls  += loss_cls.item() * imgs.size(0)
            total_attr += loss_attr.item() * imgs.size(0)

            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            samples += imgs.size(0)

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    return avg_loss, avg_cls, avg_attr, acc

In [20]:
# ---- Training loop reusing your existing functions ----
best_val_acc = 0.0
for epoch in range(1, EPOCHS + 1):
    print(f"\n[NUCLEAR MAE] Epoch {epoch}/{EPOCHS}")
    train_loss, train_cls, train_attr, train_acc = train_one_epoch(epoch)
    val_loss, val_cls, val_attr, val_acc = evaluate()
    scheduler.step()
    
    print(f"  Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "mae_nuclear_classifier.pth")
        print("Saved new best nuclear MAE classifier")

print("Best val acc (nuclear / no Bayes):", best_val_acc)


[NUCLEAR MAE] Epoch 1/25


NameError: name 'optimizer' is not defined

In [21]:
best_params = study.best_trial.params
best_lr           = best_params["lr"]
best_weight_decay = best_params["weight_decay"]
best_drop         = best_params["drop"]
best_lambda_attr  = best_params["lambda_attr"]

reset_seed()

best_vit = SimpleViTWithAttributes(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=4.0,
    drop=best_drop,   
).to(DEVICE)

LAMBDA_ATTR = best_lambda_attr

optimizer = optim.AdamW(model.parameters(), lr=best_lr, weight_decay=best_weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

best_val_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    print(f"\n[FINAL TRAIN] Epoch {epoch}/{EPOCHS}")
    train_one_epoch(epoch)
    _, _, _, val_acc = evaluate()
    scheduler.step()

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "vit_best_model_optimized.pth")
        print("Saved new best optimized model")


NameError: name 'study' is not defined

In [None]:
print("Number of finished trials:", len(study.trials))
print("Best trial value (val_acc):", study.best_trial.value)
print("Best hyperparameters:")
for k, v in study.best_trial.params.items():
    print(f"  {k}: {v}")


Number of finished trials: 25
Best trial value (val_acc): 0.023769100169779286
Best hyperparameters:
  lr: 0.0004273669312062059
  weight_decay: 2.3739423955292538e-05
  drop: 0.10912464381815776
  lambda_attr: 0.010368427651183114


In [None]:
best_vit = SimpleViTWithAttributes(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=4.0,
    drop=best_drop,  
).to(DEVICE)

best_vit.load_state_dict(torch.load("vit_best_model.pth", map_location=DEVICE))
best_vit.eval()

all_ids = []
all_preds = []

# ---- Evaluate on test set ----
print("\n=== Evaluating on Test Set ===")
test_correct = 0
test_samples = 0

with torch.no_grad():
    for imgs, img_ids in test_loader:
        imgs = imgs.to(DEVICE)
        logits, attr_pred = best_vit(imgs)
        preds = logits.argmax(dim=1).cpu().numpy()

        all_ids.extend(img_ids.numpy().tolist())
        all_preds.extend(preds.tolist())
        
        # Note: test_loader doesn't have labels, so we can't compute true accuracy
        # But we can track predictions
        test_samples += len(img_ids)

print(f"Total test samples processed: {test_samples}")
print(f"Predictions made: {len(all_preds)}")

kaggle_labels = [p + 1 for p in all_preds]

submission_vit = pd.DataFrame({
    "id": all_ids,
    "label": kaggle_labels
}).sort_values("id")

print("\nFirst few predictions:")
print(submission_vit.head())

submission_vit.to_csv("vit_submission.csv", index=False)
print("\n✓ Saved vit_submission.csv")

   id  label
0   1     33
1   2     55
2   3     28
3   4     34
4   5     16

Saved vit_submission.csv


In [12]:
# ---- Final Evaluation on Validation Set (for reference accuracy) ----
print("\n=== Final Validation Set Evaluation ===")
best_vit.eval()
val_correct = 0
val_samples = 0
val_total_loss = 0.0

with torch.no_grad():
    for imgs, attr_targets, labels in val_loader:
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        attr_targets = attr_targets.to(DEVICE)

        logits, attr_pred = best_vit(imgs)
        
        loss_cls = criterion_class(logits, labels)
        loss_attr = criterion_attr(attr_pred, attr_targets)
        loss = loss_cls + LAMBDA_ATTR * loss_attr
        
        preds = logits.argmax(dim=1)
        val_correct += (preds == labels).sum().item()
        val_samples += imgs.size(0)
        val_total_loss += loss.item() * imgs.size(0)

val_acc = val_correct / val_samples
val_loss_avg = val_total_loss / val_samples

print(f"Validation Accuracy: {val_acc:.4f}")
print(f"Validation Loss: {val_loss_avg:.4f}")
print(f"Correct Predictions: {val_correct}/{val_samples}")



=== Final Validation Set Evaluation ===


NameError: name 'best_vit' is not defined