In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import timm
import wandb

def main():
    # 初始化 wandb 项目
    wandb.init(project="ViT-aerial-classification", config={
        "epochs": 20,
        "learning_rate": 1e-4,
        "batch_size": 32,
        "model": "vit_base_patch16_224",
        "dataset": "your-dataset"
    })
    config = wandb.config

    # 检查是否有 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 定义 transform，将 PIL Image 转换为 Tensor，并进行归一化
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # 根据模型需要调整图像尺寸
        transforms.ToTensor(),          # 转换为 Tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = datasets.ImageFolder(root='datasets/train', transform=transform)
    val_dataset   = datasets.ImageFolder(root='datasets/val', transform=transform)
    test_dataset  = datasets.ImageFolder(root='datasets/test', transform=transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
    val_loader   = torch.utils.data.DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
    test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)

    # 加载预训练的 ViT 模型，此处使用 "vit_base_patch16_224"
    # 这里不自动下载，直接从本地加载权重
    model = timm.create_model('vit_base_patch16_224', pretrained=False)
    state_dict = torch.load('pytorch_model.bin', map_location=device, weights_only=True)
    model.load_state_dict(state_dict, strict=False)
    # 修改分类头：将原有的分类头替换为适用于 15 类的新全连接层
    num_features = model.head.in_features
    model.head = nn.Linear(num_features, 15)
    model = model.to(device)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    num_epochs = config.epochs

    for epoch in range(num_epochs):
        epoch_start = time.time()  # 记录 epoch 开始时间
        model.train()
        running_loss = 0.0

        # 训练阶段
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(train_dataset)
        epoch_duration = time.time() - epoch_start

        # 计算剩余 epoch 数量及总剩余时间
        remaining_epochs = num_epochs - (epoch + 1)
        total_remaining_time = remaining_epochs * epoch_duration
        # 将剩余时间转换为分钟和秒
        rem_minutes = int(total_remaining_time // 60)
        rem_seconds = int(total_remaining_time % 60)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Time: {epoch_duration:.2f} sec')
        print(f'Estimated remaining time: {rem_minutes} min {rem_seconds} sec')

        # 验证阶段
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (preds == labels).sum().item()
        val_accuracy = val_correct / val_total
        print(f'Epoch {epoch+1}/{num_epochs}, Val Accuracy: {val_accuracy:.4f}')

        # 测试阶段


        # 记录指标到 wandb
        wandb.log({
            "epoch": epoch+1,
            "training_loss": epoch_loss,
            "val_accuracy": val_accuracy,

            "epoch_duration_sec": epoch_duration,
            "total_remaining_time_sec": total_remaining_time
        })
        print()  # 换行以便阅读
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            test_total += labels.size(0)
            test_correct += (preds == labels).sum().item()
    test_accuracy = test_correct / test_total
    print(f'Epoch {epoch+1}/{num_epochs}, Test Accuracy: {test_accuracy:.4f}')

    # 保存模型
    torch.save(model.state_dict(), "vit_model.pth")
    wandb.save("vit_model.pth")

if __name__ == '__main__':
    main()

In [None]:
# 训练结束后保存模型参数
import os
print(os.path.exists('vit_weights_final.pth'))