In [34]:
from batchgenerators.utilities.file_and_folder_operations import load_json
import os

In [35]:
# this is a raw dir we only use to get the scanner name as nnUNet crops the filepath
t1_dir = "/gscratch/nrdg/asagil/raw_data/t1"
vessel_dir = "/gscratch/nrdg/asagil/raw_data/vessel"

def get_matched_ids(dirs, split_char="-"):
    """
    returns a sorted set of all ids that exist in all given dirs
    """
    files = [os.listdir(dir) for dir in dirs]
    file_ids = [[file.split(split_char)[0] for file in file_list] for
                file_list in files]
    sets = [set(file_id) for file_id in file_ids]
    matched = set.intersection(*sets)
    return sorted(matched)


def get_filepath_list_from_id(dir, id):
    dir_files = os.listdir(dir)
    out_paths = []
    for file in dir_files:
        if id in file:
            out_paths.append(os.path.join(dir, file))
    return out_paths

def get_filename_list_from_id(dir, id):
    dir_files = os.listdir(dir)
    file_names = []
    for file in dir_files:
        if id in file:
            file_names.append(file)
    return file_names

In [36]:
# all ids unsorted
ids = get_matched_ids([t1_dir, vessel_dir])

# now we sort by the three scanners
IOP_ids = []
HH_ids = []
Guys_ids = []


for id in ids:
    file_names = get_filename_list_from_id(t1_dir, id)
    if len(file_names) != 1:
        print(f"ID {id} has {len(file_names)} files in t1 dir")
        continue
    else:
        if "IOP" in file_names[0]:
            IOP_ids.append(id)
        elif "HH" in file_names[0]:
            HH_ids.append(id)
        elif "Guys" in file_names[0]:
            Guys_ids.append(id)

print(f"IOP: {len(IOP_ids)}")
print(f"HH: {len(HH_ids)}")
print(f"Guys: {len(Guys_ids)}")

print(f"Total: {len(IOP_ids) + len(HH_ids) + len(Guys_ids)}")
print(f"Total: {len(ids)}")


IOP: 65
HH: 176
Guys: 293
Total: 534
Total: 534


In [37]:
import random

def get_train_test(ids, test_percent=0.2):
    num_test = int(len(ids) * test_percent)
    test_ids = random.sample(ids, num_test)
    train_ids = [id for id in ids if id not in test_ids]
    return train_ids, test_ids

IOP_train_ids, IOP_test_ids = get_train_test(IOP_ids)
HH_train_ids, HH_test_ids = get_train_test(HH_ids)
Guys_train_ids, Guys_test_ids = get_train_test(Guys_ids)

assert(len(IOP_train_ids) + len(IOP_test_ids) == len(IOP_ids))
assert(len(HH_train_ids) + len(HH_test_ids) == len(HH_ids))
assert(len(Guys_train_ids) + len(Guys_test_ids) == len(Guys_ids))

## we now have are train and test ids for each scanner

In [38]:
def get_5_folds(ids):
    random.shuffle(ids)
    fold_size = len(ids) // 5
    remainder = len(ids) % 5

    folds = []
    start = 0
    for i in range(5):
        end = start + fold_size + (1 if i < remainder else 0)
        folds.append(ids[start:end])
        start = end

    return folds

IOP_folds = get_5_folds(IOP_train_ids)
HH_folds = get_5_folds(HH_train_ids)
Guys_folds = get_5_folds(Guys_train_ids)

assert(len(IOP_folds) == 5)
assert(len(HH_folds) == 5)
assert(len(Guys_folds) == 5)

assert(sum([len(f) for f in IOP_folds]) == len(IOP_train_ids))
assert(sum([len(f) for f in HH_folds]) == len(HH_train_ids))
assert(sum([len(f) for f in Guys_folds]) == len(Guys_train_ids))

In [39]:
import json

final_out = []

for i in range(5):
    out = {}
    out["val"] = IOP_folds[i] + HH_folds[i] + Guys_folds[i]
    out["train"] = IOP_folds[:i] + IOP_folds[i+1:] + HH_folds[:i] + HH_folds[i+1:] + Guys_folds[:i] + Guys_folds[i+1:]

    out["train"] = [item for sublist in out["train"] for item in sublist]
    final_out.append(out)
    # Save final_out as JSON
    with open('splits_final_manual.json', 'w') as f:
        json.dump(final_out, f, indent=4)


final_test = IOP_test_ids + HH_test_ids + Guys_test_ids
final_train_val = [sub_id for sub_id in ids if sub_id not in final_test]

assert(len(final_test) + len(final_train_val) == len(ids))
print(f'Test: {len(final_test)}')
print(f'Train + Val: {len(final_train_val)}')

# Save final_test as JSON
with open('test_ids.json', 'w') as f:
    json.dump(final_test, f, indent=4)

# Save final_train_val as JSON
with open('train_val_ids.json', 'w') as f:
    json.dump(final_train_val, f, indent=4)

Test: 106
Train + Val: 428


In [40]:
# now we have all our splits, we setup nnUNET


# set enviroment vars
os.environ["nnUNet_raw"] = "/gscratch/nrdg/asagil/nnUNET_data/nnunet_raw"
os.environ["nnUNet_preprocessed"] = "/gscratch/nrdg/asagil/nnUNET_data/nnunet_preprocessed"
os.environ["nnUNet_results"] = "/gscratch/nrdg/asagil/nnUNET_data/nnunet_results"

from nnunetv2.paths import nnUNet_raw
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from tqdm import tqdm
import SimpleITK as sitk
import numpy as np
import shutil

In [41]:


IXI_t1_dir = "/gscratch/nrdg/asagil/raw_data/t1"
IXI_vessel_dir = "/gscratch/nrdg/asagil/raw_data/costa"

task_id = 95
task_name = "IXI-costa-even-split"

foldername = "Dataset%03.0d_%s" % (task_id, task_name)

# setting up nnU-Net folders
print(f'raw data folder: {nnUNet_raw}')


out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
imagests = join(out_base, "imagesTs")
labelstr = join(out_base, "labelsTr")
labelsts = join(out_base, "labelsTs")


maybe_mkdir_p(imagestr)
maybe_mkdir_p(labelstr)
maybe_mkdir_p(imagests)
maybe_mkdir_p(labelsts)

raw data folder: /gscratch/nrdg/asagil/nnUNET_data/nnunet_raw


In [42]:
for sub_id in tqdm(final_test):
    t1_path = get_filepath_list_from_id(t1_dir, sub_id)
    vessel_path = get_filepath_list_from_id(vessel_dir, sub_id)
    if len(t1_path) != 1 or len(vessel_path) != 1:
        print(f"ID {sub_id} has {len(t1_path)} t1 files and {len(vessel_path)} vessel files")
        continue
    else:
        t1_path = t1_path[0]
        vessel_path = vessel_path[0]

    shutil.copy(t1_path, join(imagests, f"{sub_id}_0000.nii.gz"))
    shutil.copy(vessel_path, join(labelsts, f"{sub_id}.nii.gz"))

for sub_id in tqdm(final_train_val):
    t1_path = get_filepath_list_from_id(t1_dir, sub_id)
    vessel_path = get_filepath_list_from_id(vessel_dir, sub_id)
    if len(t1_path) != 1 or len(vessel_path) != 1:
        print(f"ID {sub_id} has {len(t1_path)} t1 files and {len(vessel_path)} vessel files")
        continue
    else:
        t1_path = t1_path[0]
        vessel_path = vessel_path[0]

    shutil.copy(t1_path, join(imagestr, f"{sub_id}_0000.nii.gz"))
    shutil.copy(vessel_path, join(labelstr, f"{sub_id}.nii.gz"))

generate_dataset_json(out_base,
                    channel_names={0: 'T1'},
                    labels={
                        'background': 0,
                        'vessel': 1,
                    },
                    num_training_cases=(len(final_train_val)),
                    file_ending='.nii.gz',
                    license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
                    reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
                    dataset_release='1.0')


100%|██████████| 106/106 [00:13<00:00,  7.64it/s]
100%|██████████| 428/428 [02:10<00:00,  3.27it/s]
