In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset
import matplotlib.pyplot as plt
import numpy as np

def load_mnist_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    X = torch.stack([train_dataset[i][0] for i in range(len(train_dataset))])
    y = torch.tensor([train_dataset[i][1] for i in range(len(train_dataset))])
    print(f"Normalized X range: min={X.min().item():.4f}, max={X.max().item():.4f}")
    return TensorDataset(X, y)

# 加载数据
dataset = load_mnist_data()
X, y = dataset.tensors

# 为每个类别选择一个样本
class_samples = {}
for i in range(len(y)):
    label = y[i].item()
    if label not in class_samples:
        class_samples[label] = X[i]
    if len(class_samples) == 10:  # 0-9 共 10 个类别
        break

# 按类别 0-9 排序
samples = [class_samples[i] for i in range(10)]

# 反归一化以便显示（从 [-1, 1] 转换回 [0, 1]）
samples = [(sample * 0.5 + 0.5).clamp(0, 1) for sample in samples]

# 绘制 2×5 网格
fig, axes = plt.subplots(2, 5, figsize=(10, 4))  # 调整 figsize 以适合 2×5 布局
for i, (sample, ax) in enumerate(zip(samples, axes.flatten())):
    sample = sample.squeeze().numpy()  # 移除通道维度并转为 numpy
    ax.imshow(sample, cmap='gray')
    ax.set_title(f'Class {i}')
    ax.axis('off')

plt.tight_layout()
plt.savefig('mnist_samples.png', dpi=300, bbox_inches='tight')
plt.close()


Normalized X range: min=-1.0000, max=1.0000
