In [None]:
%load_ext autoreload
%autoreload 2

import h5py
h5py.get_config().track_order = True

import glob
import numpy as np
import nibabel as nib

import os.path
from tqdm import tqdm

import sys
assert sys.version_info.major == 3, 'Not running on Python 3'

from IPython.utils import io
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import logging
logging.basicConfig(level=logging.INFO, stream=sys.stdout)

In [None]:
ds_path = "/home/imag2/IMAG2_DL/KDCompression/Dataset/ds.h5"

img_dir = "/home/imag2/IMAG2_DL/APMRI-DNN/Dataset/All"
teacher_dir = "/home/imag2/IMAG2_DL/KDCompression/Dataset/Teacher/"

imgs_fpath = sorted(glob.glob(os.path.join(img_dir, '*[tT]2*.nii.gz')))
masks_fpath = sorted(glob.glob(os.path.join(img_dir, '*[mM]ask*.nii.gz')))
labels_fpath = sorted(glob.glob(os.path.join(img_dir, '*[sS]egmentation*.nii.gz')))
teacher_fpath = sorted(glob.glob(os.path.join(teacher_dir, '*[tT]eacher*.nii.gz')))

In [None]:
def load_nii(f_name):
    img = nib.load(f_name)
    canonical_img = nib.as_closest_canonical(img)
    return canonical_img.get_fdata()

In [None]:
with h5py.File(ds_path, 'a', libver='latest') as f:
    with tqdm(total=len(imgs_fpath), desc="Compressing", unit="sample") as pbar:
        for i, (img, mask, label, teacher) in enumerate(zip(imgs_fpath, masks_fpath, labels_fpath, teacher_fpath)):
            with io.capture_output() as captured:
                sample = f.create_group("sample_{}".format(i))
                arr = load_nii(img).astype(np.float32)
                sample.create_dataset(name='img', shape=arr.shape, data=arr, chunks=arr.shape,
                                      compression='gzip', compression_opts=9, dtype=arr.dtype)
                arr = load_nii(mask).astype(np.uint8)
                sample.create_dataset(name='mask', shape=arr.shape, data=arr, chunks=arr.shape,
                                      compression='gzip', compression_opts=9, dtype=arr.dtype)
                arr = load_nii(mask).astype(np.uint8)
                sample.create_dataset(name='label', shape=arr.shape, data=arr, chunks=arr.shape,
                                      compression='gzip', compression_opts=9, dtype=arr.dtype)
                arr = load_nii(teacher).astype(np.float32)
                sample.create_dataset(name='teacher', shape=arr.shape, data=arr, chunks=arr.shape,
                                      compression='gzip', compression_opts=9, dtype=arr.dtype)
                pbar.update()