In [11]:
from datasets.classification import ImageClassificationDataset
from datasets.segmentation import ImageSegmentationDateset
from transforms.base import Compose,Identity,ToTensor
from transforms.vision_basic import Resize,RandomHorizontalFlip,Normalize
from torch.utils.data import DataLoader

In [17]:
from pathlib import Path

root = Path("../data/COVID-CT")
classes = ["COVID", "non-COVID"]  # 类名列表，顺序就是 label 的数字

image_paths = []
labels = []

for label, cls in enumerate(classes):
    cls_dir = root / cls
    for p in cls_dir.glob("*.png"):
        image_paths.append(str(p))
        labels.append(label)

print("num images:", len(image_paths))
print("labels sample:", labels[:10])


num images: 2481
labels sample: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [18]:
# 2. 定义一个最简单的 transform 管线（只包含 Identity）
transform = Compose([
    Resize((224,224)), #统一尺寸
    RandomHorizontalFlip(p=0.5), # 统一尺寸
    ToTensor(),
    Normalize(
        mean = [0.485,0.456,0.406],
        std = [0.229,0.224,0.225],
    ),
])

In [19]:
# 3. 构建 Dataset 和 DataLoader
ds = ImageClassificationDataset(
    image_paths=image_paths,
    labels=labels,
    transform=transform,
)

In [20]:
loader = DataLoader(ds, batch_size=1, shuffle=False)

In [21]:
for batch in loader:
    print("keys:", batch.keys())
    print("image shape:", batch["image"].shape)
    print("label:", batch["label"])
    break

keys: dict_keys(['image', 'label', 'meta'])
image shape: torch.Size([1, 3, 224, 224])
label: tensor([0])
