In [1]:
from datasets.segmentation import ImageSegmentationDataset
from transforms.base import Compose, ToTensor
from transforms.vision_basic import Resize, RandomHorizontalFlip,RandomVerticalFlip,Normalize
from torch.utils.data import DataLoader

In [8]:
from pathlib import Path

img_dir  = Path("../data/DRIVE/training/images")
mask_dir = Path("../data/DRIVE/training/1st_manual")

# 1. 把目录里所有 .tif / .gif 文件取出来，并排序，保持一一对应
image_paths = sorted(str(p) for p in img_dir.glob("*.tif"))
mask_paths  = sorted(str(p) for p in mask_dir.glob("*.gif"))

print("num images:", len(image_paths))
print("num masks :", len(mask_paths))
print("first 3 images:", image_paths[:3])
print("first 3 masks :", mask_paths[:3])


num images: 20
num masks : 20
first 3 images: ['../data/DRIVE/training/images/21_training.tif', '../data/DRIVE/training/images/22_training.tif', '../data/DRIVE/training/images/23_training.tif']
first 3 masks : ['../data/DRIVE/training/1st_manual/21_manual1.gif', '../data/DRIVE/training/1st_manual/22_manual1.gif', '../data/DRIVE/training/1st_manual/23_manual1.gif']


In [9]:
seg_transform = Compose([
    Resize((224,224)),
    RandomHorizontalFlip(p=0.5),
    RandomVerticalFlip(p=0.5),
    ToTensor(),
    Normalize(
        mean = [0.485,0.456,0.406],
        std = [0.229,0.224,0.225],
    ),
])

In [10]:
# 3. 构建 Dataset 和 DataLoader
ds = ImageSegmentationDataset(
    image_paths=image_paths,
    mask_paths=mask_paths,
    transform=seg_transform,
)

In [11]:
seg_loader = DataLoader(ds,batch_size = 1,shuffle=True)

In [15]:
for batch in seg_loader:
    print(batch["image"].shape)  # 预期: [1, 3, 224, 224]
    print(batch["mask"].shape)   # 预期: [1, 224, 224] 或 [224, 224]，我们下一步来调整
    break

torch.Size([1, 3, 224, 224])
torch.Size([1, 224, 224])


In [13]:
for batch in seg_loader:
    print("image:", batch["image"].dtype, batch["image"].shape)
    print("mask:", batch["mask"].dtype, batch["mask"].shape)
    print("mask unique values:", batch["mask"].unique())
    break

image: torch.float32 torch.Size([1, 3, 224, 224])
mask: torch.int64 torch.Size([1, 224, 224])
mask unique values: tensor([0, 1])


In [14]:
print("len(ds) =", len(ds))
for batch in seg_loader:
    print(batch["image"].shape, batch["mask"].shape)
    break


len(ds) = 20
torch.Size([1, 3, 224, 224]) torch.Size([1, 224, 224])
