In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os
import cv2

import sys
sys.path.append('..')

from shared.data import KITTIData,  VisualOdometry, draw_matches, draw_keypoints
from shared import common

%matplotlib widget
import matplotlib.pyplot as plt

In [None]:
DATASET_DIR = os.path.join('../', 'data/KITTI/dataset')
dataset = KITTIData(DATASET_DIR)
gt_poses = dataset.get_poses()
Q_left = dataset.get_left_Q_matrix()
C_left, _ = dataset.get_С_matrix()

In [None]:
from concurrent.futures import ProcessPoolExecutor

vo = VisualOdometry()

def process_transform(idx):
#     print(f'Start idx: {idx}')
    c_l_img, c_r_img = dataset.get_images(idx)
    n_l_img, n_r_img = dataset.get_images(idx+1)

    c_depth_frame = vo.process_depth(c_l_img, c_r_img, Q_left)
    n_depth_frame = vo.process_depth(n_l_img, n_r_img, Q_left)

    c_feats, n_feats = vo.get_features(c_l_img, n_l_img)

    ### Get 3D points 
    c_pnts_3d, c_ft_idxs = vo.reproject_2d_to_3d_points(c_feats, c_depth_frame)
    n_pnts_3d, n_ft_idxs = vo.reproject_2d_to_3d_points(n_feats, n_depth_frame)

    ft_idxs = c_ft_idxs & n_ft_idxs

    c_pnts_3d = c_pnts_3d[ft_idxs]
    n_pnts_3d = n_pnts_3d[ft_idxs]

    ### Filter 
    if True:
        cl_idxs, _ = vo.max_clique_filter(c_pnts_3d, n_pnts_3d)
        c_pnts_3d = c_pnts_3d[cl_idxs]
        n_pnts_3d = n_pnts_3d[cl_idxs]
        transform = vo.get_transform(c_pnts_3d, n_pnts_3d, C_left, type_='PnP')
    else:
        transform = vo.get_transform(c_pnts_3d, n_pnts_3d, C_left, type_='PnPRansac')
    
    print(f'Transform for idx {idx} done')
    return transform
    
poses_count = min(len(gt_poses)-1, 500)
with ProcessPoolExecutor(8) as ex:
    transforms = ex.map(process_transform, range(poses_count))
    
poses = [
    np.eye(4)
]

for transform in transforms:
    n_pose = vo.get_next_pose(transform, poses[-1])
    poses.append(n_pose)
    
poses = np.array(poses)

# TODO - PnP without RANSAC fails

In [None]:
plt.figure()
common.plot_trajectory(gt_poses)
plt.plot(poses[:,0,3],  poses[:,2,3])

plt.figure()
common.plot_trajectory(poses)