# 1. Reading and visualizing raw point clouds using `Data` objects

### 1.1. Preparing a `Data` reader

In [1]:
import os
import sys

# Add the project's files to the Python path
file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook
sys.path.append(file_path)

import torch
from plyfile import PlyData
from src.data import Data
from src.utils.color import to_float_rgb

def read_ply_file(
        filepath, 
        xyz=True, 
        rgb=True, 
        intensity=False, 
        semantic=False, 
        instance=False):
    """Read a PLY file.

    :param filepath: str
        Absolute path to the PLY file
    :param xyz: bool
        Whether XYZ coordinates should be saved in the output Data.pos
    :param rgb: bool
        Whether RGB colors should be saved in the output Data.rgb
    :param intensity: bool
        Whether intensity should be saved in the output Data.intensity
    :param semantic: bool
        Whether semantic labels should be saved in the output Data.y
    :param instance: bool
        Whether instance labels should be saved in the output Data.obj
    """
    # Create an empty Data object
    data = Data()
    
    plydata = PlyData.read(filepath)
    vertices = plydata['vertex']

    # Populate data with point coordinates
    if xyz:
        data.pos = torch.stack([
            torch.tensor(vertices[axis], dtype=torch.float32)
            for axis in ['x', 'y', 'z']
        ], dim=-1)

    # Populate data with point RGB colors
    if rgb and 'red' in vertices and 'green' in vertices and 'blue' in vertices:
        data.rgb = to_float_rgb(torch.stack([
            torch.tensor(vertices[axis], dtype=torch.float32) / 255
            for axis in ['red', 'green', 'blue']
        ], dim=-1))

    # Populate data with point LiDAR intensity
    if intensity and 'intensity' in vertices:
        data.intensity = torch.tensor(vertices['intensity'], dtype=torch.float32)

    # Populate data with point semantic segmentation labels
    if semantic and 'label' in vertices:
        data.y = torch.tensor(vertices['label'], dtype=torch.long)

    # Populate data with point panoptic segmentation labels
    if instance and 'instance' in vertices:
        data.obj = torch.tensor(vertices['instance'], dtype=torch.long)

    return data


Often, we need to remap the raw labels provided in a dataset to another set of labels to be used for training. 
In the next cell, we define some environment variables for remapping Vancouver class indices and corresponding customized class names and colors for downstream visualization.

> **Tip 💡**: As described in our [datasets documentation](../docs/datasets.md/#semantic-label-format) we consider labels in `[0, num_classes - 1]` to be valid classes and use the `num_classes` label for void/ignored/unlabeled points (whichever you call it). Check out the [documentation](../docs/datasets.md/#semantic-label-format) for more details.

In [2]:
import numpy as np

# Number of classes in the dataset (excluding void/unlabeled/ignored)
PLY_NUM_CLASSES = 13

# Mapping from original classes
ID2TRAINID = np.asarray([
    0,   # 0 Ceiling
    1,   # 1 Floor
    2,   # 2 Wall
    3,   # 3 Beam
    4,   # 4 Column
    5,   # 5 Window
    6,   # 6 Door
    7,   # 7 Chair
    8,   # 8 Table
    9,   # 9 Bookcase
    10,  # 10 Sofa
    11,  # 11 Board
    12,  # 12 Clutter
    13   # 13 Ignored
])

# Class names (including void/unlabeled/ignored last)
PLY_CLASS_NAMES = [
    'ceiling',
    'floor',
    'wall',
    'beam',
    'column',
    'window',
    'door',
    'chair',
    'table',
    'bookcase',
    'sofa',
    'board',
    'clutter',
    'ignored']

# Class color palette (including void/unlabeled/ignored last)
PLY_CLASS_COLORS = np.asarray([
    [233, 229, 107],  # Ceiling
    [ 95, 156, 196],  # Floor
    [179, 116,  81],  # Wall
    [241, 149, 131],  # Beam
    [ 81, 163, 148],  # Column
    [ 77, 174,  84],  # Window
    [108, 135,  75],  # Door
    [ 41,  49, 101],  # Chair
    [ 79,  79,  76],  # Table
    [223,  52,  52],  # Bookcase
    [ 89,  47,  95],  # Sofa
    [ 81, 109, 114],  # Board
    [233, 233, 229],  # Clutter
    [  0,   0,   0]])  # Ignored


### 1.2. `Data` visualization

We can now download tiles from [Vancouver LiDAR 2022](https://opendata.vancouver.ca/explore/dataset/lidar-2022/map/?location=12,49.25672,-123.14434) and read their content into a `Data` object.

In [3]:
filepath = '/home/yuanyan/autonomous-exploration-with-lio-sam/maps/test.ply'
data = read_ply_file(filepath)

We have created a `Data` object containing out point cloud and associated attributes. 
Let's have a closer look at it !

The basic `Data.__repr__()` will show the attributes (ie keys) in Data and their respective shapes.

In [None]:
data

You can check the number of points (ie nodes) in a `Data` object with `data.num_points` (or `data.num_nodes`).

In [None]:
data.num_points

You can check the list of attributes stored in a `Data` object with `data.keys`.

In [None]:
data.keys

We provide a [Plotly](https://plotly.com/python)-based too for visalizing `Data` objects. To use it, simply use `data.show()`. This function offers many options for customizing your plot. We will see later on that it can also be used for visualizing hierarchical superpoint partitions held in `NAG` objects.

First, let's visualize the whole point cloud contained in `Data` (this may take a couple of seconds if your cloud has $\sim10^5$ points or more).
We can specify our `class_names` and `class_colors` to `show()` to customize the displaying of semantic segmentation labels.

In [None]:
data.show(class_names=PLY_CLASS_NAMES, class_colors=PLY_CLASS_COLORS)

By default, the point cloud is subsampled`max_points=50000` to alleviate the visualization computation time.
To get a clearer, high-resolution view, you can increase `max_points` or visualize smaller scenes.
You can for instance, only display a spherical crop of the point cloud by specifying a `center` and a `radius`.

In [None]:
data.show(center=[-17, 36, 0], radius=30, keys=['intensity'], class_names=PLY_CLASS_NAMES, class_colors=PLY_CLASS_COLORS)

> **Tips 💡**
> - More info on our `Data` structure ? 👉 see [`docs/data_structures.md`](../docs/data_structures.md), our source code in `src.data.data`, and the [PyG Data documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data) it builds upon
> - More info on our `show()` visualization tool ? 👉 see [`docs/visualization.md`](../docs/visualization.md) and  source code in `src.visualization`

# 2. Using a pretrained model for inference

We provide pretrained weights and preprocessing parametrization for several datasets (see [README](../README.md) and [datasets documentation](../docs/datasets.md)). Since the Vancouver dataset is fairly similar to DALES, we would like to check how a DALES-pretrained SPT would fare on our present `Data` object.

As mentioned in the [introductory slides](../media/superpoint_transformer_tutorial.pdf), running an inference with a pretrained SPT requires more than just the model weights. Indeed, we also need to apply to our `Data` the same `pre_transform` and `on_device_transform` as the ones used for training the model.

### 2.1. Instantiating transforms from `configs/`

We will first need to recover the transforms used in the DALES experiments as provided in the `configs/experiment` using [Hydra](https://hydra.cc/docs/intro/). 
In the next cell, we show how to use the `init_config()` utility to get the **exact configuration used for training the released DALES model**.

> **Tips 💡**
> - More info on how `configs/` & [Hydra](https://hydra.cc/docs/intro/) work ? 👉 see the [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template) repository
> - More info on a specific experiment's settings ? 👉 explore our configuratin files in `configs/`, these are fairly commented 😉

In [9]:
from src.utils import init_config

cfg = init_config(overrides=[f"experiment=semantic/s3dis"])

This `cfg` is an [omegaconf](https://omegaconf.readthedocs.io) `DictConfig` object. It contains all the necessary hyperparameters for reproducing the pretraining experiment: dataset, model structure, training recipe, etc. We can explore its content just like a basic dictionary, or a simple object.

In [None]:
cfg.keys()

The parametrization of the transforms is specified in the datamodule config in `cfg.datamodule`.
We can instantiate the transforms from an [omegaconf](https://omegaconf.readthedocs.io) `DictConfig` object without instantiating the whole dataset by using the `instantiate_datamodule_transforms()` utility.

In [None]:
from src.transforms import instantiate_datamodule_transforms

transforms_dict = instantiate_datamodule_transforms(cfg.datamodule)
transforms_dict

The transforms are chained operations applied to a `Data` or a `NAG` object. Their order and parametrization plays a significant role and modifying these may have non-negligible downstream effects. **These must be thought as part of the model itself**.

### 2.2. Applying transforms

As explained in the [introductory slides](../media/superpoint_transformer_tutorial.pdf), we will be using `pre_transform` and `on_device_test_transform` to reproduce the behavior of the pretrained model at inference time.

> **Note 🤓**: In the next cell, we manually apply some `NAGRemoveKeys()` transform after the `pre_transform`. This is because we ocasionally need to mimick the full behavior of the pretraining `Dataset`: after the `pre_transform` is executed, the preprocessed `NAG` is saved to disk. When later read from disk by the `Dataset`, only the `point_load_keys` attributes of `NAG[0]` and `segment_load_keys` attributes of `NAG[i], i>0` are loaded from disk. This mechanism ensures we only load the strict necessary during training, hence saving I/O time. Since we are running the `pre_transform` manually here, we need to account for this mechanism and discard the preprocessed attributes that the DALES dataset did not read from disk. These can be found in `cfg.datamodule.point_load_keys` and `cfg.datamodule.segment_load_keys`.

In [12]:
# Apply pre-transforms
nag = transforms_dict['pre_transform'](data)

# Simulate the behavior of the dataset's I/O behavior with only
# `point_load_keys` and `segment_load_keys` loaded from disk
from src.transforms import NAGRemoveKeys
nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag)
nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag)

# Move to device
nag = nag.cuda()

# Apply on-device transforms
nag = transforms_dict['on_device_test_transform'](nag)

The output of the transforms is no longer a `Data` object, but a `NAG`. This is the data structure we use to carry around **point clouds** and **hierarchical superpoint partitions**. 

Essentially, it is a list of `Data` objects, each representing a partition level:
- `nag[0]` is $P_0$, the (voxelized) points
- `nag[i]` is $P_i$, the $\text{i}^\text{th}$ superpoint partition level 

At each level $i>0$, the `edge_index` and `edge_attr` attributes carry the **superpoint adjacency graph** and corresponding **adjacency features**.

> **Tip 💡** More info on our `NAG` structure ? 👉 see [`docs/data_structures.md`](../docs/data_structures.md) and source code in `src.data.nag`

Now we have preprocessed our data, we need to run an inference with the pretrained model.

> **Tip 💡**: If you want to store your progress disk, both `Data` and `NAG` have `.save()` and `.load()` methods specially designed with fast I/O and disk usage in mind 😉.

### 2.3. Instantiating a pretrained model from `configs/` and a `*.ckpt`

Similar to the transforms, we will use the DALES experiment configuration files to instantiate the **pretrained model**. 
This time, the part of the [omegaconf](https://omegaconf.readthedocs.io) `DictConfig` object we are interested in is stored under `cfg.model`.

As stated in the [README](../README.md), the pretrained weights for our models can be recovered from [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.8042712.svg)](https://doi.org/10.5281/zenodo.8042712).

In [None]:
import hydra 
from src.utils import init_config

# Path to the checkpoint file downloaded from https://zenodo.org/records/8042712
ckpt_path = "/home/yuanyan/Documents/superpoint_transformer/models/spt-2_s3dis_fold6.ckpt"

cfg = init_config(overrides=[f"experiment=semantic/s3dis"])

# Instantiate the model and load pretrained weights
model = hydra.utils.instantiate(cfg.model)
model = model._load_from_checkpoint(ckpt_path)

### 2.4. Applying SPT

Now everything is ready for running our inference ! 

In [14]:
# Set the model in inference mode on the same device as the input
model = model.eval().to(nag.device)

# Inference, returns a task-specific ouput object carrying predictions
with torch.no_grad():
    output = model(nag)

The output of the model is a `SemanticSegmentationOutput` object. It is a simple class dedicated to holding onto predictions in `output.semantic_pred()` and facilitating certain basic post-processing operations such as metrics computation. 

In [None]:
output.semantic_pred().shape, nag.num_points

As stated in [introductory slides](../media/superpoint_transformer_tutorial.pdf), it is important to remember that, by default, **SPT outputs predictions on the $P_1$ level** (ie `nag[1]`). Since the superpoints $P_1$ are assumed to be semantically pure, simply classifying those is equivalent to classifying each point in the scene. In doing so, we save a lot of computation and memory during training.

Yet, at inference time, we often want the predictions at the voxel level $P_0$ (ie `nag[0]`) or even at the full-resolution of the raw input cloud. 
To this end, we simply need to distribute the $P_1$ predictions to the lower partition levels.
The `SemanticSegmentationOutput.voxel_semantic_pred()` and `SemanticSegmentationOutput.full_res_semantic_pred()` were designed just for that ! 

In the next cell, we will convert $P_1$ predictions into $P_0$ predictions.

> **Tip 💡**: For **full-resolution predictions**, see our [`demo.ipynb` notebook](../notebooks/demo.ipynb), and have a look at [`src.utils.output_semantic.py`](../src/utils/output_semantic.py#L140). Remember that if you have applied a tiling to your data, your full-resolution predictions will be given for the tile at hand and not the original point cloud.

> **Note 🤓**: Although SPT does make predictions as $P_1$ node classifications, all losses and metrics are properly computed so as to take into account the true labels assigned to full-resolution points. To make these efficient, our pipeline always tracks the **histogram of ground truth labels** for each voxel in $P_0$ and superpoint in $P_i, i>0$.

In [16]:
# Compute the level-0 (voxel-wise) semantic segmentation predictions 
# based on the predictions on level-1 superpoints and save those for 
# visualization in the level-0 Data under the 'semantic_pred' attribute
nag[0].semantic_pred = output.voxel_semantic_pred(super_index=nag[0].super_index)

Let's visualize the resulting predictions on a small area for high-resolution display.

Note that since the model was trained on DALES classes, the predicted labels do not align with those of our Vancouver dataset. 
For better visualization, we will use the DALES `CLASS_NAMES` and `CLASS_COLORS`.

In [None]:
from src.datasets.s3dis_room import CLASS_NAMES as CLASS_NAMES
from src.datasets.s3dis_room import CLASS_COLORS as CLASS_COLORS

print(CLASS_NAMES)
#print(CLASS_COLORS)

nag.show(class_names=CLASS_NAMES, class_colors=CLASS_COLORS, center=[0, 0, 0], radius=100)

### 2.5. Post processing

To adapt the pre-trained model to simple construction site environment, we use color remapping here, where the specified labels (wall, beam, column, window, door, bookcase, board) are set to red, and others are set to grey, if bool simple_env is True.

It also saves the segmented point cloud to a PLY file with points colored based on their labels for way points generating.

In [None]:
import open3d as o3d
import numpy as np

# Function to save segmented point cloud with remapping if required
def save_segmented_ply(data, file_path, simple_env=False):
    points = data.pos.cpu().numpy()  # Extract point coordinates
    labels = data.semantic_pred.cpu().numpy()  # Extract semantic labels
    
    # Map labels to colors
    colors = np.array(PLY_CLASS_COLORS)[labels]
    
    if simple_env:
        # Remap specific classes to red, others to grey
        red_color = np.array([255, 0, 0])
        grey_color = np.array([128, 128, 128])
        remap_classes = ['wall', 'beam', 'column', 'window', 'door', 'bookcase', 'board']
        remap_indices = [PLY_CLASS_NAMES.index(cls) for cls in remap_classes]
        colors = np.where(
            np.isin(labels, remap_indices)[:, None], red_color, grey_color
        )
    
    # Create an Open3D point cloud object
    pcl = o3d.geometry.PointCloud()
    pcl.points = o3d.utility.Vector3dVector(points)
    pcl.colors = o3d.utility.Vector3dVector(colors / 255.0)  # Normalize colors to [0, 1]
    
    # Save to PLY file
    o3d.io.write_point_cloud(file_path, pcl)
    print(f"Segmented point cloud saved to {file_path}")

# Example usage
output_path = '/home/yuanyan/autonomous-exploration-with-lio-sam/segmented_output.ply'  # Replace with desired file path
save_segmented_ply(data=nag[0], file_path=output_path, simple_env=True)
