In [1]:
import os
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, random_split

from torchvision import transforms
from torchvision.transforms import InterpolationMode
from datasets import load_from_disk
from PIL import Image

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cpu


In [2]:
RANDOM_SEED = 42
BATCH_SIZE = 32
EPOCHS = 25
IMG_SIZE = 224
PATCH_SIZE = 16  
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

print("Random seed set to:", RANDOM_SEED)

Random seed set to: 42


In [3]:
from pathlib import Path

DATA_DIR = Path("data")   

train_df = pd.read_csv(DATA_DIR / "train_images.csv")
test_df  = pd.read_csv(DATA_DIR / "test_images_path.csv")

print("Train CSV shape:", train_df.shape)
print("Test  CSV shape:", test_df.shape)
print(train_df.head())

# attributes 
ATTR_PATH = DATA_DIR / "attributes.npy"
attributes = np.load(ATTR_PATH)
NUM_CLASSES = attributes.shape[0]
NUM_ATTR    = attributes.shape[1]

print("\nAttributes shape:", attributes.shape)
print("NUM_CLASSES:", NUM_CLASSES, "| NUM_ATTR:", NUM_ATTR)

# image root folders
TRAIN_IMG_DIR = DATA_DIR / "train_images"
TEST_IMG_DIR  = DATA_DIR / "test_images"

Train CSV shape: (3926, 2)
Test  CSV shape: (4000, 3)
            image_path  label
0  /train_images/1.jpg      1
1  /train_images/2.jpg      1
2  /train_images/3.jpg      1
3  /train_images/4.jpg      1
4  /train_images/5.jpg      1

Attributes shape: (200, 312)
NUM_CLASSES: 200 | NUM_ATTR: 312


In [4]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=25, interpolation=InterpolationMode.BILINEAR),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

In [None]:
import torch

class BirdTrainDataset(Dataset):
    def __init__(self, df, attributes, img_root, img_col="image_path", label_col="label", transform=None):
        self.df = df.reset_index(drop=True)
        self.attributes = attributes.astype("float32")
        self.img_root = Path(img_root)
        self.img_col = img_col
        self.label_col = label_col
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        filename = Path(row[self.img_col]).name
        img_path = self.img_root / filename

        img = Image.open(img_path).convert("RGB")
        
        label = int(row[self.label_col]) - 1

        attr_vec = self.attributes[label]  
        if self.transform is not None:
            img = self.transform(img)

        return (
            img,
            torch.tensor(attr_vec, dtype=torch.float32),
            torch.tensor(label, dtype=torch.long),
        )


class BirdTestDataset(Dataset):
    def __init__(self, df, img_root, img_col="image_path", id_col="id", transform=None):
        self.df = df.reset_index(drop=True)
        self.img_root = Path(img_root)
        self.img_col = img_col
        self.id_col = id_col
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        filename = Path(row[self.img_col]).name
        img_path = self.img_root / filename

        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        img_id = int(row[self.id_col])
        return img, img_id


full_train_dataset = BirdTrainDataset(
    df=train_df,
    attributes=attributes,
    img_root=TRAIN_IMG_DIR,
    img_col="image_path",
    label_col="label",
    transform=train_transform,
)

# 80/20 split into train / val
train_size = int(0.8 * len(full_train_dataset))
val_size   = len(full_train_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(RANDOM_SEED),
)

test_dataset = BirdTestDataset(
    df=test_df,
    img_root=TEST_IMG_DIR,
    img_col="image_path",
    id_col="id",
    transform=eval_transform,
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print("Train batches:", len(train_loader))
print("Val batches:", len(val_loader))
print("Test batches:", len(test_loader))

In [6]:
# ViT Building Blocks (PatchEmbed & TransformerEncoderBlock)

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size ** 2

        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x):
        #x:[B,3,H,W]
        x = self.proj(x) # [B, embed_dim, H/P, W/P]
        x = x.flatten(2) # [B, embed_dim, N]
        x = x.transpose(1, 2) # [B, N, embed_dim]
        return x


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=192, num_heads=3, mlp_ratio=4.0, drop=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim,
            num_heads,
            dropout=drop,
            batch_first=True,
        )
        self.norm2 = nn.LayerNorm(embed_dim)

        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


In [7]:
# loss functions 
criterion_class = nn.CrossEntropyLoss()
criterion_attr  = nn.MSELoss()
LAMBDA_ATTR = 0.05 # weight for attribute regression loss

print("LAMBDA_ATTR =", LAMBDA_ATTR)

LAMBDA_ATTR = 0.05


In [8]:
class MAEViT(nn.Module):
    """
    Masked Autoencoder with ViT backbone.
    Encoder = patch embedding + transformer encoder
    Decoder = tiny transformer that reconstructs masked patches.
    """
    def __init__(
        self,
        img_size=IMG_SIZE,
        patch_size=PATCH_SIZE,
        in_chans=3,
        embed_dim=192,
        depth=6,
        num_heads=3,
        decoder_embed_dim=128,
        decoder_depth=4,
        decoder_num_heads=4,
        mlp_ratio=4.0,
        mask_ratio=0.75,
    ):
        super().__init__()

        self.mask_ratio = mask_ratio
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # Decoder
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, decoder_embed_dim)
        )

        decoder_layer = nn.TransformerEncoderLayer(
            d_model=decoder_embed_dim,
            nhead=decoder_num_heads,
            dim_feedforward=int(decoder_embed_dim * mlp_ratio),
            batch_first=True,
        )
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=decoder_depth)

        patch_dim = patch_size * patch_size * in_chans
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_dim)

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.num_patches = num_patches
        self.norm_pix_loss = True

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
        nn.init.trunc_normal_(self.mask_token, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    # patchify / Unpatchify 
    def patchify(self, imgs):
        """
        imgs: [B, 3, H, W] -> [B, L, patch_dim]
        """
        p = self.patch_size
        B, C, H, W = imgs.shape
        assert H == self.img_size and W == self.img_size

        h = H // p
        w = W // p
        x = imgs.reshape(B, C, h, p, w, p)
        x = x.permute(0, 2, 4, 3, 5, 1)  # [B, h, w, p, p, C]
        x = x.reshape(B, h * w, p * p * C) # [B, L, patch_dim]
        return x

    def unpatchify(self, x):
        """
        x: [B, L, patch_dim] -> [B, 3, H, W]
        """
        p = self.patch_size
        B, L, patch_dim = x.shape
        C = self.in_chans
        h = w = int(L ** 0.5)
        assert h * w == L

        x = x.reshape(B, h, w, p, p, C)
        x = x.permute(0, 5, 1, 3, 2, 4)
        imgs = x.reshape(B, C, h * p, w * p)
        return imgs

    # Random Masking 
    def random_masking(self, x, mask_ratio):
        """
        x: [B, L, D]
        Returns: x_masked, mask, ids_restore
        mask: 1 = masked, 0 = visible
        """
        B, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(B, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
        )

        mask = torch.ones(B, L, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        return x_masked, mask, ids_restore

    # encoder, decoder
    def forward_encoder(self, imgs):
        x = self.patch_embed(imgs) # [B, L, D]
        x = x + self.pos_embed
        x_masked, mask, ids_restore = self.random_masking(x, self.mask_ratio)
        x_encoded = self.encoder(x_masked)
        return x_encoded, mask, ids_restore

    def forward_decoder(self, x_encoded, ids_restore):
        B, L_vis, D = x_encoded.shape
        x = self.decoder_embed(x_encoded) # [B,L_vis,D_dec]

        L = self.num_patches
        L_mask = L - L_vis
        mask_tokens = self.mask_token.repeat(B, L_mask, 1)

        x_ = torch.cat([x, mask_tokens], dim=1) # [B, L, D_dec]
        index = ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2])
        x_full = torch.gather(x_, dim=1, index=index) # [B, L, D_dec]

        x_full = x_full + self.decoder_pos_embed
        x_full = self.decoder(x_full)
        pred = self.decoder_pred(x_full) #[B,L,patch_dim]
        return pred

    def loss(self, imgs, pred, mask):
        target = self.patchify(imgs)

        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1e-6).sqrt()

        loss_per_patch = (pred - target) ** 2
        loss_per_patch = loss_per_patch.mean(dim=-1)
        loss = (loss_per_patch * mask).sum() / mask.sum()
        return loss

    def forward(self, imgs):
        x_encoded, mask, ids_restore = self.forward_encoder(imgs)
        pred = self.forward_decoder(x_encoded, ids_restore)
        loss = self.loss(imgs, pred, mask)
        return loss, pred, mask

In [9]:
encoder = MAEViT(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_chans=3,
).encoder



In [None]:
from tqdm.auto import tqdm

def pretrain_mae(mae_model, loader, epochs=20, lr=1e-4, weight_decay=0.05):
    optimizer = optim.AdamW(mae_model.parameters(), lr=lr, weight_decay=weight_decay)
    mae_model.train()

    print(f"Starting MAE Pretraining for {epochs} epochs...")
    print("Metrics: Recon Loss (lower better), Masked MSE (lower better)")

    for epoch in range(1, epochs + 1):
        total_loss = 0.0
        total_masked_mse = 0.0
        n_batches = 0
        samples = 0

        pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=True)

        for imgs, _, _ in pbar:
            imgs = imgs.to(DEVICE)

            optimizer.zero_grad()
            loss, pred_patches, mask = mae_model(imgs)  # pred_patches: [B,N,D], mask: [B,N]
            loss.backward()
            optimizer.step()

            bs = imgs.size(0)
            samples += bs
            total_loss += loss.item() * bs

            with torch.no_grad():
                mask = mask.bool()  # [B,N]
                per_patch_mse = (pred_patches ** 2).mean(dim=-1)  # [B,N]
                masked_mse = per_patch_mse[mask].mean().item()

            total_masked_mse += masked_mse
            n_batches += 1

            pbar.set_postfix({
                "Recon Loss": f"{(total_loss / samples):.4f}",
                "Masked MSE": f"{(total_masked_mse / n_batches):.4f}"
            })

        print(f"[MAE] Epoch {epoch}/{epochs} done | Avg Loss: {total_loss / samples:.4f} | Avg Masked MSE: {total_masked_mse / n_batches:.4f}")

    return mae_model

mae = MAEViT(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_chans=3,
    embed_dim=192,
    depth=6,
    num_heads=3,
    decoder_embed_dim=128,
    decoder_depth=4,
    decoder_num_heads=4,
    mask_ratio=0.75,   # 75% masked: "nuclear option"
).to(DEVICE)

mae = pretrain_mae(mae, train_loader, epochs=20, lr=1e-4, weight_decay=0.05)

os.makedirs("checkpoints", exist_ok=True)

torch.save(
    mae.encoder.state_dict(),
    "checkpoints/mae_encoder_pretrained.pt"
)
print("MAE encoder saved correctly")

Starting MAE Pretraining for 20 epochs...
Metrics: Recon Loss (lower better), Masked MSE (lower better)


Epoch 1/20:   0%|          | 0/99 [00:00<?, ?it/s]

In [None]:
# classifier + attribute head reusing MAE encoder
LR = 3e-4
class MAEViTClassifierWithAttributes(nn.Module):
    """
    Uses pretrained MAE encoder as backbone.
    """
    def __init__(self, patch_embed, pos_embed, encoder, num_classes, num_attr, drop=0.1):
        super().__init__()

        self.patch_embed = patch_embed
        self.pos_embed   = pos_embed
        self.encoder     = encoder

        self.embed_dim = pos_embed.shape[-1]

        self.norm = nn.LayerNorm(self.embed_dim)
        self.dropout = nn.Dropout(drop)

        self.head_class = nn.Linear(self.embed_dim, num_classes)
        self.head_attr  = nn.Linear(self.embed_dim, num_attr)

    def forward(self, x):
        x = self.patch_embed(x) # [B, L, D]
        x = x + self.pos_embed
        x = self.encoder(x) # [B, L, D]
        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.dropout(x)

        return self.head_class(x), self.head_attr(x)

# Reload MAE weights into fresh encoder
mae_for_ft = MAEViT(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_chans=3,
    embed_dim=192,
    depth=6,
    num_heads=3,
    decoder_embed_dim=128,
    decoder_depth=4,
    decoder_num_heads=4,
    mask_ratio=0.75, 
).to(DEVICE)

mae_for_ft.encoder.load_state_dict(
    torch.load(
        "checkpoints/mae_encoder_pretrained.pt",
        map_location=DEVICE
    )
)

print("MAE encoder weights loaded (decoder ignored)")

for p in mae_for_ft.patch_embed.parameters():
    p.requires_grad = False
for p in mae_for_ft.encoder.parameters():
    p.requires_grad = False

mae_for_ft.pos_embed.requires_grad = False
model = MAEViTClassifierWithAttributes(
    patch_embed=mae_for_ft.patch_embed,
    pos_embed=mae_for_ft.pos_embed,
    encoder=mae_for_ft.encoder,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    drop=0.1,
).to(DEVICE)

optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR
)

print("MAE classifier ready. Trainable params:",
      sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
w = next(iter(mae_for_ft.encoder.parameters()))
print("Mean abs encoder weight:", torch.mean(torch.abs(w)).item())

encoder.load_state_dict(torch.load("checkpoints/mae_encoder_pretrained.pt"))
encoder.to(DEVICE)
print("Pretrained MAE encoder loaded")

In [None]:
def train_one_epoch(epoch_idx):
    model.train()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    for batch_idx, (imgs, attr_targets, labels) in enumerate(train_loader):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        attr_targets = attr_targets.to(DEVICE)

        optimizer.zero_grad()

        logits, attr_pred = model(imgs)

        loss_cls = criterion_class(logits, labels)
        loss_attr = criterion_attr(attr_pred, attr_targets)
        loss = loss_cls + LAMBDA_ATTR * loss_attr

        loss.backward()
        optimizer.step()

        bs = imgs.size(0)
        total_loss += loss.item() * bs
        total_cls  += loss_cls.item() * bs
        total_attr += loss_attr.item() * bs

        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        samples += bs

        if batch_idx % 20 == 0:
            batch_acc = (preds == labels).float().mean().item()
            print(f"[Epoch {epoch_idx}] Batch {batch_idx}/{len(train_loader)} "
                  f"- Loss: {loss.item():.4f} | Batch Acc: {batch_acc:.4f}")

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    print(f"[Epoch {epoch_idx}] Train Loss: {avg_loss:.4f} "
          f"(cls: {avg_cls:.4f}, attr: {avg_attr:.4f}) | "
          f"Train Acc: {acc:.4f}")
    return avg_loss, avg_cls, avg_attr, acc

def evaluate():
    model.eval()
    total_loss = 0.0
    total_cls  = 0.0
    total_attr = 0.0
    correct = 0
    samples = 0

    with torch.no_grad():
        for imgs, attr_targets, labels in val_loader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            attr_targets = attr_targets.to(DEVICE)

            logits, attr_pred = model(imgs)

            loss_cls = criterion_class(logits, labels)
            loss_attr = criterion_attr(attr_pred, attr_targets)
            loss = loss_cls + LAMBDA_ATTR * loss_attr

            bs = imgs.size(0)
            total_loss += loss.item() * bs
            total_cls  += loss_cls.item() * bs
            total_attr += loss_attr.item() * bs

            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            samples += bs

    avg_loss = total_loss / samples
    avg_cls  = total_cls  / samples
    avg_attr = total_attr / samples
    acc = correct / samples
    print(f"[Validation] Loss: {avg_loss:.4f} "
          f"(cls: {avg_cls:.4f}, attr: {avg_attr:.4f}) | "
          f"Val Acc: {acc:.4f}")
    return avg_loss, avg_cls, avg_attr, acc


In [None]:
# Train MAE-backed classifier + checkpoints

import os
os.makedirs("checkpoints", exist_ok=True)

best_val_acc = 0.0
CHECKPOINT_EVERY = 5   # save every epoch 
for epoch in range(1, EPOCHS + 1):
    print(f"\n[CLS + MAE] Epoch {epoch}/{EPOCHS}")

    # Train + Validate
    train_loss, train_cls, train_attr, train_acc = train_one_epoch(epoch)
    val_loss, val_cls, val_attr, val_acc = evaluate()

    print(f"→ Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "mae_nuclear_classifier.pt")
        print("Saved new best MAE classifier")

    # Save checkpoint (resume-safe)
    if epoch % CHECKPOINT_EVERY == 0:
        checkpoint = {
            "epoch": epoch,
            "model_state": {k: v.cpu() for k, v in model.state_dict().items()},
            "optimizer_state": optimizer.state_dict(),
            "best_val_acc": best_val_acc,
        }

        torch.save(
            checkpoint,
            f"checkpoints/checkpoint_epoch_{epoch}.pt"
)

        print(f"Checkpoint saved at epoch {epoch}")

print("Best validation accuracy (MAE-backed):", best_val_acc)

In [None]:
criterion_cls = nn.CrossEntropyLoss()
criterion_attr = nn.MSELoss()

# Rebuild MAE skeleton
mae_eval = MAEViT(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_chans=3,
    embed_dim=192,
    depth=6,
    num_heads=3,
    decoder_embed_dim=128,
    decoder_depth=4,
    decoder_num_heads=4,
    mask_ratio=0.75,
)

# encoder weights ON CPU
mae_eval.encoder.load_state_dict(
    torch.load(
        "checkpoints/mae_encoder_pretrained.pt",
        map_location="cpu"
    )
)

# MAE backbone to GPU AFTER loading
mae_eval = mae_eval.to(DEVICE)

print("MAE encoder loaded for evaluation")

best_mae = MAEViTClassifierWithAttributes(
    patch_embed=mae_eval.patch_embed,
    pos_embed=mae_eval.pos_embed,
    encoder=mae_eval.encoder,
    num_classes=NUM_CLASSES,
    num_attr=NUM_ATTR,
    drop=0.1,
)

# classifier weights ON CPU
state_dict = torch.load(
    "mae_nuclear_classifier.pt",
    map_location="cpu"
)

best_mae.load_state_dict(state_dict)

# moving the model to GPU AFTER loading
best_mae = best_mae.to(DEVICE)
best_mae.eval()

print("Best MAE classifier loaded")

# Final validation loop
val_correct = 0
val_samples = 0
val_total_loss = 0.0

with torch.no_grad():
    for imgs, attr_targets, labels in val_loader:
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        attr_targets = attr_targets.to(DEVICE)

        logits, attr_pred = best_mae(imgs)

        loss_cls = criterion_cls(logits, labels)
        loss_attr = criterion_attr(attr_pred, attr_targets)
        loss = loss_cls + LAMBDA_ATTR * loss_attr

        preds = logits.argmax(dim=1)
        val_correct += (preds == labels).sum().item()
        val_samples += imgs.size(0)
        val_total_loss += loss.item() * imgs.size(0)

val_acc = val_correct / val_samples
val_loss_avg = val_total_loss / val_samples

print("\n Final Validation Evaluation (MAE-pretrained ViT) ")
print(f"Validation Accuracy: {val_acc:.4f}")
print(f"Validation Loss:     {val_loss_avg:.4f}")
print(f"Correct Predictions: {val_correct}/{val_samples}")

In [None]:
from tqdm import tqdm

best_mae.eval()

all_ids = []
all_preds = []

with torch.no_grad():
    for imgs, img_ids in tqdm(test_loader, desc="Running test inference"):
        imgs = imgs.to(DEVICE)

        logits, _ = best_mae(imgs)   # attr_pred not needed
        preds = logits.argmax(dim=1).cpu().numpy()

        all_ids.extend(img_ids.tolist())
        all_preds.extend(preds.tolist())

print(f"\nTotal test samples processed: {len(all_preds)}")

# prediction distribution
all_preds_np = np.array(all_preds)
unique, counts = np.unique(all_preds_np, return_counts=True)

print("\nPredicted class distribution (0-based):")
for u, c in zip(unique, counts):
    print(f"  class {u}: {c} samples")

# submission (1-based labels if required)
kaggle_labels = [p + 1 for p in all_preds]

submission = pd.DataFrame({
    "id": all_ids,
    "label": kaggle_labels,
}).sort_values("id")

print("\nFirst few predictions:")
print(submission.head())

submission.to_csv("mae_nuclear_submission.csv", index=False)
print("\n Saved mae_nuclear_submission.csv")

Running test inference: 100%|██████████| 125/125 [01:55<00:00,  1.08it/s]


Total test samples processed: 4000

Predicted class distribution (0-based):
  class 0: 462 samples
  class 1: 3 samples
  class 2: 5 samples
  class 3: 67 samples
  class 6: 155 samples
  class 7: 33 samples
  class 9: 165 samples
  class 12: 499 samples
  class 13: 344 samples
  class 15: 75 samples
  class 16: 606 samples
  class 17: 3 samples
  class 19: 70 samples
  class 20: 12 samples
  class 21: 1 samples
  class 25: 71 samples
  class 26: 9 samples
  class 29: 3 samples
  class 30: 357 samples
  class 33: 299 samples
  class 36: 9 samples
  class 39: 52 samples
  class 43: 9 samples
  class 44: 45 samples
  class 49: 39 samples
  class 57: 1 samples
  class 63: 1 samples
  class 68: 1 samples
  class 70: 261 samples
  class 71: 228 samples
  class 84: 34 samples
  class 87: 2 samples
  class 130: 79 samples

First few predictions:
   id  label
0   1     34
1   2     17
2   3      1
3   4     10
4   5     17

✓ Saved mae_nuclear_submission.csv



