# NumPy version

In [34]:
import numpy as np

import time
from itertools import combinations, groupby
from bisect import bisect
from operator import attrgetter

from dataclasses import dataclass
from dataclasses import astuple, asdict

from typing import Tuple, Dict, List

@dataclass
class Simplex:
    vertices: Tuple[int]
    index: int = None
    time: float = None
    weight: float = None

    def __repr__(self):
        return "({})".format(", ".join(map(str, self.vertices)))

    @property
    def dim(self):
        return len(self.vertices) - 1

    @property
    def boundary(self):
        if self.dim==0:
            faces = []
        else:
            faces = [Simplex(item) for item in combinations(self.vertices, self.dim)][::-1]
        return faces

@dataclass
class PersistenceRepresentative:
    birth_simplex: Simplex
    death_simplex: Simplex

@dataclass
class PersistenceDiagram:
    elements: List[PersistenceRepresentative]

    def num_representatives(self, dim=0):
        n_representatives = {0: 0, 1: 0}

        for representative in self.elements:
            representative_dim = representative.birth_simplex.dim
            n_representatives[representative_dim] = n_representatives[representative_dim] + 1

        return n_representatives[dim]

    def representatives_graded(self, k=0):

        representatives_graded = {}

        representatives = sorted(self.elements, key=lambda element: (element.birth_simplex.dim)) # , element.birth_simplex.index, element.death_simplex.index

        for k_repr, k_representatives in groupby(representatives, key=lambda representative: representative.birth_simplex.dim):
            k_representatives = list(k_representatives)
            representatives_graded[k_repr] = k_representatives

        return representatives_graded[k]

    def as_numpy(self, index=False):
        pd = np.zeros((len(self.elements), 3))
        
        sorted_elements = sorted(self.elements, key=lambda element: (element.birth_simplex.dim, element.birth_simplex.index, element.death_simplex.index))

        for i, element in enumerate(sorted_elements):
            if index==False:
                pd[i,:] = np.array([element.birth_simplex.dim, element.birth_simplex.time, element.death_simplex.time])
            else:
                pd[i,:] = np.array([element.birth_simplex.dim, element.birth_simplex.index, element.death_simplex.index])

        return pd#.astype(int)

class FilteredComplex:

    def __init__(self, filtration: List[Simplex], oriented=False):
        self.filtration = filtration
        self.oriented = oriented
        self.boundary_matrix = None
        self.reduced_boundary_matrix = None
        self.persistence_diagram = None

        self.simplex_to_index = {}
        for simplex in self.filtration:
            self.simplex_to_index[simplex.vertices] = simplex.index

        n_simplices = len(self.filtration)
        self.boundary_matrix = np.zeros((n_simplices, n_simplices), dtype=int)

        # building boundary matrix
        for simplex in self.filtration:
            for q, face in enumerate(simplex.boundary):
                i, j = self.simplex_to_index[face.vertices], simplex.index
                self.boundary_matrix[i,j] = 1

    def get_reduced_boundary_matrix(self):
        
        def matrix_reduction(matrix: np.ndarray) -> np.ndarray:
            
            def low(column: np.ndarray) -> int:
                if np.any(column!=0):
                    return np.flatnonzero(column)[-1] 
                return -1

            def reduceable(matrix, j, lows, pivots):
                is_reduceable = False
                if lows[j]!=-1 and pivots[lows[j]]!=-1:
                    is_reduceable = pivots[lows[j]]<j
                return is_reduceable
            
            t0 = time.time()
            # set lows and pivots
            lows = [low(column) for column in matrix.T]
            
            pivots = np.ones(matrix.shape[0]).astype(int) * -1
            for i in range(matrix.shape[0]):
                for j in range(i+1, matrix.shape[0]):
                    if (matrix[i,j]!=0 and lows[j]==i):
                        pivots[i] = j
                        break
            # print(time.time() - t0)
            
            t0 = time.time()
            pivots = list(pivots)
            for i in range(0, matrix.shape[1]):
                while reduceable(matrix, i, lows, pivots):
                    j = pivots[lows[i]]
                    matrix[:,i] = (matrix[:,j] + matrix[:,i]) % 2
                    lows[i] = low(matrix[:,i]) # update lows
                
                if lows[i]!=-1:
                    pivots[lows[i]] = i; # update pivots
            # print(time.time() - t0)
                    
            return matrix

        if (self.reduced_boundary_matrix is None): # cached
            self.reduced_boundary_matrix = matrix_reduction(self.boundary_matrix)
            # self.persistence_diagram = self.get_persistence_diagram()

        return self.reduced_boundary_matrix

    def view_boundary_matrix(self, index=None, order=1):
        
        self.simplices_at_index = {}
        self.simplices_index_idx = {}

        filtration_index = sorted(self.filtration[:index+1], key=lambda simplex: (len(simplex.vertices), simplex.index))
        for k, k_simplices in groupby(filtration_index, key=lambda simplex: len(simplex.vertices)):
            k_simplices = list(k_simplices)
            self.simplices_at_index[k-1] = k_simplices
            self.simplices_index_idx[k-1] = [simplex.index for simplex in k_simplices]

        if order==1:
            B = self.oriented_boundary_matrix[self.simplices_index_idx[0],:][:,self.simplices_index_idx[1]]
        elif order==2:
            B = self.oriented_boundary_matrix[self.simplices_index_idx[1],:][:,self.simplices_index_idx[2]]

        return B

    def get_persistence_diagram(self):
        def low(column):
            column = (column!=0).astype(int)
            argwhere = np.argwhere(column)
            if argwhere.shape[0]==0:
                lowest = -1
            else:
                lowest = argwhere[-1,0]
            return lowest

        persistence_representatives = []
        for j in range(len(self.filtration)):
            i_low = low(self.reduced_boundary_matrix[:,j])
            if i_low!=-1:
                birth_simplex, death_simplex = self.filtration[i_low], self.filtration[j]
                if (death_simplex.index - birth_simplex.index) > 1:
                    persistence_representative = PersistenceRepresentative(birth_simplex, death_simplex)
                    persistence_representatives.append(persistence_representative)

        return PersistenceDiagram(persistence_representatives)

    @property
    def harmonic_persistence_diagram(self):
        pass

class IndexFiltration:
    
    def __init__(self, cmplx):
        self.cmplx = cmplx

    def __call__(self, identity=False):
        
        if identity==False:
            filtered_cmplx = sorted(self.cmplx, key=lambda simplex: (simplex.index, simplex.vertices))
        else: # if identity - set index and time as they passed
            filtered_cmplx = self.cmplx
            for i, simplex in enumerate(filtered_cmplx):
                simplex.index = i

        for simplex in filtered_cmplx:
            simplex.time = simplex.index

        return FilteredComplex(filtered_cmplx)

class VietorisRipsFiltration:
    
    def __init__(self, X, distance_matrix=False):
        def pairwise_distances(X):
            return np.linalg.norm(X[:, None, :] - X[None, :, :], axis=-1)

        if (distance_matrix):
            self.X = X
        else:
            self.X = pairwise_distances(X)

        self.n_vertices = X.shape[0]

    def __call__(self):
        def f(simplex):
            if simplex.dim==0:
                f = 0
            elif simplex.dim==1:
                i, j = simplex.vertices
                f = self.X[i,j]
            else:
                i, j, k = simplex.vertices
                f = max([self.X[i,j], self.X[i,k], self.X[j,k]])
            return f

        # TODO: refactor
        vertices = [Simplex(item) for item in combinations(range(self.n_vertices), 1)]
        edges = [Simplex(item) for item in combinations(range(self.n_vertices), 2)]
        triangles = [Simplex(item) for item in combinations(range(self.n_vertices), 3)]
        cmplx = [item for lst in [vertices, edges, triangles] for item in lst]

        for simplex in cmplx:
            simplex.time = f(simplex)

        filtered_cmplx = sorted(cmplx, key=lambda simplex: (simplex.time, simplex.dim, simplex.vertices))

        for i, simplex in enumerate(filtered_cmplx):
            simplex.index = i

        return FilteredComplex(filtered_cmplx)

In [35]:
cloud1 = np.random.randn(20, 10)
cloud2 = np.random.randn(20, 10)

In [36]:
import time
t0 = time.time()
barc = VietorisRipsFiltration(cloud1)()
print(time.time() - t0)

0.008058786392211914


In [37]:
t0 = time.time()
barc.get_reduced_boundary_matrix()
print(time.time() - t0)

0.24260950088500977


In [38]:
t0 = time.time()
barc.get_persistence_diagram().as_numpy()
print(time.time() - t0)

0.012668371200561523


In [40]:
from ripser import ripser

In [41]:
ripser(cloud1)['dgms']

[array([[0.        , 1.97354317],
        [0.        , 2.00263047],
        [0.        , 2.04324102],
        [0.        , 2.11252642],
        [0.        , 2.12036419],
        [0.        , 2.29703283],
        [0.        , 2.38572741],
        [0.        , 2.38801169],
        [0.        , 2.40740585],
        [0.        , 2.45936656],
        [0.        , 2.72796679],
        [0.        , 2.73919749],
        [0.        , 2.7594595 ],
        [0.        , 2.85893655],
        [0.        , 2.94371986],
        [0.        , 3.22633338],
        [0.        , 3.31528974],
        [0.        , 3.41085625],
        [0.        , 3.55512404],
        [0.        ,        inf]]),
 array([[3.36920309, 3.43974566],
        [2.9860785 , 3.29701471],
        [2.83171391, 2.86610675],
        [2.55396271, 2.67745852],
        [2.48277783, 2.77688313]])]

The python code is correct

In [42]:
%%timeit

barc = VietorisRipsFiltration(cloud1)()
barc.get_reduced_boundary_matrix()
barc.get_persistence_diagram().as_numpy()

257 ms ± 6.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# CuPy version

In [43]:
import numpy as np
import cupy as cp
import cupyx as cpx
from cupyx.scipy.sparse import csr_matrix

from itertools import combinations, groupby
from bisect import bisect
from operator import attrgetter

from dataclasses import dataclass
from dataclasses import astuple, asdict

from typing import Tuple, Dict, List


# class SparseMatrix:

#     def __init__(self, shape, dtype):
#         self.row_matrix = dict()
#         self.col_matrix = dict()
#         self.transposed = False
#         assert len(shape) == 2 and isinstance(shape[0], int) and isinstance(shape[1], int)
#         self.shape = tuple(shape)
#         self.dtype = dtype

#     def __getitem__(self, ind):

#         i, j = ind
#         row, col = (i, j) if not self.transposed else (j, i)
#         assert row < self.shape[0] and col < self.shape[1]
#         if isinstance(row, int) and isinstance(col, int):
#             if row in self.row_matrix and col in self.row_matrix[row]:
#                 return self.row_matrix[row][col]
#             else:
#                 return 0
#         elif isinstance(row, slice) and isinstance(col, int):
#             if col in self.col_matrix:
#                 arr = cp.zeros((row.end - row.start,))
#                 for r in range(row.start, row.end):
#                     for c, v in self.row_matrix.items():
#                         if c == col: arr[r - row.start] = v
#                 return arr
#             else:
#                 return cp.zeros((row.end - row.start,))
#         elif isinstance(row, int) and isinstance(col, slice):
#             if row in self.row_matrix:
#                 arr = cp.zeros((col.end - col.start,))
#                 for c in range(col.start, col.end):
#                     for r, v in self.col_matrix.items():
#                         if r == row: arr[c - col.start] = v
#                 return arr
#             else:
#                 return cp.zeros((col.end - col.start,))
#         elif isinstance(row, slice) and isinstance(col, slice):
#             raise Exception(f'Not implemented yet')
#         else:
#             raise Exception(f'Index type must be int or slice, not ({i.type, j.type})')

#     def __setitem__(self, ind, value):
#         i, j = ind
#         row, col = (i, j) if not self.transposed else (j, i)
#         assert row < self.shape[0] and col < self.shape[1]
#         assert isinstance(value, self.dtype)
#         if isinstance(row, int) and isinstance(col, int):
#             if not row in self.row_matrix:
#                 self.row_matrix[row] = dict()
#             if not col in self.col_matrix:
#                 self.col_matrix[col] = dict()
#             self.row_matrix[row][col] = value
#             self.col_matrix[col][row] = value
#         elif isinstance(row, slice) and isinstance(col, int):
#             assert len(value) == self.shape[0]
#             if not col in self.col_matrix:
#                 self.col_matrix[col] = dict()
#             for i in range(row.start, row.end):
#                 v = value[i - row.start]
#                 if v != 0:
#                     self.col_matrix[col][i] = v
#                     if i not in self.row_matrix:
#                         self.row_matrix[i] = dict()
#                     self.row_matrix[i][col] = v
#         elif isinstance(row, int) and isinstance(col, slice):
#             assert len(value) == self.shape[1]
#             if not row in self.row_matrix:
#                 self.row_matrix[row] = dict()
#             for i in range(row.start, row.end):
#                 v = value[i - row.start]
#                 if v != 0:
#                     self.row_matrix[row][i] = v
#                     if i not in self.col_matrix:
#                         self.col_matrix[i] = dict()
#                     self.col_matrix[i][row] = v
#         elif isinstance(row, slice) and isinstance(col, slice):
#             pass
#         else:
#             raise Exception(f'Index type must be int or slice, not ({i.type, j.type})')

#     def transpose(self):
#         self.transposed = not self.transposed

@dataclass
class Simplex:
    vertices: Tuple[int]
    index: int = None
    time: float = None
    weight: float = None

    def __repr__(self):
        return "({})".format(", ".join(map(str, self.vertices)))

    @property
    def dim(self):
        return len(self.vertices) - 1

    @property
    def boundary(self):
        if self.dim==0:
            faces = []
        else:
            faces = [Simplex(item) for item in combinations(self.vertices, self.dim)][::-1]
        return faces

@dataclass
class PersistenceRepresentativeCupy:
    birth_simplex: Simplex
    death_simplex: Simplex

@dataclass
class PersistenceDiagramCupy:
    elements: List[PersistenceRepresentativeCupy]

    def num_representatives(self, dim=0):
        n_representatives = {0: 0, 1: 0}

        for representative in self.elements:
            representative_dim = representative.birth_simplex.dim
            n_representatives[representative_dim] = n_representatives[representative_dim] + 1

        return n_representatives[dim]

    def representatives_graded(self, k=0):

        representatives_graded = {}

        representatives = sorted(self.elements, key=lambda element: (element.birth_simplex.dim)) # , element.birth_simplex.index, element.death_simplex.index

        for k_repr, k_representatives in groupby(representatives, key=lambda representative: representative.birth_simplex.dim):
            k_representatives = list(k_representatives)
            representatives_graded[k_repr] = k_representatives

        return representatives_graded[k]

    def as_numpy(self, index=False):
        pd = np.zeros((len(self.elements), 3))
        
        sorted_elements = sorted(self.elements, key=lambda element: (element.birth_simplex.dim, element.birth_simplex.index, element.death_simplex.index))

        for i, element in enumerate(sorted_elements):
            if index==False:
                pd[i,:] = np.array([element.birth_simplex.dim, element.birth_simplex.time, element.death_simplex.time])
            else:
                pd[i,:] = np.array([element.birth_simplex.dim, element.birth_simplex.index, element.death_simplex.index])

        return pd#.astype(int)

class FilteredComplexCupy:

    def __init__(self, filtration: List[Simplex], oriented=False):
        self.filtration = filtration
        self.oriented = oriented
        self.boundary_matrix = None
        self.reduced_boundary_matrix = None
        self.persistence_diagram = None

        self.simplex_to_index = {}
        for simplex in self.filtration:
            self.simplex_to_index[simplex.vertices] = simplex.index

        n_simplices = len(self.filtration)
        self.boundary_matrix = cp.zeros((n_simplices, n_simplices), dtype=int)

        # building boundary matrix
        for simplex in self.filtration:
            for q, face in enumerate(simplex.boundary):
                i, j = self.simplex_to_index[face.vertices], simplex.index
                self.boundary_matrix[i,j] = 1

    def get_reduced_boundary_matrix(self):
        
        def matrix_reduction(matrix):
            
            def low(column):
                nz = np.flatnonzero(column)
                if len(nz) > 0:
                    return int(nz[-1])
                return -1

            def reduceable(matrix, j, lows, pivots):
                is_reduceable = False
                if lows[j] != -1 and pivots[int(lows[j])] != -1:
                    is_reduceable = pivots[int(lows[j])] < j
                return is_reduceable
            
            # set lows and pivots
            lows = cp.array([low(column) for column in matrix.T])
            
            pivots = cp.ones(matrix.shape[0]) * -1
            mat_inds = (matrix != 0).astype(bool) & (lows.reshape((1, -1)) == cp.arange(len(lows)).reshape((-1, 1)))
            mat_inds = cp.flatnonzero(mat_inds)
            row = mat_inds // matrix.shape[0]
            col = mat_inds % matrix.shape[0]
            pivots[row] = col
            del mat_inds
            
            pivots = pivots.get()
            matrix = matrix.get()
            lows = lows.get()
            
            for i in range(0, matrix.shape[1]):
                while reduceable(matrix, i, lows, pivots):
                    j = int(pivots[lows[i]])
                    matrix[:,i] = (matrix[:,j] + matrix[:,i]) % 2
                    lows[i] = low(matrix[:,i]) # update lows

                if lows[i] != -1:
                    pivots[int(lows[i])] = i; # update pivots
                    
            return matrix

        if (self.reduced_boundary_matrix is None): # cached
            self.reduced_boundary_matrix = matrix_reduction(self.boundary_matrix)
            # self.persistence_diagram = self.get_persistence_diagram()

        return self.reduced_boundary_matrix

    def view_boundary_matrix(self, index=None, order=1):
        
        self.simplices_at_index = {}
        self.simplices_index_idx = {}

        filtration_index = sorted(self.filtration[:index+1], key=lambda simplex: (len(simplex.vertices), simplex.index))
        for k, k_simplices in groupby(filtration_index, key=lambda simplex: len(simplex.vertices)):
            k_simplices = list(k_simplices)
            self.simplices_at_index[k-1] = k_simplices
            self.simplices_index_idx[k-1] = [simplex.index for simplex in k_simplices]

        if order == 1:
            B = self.oriented_boundary_matrix[self.simplices_index_idx[0],:][:,self.simplices_index_idx[1]]
        elif order == 2:
            B = self.oriented_boundary_matrix[self.simplices_index_idx[1],:][:,self.simplices_index_idx[2]]

        return B

    def get_persistence_diagram(self):
        def low(column):
            column = (column!=0).astype(int)
            argwhere = np.argwhere(column)
            if argwhere.shape[0]==0:
                lowest = -1
            else:
                lowest = argwhere[-1,0]
            return int(lowest)

        persistence_representatives = []
        for j in range(len(self.filtration)):
            i_low = low(self.reduced_boundary_matrix[:,j])
            if i_low!=-1:
                birth_simplex, death_simplex = self.filtration[i_low], self.filtration[j]
                if (death_simplex.index - birth_simplex.index) > 1:
                    persistence_representative = PersistenceRepresentativeCupy(birth_simplex, death_simplex)
                    persistence_representatives.append(persistence_representative)

        return PersistenceDiagramCupy(persistence_representatives)

    @property
    def harmonic_persistence_diagram(self):
        pass

class IndexFiltration:
    
    def __init__(self, cmplx):
        self.cmplx = cmplx

    def __call__(self, identity=False):
        
        if identity==False:
            filtered_cmplx = sorted(self.cmplx, key=lambda simplex: (simplex.index, simplex.vertices))
        else: # if identity - set index and time as they passed
            filtered_cmplx = self.cmplx
            for i, simplex in enumerate(filtered_cmplx):
                simplex.index = i

        for simplex in filtered_cmplx:
            simplex.time = simplex.index

        return FilteredComplexCupy(filtered_cmplx)

class VietorisRipsFiltrationCupy:
    
    def __init__(self, X, distance_matrix=False):
        def pairwise_distances(X):
            return np.linalg.norm(X[:, None, :] - X[None, :, :], axis=-1)

        if (distance_matrix):
            self.X = X
        else:
            self.X = pairwise_distances(X)

        self.n_vertices = X.shape[0]

    def __call__(self):
        def f(simplex):
            if simplex.dim==0:
                f = 0
            elif simplex.dim==1:
                i, j = simplex.vertices
                f = self.X[i,j]
            else:
                i, j, k = simplex.vertices
                f = max([self.X[i,j], self.X[i,k], self.X[j,k]])
            return f

        # TODO: refactor
        vertices = [Simplex(item) for item in combinations(range(self.n_vertices), 1)]
        edges = [Simplex(item) for item in combinations(range(self.n_vertices), 2)]
        triangles = [Simplex(item) for item in combinations(range(self.n_vertices), 3)]
        cmplx = [item for lst in [vertices, edges, triangles] for item in lst]

        for simplex in cmplx:
            simplex.time = f(simplex)

        filtered_cmplx = sorted(cmplx, key=lambda simplex: (simplex.time, simplex.dim, simplex.vertices))

        for i, simplex in enumerate(filtered_cmplx):
            simplex.index = i

        return FilteredComplexCupy(filtered_cmplx)

In [57]:
# %%timeit
barc = VietorisRipsFiltrationCupy(cloud1)()
barc.get_reduced_boundary_matrix()
barc.get_persistence_diagram().as_numpy()

array([[0.        , 0.        , 2.98448869],
       [0.        , 0.        , 2.93823717],
       [0.        , 0.        , 3.05616011],
       [0.        , 0.        , 2.70809174],
       [0.        , 0.        , 3.4840921 ],
       [0.        , 0.        , 2.48254916],
       [0.        , 0.        , 2.31644014],
       [0.        , 0.        , 3.19989339],
       [0.        , 0.        , 3.2685538 ],
       [0.        , 0.        , 2.95687037],
       [0.        , 0.        , 4.02932574],
       [0.        , 0.        , 2.43861846],
       [0.        , 0.        , 2.49536529],
       [0.        , 0.        , 2.43521019],
       [0.        , 0.        , 3.13172491],
       [0.        , 0.        , 3.14116846],
       [0.        , 0.        , 3.79352935],
       [0.        , 0.        , 2.63489257],
       [0.        , 0.        , 2.77147828],
       [1.        , 3.13643258, 3.34235952],
       [1.        , 3.1517784 , 3.46692956],
       [1.        , 3.19880941, 3.48303815],
       [1.

In [44]:
%%timeit

barc = VietorisRipsFiltrationCupy(cloud1)()
barc.get_reduced_boundary_matrix()
barc.get_persistence_diagram().as_numpy()

337 ms ± 6.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
