In [None]:
import numpy as np
import cv2
from tqdm import tqdm
from pathlib import Path
from gtsam.symbol_shorthand import X
import matplotlib.pyplot as plt

from lac.slam.feature_tracker import FeatureTracker, prune_features
from lac.perception.segmentation import UnetSegmentation
from lac.utils.plotting import plot_poses, plot_surface, plot_3d_points
from lac.util import load_data, load_stereo_images
from lac.utils.visualization import overlay_mask
from lac.params import LAC_BASE_PATH

%load_ext autoreload
%autoreload 2

In [None]:
# Load the data logs
data_path = Path(LAC_BASE_PATH) / "output/DataCollectionAgent/stereo_lights1.0_map1_preset1"
initial_pose, lander_pose, poses, imu_data, cam_config = load_data(data_path)
print(f"Loaded {len(poses)} poses")

# Load the images
left_imgs, right_imgs = load_stereo_images(data_path)

# Load the ground truth map
map = np.load(
    Path(LAC_BASE_PATH) / "data/heightmaps/competition/Moon_Map_01_preset_0.dat",
    allow_pickle=True,
)

In [None]:
segmentation = UnetSegmentation()
tracker = FeatureTracker(cam_config)

Frontend:

- Run both segmentation and feature extraction on left and right images
- Triangulate feature matches. For keypoints labeled as rock, group them together

Rock map:

- Each rock has a set of world points with associated descriptors (these descriptors should probably also be associated with a viewing direction since the apperance of a rock can change with viewing direction)
-

Graph SLAM:

- In the graph, we have a landmark for each rock corresponding to its centroid point
- For each keyframe, we determine the observed pixel centroid of the rock based on segementation outputs, and use that to add reprojection factors


# Detection

For each left, right image pair:

- For each image: a list of rock detections, where each detection is (pixel mask, pixel centroid, detected keypoints, keypoint descriptors, descriptor scores)
- Triangulated 3D points


In [None]:
START_FRAME = 80

In [None]:
FRAME = 200
left_image = left_imgs[FRAME]
right_image = right_imgs[FRAME]

fig, ax = plt.subplots(1, 2, figsize=(18, 10))
ax[0].imshow(left_image, cmap="gray")
ax[1].imshow(right_image, cmap="gray")
ax[0].axis("off")
ax[1].axis("off")
plt.subplots_adjust(wspace=0.03)
plt.show()

In [None]:
left_masks, left_full_mask = segmentation.segment_rocks(left_image)
right_masks, right_full_mask = segmentation.segment_rocks(right_image)

left_seg_overlay = overlay_mask(left_image, left_full_mask, color=(1, 0, 0))
right_seg_overlay = overlay_mask(right_image, right_full_mask, color=(1, 0, 0))

fig, ax = plt.subplots(1, 2, figsize=(18, 10))
ax[0].imshow(left_seg_overlay)
ax[1].imshow(right_seg_overlay)
ax[0].axis("off")
ax[1].axis("off")
plt.subplots_adjust(wspace=0.03)
plt.show()

In [None]:
left_feats, right_feats, matches, depths = tracker.process_stereo(left_image, right_image)

left_matched_feats = prune_features(left_feats, matches[:, 0])
left_matched_pts = left_matched_feats["keypoints"][0]
right_matched_feats = prune_features(right_feats, matches[:, 1])
right_matched_pts = right_matched_feats["keypoints"][0]

In [None]:
# Filter to points that are within segmentations

# Dilate the masks
kernel = np.ones((5, 5), np.uint8)
left_full_mask_dilated = cv2.dilate(left_full_mask, kernel, iterations=1)
right_full_mask_dilated = cv2.dilate(right_full_mask, kernel, iterations=1)

rock_pt_idxs = []

for i in range(len(left_matched_pts)):
    x_left, y_left = left_matched_pts[i]
    x_right, y_right = right_matched_pts[i]
    if (
        left_full_mask_dilated[int(y_left), int(x_left)]
        and right_full_mask_dilated[int(y_right), int(x_right)]
    ):
        rock_pt_idxs.append(i)

In [None]:
left_rock_matched_pts = left_matched_pts[rock_pt_idxs]
right_rock_matched_pts = right_matched_pts[rock_pt_idxs]
depths_rock_matched = depths[rock_pt_idxs]

In [None]:
from lightglue import viz2d

viz2d.plot_images([left_seg_overlay, right_seg_overlay])
viz2d.plot_matches(left_rock_matched_pts, right_rock_matched_pts, color="lime", lw=0.2)

In [None]:
rock_points = tracker.project_stereo(poses[FRAME], left_rock_matched_pts, depths_rock_matched)
plot_3d_points(rock_points)

In [None]:
from lac.perception.segmentation import SemanticClasses
from lac.perception.segmentation_util import dilate_mask

left_pred = segmentation.predict(left_image)
left_rock_mask = (left_pred == SemanticClasses.ROCKS.value).astype(np.uint8)
kernel = np.ones((5, 5), np.uint8)
left_rock_mask = cv2.dilate(left_rock_mask, kernel, iterations=1)

num_labels, labels = cv2.connectedComponents(left_rock_mask)

rock_pt_idxs = {}
MAX_DEPTH = 5.0

for i in range(len(left_matched_pts)):
    if depths[i] > MAX_DEPTH:
        continue
    x_left, y_left = left_matched_pts[i]
    x_right, y_right = right_matched_pts[i]
    id = labels[int(y_left), int(x_left)]
    if (
        left_rock_mask[int(y_left), int(x_left)] != 0
        and right_full_mask_dilated[int(y_right), int(x_right)] != 0
    ):
        if id not in rock_pt_idxs:
            rock_pt_idxs[id] = []
        rock_pt_idxs[id].append(i)

In [None]:
from lightglue import viz2d

viz2d.plot_images([left_seg_overlay, right_seg_overlay])
for key, val in rock_pt_idxs.items():
    left_pts = left_matched_pts[val]
    right_pts = right_matched_pts[val]
    newcolor = np.random.rand(3)
    viz2d.plot_matches(left_pts, right_pts, color=newcolor, lw=0.2)

# Tracking


In [None]:
from lac.slam.rock_tracker import RockTracker

In [None]:
rock_tracker = RockTracker(cam_config)
rock_points = rock_tracker.detect_rocks(poses[FRAME], left_image, right_image)

In [None]:
from norfair import Detection, Tracker

tracker = Tracker(distance_function="euclidean", distance_threshold=0.5)

# Rock map


In [None]:
from lac.perception.depth import stereo_depth_from_segmentation, project_rock_depths_to_world
import lac.params as params

In [None]:
img_idxs = sorted(list(left_imgs.keys()))

all_rock_world_points = []

for i in tqdm(range(START_FRAME, img_idxs[-1], 2)):
    left_img = left_imgs[i]
    right_img = right_imgs[i]

    # # Segment the images
    # left_masks, left_full_mask = segmentation.segment_rocks(left_img)
    # right_masks, right_full_mask = segmentation.segment_rocks(right_img)

    # # TODO: call the frontend
    # # also, ignore rocks that are too far (noisy depth)

    # stereo_depth_results = stereo_depth_from_segmentation(
    #     left_masks, right_masks, params.STEREO_BASELINE, params.FL_X
    # )
    # rock_world_points = project_rock_depths_to_world(
    #     stereo_depth_results, poses[i], "FrontLeft", cam_config
    # )
    rock_world_points = rock_tracker.detect_rocks(poses[i], left_img, right_img)
    all_rock_world_points.append(rock_world_points)

# Convert to numpy array
all_rock_world_points = np.concatenate(all_rock_world_points, axis=0)

In [None]:
all_rock_world_points.shape

In [None]:
fig = plot_surface(map)
fig = plot_3d_points(all_rock_world_points[::10], fig=fig, color="red")
fig.show()

In [None]:
fig.write_html("rock_points_stereo_seg.html")