In [2]:
import os
import sys
import torch
import torchvision
from torchvision import transforms
from tqdm import tqdm

def print_dataset_info(dataset_name, train_set, test_set):
    """打印数据集信息"""
    print(f"\n{dataset_name} 数据集信息:")
    print(f"训练集大小: {len(train_set)}")
    print(f"测试集大小: {len(test_set)}")
    print(f"数据形状: {train_set.data.shape[1:] if hasattr(train_set, 'data') else train_set.data.shape[1:]}")
    if hasattr(train_set, 'classes'):
        print(f"类别: {train_set.classes}")

def download_dataset(dataset_class, dataset_name, root_path, transform=None):
    """下载并返回数据集"""
    try:
        # 下载训练集
        print(f"\n正在下载 {dataset_name} 训练集...")
        train_set = dataset_class(
            root=root_path,
            train=True,
            download=True,
            transform=transform
        )
        
        # 下载测试集
        print(f"正在下载 {dataset_name} 测试集...")
        test_set = dataset_class(
            root=root_path,
            train=False,
            download=True,
            transform=transform
        )
        
        print(f"{dataset_name} 下载完成！")
        return train_set, test_set
        
    except Exception as e:
        print(f"下载 {dataset_name} 时出错: {str(e)}")
        return None, None

def main():
    # 检查CUDA是否可用
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 设置数据集保存路径
    data_root = '../data' if os.path.basename(os.getcwd()) == 'federated_learning' else './data'
    
    # 创建数据集目录
    datasets = {
        'mnist': {'path': os.path.join(data_root, 'mnist'), 'class': torchvision.datasets.MNIST},
        'fmnist': {'path': os.path.join(data_root, 'fmnist'), 'class': torchvision.datasets.FashionMNIST},
        'cifar10': {'path': os.path.join(data_root, 'cifar10'), 'class': torchvision.datasets.CIFAR10}
    }
    
    for name, info in datasets.items():
        os.makedirs(info['path'], exist_ok=True)
    
    # 设置基本变换
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    cifar_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # 下载数据集
    for name, info in datasets.items():
        transform = cifar_transform if name == 'cifar10' else mnist_transform
        train_set, test_set = download_dataset(
            info['class'],
            name.upper(),
            info['path'],
            transform
        )
        
        if train_set and test_set:
            print_dataset_info(name.upper(), train_set, test_set)

if __name__ == "__main__":
    main()


使用设备: cpu

正在下载 MNIST 训练集...
正在下载 MNIST 测试集...
MNIST 下载完成！

MNIST 数据集信息:
训练集大小: 60000
测试集大小: 10000
数据形状: torch.Size([28, 28])
类别: ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

正在下载 FMNIST 训练集...
正在下载 FMNIST 测试集...
FMNIST 下载完成！

FMNIST 数据集信息:
训练集大小: 60000
测试集大小: 10000
数据形状: torch.Size([28, 28])
类别: ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

正在下载 CIFAR10 训练集...
正在下载 CIFAR10 测试集...
CIFAR10 下载完成！

CIFAR10 数据集信息:
训练集大小: 50000
测试集大小: 10000
数据形状: (32, 32, 3)
类别: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
