In [1]:
import scipy.spatial as spatial
import numpy as np
import matplotlib.pyplot as plt
#from more_itertools import powerset
import mpl_toolkits.mplot3d as a3
from rich import print
import gudhi as gd
import time
import cupy as cpy

def cpy_distance_matrix(X,Y):
    distance_matrix = cpy.zeros((X.shape[0], Y.shape[0]))
    for i in range(X.shape[0]):
        distance_matrix[i,:] = cpy.linalg.norm(
            cpy.broadcast_to(X[i,:], Y.shape) - Y,
            axis = 1,
            ord = 2
        )
    return distance_matrix

In [2]:
a, b = ((1.5 - .5 - .25)/3, (1.5 - np.sqrt(3)/4)/3)
side = .5

X = np.asarray([
    [a,b],
    [side+a, b],
    [side/2 +a, side*np.sqrt(3)/2 + b],
    [.5,.5],
])
X = np.random.rand(800,3)

In [3]:
tri = spatial.Delaunay(X)
X = cpy.asarray(X)
distance_matrix = cpy.power(cpy_distance_matrix(X,X),2)
simplex_tree = gd.SimplexTree()
for face in tri.simplices:
    simplex_tree.insert(face, filtration = np.inf)

for face,_ in simplex_tree.get_simplices():
    if len(face) == 1:
        simplex_tree.insert(face, filtration = 0.0)
        
#print(list(simplex_tree.get_filtration()))

In [4]:
def circumspheres(simplices, distance_matrix, X):
    assert all([ l == len(simplices[0]) for l in map(len,simplices) ])
    
    cayley_menger = cpy.zeros((len(simplices), len(simplices[0]) + 1, len(simplices[0]) + 1))
    cayley_menger[:,1:,0] = 1
    cayley_menger[:,0,1:] = 1
    cayley_menger[:,0,0]  = 0
    
    circumradii = cpy.zeros((len(simplices)))
    circumcentres = cpy.zeros((len(simplices), X.shape[1]))
    
    for t,tau in enumerate(simplices):
        for i,idx in enumerate(tau):
            for j,jdx in enumerate(tau):
                cayley_menger[t, 1+i,1+j] = distance_matrix[idx,jdx]
        #
        cayley_menger[t] = cpy.linalg.inv(cayley_menger[t])
        circumradii[t] = cpy.sqrt(cayley_menger[t,0,0]/-2)
        circumcentres[t] = cayley_menger[t,1:,0].dot( X[tau,:] )
    return (circumradii, circumcentres)

# Plotting

In [5]:
%matplotlib notebook
plot = (X.shape[1] == 2 or X.shape[1] == 3) and (X.shape[0] < 40)

lines = [tau for tau,_ in simplex_tree.get_simplices() if len(tau) == 2]
triangles = [tau for tau,_ in simplex_tree.get_simplices() if len(tau) == 3]

if X.shape[1] == 3:
    tetrahedrons = [tau for tau,_ in simplex_tree.get_simplices() if len(tau) == 4]

if plot:
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(projection = "3d") if X.shape[1] == 3 else fig.add_subplot()
    
    def get_centres(points):
        centres = [points[:,0], points[:,1], points[:,2]] if X.shape[1] == 3 else [points[:,0], points[:,1]]
        return tuple(map(cpy.ndarray.get, centres))
    
    def plot_radii(simplices):
        plot_color = { 2 : "black", 3 : "red", 4 : "blue" }
        plot_color = plot_color[len(simplices[0])]
        for t,tau in enumerate(simplices):
            verts = X[tau, :].get()
            c = centres[t].get()
            for y in verts:
                connecting_line = (
                    [c[0],y[0]],
                    [c[1],y[1]],
                    [c[2],y[2]]
                ) if X.shape[1] == 3 else (
                    [c[0],y[0]],
                    [c[1],y[1]],
                )
                ax.plot(
                    *connecting_line,
                    "--",
                    color = plot_color,
                    alpha = 0.2
                )
    
    radii, centres = circumspheres(lines, distance_matrix, X)
    ax.scatter(*get_centres(centres), color = "black", marker = "o", facecolors = "none")
    
    plot_radii(lines)
    
    radii, centres = circumspheres(triangles, distance_matrix, X)
    ax.scatter(*get_centres(centres), color = "red", marker = "o", facecolors = "none")
    
    plot_radii(triangles)
    
    if X.shape[1] == 3:
        # the faces of a tetrahedron are triangles
        plt_tri = a3.art3d.Poly3DCollection(X.get()[triangles, :])
        plt_tri.set_alpha(0.1)
        plt_tri.set_color('grey')
        ax.add_collection3d(plt_tri)
        radii, centres = circumspheres(tetrahedrons, distance_matrix, X)
        ax.scatter(*get_centres(centres), color = "blue", marker = "o", facecolors = "none")
        
        plot_radii(tetrahedrons)

    ax.scatter(*get_centres(X), s = 10, color = "black")
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_zlim(0,1) if X.shape[1] == 3 else None
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z") if X.shape[1] == 3 else None
    plt.show()

In [6]:
# If the only points in the circumsphere of a face
# are on the boundary of the circumsphere, then that face is Gabriel.
# Faces in the complex which are not Gabriel do not get to update their filtrations.
kdtree = spatial.KDTree(X.get())
def is_gabriel(face, circumcentre, kdtree):
    _, nn = kdtree.query(circumcentre, k = len(face))    
    return np.array_equal(np.sort(nn), np.sort(face))
    
def alpha_complex(dimension, simplex_tree):
    if dimension == 0:
        return simplex_tree
    
    faces = [ tau for tau,_ in simplex_tree.get_simplices() if len(tau) == ( dimension+1 ) ]
    radii, centres = circumspheres(faces, distance_matrix, X)

    for r,radius in enumerate(radii):
        tau = faces[r]
        if is_gabriel(tau, centres[r].get(), kdtree):
            simplex_tree.insert(tau, filtration = radius)
    
    return alpha_complex(dimension - 1, simplex_tree)
    

cech_complex = alpha_complex(X.shape[1], simplex_tree)
#print(list(cech_complex.get_filtration()))