In [1]:
import jax 
import jax.numpy as jnp
import numpy as np
from jax import jit, vmap
import torch
from scipy.spatial import cKDTree
from scipy.spatial import KDTree

##### Checking for GPU, use CPU if no GPU is available

In [9]:
# Check if a GPU is available
if any(device.device_kind == 'gpu' for device in jax.devices()):
    jax.config.update("jax_platform_name", "gpu")
    print("Running on GPU:", jax.devices("gpu")[0])
# Use CPU if there is no GPU device available
else:
    jax.config.update("jax_platform_name", "cpu")
    print("Running on CPU:", jax.devices("cpu")[0])

Running on CPU: TFRT_CPU_0


# Preprocess function $\color{green} ✅$ 

In [3]:
# Load all datasets to use
datasets = np.loadtxt("./Data/Full_point_clouds/bunny_order3_normal_beta.txt"
            # "./Data/Full point clouds/etetwetwet.txt"
            # "./Data/Full point clouds/etetwetwet.txt"
            )
print(f"Points in dataset: {len(datasets)}")

Points in dataset: 35947


In [4]:
@jit
def preprocess(points):
    mean_p = points.mean(axis=0)
    min_p, max_p = jnp.min(points, axis=0), jnp.max(points, axis=0)
    bbdiag = jnp.linalg.norm(max_p - min_p, ord=2) # Bounding box diagonal L2 norm (Euclidean distance)
    return (points - mean_p) / (0.5 * bbdiag)


In [5]:
# Leihui code, for speed comparison 
def preprocessL(points):
    bbdiag = float(np.linalg.norm(points.max(0) - points.min(0), 2))
    points = (points - points.mean(0)) / (0.5*bbdiag)  # shrink shape to unit sphere
    return points

In [6]:
rawpoints = datasets[:,0:3]

%timeit our = preprocess(rawpoints)
%timeit lei = preprocessL(rawpoints)

453 μs ± 23.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
2.75 ms ± 69.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
our = preprocess(rawpoints)
lei = preprocessL(rawpoints)


print(jnp.allclose(lei, our, atol=1e-6))  # Equal down to a tolerance of 1e-6

True


# ProcessPart function