**如何加载数据集**

In [3]:
import pathlib

root_dir = "../data/COVID-CT"

root_dir = pathlib.Path(root_dir)

samples = []
class_to_idx = {}

list(root_dir.iterdir()) # root_dir.iterdir()返回一个生成器对象，list(x.iterdir())可用于检查

[PosixPath('../data/COVID-CT/non-COVID'), PosixPath('../data/COVID-CT/COVID')]

In [4]:
class_names = []
for d in root_dir.iterdir():
    print(d)            # 打印路径本身
    print(type(d))      # 查看 d 的类型
    print(d.name)       # 查看名称，例如 "COVID" 或 "non-COVID"
    print(d.is_dir())   # 是否为文件夹
    print(d.is_file())  # 是否为文件
    if d.is_dir():
        class_names.append(d.name)

../data/COVID-CT/non-COVID
<class 'pathlib.PosixPath'>
non-COVID
True
False
../data/COVID-CT/COVID
<class 'pathlib.PosixPath'>
COVID
True
False


In [5]:
for idx, class_name in enumerate(class_names):
    print(idx)
    print(type(idx))
    print(class_name)
    print(type(class_name))
    class_to_idx[class_name] = idx

0
<class 'int'>
non-COVID
<class 'str'>
1
<class 'int'>
COVID
<class 'str'>


In [6]:
exts = {".png"}
for class_name, class_idx in class_to_idx.items():
    
    class_dir = root_dir / class_name
    print(class_dir)
    for img_path in class_dir.rglob("*"):
        if img_path.suffix.lower() in exts:
            samples.append((img_path,class_idx))

../data/COVID-CT/non-COVID
../data/COVID-CT/COVID


In [7]:
len(samples)

2483

In [9]:
from PIL import Image

img_path, label = samples[2]
img = Image.open(img_path).convert("RGB")

print(img_path,label)

../data/COVID-CT/non-COVID/Non-Covid (201).png 0


- 完整的数据加载类

In [None]:
import torch
from PIL import Image

class MyImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = pathlib.Path(root_dir)
        self.transform = transform

        self.sample = []
        self.class_to_idx = {}

        class_names = []
        for d in self.root_dir.iterdir():
            if d.is_dir():
                class_names.append(d.name)
        for idx, class_name in enumerate(class_names):
            self.class_to_idx[class_name] = idx

        exts = {".png"}
        for class_name, class_idx in self.class_to_idx.items():
            class_dir = self.root_dir / class_name
            for img_path in class_dir.rglob("*"):
                if img_path.suffix.lower() in exts:
                    self.samples.append((img_path, class_idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]

        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        return img, label