In [None]:
%load_ext autoreload
%autoreload 2

import h5py
h5py.get_config().track_order = True

import pickle
import glob
import numpy as np
import nibabel as nib
from scipy.ndimage import find_objects

import os.path
from tqdm import tqdm

import sys
assert sys.version_info.major == 3, 'Not running on Python 3'

from IPython.utils import io
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import logging
logging.basicConfig(level=logging.INFO, stream=sys.stdout)

In [None]:
ds_path = "/home/imag2/IMAG2_DL/KDCompression/Dataset/ds"

img_dir = "/home/imag2/IMAG2_DL/APMRI-DNN/Dataset/All"
teacher_dir = "/home/imag2/IMAG2_DL/KDCompression/Dataset/Teacher/"

imgs_fpath = sorted(glob.glob(os.path.join(img_dir, '*[tT]2*.nii.gz')))
masks_fpath = sorted(glob.glob(os.path.join(img_dir, '*[mM]ask*.nii.gz')))
labels_fpath = sorted(glob.glob(os.path.join(img_dir, '*[sS]egmentation*.nii.gz')))
teacher_fpath = sorted(glob.glob(os.path.join(teacher_dir, '*[tT]eacher*.nii.gz')))

In [None]:
def load_nii(f_name):
    img = nib.load(f_name)
    canonical_img = nib.as_closest_canonical(img)
    return canonical_img.get_fdata()

def preprocess(img, mask=None, normalize=False, dtype=np.float32, forth_dim=False):
    if mask is not None:
        if forth_dim:
            img *= np.expand_dims(mask, axis=-1)
        else:
            img *= mask
        
    if normalize:
        mean = np.mean(img)
        std = np.std(img)
        if std > 0:
            img = (img - mean) / std
        else:
            img *= 0.
    
    if forth_dim:
        img = np.expand_dims(img.transpose((3, 2, 0, 1)), axis=0)
    else:
        img = np.expand_dims(img.transpose((2, 0, 1)), axis=0)
    return img.astype(dtype)

In [None]:
with h5py.File(ds_path + '.h5', 'a', libver='latest') as f:
    idx = 0
    with tqdm(total=len(imgs_fpath), desc="Compressing", unit="sample") as pbar:
        for i, (img, mask, label, teacher) in enumerate(zip(imgs_fpath, masks_fpath, labels_fpath, teacher_fpath)):
            mask = load_nii(mask)
            with io.capture_output() as captured:
                arr = load_nii(label)
                arr = preprocess(arr, mask, dtype=np.uint8)
                arr[(arr != 10) & (arr != 14) & (arr != 45) & (arr != 49) & (arr != 43) & (arr != 44)] = 0
                arr[arr == 10] = 1
                arr[arr == 14] = 1
                arr[arr == 45] = 2
                arr[arr == 49] = 3
                arr[arr == 43] = 4
                arr[arr == 44] = 4
                # arr[(arr != 10) & (arr != 14)] = 0
                #arr[arr > 0] = 1
                
                if np.amax(arr) > 0:
                    sample = f.create_group("sample_{}".format(idx))
                    
                    sample.create_dataset(name='label', shape=arr.shape, data=arr, chunks=arr.shape,
                                          compression='gzip', compression_opts=9, dtype=arr.dtype)
                    
                    arr = load_nii(img)
                    arr = preprocess(arr, mask, normalize=True)
                    sample.create_dataset(name='img', shape=arr.shape, data=arr, chunks=arr.shape,
                                          compression='gzip', compression_opts=9, dtype=arr.dtype)
                    
                    arr = load_nii(teacher)
                    arr = preprocess(arr, mask, forth_dim=True)
                    sample.create_dataset(name='teacher', shape=arr.shape, data=arr, chunks=arr.shape,
                                          compression='gzip', compression_opts=9, dtype=arr.dtype)
                    
                    idx += 1
                pbar.update()

In [None]:
with open(ds_path + '.npy', 'wb') as f:
    sample = {}
    idx = 0
    with tqdm(total=len(imgs_fpath), desc="Compressing", unit="sample") as pbar:
        for i, (img, mask, label, teacher) in enumerate(zip(imgs_fpath, masks_fpath, labels_fpath, teacher_fpath)):
                mask = load_nii(mask)
                loc = find_objects(mask>0)[0]
                mask = mask[loc]
            
                arr = load_nii(label)[loc]
                arr = preprocess(arr, mask, dtype=np.uint8)
                arr[(arr != 10) & (arr != 14) & (arr != 45) & (arr != 49) & (arr != 43) & (arr != 44)] = 0
                arr[arr == 10] = 1
                arr[arr == 14] = 1
                arr[arr == 45] = 2
                arr[arr == 49] = 3
                arr[arr == 43] = 4
                arr[arr == 44] = 4
              
                if np.amax(arr) > 0:
                    sample["sample_{}".format(idx)] = {}
                    
                    sample["sample_{}".format(idx)]['label'] = arr

                    arr = load_nii(img)[loc]
                    arr = preprocess(arr, mask, normalize=True)
                    sample["sample_{}".format(idx)]['img'] = arr

                    arr = load_nii(teacher)[loc]
                    arr = preprocess(arr, mask, forth_dim=True)
                    sample["sample_{}".format(idx)]['teacher'] = arr
                    
                    idx += 1
                pbar.update()
    pickle.dump(sample, f, protocol=pickle.HIGHEST_PROTOCOL)