In [None]:
import os
import numpy as np
import cv2
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch

from lightglue import LightGlue, SuperPoint, viz2d, match_pair
from lightglue.utils import rbd

from lac.perception.depth import project_pixel_to_rover
from lac.utils.frames import apply_transform
from lac.utils.plotting import plot_3d_points, plot_surface, plot_poses, plot_path_3d
from lac.util import load_data, grayscale_to_3ch_tensor
from lac.params import LAC_BASE_PATH, DT

%load_ext autoreload
%autoreload 2

In [None]:
data_path = Path(LAC_BASE_PATH) / "output/DataCollectionAgent/map1_preset0_stereo_lights1.0"
initial_pose, lander_pose, poses, imu_data, cam_config = load_data(data_path)
print(f"Num poses: {len(poses)}")

In [None]:
left_imgs = {}
right_imgs = {}

for img_name in os.listdir(data_path / "FrontLeft"):
    left_imgs[int(img_name.split(".")[0])] = cv2.imread(
        str(data_path / "FrontLeft" / img_name), cv2.IMREAD_GRAYSCALE
    )

for img_name in os.listdir(data_path / "FrontRight"):
    right_imgs[int(img_name.split(".")[0])] = cv2.imread(
        str(data_path / "FrontRight" / img_name), cv2.IMREAD_GRAYSCALE
    )

assert len(left_imgs.keys()) == len(right_imgs.keys())
img_idxs = sorted(left_imgs.keys())

In [None]:
extractor = SuperPoint(max_num_keypoints=2048).eval().cuda()
matcher = LightGlue(features="superpoint").eval().cuda()

In [None]:
image = left_imgs[1500]

feats = extractor.extract(grayscale_to_3ch_tensor(image).cuda())
feats = rbd(feats)

In [None]:
kps = feats["keypoints"]
good_kps = kps[feats["keypoint_scores"] > 0.05]
print(f"Num keypoints: {len(kps)}, {len(good_kps)}")

In [None]:
viz2d.plot_images([image])
viz2d.plot_keypoints([kps], ps=10)
viz2d.plot_keypoints([good_kps], colors=["red"], ps=10)

# LightGlue Tracking


In [None]:
prev_img = left_imgs[1500]
next_img = left_imgs[1502]

feats0, feats1, matches01 = match_pair(
    extractor,
    matcher,
    grayscale_to_3ch_tensor(prev_img).cuda(),
    grayscale_to_3ch_tensor(next_img).cuda(),
)
matches = matches01["matches"]  # indices with shape (K,2)
points0 = feats0["keypoints"][matches[..., 0]]  # coordinates in image #0, shape (K,2)
points1 = feats1["keypoints"][matches[..., 1]]  # coordinates in image #1, shape (K,2)

In [None]:
plt.figure(figsize=(10, 6))
plt.imshow(next_img, cmap="gray")
for i in range(len(matches)):
    plt.plot([points0[i, 0], points1[i, 0]], [points0[i, 1], points1[i, 1]], color="lime")
plt.axis("off")
plt.show()

# OpenCV LK Optical Flow


In [None]:
# Opencv optical flow
prev_img = left_imgs[1500]
next_img = left_imgs[1502]

prev_pts = kps.cpu().numpy()

lk_params = dict(
    winSize=(15, 15),
    maxLevel=3,
    criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03),
)

next_pts, status, err = cv2.calcOpticalFlowPyrLK(prev_img, next_img, prev_pts, None)
next_pts_tracked = next_pts[status.squeeze() == 1]
prev_pts_tracked = prev_pts[status.squeeze() == 1]

In [None]:
plt.figure(figsize=(10, 6))
plt.imshow(next_img, cmap="gray")
for new, old in zip(next_pts_tracked, prev_pts_tracked):
    a, b = new.ravel()
    c, d = old.ravel()
    plt.arrow(c, d, a - c, b - d, color="lime", head_width=1, head_length=2, linewidth=1)
plt.show()

In [None]:
prev_feats = extractor.extract(grayscale_to_3ch_tensor(prev_img).cuda())

tracked_feats = prev_feats.copy()
tracked_feats["keypoints"] = torch.from_numpy(next_pts_tracked).unsqueeze(0).cuda()
tracked_feats["keypoint_scores"] = prev_feats["keypoint_scores"][0][
    status.squeeze() == 1
].unsqueeze(0)
tracked_feats["descriptors"] = prev_feats["descriptors"][0][status.squeeze() == 1].unsqueeze(0)

In [None]:
next_feats = extractor.extract(grayscale_to_3ch_tensor(next_img).cuda())
matches = matcher({"image0": tracked_feats, "image1": next_feats})

In [None]:
matches = rbd(matches)["matches"]  # indices with shape (K,2)
points0 = rbd(tracked_feats)["keypoints"][matches[..., 0]]  # coordinates in image #0, shape (K,2)
points1 = rbd(next_feats)["keypoints"][matches[..., 1]]  # coordinates in image #1, shape (K,2)

points0 = points0.cpu().numpy()
points1 = points1.cpu().numpy()

In [None]:
from lac.localization.slam.feature_tracker import prune_features

In [None]:
prune_features(next_feats, matches[:, 1])

In [None]:
next_feats["keypoints"][0, matches[:, 1]]

In [None]:
plt.figure(figsize=(10, 6))
plt.imshow(next_img, cmap="gray")
for i in range(len(matches)):
    plt.plot([points0[i, 0], points1[i, 0]], [points0[i, 1], points1[i, 1]], color="lime")
plt.axis("off")
plt.show()