In [2]:
def xyz_projected(pos, normal, x_basis, y_basis, edge_index, k=None):
    """Projects neighboring points to the tangent basis
    and returns the local coordinates.

    Args:
        pos (Tensor): an [N, 3] tensor with the point positions.
        normal (Tensor): an [N, 3] tensor with normals per point.
        x_basis (Tensor): an [N, 3] tensor with x basis per point.
        y_basis (Tensor): an [N, 3] tensor with y basis per point.
        edge_index (Tensor): indices of the adjacency matrix of the k-nn graph [2, N * k].
        k (int): the number of neighbors per point.
    """
    row, col = edge_index
    k = (row == 0).sum() if k is None else k

    # Compute coords
    normal = np.tile(normal[:, None], (1, k, 1)).reshape(-1, 3)
    x_basis = np.tile(x_basis[:, None], (1, k, 1)).reshape(-1, 3)
    y_basis = np.tile(y_basis[:, None], (1, k, 1)).reshape(-1, 3)
    local_pos = pos[col] - pos[row]
    z_pos = batch_dot(local_pos, normal)
    local_pos = local_pos - normal * z_pos
    x_pos = batch_dot(local_pos, x_basis).flatten()
    y_pos = batch_dot(local_pos, y_basis).flatten()
    coords = np.stack([x_pos, y_pos], axis=1)

    return coords, z_pos

In [3]:
import numpy as np
from numpy import linalg as LA
from pcdiff import knn_graph, estimate_basis, build_grad_div, laplacian, coords_projected, gaussian_weights, weighted_least_squares, batch_dot

def grad_curvature(pos, k, kernel_width=1, regularizer=1e-8, shape_regularizer=None):

    edge_index = knn_graph(pos, k)

    normal, x_basis, y_basis = estimate_basis(pos, edge_index)

    row, col = edge_index

    coords, z_pos = xyz_projected(pos, normal, x_basis, y_basis, edge_index, k)

    dist = LA.norm(pos[col] - pos[row], axis=1)
    weights = gaussian_weights(dist, k, kernel_width)

    if shape_regularizer is None:
        wls = weighted_least_squares(coords, weights, k, regularizer)
    else:
        wls, wls_shape = weighted_least_squares(coords, weights, k, regularizer, shape_regularizer)

    C = (wls * z_pos).reshape(-1, k, 6).sum(axis=1)

    # df/dx^2 = 2*c3
    grad_xx = 2 * C[:,3]
    
    # df/dxdy = c4
    grad_xy = C[:,4]
    
    # df/dy^2 = 2*c5
    grad_yy = 2 * C[:,5]

    C = C[row]

    # df/dx = c1 + 2*c3*x + c4*y
    grad_x = C[:,1] + 2 * C[:,3] * coords[:,0] + C[:,4] * coords[:,1]

    # df/dy = c2 + 2*c5*y + c4*x
    grad_y = C[:,2] + 2 * C[:,5] * coords[:,1] + C[:,4] * coords[:,0]
    
    grad = np.column_stack((grad_x, grad_y))
    curvature = grad_xx + 2 * grad_xy + grad_yy

    return grad.reshape(-1, k, 2), curvature, edge_index

In [4]:
import torch
import jax
import jax.numpy as jnp

def L2norm_nbh(data, comparison_size, origin_index=0):
    # Convert PyTorch tensor to JAX array if needed
    if isinstance(data, torch.Tensor):
        data = jax.device_put(data.detach().cpu().numpy())  # Convert to JAX arrayee

    data_jax = jnp.array(data)[:, 0:comparison_size, :]  # Ensure JAX array

    # Select the origin point from each batch (shape: (batch_size, num_features))
    origin = data_jax[:, origin_index, :]

    # Compute L2 norm for each row in the neighborhood (shape: (batch_size, num_neighbors))
    dist = jnp.linalg.norm(data_jax - origin[:, None, :], axis=2)

    return dist  # Shape: (batch_size, num_neighbors)

In [None]:
def get_nn_data(pos, k, comparison_size):

    grad, curv, edge_index = grad_curvature(pos, k, kernel_width=1, regularizer=1e-8, shape_regularizer=None)
    row, col = edge_index

    grad_dist = L2norm_nbh(grad, comparison_size)
    
    pos_dist = L2norm_nbh(pos[col].reshape(-1, k, 3), comparison_size)

    print(grad_dist.shape)
    print(pos_dist.shape)
    print(curv.shape)

    return pos_dist, grad_dist, curv

In [None]:
# data = np.loadtxt('C:/Users/aagaa/Documents/GitHub/R-D/Code/Leihui Code/src/data/bunny.xyz')
# pos = data[:, 0:3]
# k = 20

# get_nn_data(pos, k, 6)

(35947, 6)
(35947, 6)
(35947,)


(Array([[0.        , 0.00106694, 0.00110564, 0.00139692, 0.00143117,
         0.00170653],
        [0.        , 0.00041083, 0.0010538 , 0.0014084 , 0.00144653,
         0.00167976],
        [0.        , 0.00101533, 0.00103636, 0.00167158, 0.00181286,
         0.00181987],
        ...,
        [0.        , 0.00077309, 0.00091525, 0.00098134, 0.00117566,
         0.0011796 ],
        [0.        , 0.00086968, 0.00089027, 0.00109637, 0.00110807,
         0.00132429],
        [0.        , 0.00111946, 0.00112335, 0.00139004, 0.00150597,
         0.00159717]], dtype=float32),
 Array([[0.0000000e+00, 6.7159667e-06, 6.8769832e-06, 7.4498157e-06,
         7.0971555e-06, 1.4111134e-05],
        [0.0000000e+00, 1.4830521e-05, 3.4024673e-05, 4.9797964e-05,
         2.8253722e-05, 5.9923361e-05],
        [0.0000000e+00, 4.9959210e-05, 4.9881623e-05, 2.2313137e-05,
         7.1261478e-05, 2.2709648e-05],
        ...,
        [0.0000000e+00, 3.5556997e-05, 4.7817190e-05, 4.6445872e-05,
         6.2464