In [None]:
%load_ext autoreload
%autoreload 2

import time
import itertools

import numpy as np
import pickle as pkl

from scipy import spatial as spatial
from scipy import sparse as sp
from sklearn.preprocessing import scale

from gtda.homology import VietorisRipsPersistence

from steenroder import *
import gudhi

from matplotlib import pyplot as plt

In [None]:
def get_density_filtered_point_cloud(weight_array,
                                     k_th_nearest=200,
                                     percentage=0.3,
                                     user_axis=1):
    """
    Returns a density filtered point cloud of 9-vectors
    Parameters
    ----------
    weight_array: array
        An array of 9-dimensional points i.e the weights of 3x3 patches
    k_th_nearest : integer, optional, default: 200
        Indicates the k-th neighbour used as a density estimator
    percentage : integer, optional, default: 0.3
        The percentage of the point cloud cardinality to be left after filtration
    user_axis :
        integer 0 or 1 for scaling axis, 0 is column normalization, 1 is row norm.
    """
    normalized_weight_array = scale(weight_array, axis=user_axis)
    m_dimension, n_dimension = normalized_weight_array.shape
    number_of_kth_densest_points = int(percentage*m_dimension)
    condensed_weight_distance_matrix = spatial.distance.pdist(normalized_weight_array)
    redundant_weight_distance_matrix = spatial.distance.squareform(condensed_weight_distance_matrix)
    kth_nearest_index_matrix = np.argsort(redundant_weight_distance_matrix, axis=1)
    kth_nearest_index_vector = np.zeros((m_dimension, 1), dtype=int)
    for i in range(m_dimension):
        kth_nearest_index = kth_nearest_index_matrix[i][k_th_nearest]
        kth_nearest_index_vector[i] = kth_nearest_index
    # kth_nearest_distances = redundant_weight_distance_matrix[kth_nearest_index_vector]
    kth_nearest_distances = np.take_along_axis(redundant_weight_distance_matrix,
                                               kth_nearest_index_vector, axis=1)
    kth_nearest_indices_sorted = np.argsort(kth_nearest_distances, axis=0)
    p_of_kth_nearest_indices = kth_nearest_indices_sorted[0:number_of_kth_densest_points]
    pth_densest_points = normalized_weight_array[p_of_kth_nearest_indices]
    pth_densest_points = np.reshape(pth_densest_points, (-1, n_dimension))

    return pth_densest_points

In [None]:
weights = np.load('data/vgg16_lay2.npy')

In [None]:
X = get_density_filtered_point_cloud(weights, k_th_nearest=200, percentage=0.2, user_axis=1)
print(len(X))

dm = spatial.distance.pdist(X)
plt.hist(dm, bins=100);
plt.title("Histogram of pairwise distances in X");
plt.show()

In [None]:
max_edge_length = 5

VietorisRipsPersistence(homology_dimensions=(0, 1),
#                         collapse_edges=True,
                        max_edge_length=max_edge_length,
                        infinity_values=np.inf).fit_transform_plot([X]);

In [None]:
k = 1

In [None]:
rips_complex_coll = gudhi.RipsComplex(points=X, max_edge_length=max_edge_length)
simplex_tree_coll = rips_complex_coll.create_simplex_tree(max_dimension=1)  # Only get the 1-skeleton this time

In [None]:
simplex_tree_coll.collapse_edges(nb_iterations=10)

filtration_coll = []
for s in simplex_tree_coll.get_filtration():
    filtration_coll.append(s)

print(f"Filtration with {len(filtration_coll)} simplices")

In [None]:
simplex_tree_coll.expansion(2)  # Get the three-simplices after collapse

In [None]:
filtration_coll = []
for s in simplex_tree_coll.get_filtration():
    filtration_coll.append(s)

print(f"Filtration with {len(filtration_coll)} simplices")

In [None]:
barcode_coll, st_barcode_coll = barcodes(
    k,
    [tuple(s[0]) for s in filtration_coll], 
    homology=True,
    filtration_values=np.array([s[1] for s in filtration_coll]),
    return_filtration_values=True,
    maxdim=2
)

In [None]:
gudhi_barcode = simplex_tree_coll.persistence(homology_coeff_field=2, persistence_dim_max=True)
check_agreement_with_gudhi(gudhi_barcode, barcode_coll)