In [1]:
import cv2
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from torchvision import tv_tensors, datasets
from PIL import Image
import json
plt.rcParams["savefig.bbox"] = "tight"

from torchvision.transforms import functional as F
from torchvision.utils import draw_segmentation_masks
from torchvision.utils import draw_bounding_boxes

from helpers import plot


class MedicalImageDataset(Dataset):
    def __init__(self, data_dict, transform=None):
        self.data_dict = data_dict
        self.transform = transform
        self.image_path_list = []
        self.mask_path_list = []
        # 遍历字典，整理数据
        for img_path, mask_path in data_dict.items():
            self.image_path_list.append(img_path)
            self.mask_path_list.append(tuple(mask_path))


    def __len__(self):
        return len(self.image_path_list)

    def __getitem__(self, idx):
        img_path = self.image_path_list[idx]
        mask_path_tuple = self.mask_path_list[idx]
        image = Image.open(img_path).convert('RGB')
        image = np.array(image)

        # print(img_path)
        # print(mask_path_tuple)

        segments = []
        bboxes = []
        labels = []
        for mask_path in mask_path_tuple:
            label = 0 if "benign" in mask_path else 1   # 假设mask文件名中包含benign则认为是良性，否则为恶性
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            # print(mask.shape) # （512，512）
            # 获取掩码中的所有唯一值
            unique_values = np.unique(mask)

            for value in unique_values:
                if value == 0:  # 忽略背景
                    continue

                # 提取特定值的掩码
                category_mask = (mask == value).astype(np.uint8) * 255
                # print(category_mask.shape)
                contours, _ = cv2.findContours(category_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

                for contour in contours:
                    # 计算边界框
                    x, y, w, h = cv2.boundingRect(contour)
                    # 计算多边形坐标
                    # segmentation = contour.flatten().tolist()

                    # segments.append(category_mask)
                    bboxes.append([x, y, x+w, y+h])
                    labels.append(label)
        # print(segments)
        targets = {
            "image_id": idx,
            # "segments": segments,
            "bboxes": bboxes,
            "labels": labels
        }
        # print(targets["segments"][0].shape)
        if self.transform:
            image = self.transform(image, targets)

        return image, targets

In [2]:


def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, image in enumerate(imgs):
        image = image.detach()  # 转换为普通tensor，不携带梯度、设备等信息
        image = F.to_pil_image(image)  # 转换为PIL Image
        axs[0, i].imshow(np.asarray(image))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [3]:

def collate_fn(batch):
    """
    自定义collate_fn函数，直接返回dataset中__getitem__方法返回的结果，而不要自动打包
    """
    for batch_img, batch_target in batch:
        # batch_img = tv_tensors.Image(batch_img)
        # segments = tv_tensors.Mask(polygon_to_mask(batch_target["segments"], batch_img.shape[0], batch_img.shape[1]))
        # segments = tv_tensors.Mask(batch_target["segments"])
        bboxes = tv_tensors.BoundingBoxes(batch_target["bboxes"], format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=batch_img.shape[-2:])
        labels = tv_tensors.TVTensor(batch_target["labels"])
        return batch_img, {"bboxes": bboxes, "labels": labels}
        # return batch_img, batch_target


# 定义转换规则，例如调整大小、归一化等
transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.RandomPhotometricDistort(p=1),
        # 随机放大图像并在放大后的图像周围填充背景。
        # fill={v2.Image: (123, 117, 104), "others": 0}：图像的填充颜色为 RGB 值 (123, 117, 104)，其他数据类型的填充值为 0。
        v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
        # RandomIoUCrop 是一个随机裁剪操作，它会根据一定的 IoU（Intersection over Union）阈值来裁剪图像。这个操作可能会导致一些边界框变得无效，所以我们需要使用 SanitizeBoundingBoxes 来移除无效的边界框。
        # v2.RandomIoUCrop(),
        v2.RandomHorizontalFlip(p=1),
        # 用于清理退化的边界框及其对应的标签和掩码。它会检查边界框的有效性，并移除那些无效的边界框。
        v2.SanitizeBoundingBoxes(),
        v2.ToDtype(torch.float32, scale=True),
    ]
)

# 假设data_dict是您提供的字典
with open("img_mask_path.json", "r") as f:
    data_dict = json.load(f)


# 创建数据集实例
dataset = MedicalImageDataset(data_dict, transform=transforms)

# 创建DataLoader实例，用于批量加载数据
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

for img, target in dataloader:
    # img = tv_tensors.Image(img)
    # print(target)
    print(img.shape)
    print(target['bboxes'].shape)
    print(f"{type(target['bboxes']) = }\n{type(target['labels']) = }\n)")
    # plot([(img, target)])

    break


AttributeError: 'tuple' object has no attribute 'shape'