In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset
import torch_optimizer as optim  # 提供 Ranger 优化器

import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# 可选模型：SimpleCNN 或 ResNet50
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


def get_model(name='resnet50', num_classes=10):
    if name.lower() == 'resnet50':
        model = torchvision.models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        model = SimpleCNN(num_classes)
    return model







In [None]:
def train_epoch(model, device, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(loader.dataset)


def evaluate(model, device, loader, criterion):
    model.eval()
    correct, total_loss = 0, 0.0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(targets).sum().item()
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [None]:
def main():
    # 超参数
    batch_size = 128
    lr = 1e-3
    num_epochs = 20
    model_name = 'resnet50'  # 可选：'simplecnn' 或 'resnet50'

    # --- Setup MPS device ---
    device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

    # 数据增强与标准化
    transform_train = transforms.Compose([
        transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),

        transforms.RandomCrop(120, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2430, 0.2610))
    ])
    transform_test = transforms.Compose([
        transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),

        
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2430, 0.2610))
    ])

    # 加载训练集并拆分为 train/val
    n_val = 5000
    full_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

    n_train = len(full_train) - n_val
    train_indices, val_indices = random_split(list(range(len(full_train))), [n_train, n_val])

    # 分别创建 train 和 val 子集，各自使用不同 transform
    train_set = Subset(
        torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train),
        train_indices.indices if hasattr(train_indices, 'indices') else train_indices
    )
    val_set = Subset(
        torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_test),
        val_indices.indices if hasattr(val_indices, 'indices') else val_indices
    )

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader   = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
    
    test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    # 模型、损失、优化器
    model = get_model(model_name).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Ranger(model.parameters(), lr=lr)

    # get the information of the optimizer
    print(optimizer)

    # 训练与验证
    best_val_acc = 0.0
    for epoch in range(1, num_epochs + 1):
        train_loss = train_epoch(model, device, train_loader, criterion, optimizer)
        val_loss, val_acc = evaluate(model, device, val_loader, criterion)
        print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc * 100:.2f}%")
        # 保存最优模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

    # 测试集评估
    model.load_state_dict(torch.load('best_model.pth'))
    test_loss, test_acc = evaluate(model, device, test_loader, criterion)
    print(f"\nFinal Test Loss: {test_loss:.4f} | Final Test Acc: {test_acc * 100:.2f}%")



In [None]:
if __name__ == '__main__':
    main()
