In [None]:
%cd..

import point_cloud_utils as pcu
import numpy as np
import torch
from losses.chamfer import chamfer_loss
from eval.sampling_util import sample_pc_random


def load_node_mesh(args, node_id):
    """
    Load a mesh corresponding to a specific node.
    """
    data_path = args.data_path + "obj/"
    mesh_path = data_path + node_id + ".obj"
    v, f = pcu.load_mesh_vf(mesh_path)
    return v, f


def sample_mesh(args, node_id, n_points):
    """
    Sample pointcloud from a mesh surface.
    """
    v, f = load_node_mesh(args, node_id)
    return sample_pc_random(v, f, n_points)


def get_pc_ae_transform(args):
    """
    Transform pointcloud for the PC-AE encoder.
    """
    def center_in_unit_sphere(pc, in_place=True):
        if not in_place:
            pc = pc.copy()
        
        for axis in range(3):  # center around each axis
            r_max = np.max(pc[:, axis])
            r_min = np.min(pc[:, axis])
            gap = (r_max + r_min) / 2.0
            pc[:, axis] -= gap
    
        largest_distance = np.max(np.sqrt(np.sum(pc**2, axis=1)))
        pc /= largest_distance
        return pc

    # Defining transformation
    if "RND" in args.data_path:
        scale_factor = np.array([0.46350697, 0.35710097, 0.40755142])  # RND
    else:
        scale_factor = np.array([0.33408034, 0.39906635, 0.35794342])  # NORMALSHIET
    
    rotate_matrix = np.array(
        [[1, 0, 0], [0, 1, 0], [0, 0, -1]]
    )

    # Apply both scaling and rotation
    transform = lambda x: center_in_unit_sphere(np.matmul(
        rotate_matrix, x.T
    ).T) * scale_factor

    return transform


def distance_to_mesh(args, in_p, node_id, n_samples=8):
    n_input = in_p.shape[0]

    # Get the pc transform
    pc_t = get_pc_ae_transform(args)

    in_p = pc_t(in_p)
    in_p = torch.from_numpy(in_p).unsqueeze(0).float().cuda()

    # Sample 2^n points
    p = sample_mesh(args, node_id, n_samples*n_input)

    # Randomly split points into batches of size n_input
    np.random.shuffle(p)
    n_batches = p.shape[0] // n_input
    p = p.reshape(n_batches, n_input, 3)

    # For each batch, compute average distance to the input pointcloud
    avg_dist = np.zeros(n_batches)
    with torch.no_grad():
        for i in range(n_batches):
            # Transform the batch
            p[i] = pc_t(p[i])
            p_c = torch.from_numpy(p[i]).float().cuda().unsqueeze(0)

            # Compute chamfer distance to input pointcloud
            cd_dist = chamfer_loss(p_c, in_p, reduction="mean").mean()
            avg_dist[i] = cd_dist.item()

    # Compute the average distance and standard deviation
    avg_dist = np.min(avg_dist)

    return avg_dist


# create dummy args
class Args:
    def __init__(self):
        self.data_path = "/ibex/user/slimhy/ShapeWalk/"
args = Args()

# Debug input pointcloud
in_p = sample_mesh(args, "100062696957012786", 4096)
distance_to_mesh(args, in_p, "100062696957012786", n_samples=15)*(10**4)