In [2]:
def xyz_projected(pos, normal, x_basis, y_basis, edge_index, k=None):
    """Projects neighboring points to the tangent basis
    and returns the local coordinates.

    Args:
        pos (Tensor): an [N, 3] tensor with the point positions.
        normal (Tensor): an [N, 3] tensor with normals per point.
        x_basis (Tensor): an [N, 3] tensor with x basis per point.
        y_basis (Tensor): an [N, 3] tensor with y basis per point.
        edge_index (Tensor): indices of the adjacency matrix of the k-nn graph [2, N * k].
        k (int): the number of neighbors per point.
    """
    row, col = edge_index
    k = (row == 0).sum() if k is None else k

    # Compute coords
    normal = np.tile(normal[:, None], (1, k, 1)).reshape(-1, 3)
    x_basis = np.tile(x_basis[:, None], (1, k, 1)).reshape(-1, 3)
    y_basis = np.tile(y_basis[:, None], (1, k, 1)).reshape(-1, 3)
    local_pos = pos[col] - pos[row]
    z_pos = batch_dot(local_pos, normal)
    local_pos = local_pos - normal * z_pos
    x_pos = batch_dot(local_pos, x_basis).flatten()
    y_pos = batch_dot(local_pos, y_basis).flatten()
    coords = np.stack([x_pos, y_pos], axis=1)

    return coords, z_pos

In [1]:
import numpy as np
from numpy import linalg as LA
from pcdiff import knn_graph, estimate_basis, build_grad_div, laplacian, coords_projected, gaussian_weights, weighted_least_squares, batch_dot

def grad_curvature(pos, k, kernel_width=1, regularizer=1e-8, shape_regularizer=None):

    edge_index = knn_graph(pos, k)

    normal, x_basis, y_basis = estimate_basis(pos, edge_index)

    row, col = edge_index

    coords, z_pos = xyz_projected(pos, normal, x_basis, y_basis, edge_index, k)

    dist = LA.norm(pos[col] - pos[row], axis=1)
    weights = gaussian_weights(dist, k, kernel_width)

    if shape_regularizer is None:
        wls = weighted_least_squares(coords, weights, k, regularizer)
    else:
        wls, wls_shape = weighted_least_squares(coords, weights, k, regularizer, shape_regularizer)

    C = (wls * z_pos).reshape(-1, k, 6).sum(axis=1)

    # df/dx^2 = 2*c3
    grad_xx = 2 * C[:,3]
    
    # df/dxdy = c4
    grad_xy = C[:,4]
    
    # df/dy^2 = 2*c5
    grad_yy = 2 * C[:,5]

    C = C[row]

    # df/dx = c1 + 2*c3*x + c4*y
    grad_x = C[:,1] + 2 * C[:,3] * coords[:,0] + C[:,4] * coords[:,1]

    # df/dy = c2 + 2*c5*y + c4*x
    grad_y = C[:,2] + 2 * C[:,5] * coords[:,1] + C[:,4] * coords[:,0]
    
    grad = np.column_stack((grad_x, grad_y))
    curvature = grad_xx + 2 * grad_xy + grad_yy

    return grad.reshape(-1, k, 2), curvature, edge_index

In [53]:
import torch
import jax
import jax.numpy as jnp

def L2norm_nbh(data, comparison_size, origin_index=0):
    # Convert PyTorch tensor to JAX array if needed
    if isinstance(data, torch.Tensor):
        data = jax.device_put(data.detach().cpu().numpy())  # Convert to JAX arrayee

    data_jax = jnp.array(data)[:, 1:comparison_size+1, :]  # Ensure JAX array

    # Select the origin point from each batch (shape: (batch_size, num_features))
    origin = jnp.array(data)[:, origin_index, :]

    # Compute L2 norm for each row in the neighborhood (shape: (batch_size, num_neighbors))
    dist = jnp.linalg.norm(data_jax - origin[:, None, :], axis=2)

    return dist  # Shape: (batch_size, num_neighbors)

In [None]:
def get_nn_data(pos, k, comparison_size):

    grad, curv, edge_index = grad_curvature(pos, k, kernel_width=1, regularizer=1e-8, shape_regularizer=None)
    row, col = edge_index

    grad_dist = L2norm_nbh(grad, comparison_size)
    
    pos_dist = L2norm_nbh(pos[col].reshape(-1, k, 3), comparison_size)

    print(grad_dist.shape)
    print(pos_dist.shape)
    print(curv.shape)

    return pos_dist, grad_dist, curv

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# data = np.loadtxt('C:/Users/aagaa/Documents/GitHub/R-D/Code/Leihui Code/src/data/bunny.xyz')
# pos = data[:, 0:3]
# k = 6

# edge_index = knn_graph(pos, k)

# row, col = edge_index

# normal, __, __ = estimate_basis(pos, edge_index)

# normal_reshaped = normal[col, :].reshape(-1, k, 3)

# norm = jnp.linalg.norm(normal[col] - normal[row], axis=1).reshape(-1, k)

# Norm_diff = L2norm_nbh(normal_reshaped, k)

# high_normdiff = np.max(norm, axis=1)

# # Convert normalized values to RGB using a colormap
# colormap = plt.get_cmap("jet")  # Choose a colormap
# colors = colormap(high_normdiff)[:, :3] # Extract RGB channels (ignore alpha)

# pos = np.hstack((pos, colors))

# np.savetxt("bunny_normal_test.xyz", pos, fmt="%.6f", delimiter=" ")