In [1]:
# train_cnn_from_single_root.py
import os, math, time, random
from pathlib import Path
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models

# ========== 0) 基本配置 ==========
DATA_ROOT = r"F:\5703Dataset\image"   # 你的根目录：里面有 Stage 1 / Stage 2 / Stage 3 / Stage 4
SAVE_PATH = "best_stage_resnet18.pth"

VAL_RATIO   = 0.2
BATCH_SIZE  = 64
EPOCHS      = 25
LR          = 1e-3
WEIGHT_DECAY= 1e-4
PATIENCE    = 5
IMG_SIZE    = 224
NUM_WORKERS = 4
USE_PRETRAINED = False  # 有网/有缓存可改 True

SEED = 42
random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

# ========== 1) 数据与分层划分 ==========
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.15,0.15,0.15,0.02),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

full_ds = datasets.ImageFolder(DATA_ROOT, transform=None)  # 先不加变换，划分时不需要
classes = full_ds.classes  # 文件夹名顺序
print("Classes (from folder names):", classes)  # 例如 ['Stage 1','Stage 2','Stage 3','Stage 4']

# 取每个样本的标签
targets = full_ds.targets  # list[int]
indices_per_class = {}
for idx, y in enumerate(targets):
    indices_per_class.setdefault(y, []).append(idx)

# 按类别分层切分
train_indices, val_indices = [], []
import math
for y, idxs in indices_per_class.items():
    n = len(idxs)
    n_val = max(1, int(round(n * VAL_RATIO)))
    random.Random(SEED).shuffle(idxs)
    val_indices.extend(idxs[:n_val])
    train_indices.extend(idxs[n_val:])

# 创建带变换的数据集
train_ds = datasets.ImageFolder(DATA_ROOT, transform=train_tfms)
val_ds   = datasets.ImageFolder(DATA_ROOT, transform=val_tfms)

train_subset = Subset(train_ds, train_indices)
val_subset   = Subset(val_ds,   val_indices)

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_subset,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

# 类别计数（用于加权）
train_counts = Counter([full_ds.targets[i] for i in train_indices])
num_classes  = len(classes)
counts_tensor = torch.tensor([train_counts.get(c,0) for c in range(num_classes)], dtype=torch.float)
class_weights = (counts_tensor.sum() / (counts_tensor + 1e-9))
class_weights = class_weights / class_weights.mean()
print("Train class counts:", dict(train_counts))
print("Class weights:", class_weights.tolist())

# ========== 2) 模型 ==========
def build_model(nc=num_classes, pretrained=USE_PRETRAINED):
    if pretrained:
        weights = models.ResNet18_Weights.DEFAULT
        model = models.resnet18(weights=weights)
    else:
        model = models.resnet18(weights=None)
    in_feat = model.fc.in_features
    model.fc = nn.Linear(in_feat, nc)
    return model

model = build_model().to(device)

# ========== 3) 训练配置 ==========
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

best_val = math.inf
no_improve = 0

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total_loss, total_correct, total_num = 0.0, 0, 0
    cm = torch.zeros(num_classes, num_classes, dtype=torch.long)
    for images, labels in loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        if train:
            optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits = model(images)
            loss = criterion(logits, labels)
        if train:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        total_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_num += labels.size(0)

        for t, p in zip(labels.view(-1), preds.view(-1)):
            cm[t.long(), p.long()] += 1

    return total_loss / max(total_num,1), total_correct / max(total_num,1), cm

# ========== 4) 训练循环 ==========
for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    tr_loss, tr_acc, _ = run_epoch(train_loader, train=True)
    val_loss, val_acc, val_cm = run_epoch(val_loader, train=False)
    print(f"[{epoch:02d}/{EPOCHS}] Train {tr_loss:.4f}/{tr_acc:.4f} | Val {val_loss:.4f}/{val_acc:.4f} | {time.time()-t0:.1f}s")
    if val_loss < best_val - 1e-4:
        best_val = val_loss
        no_improve = 0
        torch.save({
            "model_state": model.state_dict(),
            "classes": classes,  # 保存原文件夹类名
            "args": {"img_size": IMG_SIZE, "mean":[0.485,0.456,0.406], "std":[0.229,0.224,0.225]}
        }, SAVE_PATH)
        print(f"  ↳ Saved best to {SAVE_PATH}")
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print("Early stopping.")
            break

# ========== 5) 评估 ==========
print("\nConfusion Matrix (rows=true, cols=pred):")
print(val_cm.cpu().numpy())

eps = 1e-12
tp = val_cm.diag().float()
pred_pos = val_cm.sum(dim=0).float()
true_pos = val_cm.sum(dim=1).float()
precision = tp / (pred_pos + eps)
recall    = tp / (true_pos + eps)
f1        = 2 * precision * recall / (precision + recall + eps)

print("\nPer-class metrics:")
for i, cname in enumerate(classes):
    print(f"{cname:>8s} | P: {precision[i]:.3f}  R: {recall[i]:.3f}  F1: {f1[i]:.3f}")
print(f"\nMacro-F1: {f1.mean().item():.3f}")
print("Done.")


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


Classes (from folder names): ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
Train class counts: {0: 1985, 1: 3574, 2: 4006, 3: 11849}
Class weights: [1.8030756711959839, 1.0014283657073975, 0.8934361338615417, 0.30205968022346497]


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


[01/25] Train 0.6228/0.7945 | Val 2.7167/0.3996 | 115.7s
  ↳ Saved best to best_stage_resnet18.pth
[02/25] Train 0.3914/0.8770 | Val 0.7689/0.7984 | 81.8s
  ↳ Saved best to best_stage_resnet18.pth
[03/25] Train 0.3156/0.9012 | Val 0.4831/0.8489 | 82.2s
  ↳ Saved best to best_stage_resnet18.pth
[04/25] Train 0.2585/0.9197 | Val 7.9083/0.3249 | 81.6s
[05/25] Train 0.2339/0.9324 | Val 0.6318/0.7848 | 83.2s
[06/25] Train 0.2020/0.9395 | Val 11.4751/0.2916 | 81.9s
[07/25] Train 0.2027/0.9412 | Val 0.3790/0.8976 | 81.6s
  ↳ Saved best to best_stage_resnet18.pth
[08/25] Train 0.1987/0.9406 | Val 0.7999/0.7988 | 82.4s
[09/25] Train 0.1820/0.9472 | Val 0.1763/0.9432 | 83.4s
  ↳ Saved best to best_stage_resnet18.pth
[10/25] Train 0.1707/0.9503 | Val 0.0838/0.9733 | 86.4s
  ↳ Saved best to best_stage_resnet18.pth
[11/25] Train 0.1539/0.9548 | Val 4.7823/0.4257 | 83.6s
[12/25] Train 0.1440/0.9571 | Val 0.2175/0.9215 | 81.3s
[13/25] Train 0.1412/0.9595 | Val 7.4541/0.3688 | 82.0s
[14/25] Train 0.13