In [1]:
import os
import nibabel as nib
import numpy as np
import tqdm

In [2]:
amos_path = "/local/scratch/clmn1/data/amos22"

align_orientations = False #not implemented yet
modality = "ct"
classes = [1,2,3,5,6,7,8,9] #'all'
if classes == 'all':
    task_name = {"all": "Task600_AMOSall", "mri": "Task601_AMOSmri", "ct": "Task602_AMOSct"}[modality]
else:
    task_name = "Task603_AMOS" + ''.join([str(c) for c in classes])


nnunet_path = f"/local/scratch/clmn1/cardiacProstate/nnUnet_raw_data_base/nnUNet_raw_data/{task_name}"
os.makedirs(nnunet_path, exist_ok=False)

In [3]:
os.makedirs(os.path.join(nnunet_path, "imagesTr"), exist_ok=True)
os.makedirs(os.path.join(nnunet_path, "labelsTr"), exist_ok=True)

In [4]:
cases_tr = os.listdir(os.path.join(amos_path, "labelsTr"))
cases_va = os.listdir(os.path.join(amos_path, "labelsVa"))
if modality == "mri":
    cases_tr = [c for c in cases_tr if int(c[5:9]) >= 500]
    cases_va = [c for c in cases_va if int(c[5:9]) >= 500]
elif modality == "ct":
    cases_tr = [c for c in cases_tr if int(c[5:9]) < 500]
    cases_va = [c for c in cases_va if int(c[5:9]) < 500]
all_cases = cases_tr + cases_va

In [5]:
assert set(cases_tr).isdisjoint(set(cases_va)), "Training and validation cases overlap!"
for case in all_cases:
    assert case.endswith(".nii.gz"), f"Unexpected case name {case}, does not end with '.nii.gz'"
all_cases = [case[:-7] for case in all_cases]  # remove .nii.gz
cases_tr = [case[:-7] for case in cases_tr]
cases_va = [case[:-7] for case in cases_va]

In [6]:
def copy_or_link(folder: str, case_name: str, label: bool):
    if not label or classes == 'all':
        os.symlink(
            os.path.join(amos_path, folder, case_name + ".nii.gz"),
            os.path.join(nnunet_path, "labelsTr" if label else "imagesTr", case_name + "_0000.nii.gz"),
        )
    else:
        # this is a label and classes is not 'all', so we need to remap the labels
        file = nib.load(os.path.join(amos_path, folder, case_name + ".nii.gz"))
        label_map = file.get_fdata()
        new_label = np.zeros_like(label_map, dtype=int)
        assert label_map.max() <= 15
        for new, old in enumerate(classes, start=1):
            new_label[label_map == old] = new
        new_img = nib.Nifti1Image(new_label, file.affine, file.header)
        nib.save(new_img, os.path.join(nnunet_path, "labelsTr" if label else "imagesTr", case_name + ".nii.gz"))

In [7]:
for case in tqdm.tqdm(cases_tr):
    copy_or_link("imagesTr", case, label=False)
    copy_or_link("labelsTr", case, label=True)
for case in tqdm.tqdm(cases_va):
    copy_or_link("imagesVa", case, label=False)
    copy_or_link("labelsVa", case, label=True)
all_cases = [f"{case}.nii.gz" for case in all_cases]

100%|██████████| 200/200 [07:04<00:00,  2.12s/it]
100%|██████████| 100/100 [04:14<00:00,  2.55s/it]


In [8]:
#https://github.com/MIC-DKFZ/nnUNet/blob/nnunetv1/nnunet/dataset_conversion/Task017_BeyondCranialVaultAbdominalOrganSegmentation.py
from batchgenerators.utilities.file_and_folder_operations import save_json
from collections import OrderedDict

json_dict = OrderedDict()
json_dict['name'] = "AMOS"
json_dict['description'] = "Amos: A large-scale abdominal multi-organ benchmark for versatile medical image segmentation"
json_dict['tensorImageSize'] = "3D"
json_dict['reference'] = ""
json_dict['licence'] = "see challenge website"
json_dict['release'] = "0.0"
json_dict['modality'] = {
    "0": "CT",
}
if modality == "mri":
    json_dict['modality']["0"] = "MRI"
json_dict['labels'] = OrderedDict({
      "0":"background",
      "1":"spleen",
      "2":"right kidney",
      "3":"left kidney",
      "4":"gall bladder",
      "5":"esophagus",
      "6":"liver",
      "7":"stomach",
      "8":"arota",
      "9":"postcava",
      "10":"pancreas",
      "11":"right adrenal gland",
      "12":"left adrenal gland",
      "13":"duodenum",
      "14":"bladder",
      "15":"prostate/uterus"
    }
)
json_dict['numTraining'] = len(all_cases)
json_dict['numTest'] = 0
json_dict['training'] = [{'image': "./imagesTr/%s" % train_patient_name, "label": "./labelsTr/%s" % train_patient_name} for i, train_patient_name in enumerate(all_cases)]
json_dict['test'] = []

if classes != 'all':
    json_dict['labels'] = {int(k): v for k, v in json_dict['labels'].items() if int(k) in [0] + classes}
    json_dict['labels'] = {str(i): v for i, v in enumerate(json_dict['labels'].values())}


save_json(json_dict, os.path.join(nnunet_path, "dataset.json"))

In [None]:
#nnUNet_plan_and_preprocess -t 602
#nnUNet_plan_and_preprocess_ext -t 603 -copy_intensity_props_from 602