In [1]:
import jupyter_black

jupyter_black.load(line_length=79)

In [2]:
import SimpleITK as sitk
import os
import numpy as np
import shutil
import random

In [3]:
classes = {
    "aorta.nii.gz": 1,
    "heart_atrium_left.nii.gz": 2,
    "heart_atrium_right.nii.gz": 3,
    "heart_myocardium.nii.gz": 4,
    "heart_ventricle_left.nii.gz": 5,
    "heart_ventricle_right.nii.gz": 6,
}

In [4]:
def crop_data(
    starting_index: int = 1,
    hide_one_class_at_random: bool = False,
    output_folder: str = "D_0",
) -> None:
    reader = sitk.ImageFileReader()
    reader.SetImageIO("NiftiImageIO")
    writer = sitk.ImageFileWriter()
    images_folder = os.path.join(output_folder, "imagesTr")
    labels_folder = os.path.join(output_folder, "labelsTr")
    shutil.rmtree(output_folder, ignore_errors=True)
    os.mkdir(output_folder)
    os.mkdir(images_folder)
    os.mkdir(labels_folder)
    sorted_scans = sorted(os.listdir("data"))
    data_folder = os.path.join(os.environ["HOME"], "data")
    cnt = 0
    for last_scan, scan_id in enumerate(sorted_scans[starting_index:]):
        reader.SetFileName(os.path.join(data_folder, scan_id, "ct.nii.gz"))
        try:
            full_scan = sitk.GetArrayFromImage(reader.Execute())
        except:
            print("bad scan", scan_id)
            continue
        total_mask = np.zeros(full_scan.shape)
        for key in classes:
            segment_filename = os.path.join(
                data_folder, scan_id, "segmentations", key
            )
            if to_continue := not os.path.exists(segment_filename):
                print("segment file missing", scan_id)
                break
            reader.SetFileName(segment_filename)
            segment = sitk.GetArrayFromImage(reader.Execute())
            if to_continue := segment.max() == 0:
                print(f"segment file {key} is all zeros for {scan_id}")
                break
            else:
                total_mask += classes[key] * sitk.GetArrayFromImage(
                    reader.Execute()
                )
        if to_continue:
            continue
        indices = np.where(total_mask > 0)
        min_max = [(i.min(), i.max()) for i in indices]
        cropped_mask = total_mask[
            min_max[0][0] : min_max[0][1] + 1,
            min_max[1][0] : min_max[1][1] + 1,
            min_max[2][0] : min_max[2][1] + 1,
        ]
        cropped_scan = full_scan[
            min_max[0][0] : min_max[0][1] + 1,
            min_max[1][0] : min_max[1][1] + 1,
            min_max[2][0] : min_max[2][1] + 1,
        ]
        if hide_one_class_at_random:
            hidden_class = random.choice(range(1, 7))
            cropped_mask[cropped_mask == hidden_class] = 0
        writer.SetFileName(
            os.path.join(labels_folder, f"la_{scan_id[1:]}.nii.gz")
        )
        writer.Execute(sitk.GetImageFromArray(cropped_mask))
        writer.SetFileName(
            os.path.join(images_folder, f"la_{scan_id[1:]}_0000.nii.gz")
        )
        writer.Execute(sitk.GetImageFromArray(cropped_scan))
        cnt += 1
        if cnt >= 100:
            break
    return last_scan + starting_index

In [5]:
last_scan = crop_data()

bad scan s0000
bad scan s0001
bad scan s0002
segment file aorta.nii.gz is all zeros for s0003
bad scan s0004
segment file heart_atrium_left.nii.gz is all zeros for s0006
segment file heart_atrium_left.nii.gz is all zeros for s0009
segment file heart_atrium_left.nii.gz is all zeros for s0022
bad scan s0025
segment file heart_atrium_left.nii.gz is all zeros for s0034
segment file heart_ventricle_left.nii.gz is all zeros for s0035
segment file aorta.nii.gz is all zeros for s0036
bad scan s0043
segment file heart_atrium_left.nii.gz is all zeros for s0044
bad scan s0048
segment file aorta.nii.gz is all zeros for s0056
bad scan s0061
bad scan s0062
segment file heart_atrium_left.nii.gz is all zeros for s0063
bad scan s0066
bad scan s0067
segment file heart_atrium_left.nii.gz is all zeros for s0068
segment file heart_atrium_left.nii.gz is all zeros for s0069
segment file heart_atrium_left.nii.gz is all zeros for s0073
segment file heart_atrium_left.nii.gz is all zeros for s0074
bad scan s0079

In [6]:
last_scan = crop_data(last_scan + 1, True, "D_1")

bad scan s0226
segment file heart_atrium_left.nii.gz is all zeros for s0228
segment file aorta.nii.gz is all zeros for s0229
segment file heart_atrium_left.nii.gz is all zeros for s0233
segment file aorta.nii.gz is all zeros for s0234
bad scan s0235
bad scan s0236
segment file heart_atrium_left.nii.gz is all zeros for s0237
bad scan s0242
segment file heart_atrium_left.nii.gz is all zeros for s0246
bad scan s0254
segment file aorta.nii.gz is all zeros for s0259
segment file heart_atrium_left.nii.gz is all zeros for s0261
segment file heart_atrium_right.nii.gz is all zeros for s0263
segment file heart_atrium_left.nii.gz is all zeros for s0265
segment file heart_atrium_left.nii.gz is all zeros for s0277
bad scan s0278
segment file heart_atrium_left.nii.gz is all zeros for s0279
segment file heart_atrium_left.nii.gz is all zeros for s0281
bad scan s0282
segment file heart_atrium_left.nii.gz is all zeros for s0283
segment file heart_myocardium.nii.gz is all zeros for s0286
segment file hea

In [7]:
last_scan = crop_data(last_scan + 1, False, "D_val")

segment file aorta.nii.gz is all zeros for s0410
segment file heart_atrium_left.nii.gz is all zeros for s0411
bad scan s0417
segment file aorta.nii.gz is all zeros for s0418
segment file heart_atrium_left.nii.gz is all zeros for s0419
segment file heart_atrium_left.nii.gz is all zeros for s0426
bad scan s0436
segment file heart_atrium_left.nii.gz is all zeros for s0442
segment file aorta.nii.gz is all zeros for s0443
segment file aorta.nii.gz is all zeros for s0449
segment file heart_atrium_left.nii.gz is all zeros for s0450
segment file aorta.nii.gz is all zeros for s0453
segment file heart_atrium_left.nii.gz is all zeros for s0454
bad scan s0457
segment file aorta.nii.gz is all zeros for s0460
bad scan s0462
segment file aorta.nii.gz is all zeros for s0466
segment file aorta.nii.gz is all zeros for s0474
segment file heart_atrium_left.nii.gz is all zeros for s0475
segment file heart_atrium_left.nii.gz is all zeros for s0478
segment file heart_myocardium.nii.gz is all zeros for s0491


In [8]:
from tqdm import tqdm

reader = sitk.ImageFileReader()
reader.SetImageIO("NiftiImageIO")

In [9]:
for filename in tqdm(os.listdir("D_0/labelsTr")):
    reader.SetFileName(os.path.join("D_0/labelsTr", filename))
    scan = reader.Execute()
    values = [i in sitk.GetArrayFromImage(scan) for i in range(7)]
    assert 7 == sum(values)

100%|██████████| 100/100 [00:05<00:00, 16.74it/s]


In [10]:
for filename in tqdm(os.listdir("D_1/labelsTr")):
    reader.SetFileName(os.path.join("D_1/labelsTr", filename))
    scan = reader.Execute()
    values = [i in sitk.GetArrayFromImage(scan) for i in range(1, 7)]
    assert 5 == sum(values)

100%|██████████| 100/100 [00:05<00:00, 17.57it/s]


In [11]:
for filename in tqdm(os.listdir("D_val/labelsTr")):
    reader.SetFileName(os.path.join("D_val/labelsTr", filename))
    scan = reader.Execute()
    values = [i in sitk.GetArrayFromImage(scan) for i in range(7)]
    assert 7 == sum(values)

100%|██████████| 100/100 [00:06<00:00, 14.99it/s]
