In [None]:
import time
import struct
from sklearn import svm
from sklearn import preprocessing
from sklearn.kernel_approximation import *
from sklearn.preprocessing import normalize

class CycleKernel(BaseEstimator, TransformerMixin):

    def __init__(self, sigma=-1, num_components=100):
        self.sigma, self.num_components = sigma, num_components

    def fit(self, X, y=None):
        self.red = [None] * len(X)
        if self.sigma > 0:
            self.approx = RBFSampler(gamma=self.sigma, n_components=self.num_components, random_state=10).fit(X[0][0])
        for i in range(len(X)):
            print("fit: ", i/len(X))
            mat = [np.zeros((1,self.num_components),dtype=np.float64)] * len(X[i])
            for cyc in range(len(X[i])):
                if self.sigma == -1:
                    mat[cyc] = np.mean(X[i][cyc], axis=0)[np.newaxis,:]
                else:
                    mat[cyc] = np.mean(self.approx.transform(X[i][cyc]), axis=0)[np.newaxis,:]                    
            self.red[i] = np.concatenate(mat, axis=0)
        return self

    def transform(self, X):
        mat = np.zeros([len(X), len(self.red)])
        if len(self.red) == len(X):
            for i in range(len(X)):
                print("transform: ", i/len(X))
                for j in range(i, len(X)):
                    mat[i,j] = np.mean(np.matmul(self.red[i], self.red[j].T))
                    mat[j,i] = mat[i,j]
        else:
            red2 = []
            for i in range(len(X)):
                mat2 = []
                for cyc in range(len(X[i])):
                        if self.sigma == -1:
                            mat2.append(  np.mean(X[i][cyc], axis=0)[np.newaxis,:]  )
                        else:
                            mat2.append(  np.mean(self.approx.transform(X[i][cyc]), axis=0)[np.newaxis,:]  )
            red2.append( np.concatenate(mat2, axis=0) )
            for i in range(len(X)):
                for j in range(len(self.red)):
                    mat[i,j] = np.mean(np.matmul(red2[i], self.red[j].T))
        return mat
    
class Compute_CycleKernel(object):
    
    @staticmethod
    def compute_cyckernel(bnd, dims, normalize, scale=1.0):
        '''
        Compute cycle kernel distance for persistence homology
        -- Input
        - bnd:  output from function read_bnd_red_unifieddim (format must be compatible)
        - dims: persistence homology dimensions, 2 for 2D files, 3 for 3D files
        - normalize: if normalize the cyckernel distances
        - scale: scale of the normalized distances
        -- Output:
            In case of 3D input, the output is in form of [[#files X #files]] * 3:
            0 dimension, 1 dimension, 2 dimension
        '''
        assert(dims == len(bnd[0]))
        file_number = len(bnd)
        cycs_grand_list = [None] * dims

        for dim in range(dims):
            dim_list = [None] * file_number
            for i in range(file_number):
                cyc_list = [None] * len(bnd[i][dim])
                for j in range(len(cyc_list)):
                    cyc_list[j] = np.transpose(bnd[i][dim][j])
                dim_list[i] = cyc_list
            kernel = CycleKernel(sigma=10.).fit(dim_list)
            trfmed = kernel.transform(dim_list)

            if normalize:
                for i in range(file_number):
                    trfmed[i,:] = trfmed[i,:] / np.linalg.norm(trfmed[i,:]) * scale
            cycs_grand_list[dim] = trfmed
        return cycs_grand_list
    
    @staticmethod
    def compute_cyckernel_wThreshold(bnd, dims, pers, threshold, normalize, scale=1.0):
        '''
        Compute cycle kernel distance for persistence homology, the threshold is applied to the persistence
        of the structures
        -- Input
        @bnd:  output from function read_bnd_red_unifieddim (format must be compatible) [[[pt_num x 3] x struct_num] x dim] x file_num
        @dims: persistence homology dimensions, 2 for 2D files, 3 for 3D files
        @pers: output from read_pers_txt, [[struct_num x 2] x dim] x file_num
        @threshold: threshold for persistence
        @normalize: if normalize the cyckernel distances
        @scale: scale of the normalized distances
        -- Output:
            In case of 3D input, the output is in form of [[#files X #files]] * 3:
            0 dimension, 1 dimension, 2 dimension
        '''
        assert(dims == len(bnd[0]))
        file_number = len(bnd)
        kernel_dist = [None] * dims

        for dim in range(dims):
            dim_list = [None] * file_number
            for i in range(file_number):
                persistence = pers[i][dim][:,1] - pers[i][dim][:,0] >= threshold
                cyc_list = [None] * np.sum(persistence)
                cnt = 0
                for j in range(len(persistence)):
                    if persistence[j]:
                        cyc_list[cnt] = np.transpose(bnd[i][dim][j])
                        cnt = cnt + 1
                dim_list[i] = cyc_list
            kernel = CycleKernel(sigma=10.).fit(dim_list)
            trfmed = kernel.transform(dim_list)
            
            if normalize:
                for i in range(file_number):
                    trfmed[i,:] = trfmed[i,:] / np.linalg.norm(trfmed[i,:]) * scale
            kernel_dist[dim] = trfmed
            print("dimension ", dim, " completes...")
            
        return kernel_dist
    
    @staticmethod
    def compute_cyckernel_wBornDeathLimit(bnd, dims, pers, labels, mode_born, threshold_born,
                                          mode_death, threshold_death, normalize, scale=1.0):
        '''
        Compute cycle kernel distance for persistence homology, the threshold is applied to the birth
        and/or death time. labels is imported here as some subjects will not have any structure left
        after thresholding, and these structures will be deleted for classification
        -- Input
        @bnd:  output from function read_bnd_red_unifieddim (format must be compatible) [[[pt_num x 3] x struct_num] x dim] x file_num
        @dims: persistence homology dimensions, 2 for 2D files, 3 for 3D files
        @pers: output from read_pers_txt, [[struct_num x 2] x dim] x file_num
        @labels: [file_num]
        @mode_born: operator for birth time, example: ">=", "<=", "dummy"(not use)
        @threshold_born: -650, not use if mode_born == "dummy"
        @mode_death: operator for death time, example: ">=", "<=", "dummy"(not use)
        @threshold_death: -650, not use if mode_death == "dummy"
        @normalize: if normalize the cyckernel distances
        @scale: scale of the normalized distances
        -- Output:
            In case of 3D input, the output is in form of [[#files X #files]] * 3:
            0 dimension, 1 dimension, 2 dimension
        '''
        assert(dims == len(bnd[0]))
        file_number = len(bnd)
        assert(len(labels) == file_number)
        kernel_dist = [None] * dims
        filt_labels = [None] * dims

        for dim in range(dims):
            dim_list = []
            dim_labels = []
            for i in range(file_number):
                if (mode_born==">="):
                    born_filt = pers[i][dim][:,0] >= threshold_born
                elif (mode_born=="<="):
                    born_filt = pers[i][dim][:,0] <= threshold_born
                elif (mode_born=="dummy"):
                    born_filt = [True] * pers[i][dim].shape[0]
                else:
                    print("Error input for variable: mode_born")
                    
                if (mode_death==">="):
                    death_filt = pers[i][dim][:,1] >= threshold_death
                elif (mode_death=="<="):
                    death_filt = pers[i][dim][:,1] <= threshold_death
                elif (mode_death=="dummy"):
                    death_filt = [True] * pers[i][dim].shape[0]
                else:
                    print("Error input for variable: mode_death")
                    
                combined_filt = np.logical_and(born_filt, death_filt)
                if (np.sum(combined_filt) == 0):
                    continue
                
                cyc_list = [None] * np.sum(combined_filt)
                cnt = 0
                for j in range(len(combined_filt)):
                    if combined_filt[j]:
                        cyc_list[cnt] = np.transpose(bnd[i][dim][j])
                        cnt = cnt + 1
                dim_list.append(cyc_list)
                dim_labels.append(labels[i])
            kernel = CycleKernel(sigma=10.).fit(dim_list)
            trfmed = kernel.transform(dim_list)
            
            if normalize:
                for i in range(len(dim_labels)):
                    trfmed[i,:] = trfmed[i,:] / np.linalg.norm(trfmed[i,:]) * scale
            kernel_dist[dim] = trfmed
            filt_labels[dim] = dim_labels
            print("dimension ", dim, " completes...")
            
        return kernel_dist, filt_labels
    
    @staticmethod
    def filter_bnd_or_red_by_mask(target, mask):
        '''
        Remove the parts of topological structures that reside outside of the mask
        @target: bnd or red for single file [[3 x pt_num] x struct_num] x dim
        @mask: corresponding binary mask
        '''
        dims = len(target)
        assert(dims == 3)
        filtered_target = [None] * dims
        
        cnt = 0
        for dim in range(dims):
            for i in range(len(target[dim])):
                for j in range(target[dim][i].shape[1]):
                    x = target[dim][i][2,j]
                    y = target[dim][i][1,j]
                    z = target[dim][i][0,j]
                    if (mask[x][y][z] == 0):
                        cnt = cnt + 1
        print("Number of points out of mask:", cnt)
                
    
    @staticmethod
    def filter_bnd_or_red(target, pers, threshold):
        '''
        @target: bnd or red for single file [[3 x pt_num] x struct_num] x dim
        @pers: persistence for single file [struct_num x 2] x dim
        @threshold: threshold for persistence
        '''
        dims = len(pers)
        assert(dims == len(target))
        filtered_target = [None] * dims
        
        for dim in range(dims):
            persistence = pers[dim][:,1] - pers[dim][:,0] >= threshold
            print('%.3f' %(np.sum(persistence)/(pers[dim].shape[0])*100))
            cyc_list = [None] * np.sum(persistence)
            cnt = 0
            for i in range(len(persistence)):
                if persistence[i]:
                    cyc_list[cnt] = target[dim][i]
                    cnt = cnt + 1
            filtered_target[dim] = cyc_list
        return filtered_target