In [None]:

from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [None]:
from torchvision.datasets import ImageFolder

# 数据路径（假设已下载到指定目录）
data_dir = "./data/mnist-varres"

# 数据变换：调整为统一分辨率
transform = ToTensor()

# 加载数据：每种分辨率作为独立数据集
dataset = ImageFolder(root=data_dir, transform=transform)

# 将数据加载到内存并按分辨率分组
res_32 = [data for data in dataset.samples if '32x32' in data[0]]
res_48 = [data for data in dataset.samples if '48x48' in data[0]]
res_64 = [data for data in dataset.samples if '64x64' in data[0]]

# 使用 `glob` 加载并创建独立的 DataLoader（可选）


In [None]:
from torch import nn


class VariableResolutionCNN(nn.Module):
    def __init__(self, output_channels=64):
        super(VariableResolutionCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, output_channels, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.global_pool = nn.AdaptiveMaxPool2d((1, 1))  # 全局最大池化
        self.fc = nn.Linear(output_channels, 10)  # Fully connected layer

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.global_pool(x)  # 输出尺寸固定为 (batch_size, channels, 1, 1)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x


In [None]:
import torch
from torchvision.transforms import Compose, Resize


# 数据加载器，针对不同分辨率的数据
def load_variable_resolution_data():
    transform_32 = Compose([Resize((32, 32)), ToTensor()])
    transform_48 = Compose([Resize((48, 48)), ToTensor()])
    transform_64 = Compose([Resize((64, 64)), ToTensor()])

    dataset_32 = ImageFolder(root=data_dir, transform=transform_32)
    dataset_48 = ImageFolder(root=data_dir, transform=transform_48)
    dataset_64 = ImageFolder(root=data_dir, transform=transform_64)

    dataloader_32 = DataLoader(dataset_32, batch_size=16, shuffle=True)
    dataloader_48 = DataLoader(dataset_48, batch_size=16, shuffle=True)
    dataloader_64 = DataLoader(dataset_64, batch_size=16, shuffle=True)

    return dataloader_32, dataloader_48, dataloader_64


# 主训练循环
def train_variable_resolution(model, dataloaders, optimizer, criterion, device, epochs=10):
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for dataloader in dataloaders:  # 对每个分辨率的 DataLoader 训练
            for images, labels in dataloader:
                images, labels = images.to(device), labels.to(device)

                # 前向传播
                outputs = model(images)
                loss = criterion(outputs, labels)

                # 反向传播与优化
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # 统计结果
                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, Accuracy: {accuracy:.2f}%")


In [None]:
def test_variable_resolution(model, dataloaders, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for dataloader in dataloaders:
            for images, labels in dataloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

    accuracy = 100 * correct / total
    print(f"Test Loss: {total_loss:.4f}, Test Accuracy: {accuracy:.2f}%")


In [None]:
# 加载数据
dataloaders = load_variable_resolution_data()

# 初始化模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VariableResolutionCNN(output_channels=64).to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
train_variable_resolution(model, dataloaders, optimizer, criterion, device, epochs=10)

# 测试模型
test_variable_resolution(model, dataloaders, criterion, device)
