# Triplet Merge Tree and component evolution example.

In this notebook, we compute a couple of examples that depict the evolution of connected components and triplet merge trees on the one dimensional Vietoris-Rips filtration.

This notebooks has been adapted from a notebook in the topological data quality repository, see:
https://github.com/Cimagroup/tdqual/blob/main/notebooks/example_Mf_computation.ipynb

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import scipy.spatial.distance as dist
import itertools

import tdqual.topological_data_quality_0 as tdqual

import os 
plots_dir = os.path.join("plots", "tmt-components")
os.makedirs(plots_dir, exist_ok=True)

To make some plots from this notebook, we need to install GUDHI for working with simplicial complexes in an easy way.

In [None]:
# pip install gudhi
import gudhi

# Computation of Block Function in dimension 0

Consider the following example, with points taken from a sample.

We consider 7 points. 

In [None]:
RandGen = np.random.default_rng(2)
# # Generate Random Sample
Z = tdqual.sampled_circle(0,2,6, RandGen)
# Sort Z so that the first # X points are from X, also, modify some points and save 
Z = np.vstack([Z, np.array([[-0.1,0], [0.4,0], [0.15, np.sqrt(0.5**2 - 0.25**2)]]) + [-1.6,1]])
# Plot point cloud
fig, ax = plt.subplots(ncols=1, figsize=(3,3))
ax.scatter(Z[:,0], Z[:,1], color=mpl.colormaps["RdBu"](1/1.3), s=40, marker="o", zorder=1)
ax.set_axis_off()
ax.set_aspect("equal")
plt.savefig(plots_dir + "points_0.png")

We plot, for illustration, the Vietoris-Rips complex at a sequence of values

In [None]:
### Geometric Matching 
def compute_components(edgelist, num_points):
    components = np.array(range(num_points))
    for edge in edgelist:
        max_idx = np.max(components[edge])
        min_idx = np.min(components[edge])
        indices = np.nonzero(components == components[max_idx])[0]
        components[indices]=np.ones(len(indices))*components[min_idx]
    
    return components

In [None]:
def plot_Vietoris_Rips(Z,  filt_val, ax, labels=False, fontsize=15):
    # Plot point cloud
    if labels:
        ax.scatter(Z[:,0], Z[:,1], color=mpl.colormaps["RdBu"](1/1.3), s=230, marker="o", zorder=1)
    else:
        ax.scatter(Z[:,0], Z[:,1], color=mpl.colormaps["RdBu"](1/1.3), s=40, marker="o", zorder=1)
    # Plot simplicial complex 
    rips_complex = gudhi.RipsComplex(points=Z, max_edge_length=filt_val)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=1)
    simplex_tree.expansion(2)
    edgelist = []
    for filtered_value in simplex_tree.get_filtration():
        simplex = filtered_value[0]
        if len(simplex)==2:
            edgelist.append(simplex)
            ax.plot(Z[simplex][:,0], Z[simplex][:,1], linewidth=2, c=mpl.colormaps["RdBu"](1/1.3), zorder=0.5)
        # end if
    # end for
    ax.set_aspect("equal")
    # Adjust margins
    xscale = ax.get_xlim()[1]-ax.get_xlim()[0]
    yscale = ax.get_ylim()[1]-ax.get_ylim()[0]
    xlim = ax.get_xlim()
    xlim = (xlim[0]-xscale*0.1, xlim[1]+xscale*0.1)
    ax.set_xlim(xlim)
    ylim = ax.get_ylim()
    ylim = (ylim[0]-yscale*0.1, ylim[1]+yscale*0.1)
    ax.set_ylim(ylim)
    # Plot labels
    if labels:
        components = compute_components(edgelist, Z.shape[0])
        # Point Labels 
        for i in range(Z.shape[0]):
            ax.text(Z[i,0]-0.035*xscale, Z[i,1]-0.035*yscale, f"{components[i]}", fontsize=fontsize, color="white", fontweight="bold")
    # Finish with aspect details 
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
filtrations = [0.0, 0.8, 1.1, 2]
fig, ax = plt.subplots(ncols=len(filtrations), figsize=(3*len(filtrations),3))
for j, filt_val in enumerate(filtrations):
    plot_Vietoris_Rips(Z, filt_val, ax[j], labels=True)
    # Set title 
    ax[j].set_title(f"{filt_val:1.1f}", fontsize=20) 
# end for
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, "VR_filtration.png"))


Now, we describe the $0$-persistence barcodes in terms of evolution of components.

In [None]:
def plot_merge_tree(endpoints_0, reps_0, ax):
    max_x = np.max(endpoints_0)*1.1
    num_points = len(endpoints_0)+1
    y= np.linspace(0, 0.3*num_points, num_points)
    idx_death = []
    merging_into= []
    death_val = []
    for idx, (end, rep) in enumerate(zip(endpoints_0, reps_0)):
        ax.plot([0,end], [y[idx], y[idx]], c=mpl.colormaps["RdBu"](1/1.3), linewidth=3, zorder=0.5)
        idx_death.append(np.max(rep))
        merging_into.append(np.min(rep))
        death_val.append(end)
    
    # merge lines in red
    idx_death.append(0)
    for idx, (j, death) in enumerate(zip( merging_into, death_val)):
        death_merging = idx_death.index(j)
        ax.plot([death, death], [y[idx],y[death_merging]], linewidth=3, c=mpl.colormaps["RdBu"](0.3/1.3), zorder=0.5)

    xscale = (ax.get_xlim()[1]-ax.get_xlim()[0])*0.5
    yscale = (ax.get_ylim()[1]-ax.get_ylim()[0])*0.5
    for i, idx in enumerate(idx_death):
        ax.text(-0.015*xscale, y[i]-0.04*yscale, f"{idx}", zorder=0.7, fontsize=10, color="white", fontweight="bold")
        if i < len(idx_death)-1:
            death_x = endpoints_0[i]
            ax.text(death_x-0.015*xscale, y[i]-0.04*yscale, f"{merging_into[i]}", zorder=0.7, fontsize=10, color="white", fontweight="bold")

    ax.scatter(np.zeros(len(y)),y, s=100, marker="o", color=mpl.colormaps["RdBu"](1/1.3), zorder=0.6)
    ax.scatter(endpoints_0, y[:-1], s=100, marker="o", color=mpl.colormaps["RdBu"](0.3/1.3), zorder=0.6)
    ax.set_xlim(ax.get_xlim()[0]-0.1*xscale, ax.get_xlim()[1]+0.1*xscale)
    ax.set_ylim(ax.get_ylim()[0]-0.1*yscale, ax.get_ylim()[1]+0.1*yscale)
    # Top horizontal interval
    ax.plot([0,max_x*2], [y[-1],y[-1]], linewidth=3, c=mpl.colormaps["RdBu"](1/1.3), zorder=0.5)
    ax.set_yticks([])

In [None]:
filt_Z, pairs_arr_Z = tdqual.mst_edge_filtration(Z)
TMT_Z_pairs = tdqual.compute_tmt_pairs(filt_Z, pairs_arr_Z)

In [None]:
fig, ax = plt.subplots(figsize=(8,3))
plot_merge_tree(filt_Z, TMT_Z_pairs, ax)
ax.set_title("Merge Tree", fontsize=20)
plt.tight_layout()
plt.savefig(os.path.join(plots_dir, "merge_tree.png"))