In [None]:
import numpy as np
import pandas as pd
import nibabel as nib
import PIL.Image as Image
import os
import numpy as np
import glob
from tqdm import tqdm
import scipy.ndimage

common_dir = '/home/ncp/workspace/202002n050/050.신경계 질환 관련 임상 및 진료 데이터'

FILE_EXTENSION = ['.png', '.PNG', '.jpg', '.JPG', '.dcm', '.DCM', '.raw', '.RAW', '.svs', '.SVS']
IMG_EXTENSION = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG']
DCM_EXTENSION = ['.dcm', '.DCM']
RAW_EXTENSION = ['.raw', '.RAW']
NIFTI_EXTENSION = ['.nii']
NP_EXTENSION = ['.npy']


def check_extension(filename, extension_ls=FILE_EXTENSION):
    return any(filename.endswith(extension) for extension in extension_ls)


def load_file_path(folder_path, extension_ls=FILE_EXTENSION, all_sub_folders=False):
    """find 'IMG_EXTENSION' file paths in folder.
    
    Parameters:
        folder_path (str) -- folder directory
        extension_ls (list) -- list of extensions
    
    Return:
        file_paths (list) -- list of 'extension_ls' file paths
    """
    
    file_paths = []
    assert os.path.isdir(folder_path), f'{folder_path} is not a valid directory'

    for root, _, fnames in sorted(os.walk(folder_path)):
        for fname in fnames:
            if check_extension(fname, extension_ls):
                path = os.path.join(root, fname)
                file_paths.append(path)
        if not all_sub_folders:
            break

    return file_paths[:]


def gen_new_dir(new_dir):
    try: 
        if not os.path.exists(new_dir): 
            os.makedirs(new_dir) 
            #print(f"New directory!: {new_dir}")
    except OSError: 
        print("Error: Failed to create the directory.")
        

def find_aihub_img_mask_dir(common_dir, fname, folder='train'):
    if folder == 'train':
        img_dir = os.path.join(common_dir, '01.데이터/1.Training/원천데이터', fname, 'init/image')
        mask_dir = os.path.join(common_dir, '01.데이터/1.Training/라벨링데이터', fname, 'init/mask')
    elif folder == 'val':
        img_dir = os.path.join(common_dir, '01.데이터/2.Validation/원천데이터', fname, 'init/image')
        mask_dir = os.path.join(common_dir, '01.데이터/2.Validation/라벨링데이터', fname, 'init/mask')
        
    return [img_dir, mask_dir]


def pair_img_mask_path(common_dir, fname, folder='train'):
    img_dir, mask_dir = find_aihub_img_mask_dir(common_dir, fname, folder)
    img_path_ls = sorted(glob.glob(os.path.join(img_dir, '*.png')))
    if len(img_path_ls) == 0:
        return None
    img_path_dict = {os.path.splitext(os.path.basename(p))[0]:p for p in img_path_ls}
    if os.path.isdir(mask_dir):
        mask_path_ls = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
        mask_path_dict = {os.path.splitext(os.path.basename(p))[0]:p for p in mask_path_ls}
    else:
        mask_path_dict = {}
    paired_list = []
    for imgnum, imgpath in img_path_dict.items():
        paired_list.append([imgpath, mask_path_dict.get(imgnum)])
    return paired_list


def find_aihub_img_mask_fname(common_dir, folder='train'):
    if folder=='train':
        data_dir = os.path.join(common_dir, '01.데이터/1.Training/원천데이터')
    elif folder=='val':
        data_dir = os.path.join(common_dir, '01.데이터/2.Validation/원천데이터')
        
    _fname = os.listdir(data_dir)
    _fname = [p for p in _fname if os.path.isdir(os.path.join(data_dir, p))]
    paths_list = []
    for fname in _fname:
        tmp = pair_img_mask_path(fname, mod)
        if tmp:
            for p in tmp:
                paths_list.append(p)
    img_list, mask_list = list(zip(*paths_list))
    return img_list, mask_list

In [None]:
aihub_df = pd.read_csv('/home/ncp/workspace/AIHUB_dataset/df_csv_merged_v.2.1.csv')

In [None]:
train_fname = find_aihub_img_mask_fname(common_dir, folder='train')
val_fname = find_aihub_img_mask_fname(common_dir, folder='val')
def check_folder_dir(fname):
    if fname in train_fname:
        return 'train'
    elif fname in val_fname:
        return 'val'
    else:
        return None

def split_train_val_test(fname):
    if fname in train_fname:
        return 'train'
    elif fname in val_fname:
        return 'val'
    elif fname in test_fname:
        return 'test'
    else:
        return None

In [None]:
aihub_df.columns.values

In [None]:
pred_aihub_df['folder'] = pred_aihub_df['name'].map(lambda x: check_folder_dir(x))

In [None]:
pred_aihub_df_clear = pred_aihub_df.copy()
pred_aihub_df_clear['mrs_3m'] = pred_aihub_df['mrs_3m'].fillna(value=pred_aihub_df.mrs3mo)

In [None]:
pred_aihub_df_clear = pred_aihub_df_clear[['name', 'good_outcome_3m', 'mrs_3m', 'folder']]

In [None]:
pred_aihub_df_clear.isna().sum()

In [None]:
pred_aihub_df_clear = pred_aihub_df_clear.astype({'good_outcome_3m':int,
                                                  'mrs_3m':int})

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
tot_fname_label = pred_aihub_df_clear[['name', 'good_outcome_3m']].values
tot_fname = tot_fname_label[:,0]
tot_label = tot_fname_label[:,1]

train_fname, valtest_fname, train_label, valtest_label = train_test_split(tot_fname, 
                                                                          tot_label, 
                                                                          test_size=0.2, 
                                                                          random_state=77, 
                                                                          stratify=tot_label)
val_fname, test_fname, val_label, test_label = train_test_split(valtest_fname, 
                                                                valtest_label, 
                                                                test_size=0.5, 
                                                                random_state=77, 
                                                                stratify=valtest_label)

In [None]:
pred_aihub_df_clear['split'] = pred_aihub_df_clear['name'].map(lambda x: split_train_val_test(x))

In [None]:
pred_aihub_df_clear[pred_aihub_df_clear.split == 'train']['good_outcome_3m'].value_counts()

In [None]:
pred_aihub_df_clear[pred_aihub_df_clear.split == 'train']['mrs_3m'].value_counts()

In [None]:
idx_mrs_3m_9 = pred_aihub_df_clear[pred_aihub_df_clear.mrs_3m == 9].index
pred_aihub_df_clear = pred_aihub_df_clear.drop(idx_mrs_3m_9)

In [None]:
pred_aihub_df_clear[pred_aihub_df_clear.split=='train']['mrs_3m'].value_counts()

In [None]:
pred_aihub_df_clear.to_csv('/home/ncp/workspace/blocks1/3D_CNN_for_PRED/aihub_df.csv', index=False)

In [None]:
fname_folder = pred_aihub_df_clear[['name', 'folder']].values

In [None]:
def read_png_file(filepath):
    return np.array(Image.open(filepath))


def read_mask_file(filepath):
    return np.where(np.array(Image.open(filepath)), 1, 0)


def resample_3d(image_3d, dsize=(36,256,256)):
    rounded_resize_factor = np.array(dsize) / image_3d.shape
    
    return scipy.ndimage.interpolation.zoom(image_3d, rounded_resize_factor, mode='nearest')

In [None]:
def save_arr_to_np(arr, savepoint, fname):
    np.save(os.path.join(savepoint, fname+'.npy'), arr)

In [None]:
savepoint = '/home/ncp/workspace/blocks1/3D_CNN_for_PRED/data_np'
dwi_savepoint = os.path.join(savepoint, 'dwi')
mask_savepoint = os.path.join(savepoint, 'mask')
gen_new_dir(dwi_savepoint)
gen_new_dir(mask_savepoint)

for fname, folder in tqdm(fname_folder):
    dwi_stack = []
    mask_stack = []
    dwi_mask_paths_ls = pair_img_mask_path(common_dir, fname, folder)
    if dwi_mask_paths_ls:
        for dwi_path, mask_path in dwi_mask_paths_ls:
            dwi_stack.append(read_png_file(dwi_path))
            mask_stack.append(read_png_file(mask_path))
        dwi_stack = np.stack(dwi_stack, axis=0)
        mask_stack = np.stack(mask_stack, axis=0)
        save_arr_to_np(dwi_stack, dwi_savepoint, fname)
        save_arr_to_np(mask_stack, mask_savepoint, fname)

In [None]:
savepoint = '/home/ncp/workspace/blocks1/3D_CNN_for_PRED/data_np_resampled'
dwi_savepoint = os.path.join(savepoint, 'dwi')
mask_savepoint = os.path.join(savepoint, 'mask')
gen_new_dir(dwi_savepoint)
gen_new_dir(mask_savepoint)

for fname, folder in tqdm(fname_folder):
    dwi_stack = []
    mask_stack = []
    dwi_mask_paths_ls = pair_img_mask_path(common_dir, fname, folder)
    if dwi_mask_paths_ls:
        for dwi_path, mask_path in dwi_mask_paths_ls:
            dwi_stack.append(read_png_file(dwi_path))
            mask_stack.append(read_png_file(mask_path))
        dwi_stack = np.stack(dwi_stack, axis=0)
        mask_stack = np.stack(mask_stack, axis=0)
        dwi_stack_resampled = resample_3d(dwi_stack)
        mask_stack_resampled = resample_3d(mask_stack)
        save_arr_to_np(dwi_stack_resampled, dwi_savepoint, fname)
        save_arr_to_np(mask_stack_resampled, mask_savepoint, fname)