In [41]:
from data.transform import (
    volume_transform,
    # slice_transform_train,
    # slice_transform_valid,
    # FilterSliced,
)
from glob import glob
import os
from monai.transforms.utils import generate_spatial_bounding_box
import os
import numpy as np
import torch
import nibabel as nib
from monai.transforms import Compose, LoadImaged, Spacingd  # for example
from torch.utils.data import Dataset

In [None]:
import os
import numpy as np
import nibabel as nib
import torch


def process_and_save_slices(
    path_dicts,
    volume_transform,
    samples_per_image=12,
    output_dir="sliced_output",
    source_key="label",  # which key to use for bounding-box
):
    """
    1) For each (image, label) file in `path_dicts`, load & preprocess them in 3D.
    2) Find the non-zero bounding box based on `source_key` (e.g. label).
    3) Randomly pick `samples_per_image` slice indices within that bounding box's D-range.
    4) Extract and (optionally) transform each 2D slice, then save as a separate NIfTI.
    """
    os.makedirs(output_dir, exist_ok=True)

    for subj_idx, data_dict in enumerate(path_dicts):
        # ---------------------------------------------------------
        # 1) Load & apply 3D transforms (e.g. LoadImaged, Spacingd, etc.)
        # ---------------------------------------------------------
        # img_nii = nib.load(data_dict['image'])
        # print(img_nii.shape) #(512, 512, 275)
        processed_3d = volume_transform(data_dict)
        img_3d = processed_3d["image"]  
        lbl_3d = processed_3d["label"]
        # print(img_3d.shape)  shape assumed [1, 380, 380, 275])

        # ---------------------------------------------------------
        # 2) Find bounding box from the `source_key`, usually "label"
        #    box_start, box_end are each [d_start, h_start, w_start].
        # ---------------------------------------------------------
        vol_for_bbox = processed_3d[source_key]
        box_start, box_end = generate_spatial_bounding_box(vol_for_bbox)

        # box_start[-1] / box_end[-1] is W,
        # box_start[-2] / box_end[-2] is H,
        # box_start[-3] / box_end[-3] is D if shape is [C, D, H, W]
        depth_min, depth_max = box_start[0], box_end[0]  # bounding box along the D dimension

        if depth_min >= depth_max:
            # If bounding box is invalid or empty, skip
            print(f"No non-zero region found for subject {subj_idx}, skipping.")
            continue
        # ---------------------------------------------------------
        # 3) Randomly pick `samples_per_image` slices from [depth_min, depth_max)
        # ---------------------------------------------------------
        slice_idx = torch.randint(depth_min, depth_max, (samples_per_image,))

        # ---------------------------------------------------------
        # 4) Loop over the chosen slice indices, extract 2D slices, apply optional slice_transform, then save
        # ---------------------------------------------------------
        for _, d_idx in enumerate(slice_idx):
            slice_image = img_3d[:, d_idx].clone()  # shape [C, H, W]
            slice_label = lbl_3d[:, d_idx].clone()

            # Convert to numpy for saving
            slice_image_np = slice_image.cpu().numpy()
            slice_label_np = slice_label.cpu().numpy()

            # Often channel dimension is [C=1, H, W], so remove channel if desired
            if slice_image_np.shape[0] == 1:
                slice_image_np = slice_image_np[0]
            if slice_label_np.shape[0] == 1:
                slice_label_np = slice_label_np[0]

            # Build file names
            out_img_name = f"subject_{subj_idx}_slice_{d_idx}_img.nii.gz"
            out_lbl_name = f"subject_{subj_idx}_slice_{d_idx}_lbl.nii.gz"
            out_img_path = os.path.join(output_dir, out_img_name)
            out_lbl_path = os.path.join(output_dir, out_lbl_name)

            # Save NIfTI
            nib.save(nib.Nifti1Image(slice_image_np, affine=np.eye(4)), out_img_path)
            nib.save(nib.Nifti1Image(slice_label_np, affine=np.eye(4)), out_lbl_path)

            print(f"Saved: {out_img_path}, {out_lbl_path}")


In [48]:
import torch
from monai.data.meta_obj import get_track_meta
from monai.transforms.utils import generate_spatial_bounding_box
from monai.utils.type_conversion import convert_data_type, convert_to_tensor
from monai.transforms import (
    Compose,
    RandAffined,
    RandGaussianNoised,
    MapTransform,
    ToTensord,
    LoadImaged,
    Orientationd,
    CenterSpatialCropd,
    Resized,
    NormalizeIntensityd,
    Spacingd,
    Rand2DElasticd,
)

slice_transform_train = Compose(
    [
        LoadImaged(keys=["image", "label"], image_only=False, ensure_channel_first=True),
        Resized(
            keys=["image", "label"],
            spatial_size=[256, 256],
            mode=("bilinear", "nearest"),
        ),
        RandAffined(
            keys=["image", "label"],
            mode=("bilinear", "nearest"),
            prob=0.5,
            rotate_range=(3.14 / 6, 3.14 / 6),
            scale_range=(0.2, 0.2),
            translate_range=(10, 10),
        ),
        Rand2DElasticd(
            keys=["image", "label"],
            spacing=(20, 20),
            magnitude_range=(1, 2),
            prob=0.5,
            padding_mode="zeros",
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys=["image"]),
        RandGaussianNoised(keys=["image"], prob=0.5, std=0.5),
        ToTensord(keys=["image", "label"]),
    ]
)

slice_transform_valid = Compose(
    [   LoadImaged(keys=["image", "label"], image_only=False, ensure_channel_first=True),
        Resized(
            keys=["image", "label"],
            spatial_size=[192, 192],
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys=["image"]),
        ToTensord(keys=["image", "label"]),
    ]
)

In [80]:
# Suppose you have something like:
image_set = 'train'
file_paths = glob(os.path.join(r'C:\Users\SongL\Desktop\BayeSeg\dataset\ImageCAS', image_set, "*.nii.gz"))
image_paths, label_paths = [], []
for path in file_paths:
    if path.split("\\")[-1][9:12] in ["seg", "Seg"]:
        label_paths.append(path)
    else:
        image_paths.append(path)

image_paths, label_paths = sorted(image_paths), sorted(label_paths)

path_dicts = [
    {"image": image_path, "label": label_path}
    for image_path, label_path in zip(image_paths, label_paths)
]

In [81]:
process_and_save_slices(
    path_dicts,
    volume_transform=volume_transform, 
    samples_per_image=12,
    output_dir= os.path.join("E:\sliced_output",image_set),
)



Saved: E:\sliced_output\train\subject_0_slice_78_img.nii.gz, E:\sliced_output\train\subject_0_slice_78_lbl.nii.gz
Saved: E:\sliced_output\train\subject_0_slice_265_img.nii.gz, E:\sliced_output\train\subject_0_slice_265_lbl.nii.gz
Saved: E:\sliced_output\train\subject_0_slice_229_img.nii.gz, E:\sliced_output\train\subject_0_slice_229_lbl.nii.gz
Saved: E:\sliced_output\train\subject_0_slice_48_img.nii.gz, E:\sliced_output\train\subject_0_slice_48_lbl.nii.gz
Saved: E:\sliced_output\train\subject_0_slice_227_img.nii.gz, E:\sliced_output\train\subject_0_slice_227_lbl.nii.gz
Saved: E:\sliced_output\train\subject_0_slice_193_img.nii.gz, E:\sliced_output\train\subject_0_slice_193_lbl.nii.gz
Saved: E:\sliced_output\train\subject_0_slice_100_img.nii.gz, E:\sliced_output\train\subject_0_slice_100_lbl.nii.gz
Saved: E:\sliced_output\train\subject_0_slice_115_img.nii.gz, E:\sliced_output\train\subject_0_slice_115_lbl.nii.gz
Saved: E:\sliced_output\train\subject_0_slice_134_img.nii.gz, E:\sliced_outp

In [79]:
# Suppose you have something like:
image_set = 'val'
file_paths = glob(os.path.join(r'E:\sliced_output', image_set, "*.nii.gz"))
image_paths, label_paths = [], []
for path in file_paths:
    print(path.split("\\")[-1][-10:-7])
    # if path.split("\\")[-1][-10:-7] == 'lbl':
    #     label_paths.append(path)
#     else:
#         image_paths.append(path)

# image_paths, label_paths = sorted(image_paths), sorted(label_paths)
# slice_dicts = [
#     {"image": image_path, "label": label_path}
#     for image_path, label_path in zip(image_paths, label_paths)
# ]

img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl
img
lbl


In [78]:
if image_set == 'train':
    slice_transform = slice_transform_train
else:
    slice_transform = slice_transform_valid
    
class Slices2DDataset(Dataset):
    def __init__(self, data, transform=None):
        """
        data_dicts: list of {"image": <path>, "label": <path>}
        slice_transform: optional transform pipeline (e.g. LoadImaged, etc.)
        """
        self.data_dicts = data
        self.slice_transform = transform

    def __len__(self):
        return len(self.data_dicts)

    def __getitem__(self, idx):
        d = self.data_dicts[idx]
        # If you want to load the data and transform it:
        if self.slice_transform is not None:
            d = self.slice_transform(d)
        return d

In [74]:
for _, data_dict in enumerate(slice_dicts):
    # ---------------------------------------------------------
    # 1) Load & apply 3D transforms (e.g. LoadImaged, Spacingd, etc.)
    # ---------------------------------------------------------
    img_nii = nib.load(data_dict['image'])
    processed_3d = slice_transform(data_dict)


In [76]:
dataset = Slices2DDataset(data=slice_dicts, transform=slice_transform)
from torch.utils.data import DataLoader
valid_loader = DataLoader(
    dataset,
    32,
    False,
    num_workers=0,
    pin_memory=True,
)

In [77]:
total_step = len(valid_loader)
train_iterator = iter(valid_loader)

for step in range(total_step):
    data_dict = next(train_iterator)
    samples = data_dict["image"]
    print(samples.shape)  # should be [B, C, H, W]

{'image': 'E:\\sliced_output\\val\\subject_0_slice_101_img.nii.gz', 'label': 'E:\\sliced_output\\val\\subject_0_slice_101_lbl.nii.gz'}
{'image': 'E:\\sliced_output\\val\\subject_0_slice_116_img.nii.gz', 'label': 'E:\\sliced_output\\val\\subject_0_slice_116_lbl.nii.gz'}
{'image': 'E:\\sliced_output\\val\\subject_0_slice_138_img.nii.gz', 'label': 'E:\\sliced_output\\val\\subject_0_slice_138_lbl.nii.gz'}
{'image': 'E:\\sliced_output\\val\\subject_0_slice_143_img.nii.gz', 'label': 'E:\\sliced_output\\val\\subject_0_slice_143_lbl.nii.gz'}
{'image': 'E:\\sliced_output\\val\\subject_0_slice_185_img.nii.gz', 'label': 'E:\\sliced_output\\val\\subject_0_slice_185_lbl.nii.gz'}
{'image': 'E:\\sliced_output\\val\\subject_0_slice_220_img.nii.gz', 'label': 'E:\\sliced_output\\val\\subject_0_slice_220_lbl.nii.gz'}
{'image': 'E:\\sliced_output\\val\\subject_0_slice_242_img.nii.gz', 'label': 'E:\\sliced_output\\val\\subject_0_slice_242_lbl.nii.gz'}
{'image': 'E:\\sliced_output\\val\\subject_0_slice_255_