In [None]:
import torch
import sys

sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/ShapeAXI/')

import shapeaxi
from shapeaxi import utils

import plotly.graph_objects as go
import plotly.express as px



In [None]:


class OctreeNode:
    def __init__(self, center, half_size):
        self.center = center  # Center of the cube
        self.half_size = half_size  # Half the side length of the cube
        self.children = [None] * 8  # Eight children nodes
        self.points = []  # Points contained in this node

    def is_leaf(self):
        return self.children[0] is None

class Octree:
    def __init__(self, max_depth, min_points):
        self.root = None
        self.max_depth = max_depth
        self.min_points = min_points

    def build(self, points):
        min_corner = points.min(dim=0)[0]
        max_corner = points.max(dim=0)[0]
        center = (min_corner + max_corner) / 2
        half_size = (max_corner - min_corner).max().item() / 2
        self.root = OctreeNode(center, half_size)
        self._insert_points(self.root, points, 0)

    def _insert_points(self, node, points, depth):
        if len(points) <= self.min_points or depth >= self.max_depth:
            node.points = points
            return

        for i in range(8):
            offset = torch.tensor([
                (i & 1) * 2 - 1,
                (i >> 1 & 1) * 2 - 1,
                (i >> 2 & 1) * 2 - 1,
            ], dtype=torch.float32) * node.half_size / 2
            child_center = node.center + offset
            child_half_size = node.half_size / 2
            node.children[i] = OctreeNode(child_center, child_half_size)

        for i in range(8):
            mask = ((points - node.center) > 0) == ((torch.tensor([1, 2, 4]) & i) > 0)
            mask = mask.all(dim=1)
            child_points = points[mask]
            if len(child_points) > 0:
                self._insert_points(node.children[i], child_points, depth + 1)

    def get_centers_at_depth(self, depth):
        centers = []
        self._collect_centers_at_depth(self.root, depth, 0, centers)
        return torch.stack(centers)

    def _collect_centers_at_depth(self, node, target_depth, current_depth, centers):
        if node is None:
            return
        
        if current_depth == target_depth:
            centers.append(node.center)
            return

        for child in node.children:
            self._collect_centers_at_depth(child, target_depth, current_depth + 1, centers)

    def get_leaf_centers(self):
        centers = []
        self._collect_leaf_centers(self.root, centers)
        return torch.stack(centers)

    def _collect_leaf_centers(self, node, centers):
        if node is None:
            return
        
        if node.is_leaf() and len(node.points) > 0:
            centers.append(node.center)
            return

        for child in node.children:
            self._collect_leaf_centers(child, centers)
    
    def query(self, point, radius):
        result = []
        self._query(self.root, point, radius, result)
        return result

    def _query(self, node, point, radius, result):
        if node.is_leaf():
            result.extend(node.points)
            return
        
        for child in node.children:
            if child is None:
                continue
            dist = torch.abs(child.center - point) - child.half_size
            if dist.max().item() <= radius:
                self._query(child, point, radius, result)


# Example usage
# points = torch.rand((1000, 3)) * 100  # 1000 points in 3D space
# octree = Octree(max_depth=10, min_points=27)
# octree.build(points)


In [None]:

mount_point = "/mnt/raid/home/jprieto"
surf = utils.ReadSurf(f'{mount_point}/ModelNet40/airplane/train/airplane_0129.off')
surf = utils.GetUnitSurf(surf)
V, F = utils.PolyDataToTensors_v_f(surf)

octree = Octree(max_depth=10, min_points=27)
octree.build(V)
query_point = torch.tensor([0.3602121, -0.3602121, 0.3602121], dtype=torch.float32)
radius = 0.00001
result = octree.query(query_point, radius)
print(f"Number of points within radius {radius} of {query_point}: {len(result)}")
print(result)
# octree.get_centers_at_depth(8)



In [None]:
centers = octree.get_leaf_centers()

fig = go.Figure(data=[go.Scatter3d(x=centers[:, 0], y=centers[:, 1], z=centers[:, 2], mode='markers', marker=dict(
        size=2,
        color=centers[:, 2],                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))])
fig.show()



In [None]:
fig = go.Figure(data=[go.Scatter3d(x=V[:, 0], y=V[:, 1], z=V[:, 2], mode='markers', marker=dict(
        size=2,
        color=centers[:, 2],                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))])
fig.show()