In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, Normalize
from matplotlib.cm import ScalarMappable

from tqdm import tqdm
import os
import sys
sys.path.append('../')
from PIL import Image
import cv2
import open3d as o3d

from skeletor.skeleton import Octree
from skeletor.data import loadTestDataset, loadPointCloud, plotTestDatasets, TEST_DATASETS_2D, TEST_DATASETS_3D, printTestDatasets

import robust_laplacian

from scipy.signal import convolve
from scipy.spatial import KDTree
from scipy.spatial.transform import Rotation

import scipy.sparse as sparse
import scipy.sparse.linalg as sla

from pepe.topology import spatialClusterLabels

In [None]:
def removeDenseClusters(points, radius=5, removeFraction=2.):
    """
    """
    kdTree = KDTree(points)
    # Compute average NN distance
    nnDistances, nnIndices = kdTree.query(points, 2)
    avgNNDistance = np.mean(nnDistances[:,1])
    
    neighborsArr = kdTree.query_ball_tree(kdTree, avgNNDistance*radius)
    numNeighborsArr = np.array([len(a) for a in neighborsArr])

    return points[numNeighborsArr < np.median(numNeighborsArr)*removeFraction]

In [None]:
skeletonPoints = loadPointCloud('../medial_axis_2024-10-09_LG_A_PNG_T4.0_clean.npy', downsample=1)

cleanedPoints = removeDenseClusters(skeletonPoints, removeFraction=4)

print(len(cleanedPoints))

In [None]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(skeletonPoints)
pcd.paint_uniform_color((0, 0, 0))

pcd2 = o3d.geometry.PointCloud()
pcd2.points = o3d.utility.Vector3dVector(cleanedPoints)
pcd2.paint_uniform_color((1, 0, 0))

o3d.visualization.draw_geometries([pcd, pcd2])

In [None]:
def contractPoints(points, referencePoints=None, pointMasses=None, attraction=1.0, contraction=0.5):
    """
    Perform Laplacian contraction on a set of points, potentially in relation to
    a reference set of points (points that should attract other points but not
    move themselves).
    """
    dim = np.shape(points)[-1]

    if hasattr(referencePoints, '__iter__'):
        allPoints = np.concatenate((points, referencePoints))
        
        # Compute the laplacian and mass matrix
        L, M = robust_laplacian.point_cloud_laplacian(allPoints)

        # We should only have positive attraction towards reference
        # points, and we should have slight negative attraction towards
        # regular points (to avoid clumping)
        if hasattr(pointMasses, '__iter__'):
            pointRepulsion = attraction * pointMasses
        else:
            pointRepulsion = attraction * np.ones(len(points))
            
        # Multiply the attraction of the reference points by a very large number so
        # they don't move from their original positions much
        referencePointAttraction = attraction * np.ones(len(referencePoints))*1e6
        pointContraction = contraction * 1e3 * np.sqrt(np.mean(M.diagonal())) * np.ones(len(points))
        referencePointContraction = contraction * 1e3 * np.sqrt(np.mean(M.diagonal())) * np.ones(len(referencePoints))

        # Define weight matrices
        WH = sparse.diags(np.concatenate((pointRepulsion, referencePointAttraction)))
        WL = sparse.diags(np.concatenate((pointContraction, referencePointContraction)))  # I * laplacian_weighting

    else:
        allPoints = points
        
        # Compute the laplacian and mass matrix
        L, M = robust_laplacian.point_cloud_laplacian(allPoints)
        
        attractionWeights = attraction * np.ones(M.shape[0])
        # This is weighted by the sqrt of the mean of the mass matrix, not really sure why, but :/
        contractionWeights = contraction * 1e3 * np.sqrt(np.mean(M.diagonal())) * np.ones(M.shape[0])

        # Define weight matrices
        WH = sparse.diags(attractionWeights)
        WL = sparse.diags(contractionWeights)  # I * laplacian_weighting

    A = sparse.vstack([L.dot(WL), WH]).tocsc()
    b = np.vstack([np.zeros((allPoints.shape[0], 3)), WH.dot(allPoints)])

    A_new = A.T @ A

    # Solve each dimension separately
    solvedAxes = [sla.spsolve(A_new, A.T @ b[:,i], permc_spec='COLAMD') for i in range(dim)]
    # If we are in 2D, just add back in the previous z dimension (no need to solve it since
    # we will throw it away eventually)
    if dim == 2:
        solvedAxes += [list(points[:,2])]
    ret = np.vstack(solvedAxes).T

    if (np.isnan(ret)).all():
        #logging.warn('Matrix is exactly singular. Stopping Contraction.')
        ret = points

    return ret[:len(points)]

In [None]:
allPoints = np.copy(cleanedPoints)
originalPoints = contractPoints(cleanedPoints, attraction=5000, contraction=1)
adjustedSkeletonPoints = np.copy(cleanedPoints)

for _ in range(10):
    adjustedSkeletonPoints = contractPoints(adjustedSkeletonPoints, attraction=1000, contraction=1)
    allPoints = np.concatenate((allPoints, adjustedSkeletonPoints))

In [None]:
referencePoints = loadPointCloud('/home/jack/Workspaces/data/point_clouds/2024-10-21_LG_A_PNG_T4.0.npy', downsample=50)

adjustedSkeletonPoints = contractPoints(cleanedPoints, referencePoints, attraction=50, contraction=0.5)

In [None]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(cleanedPoints)
pcd.paint_uniform_color((0, 0, 0))

pcd2 = o3d.geometry.PointCloud()
pcd2.points = o3d.utility.Vector3dVector(adjustedSkeletonPoints)
pcd2.paint_uniform_color((1, 0, 0))

o3d.visualization.draw_geometries([pcd, pcd2])

In [None]:
o3d.visualization.draw_geometries([pcd])

In [None]:
sys.path.append('../../skeletor')
from skeletor.skeleton import OctreeContractionSkeleton

skeleton = OctreeContractionSkeleton(cleanedPoints, 1000, verbose=True)
contractedPoints = skeleton.contractPointCloud(iterations=10, attraction=500, contraction=1)
print(len(contractedPoints))

In [None]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(cleanedPoints)
pcd.paint_uniform_color((0, 0, 0))

pcd2 = o3d.geometry.PointCloud()
pcd2.points = o3d.utility.Vector3dVector(contractedPoints)
pcd2.paint_uniform_color((1, 0, 0))

o3d.visualization.draw_geometries([pcd, pcd2])