In [None]:
import numpy as np
import gudhi
import matplotlib.pyplot as plt
import matplotlib as mpl

import scipy.spatial.distance as dist
import itertools

import iblofunmatch.inter as ibfm
output_dir = "output"
plots_dir = "plots/example_computation/"

import os 
os.makedirs(output_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)

from scipy.sparse.csgraph import minimum_spanning_tree

In [None]:
RandGen = np.random.default_rng(2)
X = ibfm.sampled_circle(0,2,15, RandGen)
S_indices = RandGen.choice(X.shape[0],6, replace=False)
# Sort and save 
S_compl = np.ones(X.shape[0], dtype="bool")
S_compl[S_indices] = False
X = np.vstack((X[S_indices], X[S_compl]))
S_indices = range(len(S_indices))
np.savetxt("Z_example_2.txt", X, fmt="%.4f")
np.savetxt("X_idx_example_2.txt", S_indices, fmt="%d")
# X = np.loadtxt("Z_example.txt")
# S_indices = np.loadtxt("X_idx_example.txt", dtype="int")
# X = np.array(
# [[-0.2544,  1.6085],
#  [-1.1095, -0.4577],
#  [ 0.856 ,  1.1792],
#  [ 0.199 ,  0.4839],
#  [ 0.5615,  0.2027],
#  [-0.101 , -0.1536],
#  [-1.    ,  0.5   ]])
# S_indices = list(range(5))
S = X[S_indices]
# Plot point cloud
fig, ax = plt.subplots(ncols=1, figsize=(3,3))
ax.scatter(S[:,0], S[:,1], color=mpl.colormaps["RdBu"](0.3/1.3), s=60, marker="o", zorder=2)
ax.scatter(X[:,0], X[:,1], color=mpl.colormaps["RdBu"](1/1.3), s=40, marker="x", zorder=1)
ax.set_axis_off()
plt.savefig(plots_dir + "points_0.png")

In [None]:
import itertools

def filtration_pairs(points):
    dist_points = dist.squareform(dist.pdist(points))
    mst = minimum_spanning_tree(dist_points)
    filtration_list = []
    pairs = []
    for (i,j) in itertools.product(range(points.shape[0]), range(points.shape[0])):
        if i < j and mst[i,j]>0:
            filtration_list.append(mst[i,j])
            pairs.append([i,j])
    
    pairs_arr = np.array(pairs)
    np.argsort(filtration_list)
    sort_idx = np.argsort(filtration_list)
    filtration_list = np.array(filtration_list)[sort_idx].tolist()
    pairs_arr = pairs_arr[sort_idx]
    # Get proper merge tree pairs 
    labels = np.array(list(range(points.shape[0])))
    correct_pairs_list = []
    for pair in pairs_arr:
        min_label = np.min(labels[pair])
        max_label = np.max(labels[pair])
        correct_pairs_list.append([min_label, max_label])
        assert min_label < max_label
        labels[labels==max_label]=min_label

    pairs_arr = np.array(correct_pairs_list)
    return filtration_list, pairs_arr

In [None]:
filt_X, pairs_X = filtration_pairs(X)
filt_S, pairs_S = filtration_pairs(S)

In [None]:
filt_S

In [None]:
pairs_S

In [None]:
filt_X

In [None]:
pairs_X

In [None]:
dist_points = dist.squareform(dist.pdist(S))
mst = minimum_spanning_tree(dist_points)
print(mst)

In [None]:
pairs_S

In [None]:
filt_S

Now, we put this information into a matrix.

In [None]:
import phat

In [None]:
phat_input = [(0,[])]*X.shape[0]
for pair in pairs_S:
    phat_input.append((1,list(pair)))
for pair in pairs_X:
    phat_input.append((1,list(pair)))

In [None]:
boundary_matrix = phat.boundary_matrix(representation = phat.representations.vector_vector)
boundary_matrix.columns = phat_input

# or equivalently, boundary_matrix = phat.boundary_matrix(representation = ..., columns = ...)
# would combine the creation of the matrix and the assignment of the columns

# print some information of the boundary matrix:
print("\nThe boundary matrix has %d columns:" % len(boundary_matrix.columns))
for col in boundary_matrix.columns:
    s = "Column %d represents a cell of dimension %d." % (col.index, col.dimension)
    if (col.boundary):
        s = s + " Its boundary consists of the cells " + " ".join([str(c) for c in col.boundary])
        print(s)

print("Overall, the boundary matrix has %d entries." % len(boundary_matrix))

pairs = boundary_matrix.compute_persistence_pairs()

pairs.sort()


In [None]:
print("\nThere are %d persistence pairs: " % len(pairs))
for pair in pairs:
    print("Birth: %d, Death: %d" % pair)

In [None]:
for col in boundary_matrix.columns:
    print(col)

In [None]:
for preim in boundary_matrix.get_preimages():
    print(preim)

In [None]:
ibfm_out["S_barcode_0"][:,1]

In [None]:
filt_S

In [None]:
ibfm_out["X_barcode_0"][:,1]

In [None]:
filt_X

In [None]:
ibfm_out["induced_matching_0"]

In [None]:
ibfm_out = ibfm.get_IBloFunMatch_subset(None, X, S_indices, output_dir, num_it=4, max_rad=-1, points=True, store_0_pm=True)
fig, ax = plt.subplots(ncols=2, figsize=(5,2))
ibfm.plot_matching(ibfm_out, ax, fig, block_function=True, dim=0)
plt.savefig(plots_dir + "block_function_0.png")