# Block Function Simple Example

In this notebook, we explore a simple illustrative example defining the induced block function.

In [None]:
import os
import numpy as np
from numpy.random import default_rng
import matplotlib as mpl
import matplotlib.pyplot as plt

import gudhi

import IBloFunMatch_inter as ibfm

_tol = 1e-10

output_dir = "output" # Name of directory to communicate with C++ program
os.makedirs("plots/class_match/illustrative_example", exist_ok=True)

Now, we generate the circle data for two samples from a bigger point sample.

In [None]:
rng = default_rng(2)
C1 = ibfm.sampled_circle(5, 10, 40, rng)
C2 = ibfm.sampled_circle(1, 1.5, 20, rng)
C3 = ibfm.sampled_circle(5, 10, 20, rng)+[5,0]
C4 = ibfm.sampled_circle(1, 1.5, 20, rng)+[5,0]
X = np.vstack([C1, C2, C4])
Y = np.vstack([C1, C2, C3, C4])
middle_points = np.sum((Y - [2.3,0])**2, axis=1) < 2
outer_points = (middle_points==False)
Y = Y[outer_points]
Z = np.vstack([C1, C2, C3, C4])


In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(12,4))
for i, pts in enumerate([X, Z, Y]):
    ax[i].scatter(pts[:,0], pts[:,1], color=mpl.colormaps["RdYlGn"](i/3))
    ax[i].set_aspect("equal")

In [None]:
X_indices = [i for i, pt in enumerate(Z) if np.any(np.sum((X - pt)**2, axis=1) < 10e-8)]
assert len(X_indices) == X.shape[0]
Y_indices = [i for i, pt in enumerate(Z) if np.any(np.sum((Y - pt)**2, axis=1) < 10e-8)]
assert len(Y_indices) == Y.shape[0]
# Sort X and Y according to X_indices and Y_indices 
X = Z[X_indices]
Y = Z[Y_indices]

Now, compute the induced block function given by both examples.

In [None]:
exp_ibfm = []
exp_ibfm.append(ibfm.get_IBloFunMatch_subset(None, Z, X_indices, output_dir, num_it=4, max_rad=-1, points=True))
exp_ibfm.append(ibfm.get_IBloFunMatch_subset(None, Z, Y_indices, output_dir, num_it=4, max_rad=-1, points=True))

In [None]:
not_indices_block_1=[]

In [None]:
max_rad_1 = np.max([[np.max(ibfm_out["S_barcode_1"]), np.max(ibfm_out["X_barcode_1"])] for ibfm_out in exp_ibfm])*1.2
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(12,12))
for cidx, ibfm_out in enumerate(exp_ibfm):
    ibfm.plot_matching(ibfm_out, ax[cidx, [0,1]], fig, block_function=True, dim=0)
    ibfm.plot_matching(ibfm_out, ax[cidx, [2,3]], fig, max_rad=max_rad_1, block_function=True, codomain_int=not_indices_block_1)
# end for 
plt.savefig(f"plots/class_match/illustrative_example/block_function.png")

Let us print the longest intervals of domain and codomain, as well as the associated matrices.

In [None]:
X_barcode = exp_ibfm[0]["S_barcode_1"]
Y_barcode = exp_ibfm[1]["S_barcode_1"]
Z_barcode = exp_ibfm[0]["X_barcode_1"]
X_long = (X_barcode[:,1]-X_barcode[:,0])>0.5
Y_long = (Y_barcode[:,1]-Y_barcode[:,0])>0.5
Z_long = (Z_barcode[:,1]-Z_barcode[:,0])>0.5

In [None]:
X_long_barcode = X_barcode[X_long]
Y_long_barcode = Y_barcode[Y_long]
Z_long_barcode = Z_barcode[Z_long]

In [None]:
Z_long_barcode

In [None]:
XZ_mat = np.zeros((Z_barcode.shape[0], X_barcode.shape[0]), dtype="int")
for j, col in enumerate(exp_ibfm[0]["pm_matrix_1"]):
    for i in col:
        XZ_mat[i][j]=1

YZ_mat = np.zeros((Z_barcode.shape[0], Y_barcode.shape[0]), dtype="int")
for j, col in enumerate(exp_ibfm[1]["pm_matrix_1"]):
    for i in col:
        YZ_mat[i][j]=1


In [None]:
XZ_mat[Z_long][:,X_long]

In [None]:
YZ_mat[Z_long][:,Y_long]

In [None]:
Y_barcode[Y_long]

Plot matchings between these barcodes, printing their matrix data.

In [None]:
X_barcode[X_long]

In [None]:
Z_barcode[Z_long]

In [None]:
Z_long_index = np.nonzero(Z_long)[0].tolist()
Z_long_index.index(12)

In [None]:
XZ_block_long = [Z_long_index.index(exp_ibfm[0]["block_function_1"][i]) for i in np.nonzero(X_long)[0]]
XZ_block_long

In [None]:
YZ_block_long = [Z_long_index.index(exp_ibfm[1]["block_function_1"][i]) for i in np.nonzero(Y_long)[0]]
YZ_block_long

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(8,4))
ibfm.plot_from_block_function(X_barcode[X_long], Z_barcode[Z_long], XZ_block_long, fig, ax[[0,1]])
plt.savefig(f"plots/class_match/illustrative_example/block_function_0.png")

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(8,4))
ibfm.plot_from_block_function(X_barcode[X_long], Z_barcode[Z_long], YZ_block_long, fig, ax)
plt.savefig(f"plots/class_match/illustrative_example/block_function_1.png")

Now, show a plot with the most important intervals and their associated matrices.

In [None]:
# Repr lists 
X_long_reps = [exp_ibfm[0]["S_reps_1"][i] for i, is_long in enumerate(X_long) if is_long]
Y_long_reps = [exp_ibfm[1]["S_reps_1"][i] for i, is_long in enumerate(Y_long) if is_long]
Z_long_reps = [exp_ibfm[0]["X_reps_1"][i] for i, is_long in enumerate(Z_long) if is_long]
repr_long_list = [X_long_reps, Z_long_reps, Y_long_reps]
# Now print results
nrows = max(len(X_long_reps), len(Y_long_reps), len(Z_long_reps))
fig, ax = plt.subplots(nrows=nrows, ncols=3, figsize=(12,9))
point_list = [X, Z, Y]
for j, repr_list in enumerate(repr_long_list):
    points = point_list[j]
    for i in range(len(repr_list)):
        repr = repr_list[i].copy() # important to copy
        while len(repr)>0:
            edge_pts = Z[[repr.pop(), repr.pop()]]
            ax[i, j].plot(edge_pts[:,0], edge_pts[:,1], linewidth=3, c="blue", zorder=3)
        # end while
        ax[i, j].scatter(points[:,0], points[:,1], color=mpl.colormaps["Set1"](j/8), zorder=2)
        ax[i, j].set_aspect("equal")
    # end for cycles 
    for i in range(len(repr_list), nrows):
        ax[i, j].set_axis_off()
# end for X, Y, Z
plt.savefig("plots/class_match/illustrative_example/representatives.png")