In [1]:
%load_ext autoreload
%autoreload 2

# PlaceRecognitionPipeline with SOC

A module that implements a neural network algorithm for searching a database of places already visited by a vehicle for the most similar records using data from lidars and cameras and highlighting special elements of a three-dimensional scene (doors, buildings, street signs, etc.).

In [2]:
from copy import copy
from time import time
from pathlib import Path

import faiss

import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt

from hydra.utils import instantiate
import numpy as np
from omegaconf import OmegaConf
from scipy.spatial.transform import Rotation
import torch
from torch.utils.data import DataLoader

from opr.datasets import NCLTDataset
from opr.pipelines.place_recognition import PlaceRecognitionPipeline

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [3]:
def pose_to_matrix(pose):
    """From the 6D poses in the [tx ty tz qx qy qz qw] format to 4x4 pose matrices."""
    position = pose[:3]
    orientation_quat = pose[3:]
    rotation = Rotation.from_quat(orientation_quat)
    pose_matrix = np.eye(4)
    pose_matrix[:3,:3] = rotation.as_matrix()
    pose_matrix[:3,3] = position
    return pose_matrix


def compute_error(estimated_pose, gt_pose):
    """For the 6D poses in the [tx ty tz qx qy qz qw] format."""
    estimated_pose = pose_to_matrix(estimated_pose)
    gt_pose = pose_to_matrix(gt_pose)
    error_pose = np.linalg.inv(estimated_pose) @ gt_pose
    dist_error = np.sum(error_pose[:3, 3]**2) ** 0.5
    r = Rotation.from_matrix(error_pose[:3, :3])
    rotvec = r.as_rotvec()
    angle_error = (np.sum(rotvec**2)**0.5) * 180 / np.pi
    angle_error = abs(90 - abs(angle_error-90))
    return dist_error, angle_error


In [4]:
DATASET_ROOT = "/home/docker_opr/Datasets/NCLT_preprocessed"
SEMANTIC_ANNO = "/home/docker_opr/OpenPlaceRecognition/configs/dataset/anno/oneformer.yaml"

SENSOR_SUITE = ["image_Cam5", "image_Cam2", "mask_Cam5", "pointcloud_lidar"]

BATCH_SIZE = 32
NUM_WORKERS = 4
DEVICE = "cuda"

MODEL_CONFIG_PATH = "../../configs/model/place_recognition/multi-image_lidar_late-fusion.yaml"
WEIGHTS_PATH = "../../weights/place_recognition/multi-image_lidar_late-fusion_nclt.pth"
SOC_WEIGHTS_PATH = "../../weights/place_recognition/soc_nclt.pth"
SOC_CONFIG_PATH = "../../configs/model/place_recognition/soc_mixer.yaml"

In [5]:
TRACK_LIST = sorted([str(subdir.name) for subdir in Path(DATASET_ROOT).iterdir() if subdir.is_dir()])
print(f"Found {len(TRACK_LIST)} tracks")
print(TRACK_LIST)
print("WARNING: track list limited")
TRACK_LIST = TRACK_LIST[:2]
print(TRACK_LIST)

Found 10 tracks
['2012-01-08', '2012-01-22', '2012-02-12', '2012-02-18', '2012-03-31', '2012-05-26', '2012-08-04', '2012-10-28', '2012-11-04', '2012-12-01']
['2012-01-08', '2012-01-22']


## Init model

In [6]:
soc_config = OmegaConf.load(SOC_CONFIG_PATH)
soc_model = instantiate(soc_config)
soc_model.load_state_dict(torch.load(SOC_WEIGHTS_PATH)["model_state_dict"])
model_config = OmegaConf.load(MODEL_CONFIG_PATH)
model = instantiate(model_config)
model.load_state_dict(torch.load(WEIGHTS_PATH))
model.soc_module = soc_model
model = model.to(DEVICE)
model.eval();

## Calculate descriptors for databases

In [7]:
semantic_anno_cfg = OmegaConf.load(SEMANTIC_ANNO)
dataset = NCLTDataset(
    dataset_root=DATASET_ROOT,
    subset="test",
    data_to_load=SENSOR_SUITE,
    pointcloud_quantization_size=0.5,
    max_point_distance=None,
    load_soc=True,
    anno=semantic_anno_cfg,
    top_k_soc=5,
)
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=dataset.collate_fn,
)


In [8]:
descriptors = []
with torch.no_grad():
    for batch in tqdm(dataloader):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        final_descriptor = model(batch)["final_descriptor"]
        descriptors.append(final_descriptor.detach().cpu().numpy())

descriptors = np.concatenate(descriptors, axis=0)

100%|██████████| 86/86 [02:06<00:00,  1.47s/it]


In [9]:
dataset_df = dataset.dataset_df

for track, indices in dataset_df.groupby("track").groups.items():
    track_descriptors = descriptors[indices]
    track_index = faiss.IndexFlatL2(track_descriptors.shape[1])
    track_index.add(track_descriptors)
    faiss.write_index(track_index, f"{DATASET_ROOT}/{track}/index.faiss")
    print(f"Saved index {DATASET_ROOT}/{track}/index.faiss")


Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-01-08/index.faiss
Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-01-22/index.faiss
Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-02-12/index.faiss
Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-02-18/index.faiss
Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-03-31/index.faiss
Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-05-26/index.faiss
Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-08-04/index.faiss
Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-10-28/index.faiss
Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-11-04/index.faiss
Saved index /home/docker_opr/Datasets/NCLT_preprocessed/2012-12-01/index.faiss


# debug

# Test

In [10]:
PR_THRESHOLD = 25.0

test_csv = pd.read_csv(Path(DATASET_ROOT) / "test.csv", index_col=0)

all_recalls = []
all_mean_dist_errors = []
all_mean_angle_errors = []
all_median_dist_errors = []
all_median_angle_errors = []
all_times = []

for db_track in TRACK_LIST:
    pipe = PlaceRecognitionPipeline(
        database_dir=Path(DATASET_ROOT) / db_track,
        model=model,
        device=DEVICE,
    )
    for query_track in TRACK_LIST:
        if db_track == query_track:
            continue
        query_dataset = copy(dataset)
        query_dataset.dataset_df = query_dataset.dataset_df[query_dataset.dataset_df["track"] == query_track]
        query_df = pd.read_csv(Path(DATASET_ROOT) / query_track / "track.csv", index_col=0)

        # filter out only test subset
        query_df = query_df[query_df['image'].isin(query_dataset.dataset_df['image'])].reset_index(drop=True)
        # and do not forget to change the database_df in the pipeline
        pipe.database_df = pipe.database_df[pipe.database_df['image'].isin(test_csv['image'])].reset_index(drop=True)

        pr_matches = []
        dist_errors = []
        angle_errors = []
        times = []

        true_pairs = []
        false_pairs = []

        for q_i, query in tqdm(enumerate(query_dataset)):
            query["pose"] = query_df.iloc[q_i][["tx", "ty", "tz", "qx", "qy", "qz", "qw"]].to_numpy()
            t = time()
            output = pipe.infer(query)
            times.append(time() - t)
            dist_error, angle_error = compute_error(output["pose"], query["pose"])
            pr_matches.append(dist_error < PR_THRESHOLD)
            dist_errors.append(dist_error)
            angle_errors.append(angle_error)
            if dist_error < 10:
                true_pairs.append((q_i, output["idx"]))
            elif dist_error > 100:
                false_pairs.append((q_i, output["idx"]))

        all_recalls.append(np.mean(pr_matches))
        all_mean_dist_errors.append(np.mean(dist_errors))
        all_mean_angle_errors.append(np.mean(angle_errors))
        all_median_dist_errors.append(np.median(dist_errors))
        all_median_angle_errors.append(np.median(angle_errors))
        all_times.extend(times[1:]) # drop the first iteration cause it is always slower

275it [00:30,  8.92it/s]
331it [00:37,  8.78it/s]


In [11]:
np.array(all_recalls).mean()


0.928025267783576

In [12]:
np.mean(all_mean_dist_errors)

12.217167697169343

In [13]:
np.mean(all_median_dist_errors)


4.470159108054585

In [14]:
results_str = f"""Average Recall@1: {np.mean(all_recalls)*100:.2f}
Average mean dist error: {np.mean(all_mean_dist_errors):.2f}
Average mean angle error: {np.mean(all_mean_angle_errors):.2f}
Average median dist error: {np.mean(all_median_dist_errors):.2f}
Average median angle error: {np.mean(all_median_angle_errors):.2f}
"""

In [15]:
print(results_str)

Average Recall@1: 92.80
Average mean dist error: 12.22
Average mean angle error: 11.51
Average median dist error: 4.47
Average median angle error: 6.65

