In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []
        for label, cls in enumerate(self.classes):
            cls_dir = os.path.join(root_dir, cls)
            for img_name in os.listdir(cls_dir):
                self.image_paths.append(os.path.join(cls_dir, img_name))
                self.labels.append(label)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = np.load(img_path).astype(np.float32)
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# 计算训练集的均值和标准差
train_dataset = CustomDataset(root_dir='TRAIN', transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)
data = next(iter(train_loader))[0]
train_mean = data.mean()
train_std = data.std()
print(f"Train Mean: {train_mean}, Train Std: {train_std}")

# 定义预处理转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[train_mean], std=[train_std])
])

# 重新加载数据集并可视化
train_dataset = CustomDataset(root_dir='TRAIN', transform=transform)
test_dataset = CustomDataset(root_dir='TEST', transform=transform)

def imshow(img, title):
    img = img.numpy().squeeze()
    plt.imshow(img, cmap='gray')
    plt.title(title)
    plt.axis('off')

plt.figure(figsize=(10, 5))
for i in range(5):
    img, label = train_dataset[i]
    plt.subplot(2, 5, i+1)
    imshow(img, f'Train Label {label}')
for i in range(5):
    img, label = test_dataset[i]
    plt.subplot(2, 5, i+6)
    imshow(img, f'Test Label {label}')
plt.show()

FileNotFoundError: [WinError 3] 系统找不到指定的路径。: 'TRAIN'