In [1]:
!pwd

/home/KOHI/KOHI2021_MedicalImage_Team3/AIHUB


In [2]:
#!mkdir util

In [4]:
!touch ./util/__init__.py

In [5]:
%%writefile ./util/util.py
import os
import numpy as np
import glob

FILE_EXTENSION = ['.png', '.PNG', '.jpg', '.JPG', '.dcm', '.DCM', '.raw', '.RAW', '.svs', '.SVS']
IMG_EXTENSION = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG']
DCM_EXTENSION = ['.dcm', '.DCM']
RAW_EXTENSION = ['.raw', '.RAW']
NIFTI_EXTENSION = ['.nii']
NP_EXTENSION = ['.npy']

common_dir = '/home/ncp/workspace/202002n050/050.신경계 질환 관련 임상 및 진료 데이터'


def check_extension(filename, extension_ls=FILE_EXTENSION):
    return any(filename.endswith(extension) for extension in extension_ls)


def load_file_path(folder_path, extension_ls=FILE_EXTENSION, all_sub_folders=False):
    """find 'IMG_EXTENSION' file paths in folder.
    
    Parameters:
        folder_path (str) -- folder directory
        extension_ls (list) -- list of extensions
    
    Return:
        file_paths (list) -- list of 'extension_ls' file paths
    """
    
    file_paths = []
    assert os.path.isdir(folder_path), f'{folder_path} is not a valid directory'

    for root, _, fnames in sorted(os.walk(folder_path)):
        for fname in fnames:
            if check_extension(fname, extension_ls):
                path = os.path.join(root, fname)
                file_paths.append(path)
        if not all_sub_folders:
            break

    return file_paths[:]


def gen_new_dir(new_dir):
    try: 
        if not os.path.exists(new_dir): 
            os.makedirs(new_dir) 
            #print(f"New directory!: {new_dir}")
    except OSError: 
        print("Error: Failed to create the directory.")


def find_aihub_img_label_dirs(fname, mod='train'):
    if mod == 'train':
        img_dir = os.path.join(common_dir, '01.데이터/1.Training/원천데이터', fname, 'init/image')
        mask_dir = os.path.join(common_dir, '01.데이터/1.Training/라벨링데이터', fname, 'init/mask')
    elif mod == 'val':
        img_dir = os.path.join(common_dir, '01.데이터/2.Validation/원천데이터', fname, 'init/image')
        mask_dir = os.path.join(common_dir, '01.데이터/2.Validation/라벨링데이터', fname, 'init/mask')
    else:
        return None
    return [img_dir, mask_dir]


def pair_img_mask_path(fname, mod='train'):
    img_dir, mask_dir = find_aihub_img_label_dirs(fname, mod)
    img_path_ls = sorted(glob.glob(os.path.join(img_dir, '*.png')))
    if len(img_path_ls) == 0:
        return None
    img_path_dict = {os.path.splitext(os.path.basename(p))[0]:p for p in img_path_ls}
    if os.path.isdir(mask_dir):
        mask_path_ls = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
        mask_path_dict = {os.path.splitext(os.path.basename(p))[0]:p for p in mask_path_ls}
    else:
        mask_path_dict = {}
    paired_list = []
    for imgnum, imgpath in img_path_dict.items():
        paired_list.append([imgpath, mask_path_dict.get(imgnum)])
    return paired_list


def find_aihub_img_label_paths(common_dir, mod='train'):
    if mod=='train':
        data_dir = os.path.join(common_dir, '01.데이터/1.Training/원천데이터')
    elif mod=='val':
        data_dir = os.path.join(common_dir, '01.데이터/2.Validation/원천데이터')
        
    _fname = os.listdir(data_dir)
    _fname = [p for p in _fname if os.path.isdir(os.path.join(data_dir, p))]
    paths_list = []
    for fname in _fname:
        tmp = pair_img_mask_path(fname, mod)
        if tmp:
            for p in tmp:
                paths_list.append(p)
    img_list, mask_list = list(zip(*paths_list))
    return img_list, mask_list

Writing ./util/util.py


In [6]:
%%writefile ./util/visualize.py
import numpy as np


def normalize(arr):
    tmp = (arr - arr.min())/(arr.max()-arr.min())*255
    return tmp.astype(np.uint8)


def visualize_grayscale(arr):
    tmp = normalize(arr)
    return np.stack([tmp, tmp, tmp], axis=-1)

Writing ./util/visualize.py


In [7]:
#!mkdir data

In [8]:
!touch ./data/__init__.py

In [9]:
%%writefile ./data/dataloader.py
import numpy as np
import PIL.Image as Image

def img_loader(img_path):
    return np.expand_dims(np.array(Image.open(img_path)), axis=-1)
def mask_loader(mask_path):
    return np.expand_dims(np.where(np.array(Image.open(mask_path)),1,0), axis=-1).astype(np.uint8)

Writing ./data/dataloader.py


In [10]:
%%writefile ./data/dataset_2d.py
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
import os
from util.util import *
from data.dataloader import *

import numpy as np
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2


def get_training_augmentation(params=None):
    transform_list = []
    
    #transform_list.append(A.HorizontalFlip(p=.5))
    #transform_list.append(A.VerticalFlip(p=.5))
    #transform_list.append(A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=5, shift_limit=0.2, border_mode=0, p=.5))
    #transform_list.append(A.ShiftScaleRotate(scale_limit=0.01, rotate_limit=5, shift_limit=0., border_mode=0, p=.5))
    
    return A.Compose(transform_list)


def get_preprocessing(params=None,resize=(256,256),convert=True):
    transform_list = []
    transform_list.append(A.Resize(*resize))
    if convert:
        transform_list.append(A.Normalize(mean=(0.5,),  std=(0.5,)))
        #transform_list.append(A.Normalize(mean=(0.485, 0.456, 0.406),  std=(0.229, 0.224, 0.225)))
        transform_list.append(ToTensorV2(transpose_mask=True))
    return A.Compose(transform_list)


class AIHUB_LesionSegDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 dataset_dir, 
                 img_loader=img_loader, 
                 mask_loader=mask_loader,
                 augmentation=None, 
                 preprocessing=None,
                 mode='train'
    ):
        self.dataset_dir = dataset_dir
        self.img_loader = img_loader
        self.mask_loader = mask_loader
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.mode = mode
        self.img_path_ls, self.mask_path_ls = find_aihub_img_label_paths(common_dir, mod=self.mode)
        if self.mode != 'train':
            self.augmentation = None
        
    def __getitem__(self, index):
        image = self.img_loader(self.img_path_ls[index])
        if self.mask_path_ls[index]:  
            mask = self.mask_loader(self.mask_path_ls[index])
        else:
            mask = np.zeros_like(image)
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        return image, mask
    
    def __len__(self):
        return len(self.img_path_ls)

Writing ./data/dataset_2d.py
