In [1]:
import numpy as np
from pathlib import Path

from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityRanged,
    Resized,
)
from src.transforms import (
    FindCentroid,
    GetFixedROISize,
    CropToROI,
    Pad,
    Save,
    OrthogonalSlices
)


def extract_3d_roi(segmentation_path, destination_path, roi_size=100, save_size=96):
    pipeline = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            EnsureChannelFirstd(keys=["img", "seg"]),
            FindCentroid(),
            GetFixedROISize(roi_size),
            Pad(),
            CropToROI(),
            Resized(keys=["img"], spatial_size=[save_size] * 3),
            ScaleIntensityRanged(
                keys=["img"], a_min=-1024, a_max=3000, b_min=0, b_max=1, clip=True
            ),
            Save(output_dir=destination_path),
        ]
    )

    segmentation_path = Path(segmentation_path)
    scan_path = find_corresponding_scan(segmentation_path)
    data = {"img": str(scan_path), "seg": str(segmentation_path)}

    pipeline(data)


def preprocess_dir(segmentation_dir, destination_dir):
    for segmentation_path in Path(segmentation_dir).iterdir():
        extract_3d_roi(segmentation_path, destination_dir)


def find_corresponding_scan(segmention_path):
    scan_folder = Path(segmention_path).parent.parent / "scans"
    scan_name = Path(segmention_path).name.split(".")[0][:-2] + ".nii.gz"

    return scan_folder / scan_name

segmentation_path = r"D:\premium_data\amphia\monotherapy\split_segmentations\PREM_AM_001_0.nii.gz"
scan_path = find_corresponding_scan(segmentation_path)
data = {"img": str(scan_path), "seg": str(segmentation_path)}

In [2]:
save_size=96
roi_size=100

pipeline = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        EnsureChannelFirstd(keys=["img", "seg"]),
        FindCentroid(),
        GetFixedROISize(roi_size),
        Pad(),
        CropToROI(),
        Resized(keys=["img"], spatial_size=[save_size] * 3),
        ScaleIntensityRanged(
            keys=["img"], a_min=-1024, a_max=3000, b_min=0, b_max=1, clip=True
        ),
        OrthogonalSlices()
    ]
)
output = pipeline(data)

In [30]:
center_slices = (np.array(output['img'].shape)[1:] / 2).astype(int)


In [43]:
sagittal = output['img'][0,center_slices[0]]
coronal = output['img'][0,:, center_slices[1]]
transverse = output['img'][0,:,:,center_slices[2]]

tensor([[[0.2322, 0.2322, 0.2333,  ..., 0.2795, 0.2826, 0.2826],
         [0.2289, 0.2289, 0.2307,  ..., 0.2802, 0.2818, 0.2818],
         [0.2270, 0.2270, 0.2292,  ..., 0.2862, 0.2857, 0.2857],
         ...,
         [0.0518, 0.0518, 0.0566,  ..., 0.1306, 0.1323, 0.1323],
         [0.0214, 0.0214, 0.0220,  ..., 0.0304, 0.0306, 0.0306],
         [0.0239, 0.0239, 0.0229,  ..., 0.0214, 0.0198, 0.0198]],

        [[0.2218, 0.2218, 0.2239,  ..., 0.2344, 0.2208, 0.2208],
         [0.2229, 0.2229, 0.2246,  ..., 0.2437, 0.2397, 0.2397],
         [0.2222, 0.2222, 0.2227,  ..., 0.2444, 0.2465, 0.2465],
         ...,
         [0.2273, 0.2273, 0.2250,  ..., 0.2547, 0.2506, 0.2506],
         [0.2249, 0.2249, 0.2230,  ..., 0.2539, 0.2517, 0.2517],
         [0.2294, 0.2294, 0.2271,  ..., 0.2567, 0.2529, 0.2529]],

        [[0.2218, 0.2218, 0.2239,  ..., 0.2344, 0.2208, 0.2208],
         [0.2229, 0.2229, 0.2246,  ..., 0.2437, 0.2397, 0.2397],
         [0.2222, 0.2222, 0.2227,  ..., 0.2444, 0.2465, 0.