In [1]:
import torch

import principal_curvature as pc
import utils
import matplotlib.pyplot as plt
import sampling_algs
import math
import getopt
import os.path as osp
import sys

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import ModelNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius
import sampling_algs
import utils
import time
import csv
import numpy as np

import principal_curvature
import numpy as np

In [2]:
def weighted_fps():
    "some method that does fps but curvature weighted. with the prob(local curvautre) we choose that as the next point"
    # torch.bincount supports a weights parameter. Could make coordinate inputs and curvautre weights
    pass

In [36]:
n_points = int(math.pow(2,14))
train_dataset = utils.import_train(n_points)
n_points

16384

In [37]:
h = 1
cloud = train_dataset[h].pos
ratio = .5
bias = .99
k = 10
desired_num_points = int(cloud.size(0)*ratio)

In [38]:
curvatures = principal_curvature.curvatures_knn(cloud, 10)

In [74]:
import torch

def compute_distances(points, reference_point):
    return torch.norm(points - reference_point, dim=1)

def fps_pure(points, num_points):
    num_total_points = points.shape[0]
    selected_indices = []
    selected_mask = torch.zeros(num_total_points, dtype=torch.bool)  # Mask to keep track of selected points

    initial_seed_index = torch.randint(0, num_total_points, (1,))
    selected_indices.append(initial_seed_index.item())
    selected_mask[selected_indices[-1]] = True

    for _ in range(num_points):
        current_points = points[selected_indices]
        distances = torch.min(torch.stack([compute_distances(points, p) for p in current_points]), dim=0).values

        # Exclude distances of already selected points
        distances[selected_mask] = float('-inf')

        farthest_index = torch.argmax(distances)
        selected_indices.append(farthest_index.item())
        selected_mask[selected_indices[-1]] = True

    return torch.tensor(selected_indices)

def fps_weighted(points,num_points, curvature_values, curvature_scalar):
    """
    Perform weighted farthest point sampling based on both distance and curvature.
    The curvature scalar sets the weighting for the curvature over distance.
    Higher curvature scalar = more weight to curvature, less weight to distance.

    :param points: Tensor of shape [N, 3] representing the point cloud.
    :param curvature_values: Tensor of shape [N] containing curvature values for each point.
    :param num_points: Number of points to sample.
    :param curvature_scalar: A scalar weight for the curvature values.
    :return: 1D tensor of indices representing the selected points.
    """
    num_total_points = points.shape[0]
    selected_indices = []
    selected_mask = torch.zeros(num_total_points, dtype=torch.bool)

    initial_seed_index = torch.randint(0, num_total_points, (1,))
    selected_indices.append(initial_seed_index.item())
    selected_mask[selected_indices[-1]] = True
    for _ in range(num_points):
        current_points = points[selected_indices]

        distances = torch.min(torch.stack([compute_distances(points, p) for p in current_points]), dim=0).values
        distances[selected_mask] = float('-inf')

        # curvatures = curvature_values[selected_indices]
        weighted_scores = distances + (curvature_values * curvature_scalar)
        selected_idx = torch.argmax(weighted_scores)
        selected_indices.append(selected_idx.item())
        selected_mask[selected_indices[-1]] = True

        # fig = plt.figure()
        # ax = fig.add_subplot(111, projection='3d')
        # scatter1 = ax.scatter(points[:,0], points[:,1], points[:,2], marker='.',alpha=.1,color="grey")
        # scatter2 = ax.scatter(points[selected_indices,0], points[selected_indices,1], points[selected_indices,2], marker='o',alpha=1,color="orange")
        # scatter3 = ax.scatter(points[selected_idx,0], points[selected_idx,1], points[selected_idx,2], marker='X',s=20,alpha=1,color="red")
        # ax.set_xlabel('X')
        # ax.set_ylabel('Y')
        # ax.set_zlabel('Z')
        # ax.view_init(elev=30, azim=340)
        # plt.show(block=True)
        # if i>10:
        #     raise KeyboardInterrupt

    return torch.tensor(selected_indices)

def fps_top_n(points, num_points, n, curvature_values):
    """
    Perform farthest point sampling by selecting the n farthest points based on distance and
    then choosing the one with the highest curvature value among those n points.

    :param points: Tensor of shape [N, 3] representing the point cloud.
    :param num_points: Number of points to sample.
    :param n: Number of points to consider for curvature-based selection.
    :param curvature_values: Tensor of shape [N] containing curvature values for each point.
    :return: 1D tensor of indices representing the selected points.
    """
    num_total_points = points.shape[0]
    selected_indices = []
    selected_mask = torch.zeros(num_total_points, dtype=torch.bool)


    initial_seed_index = torch.randint(0, num_total_points, (1,))
    selected_indices.append(initial_seed_index.item())
    selected_mask[selected_indices[-1]] = True
    for _ in range(num_points):
        current_points = points[selected_indices]
        distances = torch.min(torch.stack([compute_distances(points, p) for p in current_points]), dim=0).values
        farthest_indices = torch.topk(distances.flatten(), n).indices

        distances[selected_mask] = float('-inf')

        res = torch.argmax(curvature_values[farthest_indices])
        selected_index = farthest_indices[res.item()]

        selected_indices.append(selected_index.item())
        selected_mask[selected_indices[-1]] = True
        # fig = plt.figure()
        # ax = fig.add_subplot(111, projection='3d')
        # scatter1 = ax.scatter(points[:,0], points[:,1], points[:,2], marker='.',alpha=.1,color="grey")
        # scatter2 = ax.scatter(points[selected_indices,0], points[selected_indices,1], points[selected_indices,2], marker='o',alpha=1,color="orange")
        # scatter3 = ax.scatter(points[farthest_indices,0], points[farthest_indices,1], points[farthest_indices,2], marker='o',alpha=1,color="royalblue")
        # scatter4 = ax.scatter(points[selected_index,0], points[selected_index,1], points[selected_index,2], marker='X',s=20,alpha=1,color="red")
        # ax.set_xlabel('X')
        # ax.set_ylabel('Y')
        # ax.set_zlabel('Z')
        # # rotate_plot()
        # ax.view_init(elev=30, azim=340)
        # plt.show(block=True)




    return torch.tensor(selected_indices)

In [75]:
# selected_indices = farthest_point_sampling(cloud, desired_num_points)

In [None]:
curve_scale = 0
selected_indices = fps_weighted(cloud, desired_num_points, curvatures, curve_scale)


In [52]:
# selected_indices_curve200 = fps_top_n(cloud, desired_num_points,10, curvatures)
# selected_indices_curve10 = fps_top_n(cloud, num_sampled_points, 10,curvatures)

In [None]:
%matplotlib qt
point_cloud_np = cloud.numpy()
selected_points_np = cloud[selected_indices].numpy()

# Plot the point clouds
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Plot the original point cloud in blue
# ax.scatter(point_cloud_np[:, 0], point_cloud_np[:, 1], point_cloud_np[:, 2], c='blue', s=5,alpha=0, label='Original Points')

# Plot the selected indices in red
# ax.scatter(selected_points_np[:, 0], selected_points_np[:, 1], selected_points_np[:, 2], c='grey', label='Selected Points')
# ax.scatter(cloud[selected_indices_curve10, 0], cloud[selected_indices_curve10, 1], cloud[selected_indices_curve10, 2], c='green', s=20, label='Selected Points')
ax.scatter(cloud[selected_indices_curve200, 0], cloud[selected_indices_curve200, 1], cloud[selected_indices_curve200, 2], c='orange', s=2, label='Selected Points')

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('Original Point Cloud and Selected Indices')
ax.legend()

plt.show()

In [None]:
raise KeyboardInterrupt

In [None]:
knn = principal_curvature.k_nearest_neighbors(cloud, k)
curves = principal_curvature.principal_curvature(cloud, knn)

In [None]:
curves.numpy()

In [None]:
%matplotlib qt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
scatter3 = ax.scatter(cloud[sample_indices,0], cloud[sample_indices,1], cloud[sample_indices,2], marker='.',alpha=.5,color="green")
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
# rotate_plot()
ax.view_init(elev=30, azim=340)
plt.show(block=True)

#### Bias sampler in isolation

In [None]:


des_nr_points = int(ratio*n_points)
test_idxs = sampling_algs.bias_anyvsfps_sampler(cloud,des_nr_points,bias, func1=sampling_algs.max_curve_sampler, args1=k)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
scatter3 = ax.scatter(cloud[test_idxs,0], cloud[test_idxs,1], cloud[test_idxs,2], marker='.',alpha=.5,color="green")
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
# rotate_plot()
ax.view_init(elev=30, azim=340)
plt.show(block=True)

In [None]:
raise KeyboardInterrupt