In [None]:
import numpy as np
import cv2
from tqdm import tqdm
from pathlib import Path
from gtsam.symbol_shorthand import X
import matplotlib.pyplot as plt

from lac.slam.feature_tracker import FeatureTracker, prune_features
from lac.perception.segmentation import UnetSegmentation
from lac.utils.plotting import plot_poses, plot_surface, plot_3d_points
from lac.util import load_data, load_stereo_images
from lac.utils.visualization import overlay_mask
from lac.params import LAC_BASE_PATH

%load_ext autoreload
%autoreload 2

In [None]:
# Load the data logs
data_path = Path("/home/shared/data_raw/LAC/segmentation/slam_map1_preset1_teleop")
initial_pose, lander_pose, poses, imu_data, cam_config = load_data(data_path)
print(f"Loaded {len(poses)} poses")

# Load the images
left_imgs, right_imgs = load_stereo_images(data_path)

# Load the ground truth map
map = np.load(
    "/home/shared/data_raw/LAC/heightmaps/competition/Moon_Map_01_preset_1.dat",
    allow_pickle=True,
)

In [None]:
segmentation = UnetSegmentation()
feature_tracker = FeatureTracker(cam_config)

Frontend:

- Run both segmentation and feature extraction on left and right images
- Triangulate feature matches. For keypoints labeled as rock, group them together

Rock map:

- Each rock has a set of world points with associated descriptors (these descriptors should probably also be associated with a viewing direction since the apperance of a rock can change with viewing direction)
-

Graph SLAM:

- In the graph, we have a landmark for each rock corresponding to its centroid point
- For each keyframe, we determine the observed pixel centroid of the rock based on segementation outputs, and use that to add reprojection factors


# Detection

For each left, right image pair:

- For each image: a list of rock detections, where each detection is (pixel mask, pixel centroid, detected keypoints, keypoint descriptors, descriptor scores)
- Triangulated 3D points


In [None]:
START_FRAME = 80

In [None]:
FRAME = 200
left_image = left_imgs[FRAME]
right_image = right_imgs[FRAME]

fig, ax = plt.subplots(1, 2, figsize=(18, 10))
ax[0].imshow(left_image, cmap="gray")
ax[1].imshow(right_image, cmap="gray")
ax[0].axis("off")
ax[1].axis("off")
plt.subplots_adjust(wspace=0.03)
plt.show()

In [None]:
left_masks, left_labels = segmentation.segment_rocks(left_image)
right_masks, right_labels = segmentation.segment_rocks(right_image)
left_full_mask = np.clip(left_labels, 0, 1).astype(np.uint8)
right_full_mask = np.clip(right_labels, 0, 1).astype(np.uint8)

left_seg_overlay = overlay_mask(left_image, left_full_mask, color=(1, 0, 0))
right_seg_overlay = overlay_mask(right_image, right_full_mask, color=(1, 0, 0))

fig, ax = plt.subplots(1, 2, figsize=(18, 10))
ax[0].imshow(left_seg_overlay)
ax[1].imshow(right_seg_overlay)
ax[0].axis("off")
ax[1].axis("off")
plt.subplots_adjust(wspace=0.03)
plt.show()

In [None]:
left_feats, right_feats, matches, depths = feature_tracker.process_stereo(left_image, right_image)

left_matched_feats = prune_features(left_feats, matches[:, 0])
left_matched_pts = left_matched_feats["keypoints"][0]
right_matched_feats = prune_features(right_feats, matches[:, 1])
right_matched_pts = right_matched_feats["keypoints"][0]

In [None]:
# Filter to points that are within segmentations

# Dilate the masks
kernel = np.ones((5, 5), np.uint8)
left_full_mask_dilated = cv2.dilate(left_full_mask, kernel, iterations=1)
right_full_mask_dilated = cv2.dilate(right_full_mask, kernel, iterations=1)

rock_pt_idxs = []

for i in range(len(left_matched_pts)):
    x_left, y_left = left_matched_pts[i]
    x_right, y_right = right_matched_pts[i]
    if (
        left_full_mask_dilated[int(y_left), int(x_left)]
        and right_full_mask_dilated[int(y_right), int(x_right)]
    ):
        rock_pt_idxs.append(i)

In [None]:
left_rock_matched_pts = left_matched_pts[rock_pt_idxs]
right_rock_matched_pts = right_matched_pts[rock_pt_idxs]
depths_rock_matched = depths[rock_pt_idxs]

In [None]:
from lightglue import viz2d

viz2d.plot_images([left_seg_overlay, right_seg_overlay])
viz2d.plot_matches(left_rock_matched_pts, right_rock_matched_pts, color="lime", lw=0.2)

In [None]:
rock_points = feature_tracker.project_stereo(
    poses[FRAME], left_rock_matched_pts, depths_rock_matched
)
plot_3d_points(rock_points)

In [None]:
from lac.perception.segmentation import SemanticClasses
from lac.perception.segmentation_util import dilate_mask

left_pred = segmentation.predict(left_image)
left_rock_mask = (left_pred == SemanticClasses.ROCK.value).astype(np.uint8)
kernel = np.ones((5, 5), np.uint8)
left_rock_mask = cv2.dilate(left_rock_mask, kernel, iterations=1)

num_labels, labels = cv2.connectedComponents(left_rock_mask)

rock_pt_idxs = {}
MAX_DEPTH = 5.0

for i in range(len(left_matched_pts)):
    if depths[i] > MAX_DEPTH:
        continue
    x_left, y_left = left_matched_pts[i]
    x_right, y_right = right_matched_pts[i]
    id = labels[int(y_left), int(x_left)]
    if (
        left_rock_mask[int(y_left), int(x_left)] != 0
        and right_full_mask_dilated[int(y_right), int(x_right)] != 0
    ):
        if id not in rock_pt_idxs:
            rock_pt_idxs[id] = []
        rock_pt_idxs[id].append(i)

In [None]:
from lightglue import viz2d

viz2d.plot_images([left_seg_overlay, right_seg_overlay])
for key, val in rock_pt_idxs.items():
    left_pts = left_matched_pts[val]
    right_pts = right_matched_pts[val]
    newcolor = np.random.rand(3)
    viz2d.plot_matches(left_pts, right_pts, color=newcolor, lw=0.2)

# Tracking


In [None]:
from norfair import Detection, Tracker
import plotly.graph_objects as go
from lac.perception.depth import (
    stereo_depth_from_segmentation,
    project_pixel_to_world,
    project_pixels_to_world,
)
from lac.params import STEREO_BASELINE, FL_X
from lac.utils.plotting import plot_rock_map
from lac.utils.visualization import int_to_color

In [None]:
img_idxs = sorted(left_imgs.keys())
tracker = Tracker(distance_function="euclidean", distance_threshold=100, hit_counter_max=5)
rock_detections = {}  # id -> (3d point, pixel, frame_num)

In [None]:
from lac.perception.segmentation_util import (
    get_mask_centroids,
    centroid_matching,
)


def get_valid_region(center, size, mask_shape):
    cx, cy = center
    w, h = size
    H, W = mask_shape

    x0 = max(cx - w // 2, 0)
    y0 = max(cy - h // 2, 0)
    x1 = min(cx + (w + 1) // 2, W)
    y1 = min(cy + (h + 1) // 2, H)

    sx0 = max(0, -(cx - w // 2))
    sy0 = max(0, -(cy - h // 2))
    sx1 = sx0 + (x1 - x0)
    sy1 = sy0 + (y1 - y0)

    return (x0, y0, x1, y1), (sx0, sy0, sx1, sy1)


def insert_submask(mask, submask, center):
    (x0, y0, x1, y1), (sx0, sy0, sx1, sy1) = get_valid_region(
        center, submask.shape[::-1], mask.shape
    )
    new_mask = mask.copy()
    new_mask[y0:y1, x0:x1] = submask[sy0:sy1, sx0:sx1]
    return new_mask


def get_submask_at_center(mask, size, center, pad_value=0):
    cx, cy = center
    w, h = size
    H, W = mask.shape

    x0 = cx - w // 2
    y0 = cy - h // 2
    x1 = x0 + w
    y1 = y0 + h

    px0 = max(0, -x0)
    py0 = max(0, -y0)
    px1 = max(0, x1 - W)
    py1 = max(0, y1 - H)

    x0_clipped = max(x0, 0)
    y0_clipped = max(y0, 0)
    x1_clipped = min(x1, W)
    y1_clipped = min(y1, H)

    cropped = mask[y0_clipped:y1_clipped, x0_clipped:x1_clipped]
    if any([py0, py1, px0, px1]):
        cropped = np.pad(
            cropped, ((py0, py1), (px0, px1)), mode="constant", constant_values=pad_value
        )

    return cropped


points_world = []
START_FRAME = 80
END_FRAME = START_FRAME + 2000

for frame in tqdm(range(START_FRAME, END_FRAME, 2)):
    left_seg_masks, left_seg_labels = segmentation.segment_rocks(left_imgs[frame])
    right_seg_masks, right_seg_labels = segmentation.segment_rocks(right_imgs[frame])
    left_seg_full_mask = np.clip(left_seg_labels, 0, 1)

    left_rock_centroids = get_mask_centroids(left_seg_masks)
    right_rock_centroids = get_mask_centroids(right_seg_masks)

    matches = centroid_matching(left_rock_centroids, right_rock_centroids)
    disparities_centroids = [
        left_rock_centroids[match[0]][0] - right_rock_centroids[match[1]][0] + 1e-8
        for match in matches
    ]
    depths_centroids = (FL_X * STEREO_BASELINE) / disparities_centroids

    idxs = np.argsort(depths_centroids)
    for i in idxs:
        if depths_centroids[i] > 5.0:
            continue

        match = matches[i]
        left_centroid = left_rock_centroids[match[0]]
        right_centroid = right_rock_centroids[match[1]]
        left_mask = left_seg_masks[match[0]]
        right_mask = right_seg_masks[match[1]]

        width = np.max(np.where(left_mask)[1]) - np.min(np.where(left_mask)[1]) + 1
        height = np.max(np.where(left_mask)[0]) - np.min(np.where(left_mask)[0]) + 1
        size = (width, height)

        best_submask, best_score = None, 0
        dx, dy = 0, 0
        # for dx in range(-2, 2):
        #     for dy in range(-2, 2):
        l_center = (left_centroid[0] + dx, left_centroid[1] + dy)
        r_center = (right_centroid[0] + dx, right_centroid[1] + dy)

        new_mask_left = np.logical_and(
            left_mask, insert_submask(np.zeros_like(left_mask), np.ones(size, np.uint8), l_center)
        )
        new_mask_right = np.logical_and(
            right_mask, insert_submask(np.zeros_like(right_mask), np.ones(size, np.uint8), r_center)
        )

        submask_left = get_submask_at_center(new_mask_left, size, left_centroid)
        submask_right = get_submask_at_center(new_mask_right, size, right_centroid)

        submask = np.logical_and(submask_left, submask_right)
        score = np.sum(submask)
        if score > best_score:
            best_score = score
            best_submask = submask

        if best_submask is None:
            continue

        left_mask_common = insert_submask(np.zeros_like(left_mask), best_submask, left_centroid)
        right_mask_common = insert_submask(np.zeros_like(right_mask), best_submask, right_centroid)
        new_left_centroid = np.mean(np.array(np.argwhere(left_mask_common)).T, axis=1)[::-1]
        new_right_centroid = np.mean(np.array(np.argwhere(right_mask_common)).T, axis=1)[::-1]

        disparity = new_left_centroid[0] - new_right_centroid[0] + 1e-8
        depth = (FL_X * STEREO_BASELINE) / disparity
        points_world.append(
            project_pixel_to_world(poses[frame], new_left_centroid, depth, "FrontLeft", cam_config)
        )

        # pixels_left = np.stack(np.where(left_mask_common)).T[:, ::-1]
        # pixels_right = np.stack(np.where(right_mask_common)).T[:, ::-1]
        # disparities = pixels_left[:, 0] - pixels_right[:, 0] + 1e-8
        # depths = (FL_X * STEREO_BASELINE) / disparities
        # points_world.extend(project_pixels_to_world(poses[frame], pixels_left, depths, "FrontLeft", cam_config))

points_world = np.array(points_world)

In [None]:
fig = go.Figure()
fig = plot_rock_map(map, fig=fig)
fig = plot_poses(poses, fig=fig, no_axes=True, color="black")
fig.add_scatter3d(
    x=points_world[:, 0],
    y=points_world[:, 1],
    z=points_world[:, 2],
    mode="markers",
    marker=dict(size=3, color="green"),
)
fig.show()

In [None]:
# END_FRAME = img_idxs[-1]
END_FRAME = 2000
for frame in tqdm(range(2, END_FRAME, 2)):
    left_seg_masks, left_seg_labels = segmentation.segment_rocks(left_imgs[frame])
    right_seg_masks, right_seg_labels = segmentation.segment_rocks(right_imgs[frame])
    left_seg_full_mask = np.clip(left_seg_labels, 0, 1)

    stereo_depth_results = stereo_depth_from_segmentation(
        left_seg_masks, right_seg_masks, STEREO_BASELINE, FL_X
    )

    detections = []
    centroids = []
    for result in stereo_depth_results:
        centroid = result["left_centroid"]
        depth = result["depth"]
        if depth < 5.0:
            rock_point_world_frame = project_pixel_to_world(
                poses[frame], centroid, result["depth"], "FrontLeft", cam_config
            )
            centroids.append(centroid)
            detections.append(Detection(points=centroid, data=rock_point_world_frame))
    tracked_objects = tracker.update(detections)

    for rock in tracked_objects:
        centroid_pixel = rock.last_detection.points[0]
        if rock.id not in rock_detections:
            rock_detections[rock.id] = {"frame": [], "points": [], "pixels": []}
        rock_detections[rock.id]["frame"].append(frame)
        rock_detections[rock.id]["points"].append(rock.last_detection.data)
        rock_detections[rock.id]["pixels"].append(centroid_pixel)

In [None]:
fig = go.Figure()
fig = plot_rock_map(map, fig=fig)
fig = plot_poses(poses, fig=fig, no_axes=True, color="black")
for id, detections in rock_detections.items():
    points = np.array(detections["points"])
    fig = plot_3d_points(
        points, fig=fig, color=int_to_color(id, hex=True), markersize=2, name=f"rock_{id}"
    )
    avg_point = np.mean(points, axis=0)
    fig = plot_3d_points(
        avg_point[None, :],
        fig=fig,
        color=int_to_color(id, hex=True),
        markersize=5,
        name=f"rock_{id}_avg",
    )
fig.show()

# SLAM


In [None]:
import gtsam
from gtsam.symbol_shorthand import X, L

from lac.slam.visual_odometry import StereoVisualOdometry
from lac.slam.slam import ROVER_T_CAM
from lac.params import FL_X, FL_Y, IMG_HEIGHT, IMG_WIDTH

In [None]:
svo = StereoVisualOdometry(cam_config)
START_FRAME = 80
svo.initialize(initial_pose, left_imgs[START_FRAME], right_imgs[START_FRAME])

# Pre-process the VO
svo_poses = [initial_pose]
pose_deltas = []

END_FRAME = 4500

for idx in tqdm(np.arange(START_FRAME + 2, END_FRAME, 2)):
    svo.track(left_imgs[idx], right_imgs[idx])
    svo_poses.append(svo.rover_pose)
    pose_deltas.append(svo.pose_delta)

In [None]:
PIXEL_NOISE = gtsam.noiseModel.Isotropic.Sigma(2, 5.0)
K = gtsam.Cal3_S2(FL_X, FL_Y, 0.0, IMG_WIDTH / 2, IMG_HEIGHT / 2)

svo_pose_sigma = 1e-2 * np.ones(6)
svo_pose_noise = gtsam.noiseModel.Diagonal.Sigmas(svo_pose_sigma)

graph = gtsam.NonlinearFactorGraph()
values = gtsam.Values()

values.insert(X(0), gtsam.Pose3(initial_pose))
graph.add(gtsam.NonlinearEqualityPose3(X(0), gtsam.Pose3(initial_pose)))

In [None]:
frame_to_i = {0: 0}

# Add poses and VO odometry
i = 1
for frame in range(START_FRAME, END_FRAME - 2, 2):
    frame_to_i[frame] = i
    # values.insert(X(i), gtsam.Pose3(poses[frame]))
    values.insert(X(i), gtsam.Pose3(svo_poses[i]))
    graph.push_back(
        gtsam.BetweenFactorPose3(X(i - 1), X(i), gtsam.Pose3(pose_deltas[i - 1]), svo_pose_noise)
    )
    i += 1

active_rock_ids = {}
rock_id_count = 0

# Add rock landmarks and observations
for id, detections in rock_detections.items():
    points = np.array(detections["points"])
    pixels = np.array(detections["pixels"])
    frames = np.array(detections["frame"])
    avg_point = np.median(points, axis=0)

    for j in range(len(frames)):
        if frames[j] not in frame_to_i:
            continue

        if id not in active_rock_ids:
            active_rock_ids[id] = rock_id_count
            rock_id_count += 1
            values.insert(L(active_rock_ids[id]), avg_point)

        i = frame_to_i[frames[j]]
        graph.add(
            gtsam.GenericProjectionFactorCal3_S2(
                pixels[j], PIXEL_NOISE, X(i), L(active_rock_ids[id]), K, ROVER_T_CAM
            )
        )

In [None]:
params = gtsam.LevenbergMarquardtParams()
params.setVerbosity("TERMINATION")
optimizer = gtsam.LevenbergMarquardtOptimizer(graph, values, params)
result = optimizer.optimize()

In [None]:
opt_poses = [result.atPose3(X(k)).matrix() for k in range(i)]

In [None]:
fig = plot_poses(poses[80:END_FRAME], no_axes=True, color="black", name="Ground Truth")
fig = plot_poses(opt_poses, no_axes=True, fig=fig, color="green", name="Opt")
fig.show()