In [None]:
%matplotlib inline
import sys
import os

HOME = os.environ["HOME"]
map_free_path = os.path.join(HOME, "map_free_localization/mapfree")

if os.path.exists(map_free_path):
    sys.path.append(map_free_path)

import torch
import numpy as np
import matplotlib.pyplot as plt

from lib.eval.sift_runner import SiftRunner
from lib.eval.mickey_runner import MicKeyRunner
from lib.dataset.mapfree import MapFreeDataset
from config.default import cfg
from lib.camera import Camera


In [None]:
data_dir = "../lib/tests/test_data"
config_path = os.path.join(data_dir, "testset.yaml")
config = cfg
config.set_new_allowed(True)
config.DEBUG = False

# You need to set this up to point at some data on your disk
if os.path.exists(config_path):
    config.merge_from_file(config_path)
    # explicitely setting to None because if loading from yaml it's a string
    config.DATASET.SCENES = ['s00001']
    config.DATASET.AUGMENTATION_TYPE = None
    config.DATASET.DATA_ROOT = '/media/jprincen/HD/Map Free Localization'
    config.DATASET.DEPTH_ROOT = '/media/jprincen/HD/Map Free Localization/mickey_depths'
else:
    config = None
dataset = MapFreeDataset(config, "val")

cl_config_path = "../config/MicKey/curriculum_learning.yaml"
checkpoint_path = "../weights/mickey.ckpt"
mickey_runner = MicKeyRunner(cl_config_path, checkpoint_path)


In [None]:
data = dataset[0]
ref_img = data['image0'].numpy()
query_img = data['image1'].numpy()
camera1 = Camera.from_K(data['K_color0'], data['W'], data['H'])
camera2 = Camera.from_K(data['K_color1'], data['W'], data['H'])
R, t, num_inliers, ref_pts, query_pts = mickey_runner.run_one(ref_img, query_img, camera1, camera2)

In [None]:
import matplotlib

def plot_images_correspondences(axis, x, y, color, img, title):
    axis.imshow(img)
    axis.scatter(x, y, color=color, s=1)
    axis.set_title(title)


In [None]:

# Transform the image from pytorch (C, H, W) to regular format (H, W, C) for display
ref_img = (
    np.transpose(data["image0"].numpy().squeeze(), (1, 2, 0)) * 255
).astype(np.uint8)
query_img = (
    np.transpose(data["image1"].numpy().squeeze(), (1, 2, 0)) * 255
).astype(np.uint8)

fig, axes = plt.subplots(1, 2, figsize=(20, 10))
plot_images_correspondences(axes[0], ref_pts[:, 0], ref_pts[:, 1], "red", ref_img, "Reference Image")
plot_images_correspondences(axes[1], query_pts[:, 0], query_pts[:, 1], "blue", query_img, "Query Image")

# Draw lines between them
line_sampling = 25
for i in range(ref_pts.shape[0]):
    if i % line_sampling == 0:
        con = matplotlib.patches.ConnectionPatch(xyA=ref_pts[i, :], xyB=query_pts[i, :], coordsA="data", coordsB="data",
                                  axesA=axes[0], axesB=axes[1], color="green")
        axes[1].add_artist(con)
    
    