In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from numba import njit

from scipy.spatial.distance import squareform

from gtda.homology import VietorisRipsPersistence

from steenroder import *

import gudhi

In [None]:
def cartesian_product(*arrays):
    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        arr[...,i] = a
    return arr.reshape(-1, la)

In [None]:
num = 15
print(f"The total number of vertices is {num**2}")

In [None]:
square = cartesian_product(np.linspace(0, 1, num=num), np.linspace(0, 1, num=num))
n = len(square)

squares = np.tile(square.T, 9).T

squares[n:2 * n] += [0, 1]

squares[2 * n:3 * n] += [0, -1]

squares[3 * n:4 * n] += [1, 0]
squares[3 * n:4 * n, 1] *= -1
squares[3 * n:4 * n] += [0, 1]

squares[4 * n:5 * n] += [-1, 0]
squares[4 * n:5 * n, 1] *= -1
squares[4 * n:5 * n] += [0, 1]

squares[5 * n:6 * n] = squares[3 * n:4 * n] + [0, 1]
squares[6 * n:7 * n] = squares[3 * n:4 * n] + [0, -1]

squares[7 * n:8 * n] = squares[4 * n:5 * n] + [0, 1]
squares[8 * n:9 * n] = squares[4 * n:5 * n] + [0, -1]

In [None]:
@njit
def compute_flat_kb_db():
    dm_condensed = np.empty((n * (n - 1)) // 2, dtype=np.float64)
    k = 0
    for i in range(n):
        x = square[i]
        for j in range(i + 1, n):
            sq_dists = np.sum((squares[j::n, :] - x) ** 2, axis=1)
            dm_condensed[k] = np.min(sq_dists)
            k += 1
    return np.sqrt(dm_condensed)

In [None]:
VR = VietorisRipsPersistence(homology_dimensions=(0, 1, 2), metric="precomputed")
VR.fit_transform_plot([squareform(compute_flat_kb_db())]);

In [None]:
RC = gudhi.RipsComplex(distance_matrix=squareform(compute_flat_kb_db()), max_edge_length=0.37)

In [None]:
spx_tree = RC.create_simplex_tree(max_dimension=1)

In [None]:
spx_tree.collapse_edges(nb_iterations=1)
spx_tree.expansion(3)
for i, _ in enumerate(spx_tree.get_filtration()):
    pass
print(f"There are now {i} simplices.")

In [None]:
filtration = []
filtration_values = []
for t in spx_tree.get_filtration():
    filtration.append(tuple(t[0]))
    filtration_values.append(t[1])
filtration_values = np.asarray(filtration_values)

In [None]:
barcode, st_barcodes = barcodes(1, filtration, homology=True, filtration_values=filtration_values, return_filtration_values=True)

In [None]:
st_barcodes