# 1. Imports & Device

In [1]:
import os
from pathlib import Path
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import torchvision.transforms as T

from transformers import ViTMAEModel, get_cosine_schedule_with_warmup

from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
import transformers
print(f"Transformers version: {transformers.__version__}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.5.1+cu121
Transformers version: 4.50.3
Using device: cuda


# 2. Configuration

In [2]:
data_root = Path(r"D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset")
list_dir  = data_root / "list"

class Config:
    data_root = data_root
    list_dir  = list_dir
    
    train_list_path = list_dir / "train.txt"
    test_list_path  = list_dir / "test.txt"
    gt_list_path    = list_dir / "groundtruth.txt"

    model_name = "facebook/vit-mae-large"
    num_classes = 100

    img_size   = 224
    batch_size = 16
    num_workers = 0
    pin_memory = False
    persistent_workers = False

    epochs        = 30
    lr            = 5e-5
    weight_decay  = 0.05
    
    # lr            = 1e-3
    # weight_decay  = 1e-4
    warmup_ratio  = 0.1

    seed = 42
    # out_dir = Path("./test")
    out_dir = Path("./runs_mae_large_AUG_5e-5")
    out_dir.mkdir(parents=True, exist_ok=True)

# 3. Seed

In [3]:
def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(Config.seed)

# 4. Transforms

In [4]:
USE_AUG = True

# ===============================
# Define all transforms
# ===============================

IMAGE_SIZE = 224  # 如果你文件里是别的，请同步修改

# Herbarium heavy augmentation
train_herbarium_transform = T.Compose([
    T.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(45),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Photo light augmentation
train_photo_transform = T.Compose([
    T.RandomResizedCrop(IMAGE_SIZE),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Validation/Test: no augmentation
eval_transform = T.Compose([
    T.Resize(IMAGE_SIZE + 32),
    T.CenterCrop(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ===============================
# Unified transform selector
# ===============================
def pick_transform(rel_path: str, train: bool = True):
    """
    Automatically pick transform based on folder name.
    rel_path: dataset path (string), used to detect 'herbarium' or 'photo'.
    train: if False → always use eval_transform
    """
    low = rel_path.lower()

    # -----------------------
    # Validation/test mode: never augment
    # -----------------------
    if not train:
        return eval_transform

    # -----------------------
    # Training but augmentation disabled
    # -----------------------
    if not USE_AUG:
        return eval_transform

    # -----------------------
    # Training with augmentation enabled
    # -----------------------
    if "herbarium" in low:
        return train_herbarium_transform

    if "photo" in low:
        return train_photo_transform

    # Default → treat as Photo dataset
    return train_photo_transform

# 5. Loading train/test labels

In [5]:
def read_groundtruth(gt_path: Path):
    mapping = {}
    with open(gt_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                mapping[parts[0]] = int(parts[-1])
    return mapping


def read_train_list(list_path: Path):
    """Expect: train.txt lines are '<rel_path> <label>'"""
    samples = []
    with open(list_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2:
                raise ValueError(f"Train line has no label: {line}")
            rel_path = parts[0]
            label = int(parts[1])
            samples.append((rel_path, label))
    return samples


def read_test_list(list_path: Path):
    """Expect: test.txt lines are either '<rel_path>' or '<rel_path> <label>'"""
    samples = []
    with open(list_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if not parts:
                continue
            rel_path = parts[0]
            label = int(parts[1]) if len(parts) >= 2 else None
            samples.append((rel_path, label))
    return samples


print("=== Path Check ===")
print(f"{Config.data_root} -> {Config.data_root.exists()}")
print(f"{Config.train_list_path} -> {Config.train_list_path.exists()}")
print(f"{Config.test_list_path} -> {Config.test_list_path.exists()}")
print(f"{Config.gt_list_path} -> {Config.gt_list_path.exists()}")

gt_mapping = read_groundtruth(Config.gt_list_path)
train_raw = read_train_list(Config.train_list_path)
test_raw  = read_test_list(Config.test_list_path)

print(f"Line counts -> train={len(train_raw)}, test={len(test_raw)}, groundtruth={len(gt_mapping)}")

# Build raw samples with original labels
train_samples_raw = []
val_samples_raw   = []
unmatched = 0

# Train: labels come directly from train.txt
for rel_path, label in train_raw:
    full_path = Config.data_root / rel_path
    train_samples_raw.append((full_path, label))

# Val/Test: prefer label in test.txt, otherwise use groundtruth.txt
for rel_path, label in test_raw:
    if label is None:
        label = gt_mapping.get(rel_path, None)
    if label is None:
        unmatched += 1
        continue
    full_path = Config.data_root / rel_path
    val_samples_raw.append((full_path, label))

all_labels_raw = [lbl for _, lbl in train_samples_raw] + [lbl for _, lbl in val_samples_raw]
unique_labels = sorted(set(all_labels_raw))

# Map original labels -> contiguous [0, num_classes-1]
label_to_idx = {lab: idx for idx, lab in enumerate(unique_labels)}
idx_to_label = {idx: lab for lab, idx in label_to_idx.items()}

# Apply mapping
train_samples = [(path, label_to_idx[lbl]) for path, lbl in train_samples_raw]
val_samples   = [(path, label_to_idx[lbl]) for path, lbl in val_samples_raw]

Config.num_classes = len(unique_labels)

print(f"[Info] Original label min={min(all_labels_raw)}, max={max(all_labels_raw)}")
print(f"[Info] Number of unique labels={Config.num_classes}")
print(f"[Info] Train samples: {len(train_samples)}, Val samples: {len(val_samples)}, Unmatched test: {unmatched}")
print(f"[Info] Labels have been remapped to range [0, {Config.num_classes - 1}]")

=== Path Check ===
D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset -> True
D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset\list\train.txt -> True
D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset\list\test.txt -> True
D:\Swinburne\Degree3_S2\COS30082_AML\GroupAssignment\dataset\list\groundtruth.txt -> True
Line counts -> train=4744, test=207, groundtruth=207
[Info] Original label min=12254, max=285398
[Info] Number of unique labels=100
[Info] Train samples: 4744, Val samples: 207, Unmatched test: 0
[Info] Labels have been remapped to range [0, 99]


# 6. Dataset & DataLoaders

In [6]:
class PlantsDataset(Dataset):
    def __init__(self, samples, train: bool = True):
        """
        samples: list of (full_path: Path, label_idx: int)
        train : True 表示训练集（会根据 USE_AUG 决定要不要 augmentation）
                False 表示验证/测试集（永远不用 augmentation）
        """
        self.samples = samples
        self.train = train

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        # 根据路径和 train / val 选择对应的 transform
        # pick_transform 在前面已经定义好了：
        #   def pick_transform(rel_path: str, train: bool = True): ...
        rel_path = str(img_path)
        transform = pick_transform(rel_path, train=self.train)

        img = transform(img)
        return img, label


# === Datasets ===
train_dataset = PlantsDataset(train_samples, train=True)
val_dataset   = PlantsDataset(val_samples,   train=False)

# === Dataloaders ===
train_loader = DataLoader(
    train_dataset,
    batch_size=Config.batch_size,
    shuffle=True,
    num_workers=Config.num_workers,
    pin_memory=Config.pin_memory,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=Config.batch_size,
    shuffle=False,
    num_workers=Config.num_workers,
    pin_memory=Config.pin_memory,
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

Train batches: 297, Val batches: 13


# 7. MAE Classification Model

In [7]:
class MAEForPlants(nn.Module):
    def __init__(self, model_name, num_classes):
        super().__init__()
        self.backbone = ViTMAEModel.from_pretrained(model_name)
        embed_dim = self.backbone.config.hidden_size

        self.production = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(0.2),       # 你可以保持 0.1 或改 0.2 都可以
        )
        # 如果你之前加了第二个 dropout 也可以保留
        # self.dropout2 = nn.Dropout(0.2)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        out = self.backbone(pixel_values=x)
        cls_token = out.last_hidden_state[:, 0]
        feats = self.production(cls_token)
        # feats = self.dropout2(feats)  # 如果你有第二个 dropout
        logits = self.classifier(feats)
        return logits



model = MAEForPlants(Config.model_name, Config.num_classes).to(device)
print(model)

MAEForPlants(
  (backbone): ViTMAEModel(
    (embeddings): ViTMAEEmbeddings(
      (patch_embeddings): ViTMAEPatchEmbeddings(
        (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      )
    )
    (encoder): ViTMAEEncoder(
      (layer): ModuleList(
        (0-23): 24 x ViTMAELayer(
          (attention): ViTMAEAttention(
            (attention): ViTMAESelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (output): ViTMAESelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTMAEIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
            (intermediate_act_fn): GELU

# 7. Loss / Optimizer / Scheduler / AMP

In [8]:
criterion = nn.CrossEntropyLoss()

head_params = list(model.production.parameters()) + list(model.classifier.parameters())

optimizer = torch.optim.AdamW(
    model.parameters(),         # ⭐ 重点：整支 model 都参与训练
    lr=Config.lr,               # 建议：比 freeze 版本稍微小一点，例如 5e-5 或 1e-5
    weight_decay=Config.weight_decay
)

total_steps = max(1, len(train_loader)) * Config.epochs
warmup_steps = int(Config.warmup_ratio * total_steps)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps,
)

scaler = torch.cuda.amp.GradScaler()


  scaler = torch.cuda.amp.GradScaler()


# 8. Training Loop

In [9]:
log_dir = Config.out_dir / "logs"
log_dir.mkdir(parents=True, exist_ok=True)

# models 目录（只存 best model）
models_dir = Config.out_dir / "models"
models_dir.mkdir(parents=True, exist_ok=True)

log_path = log_dir / "training_log.txt"
log_f = open(log_path, "w", encoding="utf-8")

print(f"[INFO] Training log will be saved to: {log_path}")
print(f"[INFO] Best model will be saved to: {models_dir}")


def train_one_epoch(epoch):
    model.train()
    running_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.amp.autocast("cuda", enabled=device.type == "cuda"):
            logits = model(imgs)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        pbar.set_postfix(loss=running_loss/total, acc=correct/total*100)

    return running_loss/total, correct/total*100


@torch.no_grad()
def validate(epoch):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)

        with torch.amp.autocast("cuda", enabled=device.type == "cuda"):
            logits = model(imgs)
            loss = criterion(logits, labels)

        running_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        pbar.set_postfix(loss=running_loss/total, acc=correct/total*100)

    return running_loss/total, correct/total*100


best_val_acc = 0
best_model_path = models_dir / "mae_frozen_best.pth"

for epoch in range(Config.epochs):
    train_loss, train_acc = train_one_epoch(epoch)
    val_loss, val_acc = validate(epoch)

    print(f"[Epoch {epoch+1}] Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%")

    # ===== 写 log =====
    log_line = (
        f"Epoch {epoch+1}/{Config.epochs} | "
        f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}% | "
        f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}% | "
        f"Best Val Acc={max(best_val_acc, val_acc):.2f}%\n"
    )
    log_f.write(log_line)
    log_f.flush()

    # ===== 只保存 best model =====
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"[INFO] Saved BEST model to: {best_model_path}")

log_f.close()
print(f"[✓] Training completed. Logs saved to {log_path}")


[INFO] Training log will be saved to: runs_mae_large_AUG_5e-5\logs\training_log.txt
[INFO] Best model will be saved to: runs_mae_large_AUG_5e-5\models


Epoch 1 [Train]:   0%|          | 0/297 [00:00<?, ?it/s]

Epoch 1 [Train]: 100%|██████████| 297/297 [19:24<00:00,  3.92s/it, acc=2.02, loss=4.58]
Epoch 1 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.72it/s, acc=4.83, loss=4.53]


[Epoch 1] Train Acc=2.02%, Val Acc=4.83%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 2 [Train]: 100%|██████████| 297/297 [19:56<00:00,  4.03s/it, acc=11.7, loss=3.98]
Epoch 2 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.75it/s, acc=13, loss=4.07]  


[Epoch 2] Train Acc=11.70%, Val Acc=13.04%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 3 [Train]: 100%|██████████| 297/297 [19:56<00:00,  4.03s/it, acc=22.5, loss=3.25]
Epoch 3 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.75it/s, acc=21.7, loss=3.53]


[Epoch 3] Train Acc=22.47%, Val Acc=21.74%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 4 [Train]: 100%|██████████| 297/297 [20:00<00:00,  4.04s/it, acc=31.9, loss=2.7] 
Epoch 4 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.75it/s, acc=33.3, loss=3.17]


[Epoch 4] Train Acc=31.89%, Val Acc=33.33%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 5 [Train]: 100%|██████████| 297/297 [20:18<00:00,  4.10s/it, acc=41.2, loss=2.31]
Epoch 5 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.71it/s, acc=39.6, loss=2.99]


[Epoch 5] Train Acc=41.19%, Val Acc=39.61%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 6 [Train]: 100%|██████████| 297/297 [20:17<00:00,  4.10s/it, acc=47.8, loss=1.99]
Epoch 6 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.75it/s, acc=41.5, loss=2.89]


[Epoch 6] Train Acc=47.81%, Val Acc=41.55%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 7 [Train]: 100%|██████████| 297/297 [20:18<00:00,  4.10s/it, acc=53.3, loss=1.77]
Epoch 7 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.67it/s, acc=45.4, loss=2.77]


[Epoch 7] Train Acc=53.27%, Val Acc=45.41%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 8 [Train]: 100%|██████████| 297/297 [20:45<00:00,  4.19s/it, acc=56.8, loss=1.57]
Epoch 8 [Val]: 100%|██████████| 13/13 [00:09<00:00,  1.40it/s, acc=49.3, loss=2.62]


[Epoch 8] Train Acc=56.77%, Val Acc=49.28%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 9 [Train]: 100%|██████████| 297/297 [20:33<00:00,  4.15s/it, acc=63, loss=1.4]   
Epoch 9 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.71it/s, acc=48.8, loss=2.68]


[Epoch 9] Train Acc=62.96%, Val Acc=48.79%


Epoch 10 [Train]: 100%|██████████| 297/297 [20:15<00:00,  4.09s/it, acc=64.8, loss=1.28]
Epoch 10 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.70it/s, acc=50.2, loss=2.63]


[Epoch 10] Train Acc=64.84%, Val Acc=50.24%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 11 [Train]: 100%|██████████| 297/297 [20:17<00:00,  4.10s/it, acc=67.4, loss=1.18]
Epoch 11 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.71it/s, acc=54.1, loss=2.58]


[Epoch 11] Train Acc=67.37%, Val Acc=54.11%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 12 [Train]: 100%|██████████| 297/297 [20:19<00:00,  4.11s/it, acc=70.4, loss=1.07]
Epoch 12 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.69it/s, acc=56, loss=2.57]  


[Epoch 12] Train Acc=70.45%, Val Acc=56.04%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 13 [Train]: 100%|██████████| 297/297 [20:16<00:00,  4.10s/it, acc=74.2, loss=0.92] 
Epoch 13 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.72it/s, acc=54.6, loss=2.56]


[Epoch 13] Train Acc=74.22%, Val Acc=54.59%


Epoch 14 [Train]: 100%|██████████| 297/297 [20:15<00:00,  4.09s/it, acc=76.9, loss=0.836]
Epoch 14 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.72it/s, acc=56, loss=2.58]  


[Epoch 14] Train Acc=76.94%, Val Acc=56.04%


Epoch 15 [Train]: 100%|██████████| 297/297 [20:16<00:00,  4.10s/it, acc=78.3, loss=0.767]
Epoch 15 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.70it/s, acc=56, loss=2.56]  


[Epoch 15] Train Acc=78.29%, Val Acc=56.04%


Epoch 16 [Train]: 100%|██████████| 297/297 [20:28<00:00,  4.14s/it, acc=79.4, loss=0.705]
Epoch 16 [Val]: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s, acc=56.5, loss=2.56]


[Epoch 16] Train Acc=79.41%, Val Acc=56.52%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 17 [Train]: 100%|██████████| 297/297 [20:59<00:00,  4.24s/it, acc=82.1, loss=0.636]
Epoch 17 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.68it/s, acc=58.5, loss=2.62]


[Epoch 17] Train Acc=82.12%, Val Acc=58.45%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 18 [Train]: 100%|██████████| 297/297 [20:23<00:00,  4.12s/it, acc=83.9, loss=0.557]
Epoch 18 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.69it/s, acc=54.1, loss=2.6] 


[Epoch 18] Train Acc=83.92%, Val Acc=54.11%


Epoch 19 [Train]: 100%|██████████| 297/297 [20:20<00:00,  4.11s/it, acc=85.6, loss=0.508]
Epoch 19 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.74it/s, acc=58.5, loss=2.56]


[Epoch 19] Train Acc=85.60%, Val Acc=58.45%


Epoch 20 [Train]: 100%|██████████| 297/297 [20:20<00:00,  4.11s/it, acc=87.2, loss=0.456]
Epoch 20 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.69it/s, acc=58, loss=2.52]  


[Epoch 20] Train Acc=87.23%, Val Acc=57.97%


Epoch 21 [Train]: 100%|██████████| 297/297 [20:19<00:00,  4.11s/it, acc=88.5, loss=0.415]
Epoch 21 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.72it/s, acc=59.4, loss=2.47]


[Epoch 21] Train Acc=88.53%, Val Acc=59.42%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 22 [Train]: 100%|██████████| 297/297 [20:06<00:00,  4.06s/it, acc=89.8, loss=0.382]
Epoch 22 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.71it/s, acc=58.5, loss=2.59]


[Epoch 22] Train Acc=89.84%, Val Acc=58.45%


Epoch 23 [Train]: 100%|██████████| 297/297 [19:56<00:00,  4.03s/it, acc=90.3, loss=0.348]
Epoch 23 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.83it/s, acc=56.5, loss=2.68]


[Epoch 23] Train Acc=90.32%, Val Acc=56.52%


Epoch 24 [Train]: 100%|██████████| 297/297 [19:50<00:00,  4.01s/it, acc=91.7, loss=0.305]
Epoch 24 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.82it/s, acc=59.4, loss=2.59]


[Epoch 24] Train Acc=91.69%, Val Acc=59.42%


Epoch 25 [Train]: 100%|██████████| 297/297 [19:49<00:00,  4.01s/it, acc=91.5, loss=0.314]
Epoch 25 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.84it/s, acc=59.4, loss=2.54]


[Epoch 25] Train Acc=91.48%, Val Acc=59.42%


Epoch 26 [Train]: 100%|██████████| 297/297 [19:50<00:00,  4.01s/it, acc=92.5, loss=0.296]
Epoch 26 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.78it/s, acc=59.9, loss=2.55]


[Epoch 26] Train Acc=92.54%, Val Acc=59.90%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 27 [Train]: 100%|██████████| 297/297 [19:51<00:00,  4.01s/it, acc=92.6, loss=0.275]
Epoch 27 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.81it/s, acc=61.4, loss=2.58]


[Epoch 27] Train Acc=92.60%, Val Acc=61.35%
[INFO] Saved BEST model to: runs_mae_large_AUG_5e-5\models\mae_frozen_best.pth


Epoch 28 [Train]: 100%|██████████| 297/297 [19:50<00:00,  4.01s/it, acc=93, loss=0.265]  
Epoch 28 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.78it/s, acc=58.9, loss=2.56]


[Epoch 28] Train Acc=93.04%, Val Acc=58.94%


Epoch 29 [Train]: 100%|██████████| 297/297 [19:52<00:00,  4.02s/it, acc=93.5, loss=0.255]
Epoch 29 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.76it/s, acc=60.4, loss=2.54]


[Epoch 29] Train Acc=93.47%, Val Acc=60.39%


Epoch 30 [Train]: 100%|██████████| 297/297 [19:57<00:00,  4.03s/it, acc=93.1, loss=0.265]
Epoch 30 [Val]: 100%|██████████| 13/13 [00:07<00:00,  1.80it/s, acc=59.4, loss=2.58]

[Epoch 30] Train Acc=93.09%, Val Acc=59.42%
[✓] Training completed. Logs saved to runs_mae_large_AUG_5e-5\logs\training_log.txt





# 9. Evaluation

In [11]:
# === Folder to save results ===
eval_dir = Config.out_dir / "evaluation"
eval_dir.mkdir(parents=True, exist_ok=True)
print(f"[INFO] Saving evaluation results to: {eval_dir}")


# === helper: top-k accuracy ===
def topk_acc_from_topk_array(y_true_group, topk_array):
    """
    y_true_group : [N]
    topk_array   : [N, K]，每行是该样本 top-K 预测的类别 index
    返回: 百分比 (0-100)
    """
    if len(y_true_group) == 0:
        return None
    correct = np.any(topk_array == y_true_group[:, None], axis=1)
    return correct.mean() * 100.0


# === Collect predictions (Top-1 + Top-5) ===
@torch.no_grad()
def collect_preds(loader, k=5):
    model.eval()
    all_top1, all_topk, all_labels = [], [], []
    for imgs, labels in loader:
        imgs = imgs.to(device, non_blocking=True)
        logits = model(imgs)

        top1 = torch.argmax(logits, dim=1)           # [B]
        topk = torch.topk(logits, k=k, dim=1).indices  # [B, k]

        all_top1.append(top1.cpu().numpy())
        all_topk.append(topk.cpu().numpy())
        all_labels.append(labels.numpy())

    y_pred_top1 = np.concatenate(all_top1)      # [N]
    y_pred_topk = np.concatenate(all_topk)      # [N, k]
    y_true      = np.concatenate(all_labels)    # [N]
    return y_pred_top1, y_pred_topk, y_true


y_pred, y_top5, y_true = collect_preds(val_loader, k=5)

# === Overall Top-1 & Top-5 ===
overall_top1 = accuracy_score(y_true, y_pred) * 100.0
overall_top5 = topk_acc_from_topk_array(y_true, y_top5)

print(f"\n[Overall] Top-1 Accuracy: {overall_top1:.2f}%")
print(f"[Overall] Top-5 Accuracy: {overall_top5:.2f}%\n")


# === 1. Save Classification Report ===
report_str = classification_report(
    y_true,
    y_pred,
    digits=4,
    target_names=[str(idx_to_label[i]) for i in range(len(idx_to_label))]
)

report_path = eval_dir / "classification_report.txt"
with open(report_path, "w", encoding="utf-8") as f:
    f.write(f"Overall Top-1 Accuracy: {overall_top1:.4f}%\n")
    f.write(f"Overall Top-5 Accuracy: {overall_top5:.4f}%\n")
    f.write("\n")
    f.write(report_str)

print(f"[Saved] Classification report → {report_path}")


# === 2. Save Confusion Matrix Plot (Top-1) ===
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(idx_to_label))))

plt.figure(figsize=(8, 7))
plt.imshow(cm, cmap="Blues", interpolation="nearest")
plt.title("Confusion Matrix (counts)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.colorbar()
plt.tight_layout()

cm_path = eval_dir / "confusion_matrix.png"
plt.savefig(cm_path, dpi=300)
plt.close()

print(f"[Saved] Confusion matrix image → {cm_path}")


# === 3. Save Per-Class Metrics as CSV (from Top-1) ===
report_dict = classification_report(
    y_true,
    y_pred,
    digits=4,
    target_names=[str(idx_to_label[i]) for i in range(len(idx_to_label))],
    output_dict=True
)

metrics_df = pd.DataFrame(report_dict).transpose()
metrics_path = eval_dir / "per_class_metrics.csv"
metrics_df.to_csv(metrics_path, encoding="utf-8")

print(f"[Saved] Per-class CSV → {metrics_path}")

print("\n[✓] All evaluation results (including Top-1 & Top-5) saved successfully.")



# ====== Step 9 (extra): WITH / WITHOUT PAIRS 分组分析并保存到 evaluation 文件夹 ======

# 0. 确保 evaluation 目录存在（和前面 Step 9 保存 report 的目录一致）
eval_dir = Config.out_dir / "evaluation"
eval_dir.mkdir(parents=True, exist_ok=True)

print(f"\n[INFO] Saving with/without-pairs results to: {eval_dir}")

# 1. 设置 with / without pairs 的 class list 文件路径
#    确保这两个文件在当前工作目录下，或者改成你的完整路径
WITH_PAIRS_FILE = list_dir / "class_with_pairs.txt"
WITHOUT_PAIRS_FILE = list_dir / "class_without_pairs.txt"

# 2. 读入原始 class ID（每行一个）
with open(WITH_PAIRS_FILE, "r", encoding="utf-8") as f:
    with_pairs_ids = {line.strip() for line in f if line.strip()}

with open(WITHOUT_PAIRS_FILE, "r", encoding="utf-8") as f:
    without_pairs_ids = {line.strip() for line in f if line.strip()}

print(f"#classes in WITH-PAIRS list   : {len(with_pairs_ids)}")
print(f"#classes in WITHOUT-PAIRS list: {len(without_pairs_ids)}")

# 3. 把 y_true 里的「类别 index」转换成「原始 class ID 字符串」
#    假设 idx_to_label[i] 就是原始 ID（数字或字符串），统一转成 str 来对比
y_true_class_ids = np.array([str(idx_to_label[int(c)]) for c in y_true])

# 4. 根据 class ID mask 出 with-pair / without-pair 的样本
mask_with_pairs = np.isin(y_true_class_ids, list(with_pairs_ids))
mask_without_pairs = np.isin(y_true_class_ids, list(without_pairs_ids))

y_true_with = y_true[mask_with_pairs]
y_pred_with = y_pred[mask_with_pairs]
y_top5_with = y_top5[mask_with_pairs]

y_true_without = y_true[mask_without_pairs]
y_pred_without = y_pred[mask_without_pairs]
y_top5_without = y_top5[mask_without_pairs]

print("\n=== WITH-PAIRS / WITHOUT-PAIRS Result on Validation Set ===")
print(f"Samples in WITH-PAIRS group   : {len(y_true_with)}")
print(f"Samples in WITHOUT-PAIRS group: {len(y_true_without)}")

# 防止某个 group 为空
if len(y_true_with) > 0:
    acc_with_top1 = accuracy_score(y_true_with, y_pred_with) * 100.0
    acc_with_top5 = topk_acc_from_topk_array(y_true_with, y_top5_with)
    print(f"\nGroup: WITH PAIRS")
    print(f"  Top-1 Accuracy: {acc_with_top1:.2f}%")
    print(f"  Top-5 Accuracy: {acc_with_top5:.2f}%")
else:
    acc_with_top1 = None
    acc_with_top5 = None
    print("\nGroup: WITH PAIRS")
    print("  (No samples from WITH-PAIRS class IDs found in this validation set.)")

if len(y_true_without) > 0:
    acc_without_top1 = accuracy_score(y_true_without, y_pred_without) * 100.0
    acc_without_top5 = topk_acc_from_topk_array(y_true_without, y_top5_without)
    print(f"\nGroup: WITHOUT PAIRS")
    print(f"  Top-1 Accuracy: {acc_without_top1:.2f}%")
    print(f"  Top-5 Accuracy: {acc_without_top5:.2f}%")
else:
    acc_without_top1 = None
    acc_without_top5 = None
    print("\nGroup: WITHOUT PAIRS")
    print("  (No samples from WITHOUT-PAIRS class IDs found in this validation set.)")

# 5. 把 with / without 的结果也存到当前 run 的 evaluation 目录里
out_path = eval_dir / "val_with_without_pairs_results.txt"
with open(out_path, "w", encoding="utf-8") as f:
    f.write("WITH-PAIRS / WITHOUT-PAIRS Result on Validation Set\n")
    f.write(f"Samples in WITH-PAIRS group   : {len(y_true_with)}\n")
    f.write(f"Samples in WITHOUT-PAIRS group: {len(y_true_without)}\n\n")

    if acc_with_top1 is not None:
        f.write(f"WITH PAIRS Top-1 Accuracy   : {acc_with_top1:.4f}%\n")
        f.write(f"WITH PAIRS Top-5 Accuracy   : {acc_with_top5:.4f}%\n")
    else:
        f.write("WITH PAIRS: no samples in this val set\n")

    if acc_without_top1 is not None:
        f.write(f"WITHOUT PAIRS Top-1 Accuracy: {acc_without_top1:.4f}%\n")
        f.write(f"WITHOUT PAIRS Top-5 Accuracy: {acc_without_top5:.4f}%\n")
    else:
        f.write("WITHOUT PAIRS: no samples in this val set\n")

print(f"\n[Saved] With/without-pairs results → {out_path}")

[INFO] Saving evaluation results to: runs_mae_large_AUG_5e-5\evaluation

[Overall] Top-1 Accuracy: 59.42%
[Overall] Top-5 Accuracy: 68.60%

[Saved] Classification report → runs_mae_large_AUG_5e-5\evaluation\classification_report.txt


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Saved] Confusion matrix image → runs_mae_large_AUG_5e-5\evaluation\confusion_matrix.png


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Saved] Per-class CSV → runs_mae_large_AUG_5e-5\evaluation\per_class_metrics.csv

[✓] All evaluation results (including Top-1 & Top-5) saved successfully.

[INFO] Saving with/without-pairs results to: runs_mae_large_AUG_5e-5\evaluation
#classes in WITH-PAIRS list   : 60
#classes in WITHOUT-PAIRS list: 40

=== WITH-PAIRS / WITHOUT-PAIRS Result on Validation Set ===
Samples in WITH-PAIRS group   : 153
Samples in WITHOUT-PAIRS group: 54

Group: WITH PAIRS
  Top-1 Accuracy: 80.39%
  Top-5 Accuracy: 92.81%

Group: WITHOUT PAIRS
  Top-1 Accuracy: 0.00%
  Top-5 Accuracy: 0.00%

[Saved] With/without-pairs results → runs_mae_large_AUG_5e-5\evaluation\val_with_without_pairs_results.txt
