In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

def download_cifar100(batch_size=128, num_workers=2):
    # 数据增强和预处理
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    # 下载和加载 CIFAR-100 数据集
    trainset = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    testset = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return trainloader, testloader

# 示例：下载并加载 CIFAR-100 数据集
trainloader, testloader = download_cifar100()

# 检查数据集是否正确下载和加载
for images, labels in trainloader:
    print(f"Batch size: {images.size(0)}, Image shape: {images.size()[1:]}, Labels: {labels}")
    break

  warn(
  from .autonotebook import tqdm as notebook_tqdm


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:29<00:00, 5744829.19it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Batch size: 128, Image shape: torch.Size([3, 32, 32]), Labels: tensor([19,  3,  9, 31, 59, 77, 61, 24, 11, 49, 98, 13, 75, 88, 13, 49, 32, 64,
         0, 85, 45, 72, 70, 41, 35, 44, 24, 47,  0, 22, 79, 38, 63, 57, 27, 40,
        31, 76, 49, 51, 66, 40, 24, 48, 72, 92, 18, 52, 63, 55, 61, 20, 66, 46,
        12, 39, 90, 94, 73, 15, 41, 35, 15, 43,  5, 25, 50, 67, 69, 57, 83, 23,
        88, 47, 54, 28, 45, 29, 93, 33, 95, 30, 44, 21, 27, 24, 14, 12, 12, 62,
        14,  4, 30, 69,  3, 61, 14, 93, 38, 27, 41, 54, 98, 35, 75, 45,  5,  0,
        56,  1, 42, 54, 46,  0, 63,  1, 68, 29, 41, 78, 18, 84, 95, 87, 23, 72,
        82, 41])
