In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

In [2]:
class SimpleImageClassificationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        root_dir: 数据集根目录，例如 "D:/data/cats_and_dogs"
        transform: 对单张 PIL.Image 做的预处理（Resize/ToTensor/Normalize 等）
        """
        self.root_dir = root_dir
        self.transform = transform

        # 1) 找到所有类别名（子文件夹名）
        #    例如 ["cat", "dog"]
        class_names = sorted(
            d for d in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, d))
        )

        # 2) 建一个 类名 -> 类别id 的字典，例如 {"cat": 0, "dog": 1}
        self.class_to_idx = {name: idx for idx, name in enumerate(class_names)}

        # 3) 遍历每个子文件夹，把 (图片路径, 类别id) 记录下来
        self.samples = []  # 每个元素是 (image_path, label)
        exts = (".jpg", ".jpeg", ".png", ".bmp")
        for class_name in class_names:
            class_folder = os.path.join(root_dir, class_name)
            label = self.class_to_idx[class_name]

            for fname in os.listdir(class_folder):
                if fname.lower().endswith(exts):
                    img_path = os.path.join(class_folder, fname)
                    self.samples.append((img_path, label))

    def __len__(self):
        """返回数据集中样本的总数"""
        return len(self.samples)

    def __getitem__(self, idx):
        """
        根据给定的 idx 返回一条数据：
        (预处理后的 image_tensor, label)
        """
        img_path, label = self.samples[idx]

        # 用 PIL 打开图片，并转成 RGB
        img = Image.open(img_path).convert("RGB")

        # 如果传入了 transform，就对图像做预处理
        if self.transform is not None:
            img = self.transform(img)

        return img, label


In [3]:
# ====== 下面是一个简单的预处理和 DataLoader 示例 ======

# 先假设一个数据集根目录路径（请改成你自己的）
data_root = os.path.join("..", "data", "COVID-CT")

# 定义图像预处理：缩放 -> 转 tensor -> 标准化
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    )
])

# 构建 Dataset 实例
dataset = SimpleImageClassificationDataset(root_dir=data_root,
                                           transform=transform)

# 用 DataLoader 打包，方便后续训练时按 batch 取数据
dataloader = DataLoader(dataset,
                        batch_size=8,
                        shuffle=True)

# 简单测试一下：从 dataloader 里取一个 batch 看看形状
images, labels = next(iter(dataloader))
print("images shape:", images.shape)   # 期望: [batch_size, 3, 224, 224]
print("labels shape:", labels.shape)   # 期望: [batch_size]
print("labels:", labels)

images shape: torch.Size([8, 3, 224, 224])
labels shape: torch.Size([8])
labels: tensor([0, 0, 0, 0, 0, 0, 0, 1])
