In [1]:
# 计算数据集的均值和标准差
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

def calculate_dataset_stats(dataset_path, batch_size=64, num_workers=4):
    # 只使用 ToTensor 变换，不进行标准化
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    # 加载数据集
    dataset = datasets.ImageFolder(dataset_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

    # 初始化变量
    total_mean = torch.zeros(3)
    total_var = torch.zeros(3)
    total_images = 0

    # 计算均值和方差
    for images, _ in dataloader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        total_mean += images.mean(2).sum(0)
        total_var += images.var(2).sum(0)
        total_images += batch_samples

    # 计算最终的均值和标准差
    mean = total_mean / total_images
    std = torch.sqrt(total_var / total_images)

    return mean, std

# 使用函数
dataset_path = './archive'
mean, std = calculate_dataset_stats(dataset_path)

print(f"计算得到的均值: {mean}")
print(f"计算得到的标准差: {std}")

计算得到的均值: tensor([0.6953, 0.6752, 0.6424])
计算得到的标准差: tensor([0.0941, 0.0914, 0.0880])
