In this notebook, we compute the matching using the minimum spanning tree Kruskal Algorithm.

In [None]:
import numpy as np
import gudhi
import matplotlib.pyplot as plt
import matplotlib as mpl

import scipy.spatial.distance as dist
import itertools
from scipy.sparse.csgraph import minimum_spanning_tree
import numpy as np

import iblofunmatch.topological_data_quality_0 as tdq0

plots_dir = "plots/mnist_experiment/"
import os 
os.makedirs(plots_dir, exist_ok=True)

In [None]:
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn import decomposition

X_raw_l = pd.read_csv('mnist_data/train.csv')
y = np.array(X_raw_l["label"])
X_raw = X_raw_l.drop("label", axis=1)

In [None]:
# Restrict to first num_class classes from the dataset
num_class = 3
X_raw_class_idx = X_raw[y<=num_class ]
y_sub = y[y<=num_class]
# Scale using the StandardScaler
X_scal = StandardScaler().fit_transform(X_raw_class_idx)
# Reduce dimension with PCA
pca = decomposition.PCA()
pca.n_components = 10
X = pca.fit_transform(X_scal)
print("dataset shape")
X.shape

Compute filtration pairs in all dataset. This takes a while.

In [None]:
filtration_list_X, pairs_arr_X = tdq0.filtration_pairs(X)

For each class, compute matching and plot matched pairs $(a,b)$, with $a>b$ as well as density matrix of scattered points $(a,b-a)$.

In [None]:
for i in range(num_class):
    indices_subset = np.nonzero(y_sub==i)[0]
    S = X[indices_subset]
    filtration_list_S, pairs_arr_S = tdq0.filtration_pairs(S)
    inclusion_matrix = tdq0.get_inclusion_matrix(pairs_arr_S, pairs_arr_X, indices_subset)
    matching = tdq0.get_inclusion_matrix_pivots(inclusion_matrix, X.shape[0])
    fig, ax = plt.subplots(ncols=2, figsize=(10,5))
    endpoints_X = np.array(filtration_list_X)[matching]
    ax[0].scatter(filtration_list_S, endpoints_X, s=10, marker="o", c="blue")
    ax[0].set_title(f"Class {i}")
    tdq0.plot_density_matrix(filtration_list_S, filtration_list_X, matching, ax[1], nbins=10)
    plt.savefig(plots_dir + f"MNIST_class_{i}.png")

Now we compare that the matchign that it gives is the same than the standard algorithm in a few examples.

In [None]:
import iblofunmatch.inter as ibfm
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)

In [None]:
for iterate in range(10):
    RandGen = np.random.default_rng(iterate)
    X = ibfm.sampled_circle(0,2,500, RandGen)
    S_indices = RandGen.choice(X.shape[0],200, replace=False)
    # # Sort and save 
    # S_compl = np.ones(X.shape[0], dtype="bool")
    # S_compl[S_indices] = False
    # X = np.vstack((X[S_indices], X[S_compl]))
    # S_indices = range(len(S_indices))
    S = X[S_indices]
    # Compute matching using minimum spanning trees
    filtration_list_X, pairs_arr_X = tdq0.filtration_pairs(X)
    filtration_list_S, pairs_arr_S = tdq0.filtration_pairs(S)
    F = tdq0.get_inclusion_matrix(pairs_arr_S, pairs_arr_X, S_indices)
    matching = tdq0.get_inclusion_matrix_pivots(F, X.shape[0])
    # Compute matching using iblofunmatch
    ibfm_out = ibfm.get_IBloFunMatch_subset(None, X, S_indices, output_dir, num_it=4, max_rad=-1, points=True, store_0_pm=True)
    # CHeck that both outputs coincide
    assert(np.all((ibfm_out["X_barcode_0"][:,1] - filtration_list_X)<1e-5))
    assert(np.all((ibfm_out["S_barcode_0"][:,1] - filtration_list_S)<1e-5))
    assert(np.all((np.array(ibfm_out["induced_matching_0"],dtype="int") - np.array(matching, dtype="int")==0)))
    # Plot example 
    fig, ax = plt.subplots(ncols=3, figsize=(12,4))
    ax[0].scatter(S[:,0], S[:,1], color=mpl.colormaps["RdBu"](0.3/1.3), s=60, marker="o", zorder=2)
    ax[0].scatter(X[:,0], X[:,1], color=mpl.colormaps["RdBu"](1/1.3), s=40, marker="x", zorder=1)
    ax[0].set_axis_off()
    endpoints_X = np.array(filtration_list_X)[matching]
    ax[1].scatter(filtration_list_S, endpoints_X, s=10, marker="o", c="blue")
    tdq0.plot_density_matrix(filtration_list_S, filtration_list_X, matching, ax[2], nbins=10)
    print("################################## TEST PASSED :) ====================================")