In [1]:
import os
import copy
import random
from time import time

import cv2
import faiss
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
from omegaconf import OmegaConf
from hydra.utils import instantiate
from torch.utils.data import DataLoader
from scipy.spatial.transform import Rotation
from geotransformer.utils.pointcloud import get_transform_from_rotation_translation

from opr.datasets.itlp import ITLPCampus
from opr.pipelines.localization import LocalizationPipeline
from opr.pipelines.place_recognition import PlaceRecognitionPipeline
from opr.pipelines.registration import PointcloudRegistrationPipeline

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
DATASET_ROOT = "/home/docker_opr/Datasets/OpenPlaceRecognition/itlp_campus_outdoor_part2"
SENSOR_SUITE = ["front_cam", "back_cam", "lidar"]

TRACK_LIST = [
    "00_2023-02-10",
    "03_2023-04-11",
    "05_2023-08-15-day",
    "07_2023-10-04-day",
]

SEASON_MAPPING = {
    "00_2023-02-10": "winter",
    "03_2023-04-11": "spring",
    "05_2023-08-15-day": "summer",
    "07_2023-10-04-day": "fall",
}

BATCH_SIZE = 16
NUM_WORKERS = 4
DEVICE = "cuda:0"

In [4]:
def set_seed(seed: int = 18) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

set_seed()

Random seed set as 18


In [5]:
def pose_to_matrix(pose):
    """From the 6D poses in the [tx ty tz qx qy qz qw] format to 4x4 pose matrices."""
    position = pose[:3]
    orientation_quat = pose[3:]
    rotation = Rotation.from_quat(orientation_quat)
    pose_matrix = np.eye(4)
    pose_matrix[:3,:3] = rotation.as_matrix()
    pose_matrix[:3,3] = position
    return pose_matrix

def compute_error(estimated_pose, gt_pose):
    """For the 6D poses in the [tx ty tz qx qy qz qw] format."""
    estimated_pose = pose_to_matrix(estimated_pose)
    gt_pose = pose_to_matrix(gt_pose)
    error_pose = np.linalg.inv(estimated_pose) @ gt_pose
    dist_error = np.sum(error_pose[:3, 3]**2) ** 0.5
    r = Rotation.from_matrix(error_pose[:3, :3])
    rotvec = r.as_rotvec()
    angle_error = (np.sum(rotvec**2)**0.5) * 180 / np.pi
    angle_error = abs(90 - abs(angle_error-90))
    return angle_error, dist_error

def inference():
    for db_track in TRACK_LIST:
        pr_pipe = PlaceRecognitionPipeline(
            database_dir=Path(DATASET_ROOT) / db_track,
            model=pr_model,
            model_weights_path=PR_WEIGHTS_PATH,
            device=DEVICE,
        )
        for query_track in TRACK_LIST:
            if db_track == query_track:
                continue

            reg_pipe = PointcloudRegistrationPipeline(
                model=reg_model,
                model_weights_path=REGISTRATION_WEIGHTS_PATH,
                device=DEVICE,
                voxel_downsample_size=0.3,
                num_points_downsample=8192,
            )
            loc_pipe = LocalizationPipeline(
                place_recognition_pipeline=pr_pipe,
                registration_pipeline=reg_pipe,
                precomputed_reg_feats=True,
                pointclouds_subdir="lidar"
            )

            query_dataset = copy.deepcopy(dataset)
            query_dataset.dataset_df = query_dataset.dataset_df[query_dataset.dataset_df["track"] == query_track].reset_index(drop=True)
            query_df = query_dataset.dataset_df

            db_dataset = copy.deepcopy(dataset)
            db_dataset.dataset_df = db_dataset.dataset_df[db_dataset.dataset_df["track"] == db_track].reset_index(drop=True)
            db_df = db_dataset.dataset_df

            loc_pipe.pr_pipe.database_df = db_df
            loc_pipe.database_df = db_df

            reg_matches = []
            reg_rotation_errors = []
            reg_translation_errors = []
            times = []

            for q_i, query in tqdm(enumerate(query_dataset)):
                query_pose = query_df.iloc[q_i][["tx", "ty", "tz", "qx", "qy", "qz", "qw"]].to_numpy()
                start = time()
                estimated_pose = loc_pipe.infer(query)["estimated_pose"]
                torch.cuda.current_stream().synchronize()
                step_time = time() - start
                times.append(step_time)

                reg_rotation_error, reg_translation_error = compute_error(estimated_pose, query_pose)
                reg_correct = reg_translation_error < RECALL_THRESHOLD
                reg_matches.append(reg_correct)
                reg_rotation_errors.append(reg_rotation_error)
                reg_translation_errors.append(reg_translation_error)

            key_str = f"DB {SEASON_MAPPING[db_track]}, Query {SEASON_MAPPING[query_track]}"
            all_reg_recalls[key_str] = np.nanmean(reg_matches)
            all_mean_reg_rotation_errors[key_str] = np.nanmean(reg_rotation_errors)
            all_mean_reg_translation_errors[key_str] = np.nanmean(reg_translation_errors)
            all_median_reg_rotation_errors[key_str] = np.nanmedian(reg_rotation_errors)
            all_median_reg_translation_errors[key_str] = np.nanmedian(reg_translation_errors)
            all_times.extend(times[1:])

In [6]:
from albumentations.pytorch import ToTensorV2
import albumentations as A
from opr.datasets.augmentations import DefaultImageTransform

class ToTensorTransform:
    def __init__(self):
        transform_list = [ToTensorV2()]
        self.transform = A.Compose(transform_list)

    def __call__(self, img: np.ndarray):
        """Applies transformations to the given image."""
        return self.transform(image=img)["image"]

dataset = ITLPCampus(
    dataset_root=DATASET_ROOT,
    subset="test",
    csv_file="full_test.csv",
    sensors=SENSOR_SUITE,
    load_semantics=True,
    exclude_dynamic_classes=True,
    image_transform=ToTensorTransform(),
    semantic_transform=ToTensorTransform(),
    late_image_transform=DefaultImageTransform(resize=(320, 192), train=False)
)
dataset.dataset_df = dataset.dataset_df[dataset.dataset_df["track"].isin(TRACK_LIST)]
dataset.dataset_df.reset_index(inplace=True)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=dataset.collate_fn,
)

In [7]:
REGISTRATION_MODEL_CONFIG_PATH = "../../configs/model/registration/hregnet_light_feats.yaml"
REGISTRATION_WEIGHTS_PATH = "../../weights/registration/hregnet_light_feats_nuscenes.pth"
reg_model_config = OmegaConf.load(REGISTRATION_MODEL_CONFIG_PATH)
reg_model = instantiate(reg_model_config)
reg_model.load_state_dict(torch.load(REGISTRATION_WEIGHTS_PATH))
reg_model = reg_model.to(DEVICE)
reg_model.eval();

In [8]:
# MIPT finetune

all_reg_recalls = {}
all_mean_reg_rotation_errors = {}
all_mean_reg_translation_errors = {}
all_median_reg_rotation_errors = {}
all_median_reg_translation_errors = {}
all_times = []
RECALL_THRESHOLD = 25.0

PR_MODEL_CONFIG_PATH = "../../configs/model/place_recognition/multi-image_lidar_late-fusion.yaml"
PR_WEIGHTS_PATH = "../../weights/place_recognition/multi-image_lidar_late-fusion_itlp-finetune.pth"
pr_model_config = OmegaConf.load(PR_MODEL_CONFIG_PATH)
pr_model = instantiate(pr_model_config)
pr_model.load_state_dict(torch.load(PR_WEIGHTS_PATH))
pr_model = pr_model.to(DEVICE)
pr_model.eval();

In [9]:
descriptors = []
with torch.no_grad():
    for batch in tqdm(dataloader):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        final_descriptor = pr_model(batch)["final_descriptor"]
        descriptors.append(final_descriptor.detach().cpu().numpy())
descriptors = np.concatenate(descriptors, axis=0)

dataset_df = dataset.dataset_df
for track, indices in dataset_df.groupby("track").groups.items():
    track_descriptors = descriptors[indices]
    track_index = faiss.IndexFlatL2(track_descriptors.shape[1])
    track_index.add(track_descriptors)
    faiss.write_index(track_index, f"{DATASET_ROOT}/{track}/index.faiss")
    print(f"Saved index {DATASET_ROOT}/{track}/index.faiss")

100%|██████████| 37/37 [00:58<00:00,  1.59s/it]

Saved index /home/docker_opr/Datasets/OpenPlaceRecognition/itlp_campus_outdoor_part2/00_2023-02-10/index.faiss
Saved index /home/docker_opr/Datasets/OpenPlaceRecognition/itlp_campus_outdoor_part2/03_2023-04-11/index.faiss
Saved index /home/docker_opr/Datasets/OpenPlaceRecognition/itlp_campus_outdoor_part2/05_2023-08-15-day/index.faiss
Saved index /home/docker_opr/Datasets/OpenPlaceRecognition/itlp_campus_outdoor_part2/07_2023-10-04-day/index.faiss





In [10]:
inference()

136it [00:52,  2.59it/s]
152it [00:49,  3.08it/s]
152it [00:49,  3.04it/s]
139it [00:45,  3.03it/s]
152it [00:48,  3.11it/s]
152it [00:48,  3.11it/s]
139it [00:46,  2.99it/s]
136it [00:43,  3.12it/s]
152it [00:49,  3.05it/s]
139it [00:45,  3.03it/s]
136it [00:44,  3.08it/s]
152it [00:49,  3.09it/s]


In [11]:
print("Recall@1:")
for key, value in all_reg_recalls.items():
    print(f"{key}: {value*100:.2f}")

print(f"Mean: {np.nanmean(list(all_reg_recalls.values()))*100:.2f}")

Recall@1:
DB winter, Query spring: 100.00
DB winter, Query summer: 92.11
DB winter, Query fall: 99.34
DB spring, Query winter: 99.28
DB spring, Query summer: 96.05
DB spring, Query fall: 99.34
DB summer, Query winter: 94.96
DB summer, Query spring: 91.91
DB summer, Query fall: 100.00
DB fall, Query winter: 96.40
DB fall, Query spring: 94.85
DB fall, Query summer: 100.00
Mean: 97.02


In [12]:
print("Median RRE:")
for key, value in all_median_reg_rotation_errors.items():
    print(f"{key}: {value:.2f}")

print(f"Mean: {np.nanmean(list(all_median_reg_rotation_errors.values())):.2f}")

Median RRE:
DB winter, Query spring: 1.78
DB winter, Query summer: 7.25
DB winter, Query fall: 4.43
DB spring, Query winter: 1.75
DB spring, Query summer: 5.76
DB spring, Query fall: 4.44
DB summer, Query winter: 7.25
DB summer, Query spring: 5.58
DB summer, Query fall: 4.91
DB fall, Query winter: 4.13
DB fall, Query spring: 4.65
DB fall, Query summer: 4.27
Mean: 4.68


In [13]:
print("Median RTE:")
for key, value in all_median_reg_translation_errors.items():
    print(f"{key}: {value:.2f}")


print(f"Mean: {np.nanmean(list(all_median_reg_translation_errors.values())):.2f}")

Median RTE:
DB winter, Query spring: 1.05
DB winter, Query summer: 4.25
DB winter, Query fall: 4.37
DB spring, Query winter: 0.73
DB spring, Query summer: 3.29
DB spring, Query fall: 4.06
DB summer, Query winter: 3.82
DB summer, Query spring: 3.56
DB summer, Query fall: 2.99
DB fall, Query winter: 4.18
DB fall, Query spring: 4.49
DB fall, Query summer: 3.17
Mean: 3.33


In [14]:
print(f"Mean inference time: {np.nanmean(all_times) * 1000:.2f} ms")

Mean inference time: 119.42 ms
