In [1]:
import os
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import DataLoader


In [9]:

# 数据路径（假设已解压到此目录）
data_dir = "./data/mnist-varres"

# 数据变换：保持 RGB 格式
transform_32 = Compose([Resize((32, 32)), ToTensor()])  # 输入保持 RGB
transform_48 = Compose([Resize((48, 48)), ToTensor()])
transform_64 = Compose([Resize((64, 64)), ToTensor()])

# 加载不同分辨率的数据
dataset_32 = ImageFolder(root=data_dir, transform=transform_32)
dataset_48 = ImageFolder(root=data_dir, transform=transform_48)
dataset_64 = ImageFolder(root=data_dir, transform=transform_64)

# 数据加载器
batch_size = 16
dataloader_32 = DataLoader(dataset_32, batch_size=batch_size, shuffle=True)
dataloader_48 = DataLoader(dataset_48, batch_size=batch_size, shuffle=True)
dataloader_64 = DataLoader(dataset_64, batch_size=batch_size, shuffle=True)

# 合并所有分辨率的数据加载器
dataloaders = [dataloader_32, dataloader_48, dataloader_64]


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VariableResolutionCNN(nn.Module):
    def __init__(self, output_channels=64):
        super(VariableResolutionCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # 输入通道数改为 3
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, output_channels, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.global_pool = nn.AdaptiveMaxPool2d((1, 1))  # 全局最大池化
        self.fc = nn.Linear(output_channels, 10)  # 输出类别数为 10

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.global_pool(x)  # 输出尺寸固定为 (batch_size, channels, 1, 1)
        x = x.view(x.size(0), -1)  # 展平为 (batch_size, channels)
        x = self.fc(x)
        return x


In [11]:
# 训练函数
def train_variable_resolution(model, dataloaders, optimizer, criterion, device, epochs=10):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for dataloader in dataloaders:  # 遍历每个分辨率的数据集
            for images, labels in dataloader:
                images, labels = images.to(device), labels.to(device)

                # 前向传播
                outputs = model(images)
                loss = criterion(outputs, labels)

                # 反向传播和优化
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # 统计损失和准确率
                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, Accuracy: {accuracy:.2f}%")


In [12]:
# 测试函数
def test_variable_resolution(model, dataloaders, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for dataloader in dataloaders:
            for images, labels in dataloader:
                images, labels = images.to(device), labels.to(device)

                # 前向传播
                outputs = model(images)
                loss = criterion(outputs, labels)

                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

    accuracy = 100 * correct / total
    print(f"Test Loss: {total_loss:.4f}, Test Accuracy: {accuracy:.2f}%")


In [13]:
# 初始化模型、优化器和损失函数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VariableResolutionCNN(output_channels=64)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
train_variable_resolution(model, dataloaders, optimizer, criterion, device, epochs=10)

# 测试模型
test_variable_resolution(model, dataloaders, criterion, device)


Epoch 1/10, Loss: 5437.2775, Accuracy: 85.70%
Epoch 2/10, Loss: 5390.9052, Accuracy: 85.71%
Epoch 3/10, Loss: 5384.3496, Accuracy: 85.71%
Epoch 4/10, Loss: 5381.5310, Accuracy: 85.71%
Epoch 5/10, Loss: 5379.4571, Accuracy: 85.71%
Epoch 6/10, Loss: 5380.2745, Accuracy: 85.71%
Epoch 7/10, Loss: 5379.3883, Accuracy: 85.71%
Epoch 8/10, Loss: 5377.4326, Accuracy: 85.71%
Epoch 9/10, Loss: 5375.9705, Accuracy: 85.71%
Epoch 10/10, Loss: 5374.8203, Accuracy: 85.71%
Test Loss: 5372.6744, Test Accuracy: 85.71%
