# 1. Imports & Device

In [1]:
import os
from pathlib import Path
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.optim as optim
from transformers import ViTMAEModel, get_cosine_schedule_with_warmup
import pandas as pd
import torchvision.transforms as T

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime

print(f"PyTorch version: {torch.__version__}")
import transformers
print(f"Transformers version: {transformers.__version__}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.5.1+cu121
Transformers version: 4.50.3
Using device: cuda


# 2. Configuration

In [2]:
data_root = Path(r"D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset")
list_dir  = data_root / "list"

class Config:
    data_root = data_root
    list_dir  = list_dir

    train_list_path = list_dir / "train.txt"
    test_list_path  = list_dir / "test.txt"
    gt_list_path    = list_dir / "groundtruth.txt"

    model_name = "facebook/vit-mae-large"
    num_classes = 100

    img_size   = 224
    batch_size = 16
    num_workers = 0
    pin_memory = False
    persistent_workers = False

    epochs        = 30
    lr            = 1e-3
    weight_decay  = 1e-4
    warmup_ratio  = 0.1
    warmup_epochs = 1

    seed = 42
    # out_dir = Path("./test")
    out_dir = Path("./runs_mae_freeze_large_NOAUG")
    out_dir.mkdir(parents=True, exist_ok=True)

# 3. Seed

In [3]:
def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(Config.seed)

# 4. Transforms

In [4]:
USE_AUG = False

# ===============================
# Define all transforms
# ===============================

IMAGE_SIZE = 224  # 如果你文件里是别的，请同步修改

# Herbarium heavy augmentation
train_herbarium_transform = T.Compose([
    T.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(45),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Photo light augmentation
train_photo_transform = T.Compose([
    T.RandomResizedCrop(IMAGE_SIZE),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Validation/Test: no augmentation
eval_transform = T.Compose([
    T.Resize(IMAGE_SIZE + 32),
    T.CenterCrop(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ===============================
# Unified transform selector
# ===============================
def pick_transform(rel_path: str, train: bool = True):
    """
    Automatically pick transform based on folder name.
    rel_path: dataset path (string), used to detect 'herbarium' or 'photo'.
    train: if False → always use eval_transform
    """
    low = rel_path.lower()

    # -----------------------
    # Validation/test mode: never augment
    # -----------------------
    if not train:
        return eval_transform

    # -----------------------
    # Training but augmentation disabled
    # -----------------------
    if not USE_AUG:
        return eval_transform

    # -----------------------
    # Training with augmentation enabled
    # -----------------------
    if "herbarium" in low:
        return train_herbarium_transform

    if "photo" in low:
        return train_photo_transform

    # Default → treat as Photo dataset
    return train_photo_transform

# 5. Loading train/test labels

In [5]:
def read_groundtruth(gt_path: Path):
    mapping = {}
    with open(gt_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                mapping[parts[0]] = int(parts[-1])
    return mapping


def read_train_list(list_path: Path):
    """Expect: train.txt lines are '<rel_path> <label>'"""
    samples = []
    with open(list_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2:
                raise ValueError(f"Train line has no label: {line}")
            rel_path = parts[0]
            label = int(parts[1])
            samples.append((rel_path, label))
    return samples


def read_test_list(list_path: Path):
    """Expect: test.txt lines are either '<rel_path>' or '<rel_path> <label>'"""
    samples = []
    with open(list_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if not parts:
                continue
            rel_path = parts[0]
            label = int(parts[1]) if len(parts) >= 2 else None
            samples.append((rel_path, label))
    return samples


print("=== Path Check ===")
print(f"{Config.data_root} -> {Config.data_root.exists()}")
print(f"{Config.train_list_path} -> {Config.train_list_path.exists()}")
print(f"{Config.test_list_path} -> {Config.test_list_path.exists()}")
print(f"{Config.gt_list_path} -> {Config.gt_list_path.exists()}")

gt_mapping = read_groundtruth(Config.gt_list_path)
train_raw = read_train_list(Config.train_list_path)
test_raw  = read_test_list(Config.test_list_path)

print(f"Line counts -> train={len(train_raw)}, test={len(test_raw)}, groundtruth={len(gt_mapping)}")

# Build raw samples with original labels
train_samples_raw = []
val_samples_raw   = []
unmatched = 0

# Train: labels come directly from train.txt
for rel_path, label in train_raw:
    full_path = Config.data_root / rel_path
    train_samples_raw.append((full_path, label))

# Val/Test: prefer label in test.txt, otherwise use groundtruth.txt
for rel_path, label in test_raw:
    if label is None:
        label = gt_mapping.get(rel_path, None)
    if label is None:
        unmatched += 1
        continue
    full_path = Config.data_root / rel_path
    val_samples_raw.append((full_path, label))

all_labels_raw = [lbl for _, lbl in train_samples_raw] + [lbl for _, lbl in val_samples_raw]
unique_labels = sorted(set(all_labels_raw))

# Map original labels -> contiguous [0, num_classes-1]
label_to_idx = {lab: idx for idx, lab in enumerate(unique_labels)}
idx_to_label = {idx: lab for lab, idx in label_to_idx.items()}

# Apply mapping
train_samples = [(path, label_to_idx[lbl]) for path, lbl in train_samples_raw]
val_samples   = [(path, label_to_idx[lbl]) for path, lbl in val_samples_raw]

Config.num_classes = len(unique_labels)

print(f"[Info] Original label min={min(all_labels_raw)}, max={max(all_labels_raw)}")
print(f"[Info] Number of unique labels={Config.num_classes}")
print(f"[Info] Train samples: {len(train_samples)}, Val samples: {len(val_samples)}, Unmatched test: {unmatched}")
print(f"[Info] Labels have been remapped to range [0, {Config.num_classes - 1}]")

=== Path Check ===
D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset -> True
D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset\list\train.txt -> True
D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset\list\test.txt -> True
D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset\list\groundtruth.txt -> True
Line counts -> train=4744, test=207, groundtruth=207
[Info] Original label min=12254, max=285398
[Info] Number of unique labels=100
[Info] Train samples: 4744, Val samples: 207, Unmatched test: 0
[Info] Labels have been remapped to range [0, 99]


# 6. Dataset & DataLoaders

In [6]:
class PlantsDataset(Dataset):
    def __init__(self, samples, train: bool = True):
        """
        samples: list of (full_path: Path, label_idx: int)
        train : True 表示训练集（会根据 USE_AUG 决定要不要 augmentation）
                False 表示验证/测试集（永远不用 augmentation）
        """
        self.samples = samples
        self.train = train

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        # 根据路径和 train / val 选择对应的 transform
        # pick_transform 在前面已经定义好了：
        #   def pick_transform(rel_path: str, train: bool = True): ...
        rel_path = str(img_path)
        transform = pick_transform(rel_path, train=self.train)

        img = transform(img)
        return img, label


# === Datasets ===
train_dataset = PlantsDataset(train_samples, train=True)
val_dataset   = PlantsDataset(val_samples,   train=False)

# === Dataloaders ===
train_loader = DataLoader(
    train_dataset,
    batch_size=Config.batch_size,
    shuffle=True,
    num_workers=Config.num_workers,
    pin_memory=Config.pin_memory,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=Config.batch_size,
    shuffle=False,
    num_workers=Config.num_workers,
    pin_memory=Config.pin_memory,
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

Train batches: 297, Val batches: 13


# 7. MAE Classification Model (Frozen Backbone)

In [7]:
class PlantMAEClassifier(nn.Module):
    def __init__(self, model_name: str, num_classes: int):
        super().__init__()
        # 1. 加载 MAE backbone（这里会根据 Config.model_name 加载 base 或 large）
        self.backbone = ViTMAEModel.from_pretrained(model_name)

        # 2. 自动拿到 hidden size（Base=768, Large=1024），不用手写
        hidden_dim = self.backbone.config.hidden_size

        # 3. 定义我们自己的分类 head
        self.head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, num_classes),
        )

        # 4. 冻结 backbone，只训练 head
        for p in self.backbone.parameters():
            p.requires_grad = False

    def forward(self, pixel_values):
        # ViTMAEModel 输出 last_hidden_state: [B, num_patches+1, hidden_dim]
        outputs = self.backbone(pixel_values=pixel_values)
        cls_token = outputs.last_hidden_state[:, 0]   # 取 [CLS] token
        logits = self.head(cls_token)
        return logits

# 实例化模型
model = PlantMAEClassifier(
    model_name=Config.model_name,          # 这里建议已经设为 "facebook/vit-mae-large"
    num_classes=Config.num_classes,
).to(device)

# 打印一下参数情况，确认只有 head 在训练
backbone_trainable = sum(p.numel() for p in model.backbone.parameters() if p.requires_grad)
head_trainable     = sum(p.numel() for p in model.head.parameters() if p.requires_grad)

print(f"Backbone trainable params: {backbone_trainable}")
print(f"Head trainable params:     {head_trainable}")
print(f"Total params:              {sum(p.numel() for p in model.parameters())}")

Backbone trainable params: 0
Head trainable params:     1154148
Total params:              304455780


# 7. Loss / Optimizer / Scheduler / AMP

In [8]:
criterion = nn.CrossEntropyLoss()

# 只训练 head（backbone 已经 requires_grad=False）
head_params = model.head.parameters()

optimizer = optim.AdamW(
    head_params,
    lr=Config.lr,                    # 比如 1e-3
    weight_decay=Config.weight_decay # 比如 1e-4
)

# 训练总步数 & warmup 步数
num_train_steps = max(1, len(train_loader)) * Config.epochs
num_warmup_steps = int(Config.warmup_epochs * len(train_loader))

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_train_steps,
)

scaler = torch.amp.GradScaler("cuda")

print(f"Total training steps: {num_train_steps}")
print(f"Warmup steps:         {num_warmup_steps}")

Total training steps: 8910
Warmup steps:         297


# 8. Training Loop

In [9]:
# Create logs directory
logs_dir = Config.out_dir / "logs"
logs_dir.mkdir(parents=True, exist_ok=True)

# Create models directory (only for BEST model)
models_dir = Config.out_dir / "models"
models_dir.mkdir(parents=True, exist_ok=True)

# Fixed log file name (NO timestamp)
log_path = logs_dir / "training_log.txt"

def LogWrite(text):
    """Helper function to write logs to file + print on screen."""
    print(text)
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(text + "\n")


def train_one_epoch(epoch):
    model.train()
    running_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.amp.autocast("cuda", enabled=device.type == "cuda"):
            logits = model(imgs)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        pbar.set_postfix(loss=running_loss/total, acc=correct/total*100)

    epoch_loss = running_loss / total
    epoch_acc = correct / total * 100
    return epoch_loss, epoch_acc


@torch.no_grad()
def validate(epoch):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)

        with torch.amp.autocast("cuda", enabled=device.type == "cuda"):
            logits = model(imgs)
            loss = criterion(logits, labels)

        running_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        pbar.set_postfix(loss=running_loss/total, acc=correct/total*100)

    epoch_loss = running_loss / total
    epoch_acc = correct / total * 100
    return epoch_loss, epoch_acc


# ========== Training Loop ==========
best_val_acc = 0
best_model_path = models_dir / "mae_frozen_best.pth"

LogWrite("==== Training Started ====")
LogWrite(f"Saving logs to: {log_path}")
LogWrite(f"Best model will be saved to: {best_model_path}\n")

for epoch in range(Config.epochs):
    train_loss, train_acc = train_one_epoch(epoch)
    val_loss, val_acc = validate(epoch)

    LogWrite(
        f"[Epoch {epoch+1}] "
        f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}% | "
        f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%"
    )

    # ===== Save ONLY best model =====
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        LogWrite(
            f"--> New Best Model Saved at {best_model_path} "
            f"(Val Acc: {val_acc:.2f}%)\n"
        )

LogWrite("==== Training Completed ====")
LogWrite(f"Best Validation Accuracy: {best_val_acc:.2f}%")
LogWrite(f"Best Model Path: {best_model_path}")

==== Training Started ====
Saving logs to: runs_mae_freeze_large_NOAUG\logs\training_log.txt
Best model will be saved to: runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth



Epoch 1 [Train]:   0%|          | 0/297 [00:00<?, ?it/s]

Epoch 1 [Train]: 100%|██████████| 297/297 [02:07<00:00,  2.34it/s, acc=7.76, loss=4.16]
Epoch 1 [Val]: 100%|██████████| 13/13 [00:05<00:00,  2.19it/s, acc=13.5, loss=4.3] 


[Epoch 1] Train Loss=4.1567, Train Acc=7.76% | Val Loss=4.3049, Val Acc=13.53%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 13.53%)



Epoch 2 [Train]: 100%|██████████| 297/297 [01:20<00:00,  3.71it/s, acc=31.1, loss=2.78]
Epoch 2 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.35it/s, acc=29.5, loss=3.75]


[Epoch 2] Train Loss=2.7845, Train Acc=31.13% | Val Loss=3.7514, Val Acc=29.47%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 29.47%)



Epoch 3 [Train]: 100%|██████████| 297/297 [01:20<00:00,  3.68it/s, acc=45.2, loss=2.05]
Epoch 3 [Val]: 100%|██████████| 13/13 [00:04<00:00,  3.24it/s, acc=35.3, loss=3.67]


[Epoch 3] Train Loss=2.0549, Train Acc=45.15% | Val Loss=3.6673, Val Acc=35.27%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 35.27%)



Epoch 4 [Train]: 100%|██████████| 297/297 [01:18<00:00,  3.78it/s, acc=57.6, loss=1.57]
Epoch 4 [Val]: 100%|██████████| 13/13 [00:04<00:00,  3.23it/s, acc=39.1, loss=3.59]


[Epoch 4] Train Loss=1.5679, Train Acc=57.59% | Val Loss=3.5860, Val Acc=39.13%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 39.13%)



Epoch 5 [Train]: 100%|██████████| 297/297 [01:17<00:00,  3.81it/s, acc=64.5, loss=1.26]
Epoch 5 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.27it/s, acc=41.5, loss=3.58]


[Epoch 5] Train Loss=1.2633, Train Acc=64.50% | Val Loss=3.5814, Val Acc=41.55%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 41.55%)



Epoch 6 [Train]: 100%|██████████| 297/297 [01:19<00:00,  3.72it/s, acc=72.1, loss=1.02] 
Epoch 6 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.29it/s, acc=43.5, loss=3.72]


[Epoch 6] Train Loss=1.0176, Train Acc=72.07% | Val Loss=3.7166, Val Acc=43.48%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 43.48%)



Epoch 7 [Train]: 100%|██████████| 297/297 [01:22<00:00,  3.58it/s, acc=76.6, loss=0.828]
Epoch 7 [Val]: 100%|██████████| 13/13 [00:04<00:00,  2.68it/s, acc=43.5, loss=3.69]


[Epoch 7] Train Loss=0.8278, Train Acc=76.62% | Val Loss=3.6929, Val Acc=43.48%


Epoch 8 [Train]: 100%|██████████| 297/297 [01:20<00:00,  3.67it/s, acc=80.5, loss=0.671]
Epoch 8 [Val]: 100%|██████████| 13/13 [00:04<00:00,  3.10it/s, acc=43.5, loss=3.91]


[Epoch 8] Train Loss=0.6707, Train Acc=80.50% | Val Loss=3.9072, Val Acc=43.48%


Epoch 9 [Train]: 100%|██████████| 297/297 [01:17<00:00,  3.83it/s, acc=83.6, loss=0.573]
Epoch 9 [Val]: 100%|██████████| 13/13 [00:04<00:00,  3.24it/s, acc=50.2, loss=3.93]


[Epoch 9] Train Loss=0.5732, Train Acc=83.64% | Val Loss=3.9321, Val Acc=50.24%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 50.24%)



Epoch 10 [Train]: 100%|██████████| 297/297 [01:17<00:00,  3.84it/s, acc=86.1, loss=0.468]
Epoch 10 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.37it/s, acc=47.8, loss=4.21]


[Epoch 10] Train Loss=0.4678, Train Acc=86.07% | Val Loss=4.2146, Val Acc=47.83%


Epoch 11 [Train]: 100%|██████████| 297/297 [01:16<00:00,  3.91it/s, acc=88.4, loss=0.394]
Epoch 11 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.31it/s, acc=46.4, loss=4.53]


[Epoch 11] Train Loss=0.3942, Train Acc=88.39% | Val Loss=4.5269, Val Acc=46.38%


Epoch 12 [Train]: 100%|██████████| 297/297 [01:15<00:00,  3.92it/s, acc=89.9, loss=0.338]
Epoch 12 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.38it/s, acc=48.3, loss=4.33]


[Epoch 12] Train Loss=0.3379, Train Acc=89.92% | Val Loss=4.3311, Val Acc=48.31%


Epoch 13 [Train]: 100%|██████████| 297/297 [01:15<00:00,  3.93it/s, acc=91.8, loss=0.289]
Epoch 13 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.38it/s, acc=49.3, loss=4.33]


[Epoch 13] Train Loss=0.2885, Train Acc=91.84% | Val Loss=4.3326, Val Acc=49.28%


Epoch 14 [Train]: 100%|██████████| 297/297 [01:15<00:00,  3.92it/s, acc=92.8, loss=0.241]
Epoch 14 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.31it/s, acc=48.8, loss=4.58]


[Epoch 14] Train Loss=0.2408, Train Acc=92.83% | Val Loss=4.5833, Val Acc=48.79%


Epoch 15 [Train]: 100%|██████████| 297/297 [01:15<00:00,  3.92it/s, acc=94.2, loss=0.209]
Epoch 15 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.42it/s, acc=51.2, loss=4.32]


[Epoch 15] Train Loss=0.2095, Train Acc=94.20% | Val Loss=4.3169, Val Acc=51.21%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 51.21%)



Epoch 16 [Train]: 100%|██████████| 297/297 [01:15<00:00,  3.94it/s, acc=94.9, loss=0.19] 
Epoch 16 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.29it/s, acc=47.3, loss=4.61]


[Epoch 16] Train Loss=0.1900, Train Acc=94.86% | Val Loss=4.6111, Val Acc=47.34%


Epoch 17 [Train]: 100%|██████████| 297/297 [01:16<00:00,  3.89it/s, acc=95.2, loss=0.165]
Epoch 17 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.41it/s, acc=49.8, loss=4.58]


[Epoch 17] Train Loss=0.1647, Train Acc=95.24% | Val Loss=4.5777, Val Acc=49.76%


Epoch 18 [Train]: 100%|██████████| 297/297 [01:16<00:00,  3.90it/s, acc=96.6, loss=0.137]
Epoch 18 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.46it/s, acc=53.6, loss=4.58]


[Epoch 18] Train Loss=0.1369, Train Acc=96.56% | Val Loss=4.5825, Val Acc=53.62%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 53.62%)



Epoch 19 [Train]: 100%|██████████| 297/297 [01:16<00:00,  3.89it/s, acc=96.2, loss=0.131]
Epoch 19 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.43it/s, acc=50.7, loss=4.49]


[Epoch 19] Train Loss=0.1309, Train Acc=96.25% | Val Loss=4.4923, Val Acc=50.72%


Epoch 20 [Train]: 100%|██████████| 297/297 [01:15<00:00,  3.93it/s, acc=97, loss=0.118]  
Epoch 20 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.42it/s, acc=48.3, loss=4.73]


[Epoch 20] Train Loss=0.1179, Train Acc=97.05% | Val Loss=4.7324, Val Acc=48.31%


Epoch 21 [Train]: 100%|██████████| 297/297 [01:15<00:00,  3.92it/s, acc=97.6, loss=0.102] 
Epoch 21 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.47it/s, acc=51.2, loss=4.72]


[Epoch 21] Train Loss=0.1018, Train Acc=97.64% | Val Loss=4.7247, Val Acc=51.21%


Epoch 22 [Train]: 100%|██████████| 297/297 [01:14<00:00,  3.96it/s, acc=97.7, loss=0.0952]
Epoch 22 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.34it/s, acc=51.2, loss=4.73]


[Epoch 22] Train Loss=0.0952, Train Acc=97.66% | Val Loss=4.7286, Val Acc=51.21%


Epoch 23 [Train]: 100%|██████████| 297/297 [01:14<00:00,  3.97it/s, acc=98.1, loss=0.0836]
Epoch 23 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.49it/s, acc=50.7, loss=4.82]


[Epoch 23] Train Loss=0.0836, Train Acc=98.15% | Val Loss=4.8209, Val Acc=50.72%


Epoch 24 [Train]: 100%|██████████| 297/297 [01:18<00:00,  3.81it/s, acc=98.5, loss=0.0775]
Epoch 24 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.25it/s, acc=51.7, loss=4.83]


[Epoch 24] Train Loss=0.0775, Train Acc=98.48% | Val Loss=4.8297, Val Acc=51.69%


Epoch 25 [Train]: 100%|██████████| 297/297 [01:15<00:00,  3.93it/s, acc=98.1, loss=0.0819]
Epoch 25 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.44it/s, acc=51.2, loss=4.72]


[Epoch 25] Train Loss=0.0819, Train Acc=98.10% | Val Loss=4.7199, Val Acc=51.21%


Epoch 26 [Train]: 100%|██████████| 297/297 [01:19<00:00,  3.74it/s, acc=98.7, loss=0.0685]
Epoch 26 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.37it/s, acc=49.8, loss=4.79]


[Epoch 26] Train Loss=0.0685, Train Acc=98.67% | Val Loss=4.7905, Val Acc=49.76%


Epoch 27 [Train]: 100%|██████████| 297/297 [01:14<00:00,  3.97it/s, acc=98.6, loss=0.0692]
Epoch 27 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.47it/s, acc=52.7, loss=4.89]


[Epoch 27] Train Loss=0.0692, Train Acc=98.61% | Val Loss=4.8882, Val Acc=52.66%


Epoch 28 [Train]: 100%|██████████| 297/297 [01:14<00:00,  3.97it/s, acc=98.5, loss=0.0699]
Epoch 28 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.35it/s, acc=54.1, loss=4.72]


[Epoch 28] Train Loss=0.0699, Train Acc=98.55% | Val Loss=4.7245, Val Acc=54.11%
--> New Best Model Saved at runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth (Val Acc: 54.11%)



Epoch 29 [Train]: 100%|██████████| 297/297 [01:14<00:00,  3.99it/s, acc=98.9, loss=0.0604]
Epoch 29 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.46it/s, acc=52.7, loss=4.69]


[Epoch 29] Train Loss=0.0604, Train Acc=98.92% | Val Loss=4.6945, Val Acc=52.66%


Epoch 30 [Train]: 100%|██████████| 297/297 [01:14<00:00,  3.96it/s, acc=98.7, loss=0.0612]
Epoch 30 [Val]: 100%|██████████| 13/13 [00:03<00:00,  3.45it/s, acc=52.7, loss=4.7] 

[Epoch 30] Train Loss=0.0612, Train Acc=98.69% | Val Loss=4.6958, Val Acc=52.66%
==== Training Completed ====
Best Validation Accuracy: 54.11%
Best Model Path: runs_mae_freeze_large_NOAUG\models\mae_frozen_best.pth





# 9. Evaluation

In [10]:
# === Folder to save results ===
eval_dir = Config.out_dir / "evaluation"
eval_dir.mkdir(parents=True, exist_ok=True)
print(f"[INFO] Saving evaluation results to: {eval_dir}")


# === helper: top-k accuracy ===
def topk_acc_from_topk_array(y_true_group, topk_array):
    """
    y_true_group : [N]
    topk_array   : [N, K]，每行是该样本 top-K 预测的类别 index
    返回: 百分比 (0-100)
    """
    if len(y_true_group) == 0:
        return None
    correct = np.any(topk_array == y_true_group[:, None], axis=1)
    return correct.mean() * 100.0


# === Collect predictions (Top-1 + Top-5) ===
@torch.no_grad()
def collect_preds(loader, k=5):
    model.eval()
    all_top1, all_topk, all_labels = [], [], []
    for imgs, labels in loader:
        imgs = imgs.to(device, non_blocking=True)
        logits = model(imgs)

        top1 = torch.argmax(logits, dim=1)           # [B]
        topk = torch.topk(logits, k=k, dim=1).indices  # [B, k]

        all_top1.append(top1.cpu().numpy())
        all_topk.append(topk.cpu().numpy())
        all_labels.append(labels.numpy())

    y_pred_top1 = np.concatenate(all_top1)      # [N]
    y_pred_topk = np.concatenate(all_topk)      # [N, k]
    y_true      = np.concatenate(all_labels)    # [N]
    return y_pred_top1, y_pred_topk, y_true


y_pred, y_top5, y_true = collect_preds(val_loader, k=5)

# === Overall Top-1 & Top-5 ===
overall_top1 = accuracy_score(y_true, y_pred) * 100.0
overall_top5 = topk_acc_from_topk_array(y_true, y_top5)

print(f"\n[Overall] Top-1 Accuracy: {overall_top1:.2f}%")
print(f"[Overall] Top-5 Accuracy: {overall_top5:.2f}%\n")


# === 1. Save Classification Report ===
report_str = classification_report(
    y_true,
    y_pred,
    digits=4,
    target_names=[str(idx_to_label[i]) for i in range(len(idx_to_label))]
)

report_path = eval_dir / "classification_report.txt"
with open(report_path, "w", encoding="utf-8") as f:
    f.write(f"Overall Top-1 Accuracy: {overall_top1:.4f}%\n")
    f.write(f"Overall Top-5 Accuracy: {overall_top5:.4f}%\n")
    f.write("\n")
    f.write(report_str)

print(f"[Saved] Classification report → {report_path}")


# === 2. Save Confusion Matrix Plot (Top-1) ===
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(idx_to_label))))

plt.figure(figsize=(8, 7))
plt.imshow(cm, cmap="Blues", interpolation="nearest")
plt.title("Confusion Matrix (counts)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.colorbar()
plt.tight_layout()

cm_path = eval_dir / "confusion_matrix.png"
plt.savefig(cm_path, dpi=300)
plt.close()

print(f"[Saved] Confusion matrix image → {cm_path}")


# === 3. Save Per-Class Metrics as CSV (from Top-1) ===
report_dict = classification_report(
    y_true,
    y_pred,
    digits=4,
    target_names=[str(idx_to_label[i]) for i in range(len(idx_to_label))],
    output_dict=True
)

metrics_df = pd.DataFrame(report_dict).transpose()
metrics_path = eval_dir / "per_class_metrics.csv"
metrics_df.to_csv(metrics_path, encoding="utf-8")

print(f"[Saved] Per-class CSV → {metrics_path}")

print("\n[✓] All evaluation results (including Top-1 & Top-5) saved successfully.")

# ====== Step 9 (extra): WITH / WITHOUT PAIRS 分组分析并保存到 evaluation 文件夹 ======

# 0. 确保 evaluation 目录存在（和前面 Step 9 保存 report 的目录一致）
eval_dir = Config.out_dir / "evaluation"
eval_dir.mkdir(parents=True, exist_ok=True)

print(f"\n[INFO] Saving with/without-pairs results to: {eval_dir}")

# 1. 设置 with / without pairs 的 class list 文件路径
#    确保这两个文件在当前工作目录下，或者改成你的完整路径
WITH_PAIRS_FILE = list_dir / "class_with_pairs.txt"
WITHOUT_PAIRS_FILE = list_dir / "class_without_pairs.txt"

# 2. 读入原始 class ID（每行一个）
with open(WITH_PAIRS_FILE, "r", encoding="utf-8") as f:
    with_pairs_ids = {line.strip() for line in f if line.strip()}

with open(WITHOUT_PAIRS_FILE, "r", encoding="utf-8") as f:
    without_pairs_ids = {line.strip() for line in f if line.strip()}

print(f"#classes in WITH-PAIRS list   : {len(with_pairs_ids)}")
print(f"#classes in WITHOUT-PAIRS list: {len(without_pairs_ids)}")

# 3. 把 y_true 里的「类别 index」转换成「原始 class ID 字符串」
#    假设 idx_to_label[i] 就是原始 ID（数字或字符串），统一转成 str 来对比
y_true_class_ids = np.array([str(idx_to_label[int(c)]) for c in y_true])

# 4. 根据 class ID mask 出 with-pair / without-pair 的样本
mask_with_pairs = np.isin(y_true_class_ids, list(with_pairs_ids))
mask_without_pairs = np.isin(y_true_class_ids, list(without_pairs_ids))

y_true_with = y_true[mask_with_pairs]
y_pred_with = y_pred[mask_with_pairs]
y_top5_with = y_top5[mask_with_pairs]

y_true_without = y_true[mask_without_pairs]
y_pred_without = y_pred[mask_without_pairs]
y_top5_without = y_top5[mask_without_pairs]

print("\n=== WITH-PAIRS / WITHOUT-PAIRS Result on Validation Set ===")
print(f"Samples in WITH-PAIRS group   : {len(y_true_with)}")
print(f"Samples in WITHOUT-PAIRS group: {len(y_true_without)}")

# 防止某个 group 为空
if len(y_true_with) > 0:
    acc_with_top1 = accuracy_score(y_true_with, y_pred_with) * 100.0
    acc_with_top5 = topk_acc_from_topk_array(y_true_with, y_top5_with)
    print(f"\nGroup: WITH PAIRS")
    print(f"  Top-1 Accuracy: {acc_with_top1:.2f}%")
    print(f"  Top-5 Accuracy: {acc_with_top5:.2f}%")
else:
    acc_with_top1 = None
    acc_with_top5 = None
    print("\nGroup: WITH PAIRS")
    print("  (No samples from WITH-PAIRS class IDs found in this validation set.)")

if len(y_true_without) > 0:
    acc_without_top1 = accuracy_score(y_true_without, y_pred_without) * 100.0
    acc_without_top5 = topk_acc_from_topk_array(y_true_without, y_top5_without)
    print(f"\nGroup: WITHOUT PAIRS")
    print(f"  Top-1 Accuracy: {acc_without_top1:.2f}%")
    print(f"  Top-5 Accuracy: {acc_without_top5:.2f}%")
else:
    acc_without_top1 = None
    acc_without_top5 = None
    print("\nGroup: WITHOUT PAIRS")
    print("  (No samples from WITHOUT-PAIRS class IDs found in this validation set.)")

# 5. 把 with / without 的结果也存到当前 run 的 evaluation 目录里
out_path = eval_dir / "val_with_without_pairs_results.txt"
with open(out_path, "w", encoding="utf-8") as f:
    f.write("WITH-PAIRS / WITHOUT-PAIRS Result on Validation Set\n")
    f.write(f"Samples in WITH-PAIRS group   : {len(y_true_with)}\n")
    f.write(f"Samples in WITHOUT-PAIRS group: {len(y_true_without)}\n\n")

    if acc_with_top1 is not None:
        f.write(f"WITH PAIRS Top-1 Accuracy   : {acc_with_top1:.4f}%\n")
        f.write(f"WITH PAIRS Top-5 Accuracy   : {acc_with_top5:.4f}%\n")
    else:
        f.write("WITH PAIRS: no samples in this val set\n")

    if acc_without_top1 is not None:
        f.write(f"WITHOUT PAIRS Top-1 Accuracy: {acc_without_top1:.4f}%\n")
        f.write(f"WITHOUT PAIRS Top-5 Accuracy: {acc_without_top5:.4f}%\n")
    else:
        f.write("WITHOUT PAIRS: no samples in this val set\n")

print(f"\n[Saved] With/without-pairs results → {out_path}")

[INFO] Saving evaluation results to: runs_mae_freeze_large_NOAUG\evaluation

[Overall] Top-1 Accuracy: 52.66%
[Overall] Top-5 Accuracy: 67.15%

[Saved] Classification report → runs_mae_freeze_large_NOAUG\evaluation\classification_report.txt


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Saved] Confusion matrix image → runs_mae_freeze_large_NOAUG\evaluation\confusion_matrix.png
[Saved] Per-class CSV → runs_mae_freeze_large_NOAUG\evaluation\per_class_metrics.csv

[✓] All evaluation results (including Top-1 & Top-5) saved successfully.

[INFO] Saving with/without-pairs results to: runs_mae_freeze_large_NOAUG\evaluation
#classes in WITH-PAIRS list   : 60
#classes in WITHOUT-PAIRS list: 40

=== WITH-PAIRS / WITHOUT-PAIRS Result on Validation Set ===
Samples in WITH-PAIRS group   : 153
Samples in WITHOUT-PAIRS group: 54

Group: WITH PAIRS
  Top-1 Accuracy: 71.24%
  Top-5 Accuracy: 89.54%

Group: WITHOUT PAIRS
  Top-1 Accuracy: 0.00%
  Top-5 Accuracy: 3.70%

[Saved] With/without-pairs results → runs_mae_freeze_large_NOAUG\evaluation\val_with_without_pairs_results.txt


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
