In [1]:
import jax 
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap
import torch
from scipy.spatial import cKDTree


##### Checking for GPU, use CPU if no GPU is available

In [2]:
# Check if a GPU is available
if any(device.device_kind == 'gpu' for device in jax.devices()):
    jax.config.update("jax_platform_name", "gpu")
    print("Running on GPU:", jax.devices("gpu")[0])
# Use CPU if there is no GPU device available
else:
    jax.config.update("jax_platform_name", "cpu")
    print("Running on CPU:", jax.devices("cpu")[0])

Running on CPU: TFRT_CPU_0


##### Loading datasets to process

In [3]:
# Load all datasets to use
datasets = np.loadtxt("./Data/Full point clouds/bunny_order3_normal_beta.txt"
            # "./Data/Full point clouds/etetwetwet.txt"
            # "./Data/Full point clouds/etetwetwet.txt"
            )
print(f"Points in dataset: {len(datasets)}")

Points in dataset: 35947


In [4]:
@jit
def preprocess(points):
    mean_p = points.mean(axis=0)
    min_p, max_p = jnp.min(points, axis=0), jnp.max(points, axis=0)
    bbdiag = jnp.linalg.norm(max_p - min_p, ord=2) # Bounding box diagonal L2 norm (Euclidean distance)
    return (points - mean_p) / (0.5 * bbdiag)


In [5]:
# Leihui code, for speed comparison 
def preprocessL(points):
    bbdiag = float(np.linalg.norm(points.max(0) - points.min(0), 2))
    points = (points - points.mean(0)) / (0.5*bbdiag)  # shrink shape to unit sphere
    return points

##### Speed comparison of preprocess

In [6]:
rawpoints = datasets[:,0:3]

# %timeit our = preprocess(rawpoints)
# %timeit lei = preprocessL(rawpoints)

In [None]:
our = preprocess(rawpoints)
lei = preprocessL(rawpoints)


print(jnp.allclose(lei, our, atol=1e-6))  # Equal down to a tolerance of 1e-6

#### Processing function

In [8]:
def pca_points(patch_points):
    '''
    Args:
        patch_points: xyz points

    Returns:
        patch_points: xyz points after aligning using pca
    '''
    # compute pca of points in the patch:
    # center the patch around the mean:
    pts_mean = patch_points.mean(0)
    patch_points = patch_points - pts_mean
    trans, _, _ = torch.svd(torch.t(patch_points))
    patch_points = torch.mm(patch_points, trans)
    cp_new = -pts_mean  # since the patch was originally centered, the original cp was at (0,0,0)
    cp_new = torch.matmul(cp_new, trans)
    # re-center on original center point
    patch_points = patch_points - cp_new
    return patch_points, trans


In [40]:

def save_neighborhood_to_txt(patch_points, filename="neighborhood.txt"):
    np.savetxt(filename, patch_points.numpy(), fmt="%.6f", delimiter=" ")
    print(f"Saved neighborhood to {filename}")
    
    
## Modified leihui code to save the files 
    
def processPartL(kdtree, index, points, searchK, save_to_file=False):
    # print (f'points[index, :]:{points[index, :]}')
    point_distances, patch_point_inds = kdtree.query(points[index, :], k=searchK)
    rad = max(point_distances)
    patch_points = torch.from_numpy(points[patch_point_inds, :])

    # center the points around the query point and scale patch to unit sphere
    patch_points = patch_points - torch.from_numpy(points[index, :])
    # patch_points = patch_points / rad
        # Save to file if required
    if save_to_file:
        save_neighborhood_to_txt(patch_points)  # No transpose needed    
    
    patch_points, trans = pca_points(patch_points)
    return torch.transpose(patch_points, 0, 1), trans, rad

In [53]:
from scipy.spatial import KDTree
import numpy as np

# Example point cloud (replace with your real data)

our_n = np.array(our.block_until_ready())
kdtree = KDTree(our_n)



# Select an index and find neighborhood
index = 1000  # Query point index
searchK = 20  # Number of neighbors

# Process and save the neighborhood
processPartL(kdtree, index, our_n, searchK, save_to_file=True)


Saved neighborhood to neighborhood.txt


(tensor([[ 0.0000, -0.0075,  0.0075, -0.0113, -0.0038,  0.0038, -0.0151,  0.0150,
           0.0114, -0.0037,  0.0036, -0.0188,  0.0189, -0.0114, -0.0227,  0.0225,
           0.0112, -0.0263, -0.0149,  0.0266],
         [ 0.0000,  0.0033, -0.0040, -0.0104, -0.0157,  0.0157,  0.0059, -0.0079,
           0.0131,  0.0183, -0.0189, -0.0071,  0.0098,  0.0195,  0.0092, -0.0119,
          -0.0229, -0.0038, -0.0227,  0.0086],
         [ 0.0000, -0.0002,  0.0007, -0.0012,  0.0002, -0.0011,  0.0007,  0.0011,
          -0.0015,  0.0001,  0.0007, -0.0015, -0.0021,  0.0016,  0.0008,  0.0016,
           0.0010, -0.0013, -0.0042, -0.0036]]),
 tensor([[ 0.9539, -0.2405, -0.1793],
         [-0.2996, -0.7321, -0.6119],
         [-0.0159, -0.6374,  0.7704]]),
 np.float64(0.028196911705704875))