In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
import os
from tqdm import tqdm

def calculate_mean_std(data_path):
    channel_sum = np.zeros(3)
    channel_sum_squared = np.zeros(3)
    pixel_count = 0
    
    # 获取所有图片文件
    image_files = []
    for root, _, files in os.walk(data_path):
        for file in files:
            if file.endswith(('.jpg', '.jpeg', '.png')):
                image_files.append(os.path.join(root, file))
    
    # 遍历所有图片并显示进度条
    for img_path in tqdm(image_files, desc="Processing images"):
        img = Image.open(img_path).convert('RGB')
        img = np.array(img) / 255.0  # 归一化到 [0,1]
        
        pixel_count += (img.shape[0] * img.shape[1])
        channel_sum += np.sum(img, axis=(0, 1))
        channel_sum_squared += np.sum(np.square(img), axis=(0, 1))
    
    # 计算均值
    mean = channel_sum / pixel_count
    
    # 计算标准差
    std = np.sqrt((channel_sum_squared / pixel_count) - np.square(mean))
    
    return mean, std

mean, std = calculate_mean_std(data_path)
print('均值: ', mean)
print('标准差: ', std)


Processing images: 100%|██████████| 27153/27153 [02:28<00:00, 182.43it/s]

均值:  [0.75885325 0.77881755 0.75984938]
标准差:  [0.24770346 0.23468548 0.26303263]





In [49]:
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

In [53]:
#@save
def read_csv_labels(fname):
    """读取fname来给标签字典返回一个文件名"""
    with open(fname, 'r') as f:
        # 跳过文件头行(列名)
        lines = f.readlines()[1:]
    tokens = [l.rstrip().split(',') for l in lines]
    return dict(((name, label) for name, label in tokens))

labels = read_csv_labels(os.path.join('train.csv'))
print('# 训练样本 :', len(labels))
print('# 类别 :', len(set(labels.values())))
print(list(labels.keys())[0])

# 训练样本 : 18353
# 类别 : 176
images/0.jpg


In [57]:
#@save
def copyfile(filename, target_dir):
    """将文件复制到目标目录"""
    os.makedirs(target_dir, exist_ok=True)
    shutil.copy(filename, target_dir)

#@save
def reorg_data(data_dir, labels, valid_ratio):
    """将验证集从原始的训练集中拆分出来"""
    # 训练数据集中样本最少的类别中的样本数
    n = collections.Counter(labels.values()).most_common()[-1][1]
    # 验证集中每个类别的样本数
    n_valid_per_label = max(1, math.floor(n * valid_ratio))
    label_count = {}
    for train_file in os.listdir(os.path.join(data_dir, 'images')):

        if 'images/'+ train_file  in labels:
            label = labels['images/'+ train_file]
            if label not in label_count or label_count[label] < n_valid_per_label:
                copyfile(os.path.join(data_dir, 'images', train_file), os.path.join(data_dir, 'data',
                                         'valid', label))
                label_count[label] = label_count.get(label, 0) + 1
            else:
                copyfile(os.path.join(data_dir, 'images', train_file), os.path.join(data_dir, 'data',
                                         'train', label))
        else:
            copyfile(os.path.join(data_dir, 'images', train_file), os.path.join(data_dir, 'data',
                                         'test'))
    return n_valid_per_label

In [59]:
batch_size = 32 # 批量大小
valid_ratio = 0.1
data_dir = './'
labels = read_csv_labels(os.path.join(data_dir, 'train.csv'))
reorg_data(data_dir, labels, valid_ratio)

5