In [1]:
# poss_dataset.py
import numpy as np
import yaml
import open3d.ml.torch as ml3d

# POSS (17)
POSS_LABELS = {
    0:  "unlabeled",
    1:  "1 person",
    2:  "2+ person",
    3:  "rider",
    4:  "car",
    5:  "trunk",
    6:  "plants",
    7:  "traffic sign 1",
    8:  "traffic sign 2",
    9:  "traffic sign 3",
    10: "pole",
    11: "trashcan",
    12: "building",
    13: "cone/stone",
    14: "fence",
    15: "bike",
    16: "ground",
}

class POSSDataset(ml3d.datasets.SemanticKITTI):

    def __init__(
        self,
        dataset_path: str,
        poss_yaml_path: str, # <--- poss.yaml
        name: str = "poss",
        cache_dir: str = "./logs/cache_poss",
        use_cache: bool = False,
        class_weights=None,
        ignored_label_inds=None,
        test_result_folder=None,
        test_split=None,
        training_split=None,
        validation_split=None,
        all_split=None,
        **kwargs,
    ):
        super().__init__(dataset_path=dataset_path,
                         name=name,
                         cache_dir=cache_dir,
                         use_cache=use_cache,
                         class_weights=class_weights,
                         ignored_label_inds=ignored_label_inds or [],
                         test_result_folder=test_result_folder,
                         training_split=training_split or ["00","01","02","03"],
                         validation_split=validation_split or ["04"],
                         test_split=test_split or ["05"],
                         all_split=all_split or ["00","01","02","03","04","05"],
                         **kwargs)

        self._poss_yaml_path = poss_yaml_path
        with open(self._poss_yaml_path, "r") as f:
            DATA = yaml.safe_load(f)

        self.label_to_names = self.get_label_to_names()
        self.num_classes = len(self.label_to_names)

        # learning_map_inv: train_id -> raw_id
        remap_dict_inv = DATA["learning_map_inv"]
        max_key = max(remap_dict_inv.keys()) if remap_dict_inv else 0
        remap_lut = np.zeros((max_key + 100), dtype=np.int32)
        remap_lut[list(remap_dict_inv.keys())] = list(remap_dict_inv.values())

        # learning_map: raw_id -> train_id
        remap_dict = DATA["learning_map"]
        max_key_val = max(remap_dict.keys()) if remap_dict else 0
        remap_lut_val = np.zeros((max_key_val + 100), dtype=np.int32)
        remap_lut_val[list(remap_dict.keys())] = list(remap_dict.values())

        self.remap_lut = remap_lut
        self.remap_lut_val = remap_lut_val

    @staticmethod
    def get_label_to_names():
        return dict(POSS_LABELS)


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
