In [None]:
import numpy as np
from scipy import ndimage
import random
import copy
import matplotlib.pyplot as plt
    
class Util_cv(object):
    
    @staticmethod
    def compute_bnd_red_cv(img, low_th, high_th, connectivity):
        ret, thresh = cv2.threshold(img,low_th,high_th,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
        image, contours, hierarchy = cv2.findContours(thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
        reduction = cv2.connectedComponents(thresh, connectivity, cv2.CV_32S)
        return contours, hierarchy, reduction
    
    @staticmethod
    def compute_bnd_red_cv_batch(data):
        '''
        data in format of batch_size * [height * width]
        '''
        bat_size = data.shape[0]
        bnd_res = [None] * bat_size
        hcy_res = [None] * bat_size
        red_res = [None] * bat_size
        for i in range(bat_size):
            bnd_res[i], hcy_res[i], red_res[i] = Utility_topo.compute_bnd_red_cv(data[i,:], 0, 255, 8)
        return bnd_res, hcy_res, red_res

class Util_gen(object):
    
    @staticmethod
    def array_to_color(array, cmap="Oranges"):
        s_m = plt.cm.ScalarMappable(cmap=cmap)
        return s_m.to_rgba(array)[:,:-1]
    
    @staticmethod
    def rgb_data_transform(data):
        data_t = []
        for i in range(data.shape[0]):
            data_t.append(Util_gen.array_to_color(data[i]).reshape(16, 16, 16, 3))
        return np.asarray(data_t, dtype=np.float32)
    
    @staticmethod
    def dist_trfm(dat, binarize=False, inverse=False):
        '''
        compute distance transform: distance from a pixel to nearest zero value pixel
        dat: data to compute on
        binarize: binarize the data so it has values either 0 or 255
        inverse: inverse the distance transform result, i.e. max(dt) - dt
        '''
        if binarize:         
            dat[dat>127]  = 255
            dat[dat<=127] = 0
        dt = ndimage.distance_transform_edt(dat)
        if inverse:
            dt = np.amax(dt) - dt
        return dt
    
    @staticmethod
    def generate_gaussian_kernel(size, sigma, dims):
        '''
        @size: kernel size
        @sigma: gaussian kernel sigma
        @dims: supports 1, 2, or 3
        '''
        def dnorm(x, mu, sd):
            return 1 / (np.sqrt(2 * np.pi) * sd) * np.e ** (-np.power((x - mu) / sd, 2) / 2)

        size_ = [size] * dims
        kernel_1D = np.linspace(-(size//2), size//2, size)
        for i in range(size):
            kernel_1D[i] = dnorm(kernel_1D[i], 0, sigma)
        res = np.zeros(size_, dtype=np.float32)
        if dims == 1:
            return kernel_1D
        elif dims == 2:
            for i in range(size):
                for j in range(size):
                    res[i,j] = kernel_1D[i]*kernel_1D[j]
        elif dims == 3:
            for i in range(size):
                for j in range(size):
                    for k in range(size):
                        res[i,j,k] = kernel_1D[i]*kernel_1D[j]*kernel_1D[k]
        else:
            print("generate_gaussian_kernel: unsupported dimensin")
        res = res * 1.0 / res.max()
        return res
    
    @staticmethod
    def shuffle_partition_label(part, shuffle_num, dirOut):
        min_val = np.amin(part)
        max_val = np.amax(part)
        num = max_val - min_val + 1
        shift = max_val + 1 - min_val
        shuffle_stash = []
        for i in range(shuffle_num):
            duplicate = False
            shuffle = random.sample(range(min_val, max_val+1), num)
            for j in range(len(shuffle_stash)):
                if shuffle == shuffle_stash[j]:
                    duplicate = True
                    break
            if duplicate == False:
                shuffle_stash.append(shuffle)
        
        shuffle_results = [None] * len(shuffle_stash)
        for i in range(len(shuffle_stash)):
            cnt = 0
            shuffle = shuffle_stash[i]
            part_cpy = copy.copy(part)          
            for j in range(min_val, max_val+1):
                part_cpy[part_cpy==j] = shuffle[cnt] + shift
                cnt = cnt + 1
            part_cpy = part_cpy - shift
            shuffle_results[i] = part_cpy
        if dirOut == "nosave":
            return shuffle_results
        else:    
            for i in range(len(shuffle_stash)):
                dirOut_ = dirOut.split('.')[0] + '_' + str(i) + '.dat'
                FileIO.write_matrix_binary(dirOut_, shuffle_results[i], 'i')
            FileIO.write_matrix_binary(dirOut, part, 'i')
    
    @staticmethod
    def fill_rift_(mat_, connectivity):
        '''
        mat: the input matrix
        connectivity: 4 or 8
        This function fills the rift between different regions.
        Assume 0 for the rift
        '''
        def top_pix(x,y,height,width,mat):
            if y - 1 >= 0:
                return mat[y-1,x]
            else:
                return -1
        def bot_pix(x,y,height,width,mat):
            if y + 1 < height:
                return mat[y+1,x]
            else:
                return -1
        def left_pix(x,y,height,width,mat):
            if x - 1 >= 0:
                return mat[y,x-1]
            else:
                return -1
        def right_pix(x,y,height,width,mat):
            if x + 1 < width:
                return mat[y,x+1]
            else:
                return -1
        def top_left_pix(x,y,height,width,mat):
            if x - 1 >= 0 and y - 1 >= 0:
                return mat[y-1,x-1]
            else:
                return -1
        def top_right_pix(x,y,height,width,mat):
            if x + 1 < width and y - 1 >= 0:
                return mat[y-1,x+1]
            else:
                return -1
        def bot_left_pix(x,y,height,width,mat):
            if x - 1 >= 0 and y + 1 < height:
                return mat[y+1,x-1]
            else:
                return -1
        def bot_right_pix(x,y,height,width,mat):
            if x + 1 < width and y + 1 < height:
                return mat[y+1,x+1]
            else:
                return -1
        
        mat = copy.copy(mat_)
        height, width = mat.shape
        while np.sum(mat==0) > 0:
            to_be_changed = []
            for i in range(height):
                for j in range(width):
                    if mat[i,j] == 0:
                        neighbors = []
                        if connectivity == 4:
                            if top_pix(j,i,height,width,mat) > 0:
                                neighbors.append(top_pix(j,i,height,width,mat))
                            if bot_pix(j,i,height,width,mat) > 0:
                                neighbors.append(bot_pix(j,i,height,width,mat))
                            if left_pix(j,i,height,width,mat) > 0:
                                neighbors.append(left_pix(j,i,height,width,mat))
                            if right_pix(j,i,height,width,mat) > 0:
                                neighbors.append(right_pix(j,i,height,width,mat))
                        if connectivity == 8:
                            if top_pix(j,i,height,width,mat) > 0:
                                neighbors.append(top_pix(j,i,height,width,mat))
                            if bot_pix(j,i,height,width,mat) > 0:
                                neighbors.append(bot_pix(j,i,height,width,mat))
                            if left_pix(j,i,height,width,mat) > 0:
                                neighbors.append(left_pix(j,i,height,width,mat))
                            if right_pix(j,i,height,width,mat) > 0:
                                neighbors.append(right_pix(j,i,height,width,mat))
                            if top_left_pix(j,i,height,width,mat) > 0:
                                neighbors.append(top_left_pix(j,i,height,width,mat))
                            if top_right_pix(j,i,height,width,mat) > 0:
                                neighbors.append(top_right_pix(j,i,height,width,mat))
                            if bot_left_pix(j,i,height,width,mat) > 0:
                                neighbors.append(bot_left_pix(j,i,height,width,mat))
                            if bot_right_pix(j,i,height,width,mat) > 0:
                                neighbors.append(bot_right_pix(j,i,height,width,mat))
                        if len(neighbors) > 0:
                            rand_index = random.sample(range(0,len(neighbors)), 1)[0]
                            to_be_changed.append((i,j,neighbors[rand_index]))

            for tup in to_be_changed:
                mat[tup[0],tup[1]] = tup[2]
        return mat