In [None]:
import numpy as np
import sliced_wasserstein 

class sliced_wasserstein_python(object):
    
    def __init__(self):
        self.sw = sliced_wasserstein.sliced_wasserstein()
        
    def convert_pd_2_cpp(self, data, threshold):
        '''
        @data: [[struct_num x 2] x dimensions] x file_num
        '''
        file_num = len(data)
        dims = len(data[0])
        cpp_pd = [None] * file_num
        
        for i in range(file_num):
            file_pd = [None] * dims
            for dim in range(dims):
                persistence = (data[i][dim][:,1] - data[i][dim][:,0]) >= threshold
                dim_pd = sliced_wasserstein.PD(int(np.sum(persistence)))
                cnt = 0
                for j in range(data[i][dim].shape[0]):
                    if persistence[j]:
                        dim_pd[cnt].first  = data[i][dim][j,0]
                        dim_pd[cnt].second = data[i][dim][j,1]
                        cnt = cnt + 1
                file_pd[dim] = dim_pd
            cpp_pd[i] = file_pd
        return cpp_pd

    def compute_sw(self, data, threshold, mode):
        '''
        @data: [[struct_num x 2] x dimensions] x file_num
        @mode: "exact" or "approximate"
        '''
        file_num = len(data)
        dims = len(data[0])
        sw_dist = [np.zeros((file_num, file_num), dtype=np.float64)] * dims
        pers_cpp = self.convert_pd_2_cpp(data, threshold)

        for i in range(dims):
            for j in range(file_num-1):
                for k in range(j+1, file_num):
                    if mode == "exact":
                        sw_dist[i][j,k] = self.sw.compute_exact_SW(pers_cpp[j][i], pers_cpp[k][i])
                        sw_dist[i][k,j] = sw_dist[i][j,k]
                    else:
                        sw_dist[i][j,k] = self.sw.compute_approximate_SW(pers_cpp[j][i], pers_cpp[k][i])
                        sw_dist[i][k,j] = sw_dist[i][j,k]
                print("Dim ", i, " ", j/file_num)
            print("Dim ", i, " completes")
        return sw_dist