<a href="https://colab.research.google.com/github/Carba6/deeplearning/blob/main/Wide_Resnet_28_10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from wide_resnet import Wide_ResNet

def main():

    batch_size = 128
    learning_rate = 0.1
    epochs = 200
    weight_decay = 0.0005
    momentum = 0.9
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 数据预处理
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 加载CIFAR-10数据集
    train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    # 创建Wide ResNet模型
    model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.0).to(device)

    # 设置损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

    # 设置学习率调度器
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120], gamma=0.1)

    # 训练和测试函数
    def train_epoch(model, dataloader, criterion, optimizer, device):
        model.train()
        running_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        return running_loss / len(dataloader)


    def test(model, dataloader, criterion, device):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return correct / total


    # 训练循环
    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        test_accuracy = test(model, test_loader, criterion, device)
        scheduler.step()
        print(f"Epoch: {epoch}, Loss: {train_loss:.4f}, Test Accuracy: {test_accuracy * 100:.2f}%")
    # 保存最终模型
    torch.save(model.state_dict(), "wide_resnet_final.pth")

if __name__ == '__main__':
    main()


Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 28x10
Is GPU available? True
Current device: 0
Epoch: 1, Loss: 1.7484, Test Accuracy: 24.22%
Epoch: 2, Loss: 1.2701, Test Accuracy: 52.82%
Epoch: 3, Loss: 1.0027, Test Accuracy: 60.83%
Epoch: 4, Loss: 0.8211, Test Accuracy: 65.47%
Epoch: 5, Loss: 0.6848, Test Accuracy: 68.19%
Epoch: 6, Loss: 0.5792, Test Accuracy: 75.07%
Epoch: 7, Loss: 0.5003, Test Accuracy: 74.53%
Epoch: 8, Loss: 0.4389, Test Accuracy: 78.64%
Epoch: 9, Loss: 0.3989, Test Accuracy: 81.48%
Epoch: 10, Loss: 0.3661, Test Accuracy: 73.05%
Epoch: 11, Loss: 0.3283, Test Accuracy: 82.54%
Epoch: 12, Loss: 0.2996, Test Accuracy: 81.67%
Epoch: 13, Loss: 0.2810, Test Accuracy: 79.89%
Epoch: 14, Loss: 0.2587, Test Accuracy: 85.04%
Epoch: 15, Loss: 0.2415, Test Accuracy: 83.43%
Epoch: 16, Loss: 0.2307, Test Accuracy: 83.15%
Epoch: 17, Loss: 0.2056, Test Accuracy: 83.05%
Epoch: 18, Loss: 0.1943, Test Accuracy: 83.53%
Epoch: 19, Loss: 0.1917, T