In [None]:
%load_ext autoreload
%autoreload 2

import time

import numpy as np

from scipy import spatial
from sklearn.preprocessing import scale

from gtda.homology import VietorisRipsPersistence

from steenroder import *
import gudhi

from matplotlib import pyplot as plt

In [None]:
def get_density_filtered_point_cloud(weight_array,
                                     k_th_nearest=200,
                                     percentage=0.3,
                                     user_axis=1):
    """
    Returns a density filtered point cloud of 9-vectors
    Parameters
    ----------
    weight_array: array
        An array of 9-dimensional points i.e the weights of 3x3 patches
    k_th_nearest : integer, optional, default: 200
        Indicates the k-th neighbour used as a density estimator
    percentage : integer, optional, default: 0.3
        The percentage of the point cloud cardinality to be left after filtration
    user_axis :
        integer 0 or 1 for scaling axis, 0 is column normalization, 1 is row norm.
    """
    normalized_weight_array = scale(weight_array, axis=user_axis)
    m_dimension, n_dimension = normalized_weight_array.shape
    number_of_kth_densest_points = int(percentage*m_dimension)
    condensed_weight_distance_matrix = spatial.distance.pdist(normalized_weight_array)
    redundant_weight_distance_matrix = spatial.distance.squareform(condensed_weight_distance_matrix)
    kth_nearest_index_matrix = np.argsort(redundant_weight_distance_matrix, axis=1)
    kth_nearest_index_vector = np.zeros((m_dimension, 1), dtype=int)
    for i in range(m_dimension):
        kth_nearest_index = kth_nearest_index_matrix[i][k_th_nearest]
        kth_nearest_index_vector[i] = kth_nearest_index
    # kth_nearest_distances = redundant_weight_distance_matrix[kth_nearest_index_vector]
    kth_nearest_distances = np.take_along_axis(redundant_weight_distance_matrix,
                                               kth_nearest_index_vector, axis=1)
    kth_nearest_indices_sorted = np.argsort(kth_nearest_distances, axis=0)
    p_of_kth_nearest_indices = kth_nearest_indices_sorted[0:number_of_kth_densest_points]
    pth_densest_points = normalized_weight_array[p_of_kth_nearest_indices]
    pth_densest_points = np.reshape(pth_densest_points, (-1, n_dimension))

    return pth_densest_points

In [None]:
weights = np.load('../data/vgg16_lay2.npy')

In [None]:
X = get_density_filtered_point_cloud(weights, k_th_nearest=15, percentage=0.3, user_axis=1)
print(len(X))

dm = spatial.distance.pdist(X)
plt.hist(dm, bins=100);
plt.title("Histogram of pairwise distances in X");
plt.show()

In [None]:
max_edge_length = 5

VietorisRipsPersistence(homology_dimensions=(0, 1, 2),
                        max_edge_length=max_edge_length,
                        infinity_values=np.inf).fit_transform_plot([X]);

In [None]:
simplex_tree_coll = gudhi.RipsComplex(points=X, max_edge_length=4.5).\
    create_simplex_tree(max_dimension=1)  # Only get the 1-skeleton this time

In [None]:
for i, _ in enumerate(simplex_tree_coll.get_filtration()):
        pass

len_filtration = i + 1

while True:
    simplex_tree_coll.collapse_edges()

    for i, _ in enumerate(simplex_tree_coll.get_filtration()):
        pass

    if i + 1 == len_filtration:
        break
    else:
        len_filtration = i + 1
        print(len_filtration)

print(f"Filtration with {len_filtration} simplices")

In [None]:
simplex_tree_coll.expansion(3)  # Get the three-simplices after collapse

filtration = []
filtration_values = []
for s in simplex_tree_coll.get_filtration():
    filtration.append(tuple(s[0]))
    filtration_values.append(s[1])

filtration_values = np.asarray(filtration_values, dtype=np.float64)

print(f"Filtration with {len(filtration)} simplices")

In [None]:
barcode, st_barcode = barcodes(
    1,
    filtration, 
    homology=False,
    filtration_values=filtration_values,
    return_filtration_values=True,
    verbose=True
    )

st_barcode

In [None]:
gudhi_barcode = simplex_tree_coll.persistence(homology_coeff_field=2, persistence_dim_max=True)
check_agreement_with_gudhi(gudhi_barcode, barcode)