In [None]:
import os
from datetime import datetime
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

from torchvision import transforms
from torchvision.transforms import InterpolationMode

from torchvision.models import vit_b_16
vit_backbone = vit_b_16(weights='IMAGENET1K_V1')

from sklearn.model_selection import train_test_split

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        if self.transform:
            image = self.transform(image)

        return image, super_idx, super_label, sub_idx, sub_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): 
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [None]:
train_ann_df = pd.read_csv('/content/drive/MyDrive/Released_Data_NNDL_2025/train_data_novel.csv')
super_map_df = pd.read_csv('/content/drive/MyDrive/Released_Data_NNDL_2025/superclass_mapping.csv')
sub_map_df = pd.read_csv('/content/drive/MyDrive/Released_Data_NNDL_2025/subclass_mapping.csv')

train_img_dir = '/content/drive/MyDrive/Released_Data_NNDL_2025/train_images_with_novel'
test_img_dir = '/content/drive/MyDrive/Released_Data_NNDL_2025/test_images'

train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8,1.0),interpolation=InterpolationMode.LANCZOS),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ])

val_test_transform = transforms.Compose([
    transforms.Resize((224,224), interpolation=InterpolationMode.LANCZOS),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

# Create train and val split
train_idx, val_idx = train_test_split(
    np.arange(len(train_ann_df)),
    test_size=0.1,
    random_state=42,
    stratify=train_ann_df["superclass_index"].values
)

train_df = train_ann_df.iloc[train_idx].reset_index(drop=True)
val_df   = train_ann_df.iloc[val_idx].reset_index(drop=True)

train_dataset = MultiClassImageDataset(train_df, super_map_df, sub_map_df, train_img_dir, transform=train_transform)

val_dataset = MultiClassImageDataset(val_df,super_map_df, sub_map_df, train_img_dir,transform=val_test_transform)

# Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=val_test_transform)

# Create dataloaders
batch_size = 64
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True)

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)

test_loader = DataLoader(test_dataset,
                         batch_size=8,
                         shuffle=False,
                         num_workers=4,
                         pin_memory=True)

In [None]:
class ViTMultiLabel(nn.Module):
    def __init__(self, vit_backbone, num_super_classes=4, num_sub_classes=88, bottleneck_dim=512,
                 dropout_p=0.3):
        super().__init__()
        self.vit = vit_backbone
        
        self.vit.heads = nn.Identity()
        self.hidden_dim = self.vit.hidden_dim
        self.super_fc = nn.Linear(self.hidden_dim, num_super_classes)
        self.sub_fc = nn.Linear(self.hidden_dim, num_sub_classes)

        self.adapter = nn.Sequential(
            nn.Linear(self.hidden_dim, bottleneck_dim),
            nn.GELU(),
            nn.Dropout(dropout_p),
            nn.Linear(bottleneck_dim, self.hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout_p)
        )

    def forward(self, x):
        x = self.vit(x)
        x = self.adapter(x)
        super_out = self.super_fc(x)
        sub_out = self.sub_fc(x)
        return super_out, sub_out

In [None]:
# Trainer
class Trainer():
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model.to(device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def train_epoch(self):
        running_loss = 0.0
        for i, data in enumerate(self.train_loader):
            inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

            self.optimizer.zero_grad()

            super_out, sub_out = self.model(inputs)

            probs_super = F.softmax(super_out, dim=1)
            probs_sub   = F.softmax(sub_out, dim=1)

            uniform_super = torch.full_like(probs_super, 1 / probs_super.size(1))
            uniform_sub   = torch.full_like(probs_sub, 1 / probs_sub.size(1))

            novel_super_mask = (super_labels == 3)
            novel_sub_mask   = (sub_labels == 87)

            ce_super = F.cross_entropy(super_out, super_labels)
            ce_sub   = F.cross_entropy(sub_out,  sub_labels)

            # OE-style Total Variation Loss
            tv_super = torch.tensor(0.0, device=device)
            tv_sub   = torch.tensor(0.0, device=device)

            if novel_super_mask.any():
                tv_super = torch.sum(
                    torch.abs(probs_super[novel_super_mask] - uniform_super[novel_super_mask]),
                    dim=1
                ).mean()

            if novel_sub_mask.any():
                tv_sub = torch.sum(
                    torch.abs(probs_sub[novel_sub_mask] - uniform_sub[novel_sub_mask]),
                    dim=1
                ).mean()

            λ_tv = 0.05
            loss = ce_super + ce_sub + λ_tv * (tv_super + tv_sub)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/i:.3f}')

    def validate_epoch(self):
        super_correct_all = 0
        sub_correct_all = 0
        super_correct_seen = 0
        super_correct_novel = 0
        sub_correct_seen = 0
        sub_correct_novel = 0
        total = 0
        seen_super_total = 0
        novel_super_total = 0
        seen_sub_total = 0
        novel_sub_total = 0
        running_loss = 0.0
        ce_super_total = 0.0
        ce_sub_total = 0.0

        novel_super_idx = 3
        novel_sub_idx = 87

        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                super_out, sub_out = self.model(inputs)

                _, super_preds = torch.max(super_out, 1)
                _, sub_preds   = torch.max(sub_out, 1)

                # Separate CE losses
                ce_super = self.criterion(super_out, super_labels)
                ce_sub = self.criterion(sub_out, sub_labels)
                loss = ce_super + ce_sub

                running_loss += loss.item()
                ce_super_total += ce_super.item()
                ce_sub_total += ce_sub.item()

                total += super_labels.size(0)
                super_correct_all += (super_preds == super_labels).sum().item()
                sub_correct_all += (sub_preds == sub_labels).sum().item()

                # Superclass: Seen vs Novel
                for j in range(super_labels.size(0)):
                    label = super_labels[j].item()
                    if label == novel_super_idx:
                        novel_super_total += 1
                        if super_preds[j] == super_labels[j]:
                            super_correct_novel += 1
                    else:
                        seen_super_total += 1
                        if super_preds[j] == super_labels[j]:
                            super_correct_seen += 1

                # Subclass: Seen vs Novel
                for j in range(sub_labels.size(0)):
                    label = sub_labels[j].item()
                    if label == novel_sub_idx:
                        novel_sub_total += 1
                        if sub_preds[j] == sub_labels[j]:
                            sub_correct_novel += 1
                    else:
                        seen_sub_total += 1
                        if sub_preds[j] == sub_labels[j]:
                            sub_correct_seen += 1

        # Avoid division by zero
        seen_super_acc = 100 * super_correct_seen / seen_super_total if seen_super_total > 0 else 0
        novel_super_acc = 100 * super_correct_novel / novel_super_total if novel_super_total > 0 else 0
        seen_sub_acc = 100 * sub_correct_seen / seen_sub_total if seen_sub_total > 0 else 0
        novel_sub_acc = 100 * sub_correct_novel / novel_sub_total if novel_sub_total > 0 else 0

        # Final Output
        overall_cross_entropy = running_loss / len(self.val_loader)
        print(f'Cross-Entropy: Superclass={ce_super_total / len(self.val_loader):.4f} | Subclass={ce_sub_total / len(self.val_loader):.4f}')
        print(f'Overall Cross-Entropy Loss: {overall_cross_entropy:.4f}')
        print(f'Superclass Acc: Overall={100*super_correct_all/total:.2f}% | Seen={seen_super_acc:.2f}% | Novel={novel_super_acc:.2f}%')
        print(f'Subclass  Acc: Overall={100*sub_correct_all/total:.2f}% | Seen={seen_sub_acc:.2f}% | Novel={novel_sub_acc:.2f}%')

        return overall_cross_entropy


    def test(self, save_to_csv=False, return_predictions=False):
        if not self.test_loader:
            raise NotImplementedError('test_loader not specified')

        # Evaluate on test set, no special care is taken for novel/unseen classes
        test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                inputs, img_name = data[0].to(device), data[1]

                super_outputs, sub_outputs = self.model(inputs)
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                for j in range(inputs.size(0)):  # modified for different batch sizes
                  test_predictions['image'].append(img_name[j])
                  test_predictions['superclass_index'].append(super_predicted[j].item())
                  test_predictions['subclass_index'].append(sub_predicted[j].item())

        test_predictions = pd.DataFrame(data=test_predictions)

        if save_to_csv:
            test_predictions.to_csv('test_predictions.csv', index=False)

        if return_predictions:
            return test_predictions

In [None]:
# Training Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vit_backbone = vit_b_16(weights='IMAGENET1K_V1')
model = ViTMultiLabel(vit_backbone).to(device)
print("Current device:", device)
print("Model device:", next(model.parameters()).device)
criterion = nn.CrossEntropyLoss()

Current device: cuda
Model device: cuda:0


In [None]:
# The following two-phase training strategy was designed with the assistance of OpenAI's ChatGPT.
# ChatGPT was used to help clarify the fine-tuning protocol for Vision Transformers,
# especially for staged training with adapter tuning and partial unfreezing.
# Final implementation and experimental validation were conducted independently.

In [None]:
global_epoch = 0

# ─── Phase 1：Freeze backbone ─────────────────────────────────
for p in model.vit.parameters():
    p.requires_grad = False
optimizer = optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-3, weight_decay=1e-4
)

trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, test_loader, device=device)

print("=== Phase 1: train Adapter+FC ===")
for epoch in range(5):            # Fixed to 5 epochs
    global_epoch += 1
    print(f"\n— Epoch {global_epoch}:")
    trainer.train_epoch()
    trainer.validate_epoch()

# ─── Phase 2：Unfreeze last 2 layers with Early Stopping ────────
for p in model.vit.encoder.layers[-2:].parameters():
    p.requires_grad = True

backbone_params = list(model.vit.encoder.layers[-2:].parameters())
backbone_ids    = set(id(p) for p in backbone_params)

head_params = [p for p in model.parameters()
               if p.requires_grad and id(p) not in backbone_ids]

optimizer = optim.AdamW([
    {"params": head_params,     "lr": 1e-3},
    {"params": backbone_params, "lr": 1e-5}
], weight_decay=1e-4)
trainer.optimizer = optimizer          

best_val_loss = float('inf')
patience, patience_cnt = 5, 0
best_epoch, best_model_path = 0, "best_model.pth"

print("\n=== Phase 2: fine‑tune last 2 blocks (early stop) ===")
for epoch in range(30):              
    global_epoch += 1
    print(f"\n— Epoch {global_epoch}:")
    trainer.train_epoch()
    vl = trainer.validate_epoch()

    if vl < best_val_loss - 1e-4:      
        best_val_loss, best_epoch = vl, epoch
        torch.save(model.state_dict(), best_model_path)
        patience_cnt = 0
        print(f"new best!  saved to {best_model_path}")
    else:
        patience_cnt += 1
        print(f"no improv. patience {patience_cnt}/{patience}")

    if patience_cnt >= patience:
        print("Early stopping")
        break

print(f"\nTraining done. Best epoch = {best_epoch}, val_loss = {best_val_loss:.4f}")

=== Phase 1: train Adapter+FC ===

— Epoch 1:
Training loss: 1.422
Cross-Entropy: Superclass=0.0101 | Subclass=0.3932
Overall Cross-Entropy Loss: 0.4033
Superclass Acc: Overall=99.77% | Seen=100.00% | Novel=96.97%
Subclass  Acc: Overall=86.31% | Seen=80.52% | Novel=100.00%

— Epoch 2:
Training loss: 0.459
Cross-Entropy: Superclass=0.0122 | Subclass=0.2840
Overall Cross-Entropy Loss: 0.2962
Superclass Acc: Overall=99.55% | Seen=99.76% | Novel=96.97%
Subclass  Acc: Overall=91.52% | Seen=87.92% | Novel=100.00%

— Epoch 3:
Training loss: 0.366
Cross-Entropy: Superclass=0.0082 | Subclass=0.2614
Overall Cross-Entropy Loss: 0.2696
Superclass Acc: Overall=99.66% | Seen=99.88% | Novel=96.97%
Subclass  Acc: Overall=91.29% | Seen=87.60% | Novel=100.00%

— Epoch 4:
Training loss: 0.325
Cross-Entropy: Superclass=0.0069 | Subclass=0.1943
Overall Cross-Entropy Loss: 0.2012
Superclass Acc: Overall=99.77% | Seen=99.88% | Novel=98.48%
Subclass  Acc: Overall=94.34% | Seen=91.95% | Novel=100.00%

— Epoch 

In [None]:
# Test and Save Prediction
model.load_state_dict(torch.load(best_model_path))
test_predictions = trainer.test(save_to_csv=False, return_predictions=True)

drive_root = "/content/drive/MyDrive"
output_dir = os.path.join(drive_root, "Released_Data_NNDL_2025", "test_predictions")
os.makedirs(output_dir, exist_ok=True)

notebook_name = "NNDL_ViT_v11.ipynb"

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_name = f'test_predictions_{timestamp}'
csv_filename = os.path.join(output_dir, base_name + ".csv")
meta_filename = os.path.join(output_dir, base_name + "_info.txt")

test_predictions.to_csv(csv_filename, index=False)

with open(meta_filename, 'w') as f:
    f.write(f"Best Epoch: {best_epoch}\n")
    f.write(f"Validation Loss: {best_val_loss:.4f}\n")
    f.write(f"Notebook Source: {notebook_name}\n")

print(f"Saved to: {csv_filename}")

Saved to: /content/drive/MyDrive/Released_Data_NNDL_2025/test_predictions/test_predictions_20250511_203808.csv
