In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys

# Add the project's files to the python path
# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script
file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook
sys.path.append(file_path)

import torch
from src.datasets.s3dis import CLASS_NAMES, CLASS_COLORS, STUFF_CLASSES
from src.datasets.s3dis import S3DIS_NUM_CLASSES as NUM_CLASSES
from src.transforms import *

The main data structures of this project are `Data` and `NAG`.

`Data` stores a single-level graph. 
It inherits from `torch_geometric`'s `Data` and has a similar behavior (see the [official documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data) for more on this). 
Important specificities of our `Data` object are:
- `Data.super_index` stores the parent's index for each node in `Data`
- `Data.sub` holds a `Cluster` object indicating the children of each node in `Data`
- `Data.to_trimmed()` works like `torch_geometric`'s `Data.coalesce()` with the additional constraint that (i,j) and (j,i) edges are considered duplicates
- `Data.save()` and `Data.load()` allow optimized, memory-friedly I/O operations
- `Data.select()` indexes the nodes à la numpy

`NAG` (Nested Acyclic Graph) stores the hierarchical partition in the form of a list of `Data` objects.
Important specificities of our `Data` object are:
- `NAG[i]` returns a `Data` object holding the partition level `ì`
- `NAG.get_super_index()` returns the index mapping nodes from any level `i` to `j` with `i<j`
- `NAG.get_sampling()` produces indices for sampling the superpoints with certain constraints
- `NAG.save()` and `NAG.load()` allow optimized, memory-friedly I/O operations
- `NAG.select()` indexes the nodes of a specified partition level à la numpy and updates the rest of the `NAG` structure accordingly

## Load a NAG

In [2]:
nag = torch.load('demo_nag.pt')

In [11]:
type(nag)

src.data.nag.NAG

In [12]:
#save with .pt file
torch.save(nag, 'demo_nag_2.pt')

In [13]:
nag = torch.load('demo_nag_2.pt')

In [14]:
# Print general info about the NAG
print(nag)

NAG(num_levels=4, num_points=[41568, 1192, 501, 166], device=cpu)


In [15]:
# Loop over the partition levels and print general info about each Data
for i_level, data in enumerate(nag):
    print(f"Level-{i_level}:\n{data}\n")

Level-0:
Data(super_index=[41568], y=[41568, 14], pos=[41568, 3], elevation=[41568, 1], linearity=[41568, 1], planarity=[41568, 1], rgb=[41568, 3], scattering=[41568, 1], verticality=[41568, 1])

Level-1:
Data(edge_index=[2, 9158], sub=Cluster(num_clusters=1192, num_points=41568, device=cpu), super_index=[1192], edge_attr=[9158, 7], y=[1192, 14], pos=[1192, 3], log_length=[1192, 1], log_size=[1192, 1], log_surface=[1192, 1], log_volume=[1192, 1], normal=[1192, 3])

Level-2:
Data(edge_index=[2, 7232], sub=Cluster(num_clusters=501, num_points=1192, device=cpu), super_index=[501], edge_attr=[7232, 7], y=[501, 14], pos=[501, 3], log_length=[501, 1], log_size=[501, 1], log_surface=[501, 1], log_volume=[501, 1], normal=[501, 3])

Level-3:
Data(edge_index=[2, 2545], sub=Cluster(num_clusters=166, num_points=501, device=cpu), edge_attr=[2545, 7], y=[166, 14], pos=[166, 3], log_length=[166, 1], log_size=[166, 1], log_surface=[166, 1], log_volume=[166, 1], normal=[166, 3])



## Visualizing a NAG

In [None]:
# Visualize the hierarchical partition
nag.show( 
    class_names=CLASS_NAMES,
    class_colors=CLASS_COLORS, 
    stuff_classes=STUFF_CLASSES,
    num_classes=NUM_CLASSES,
    max_points=100000,
    centroids=True,
    h_edge=True
)

## Selecting a portion of the hierarchical partition
The NAG structure can be subselected using `nag.select()`.

This function expects an `int` specifying the partition level from which we should select, along with an index or a mask in the form or a `list`, `numpy.ndarray`, `torch.Tensor`, or `slice`.
This index/mask describes which nodes to select at the specified level.

The output NAG will only contain children, parents and edges of the selected nodes.

In [6]:
# Pick a center and radius for the spherical sample
center = nag[0].pos.mean(dim=0)
radius = 1

# Create a mask on level-0 (i.e. points) to be used for indexing the NAG 
# structure
mask = torch.where(torch.linalg.norm(nag[0].pos - center, dim=1) < radius)[0]

# Subselect the hierarchical partition based on the level-0 mask
nag_visu = nag.select(0, mask)

In [None]:
# Visualize the sample
nag_visu.show(
    class_names=CLASS_NAMES,
    class_colors=CLASS_COLORS, 
    stuff_classes=STUFF_CLASSES,
    num_classes=NUM_CLASSES,
    max_points=100000,
    centroids=True,
    h_edge=True
)

In [None]:
# Visualize the sample
nag_visu.show(
    class_names=CLASS_NAMES,
    class_colors=CLASS_COLORS, 
    stuff_classes=STUFF_CLASSES,
    num_classes=NUM_CLASSES,
    max_points=100000,
    centroids=False,
    h_edge=True
)

In [None]:
# Visualize the sample
nag_visu.show(
    #class_names=CLASS_NAMES,
    #class_colors=CLASS_COLORS, 
    #stuff_classes=STUFF_CLASSES,
    #num_classes=NUM_CLASSES,
    max_points=100000,
    centroids=False,
    h_edge=True
)

**Tip** - the above example is used to illustrate the `nag.select()` method, which is not limited to a mask for spherical samplings. However, since visualizing a spherical sample of a large cloud is a common operation, the `show()` function allows you to do the same as above, by specifying a `radius` and a `center`. See `show()` documentation for more details.

In [7]:
# nag = torch.load("/workspace/superpoint_transformer/data/kitti360gs/processed/train/6d12657d4d93f80b2cc65f0c502012c5/scenes/kitchen.pt")
nag = torch.load("/workspace/superpoint_transformer/supergaussians/kitchen.pt")

In [4]:
# Loop over the partition levels and print general info about each Data
for i_level, data in enumerate(nag):
    print(f"Level-{i_level}:\n{data}\n")

Level-0:
Data(pos=[10565, 3], pos_offset=[3], rgb=[10565, 3], f_dc=[10565, 3], f_rest=[10565, 45], scales=[10565, 3], rots=[10565, 4], sub=Cluster(num_clusters=10565, num_points=241367, device=cpu), linearity=[10565, 1], planarity=[10565, 1], scattering=[10565, 1], verticality=[10565, 1], curvature=[10565, 1], length=[10565, 1], volume=[10565, 1], normal=[10565, 3], super_index=[10565])

Level-1:
Data(pos=[354, 3], sub=Cluster(num_clusters=354, num_points=10565, device=cpu), node_size=[354], super_index=[354], log_length=[354, 1], log_surface=[354, 1], log_volume=[354, 1], normal=[354, 3], log_size=[354, 1], edge_index=[2, 5621], edge_attr=[5621, 7])

Level-2:
Data(pos=[118, 3], sub=Cluster(num_clusters=118, num_points=354, device=cpu), node_size=[118], super_index=[118], log_length=[118, 1], log_surface=[118, 1], log_volume=[118, 1], normal=[118, 3], log_size=[118, 1], edge_index=[2, 1864], edge_attr=[1864, 7])

Level-3:
Data(pos=[24, 3], sub=Cluster(num_clusters=24, num_points=118, d

In [10]:
# # Pick a center and radius for the spherical sample
# center = nag[0].pos.mean(dim=0)
# radius = 3

# # Create a mask on level-0 (i.e. points) to be used for indexing the NAG 
# # structure
# mask = torch.where(torch.linalg.norm(nag[0].pos - center, dim=1) < radius)[0]

# # Subselect the hierarchical partition based on the level-0 mask
# nag_visu = nag.select(0, mask)

# Visualize the hierarchical partition
nag.show( 
    max_points=100000,
    centroids=False,
    h_edge=True
)