In [16]:
import jax 
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap
import torch
from scipy.spatial import cKDTree


##### Checking for GPU, use CPU if no GPU is available

In [17]:
# Check if a GPU is available
if any(device.device_kind == 'gpu' for device in jax.devices()):
    jax.config.update("jax_platform_name", "gpu")
    print("Running on GPU:", jax.devices("gpu")[0])
# Use CPU if there is no GPU device available
else:
    jax.config.update("jax_platform_name", "cpu")
    print("Running on CPU:", jax.devices("cpu")[0])

Running on CPU: TFRT_CPU_0


##### Loading datasets to process

In [18]:
# Load all datasets to use
datasets = np.loadtxt("./Data/Full point clouds/bunny_order3_normal_beta.txt"
            # "./Data/Full point clouds/etetwetwet.txt"
            # "./Data/Full point clouds/etetwetwet.txt"
            )
print(f"Points in dataset: {len(datasets)}")

Points in dataset: 35947


In [19]:
@jit
def preprocess(points):
    mean_p = points.mean(axis=0)
    min_p, max_p = jnp.min(points, axis=0), jnp.max(points, axis=0)
    bbdiag = jnp.linalg.norm(max_p - min_p, ord=2) # Bounding box diagonal L2 norm (Euclidean distance)
    return (points - mean_p) / (0.5 * bbdiag)


In [20]:
# Leihui code, for speed comparison 
def preprocessL(points):
    bbdiag = float(np.linalg.norm(points.max(0) - points.min(0), 2))
    points = (points - points.mean(0)) / (0.5*bbdiag)  # shrink shape to unit sphere
    return points

##### Speed comparison of preprocess

In [21]:
rawpoints = datasets[:,0:3]

%timeit our = preprocess(rawpoints)
%timeit lei = preprocessL(rawpoints)

AttributeError: module 'jax._src.linear_util' has no attribute 'transformation2'

In [12]:
our = preprocess(rawpoints)
lei = preprocessL(rawpoints)


print(jnp.allclose(lei, our, atol=1e-6))  # Equal down to a tolerance of 1e-6

ModuleNotFoundError: No module named 'jax.experimental.attrs'

#### Processing function

In [13]:
def pca_points(patch_points):
    '''
    Args:
        patch_points: xyz points

    Returns:
        patch_points: xyz points after aligning using pca
    '''
    # compute pca of points in the patch:
    # center the patch around the mean:
    pts_mean = patch_points.mean(0)
    patch_points = patch_points - pts_mean
    trans, _, _ = torch.svd(torch.t(patch_points))
    patch_points = torch.mm(patch_points, trans)
    cp_new = -pts_mean  # since the patch was originally centered, the original cp was at (0,0,0)
    cp_new = torch.matmul(cp_new, trans)
    # re-center on original center point
    patch_points = patch_points - cp_new
    return patch_points, trans


In [14]:

def save_neighborhood_to_txt(patch_points, filename="neighborhood.txt"):
    np.savetxt(filename, patch_points.numpy(), fmt="%.6f", delimiter=" ")
    print(f"Saved neighborhood to {filename}")
    
    
## Modified leihui code to save the files 
    
def processPartL(kdtree, index, points, searchK, save_to_file=False):
    # print (f'points[index, :]:{points[index, :]}')
    point_distances, patch_point_inds = kdtree.query(points[index, :], k=searchK)
    rad = max(point_distances)
    patch_points = torch.from_numpy(points[patch_point_inds, :])

    # center the points around the query point and scale patch to unit sphere
    patch_points = patch_points - torch.from_numpy(points[index, :])
    # patch_points = patch_points / rad
        # Save to file if required
    if save_to_file:
        save_neighborhood_to_txt(patch_points)  # No transpose needed    
    
    patch_points, trans = pca_points(patch_points)
    return torch.transpose(patch_points, 0, 1), trans, rad

In [15]:
from scipy.spatial import KDTree
import numpy as np

# Example point cloud (replace with your real data)

our_n = np.array(our.block_until_ready())
kdtree = KDTree(our_n)



# Select an index and find neighborhood
index = 1000  # Query point index
searchK = 20  # Number of neighbors

# Process and save the neighborhood
processPartL(kdtree, index, our_n, searchK, save_to_file=True)


NameError: name 'our' is not defined