In [1]:
%load_ext autoreload
%autoreload 2

# SequencePointcloudRegistrationPipeline

A module that implements an algorithm for optimizing the position and orientation of a vehicle in space based on a sequence of multimodal data using neural network methods.

In [2]:
import copy

from pathlib import Path
from time import time

import faiss
import numpy as np
import open3d as o3d
import open3d.core as o3c
import pandas as pd
import torch
import torchshow as ts

from hydra.utils import instantiate
from omegaconf import OmegaConf
from scipy.spatial.transform import Rotation
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from geotransformer.utils.registration import compute_registration_error


from opr.datasets import NCLTDataset
from opr.pipelines.place_recognition import PlaceRecognitionPipeline
from opr.pipelines.registration.pointcloud import SequencePointcloudRegistrationPipeline
from opr.pipelines.localization import LocalizationPipeline

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


  check_for_updates()


In [3]:
import os

os.environ["DISPLAY"] = ":1"

torch.cuda.synchronize()

In [4]:
def pose_to_matrix(pose):
    """From the 6D poses in the [tx ty tz qx qy qz qw] format to 4x4 pose matrices."""
    position = pose[:3]
    orientation_quat = pose[3:]
    rotation = Rotation.from_quat(orientation_quat)
    pose_matrix = np.eye(4)
    pose_matrix[:3,:3] = rotation.as_matrix()
    pose_matrix[:3,3] = position
    return pose_matrix


# def compute_error(estimated_pose, gt_pose):
#     """For the 6D poses in the [tx ty tz qx qy qz qw] format."""
#     estimated_pose = pose_to_matrix(estimated_pose)
#     gt_pose = pose_to_matrix(gt_pose)
#     return compute_registration_error(estimated_pose, gt_pose)

def compute_error(estimated_pose, gt_pose):
    """For the 6D poses in the [tx ty tz qx qy qz qw] format."""
    estimated_pose = pose_to_matrix(estimated_pose)
    gt_pose = pose_to_matrix(gt_pose)
    error_pose = np.linalg.inv(estimated_pose) @ gt_pose
    dist_error = np.sum(error_pose[:3, 3]**2) ** 0.5
    r = Rotation.from_matrix(error_pose[:3, :3])
    rotvec = r.as_rotvec()
    angle_error = (np.sum(rotvec**2)**0.5) * 180 / np.pi
    angle_error = abs(90 - abs(angle_error-90))
    return angle_error, dist_error


def draw_pc(pc: Tensor, color: str = "blue"):
    pc_o3d = o3c.Tensor.from_dlpack(torch.utils.dlpack.to_dlpack(pc))
    pcd = o3d.t.geometry.PointCloud(pc_o3d)
    if color == "blue":
        c = [0.0, 0.0, 1.0]
    elif color == "red":
        c = [1.0, 0.0, 0.0]
    else:
        c = [0.0, 1.0, 0.0]
    pcd = pcd.paint_uniform_color(c)
    o3d.visualization.draw_geometries(
        [pcd.to_legacy()],
    )


def invert_rigid_transformation_matrix(T: np.ndarray) -> np.ndarray:
    """
    Inverts a 4x4 rigid body transformation matrix.

    Args:
        T (np.ndarray): A 4x4 rigid body transformation matrix.

    Returns:
        np.ndarray: The inverted 4x4 rigid body transformation matrix.
    """
    assert T.shape == (4, 4), "Input matrix must be 4x4."

    R = T[:3, :3]
    t = T[:3, 3]

    R_inv = R.T
    t_inv = -R.T @ t

    T_inv = np.eye(4)
    T_inv[:3, :3] = R_inv
    T_inv[:3, 3] = t_inv

    return T_inv


def draw_pc_pair(
    pc_blue: Tensor, pc_blue_pose: np.ndarray | Tensor, pc_red: Tensor, pc_red_pose: np.ndarray | Tensor
):
    pc_blue_o3d = o3c.Tensor.from_dlpack(torch.utils.dlpack.to_dlpack(copy.deepcopy(pc_blue)))
    pc_red_o3d = o3c.Tensor.from_dlpack(torch.utils.dlpack.to_dlpack(copy.deepcopy(pc_red)))

    blue_pcd = o3d.t.geometry.PointCloud(pc_blue_o3d)
    blue_pcd_tmp = copy.deepcopy(blue_pcd)

    red_pcd = o3d.t.geometry.PointCloud(pc_red_o3d)
    red_pcd_tmp = copy.deepcopy(red_pcd)

    blue_pcd_tmp.voxel_down_sample(voxel_size=0.3)
    # blue_pcd_tmp.transform(pose_to_matrix(pc_blue_pose))
    blue_pcd_tmp = blue_pcd_tmp.paint_uniform_color([0.0, 0.0, 1.0])

    red_pcd_tmp.voxel_down_sample(voxel_size=0.3)
    red_pcd_tmp.transform(pose_to_matrix(pc_red_pose))
    red_pcd_tmp.transform(invert_rigid_transformation_matrix(pose_to_matrix(pc_blue_pose)))
    red_pcd_tmp = red_pcd_tmp.paint_uniform_color([1.0, 0.0, 0.0])
    o3d.visualization.draw_geometries(
        [blue_pcd_tmp.to_legacy(), red_pcd_tmp.to_legacy()],
    )


You can download the dataset:

- Kaggle:
  - [NCLT_OpenPlaceRecognition](https://www.kaggle.com/datasets/creatorofuniverses/nclt-iprofi-hack-23)
- Hugging Face:
  - [NCLT_OpenPlaceRecognition](https://huggingface.co/datasets/OPR-Project/NCLT_OpenPlaceRecognition)


In [5]:
DATASET_ROOT = "/home/docker_opr/Datasets/OpenPlaceRecognition/NCLT_preprocessed"

SENSOR_SUITE = ["image_Cam5", "image_Cam2", "pointcloud_lidar"]

BATCH_SIZE = 32
NUM_WORKERS = 4
DEVICE = "cuda"

PR_MODEL_CONFIG_PATH = "../../configs/model/place_recognition/multi-image_lidar_late-fusion.yaml"
PR_WEIGHTS_PATH = "../../weights/place_recognition/multi-image_lidar_late-fusion_nclt.pth"

REGISTRATION_MODEL_CONFIG_PATH = "../../configs/model/registration/hregnet_light_feats.yaml"
REGISTRATION_WEIGHTS_PATH = "../../weights/registration/hregnet_light_feats_nuscenes.pth"

In [6]:
TRACK_LIST = sorted([str(subdir.name) for subdir in Path(DATASET_ROOT).iterdir() if subdir.is_dir()])
print(f"Found {len(TRACK_LIST)} tracks")
print(TRACK_LIST)

print("WARNING: track list limited")
TRACK_LIST = TRACK_LIST[:2]
print(TRACK_LIST)

Found 10 tracks
['2012-01-08', '2012-01-22', '2012-02-12', '2012-02-18', '2012-03-31', '2012-05-26', '2012-08-04', '2012-10-28', '2012-11-04', '2012-12-01']
['2012-01-08', '2012-01-22']


## Init models

In [7]:
pr_model_config = OmegaConf.load(PR_MODEL_CONFIG_PATH)
pr_model = instantiate(pr_model_config)
pr_model.load_state_dict(torch.load(PR_WEIGHTS_PATH))
pr_model = pr_model.to(DEVICE)
pr_model.eval();

reg_model_config = OmegaConf.load(REGISTRATION_MODEL_CONFIG_PATH)
reg_model = instantiate(reg_model_config)
reg_model.load_state_dict(torch.load(REGISTRATION_WEIGHTS_PATH))
reg_model = reg_model.to(DEVICE)
reg_model.eval();

## Calculate descriptors for databases

In [8]:
dataset = NCLTDataset(
    dataset_root=DATASET_ROOT,
    subset="test",
    data_to_load=SENSOR_SUITE,
    pointcloud_quantization_size=0.5,
    max_point_distance=None,
)
dataset.dataset_df = dataset.dataset_df[dataset.dataset_df["track"].isin(TRACK_LIST)]
dataset.dataset_df.reset_index(inplace=True)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=dataset.collate_fn,
)


In [9]:
descriptors = []
with torch.no_grad():
    for batch in tqdm(dataloader):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        final_descriptor = pr_model(batch)["final_descriptor"]
        descriptors.append(final_descriptor.detach().cpu().numpy())

descriptors = np.concatenate(descriptors, axis=0)

100%|██████████| 19/19 [00:04<00:00,  4.10it/s]


### Saving database indexes

In [10]:
dataset_df = dataset.dataset_df

for track, indices in dataset_df.groupby("track").groups.items():
    track_descriptors = descriptors[indices]
    track_index = faiss.IndexFlatL2(track_descriptors.shape[1])
    track_index.add(track_descriptors)
    faiss.write_index(track_index, f"{DATASET_ROOT}/{track}/index.faiss")
    print(f"Saved index {DATASET_ROOT}/{track}/index.faiss")


Saved index /home/docker_opr/Datasets/OpenPlaceRecognition/NCLT_preprocessed/2012-01-08/index.faiss
Saved index /home/docker_opr/Datasets/OpenPlaceRecognition/NCLT_preprocessed/2012-01-22/index.faiss


# Test

In [11]:
RECALL_THRESHOLD = 25.0

test_csv = pd.read_csv(Path(DATASET_ROOT) / "test.csv", index_col=0)

all_pr_recalls = []
all_reg_recalls = []  # it is recall after registration (if estimated pose within RECALL_THRESHOLD), do not confuse with registration recall

all_mean_pr_rotation_errors = []
all_mean_pr_translation_errors = []

all_median_pr_rotation_errors = []
all_median_pr_translation_errors = []

all_mean_reg_rotation_errors = []
all_mean_reg_translation_errors = []

all_median_reg_rotation_errors = []
all_median_reg_translation_errors = []

all_times = []

correct_examples = []  # the most representative correct pairs
pr_incorrect_examples = []  # the most representative incorrect pairs where place recognition failed
reg_incorrect_examples = []  # the most representative incorrect pairs where registration failed

for db_track in TRACK_LIST:
    pr_pipe = PlaceRecognitionPipeline(
        database_dir=Path(DATASET_ROOT) / db_track,
        model=pr_model,
        model_weights_path=PR_WEIGHTS_PATH,
        device=DEVICE,
    )
    for query_track in TRACK_LIST:
        if db_track == query_track:
            continue

        reg_pipe = SequencePointcloudRegistrationPipeline(
            model=reg_model,
            model_weights_path=REGISTRATION_WEIGHTS_PATH,
            device=DEVICE,
            voxel_downsample_size=0.3,
            num_points_downsample=8192,
        )
        loc_pipe = LocalizationPipeline(
            place_recognition_pipeline=pr_pipe,
            registration_pipeline=reg_pipe,
            precomputed_reg_feats=False,
            pointclouds_subdir="velodyne_data",
        )

        query_dataset = copy.deepcopy(dataset)
        query_dataset.dataset_df = query_dataset.dataset_df[query_dataset.dataset_df["track"] == query_track].reset_index(drop=True)
        query_df = pd.read_csv(Path(DATASET_ROOT) / query_track / "track.csv", index_col=0)
        query_df = query_df[query_df['image'].isin(query_dataset.dataset_df['image'])].reset_index(drop=True)

        db_dataset = copy.deepcopy(dataset)
        db_dataset.dataset_df = db_dataset.dataset_df[db_dataset.dataset_df["track"] == db_track].reset_index(drop=True)
        db_df = pd.read_csv(Path(DATASET_ROOT) / db_track / "track.csv", index_col=0)
        db_df = db_df[db_df['image'].isin(db_dataset.dataset_df['image'])].reset_index(drop=True)

        loc_pipe.pr_pipe.database_df = db_df
        loc_pipe.database_df = db_df

        pr_matches = []
        pr_rotation_errors = []
        pr_translation_errors = []

        reg_matches = []
        reg_rotation_errors = []
        reg_translation_errors = []

        times = []

        for q_i in tqdm(range(1, len(query_dataset))):
            # query_seq = [query_dataset[q_i-1], query_dataset[q_i]]
            query_seq = [query_dataset[q_i]]
            query = query_seq[-1]
            query_pose = query_df.iloc[q_i][["tx", "ty", "tz", "qx", "qy", "qz", "qw"]].to_numpy()

            t = time()
            output = loc_pipe.infer(query_seq)
            times.append(time() - t)

            pr_rotation_error, pr_translation_error = compute_error(output["db_match_pose"], query_pose)
            reg_rotation_error, reg_translation_error = compute_error(output["estimated_pose"], query_pose)

            pr_correct = pr_translation_error < RECALL_THRESHOLD
            reg_correct = reg_translation_error < RECALL_THRESHOLD

            pr_matches.append(pr_correct)
            pr_rotation_errors.append(pr_rotation_error)
            pr_translation_errors.append(pr_translation_error)

            reg_matches.append(reg_correct)
            reg_rotation_errors.append(reg_rotation_error)
            reg_translation_errors.append(reg_translation_error)

            if pr_correct and reg_correct \
                and reg_rotation_error < pr_rotation_error and reg_translation_error < pr_translation_error \
                and reg_rotation_error < 3.0 and reg_translation_error < 1.0:  # below median
                query["pose"] = query_pose
                db_match = db_dataset[output["db_match_idx"]]
                db_match["pose"] = output["db_match_pose"]
                correct_examples.append((query, db_match, output["estimated_pose"]))

            if pr_correct and not reg_correct:
                query["pose"] = query_pose
                db_match = db_dataset[output["db_match_idx"]]
                db_match["pose"] = output["db_match_pose"]
                reg_incorrect_examples.append((query, db_match, output["estimated_pose"]))

            if not pr_correct and pr_translation_error > 50.0:
                query["pose"] = query_pose
                db_match = db_dataset[output["db_match_idx"]]
                db_match["pose"] = output["db_match_pose"]
                pr_incorrect_examples.append((query, db_match, output["estimated_pose"]))

        all_pr_recalls.append(np.mean(pr_matches))
        all_reg_recalls.append(np.mean(reg_matches))

        all_mean_pr_rotation_errors.append(np.mean(pr_rotation_errors))
        all_mean_pr_translation_errors.append(np.mean(pr_translation_errors))
        all_median_pr_rotation_errors.append(np.median(pr_rotation_errors))
        all_median_pr_translation_errors.append(np.median(pr_translation_errors))

        all_mean_reg_rotation_errors.append(np.mean(reg_rotation_errors))
        all_mean_reg_translation_errors.append(np.mean(reg_translation_errors))
        all_median_reg_rotation_errors.append(np.median(reg_rotation_errors))
        all_median_reg_translation_errors.append(np.median(reg_translation_errors))
        all_times.extend(times[1:]) # drop the first iteration cause it is always slower

  output = torch.cuda.IntTensor(B, npoint)
100%|██████████| 274/274 [00:35<00:00,  7.83it/s]
100%|██████████| 330/330 [00:42<00:00,  7.73it/s]


In [12]:
len(correct_examples), len(pr_incorrect_examples), len(reg_incorrect_examples)

(73, 27, 11)

In [13]:
results_str = f"""Average PR Recall@1:   {np.mean(all_pr_recalls)*100:.2f}
Average REG Recall@1:  {np.mean(all_reg_recalls)*100:.2f}

Average Mean RRE PR:   {np.mean(all_mean_pr_rotation_errors):.2f}
Average Mean RTE PR:   {np.mean(all_mean_pr_translation_errors):.2f}
Average Median RRE PR:   {np.mean(all_median_pr_rotation_errors):.2f}
Average Median RTE PR:   {np.mean(all_median_pr_translation_errors):.2f}

Average Mean RRE REG:  {np.mean(all_mean_reg_rotation_errors):.2f}
Average Mean RTE REG:  {np.mean(all_mean_reg_translation_errors):.2f}
Average Median RRE REG:  {np.mean(all_median_reg_rotation_errors):.2f}
Average Median RTE REG:  {np.mean(all_median_reg_translation_errors):.2f}

Mean inference time (PR + REG):     {np.mean(all_times)*1000:.2f} ms
"""

In [14]:
print(results_str)

Average PR Recall@1:   92.78
Average REG Recall@1:  91.75

Average Mean RRE PR:   11.52
Average Mean RTE PR:   12.24
Average Median RRE PR:   6.59
Average Median RTE PR:   4.46

Average Mean RRE REG:  12.99
Average Mean RTE REG:  13.23
Average Median RRE REG:  6.16
Average Median RTE REG:  4.51

Mean inference time (PR + REG):     107.66 ms



In [15]:
loc_pipe.reg_pipe.stats_history.keys()
print(f"Mean registration time: {np.mean(loc_pipe.reg_pipe.stats_history['total_time'][1:]) * 1000:.2f} ms")

Mean registration time: 50.00 ms


In [None]:
correct_example = correct_examples[10]

query_sample, db_match_sample, estimated_pose = correct_example

# draw_pc_pair(
#     query_sample["pointcloud_lidar_coords"],
#     estimated_pose,
#     db_match_sample["pointcloud_lidar_coords"],
#     db_match_sample["pose"],
# )

ts.show([
    query_sample["image_Cam5"], query_sample["image_Cam2"],
])
ts.show([
    db_match_sample["image_Cam5"], db_match_sample["image_Cam2"],
])
print(f"PR pose error: {compute_error(db_match_sample['pose'], query_sample['pose'])}")
print(f"REG pose error: {compute_error(estimated_pose, query_sample['pose'])}")

In [None]:
pr_incorrect_example = pr_incorrect_examples[0]

query_sample, db_match_sample, estimated_pose = pr_incorrect_example

# draw_pc_pair(
#     query_sample["pointcloud_lidar_coords"],
#     estimated_pose,
#     db_match_sample["pointcloud_lidar_coords"],
#     db_match_sample["pose"],
# )

ts.show([
    query_sample["image_Cam5"], query_sample["image_Cam2"],
])
ts.show([
    db_match_sample["image_Cam5"], db_match_sample["image_Cam2"],
])
print(f"PR pose error: {compute_error(db_match_sample['pose'], query_sample['pose'])}")
print(f"REG pose error: {compute_error(estimated_pose, query_sample['pose'])}")

In [None]:
reg_incorrect_example = reg_incorrect_examples[0]

query_sample, db_match_sample, estimated_pose = reg_incorrect_example

# draw_pc_pair(
#     query_sample["pointcloud_lidar_coords"],
#     estimated_pose,
#     db_match_sample["pointcloud_lidar_coords"],
#     db_match_sample["pose"],
# )

ts.show([
    query_sample["image_Cam5"], query_sample["image_Cam2"],
])
ts.show([
    db_match_sample["image_Cam5"], db_match_sample["image_Cam2"],
])
print(f"PR pose error: {compute_error(db_match_sample['pose'], query_sample['pose'])}")
print(f"REG pose error: {compute_error(estimated_pose, query_sample['pose'])}")