In [None]:
%run Utility_general.ipynb
import glob
import struct
from pathlib import Path
from PIL import Image as PImg
from shutil import copyfile
import cv2
import numpy as np
import nibabel as nib
import os
#import h5py


from matplotlib import image as mpimg

# Refer to https://docs.python.org/3/library/struct.html
# f: float; d: double; q: long long; l: long; i: int; I: unsigned int
dlength_dict = {'f': 4, 'd': 8, 'q': 8, 'l': 4, 'i': 4, 'I': 4}

class FileIO_DMT(object):
    
    DIPHA_CONST = 8067171840
    DIPHA_IMAGE_TYPE_CONST = 1
    
    @staticmethod
    def cvt_dat2D2dipha_(dat, pathOut):
        '''
        img: the data to be written out as dipha
        pathOut: path to the binary output file
        output file format compatible with DMT program
        output file name has to end with .bin
        '''
        assert(pathOut.endswith('.bin'))
        assert(size(dat.shape) == 2)
        fout = open(pathOut, "wb")
        fout.write(struct.pack('q', FileIO_DMT.DIPHA_CONST))
        fout.write(struct.pack('q', FileIO_DMT.DIPHA_IMAGE_TYPE_CONST))
        fout.write(struct.pack('q', np.prod(dat.shape)))
        fout.write(struct.pack('q', 3))
        fout.write(struct.pack('q', dat.shape[0]))
        fout.write(struct.pack('q', dat.shape[1]))
        fout.write(struct.pack('q', 1))
        dat = np.transpose(-dat)
        fout.write(struct.pack(str(np.prod(dat.shape))+'d', *(dat.flatten())))
        fout.close()
    
    @staticmethod
    def cvt_img2dipha_(pathIn, pathOut, transform_argv=[]):
        '''
        pathIn: path to the image
        pathOut: path to the binary output file
        - transform_argv: transform arguments. 
        - Example: ['dist_trfm', True, True]
        '''
        img  = cv2.imread(pathIn, cv2.IMREAD_GRAYSCALE)
        img  = np.float64(img)
        if len(transform_argv) != 0 and transform_argv[0] == 'dist_trfm':
            assert(len(transform_argv) == 3)
            dt = Util_gen.dist_trfm(img, binarize=transform_argv[1], inverse=transform_argv[2])
            FileIO_DMT.cvt_dat2D2dipha_(dt, pathOut)
        else:
            FileIO_DMT.cvt_dat2D2dipha_(img, pathOut)
        
    @staticmethod
    def cvt_dipha2img_(pathIn, uint8=False):
        '''
        pathIn: path to the binary dipha file
        uint8: convert output matrix to uint8 data type
        '''
        fin = open(pathIn, "rb")
        assert(struct.unpack('q', fin.read(dlength_dict['q']))[0] == FileIO_DMT.DIPHA_CONST)
        assert(struct.unpack('q', fin.read(dlength_dict['q']))[0] == FileIO_DMT.DIPHA_IMAGE_TYPE_CONST)
        total_pixel = struct.unpack('q', fin.read(dlength_dict['q']))[0]
        struct.unpack('q', fin.read(dlength_dict['q']))
        [nx, ny, nz] = struct.unpack('3q', fin.read(dlength_dict['q']*3))
        assert(nx*ny*nz == total_pixel)
        img = struct.unpack(str(total_pixel)+'d', fin.read(dlength_dict['d']*total_pixel))
        img = np.reshape(img, [nx, ny])
        img = np.transpose(-img)
        if uint8:
            img = np.uint8(img)
        return img
    
    @staticmethod
    def cvt_img2dipha_batch_(dirIn, dirOut, ext, transform_argv=[]):
        '''
        dirIn: the folder containing the images
        dirOut: the folder to output .bin dipha files
        ext: target extension of the images
        transform_argv: transform arguments.
        '''
        filesIn  = []
        filesOut = []
        os.chdir(dirIn)
        for file in glob.glob("*."+ext):
            name = file[:-len(ext)-1]
            filesIn.append(os.path.join(dirIn, file))
            filesOut.append(os.path.join(dirOut, name+".bin"))
        Path(dirOut).mkdir(parents=True, exist_ok=True)
        
        for i in range(len(filesIn)):
            FileIO_DMT.cvt_img2dipha_(filesIn[i], filesOut[i], transform_argv)
    
class FileIO(object):
    
    @staticmethod
    def read_binary(path, shape, dtype='d'):
        file_in  = open(path, "rb")
        data_arr = struct.unpack(str(np.prod(shape))+dtype, file_in.read(dlength_dict[dtype]*np.prod(shape)))
        data_arr = np.reshape(data_arr, shape)
        file_in.close()
        if dtype=='f':
            data_arr = np.float32(data_arr)
        return data_arr
    
    @staticmethod
    def write_binary(path, data, shape, dtype='d'):
        '''
        data has to be flattened.
        '''
        file_out = open(path, "wb")
        file_out.write(struct.pack(str(np.prod(shape))+dtype, *(data)))
        file_out.close()
    
    @staticmethod
    def read_matrix_binary(path, dtype='d'):
        file_in = open(path, "rb")
        dims    = struct.unpack('I', file_in.read(4))[0]
        if dims == 0:
            return None
        shape   = struct.unpack(str(dims)+'I', file_in.read(4 * dims))
        mat     = struct.unpack(str(np.prod(shape))+dtype, file_in.read(dlength_dict[dtype]*np.prod(shape)))
        mat     = np.reshape(mat, shape)
        return mat
    
    @staticmethod
    def write_matrix_binary(path, mat, dtype='d'):
        dims  = len(mat.shape)
        shape = mat.shape
        file_out = open(path, "wb")
        file_out.write(struct.pack('I', dims))
        file_out.write(struct.pack(str(dims)+'I', *(shape)))
        file_out.write(struct.pack(str(np.prod(shape))+dtype, *(mat.flatten())))
        file_out.close()
    
    @staticmethod
    def load_nii_(path, verbose=False):
        vol_struct = nib.load(path)
        vol = np.squeeze(np.array(vol_struct.get_fdata()))
        assert(size(vol.shape) == 3)
        print(vol.shape)
        print(type(vol[0,0,0]))
        
        if verbose == True:
            print(vol_struct.affine)
            print(vol_struct.header)
        return vol
    
#     @staticmethod
#     def read_h5py(path, shape):
#         with h5py.File(path, "r") as hf:    

#             # Split the data into training/test features/targets
#             X_train = hf["X_train"][:]
#             targets_train = hf["y_train"][:]
#             X_test = hf["X_test"][:] 
#             targets_test = hf["y_test"][:]

#             # Determine sample shape
#             sample_shape = shape

#             # Reshape data into 3D format
#             X_train = Util_gen.rgb_data_transform(X_train)
#             X_test = Util_gen.rgb_data_transform(X_test)
#         return X_train, X_test, targets_train, targets_test
    
    @staticmethod
    def save_image_batch(data, folder, prefix, scalor, number_offset, binary_out, num=-1):
        '''
        data: the image data to be saved
        folder: the path to the folder where the images are saved
        prefix: prefix_00000.png
        scalor: data * scalor + scalor, 0 indicates no scaling
        number_offset: the index of the image to start saving from (included)
        num: number of images to save out, -1 means all
        '''
        data = np.squeeze(data)
        assert(len(data.shape) == 3 or len(data.shape) == 4)
        batch_size  = data.shape[0]
        image_shape = data.shape[-2:]
        
        if num == -1 or num >= batch_size:
            out_num = batch_size
        else:
            out_num = num
            
        for idx in range(out_num):
            out_name = ''
            for _ in range(5 - len(str(idx + number_offset))):
                out_name += '0'
            dat_ = np.reshape(data[idx], image_shape)
            if scalor > 0:
                dat_ = dat_ * scalor + scalor
            if binary_out == False:
                out_name = folder + "/" + prefix + "_" + out_name + str(idx + number_offset) + ".png"
                dat_ = dat_.astype(np.uint8)
                cv2.imwrite(out_name, dat_)
            else:
                out_name = folder + "/" + prefix + "_" + out_name + str(idx + number_offset) + ".dat"
                FileIO.write_matrix_binary(out_name, dat_, 'f')
        print("Data write out complete.")
    
    @staticmethod
    def img_boundbox_batch_(dirIn, dirOut, ext, thick):
        '''
        dirIn: the folder containing the images
        dirOut: the folder to output .bin dipha files
        ext: target extension of the images
        thick: thickness of the bounding box
        This function reads and output images in GRAYSCALE
        '''
        os.chdir(dirIn)
        Path(dirOut).mkdir(parents=True, exist_ok=True)
        for file in glob.glob("*."+ext):
            pathIn = os.path.join(dirIn, file)
            pathOut = os.path.join(dirOut, file)
            img = cv2.imread(pathIn, cv2.IMREAD_GRAYSCALE)
            cv2.rectangle(img, (0,0),(img.shape[1]-1, img.shape[0]-1),(0), thick)
            cv2.imwrite(pathOut, img)
            
    @staticmethod
    def img_connected_component_cv_batch_(dirIn, dirOut, ext, binarize, fillRift, shuffle_num):
        '''
        dirIn: the folder containing the images
        dirOut: the folder to output .bin dipha files
        ext: target extension of the images
        This function reads images and outputs their connected components ranging from 0 to N
        '''
        filesIn  = []
        filesOut = []
        os.chdir(dirIn)
        Path(dirOut).mkdir(parents=True, exist_ok=True)
        for file in glob.glob("*."+ext):
            name = file[:-len(ext)-1]
            filesIn.append(os.path.join(dirIn, file))
            filesOut.append(os.path.join(dirOut, name+".dat"))
        for i in range(len(filesIn)):
            img = cv2.imread(filesIn[i], cv2.IMREAD_GRAYSCALE)
            if binarize == True:
                img[img>127]  = 255
                img[img<=127] = 0
            cnt, hry, red = Util_cv.compute_bnd_red_cv(img, 0, 255, 8)
            if fillRift == True:
                filled = Util_gen.fill_rift_(red[1], 8)
                if shuffle_num > 0:
                    Util_gen.shuffle_partition_label(filled, shuffle_num, filesOut[i])
                else:
                    FileIO.write_matrix_binary(filesOut[i], filled, 'i')
            else:
                if shuffle_num > 0:
                    Util_gen.shuffle_partition_label(red[1], shuffle_num, filesOut[i])
                else:
                    FileIO.write_matrix_binary(filesOut[i], red[1], 'i')