In [None]:
import numpy as np
from numba import njit

from scipy.spatial.distance import squareform

from gtda.homology import VietorisRipsPersistence

from steenroder import *

import gudhi

## Uncomment to use giotto-ph
# from scipy import sparse as sp
# from gph.modules import gph_collapser

In [None]:
def cartesian_product(*arrays):
    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        arr[...,i] = a
    return arr.reshape(-1, la)

In [None]:
num = 15
print(f"The total number of vertices is {num**2}")

In [None]:
square = cartesian_product(np.linspace(0, 1, num=num, endpoint=False), np.linspace(0, 1, num=num, endpoint=False))
n = len(square)

squares = np.tile(square.T, 9).T

squares[n:2 * n] += [0, 1]

squares[2 * n:3 * n] += [0, -1]

squares[3 * n:4 * n] += [1, 0]
squares[3 * n:4 * n, 1] *= -1
squares[3 * n:4 * n] += [0, 1]

squares[4 * n:5 * n] += [-1, 0]
squares[4 * n:5 * n, 1] *= -1
squares[4 * n:5 * n] += [0, 1]

squares[5 * n:6 * n] = squares[3 * n:4 * n] + [0, 1]
squares[6 * n:7 * n] = squares[3 * n:4 * n] + [0, -1]

squares[7 * n:8 * n] = squares[4 * n:5 * n] + [0, 1]
squares[8 * n:9 * n] = squares[4 * n:5 * n] + [0, -1]

In [None]:
@njit
def compute_flat_kb_db():
    dm_condensed = np.empty((n * (n - 1)) // 2, dtype=np.float64)
    k = 0
    for i in range(n):
        x = square[i]
        for j in range(i + 1, n):
            sq_dists = np.sum((squares[j::n, :] - x) ** 2, axis=1)
            dm_condensed[k] = np.min(sq_dists)
            k += 1
    return np.sqrt(dm_condensed)

In [None]:
dm = squareform(compute_flat_kb_db())

In [None]:
VR = VietorisRipsPersistence(homology_dimensions=(0, 1, 2), metric="precomputed")
diagram = VR.fit_transform_plot([dm])[0];

In [None]:
# Plotting - Seaborn plugins
import seaborn as sns
sns.set_theme()
sns.set_style("whitegrid")
sns.set_style("ticks")
sns.set_palette("bright")

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection

import matplotlib
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['font.style'] = "normal"
matplotlib.rcParams['font.variant'] = "normal"

matplotlib.rcParams['font.serif'] = "Computer Modern Roman"

In [None]:
max_edge_length = 0.37

In [None]:
## Uncomment if using giotto-ph
# row, col, data = gph_collapser.flag_complex_collapse_edges_dense(dm, thresh=max_edge_length)

# spx_tree = gudhi.simplex_tree.SimplexTree()
# for i in range(dm.shape[0]):
#     spx_tree.insert([i], 0.)
# for i, v in enumerate(data):
#     spx_tree.insert([row[i], col[i]], v)

# spx_tree.expansion(3)
# for i, _ in enumerate(spx_tree.get_filtration()):
#     pass
# print(f"There are now {i} simplices.")

In [None]:
RC = gudhi.RipsComplex(distance_matrix=dm.astype(np.float32), max_edge_length=max_edge_length)
spx_tree = RC.create_simplex_tree(max_dimension=1)

spx_tree.collapse_edges(nb_iterations=1)

spx_tree.expansion(3)
for i, _ in enumerate(spx_tree.get_filtration()):
    pass
print(f"There are now {i} simplices.")

In [None]:
filtration = []
filtration_values = []
for t in spx_tree.get_filtration():
    filtration.append(tuple(t[0]))
    filtration_values.append(t[1])
filtration_values = np.asarray(filtration_values)

In [None]:
k = 1

barcode, st_barcodes = barcodes(
    k,
    filtration,
    filtration_values=filtration_values,
    return_filtration_values=True,
    verbose=True
)

In [None]:
st_barcodes

In [None]:
n_dims = len(barcode)

thresh = 0.1
eps = 0.01
min_filtration_value = np.min(filtration_values)

fig, (ax_rel_coho, ax_st) = plt.subplots(2, 1,
                                         figsize=(16, 8),
                                         sharex='col',
                                         gridspec_kw={'height_ratios': [2, 1]},
                                         tight_layout=True)

colors = ["Orange", "Green", "Blue", "Red"]
labels_rel_coho = [r"$H^0_R$",
                   r"$H^1_R$",
                   r"$H^2_R$",
                   r"$H^3_R$"]
labels_st = [r"$\mathrm{img}(Sq^1) \cap H^0_R$",
             r"$\mathrm{img}(Sq^1) \cap H^1_R$",
             r"$\mathrm{img}(Sq^1) \cap H^2_R$",
             r"$\mathrm{img}(Sq^1) \cap H^3_R$"]

counter = 0
for dim in range(n_dims):
    segs = []
    multiplicities = {}
    dgm = barcode[dim]
#     dgm = dgm[dgm[:, 1] - dgm[:, 0] > thresh]
    for p in dgm:
        if tuple(p) in multiplicities:
            multiplicities[tuple(p)] += 1
        else:
             multiplicities[tuple(p)] = 1

    counter_now = counter
    for i, (k, v) in enumerate(multiplicities.items()):
        death, birth = k
        y = - (counter_now + i)
        if death == -np.inf:
            ax_rel_coho.arrow(min_filtration_value - eps, y, -0.0000001, 0, head_starts_at_zero=False, width=0, head_width=0.3, head_length=0.005, color=colors[dim], ec=colors[dim])
            death = min_filtration_value - eps
        segs.append([[birth, y], [death, y]])
        if v > 1:
            ax_rel_coho.annotate(f"{v}", (death, y + 0.2))
        counter += 1

    segs = np.array(segs, dtype=np.float64)
    if len(segs):
        line_segments = LineCollection(segs, linewidths=2,
                                       colors=colors[dim],
                                       label=labels_rel_coho[dim],
                                       linestyle="solid")
        ax_rel_coho.add_collection(line_segments)

    counter += 2

ax_rel_coho.axvline(x=max_edge_length, color="gray", alpha=0.3)
ax_rel_coho.text(max_edge_length, y, rf"thresh = {max_edge_length}", rotation=90, fontdict={"fontsize": 15})

ax_rel_coho.autoscale()
ax_rel_coho.get_yaxis().set_visible(False)
ax_rel_coho.legend(loc="upper right", fontsize=18)
# ax_rel_coho.margins(y=1)
ax_rel_coho.set_title("Persistent relative cohomology barcode", fontdict={"fontsize": 22}, pad=15)

counter = 0
for dim in range(n_dims):
    segs = []
    multiplicities = {}
    dgm = st_barcodes[dim]
    dgm = dgm[dgm[:, 1] - dgm[:, 0] > thresh]
    for p in dgm:
        if tuple(p) in multiplicities:
            multiplicities[tuple(p)] += 1
        else:
             multiplicities[tuple(p)] = 1

    counter_now = counter
    for i, (k, v) in enumerate(multiplicities.items()):
        death, birth = k
        y = - (counter_now + i)
        if death == -np.inf:
            ax_st.arrow(min_filtration_value - eps, y, -0.0000001, 0, head_starts_at_zero=False, width=0, head_width=0.3, head_length=0.005, color=colors[dim], ec=colors[dim])
            death = min_filtration_value - eps
        segs.append([[birth, y], [death, y]])
        if v > 1:
            ax_st.annotate(f"{v}", (death, y + 0.2))
        counter += 1

    segs = np.array(segs, dtype=np.float64)
    if len(segs):
        line_segments = LineCollection(segs, linewidths=2,
                                       colors=colors[dim],
                                       label=labels_st[dim],
                                       linestyle="dashed")
        ax_st.add_collection(line_segments)

    counter += 2

ax_st.axvline(x=max_edge_length, color="gray", alpha=0.3)
ax_st.text(max_edge_length, y - 0.26, rf"thresh = {max_edge_length}", rotation=90, fontdict={"fontsize": 15})

ax_st.tick_params(axis="x", labelsize=18) 

ax_st.autoscale()
ax_st.get_yaxis().set_visible(False)
ax_st.legend(loc="upper right", fontsize=18)
# ax_st.margins(y=1)
ax_st.set_title("Steenrod barcode", fontdict={"fontsize": 22}, pad=15)

plt.savefig("flat_Klein_bottle.pdf")