In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset
import torch_optimizer as optim  # 提供 Ranger 优化器

import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

In [3]:
# 超参数
batch_size = 128
lr = 1e-3
num_epochs = 20
model_name = 'resnet50'  # 可选：'simplecnn' 或 'resnet50'

# --- Setup MPS device ---
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

In [5]:
# 可选模型：SimpleCNN 或 ResNet50
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


def get_model(name='resnet50', num_classes=10):
    if name.lower() == 'resnet50':
        model = torchvision.models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        model = SimpleCNN(num_classes)
    return model







In [6]:
def train_epoch(model, device, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(loader.dataset)


def evaluate(model, device, loader, criterion):
    model.eval()
    correct, total_loss = 0, 0.0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(targets).sum().item()
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [7]:
def split_train_val_index(full_train, train_ratio=0.8):
    train_indices, val_indices = random_split(
        list(range(len(full_train))),
        [int(len(full_train) * train_ratio), len(full_train) - int(len(full_train) * train_ratio)]
    )
    return train_indices, val_indices

In [8]:
def get_data_augmentation(train_indices, val_indices):
    # 数据增强与标准化
    transform_train = transforms.Compose([
        transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),

        transforms.RandomCrop(120, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2430, 0.2610))
    ])
    transform_test = transforms.Compose([
        transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),

        
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2430, 0.2610))
    ])

    train_set = Subset(
        torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train),
        train_indices.indices if hasattr(train_indices, 'indices') else train_indices
    )
    val_set = Subset(
        torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_test),
        val_indices.indices if hasattr(val_indices, 'indices') else val_indices
    )

    test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    
    return train_set, val_set, test_set

In [None]:
# 加载训练集并拆分为 train/val
full_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
train_indices, val_indices = split_train_val_index(full_train, train_ratio=0.8)

# get train/val/test sets after data pre-processing and augmentation
train_set, val_set, test_set = get_data_augmentation(train_indices, val_indices)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:


# 模型、损失、优化器
# model = get_model(model_name).to(device)

# Use ResNet50 as the model
num_classes = 10
model = torchvision.models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# model = SimpleCNN(num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Ranger(model.parameters(), lr=lr)

# get the information of the optimizer
print(optimizer)

# 训练与验证
best_val_acc = 0.0
for epoch in range(num_epochs):
    train_loss = train_epoch(model, device, train_loader, criterion, optimizer)
    val_loss, val_acc = evaluate(model, device, val_loader, criterion)
    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc * 100:.2f}%")
    # 保存最优模型
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), './model/best_model_cnn_cifar.pth')





Files already downloaded and verified
Files already downloaded and verified




Ranger (
Parameter Group 0
    N_sma_threshhold: 5
    alpha: 0.5
    betas: (0.95, 0.999)
    eps: 1e-05
    k: 6
    lr: 0.001
    step_counter: 0
    weight_decay: 0
)


	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value = 1) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1581.)
  exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)


Epoch 00 | Train Loss: 1.9154 | Val Loss: 1.7450 | Val Acc: 37.40%
Epoch 01 | Train Loss: 1.4780 | Val Loss: 1.3336 | Val Acc: 51.47%
Epoch 02 | Train Loss: 1.2491 | Val Loss: 1.2772 | Val Acc: 54.82%
Epoch 03 | Train Loss: 1.0720 | Val Loss: 1.1537 | Val Acc: 59.79%
Epoch 04 | Train Loss: 0.9241 | Val Loss: 0.8925 | Val Acc: 68.89%
Epoch 05 | Train Loss: 0.8250 | Val Loss: 0.8123 | Val Acc: 71.11%
Epoch 06 | Train Loss: 0.7251 | Val Loss: 0.6800 | Val Acc: 75.66%
Epoch 07 | Train Loss: 0.6416 | Val Loss: 0.6667 | Val Acc: 77.08%
Epoch 08 | Train Loss: 0.5728 | Val Loss: 0.6540 | Val Acc: 77.21%
Epoch 09 | Train Loss: 0.5313 | Val Loss: 0.7083 | Val Acc: 75.52%
Epoch 10 | Train Loss: 0.4791 | Val Loss: 0.8700 | Val Acc: 71.42%
Epoch 11 | Train Loss: 0.4458 | Val Loss: 0.5106 | Val Acc: 82.61%
Epoch 12 | Train Loss: 0.4163 | Val Loss: 0.5449 | Val Acc: 81.77%
Epoch 13 | Train Loss: 0.3813 | Val Loss: 0.5040 | Val Acc: 82.97%
Epoch 14 | Train Loss: 0.3624 | Val Loss: 0.5453 | Val Acc: 81

  model.load_state_dict(torch.load('best_model.pth'))



Final Test Loss: 0.4944 | Final Test Acc: 84.23%


In [15]:
# 测试集评估
model.load_state_dict(torch.load('./model/best_model_cnn_cifar.pth'))
test_loss, test_acc = evaluate(model, device, test_loader, criterion)
print(f"\nFinal Test Loss: {test_loss:.4f} | Final Test Acc: {test_acc * 100:.2f}%")

  model.load_state_dict(torch.load('./model/best_model_cnn_cifar.pth'))



Final Test Loss: 0.4944 | Final Test Acc: 84.23%
