# Preparation

In [None]:
from plot import explore, patches_format
import numpy as np
import faiss
import torch
%cd ..
import plotly.io as pio
pio.renderers.default = "browser"
#pio.renderers.default = "jupyterlab"
from sklearn.neighbors import NearestNeighbors
import time
from torchvision import transforms

In [None]:
def process(ref, query, n_neighbors=32):
    knn = NearestNeighbors(n_neighbors=n_neighbors)
    knn.fit(ref) 
    val = knn.kneighbors(query, return_distance=False)
    return val

def get_moments(centers):
    if len(centers.shape) == 3:
        centers = centers[0]
    if isinstance(centers, torch.Tensor):
        centers = centers.numpy()
    dist = np.linalg.norm((np.expand_dims(centers, 0) - np.expand_dims(centers, 1)), axis=2)
    idx = np.nonzero(dist == 0)
    dist[idx] = 99
    print("min distance:", np.min(np.min(dist, axis=1)))
    print("mean distance:", np.mean(np.min(dist, axis=1)))
    return dist, centers

def get_dist(centers):
    if len(centers.shape) == 3:
        centers = centers[0]
    if isinstance(centers, torch.Tensor):
        centers = centers.numpy()
    dist = np.linalg.norm((np.expand_dims(centers, 0) - np.expand_dims(centers, 1)), axis=2)
    idx = np.nonzero(dist == 0)
    dist[idx] = 99
    min_distance = np.min(np.min(dist, axis=1))
    mean_distance = np.mean(np.min(dist, axis=1))
    #print("mean distance:", np.mean(np.min(dist, axis=1)))
    #print("mean distance:", np.mean(np.mean(dist, axis=1)))
    return min_distance, mean_distance
    
import pandas as pd
# get the number of point IN not out
def get_no_leftout(output):
    assert output.shape[1] == 3
    output = pd.DataFrame(output).drop_duplicates()
    return len(output)

def get_leftout(input, output):
    assert output.shape[1] == 3
    assert input.shape[1] == 3

    input = pd.DataFrame(input).drop_duplicates()
    output = pd.DataFrame(output).drop_duplicates()

    combined_dfs = pd.concat([input, output])
    symmetric_difference = combined_dfs.drop_duplicates(keep=False)

    return symmetric_difference.to_numpy()



In [None]:
import utils.parser as parser
from utils.config import *
import tools.builder as builder

#argv = ["--config", "cfgs/pretraining/pretrain64.yaml"]
#argv = ["--config", "cfgs/segmentation/offset.yaml"]
argv = ["--config", "cfgs/pretraining/only_shapenet_pretrain.yaml"]
#argv = ["--config", "cfgs/classification/cls_treeset_fewshot.yaml"]
#argv = ["--config", "cfgs/classification/scanobject_hardest.yaml"]
#argv = ["--config", "cfgs/classification/modelnet.yaml"]
args = parser.get_args(argv)
config = get_config(args)
args.distributed = False
args.task = "cls"
args.sampling_method = "kmeans"
config.model.mask_type = "rand"
config.dataset.test.others.model = config.model
config.dataset.test.others.bs = 1
print(config.dataset.test)


In [None]:
_, loader = builder.dataset_builder(args, config.dataset.test)
dataset = loader.dataset
dataset.center = False

# Test on dataset

In [None]:
dataset.grouper.sampling_method = "fps"  # "rand"  "fps" "slice_fps"
dataset.transforms = transforms.Compose([])
dataset.token_transforms = transforms.Compose([])
n = 1000
min_distance, mean_distance, lnout, mean_dist, NNdist = np.empty(n), np.empty(n), np.empty(n), np.empty(n), np.empty(n)
time.sleep(2.5)
start = time.time()
for i in range(n):
    #neighborhood, center, label, _, _, _, _ = dataset[i]
    neighborhood, center, label = dataset[i]
    min_distance[i], mean_distance[i] = get_dist(center)
    sample = neighborhood.reshape(-1, 3)
    lnout[i] = get_no_leftout(sample.numpy())
end = time.time()
missing_points = 1 - lnout/dataset.npoints
std = np.std(missing_points) * 100
missing_points = np.mean(missing_points) * 100
print(f"{missing_points:.1f} \\% $\\pm$ {std:.1f} & {np.mean(min_distance):.2f} & {np.mean(mean_distance):.2f}  & {((end - start)/n*10):.2f}")

In [None]:
# Timing
dataset.grouper.sampling_method = "kmeans"  # "rand"  "fps" "slice_fps"
dataset.normalization = False
n = 100
min_distance, mean_distance, lnout = np.empty(n), np.empty(n), np.empty(n)
start = time.time()
for i in range(n):
    neighborhood, center, mask = dataset[i]
end = time.time()

print(f"{((end - start)/10):.3f}")