In [None]:
# import sys
# # python library compile from cpp
# sys.path.insert(0, './persis_lib_cpp')
# from persis_homo_optimal import *

%run Utility_general.ipynb
%run Utility_topo.ipynb

from persim import PersImage
from numpy import linalg as LA
from mpl_toolkits.mplot3d import Axes3D
from cvxopt import matrix, spmatrix, sparse, solvers

# # installed lib for persistence image
# import PersistenceImages.persistence_images as pimg

class Edges_(object):
    
    def __init__(self, params, debug=False):
        self.pc      = Persistence_Computer()
        self.persimg = PersImg_(params["persimg"], False)
        self.debug   = debug
        
        # Parse parameters
        self.target_dim      = params["Topo_edge"]["target_topo_dimension"]
        self.tp_loss_weight  = params["Topo_edge"]["topology_loss_weight"]
        self.blind_force     = params["Topo_edge"]["use_blind_force"]
        self.project2dim     = params["Topo_edge"]["project_2_dim"]
        self.watershed       = params["Topo_edge"]["image_watershed"]
        self.topo_threshold  = params["Topo_edge"]["target_topo_threshold"]
        self.pts_to_fix      = params["Topo_edge"]["number_of_pts_to_fix"]
        self.hole_kernel_r   = params["Topo_edge"]["hole_test_kernel_radius"]
        self.hole_iterations = params["Topo_edge"]["hole_test_iteration"]
        self.hole_bordwidth  = params["Topo_edge"]["hole_test_border_width"]
        self.hole_threshold  = params["Topo_edge"]["hole_test_threshold"]
        
        self.opposite2       = params["Topo_edge"]["detect_opposite_2pts"]
        self.connect         = params["Topo_edge"]["connect_2pts"]
        self.thickness       = params["Topo_edge"]["segment_thickness"]
        self.pd_subset_rate  = params["Topo_edge"]["pd_subset_rate"]
        self.shuffle_subset  = params["Topo_edge"]["shuffle_subset"]
        
    def load_pd_pool(self, dir_in, ext_in, percentage, batch_size):
        '''
        Load in persistence diagrams pre-computed for database. Needs to be
        called before subset_pd_pool(). Note, if self.topo_threshold > 0, 
        persistence diagram will be filtered before projection if any.
        ===== inputs
        percentage: percentage of database to be loaded, max 1.0
        batch_size: number of images in each generated batch, has to be fixed all along, integer
        project2dim: if to project 2d persistence diagram to birth or death
        '''
        self.set_ref, _= FileIO.read_pd_subset(dir_in, ext_in, percentage, shuffle=False, dummyifempty=True) 
        print("Reference persistence diagram read complete.")
        if self.topo_threshold > 0:
            self.set_ref = Utility_topo.topo_filter_retmat(self.set_ref, self.topo_threshold)
            print("Reference persistence diagram filter complete.")
        if self.project2dim >= 0:
            assert(self.project2dim < 2)
            self.set_ref = Utility_topo.extract_dim_from_list(self.set_ref, self.project2dim)
            print("Reference persistence diagram extracted from dimension %d." %self.project2dim)
            
        self.ref_num     = len(self.set_ref)
        self.read_num    = int(np.floor(self.ref_num * self.pd_subset_rate))
        self.lp_         = LPSolver_(batch_size, self.read_num, False)
        print("Compute wasserstein distance between %d and %d" %(batch_size, self.read_num))
        
    def subset_pd_pool(self, num, read_num, shuffle=True):
        if shuffle:
            ind_list = random.sample(range(num), read_num)
        else:
            ind_list = np.arange(read_num)
        subset = [None] * read_num
        for i in range(read_num):
            subset[i] = self.set_ref[ind_list[i]]
        return subset
    
    def fix_with_topo(self, gen, dim, value, wassertein_dist=1.0, blind=False):
        '''
        ===== inputs
        gen: generated 1-channel images from generator, should have value from -1.0 to 1.0
        dim: dimension of the topological structures to extract
        value: the value assigned to detected points, in range of -1.0 and 1.0
        wassertein_dist: 1-wasserstein or 2-wasserstein distance, has to be 1.0 or 2.0
        '''
        assert(dim == 0 or dim == 1)
        gen_bin   = Utility_topo.binarize_data(gen, self.watershed)
        tsfm_list = Utility_topo.dist_trfm_batch(gen_bin)
        if dim == 0:
            B, D, set1, _, red1 = Utility_topo.compute_dist_homology(gen_bin.shape,
                         tsfm_list, self.pc, dim, debug=True, old_form=False)
        else:
            B, D, set1 = Utility_topo.compute_dist_homology(gen_bin.shape,
                     tsfm_list, self.pc, dim, debug=False, old_form=False)
            
        if self.topo_threshold > 0:
            if dim == 0:
                print("Dim 0 with topology threshold > 0 detected.")
                red1 = Utility_topo.topo_filter_retmat_bndorred_mul(red1, set1, self.topo_threshold)
            B, D, set1 = Utility_topo.topo_filter_retmat_mul(B, D, set1, self.topo_threshold)
        if self.project2dim >= 0:
            set1 = Utility_topo.extract_dim_from_list(set1, self.project2dim)
        if blind:
            force = Utility_topo.topo_force_blind_(set1, dim)
            mean_wasserstein_dist = -1.0
        else:
            set2 = self.subset_pd_pool(self.ref_num, self.read_num, self.shuffle_subset)
            dist, G = Utility_general.wasserstein_set_distance(set1, set2, wassertein_dist)
            mapping, dist, G = self.lp_.linear_program_(dist, G)
            force = Utility_topo.topo_force_(set1, set2, G, mapping)
            mean_wasserstein_dist = np.mean(np.asarray(dist))

        if dim == 0:
            flt_lst, frc_x_, frc_y_ = Utility_topo.apply_force_dim0_(force, red1, False)
        else:
            bnd_cv, hcy_cv, red_cv  = Utility_topo.compute_bnd_red_cv_batch(gen_bin)
            flt_lst, frc_x_, frc_y_ = Utility_topo.apply_force_(set1, force, B, bnd_cv, hcy_cv, self.pts_to_fix,
                                      False, self.opposite2, self.connect, gen_bin.shape[1:], self.thickness)
        gen_res = self.plot_on_images(gen, frc_x_, frc_y_, value)        
        
        return gen_res, mean_wasserstein_dist
    
    def treat_edges(self, d_ori):
        # binarize generated images
        d_ori = np.squeeze(d_ori)
        d_bin = Utility_topo.binarize_data(d_ori, self.watershed)

        # compute persistent homology for each image
        tsfm_list                 = Utility_topo.dist_trfm_batch(d_bin)
        bx, by, dx, dy, pd        = Utility_topo.compute_dist_homology(d_ori.shape, tsfm_list, self.pc, 1, self.debug)
        target_ind                = Utility_topo.topo_filter_retindex(pd, self.topo_threshold)
        bnd_res, hcy_res, red_res = Utility_topo.compute_bnd_red_cv_batch(d_bin)
        fix_x, fix_y, crtB_dt     = self.detection_fix_points(bx, by, pd, target_ind, bnd_res, red_res, d_ori.shape[1:3])    
#         loss                      = self.compute_topo_loss(d_ori, fix_x, fix_y, crtB_dt)
#         return np.mean(loss)
        d_ori = np.expand_dims(self.plot_on_images(d_ori, fix_x, fix_y, -1.0), axis=1)
        return d_ori
        
    def persimg_batch(self, d, binarize=False):
        '''
        Compute persistence images for the input batch.
        ===== inputs
        d: input images, if binarize=False, d should have value EITHER 0 OR 255
           (binary output from binarize_data() function)
        dim: dimension to extract topological information
        debug/old_form: refer to Utility_topo.compute_dist_homology
        binarize: if to binarize d
        '''
        if binarize:
            d_bin = Utility_topo.binarize_data(d, self.watershed)
        else:
            d_bin = d
        tsfm_list = Utility_topo.dist_trfm_batch(d_bin)
        
        tsfm_list          = Utility_topo.dist_trfm_batch(d_bin)
        bx, by, dx, dy, pd = Utility_topo.compute_dist_homology(d_bin.shape, tsfm_list, self.pc, 1, self.debug)
        pims               = self.persimg.pim_batch(pd)
        return pims
    
    def pd_batch(self, d, dim, debug=False, old_form=False, binarize=False, disttrfm=True):
        '''
        Compute persistence homology for the input batch.
        ===== inputs
        d: input images, should be numpy array NOT list, if binarize=False,
           d should have value EITHER 0 OR 255 (binary output from binarize_data() function)
        dim: dimension to extract topological information
        debug/old_form: refer to Utility_topo.compute_dist_homology
        binarize: if to binarize d
        '''
        if binarize:
            d_bin = Utility_topo.binarize_data(d, self.watershed)
        else:
            d_bin = d
            
        if disttrfm:
            tsfm_list = Utility_topo.dist_trfm_batch(d_bin)
        else:
            tsfm_list = Utility_general.flatten_image_batch(d_bin)
        
        if old_form:
            if debug:
                birth_x_, birth_y_, death_x_, death_y_, pd_, bnd_, red_ = Utility_topo.compute_dist_homology(
                d_bin.shape, tsfm_list, self.pc, dim, debug=True, old_form=True)
                return birth_x_, birth_y_, death_x_, death_y_, pd_, bnd_, red_
            else:
                birth_x_, birth_y_, death_x_, death_y_, pd_ = Utility_topo.compute_dist_homology(
                d_bin.shape, tsfm_list, self.pc, dim, debug=False, old_form=True)
                return birth_x_, birth_y_, death_x_, death_y_, pd_
        else:
            if debug:
                B_, D_, PD_, bnd_ph, red_ph = Utility_topo.compute_dist_homology(
                d_bin.shape, tsfm_list, self.pc, dim, debug=True, old_form=False)
                return B_, D_, PD_, tsfm_list, bnd_ph, red_ph
            else:
                B_, D_, PD_ = Utility_topo.compute_dist_homology(
                d_bin.shape, tsfm_list, self.pc, dim, debug=False, old_form=False)
                return B_, D_, PD_, tsfm_list
    
    def plot_on_images(self, images, x, y, value):
        '''
        plot coordinates on the images
        ===== inputs
        images: image batch, batch_size * [images]
        x/y: batch_size * [integers]
        '''
        d = copy.copy(np.squeeze(images))
        bat_size = d.shape[0]
        for i in range(bat_size):
            for j in range(len(x[i])):
                d[i][y[i][j]][x[i][j]] = value
        return d
    
    def fd_cycle_distance(self, folder_A, folder_B, ext, normalize=True, topothresh=-1.0, disttrfm=False):
        '''
        folder_A: path to the reference folder
        folder_B: path to the target folder
        if folder_A equals to folder_B, program will save the second load
        '''             
        %run cyckernel/Cyckernel.ipynb
        
        img_batch = Utility_general.read_image_subset(folder_A, ext, 1.0, shuffle=False)
        img_batch = np.stack(img_batch)
        _, _, PDA, _, bnd_phA, _ = self.pd_batch(img_batch, 1, debug=True, old_form=False, binarize=False, disttrfm=disttrfm)
        if topothresh > 0.0:
            bnd_phA = Utility_topo.topo_filter_retmat_bndorred_mul(bnd_phA, PDA, 1.0)
       
        if folder_A == folder_B:
            kernel  = CycleKernel(sigma=10.).fit(bnd_phA)
            cycdist = kernel.transform(bnd_phA)
        else:
            img_batch = Utility_general.read_image_subset(folder_B, ext, 1.0, shuffle=False)
            img_batch = np.stack(img_batch)
            _, _, PDB, _, bnd_phB, _ = self.pd_batch(img_batch, 1, debug=True, old_form=False, binarize=False, disttrfm=disttrfm)
            if topothresh > 0.0:
                bnd_phB = Utility_topo.topo_filter_retmat_bndorred_mul(bnd_phB, PDB, 1.0)
            kernel  = CycleKernel(sigma=10.).fit(bnd_phA)
            cycdist = kernel.transform(bnd_phB)
            
        if normalize:
            column_ = cycdist.shape[0]
            for i in range(column_):
                cycdist[i,:] = cycdist[i,:] / np.linalg.norm(cycdist[i,:]) * 10
        
        return cycdist
    
    def fd_wasserstein_distance(self, folder_A, folder_B, ext, dim, wassertein_dist,
        topothresh=-1.0, binarize=False, disttrfm=False):
        '''
        folder_A: path to the target folder
        folder_B: path to the reference folder
        Note folder order is REVERSE of fd_cycle_distance function!!
        ext: extension of the images like "png"
        dim: dimension of the topology features, integer
        wasserstein_dist: 1.0 or 2.0
        binarize: if to binarize input images to 0 OR 255
        topothresh: topo threshold
        project2dim: if to project 2d persistence dot to birth / death
        '''       
        img_batch = Utility_general.read_image_subset(folder_A, ext, 1.0, shuffle=False)
        img_batch = np.stack(img_batch)
        _, _, PD1, _ = self.pd_batch(img_batch, dim, debug=False, old_form=False, binarize=binarize, disttrfm=disttrfm)
        if topothresh > 0.0:
            PD1 = Utility_topo.topo_filter_retmat(PD1, topothresh)
        if self.project2dim >= 0:
            PD1 = Utility_topo.extract_dim_from_list(PD1, self.project2dim)
            
        if folder_A == folder_B:
            dist, G = Utility_general.wasserstein_set_distance(PD1, PD1, wassertein_dist)
        else:
            img_batch = Utility_general.read_image_subset(folder_B, ext, 1.0, shuffle=False)
            img_batch = np.stack(img_batch)
            _, _, PD2, _ = self.pd_batch(img_batch, dim, debug=False, old_form=False, binarize=binarize, disttrfm=disttrfm)
            if topothresh > 0.0:
                PD2 = Utility_topo.topo_filter_retmat(PD2, topothresh)
            if self.project2dim >= 0:
                PD2 = Utility_topo.extract_dim_from_list(PD2, self.project2dim)
            dist, G = Utility_general.wasserstein_set_distance(PD1, PD2, wassertein_dist)
        
        return dist      
        
    def detection_fix_points(self, bx, by, pd, ind, bnd_res, red_res, shape):
        bat_size = len(ind)
        fix_x    = [None] * bat_size
        fix_y    = [None] * bat_size
        crt_birth_distrfm_val = [None] * bat_size
        kernel = np.ones((self.hole_kernel_r, self.hole_kernel_r), np.uint8)
        
        for i in range(bat_size):
            set_x = []
            set_y = []
            crt_b_dt_val = []
            good_section_record = []
            for idx in ind[i]:
                countour_idx = Utility_topo.return_countour_with_p_inside(bnd_res[i], (bx[i][idx], by[i][idx]))
                if countour_idx >= 0:
                    section_label = red_res[i][1][by[i][idx]][bx[i][idx]]
                    if section_label not in good_section_record:
                        hole_test = Utility_topo.dangling_edge_test(section_label, red_res[i],
                                    shape, kernel, self.hole_iterations, self.hole_bordwidth)
                        if hole_test > self.hole_threshold:
                            good_section_record.append(section_label)
                    if section_label in good_section_record:
                        pts_x, pts_y = Utility_general.find_closest_N_points((bx[i][idx], by[i][idx]),
                                       bnd_res[i][countour_idx], self.pts_to_fix)
                        set_x = set_x + pts_x
                        set_y = set_y + pts_y
                        crt_b_dt_val.append(pd[i][idx][0])
            fix_x[i] = set_x
            fix_y[i] = set_y
            crt_birth_distrfm_val[i] = crt_b_dt_val
        return fix_x, fix_y, crt_birth_distrfm_val
    
    def compute_topo_loss(self, data_origin, fix_x, fix_y, crt_birth_distrfm_val):
        bat_size = data_origin.shape[0]
        loss = [0.0] * bat_size
        
        for i in range(bat_size):
            l_ = 0.0
            for j in range(len(fix_x[i])):
                cur_val = crt_birth_distrfm_val[i][int(j/self.pts_to_fix)]
                l_ = l_ + (data_origin[i][fix_y[i][j]][fix_x[i][j]] - (-1.0)) * cur_val
            loss[i] = l_
        return loss
    
    def test(self, d, device):
        c = torch.randn([128, 1, 64, 64]).to(device)
        loss = torch.abs(c - d).mean()
        return loss
    
    def return_tp_weight(self):
        return self.tp_loss_weight
    
    def return_target_dim(self):
        return self.target_dim
    
    def blind(self):
        return self.blind_force
    
class PersImg_(object):
    
    def __init__(self, params, verbose):
        self.pim = PersImage(spread=params["spread"], pixels=params["pixels"], verbose=verbose)
        
    def pim_single(self, dgm):
        return self.pim.transform(dgm)
    
    def pim_batch(self, phc_pd):
        num_ = len(phc_pd)
        dgms = Utility_topo.convert_phc_pd_2_persim_batch(phc_pd, np.arange(num_))
        pim_list = [None] * num_
        for i in range(num_):
            pim_list[i] = np.expand_dims(self.pim_single(dgms[i]), 0)
        pim_list = np.concatenate(pim_list, axis=0)
        return pim_list
    
class LPSolver_(object):
    
    def __init__(self, M, N, show_progress):
        '''
        Find a match so that the distance is minimized through
        linear programming. Distance matrix should be M X N.
        Samples on left are targets, samples on top are database.
        '''
        solvers.options['show_progress'] = show_progress
        solvers.options['glpk'] = {'msg_lev': 'GLP_MSG_OFF'}
        
        A_ = spmatrix(1.0, range(N), [0]*N, (N, M))
        for i in range(1, M):
            A_sub = spmatrix(1.0, range(N), [i]*N, (N, M))
            A_    = sparse([A_, A_sub])
            
        D_sub = spmatrix(-1.0, range(N), range(N), (N, N))
        D_ = D_sub
        for i in range(1, M):
            D_ = sparse([D_, D_sub])          
        self.A = sparse([[A_], [D_]])
        
        cr = matrix([-1.0 / M] * M)
        cf = matrix([1.0 / N] * N)
        self.c = matrix([cr, cf])
           
        # determine initial point
        self.pStart = {}
        self.pStart['x'] = matrix([matrix([1.0]*M),matrix([-1.0]*N)])
        self.pStart['s'] = matrix([1.0]*(M + N))
        self.M           = M
        self.N           = N
        
    def approx_OT(self, sol):
        '''
        Each sample on the left has a best match.
        '''
        ResMat = np.array(sol['z']).reshape((self.M, self.N))
        mapping = np.argmax(ResMat, axis=1).astype(np.int64)
        return mapping
        
    def Wasserstein_LP(self, dist):
        assert(dist.shape[0] == self.M and dist.shape[1] == self.N)
        h = matrix(dist.astype(np.double).flatten())
        sol = solvers.lp(self.c, self.A, h, primalstart=self.pStart, solver='glpk')
        self.pStart['x'] = sol['x']
        self.pStart['s'] = sol['s']
        return sol
    
    def linear_program_(self, dist, G):
        solution = self.Wasserstein_LP(dist)
        mapping  = self.approx_OT(solution)
        dist_    = [0.] * self.M
        G_       = [None] * self.M
        for i in range(len(mapping)):
            dist_[i] = dist[i, mapping[i]]
            G_[i]    = G[i][mapping[i]]
        return mapping, dist_, G_