In [12]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset
import torch_optimizer as optim  # 提供 Ranger 优化器
from torch.amp import autocast, GradScaler

import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# 超参数
batch_size = 128
lr = 1e-3
num_epochs = 20
model_name = 'resnet50'  # 可选：'simplecnn' 或 'resnet50'

# --- Setup MPS device ---
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

In [4]:
import torch.nn as nn
import torch

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()
        # 1×1 降维
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)
        # 3×3 空洞感受野
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(planes)
        # 1×1 恢复维度
        self.conv3 = nn.Conv2d(planes, planes * self.expansion,
                               kernel_size=1, bias=False)
        self.bn3   = nn.BatchNorm2d(planes * self.expansion)
        # 如果跨步或通道数变化，用下采样调整捷径
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

# def make_layer(in_planes, planes, blocks, stride=1):
#     downsample = None
#     out_planes = planes * Bottleneck.expansion
#     if stride != 1 or in_planes != out_planes:
#         # 用 1×1 卷积来匹配维度 & 跨步下采样
#         downsample = nn.Sequential(
#             nn.Conv2d(in_planes, out_planes,
#                       kernel_size=1, stride=stride, bias=False),
#             nn.BatchNorm2d(out_planes),
#         )
#     layers = [Bottleneck(in_planes, planes, stride, downsample)]
#     for _ in range(1, blocks):
#         layers.append(Bottleneck(out_planes, planes))
#     return nn.Sequential(*layers)

class SimpleResNetLike(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        # 初始层（可以改成 3×3 conv + BN + ReLU，去掉大核7×7）
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        # 四个 stage，block 数量可按 ResNet-50 ([3,4,6,3]) 或简化
        self.layer1 = self.make_layer( 64,  64, blocks=3, stride=1)  # 输出 256 通道
        self.layer2 = self.make_layer(256, 128, blocks=4, stride=2)  # 输出 512 通道
        self.layer3 = self.make_layer(512, 256, blocks=6, stride=2)  # 输出1024 通道
        self.layer4 = self.make_layer(1024,512, blocks=3, stride=2)  # 输出2048 通道

        # 全局池化 + 全连接
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)

        self._init_weights()

    def make_layer(self, in_planes, planes, blocks, stride=1):
        downsample = None
        out_planes = planes * Bottleneck.expansion
        if stride != 1 or in_planes != out_planes:
            # 用 1×1 卷积来匹配维度 & 跨步下采样
            downsample = nn.Sequential(
                nn.Conv2d(in_planes, out_planes,
                        kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_planes),
            )
        layers = [Bottleneck(in_planes, planes, stride, downsample)]
        for _ in range(1, blocks):
            layers.append(Bottleneck(out_planes, planes))
        return nn.Sequential(*layers)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # He 初始化
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                # 全连接层也可以做类似初始化
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                # BN 的 weight 初始化为 1，bias 为 0
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

In [5]:
# 可选模型：SimpleCNN 或 ResNet50
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 32 * 32, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


def get_model(name='resnet50', num_classes=10):
    if name.lower() == 'resnet50':
        model = torchvision.models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        model = SimpleResNetLike(num_classes)
    return model







In [14]:
def train_epoch(model, device, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(loader.dataset)


def evaluate(model, device, loader, criterion):
    model.eval()
    correct, total_loss = 0, 0.0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(targets).sum().item()
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [7]:
def split_train_val_index(full_train, train_ratio=0.8):
    train_indices, val_indices = random_split(
        list(range(len(full_train))),
        [int(len(full_train) * train_ratio), len(full_train) - int(len(full_train) * train_ratio)]
    )
    return train_indices, val_indices

In [8]:
def get_data_augmentation(train_indices, val_indices):
    # 数据增强与标准化
    transform_train = transforms.Compose([
        transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),

        transforms.RandomCrop(128, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2430, 0.2610))
    ])
    transform_test = transforms.Compose([
        transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.BICUBIC),

        
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2430, 0.2610))
    ])

    train_set = Subset(
        torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train),
        train_indices.indices if hasattr(train_indices, 'indices') else train_indices
    )
    val_set = Subset(
        torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_test),
        val_indices.indices if hasattr(val_indices, 'indices') else val_indices
    )

    test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    
    return train_set, val_set, test_set

In [9]:
# 加载训练集并拆分为 train/val
full_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
train_indices, val_indices = split_train_val_index(full_train, train_ratio=0.8)

# get train/val/test sets after data pre-processing and augmentation
train_set, val_set, test_set = get_data_augmentation(train_indices, val_indices)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [15]:


# 模型、损失、优化器
model = get_model("SimpleResNetLike").to(device)

# Use ResNet50 as the model
num_classes = 10
# model = torchvision.models.resnet50(pretrained=False)
# model.fc = nn.Linear(model.fc.in_features, num_classes)

# model = SimpleCNN(num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Ranger(model.parameters(), lr=lr)
scaler = GradScaler()  # 用于混合精度训练

# get the information of the optimizer
print(optimizer)

# 训练与验证
best_val_acc = 0.0
for epoch in range(num_epochs):
    train_loss = train_epoch(model, device, train_loader, criterion, optimizer)
    val_loss, val_acc = evaluate(model, device, val_loader, criterion)
    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc * 100:.2f}%")
    # 保存最优模型
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), './model/best_model_cnn_cifar.pth')





Ranger (
Parameter Group 0
    N_sma_threshhold: 5
    alpha: 0.5
    betas: (0.95, 0.999)
    eps: 1e-05
    k: 6
    lr: 0.001
    step_counter: 0
    weight_decay: 0
)
Epoch 00 | Train Loss: 1.9357 | Val Loss: 1.6175 | Val Acc: 41.58%
Epoch 01 | Train Loss: 1.5166 | Val Loss: 1.4149 | Val Acc: 48.95%
Epoch 02 | Train Loss: 1.2782 | Val Loss: 1.3064 | Val Acc: 53.91%
Epoch 03 | Train Loss: 1.0907 | Val Loss: 1.0505 | Val Acc: 62.52%
Epoch 04 | Train Loss: 0.9225 | Val Loss: 1.0745 | Val Acc: 62.75%
Epoch 05 | Train Loss: 0.8163 | Val Loss: 0.8176 | Val Acc: 71.36%
Epoch 06 | Train Loss: 0.7186 | Val Loss: 0.7963 | Val Acc: 72.60%
Epoch 07 | Train Loss: 0.6406 | Val Loss: 0.7573 | Val Acc: 74.21%
Epoch 08 | Train Loss: 0.5677 | Val Loss: 1.0246 | Val Acc: 67.82%
Epoch 09 | Train Loss: 0.5127 | Val Loss: 0.7489 | Val Acc: 75.75%
Epoch 10 | Train Loss: 0.4673 | Val Loss: 0.8663 | Val Acc: 72.51%
Epoch 11 | Train Loss: 0.4207 | Val Loss: 0.5818 | Val Acc: 80.23%
Epoch 12 | Train Loss: 0.

In [16]:
# 测试集评估
model.load_state_dict(torch.load('./model/best_model_cnn_cifar.pth'))
test_loss, test_acc = evaluate(model, device, test_loader, criterion)
print(f"\nFinal Test Loss: {test_loss:.4f} | Final Test Acc: {test_acc * 100:.2f}%")

  model.load_state_dict(torch.load('./model/best_model_cnn_cifar.pth'))



Final Test Loss: 0.5429 | Final Test Acc: 83.48%
