## Split train and valid

In [1]:
import os
import glob

In [2]:
base = 'data'

task_name = 'Task500_ATLAS'
target_base = os.path.join(base, task_name)
imagesTr_path = os.path.join(target_base, "imagesTr")
labelsTr_path = os.path.join(target_base, "labelsTr")

In [3]:
from collections import Counter

train_pathes = sorted(glob.glob(os.path.join(imagesTr_path, '*.nii.gz')))

prefix_list = []
for train_path in train_pathes:
    prefix = train_path.split('\\')[-1][:6]
    prefix_list.append(prefix)
    
cnt_prefix_list = Counter(prefix_list)     
train_cnt_list = list(cnt_prefix_list.values())
print("train 데이터의 총합: ", sum(train_cnt_list))
print("prefix 별 train 개수: ", train_cnt_list)


train 데이터의 총합:  654
prefix 별 train 개수:  [38, 12, 15, 36, 18, 111, 26, 29, 5, 12, 8, 6, 7, 8, 11, 16, 13, 5, 37, 24, 8, 49, 2, 45, 18, 2, 2, 7, 25, 23, 12, 7, 17]


In [4]:
train_pathes

['data\\Task500_ATLAS\\imagesTr\\sub001_001_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_002_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_003_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_004_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_005_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_006_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_007_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_008_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_009_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_010_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_011_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_012_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_013_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_014_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_015_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_016_0000.nii.gz',
 'data\\Task500_ATLAS\\imagesTr\\sub001_017_0000.nii.gz',
 'data\\Task50

In [5]:
val_cnt_list = []
for train_cnt in train_cnt_list:
    val_ratio = train_cnt // 10 
    if val_ratio == 0:
        val_cnt_list.append(1)
    else:
        val_cnt_list.append(val_ratio)

print("valid 데이터의 총합: ", sum(val_cnt_list))
print("prefix 별 valid 개수: ", val_cnt_list)

valid 데이터의 총합:  60
prefix 별 valid 개수:  [3, 1, 1, 3, 1, 11, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 2, 1, 4, 1, 4, 1, 1, 1, 1, 2, 2, 1, 1, 1]


In [6]:
def flag_train_val(count, val_cnt):
    if count <= val_cnt:
        return 'valid'
    return 'train'

In [7]:
train_pathes = sorted(glob.glob(os.path.join(imagesTr_path, '*.nii.gz')))

target_imagesTr = []
target_imagesVal = []
count = 0
total_train_cnt = 0
current_idx = 0
for train_path  in train_pathes:
    name = train_path.split('\\')[-1][:-7]
    prefix_idx = int(train_path.split('\\')[-1][:6][-1]) - 1
    count += 1

    flag = flag_train_val(count, val_cnt_list[prefix_idx])
    
    if flag == 'train':
        target_imagesTr.append(name)
    else:
        target_imagesVal.append(name)
    
    if current_idx != prefix_idx:
        count = 0
        current_idx = prefix_idx

In [8]:
print("imagesTr 개수: ", len(target_imagesTr))
print("imagesVal 개수: ", len(target_imagesVal))
print("총합: ", len(target_imagesTr) + len(target_imagesVal) )

imagesTr 개수:  594
imagesVal 개수:  60
총합:  654


## Make json file

In [9]:
import json
from typing import List

from typing import Tuple
import numpy as np

In [10]:
def subfiles(folder: str, join: bool = True, prefix: str = None, suffix: str = None, sort: bool = True) -> List[str]:
    if join:
        l = os.path.join
    else:
        l = lambda x, y: y
    res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
           and (prefix is None or i.startswith(prefix))
           and (suffix is None or i.endswith(suffix))]
    if sort:
        res.sort()
    return res

def save_json(obj, file: str, indent: int = 4, sort_keys: bool = True) -> None:
    with open(file, 'w') as f:
        json.dump(obj, f, sort_keys=sort_keys, indent=indent)

def get_identifiers_from_splitted_files(folder: str):
    uniques = np.unique([i[:-7] for i in subfiles(folder, suffix='.nii.gz', join=False)]) # edit i[:-12] to i[:-7]
    return uniques

def generate_dataset_json(output_file: str, imagesTr_dir: List, imagesVal_dir: List, modalities: Tuple,
                          labels: dict, dataset_name: str, sort_keys=True, license: str = "hands off!", dataset_description: str = "",
                          dataset_reference="", dataset_release='0.0'):
    """
    :param output_file: This needs to be the full path to the dataset.json you intend to write, so
    output_file='DATASET_PATH/dataset.json' where the folder DATASET_PATH points to is the one with the
    imagesTr and labelsTr subfolders
    :param imagesTr_dir: path to the imagesTr folder of that dataset
    :param imagesTs_dir: path to the imagesTs folder of that dataset. Can be None
    :param modalities: tuple of strings with modality names. must be in the same order as the images (first entry
    corresponds to _0000.nii.gz, etc). Example: ('T1', 'T2', 'FLAIR').
    :param labels: dict with int->str (key->value) mapping the label IDs to label names. Note that 0 is always
    supposed to be background! Example: {0: 'background', 1: 'edema', 2: 'enhancing tumor'}
    :param dataset_name: The name of the dataset. Can be anything you want
    :param sort_keys: In order to sort or not, the keys in dataset.json
    :param license:
    :param dataset_description:
    :param dataset_reference: website of the dataset, if available
    :param dataset_release:
    :return:
    """
    # train_identifiers = get_identifiers_from_splitted_files(imagesTr_dir)

    # if imagesTs_dir is not None:
    #     test_identifiers = get_identifiers_from_splitted_files(imagesTs_dir)
    # else:
    #     test_identifiers = []

    json_dict = {}
    json_dict['name'] = dataset_name
    json_dict['description'] = dataset_description
    json_dict['tensorImageSize'] = "4D"
    json_dict['reference'] = dataset_reference
    json_dict['licence'] = license
    json_dict['release'] = dataset_release
    json_dict['modality'] = {str(i): modalities[i] for i in range(len(modalities))}
    json_dict['labels'] = {str(i): labels[i] for i in labels.keys()}

    json_dict['numTraining'] = len(imagesTr_dir)
    json_dict['numTest'] = len(imagesVal_dir)
    json_dict['training'] = [
        {'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i
        in
        imagesTr_dir]
    json_dict['validation'] = [
        {'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i
        in
        imagesVal_dir]

    if not output_file.endswith("dataset.json"):
        print("WARNING: output file name is not dataset.json! This may be intentional or not. You decide. "
              "Proceeding anyways...")
    save_json(json_dict, os.path.join(output_file), sort_keys=sort_keys)

In [11]:
base = 'data'

task_name = 'Task500_ATLAS'

In [12]:
generate_dataset_json(
    output_file=os.path.join(target_base, 'dataset.json'),
    imagesTr_dir=target_imagesTr,
    imagesVal_dir=target_imagesVal,
    modalities=('T1',),
    labels={0: 'background', 1: 'Stroke Lesion'},
    dataset_name=task_name,
    license='hands off!'
)

: 