In [12]:
# 把所有的图片和分割大小调整到224*224
import os, json
from typing import Any

datasets = [
    "~/dataset/cmri",
    "~/dataset/acdc",
    "~/dataset/mnms",
]
datasets = [os.path.expanduser(dataset) for dataset in datasets]

TARGET_SIZE = 224

for dataset in datasets:
    # 处理dataset.json
    dataset_json = os.path.join(dataset, "dataset.json")
    if not os.path.exists(f"{dataset_json}.bak"):
        os.system(f"cp {dataset_json} {dataset_json}.bak")

    with open(dataset_json, "r") as f:
        dataset_dict = json.load(f)
        subjects: list = dataset_dict["subjects"]
        for subject in subjects:
            d, h, w = subject["shape"]
            scale = TARGET_SIZE / min(h, w)
            h, w = int(h * scale), int(w * scale)
            subject["shape"] = [d, h, w]

            sx, sy, sz = subject["spacing"]
            sx, sy = sx / scale, sy / scale
            subject["spacing"] = [sx, sy, sz]
        files: list = dataset_dict["files"]
        for file in files:
            d, h, w = file["shape"]
            scale = TARGET_SIZE / min(h, w)
            h, w = int(h * scale), int(w * scale)
            file["shape"] = [d, h, w]

            if file["segment_box_3d"] is not None:
                x1, y1, z1, x2, y2, z2 = file["segment_box_3d"]
                x1, y1, x2, y2 = int(x1 * scale), int(y1 * scale), int(x2 * scale), int(y2 * scale)
                file["segment_box_3d"] = [x1, y1, z1, x2, y2, z2]

            segment_box_2d: dict[str, Any] = file["segment_box_2d"]
            segment_box_2d = {
                k: [int(v[0] * scale), int(v[1] * scale), int(v[2] * scale), int(v[3] * scale)]
                for k, v in segment_box_2d.items()
            }

            sx, sy, sz = file["spacing"]
            sx, sy = sx / scale, sy / scale
            file["spacing"] = [sx, sy, sz]

    with open(dataset_json, "w") as f:
        json.dump(dataset_dict, f, indent=2)

In [None]:
import numpy as np
import torch
import torchvision.transforms as transforms

_image_transform = transforms.Compose(
    [
        transforms.Resize(
            (TARGET_SIZE, TARGET_SIZE), antialias=True, interpolation=transforms.InterpolationMode.BILINEAR
        ),
    ]
)
_segment_transform = transforms.Compose(
    [
        transforms.Resize(
            (TARGET_SIZE, TARGET_SIZE), antialias=True, interpolation=transforms.InterpolationMode.NEAREST
        ),
    ]
)
for dataset in datasets:
    # list all npz files in dataset
    npz_files = []
    for root, dirs, files in os.walk(dataset):
        for file in files:
            if file.endswith(".npz"):
                npz_files.append((root, file))

    for root, file in npz_files:
        full_path = os.path.join(root, file)
        print("process file:", full_path)
        data = dict(np.load(full_path))
        image = data["image"]
        image = torch.from_numpy(image)
        image = _image_transform(image)
        data["image"] = image.numpy()
        if "segment" in data:
            segmentation = data["segment"]
            segmentation = torch.from_numpy(segmentation)
            segmentation = _segment_transform(segmentation)
            segmentation = segmentation.to(torch.uint8)
            data["segment"] = segmentation.numpy()
            print("segmentation shape:", segmentation.shape)
        np.savez_compressed(full_path, **data)