In [None]:
#   1 🎯 Test Loss: 0.2043 | Test Accuracy: 0.9418
#   2 🎯 Test Loss: 0.5611 | Test Accuracy: 0.9136
#   3 🎯 Test Loss: 0.0806 | Test Accuracy: 0.9912
#   病变 🎯 Test Loss: 0.1399 | Test Accuracy: 0.9606

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16, ViT_B_16_Weights
from tqdm.notebook import tqdm

# —— 1. 忽略隐藏目录 —— #
class FilteredImageFolder(ImageFolder):
    def find_classes(self, directory):
        classes = [d.name for d in os.scandir(directory) if d.is_dir() and not d.name.startswith('.')]
        classes.sort()
        class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        return classes, class_to_idx

# —— 2. 配置 —— #
train_dir = "/root/autodl-fs/isic19_20_split/train"
val_dir   = "/root/autodl-fs/isic19_20_split/val"
test_dir  = "/root/autodl-fs/isic19_20_split/test"
ckpt_path     = "/root/autodl-fs/best_vit.pth"
batch_size    = 32
num_epochs    = 10
learning_rate = 5e-5  # ViT 通常使用较小的学习率
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# —— 3. 使用 ViT 的官方预处理 —— #
weights = ViT_B_16_Weights.DEFAULT
transform = weights.transforms()

# —— 4. 加载数据 —— #
train_dataset = FilteredImageFolder(root=train_dir, transform=transform)
test_dataset  = FilteredImageFolder(root=test_dir,  transform=transform)

if os.path.exists(val_dir):
    val_dataset = FilteredImageFolder(root=val_dir, transform=transform)
else:
    val_len = int(len(train_dataset) * 0.15)
    train_len = len(train_dataset) - val_len
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4)

# —— 5. 模型定义 —— #
model = vit_b_16(weights=weights)
num_classes = len(train_dataset.dataset.classes if hasattr(train_dataset, 'dataset') else train_dataset.classes)
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# —— 6. 训练 + 验证 + 保存最佳模型 —— #
best_val_acc = 0.0
for epoch in range(num_epochs):
    model.train()
    train_loss, train_acc = 0.0, 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
        train_acc += (out.argmax(1) == y).sum().item()
    
    train_loss /= len(train_loader.dataset)
    train_acc  /= len(train_loader.dataset)

    # 验证
    model.eval()
    val_loss, val_acc = 0.0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            val_loss += loss.item() * x.size(0)
            val_acc  += (out.argmax(1) == y).sum().item()
    val_loss /= len(val_loader.dataset)
    val_acc  /= len(val_loader.dataset)

    print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
        torch.save(model.state_dict(), ckpt_path)
        print("✅ Saved Best Model!")

    # 每个 epoch 保存一次模型
    epoch_ckpt_path = f"/root/autodl-fs/ckpt_vit/epoch_{epoch+1}.pth"
    os.makedirs(os.path.dirname(epoch_ckpt_path), exist_ok=True)
    torch.save(model.state_dict(), epoch_ckpt_path)

# —— 7. 加载并评估测试集 —— #
model.load_state_dict(torch.load(ckpt_path))
model.eval()
test_loss, test_acc, total = 0.0, 0, 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        test_loss += loss.item() * x.size(0)
        test_acc  += (out.argmax(1) == y).sum().item()
        total += y.size(0)

test_loss /= total
test_acc  /= total
print(f"🎯 Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")

# —— 8. 所有 epoch 的模型推理 —— #
print("\n📊 Evaluating all epoch checkpoints on test set:")
epoch_results = []

for e in range(1, num_epochs + 1):
    ckpt_file = f"/root/autodl-fs/ckpt_vit/epoch_{e}.pth"
    if not os.path.exists(ckpt_file):
        print(f"❌ Epoch {e} model not found.")
        continue

    model.load_state_dict(torch.load(ckpt_file))
    model.eval()

    test_loss, test_acc, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            test_loss += loss.item() * x.size(0)
            test_acc  += (out.argmax(1) == y).sum().item()
            total += y.size(0)

    test_loss /= total
    test_acc  /= total
    epoch_results.append((e, test_loss, test_acc))
    print(f"📁 Epoch {e:02d} | Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [01:32<00:00, 3.76MB/s] 


Epoch 1/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 1] Train Loss: 0.2430, Acc: 0.8967 | Val Loss: 0.1823, Acc: 0.9295
✅ Saved Best Model!


Epoch 2/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 2] Train Loss: 0.1655, Acc: 0.9364 | Val Loss: 0.1745, Acc: 0.9319
✅ Saved Best Model!


Epoch 3/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 3] Train Loss: 0.1211, Acc: 0.9552 | Val Loss: 0.1801, Acc: 0.9388
✅ Saved Best Model!


Epoch 4/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 4] Train Loss: 0.1008, Acc: 0.9636 | Val Loss: 0.2495, Acc: 0.9185


Epoch 5/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 5] Train Loss: 0.0795, Acc: 0.9717 | Val Loss: 0.2793, Acc: 0.8893


Epoch 6/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 6] Train Loss: 0.0447, Acc: 0.9842 | Val Loss: 0.2561, Acc: 0.9190


Epoch 7/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 7] Train Loss: 0.0416, Acc: 0.9860 | Val Loss: 0.2922, Acc: 0.9045


Epoch 8/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 8] Train Loss: 0.0314, Acc: 0.9896 | Val Loss: 0.3221, Acc: 0.8992


Epoch 9/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 9] Train Loss: 0.0243, Acc: 0.9924 | Val Loss: 0.2801, Acc: 0.9330


Epoch 10/10:   0%|          | 0/251 [00:00<?, ?it/s]

[Epoch 10] Train Loss: 0.0309, Acc: 0.9889 | Val Loss: 0.2986, Acc: 0.9255


  model.load_state_dict(torch.load(ckpt_path))


🎯 Test Loss: 0.1621 | Test Accuracy: 0.9418

📊 Evaluating all epoch checkpoints on test set:


  model.load_state_dict(torch.load(ckpt_file))


📁 Epoch 01 | Test Loss: 0.1781 | Test Accuracy: 0.9331
📁 Epoch 02 | Test Loss: 0.1664 | Test Accuracy: 0.9424
📁 Epoch 03 | Test Loss: 0.1621 | Test Accuracy: 0.9418
📁 Epoch 04 | Test Loss: 0.2043 | Test Accuracy: 0.9255
📁 Epoch 05 | Test Loss: 0.2699 | Test Accuracy: 0.8958
📁 Epoch 06 | Test Loss: 0.2317 | Test Accuracy: 0.9348
📁 Epoch 07 | Test Loss: 0.2387 | Test Accuracy: 0.9185
📁 Epoch 08 | Test Loss: 0.2975 | Test Accuracy: 0.9092
📁 Epoch 09 | Test Loss: 0.2634 | Test Accuracy: 0.9383
📁 Epoch 10 | Test Loss: 0.2613 | Test Accuracy: 0.9331


In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16, ViT_B_16_Weights
from tqdm.notebook import tqdm

# —— 1. 忽略隐藏目录 —— #
class FilteredImageFolder(ImageFolder):
    def find_classes(self, directory):
        classes = [d.name for d in os.scandir(directory) if d.is_dir() and not d.name.startswith('.')]
        classes.sort()
        class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        return classes, class_to_idx

# —— 2. 配置 —— #
train_dir = "/root/autodl-fs/processed/processed/train"
val_dir   = "/root/autodl-fs/processed/processed/val"
test_dir  = "/root/autodl-fs/processed/processed/test"
ckpt_path     = "/root/autodl-fs/best_vit.pth"
batch_size    = 32
num_epochs    = 10
learning_rate = 5e-5  # ViT 通常使用较小的学习率
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# —— 3. 使用 ViT 的官方预处理 —— #
weights = ViT_B_16_Weights.DEFAULT
transform = weights.transforms()

# —— 4. 加载数据 —— #
train_dataset = FilteredImageFolder(root=train_dir, transform=transform)
test_dataset  = FilteredImageFolder(root=test_dir,  transform=transform)

if os.path.exists(val_dir):
    val_dataset = FilteredImageFolder(root=val_dir, transform=transform)
else:
    val_len = int(len(train_dataset) * 0.15)
    train_len = len(train_dataset) - val_len
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4)

# —— 5. 模型定义 —— #
model = vit_b_16(weights=weights)
num_classes = len(train_dataset.dataset.classes if hasattr(train_dataset, 'dataset') else train_dataset.classes)
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# —— 6. 训练 + 验证 + 保存最佳模型 —— #
best_val_acc = 0.0
for epoch in range(num_epochs):
    model.train()
    train_loss, train_acc = 0.0, 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
        train_acc += (out.argmax(1) == y).sum().item()
    
    train_loss /= len(train_loader.dataset)
    train_acc  /= len(train_loader.dataset)

    # 验证
    model.eval()
    val_loss, val_acc = 0.0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            val_loss += loss.item() * x.size(0)
            val_acc  += (out.argmax(1) == y).sum().item()
    val_loss /= len(val_loader.dataset)
    val_acc  /= len(val_loader.dataset)

    print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
        torch.save(model.state_dict(), ckpt_path)
        print("✅ Saved Best Model!")

    # 每个 epoch 保存一次模型
    epoch_ckpt_path = f"/root/autodl-fs/ckpt_vit/epoch_{epoch+1}.pth"
    os.makedirs(os.path.dirname(epoch_ckpt_path), exist_ok=True)
    torch.save(model.state_dict(), epoch_ckpt_path)

# —— 7. 加载并评估测试集 —— #
model.load_state_dict(torch.load(ckpt_path))
model.eval()
test_loss, test_acc, total = 0.0, 0, 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        test_loss += loss.item() * x.size(0)
        test_acc  += (out.argmax(1) == y).sum().item()
        total += y.size(0)

test_loss /= total
test_acc  /= total
print(f"🎯 Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")

# —— 8. 所有 epoch 的模型推理 —— #
print("\n📊 Evaluating all epoch checkpoints on test set:")
epoch_results = []

for e in range(1, num_epochs + 1):
    ckpt_file = f"/root/autodl-fs/ckpt_vit/epoch_{e}.pth"
    if not os.path.exists(ckpt_file):
        print(f"❌ Epoch {e} model not found.")
        continue

    model.load_state_dict(torch.load(ckpt_file))
    model.eval()

    test_loss, test_acc, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            test_loss += loss.item() * x.size(0)
            test_acc  += (out.argmax(1) == y).sum().item()
            total += y.size(0)

    test_loss /= total
    test_acc  /= total
    epoch_results.append((e, test_loss, test_acc))
    print(f"📁 Epoch {e:02d} | Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")


Epoch 1/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 1] Train Loss: 0.2267, Acc: 0.9080 | Val Loss: 0.1853, Acc: 0.9238
✅ Saved Best Model!


Epoch 2/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 2] Train Loss: 0.1535, Acc: 0.9410 | Val Loss: 0.1820, Acc: 0.9276
✅ Saved Best Model!


Epoch 3/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 3] Train Loss: 0.1075, Acc: 0.9602 | Val Loss: 0.1763, Acc: 0.9350
✅ Saved Best Model!


Epoch 4/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 4] Train Loss: 0.0752, Acc: 0.9704 | Val Loss: 0.2333, Acc: 0.9251


Epoch 5/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 5] Train Loss: 0.0609, Acc: 0.9786 | Val Loss: 0.2206, Acc: 0.9245


Epoch 6/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 6] Train Loss: 0.0494, Acc: 0.9825 | Val Loss: 0.2702, Acc: 0.9294


Epoch 7/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 7] Train Loss: 0.0335, Acc: 0.9883 | Val Loss: 0.2822, Acc: 0.9269


Epoch 8/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 8] Train Loss: 0.0335, Acc: 0.9884 | Val Loss: 0.3024, Acc: 0.9288


Epoch 9/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 9] Train Loss: 0.0213, Acc: 0.9934 | Val Loss: 0.2894, Acc: 0.9276


Epoch 10/10:   0%|          | 0/287 [00:00<?, ?it/s]

[Epoch 10] Train Loss: 0.0185, Acc: 0.9927 | Val Loss: 0.3948, Acc: 0.9238


  model.load_state_dict(torch.load(ckpt_path))


🎯 Test Loss: 0.2742 | Test Accuracy: 0.9136

📊 Evaluating all epoch checkpoints on test set:


  model.load_state_dict(torch.load(ckpt_file))


📁 Epoch 01 | Test Loss: 0.2411 | Test Accuracy: 0.9004
📁 Epoch 02 | Test Loss: 0.2298 | Test Accuracy: 0.9107
📁 Epoch 03 | Test Loss: 0.2742 | Test Accuracy: 0.9136
📁 Epoch 04 | Test Loss: 0.3461 | Test Accuracy: 0.9078
📁 Epoch 05 | Test Loss: 0.2891 | Test Accuracy: 0.9019
📁 Epoch 06 | Test Loss: 0.5391 | Test Accuracy: 0.8975
📁 Epoch 07 | Test Loss: 0.4414 | Test Accuracy: 0.9034
📁 Epoch 08 | Test Loss: 0.4667 | Test Accuracy: 0.9136
📁 Epoch 09 | Test Loss: 0.4153 | Test Accuracy: 0.9048
📁 Epoch 10 | Test Loss: 0.5611 | Test Accuracy: 0.8990


In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16, ViT_B_16_Weights
from tqdm.notebook import tqdm

# —— 1. 忽略隐藏目录 —— #
class FilteredImageFolder(ImageFolder):
    def find_classes(self, directory):
        classes = [d.name for d in os.scandir(directory) if d.is_dir() and not d.name.startswith('.')]
        classes.sort()
        class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        return classes, class_to_idx

# —— 2. 配置 —— #
train_dir = "/root/autodl-fs/generate/test"
val_dir   = "/root/autodl-fs/generate/val"
test_dir  = "/root/autodl-fs/generate/test"
ckpt_path     = "/root/autodl-fs/best_vit.pth"
batch_size    = 32
num_epochs    = 10
learning_rate = 5e-5  # ViT 通常使用较小的学习率
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# —— 3. 使用 ViT 的官方预处理 —— #
weights = ViT_B_16_Weights.DEFAULT
transform = weights.transforms()

# —— 4. 加载数据 —— #
train_dataset = FilteredImageFolder(root=train_dir, transform=transform)
test_dataset  = FilteredImageFolder(root=test_dir,  transform=transform)

if os.path.exists(val_dir):
    val_dataset = FilteredImageFolder(root=val_dir, transform=transform)
else:
    val_len = int(len(train_dataset) * 0.15)
    train_len = len(train_dataset) - val_len
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4)

# —— 5. 模型定义 —— #
model = vit_b_16(weights=weights)
num_classes = len(train_dataset.dataset.classes if hasattr(train_dataset, 'dataset') else train_dataset.classes)
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# —— 6. 训练 + 验证 + 保存最佳模型 —— #
best_val_acc = 0.0
for epoch in range(num_epochs):
    model.train()
    train_loss, train_acc = 0.0, 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
        train_acc += (out.argmax(1) == y).sum().item()
    
    train_loss /= len(train_loader.dataset)
    train_acc  /= len(train_loader.dataset)

    # 验证
    model.eval()
    val_loss, val_acc = 0.0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            val_loss += loss.item() * x.size(0)
            val_acc  += (out.argmax(1) == y).sum().item()
    val_loss /= len(val_loader.dataset)
    val_acc  /= len(val_loader.dataset)

    print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
        torch.save(model.state_dict(), ckpt_path)
        print("✅ Saved Best Model!")

    # 每个 epoch 保存一次模型
    epoch_ckpt_path = f"/root/autodl-fs/ckpt_vit/epoch_{epoch+1}.pth"
    os.makedirs(os.path.dirname(epoch_ckpt_path), exist_ok=True)
    torch.save(model.state_dict(), epoch_ckpt_path)

# —— 7. 加载并评估测试集 —— #
model.load_state_dict(torch.load(ckpt_path))
model.eval()
test_loss, test_acc, total = 0.0, 0, 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        test_loss += loss.item() * x.size(0)
        test_acc  += (out.argmax(1) == y).sum().item()
        total += y.size(0)

test_loss /= total
test_acc  /= total
print(f"🎯 Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")

# —— 8. 所有 epoch 的模型推理 —— #
print("\n📊 Evaluating all epoch checkpoints on test set:")
epoch_results = []

for e in range(1, num_epochs + 1):
    ckpt_file = f"/root/autodl-fs/ckpt_vit/epoch_{e}.pth"
    if not os.path.exists(ckpt_file):
        print(f"❌ Epoch {e} model not found.")
        continue

    model.load_state_dict(torch.load(ckpt_file))
    model.eval()

    test_loss, test_acc, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            test_loss += loss.item() * x.size(0)
            test_acc  += (out.argmax(1) == y).sum().item()
            total += y.size(0)

    test_loss /= total
    test_acc  /= total
    epoch_results.append((e, test_loss, test_acc))
    print(f"📁 Epoch {e:02d} | Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")


Epoch 1/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 1] Train Loss: 0.4730, Acc: 0.8055 | Val Loss: 0.2508, Acc: 0.9020
✅ Saved Best Model!


Epoch 2/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 2] Train Loss: 0.2161, Acc: 0.9277 | Val Loss: 0.2269, Acc: 0.9118
✅ Saved Best Model!


Epoch 3/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 3] Train Loss: 0.1181, Acc: 0.9570 | Val Loss: 0.2425, Acc: 0.9216
✅ Saved Best Model!


Epoch 4/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 4] Train Loss: 0.0926, Acc: 0.9673 | Val Loss: 0.2842, Acc: 0.9412
✅ Saved Best Model!


Epoch 5/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 5] Train Loss: 0.0428, Acc: 0.9862 | Val Loss: 0.4285, Acc: 0.8627


Epoch 6/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 6] Train Loss: 0.0449, Acc: 0.9880 | Val Loss: 0.2778, Acc: 0.9314


Epoch 7/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 7] Train Loss: 0.0156, Acc: 0.9983 | Val Loss: 0.2899, Acc: 0.9412


Epoch 8/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 8] Train Loss: 0.0040, Acc: 0.9983 | Val Loss: 0.3510, Acc: 0.9314


Epoch 9/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 9] Train Loss: 0.0226, Acc: 0.9966 | Val Loss: 0.2795, Acc: 0.9510
✅ Saved Best Model!


Epoch 10/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 10] Train Loss: 0.0044, Acc: 0.9983 | Val Loss: 0.2930, Acc: 0.9412


  model.load_state_dict(torch.load(ckpt_path))


🎯 Test Loss: 0.0475 | Test Accuracy: 0.9912

📊 Evaluating all epoch checkpoints on test set:


  model.load_state_dict(torch.load(ckpt_file))


📁 Epoch 01 | Test Loss: 0.2462 | Test Accuracy: 0.9004
📁 Epoch 02 | Test Loss: 0.1472 | Test Accuracy: 0.9531
📁 Epoch 03 | Test Loss: 0.1069 | Test Accuracy: 0.9678
📁 Epoch 04 | Test Loss: 0.0829 | Test Accuracy: 0.9766
📁 Epoch 05 | Test Loss: 0.0806 | Test Accuracy: 0.9780
📁 Epoch 06 | Test Loss: 0.0556 | Test Accuracy: 0.9883
📁 Epoch 07 | Test Loss: 0.0450 | Test Accuracy: 0.9912
📁 Epoch 08 | Test Loss: 0.0624 | Test Accuracy: 0.9883
📁 Epoch 09 | Test Loss: 0.0475 | Test Accuracy: 0.9912
📁 Epoch 10 | Test Loss: 0.0446 | Test Accuracy: 0.9912


In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16, ViT_B_16_Weights
from tqdm.notebook import tqdm

# —— 1. 忽略隐藏目录 —— #
class FilteredImageFolder(ImageFolder):
    def find_classes(self, directory):
        classes = [d.name for d in os.scandir(directory) if d.is_dir() and not d.name.startswith('.')]
        classes.sort()
        class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        return classes, class_to_idx

# —— 2. 配置 —— #
train_dir = "/root/autodl-fs/generate_twice/test"
val_dir   = "/root/autodl-fs/generate_twice/val"
test_dir  = "/root/autodl-fs/generate_twice/test"
ckpt_path     = "/root/autodl-fs/best_vit.pth"
batch_size    = 32
num_epochs    = 10
learning_rate = 5e-5  # ViT 通常使用较小的学习率
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# —— 3. 使用 ViT 的官方预处理 —— #
weights = ViT_B_16_Weights.DEFAULT
transform = weights.transforms()

# —— 4. 加载数据 —— #
train_dataset = FilteredImageFolder(root=train_dir, transform=transform)
test_dataset  = FilteredImageFolder(root=test_dir,  transform=transform)

if os.path.exists(val_dir):
    val_dataset = FilteredImageFolder(root=val_dir, transform=transform)
else:
    val_len = int(len(train_dataset) * 0.15)
    train_len = len(train_dataset) - val_len
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=4)

# —— 5. 模型定义 —— #
model = vit_b_16(weights=weights)
num_classes = len(train_dataset.dataset.classes if hasattr(train_dataset, 'dataset') else train_dataset.classes)
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# —— 6. 训练 + 验证 + 保存最佳模型 —— #
best_val_acc = 0.0
for epoch in range(num_epochs):
    model.train()
    train_loss, train_acc = 0.0, 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
        train_acc += (out.argmax(1) == y).sum().item()
    
    train_loss /= len(train_loader.dataset)
    train_acc  /= len(train_loader.dataset)

    # 验证
    model.eval()
    val_loss, val_acc = 0.0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            val_loss += loss.item() * x.size(0)
            val_acc  += (out.argmax(1) == y).sum().item()
    val_loss /= len(val_loader.dataset)
    val_acc  /= len(val_loader.dataset)

    print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
        torch.save(model.state_dict(), ckpt_path)
        print("✅ Saved Best Model!")

    # 每个 epoch 保存一次模型
    epoch_ckpt_path = f"/root/autodl-fs/ckpt_vit/epoch_{epoch+1}.pth"
    os.makedirs(os.path.dirname(epoch_ckpt_path), exist_ok=True)
    torch.save(model.state_dict(), epoch_ckpt_path)

# —— 7. 加载并评估测试集 —— #
model.load_state_dict(torch.load(ckpt_path))
model.eval()
test_loss, test_acc, total = 0.0, 0, 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        test_loss += loss.item() * x.size(0)
        test_acc  += (out.argmax(1) == y).sum().item()
        total += y.size(0)

test_loss /= total
test_acc  /= total
print(f"🎯 Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")

# —— 8. 所有 epoch 的模型推理 —— #
print("\n📊 Evaluating all epoch checkpoints on test set:")
epoch_results = []

for e in range(1, num_epochs + 1):
    ckpt_file = f"/root/autodl-fs/ckpt_vit/epoch_{e}.pth"
    if not os.path.exists(ckpt_file):
        print(f"❌ Epoch {e} model not found.")
        continue

    model.load_state_dict(torch.load(ckpt_file))
    model.eval()

    test_loss, test_acc, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            test_loss += loss.item() * x.size(0)
            test_acc  += (out.argmax(1) == y).sum().item()
            total += y.size(0)

    test_loss /= total
    test_acc  /= total
    epoch_results.append((e, test_loss, test_acc))
    print(f"📁 Epoch {e:02d} | Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")


Epoch 1/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 1] Train Loss: 0.4655, Acc: 0.7873 | Val Loss: 0.3960, Acc: 0.8529
✅ Saved Best Model!


Epoch 2/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 2] Train Loss: 0.2537, Acc: 0.8885 | Val Loss: 0.4012, Acc: 0.8431


Epoch 3/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 3] Train Loss: 0.1300, Acc: 0.9537 | Val Loss: 0.4311, Acc: 0.8529


Epoch 4/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 4] Train Loss: 0.0524, Acc: 0.9828 | Val Loss: 0.7007, Acc: 0.8333


Epoch 5/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 5] Train Loss: 0.0618, Acc: 0.9828 | Val Loss: 0.4629, Acc: 0.8529


Epoch 6/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 6] Train Loss: 0.0506, Acc: 0.9828 | Val Loss: 0.3732, Acc: 0.8824
✅ Saved Best Model!


Epoch 7/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 7] Train Loss: 0.0378, Acc: 0.9897 | Val Loss: 0.4176, Acc: 0.8824


Epoch 8/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 8] Train Loss: 0.0139, Acc: 0.9966 | Val Loss: 0.6797, Acc: 0.8529


Epoch 9/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 9] Train Loss: 0.0025, Acc: 1.0000 | Val Loss: 0.5838, Acc: 0.8824


Epoch 10/10:   0%|          | 0/19 [00:00<?, ?it/s]

[Epoch 10] Train Loss: 0.0007, Acc: 1.0000 | Val Loss: 0.7245, Acc: 0.8431


  model.load_state_dict(torch.load(ckpt_path))


🎯 Test Loss: 0.0654 | Test Accuracy: 0.9825

📊 Evaluating all epoch checkpoints on test set:


  model.load_state_dict(torch.load(ckpt_file))


📁 Epoch 01 | Test Loss: 0.2961 | Test Accuracy: 0.8745
📁 Epoch 02 | Test Loss: 0.1996 | Test Accuracy: 0.9241
📁 Epoch 03 | Test Loss: 0.1119 | Test Accuracy: 0.9635
📁 Epoch 04 | Test Loss: 0.1399 | Test Accuracy: 0.9606
📁 Epoch 05 | Test Loss: 0.0816 | Test Accuracy: 0.9766
📁 Epoch 06 | Test Loss: 0.0654 | Test Accuracy: 0.9825
📁 Epoch 07 | Test Loss: 0.0777 | Test Accuracy: 0.9781
📁 Epoch 08 | Test Loss: 0.1053 | Test Accuracy: 0.9766
📁 Epoch 09 | Test Loss: 0.0876 | Test Accuracy: 0.9825
📁 Epoch 10 | Test Loss: 0.1083 | Test Accuracy: 0.9766
