In [65]:
import os
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.transforms import InterpolationMode
from datasets import load_from_disk
from PIL import Image

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


Using device: cpu


In [66]:
# Cell 2: Config & Seeds

RANDOM_SEED = 42
BATCH_SIZE = 32
EPOCHS = 25
IMG_SIZE = 224
PATCH_SIZE = 16   # 224 / 16 = 14 patches per side -> 196 patches total

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

print("Random seed set to:", RANDOM_SEED)


Random seed set to: 42


In [67]:
from pathlib import Path
import pandas as pd

DATA_DIR = Path("data")   # folder you showed in the screenshot

# --- CSVs with paths + labels ---
train_df = pd.read_csv(DATA_DIR / "train_images.csv")
test_df  = pd.read_csv(DATA_DIR / "test_images_path.csv")

print("Train CSV shape:", train_df.shape)
print("Test  CSV shape:", test_df.shape)
print(train_df.head())

# --- Attributes ---
ATTR_PATH = DATA_DIR / "attributes.npy"
attributes = np.load(ATTR_PATH)
NUM_CLASSES = attributes.shape[0]
NUM_ATTR    = attributes.shape[1]

print("\nAttributes shape:", attributes.shape)
print("NUM_CLASSES:", NUM_CLASSES, "| NUM_ATTR:", NUM_ATTR)

# image root folders
TRAIN_IMG_DIR = DATA_DIR / "train_images"
TEST_IMG_DIR  = DATA_DIR / "test_images"


Train CSV shape: (3926, 2)
Test  CSV shape: (4000, 3)
            image_path  label
0  /train_images/1.jpg      1
1  /train_images/2.jpg      1
2  /train_images/3.jpg      1
3  /train_images/4.jpg      1
4  /train_images/5.jpg      1

Attributes shape: (200, 312)
NUM_CLASSES: 200 | NUM_ATTR: 312


In [68]:
# Cell 4: Transforms

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=25, interpolation=InterpolationMode.BILINEAR),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])


In [69]:
# Cell 5: Datasets & Dataloaders

# Cell 5: Datasets & Dataloaders

from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
import torch

class BirdTrainDataset(Dataset):
    def __init__(self, df, attributes, img_root, img_col="image_path", label_col="label", transform=None):
        self.df = df.reset_index(drop=True)
        self.attributes = attributes.astype("float32")
        self.img_root = Path(img_root)
        self.img_col = img_col
        self.label_col = label_col
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # 1. FIX PATH: Handle absolute vs relative path issue
        filename = Path(row[self.img_col]).name
        img_path = self.img_root / filename

        img = Image.open(img_path).convert("RGB")
        
        # 2. FIX INDEX ERROR: Subtract 1 to make labels 0-indexed (0 to 199)
        # Assuming your CSV labels are 1-200.
        label = int(row[self.label_col]) - 1

        attr_vec = self.attributes[label]  # Now accesses indices 0..199

        if self.transform is not None:
            img = self.transform(img)

        return (
            img,
            torch.tensor(attr_vec, dtype=torch.float32),
            torch.tensor(label, dtype=torch.long),
        )


class BirdTestDataset(Dataset):
    def __init__(self, df, img_root, img_col="image_path", id_col="id", transform=None):
        self.df = df.reset_index(drop=True)
        self.img_root = Path(img_root)
        self.img_col = img_col
        self.id_col = id_col
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # 1. FIX PATH here too
        filename = Path(row[self.img_col]).name
        img_path = self.img_root / filename

        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        img_id = int(row[self.id_col])
        return img, img_id

# --- Setup Datasets & Loaders ---

# Re-initialize datasets with the new classes
full_train_dataset = BirdTrainDataset(
    df=train_df,
    attributes=attributes,
    img_root=TRAIN_IMG_DIR,
    img_col="image_path",
    label_col="label",
    transform=train_transform,
)

# 80/20 split into train / val
train_size = int(0.8 * len(full_train_dataset))
val_size   = len(full_train_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(RANDOM_SEED),
)

test_dataset = BirdTestDataset(
    df=test_df,
    img_root=TEST_IMG_DIR,
    img_col="image_path",
    id_col="id",
    transform=eval_transform,
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print("Train batches:", len(train_loader))
print("Val batches:", len(val_loader))
print("Test batches:", len(test_loader))

Train batches: 99
Val batches: 25
Test batches: 125


In [70]:
# Cell 6: ViT Building Blocks (PatchEmbed & TransformerEncoderBlock)

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size ** 2

        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x):
        # x: [B, 3, H, W]
        x = self.proj(x)               # [B, embed_dim, H/P, W/P]
        x = x.flatten(2)               # [B, embed_dim, N]
        x = x.transpose(1, 2)          # [B, N, embed_dim]
        return x


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=192, num_heads=3, mlp_ratio=4.0, drop=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim,
            num_heads,
            dropout=drop,
            batch_first=True,
        )
        self.norm2 = nn.LayerNorm(embed_dim)

        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


In [71]:
# Cell 7: Loss Functions (classification + attribute regression)

criterion_class = nn.CrossEntropyLoss()
criterion_attr  = nn.MSELoss()
LAMBDA_ATTR = 0.05   # weight for attribute regression loss

print("LAMBDA_ATTR =", LAMBDA_ATTR)


LAMBDA_ATTR = 0.05


In [72]:
# Cell 8: MAEViT ("Nuclear Option": Mask 75% and reconstruct)

class MAEViT(nn.Module):
    """
    Masked Autoencoder with ViT backbone.
    Encoder = patch embedding + transformer encoder
    Decoder = tiny transformer that reconstructs masked patches.
    """
    def __init__(
        self,
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        in_chans=3,
        embed_dim=192,
        depth=6,
        num_heads=3,
        decoder_embed_dim=128,
        decoder_depth=4,
        decoder_num_heads=4,
        mlp_ratio=4.0,
        mask_ratio=0.75,
    ):
        super().__init__()

        self.mask_ratio = mask_ratio
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # Decoder
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, decoder_embed_dim)
        )

        decoder_layer = nn.TransformerEncoderLayer(
            d_model=decoder_embed_dim,
            nhead=decoder_num_heads,
            dim_feedforward=int(decoder_embed_dim * mlp_ratio),
            batch_first=True,
        )
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=decoder_depth)

        patch_dim = patch_size * patch_size * in_chans
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_dim)

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.num_patches = num_patches
        self.norm_pix_loss = True

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
        nn.init.trunc_normal_(self.mask_token, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    # ---- Patchify / Unpatchify ----

    def patchify(self, imgs):
        """
        imgs: [B, 3, H, W] -> [B, L, patch_dim]
        """
        p = self.patch_size
        B, C, H, W = imgs.shape
        assert H == self.img_size and W == self.img_size

        h = H // p
        w = W // p
        x = imgs.reshape(B, C, h, p, w, p)
        x = x.permute(0, 2, 4, 3, 5, 1)      # [B, h, w, p, p, C]
        x = x.reshape(B, h * w, p * p * C)   # [B, L, patch_dim]
        return x

    def unpatchify(self, x):
        """
        x: [B, L, patch_dim] -> [B, 3, H, W]
        """
        p = self.patch_size
        B, L, patch_dim = x.shape
        C = self.in_chans
        h = w = int(L ** 0.5)
        assert h * w == L

        x = x.reshape(B, h, w, p, p, C)
        x = x.permute(0, 5, 1, 3, 2, 4)
        imgs = x.reshape(B, C, h * p, w * p)
        return imgs

    # ---- Random Masking ----

    def random_masking(self, x, mask_ratio):
        """
        x: [B, L, D]
        Returns: x_masked, mask, ids_restore
        mask: 1 = masked, 0 = visible
        """
        B, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(B, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
        )

        mask = torch.ones(B, L, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        return x_masked, mask, ids_restore

    # ---- Encoder & Decoder ----

    def forward_encoder(self, imgs):
        x = self.patch_embed(imgs)          # [B, L, D]
        x = x + self.pos_embed
        x_masked, mask, ids_restore = self.random_masking(x, self.mask_ratio)
        x_encoded = self.encoder(x_masked)
        return x_encoded, mask, ids_restore

    def forward_decoder(self, x_encoded, ids_restore):
        B, L_vis, D = x_encoded.shape
        x = self.decoder_embed(x_encoded)   # [B, L_vis, D_dec]

        L = self.num_patches
        L_mask = L - L_vis
        mask_tokens = self.mask_token.repeat(B, L_mask, 1)

        x_ = torch.cat([x, mask_tokens], dim=1)        # [B, L, D_dec]
        index = ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2])
        x_full = torch.gather(x_, dim=1, index=index)  # [B, L, D_dec]

        x_full = x_full + self.decoder_pos_embed
        x_full = self.decoder(x_full)
        pred = self.decoder_pred(x_full)               # [B, L, patch_dim]
        return pred

    def loss(self, imgs, pred, mask):
        target = self.patchify(imgs)

        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1e-6).sqrt()

        loss_per_patch = (pred - target) ** 2
        loss_per_patch = loss_per_patch.mean(dim=-1)
        loss = (loss_per_patch * mask).sum() / mask.sum()
        return loss

    def forward(self, imgs):
        x_encoded, mask, ids_restore = self.forward_encoder(imgs)
        pred = self.forward_decoder(x_encoded, ids_restore)
        loss = self.loss(imgs, pred, mask)
        return loss, pred, mask


In [73]:
import pandas as pd
from pathlib import Path

DATA_DIR = Path("data")  # adjust if needed
train_df = pd.read_csv(DATA_DIR / "train_images.csv")

print(train_df.columns)
print(train_df.head())


Index(['image_path', 'label'], dtype='object')
            image_path  label
0  /train_images/1.jpg      1
1  /train_images/2.jpg      1
2  /train_images/3.jpg      1
3  /train_images/4.jpg      1
4  /train_images/5.jpg      1


In [None]:
# Cell 9: Self-supervised pretraining of MAE on train images
from tqdm.auto import tqdm  # specific import for progress bars in notebooks

def pretrain_mae(mae_model, loader, epochs=20, lr=1e-4, weight_decay=0.05):
    optimizer = optim.AdamW(mae_model.parameters(), lr=lr, weight_decay=weight_decay)
    mae_model.train()

    print(f"Starting MAE Pretraining for {epochs} epochs...")
    print("Metric to watch: 'Reconstruction Loss' (Lower is better)")
    
    for epoch in range(1, epochs + 1):
        total_loss = 0.0
        samples = 0
        
        # Create a progress bar for the dataloader
        pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=True)
        
        for imgs, _, _ in pbar:   # ignore attributes + labels (we don't use them yet!)
            imgs = imgs.to(DEVICE)

            optimizer.zero_grad()
            loss, _, _ = mae_model(imgs)
            loss.backward()
            optimizer.step()

            bs = imgs.size(0)
            total_loss += loss.item() * bs
            samples += bs
            
            # Update the progress bar with the running average loss
            current_avg_loss = total_loss / samples
            pbar.set_postfix({"Recon Loss": f"{current_avg_loss:.4f}"})

        # Final print for the epoch log
        print(f"[MAE] Epoch {epoch}/{epochs} completed. Avg Loss: {total_loss / samples:.4f}")

    return mae_model


mae = MAEViT(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_chans=3,
    embed_dim=192,
    depth=6,
    num_heads=3,
    decoder_embed_dim=128,
    decoder_depth=4,
    decoder_num_heads=4,
    mask_ratio=0.75,   # 75% masked: "nuclear option"
).to(DEVICE)

mae = pretrain_mae(mae, train_loader, epochs=20, lr=1e-4, weight_decay=0.05)
torch.save(mae.state_dict(), "mae_pretrained_nuclear.pth")
print("Saved MAE nuclear weights -> mae_pretrained_nuclear.pth")

Starting MAE Pretraining for 20 epochs...
Metric to watch: 'Reconstruction Loss' (Lower is better)


Epoch 1/20:   0%|          | 0/99 [00:00<?, ?it/s]

In [None]:
# Cell 10: Classifier + attribute head reusing MAE encoder

class MAEViTClassifierWithAttributes(nn.Module):
    """
    Reuses MAE encoder as backbone, adds:
      - class head (NUM_CLASSES)
      - attribute regression head (NUM_ATTR)
    """
    def __init__(self, mae_model: MAEViT, num_classes: int, num_attr: int, drop: float = 0.1):
        super().__init__()

        self.patch_embed = mae_model.patch_embed
        self.pos_embed   = mae_model.pos_embed
        self.encoder     = mae_model.encoder
        self.embed_dim   = mae_model.pos_embed.shape[-1]

        self.norm = nn.LayerNorm(self.embed_dim)
        self.dropout = nn.Dropout(drop)

        self.head_class = nn.Linear(self.embed_dim, num_classes)
        self.head_attr  = nn.Linear(self.embed_dim, num_attr)

    def forward(self, x):
        x = self.patch_embed(x)          # [B, L, D]
        x = x + self.pos_embed
        x = self.encoder(x)              # [B, L, D]
        x = self.norm(x)
        x = x.mean(dim=1)                # global average pooling over patches
        x = self.dropout(x)

        logits = self.head_class(x)
        attr_pred = self.head_attr(x)
        return logits, attr_pred


# Reload MAE weights into fresh encoder
mae_for_ft = MAEViT(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_chans=3,
    embed_dim=192,
    depth=6,
    num_heads=3,
    decoder_embed_dim=128,
    decoder_depth=4,
    decoder_num_heads=4,
    mask_ratio=0.75,
).to(DEVICE)

mae_for_ft.load_state_dict(torch.load("mae_pretrained_nuclear.pth", map_location=DEVICE))

# Option: freeze encoder (backbone acts like fixed "pretrained")
for p in mae_for_ft.patch_embed.parameters():
    p.requires_grad = False
for p in mae_for_ft.encoder.parameters():
    p.requires_grad = False
mae_for_ft.pos_embed.requires_grad = False

model = MAEViTClassifierWithAttributes(
    mae_model=mae_for_ft,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    drop=0.1,
).to(DEVICE)

LAMBDA_ATTR = 0.05

optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-3,
    weight_decay=1e-4,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

print("MAE classifier ready. Trainable params:",
      sum(p.numel() for p in model.parameters() if p.requires_grad))


In [None]:
# Cell 11: Training & evaluation loops (no Bayesian optimisation)

def train_one_epoch(epoch_idx):
    model.train()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    for batch_idx, (imgs, attr_targets, labels) in enumerate(train_loader):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        attr_targets = attr_targets.to(DEVICE)

        optimizer.zero_grad()

        logits, attr_pred = model(imgs)

        loss_cls = criterion_class(logits, labels)
        loss_attr = criterion_attr(attr_pred, attr_targets)
        loss = loss_cls + LAMBDA_ATTR * loss_attr

        loss.backward()
        optimizer.step()

        bs = imgs.size(0)
        total_loss += loss.item() * bs
        total_cls  += loss_cls.item() * bs
        total_attr += loss_attr.item() * bs

        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        samples += bs

        if batch_idx % 20 == 0:
            batch_acc = (preds == labels).float().mean().item()
            print(f"[Epoch {epoch_idx}] Batch {batch_idx}/{len(train_loader)} "
                  f"- Loss: {loss.item():.4f} | Batch Acc: {batch_acc:.4f}")

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    print(f"[Epoch {epoch_idx}] Train Loss: {avg_loss:.4f} "
          f"(cls: {avg_cls:.4f}, attr: {avg_attr:.4f}) | "
          f"Train Acc: {acc:.4f}")
    return avg_loss, avg_cls, avg_attr, acc

def evaluate():
    model.eval()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    with torch.no_grad():
        for imgs, attr_targets, labels in val_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            attr_targets = attr_targets.to(DEVICE)

            logits, attr_pred = model(imgs)

            loss_cls = criterion_class(logits, labels)
            loss_attr = criterion_attr(attr_pred, attr_targets)
            loss = loss_cls + LAMBDA_ATTR * loss_attr

            bs = imgs.size(0)
            total_loss += loss.item() * bs
            total_cls  += loss_cls.item() * bs
            total_attr += loss_attr.item() * bs

            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            samples += bs

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    print(f"[Validation] Loss: {avg_loss:.4f} "
          f"(cls: {avg_cls:.4f}, attr: {avg_attr:.4f}) | "
          f"Val Acc: {acc:.4f}")
    return avg_loss, avg_cls, avg_attr, acc


In [None]:
# Cell 12: Train MAE-backed classifier ("nuclear option", no Bayes)

best_val_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    print(f"\n[NUCLEAR MAE] Epoch {epoch}/{EPOCHS}")
    train_loss, train_cls, train_attr, train_acc = train_one_epoch(epoch)
    val_loss, val_cls, val_attr, val_acc = evaluate()
    scheduler.step()

    print(f"  -> Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "mae_nuclear_classifier.pth")
        print("Saved new best nuclear MAE classifier -> mae_nuclear_classifier.pth")

print("Best validation accuracy (MAE nuclear):", best_val_acc)


In [None]:
# Cell 13: Final validation eval with best-saved MAE classifier

# Reload a fresh encoder & classifier to ensure we're using the saved best model
mae_eval = MAEViT(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_chans=3,
    embed_dim=192,
    depth=6,
    num_heads=3,
    decoder_embed_dim=128,
    decoder_depth=4,
    decoder_num_heads=4,
    mask_ratio=0.75,
).to(DEVICE)

mae_eval.load_state_dict(torch.load("mae_pretrained_nuclear.pth", map_location=DEVICE))

best_mae = MAEViTClassifierWithAttributes(
    mae_model=mae_eval,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    drop=0.1,
).to(DEVICE)

best_mae.load_state_dict(torch.load("mae_nuclear_classifier.pth", map_location=DEVICE))
best_mae.eval()

val_correct = 0
val_samples = 0
val_total_loss = 0.0

with torch.no_grad():
    for imgs, attr_targets, labels in val_loader:
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        attr_targets = attr_targets.to(DEVICE)

        logits, attr_pred = best_mae(imgs)

        loss_cls = criterion_class(logits, labels)
        loss_attr = criterion_attr(attr_pred, attr_targets)
        loss = loss_cls + LAMBDA_ATTR * loss_attr

        preds = logits.argmax(dim=1)
        val_correct += (preds == labels).sum().item()
        val_samples += imgs.size(0)
        val_total_loss += loss.item() * imgs.size(0)

val_acc = val_correct / val_samples
val_loss_avg = val_total_loss / val_samples

print("\n=== Final Validation Evaluation (MAE Nuclear) ===")
print(f"Validation Accuracy: {val_acc:.4f}")
print(f"Validation Loss:     {val_loss_avg:.4f}")
print(f"Correct Predictions: {val_correct}/{val_samples}")


In [None]:
# Cell 14: Test predictions + submission using MAE nuclear model

# Reuse best_mae from previous cell; if notebook restarted, rebuild as below:

# mae_eval = MAEViT(
#     img_size=IMG_SIZE,
#     patch_size=PATCH_SIZE,
#     in_chans=3,
#     embed_dim=192,
#     depth=6,
#     num_heads=3,
#     decoder_embed_dim=128,
#     decoder_depth=4,
#     decoder_num_heads=4,
#     mask_ratio=0.75,
# ).to(DEVICE)
# mae_eval.load_state_dict(torch.load("mae_pretrained_nuclear.pth", map_location=DEVICE))
# best_mae = MAEViTClassifierWithAttributes(
#     mae_model=mae_eval,
#     num_classes=NUM_CLASSES,
#     num_attr=NUM_ATTR,
#     drop=0.1,
# ).to(DEVICE)
# best_mae.load_state_dict(torch.load("mae_nuclear_classifier.pth", map_location=DEVICE))
# best_mae.eval()

# Test predictions + submission using MAE nuclear model
best_mae.eval()

all_ids = []
all_preds = []
test_samples = 0

with torch.no_grad():
    for imgs, img_ids in test_loader:
        imgs = imgs.to(DEVICE)
        logits, attr_pred = best_mae(imgs)

        preds = logits.argmax(dim=1).cpu().numpy()

        all_ids.extend(img_ids)
        all_preds.extend(preds)
        test_samples += len(img_ids)

print(f"Total test samples processed: {test_samples}")
print(f"Predictions made:             {len(all_preds)}")

# Show predicted class distribution
all_preds_np = np.array(all_preds)
unique, counts = np.unique(all_preds_np, return_counts=True)
print("\nPredicted class distribution on test set (0-based):")
for u, c in zip(unique, counts):
    print(f"  class {u}: {c} samples")

# If competition expects labels starting from 1 instead of 0:
kaggle_labels = [p + 1 for p in all_preds]

submission = pd.DataFrame({
    "id": all_ids,
    "label": kaggle_labels,
}).sort_values("id")

print("\nFirst few predictions:")
print(submission.head())

submission.to_csv("mae_nuclear_submission.csv", index=False)
print("\nâœ“ Saved mae_nuclear_submission.csv")