# Introduction
This Jupyter notebook contains codes to implement the numerical experiments in paper (url for preprint), and is organized into four sections:

- Verify convergence rate on torus
- Persistence measure v.s. Frechet mean
- PH approximation on massive data
- Shape clustering

The computation of persistent homology is implemented using `gudhi` (https://gudhi.inria.fr), and the computation of optimal transport is implemented using `pot` (https://pythonot.github.io).

**Caveat**: The codes in general are not fast, and some can take up to 7 hours. Therefore we do not recommend the  option `Run all cells`. Instead, please run cells section by section. 

Author: Yueqi Cao

Contact: y.cao21@imperial.ac.uk

In [8]:
# import necessary modules
import numpy as np
np.random.seed(20211121) # set random seed for reproducible experiments
import ApproxPH
import matplotlib.pyplot as plt
from gudhi.wasserstein.barycenter import lagrangian_barycenter as bary

# function to compute mean persistence measure and Frechet mean
def compute_mean(original_set, nb_subs, nb_sub_points, max_edge_length, min_persistence, scenario):
    subs = ApproxPH.get_subsample(original_set, nb_sub_points, nb_subs)
    # list of PDs
    diags = []
    for points in subs:
        diag = ApproxPH.get_PD(points, max_edge_length=max_edge_length, min_persistence=min_persistence)
        diag[np.isinf(diag)] = max_edge_length
        diags.append(diag)
        
    if scenario == 'mpm':
        # compute mean persistence measure
        sub_pers = np.array([[0,0]])
        for diag in diags:
            sub_pers = np.append(sub_pers, diag, axis=0)
        unit_mass = 1/nb_subs
        mean_mesr, mean_mesr_vis = ApproxPH.diag_to_mesr(sub_pers, unit_mass)
        return mean_mesr, mean_mesr_vis
    
    if scenario == 'fm':
        # compute Frechet mean
        wmean, log = bary(diags, init=0, verbose=True)
        return wmean
    
    if scenario == 'both':
        # compute both mean persistence measure and Frechet mean
        wmean, log = bary(diags, init=0, verbose=True)
        sub_pers = np.array([[0,0]])
        for diag in diags:
            sub_pers = np.append(sub_pers, diag, axis=0)
        unit_mass = 1/nb_subs
        mean_mesr, mean_mesr_vis = ApproxPH.diag_to_mesr(sub_pers, unit_mass)
        return mean_mesr, mean_mesr_vis, wmean    

# Verify convergence rate on torus

In this section, we test the convergence rate on a synthetic data sampled from a 2-dimensional torus. 

State main result here:....

In [7]:
# generate true set and compute persistent homology
# this step takes about 4 mins.
X_torus = ApproxPH.sample_torus(50000, 0.8, 0.3)
np.save('outputs/true-torus-points.npy', X_torus)
diag_torus = ApproxPH.get_PD(X_torus, max_edge_length=0.9)
np.save('outputs/true-torus-diagram.npy', diag_torus)

In [None]:
# visualize the true persistence diagram
# use the following command if you have true-torus-diagram prepared 
# diag_torus = np.load('outputs/true-torus-diagram.npy')
ApproxPH.plot_diag(diag_torus)

In [None]:
# extract subsamples from the true set
# use the following command if you have true-torus-points prepared
#X_torus = np.load('outputs/true-torus-points.npy')
# it takes about 7 hours to run 15 simulations
nb_simulates = 15
for i in range(nb_simulates): 
    mean_mesr, mean_mesr_vis = compute_mean(original_set = X_torus,
                                            nb_subs = 20*(i+2),
                                            nb_sub_points = 200*(i+2),
                                            max_edge_length = 0.9,
                                            min_persistence = 0.01,
                                            scenario = 'mpm'
                                           )
    np.save('outputs/mean_mesr_nb%d.npy' %(i), mean_mesr)
    print('mean persistence measure for %dth simulation' %(i))
    # use the following command to visualize the mean persistence measure
    # mpd.plot_mesr(mean_mesr_vis)

In [None]:
# compute Wasserstein distances between mean persistence measures and true persistence diagram
mesr_list = []
for i in range(15):
    mesr = np.load('outputs/mean_mesr_nb%d.npy' %(i))
    mesr_list.append(mesr)

# load the true PD
true_PD = np.load('outputs/true-torus-diagram.npy')
# transform the true PD to PM
true_mesr, true_mesr_vis = ApproxPH.diag_to_mesr(true_PD, 1)

# compute the Wasserstein distance
power_index = 3
grid = ApproxPH.mesh_gen()
Mp = ApproxPH.dist_mat(grid, power_index)
dist_list = []
point_list = []
for i in range(len(mesr_list)):
    distance = ApproxPH.wass_dist(mesr_list[i], true_mesr, Mp)
    point_list.append(200*(i+2))
    dist_list += distance.tolist()
    
# plot fitting curve
ApproxPH.plot_fitting_curve(point_list, dist_list)

# Comparison with the Frechet mean method

In this section, we compare the performance of mean persistence measure and Frechet mean.

In [5]:
# compute the true diagram
nb_points = 5000
true_set = ApproxPH.sample_annulus(nb_points, r1=0.2, r2=0.5)
true_PD = ApproxPH.get_PD(true_set, max_edge_length=0.4, min_persistence=0.01)
true_mesr, true_mesr_vis = ApproxPH.diag_to_mesr(true_PD, 1)

In [None]:
# each time we draw 20 subsets from the true_set
# each subset has number of points in nb_sub_points_list
# we compute the 2-Wasserstein distance of mean persistence measure & Frechet mean to the true diagram

# set parameters
nb_subs = 20
unit_mass  = 1/nb_samples
nb_sub_points_list = [50,100,150,200,250,300,350,400]
power_index = 2
w_list = []
permesr_list = []

for nb_sub_points in nb_sub_points_list:
    print('number of points in each subset: %d' %(nb_sub_points))
    mean_mesr, mean_mesr_vis, wmean = compute_mean(original_set = true_set,
                                            nb_subs = nb_subs,
                                            nb_sub_points = nb_sub_points,
                                            max_edge_length = 0.4,
                                            min_persistence = 0.01,
                                            scenario = 'both'
                                           )
    wmean_mesr, wmean_mesr_vis = ApproxPH.diag_to_mesr(wmean, 1)
    # compute distance
    grid = ApproxPH.mesh_gen()
    Mp = ApproxPH.dist_mat(grid, power_index)
    permesr_distance = ApproxPH.wass_dist(mean_mesr, true_mesr, Mp)
    wmean_distance = ApproxPH.wass_dist(wmean_mesr, true_mesr, Mp)
    permesr_list.append(permesr_distance**(1/power_index))
    w_list.append(wmean_distance**(1/power_index))

In [None]:
# visualize the comparison
# plot mean persistence diagram
fig = plt.figure(figsize=(8,8))
plt.plot(nb_sub_points_list, permesr_list, linestyle='-', color='blue',\
         linewidth=2, label='Mean Persistence Measure')
plt.scatter(nb_sub_points_list, permesr_list, s=70, color='red', marker='o')
plt.plot(nb_sub_points_list, w_list, linestyle='--', color='green',\
         linewidth=2, label='Frechet Mean')
plt.scatter(nb_sub_points_list, w_list, s=70, color='black', marker='P')
plt.xlabel('Number of Points')
plt.ylabel('2-Wasserstein distance')
plt.title('Comparison of Frechet mean\n and mean persistence measure')
plt.legend()
plt.show()

# PH approximation on massive data

In this section, we compute the mean persistence measure and Frechet mean for real large data

description of data

In [None]:
import numpy as np
from plyfile import *

pltdata = PlyData.read('grayloc.ply')

x = pltdata['vertex']['x']
y = pltdata['vertex']['y']
z = pltdata['vertex']['z']

large_points = np.array([x,y,z]).T
print(large_points.shape)

# rescale
def rescale_points(points):
    for i in range(3):
        max_scale = np.max(points[:,i])
        min_scale = np.min(points[:,i])
        a = 2/(max_scale-min_scale)
        b = 1 - a * max_scale
        points[:,i] = a * points[:,i] + b
    return points

# compute persistent homology
import ssp
import gudhi as gd
import matplotlib.pyplot as plt
from gudhi.wasserstein.barycenter import lagrangian_barycenter as bary

nb_samples = 30
unit_mass  = 1/nb_samples
nb_sub_ratio = 0.02
power_index = 2
max_edge_length = 0.55
homology_dimension = 1

points = rescale_points(large_points)
nb_sub_points = int(nb_sub_ratio*large_points.shape[0])
# get subsamples
point_set = ssp.get_subsample(points, nb_sub_points, nb_samples)
# get PD list
diags = []
for point in point_set:
    diag = ssp.get_PD(point, homology_dimension, max_edge_length)
    diags.append(diag)
for diag in diags:
    diag[np.isinf(diag)] = max_edge_length

# compute wasserstein mean
wmean, log = bary(diags, init=0, verbose=True)
# compute mean persistence measure
sub_pers = np.array([[0,0]])
for diag in diags:
    sub_pers = np.append(sub_pers, diag, axis=0)          
grid = ssp.mesh_gen()
Mp = ssp.dist_mat(grid, power_index)
# diagrams to measures
mean_mesr, mean_mesr_vis = ssp.diag_to_mesr(sub_pers, unit_mass)

np.save('Fmean-1-loc.npy',wmean)
np.save('meanM-1-loc.npy',mean_mesr_vis)


fig = plt.figure(figsize=(5, 10))
plt.rcParams.update({'font.family':'Times New Roman', 'font.size':16})
plt.rc('text', usetex=True)
main_ax = fig.add_subplot(211)
main_ax.scatter(wmean[:,0], wmean[:,1], s=75, marker='o', c='red', alpha=0.8)
main_ax.plot([0,0.2], [0,0.2], linewidth=0.5)
main_ax.fill_between([0,0.2], [0,0.2], [0,0], facecolor='green', alpha=0.2)	
main_ax.set_xlim((0,0.2))
main_ax.set_ylim((0,0.5))
main_ax.set_xticks([0,0.1])
main_ax.set_title("FM")

m_ax = fig.add_subplot(212)
m_ax.imshow(mean_mesr_vis.T, origin='lower', cmap='hot_r', interpolation='bilinear',\
               aspect='auto')
L = mean_mesr_vis.shape[0]
m_ax.set_xlim((0,L/2))
m_ax.set_ylim((0,L/2))
m_ax.set_xticks([0,10,20])
m_ax.set_xticklabels([0.0,0.2,0.4])
m_ax.set_yticks([0,5,10,15,20,25])
m_ax.set_yticklabels([0.0,0.1,0.2,0.3,0.4,0.5])
m_ax.fill_between([0,L/2], [0,L/2], [0,0], facecolor='green', alpha=0.2)
m_ax.set_title("MPM")

plt.savefig('1-homology.png',dpi=400)

wmean=np.load('Fmean-1-loc.npy')
mean_mesr_vis=np.load('meanM-1-loc.npy')

fig = plt.figure(figsize=(12, 12))
plt.rcParams.update({'font.family':'Times New Roman', 'font.size':28})
plt.rc('text', usetex=True)
main_ax = fig.add_subplot(111)
main_ax.scatter(wmean[:,0], wmean[:,1], s=100, marker='o', c='red', alpha=0.8)
main_ax.plot([0,0.2], [0,0.2], linewidth=0.5)
main_ax.fill_between([0,0.2], [0,0.2], [0,0], facecolor='green', alpha=0.2)	
main_ax.set_xlim((0,0.2))
main_ax.set_ylim((0,0.5))
main_ax.set_xticks([0,0.1,0.2])
plt.savefig('loc-FM-1.png',dpi=400)
#main_ax.set_title("FM")

fig = plt.figure(figsize=(12, 12))
plt.rcParams.update({'font.family':'Times New Roman', 'font.size':28})
plt.rc('text', usetex=True)
m_ax = fig.add_subplot(111)
m_ax.imshow(mean_mesr_vis.T, origin='lower', cmap='hot_r', interpolation='bilinear',\
               aspect='auto')
L = mean_mesr_vis.shape[0]
print(L)
m_ax.set_xlim((0,L/5))
m_ax.set_ylim((0,L/2))
m_ax.set_xticks([0,5,10])
m_ax.set_xticklabels([0.0,0.1,0.2])
m_ax.set_yticks([0,5,10,15,20,25])
m_ax.set_yticklabels([0.0,0.1,0.2,0.3,0.4,0.5])
m_ax.fill_between([0,L/5], [0, L/5], [0,0], facecolor='green', alpha=0.2)
#m_ax.set_title("MPM")
plt.savefig('loc-MPM-1.png',dpi=400)

# Shape clustering

In this section, we apply subsampling methods to shape clustering.

description of data

In [None]:
code here