In [None]:
import numpy as np
import glob
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import cv2

In [None]:
def get_idx(cnt):
    ## rename ____ format
    return str(cnt).zfill(4)


def adjust_gray(data_resampled, w_width, w_center):
    val_min = w_center - (w_width / 2)
    val_max = w_center + (w_width / 2)

    data_adjusted = data_resampled.copy()
    data_adjusted[data_resampled < val_min] = val_min
    data_adjusted[data_resampled > val_max] = val_max

    return data_adjusted

def select_label(labels):
    '''
    ignore_idx = [4, 5, 7, 8, 9, 10, 11, 12, 13]
    ori label = liver (6), R.kidney (2), L.kidney (3), Spleen (1) 
    re-organized it to : bg 0 , liver 1, R.kidney 2, L.kidney 3, Spleen 4
    '''
    # ignore_idx = [4, 5, 7, 8, 9, 10, 11, 12, 13]
    # for i in range(len(ignore_idx)):
    #     labels[labels == ignore_idx[i]] = 0
    new_labels = np.zeros_like(labels)
    new_labels[labels == 6] = 1
    new_labels[labels == 2] = 2
    new_labels[labels == 3] = 3
    new_labels[labels == 1] = 4
    return new_labels

def one_hot_encoder(input):
    tensor_list = []
    for i in range(5):
        temp_prob = input == i  # * torch.ones_like(input_tensor)
        tensor_list.append(np.expand_dims(temp_prob, 2))
    output_tensor = np.concatenate(tensor_list, axis=2)
    return output_tensor.astype(np.float32)

def select_label_mr(labels):
    '''
    Liver: 63 (55<<<70)
    Right kidney: 126 (110<<<135)
    Left kidney: 189 (175<<<200)
    Spleen: 252 (240<<<255)
    '''
    # ignore_idx = [4, 5, 7, 8, 9, 10, 11, 12, 13]
    # for i in range(len(ignore_idx)):
    #     labels[labels == ignore_idx[i]] = 0
    new_labels = np.zeros_like(labels)
    new_labels[labels == 63] = 1
    new_labels[labels == 126] = 2
    new_labels[labels == 189] = 3
    new_labels[labels == 252] = 4
    return new_labels

In [None]:
# ct
file_path = 'Data/chaos/chao_ct_ori'
name_list = os.listdir(f'{file_path}/img')

total_slice_cnt_train = 0
total_slice_cnt_val = 0

for name_i in name_list:
    id = name_i[3:7]
    imgs = sitk.GetArrayFromImage(sitk.ReadImage(f'{file_path}/img/img{id}.nii.gz'))
    labels = sitk.GetArrayFromImage(sitk.ReadImage(f'{file_path}/label/label{id}.nii.gz'))

    ##########################
    ## 1, label preprocess
    ## ori label include 13 organs , see https://www.synapse.org/#!Synapse:syn3193805/wiki/217789
    ## we only use liver (6), R.kidney (2), L.kidney (3), Spleen (1) , and reorgnized it in to liver, r.kid, l.kid, spleen order
    labels = select_label(labels)

    ##########################
    ## 2. split into slices
    slice_cnt = 0
    
    # save the img max and min, to norm the img volume uniformly
    img_max = np.max(imgs)
    img_min = np.min(imgs)

    for i in range(imgs.shape[0]):
        if not (len(np.unique(labels[i])) == 1): # screen out slices contain foreground
   
            img_i = imgs[i, ...]
            label_i = labels[i, ...]

            slice_cnt += 1
            slice_idx = get_idx(slice_cnt)

            img_i = adjust_gray(img_i, 400, 40)

            img_i = np.flip(img_i, axis=0) 
            label_i = np.flip(label_i, axis=0)

            ## visual 
            # plt.figure(figsize=(15,10))
            # plt.subplot(121)
            # plt.imshow(img_i, 'gray')
            # plt.subplot(122)
            # plt.imshow(label_i, vmin=0, vmax=4)

            label_i_onehot = one_hot_encoder(label_i)

            all_i = np.concatenate((np.expand_dims(img_i, 2), label_i_onehot), 2) #（512， 512， 6）

            ######################
            # 3. save
            if int(id) not in (1, 6, 30, 32, 33, 39):
                np.save(f'/mnt/ExtData/Data/processed/chaos/ct_train/img{id}_{slice_idx}.npy', all_i)
            else:
                np.save(f'/mnt/ExtData/Data/processed/chaos/ct_val/img{id}_{slice_idx}.npy', all_i)

    if int(id) not in (1, 6, 30, 32, 33, 39):
        total_slice_cnt_train += slice_cnt
    else:
        total_slice_cnt_val += slice_cnt

print('total_slice_cnt', total_slice_cnt_train+total_slice_cnt_val, 'train', total_slice_cnt_train, 'val', total_slice_cnt_val)
        

In [None]:
# mr
file_path = '/Data/chaos_MR_ori'
name_list = os.listdir(f'{file_path}')

total_slice = 0
total_slice_cnt_train = 0
total_slice_cnt_val = 0

for name_i in name_list:
    slice_cnt = 0 # cnt slices num with foreground
    id = name_i

    img_path = f'{file_path}/{id}/T2SPIR/DICOM_anon/'
    slices = os.listdir(img_path)
   
    for slice_i in slices:
        slice_idx = slice_i[:-4]
        labels = cv2.imread(f'{file_path}/{id}/T2SPIR/Ground/{slice_idx}.png', cv2.IMREAD_GRAYSCALE) #(256, 256)
        imgs = sitk.GetArrayFromImage(sitk.ReadImage(f'{file_path}/{id}/T2SPIR/DICOM_anon/{slice_idx}.dcm'))[0, ...] #ori  (1, 256, 256)

        img_max = np.max(imgs)
        img_min = np.min(imgs)
        ##########################
        ## 1, label preprocess
        ## ori label include 13 organs , see https://www.synapse.org/#!Synapse:syn3193805/wiki/217789
        ## we only use liver (6), R.kidney (2), L.kidney (3), Spleen (1) , and reorgnized it in to liver, r.kid, l.kid, spleen order
        if not len(np.unique(labels)) == 1:# screen out slices contain foreground
            labels = select_label_mr(labels)
            slice_cnt += 1

            imgs[imgs > 1200] = 1200

            ## visual 
            # plt.figure(figsize=(15,10))
            # plt.subplot(121)
            # plt.imshow(imgs, 'gray')
            # plt.subplot(122)
            # plt.imshow(labels, vmin=0, vmax=4)


            label_i_onehot = one_hot_encoder(labels)

            all_i = np.concatenate((np.expand_dims(imgs, 2), label_i_onehot), 2) #（512， 512， 6）

            ########################
            # 3. save
            if int(id) in (1, 31, 32, 38):
                np.save(f'/mnt/ExtData/Data/processed/chaos/mr_val/{id}_{slice_idx}.npy', all_i)
            else:
                np.save(f'/mnt/ExtData/Data/processed/chaos/mr_train/{id}_{slice_idx}.npy', all_i)


    if int(id) in (1, 31, 32, 38):
        total_slice_cnt_val += slice_cnt
    else:
        total_slice_cnt_train += slice_cnt

print('total_slice_cnt', total_slice_cnt_train+total_slice_cnt_val, 'train', total_slice_cnt_train, 'val', total_slice_cnt_val)
