In [None]:
import numpy as np
import cv2
from tqdm import tqdm
from pathlib import Path
from gtsam.symbol_shorthand import X
import matplotlib.pyplot as plt
from lac.perception.depth import (
    stereo_depth_from_segmentation,
    project_pixel_to_world,
    project_pixels_to_world,
)

from lac.slam.feature_tracker import FeatureTracker, prune_features
from lac.perception.segmentation import UnetSegmentation
from lac.utils.plotting import plot_poses, plot_surface, plot_3d_points
from lac.util import load_data, load_stereo_images, load_side_images
from lac.utils.visualization import overlay_mask
from lac.params import LAC_BASE_PATH

%load_ext autoreload
%autoreload 2

In [None]:
# Load the data logs
# data_path = Path("/home/shared/data_raw/LAC/segmentation/slam_map1_preset1_teleop")
data_path = Path("/home/shared/data_raw/LAC/runs/full_spiral_map1_preset1_recovery_agent")
initial_pose, lander_pose, poses, imu_data, cam_config = load_data(data_path)
print(f"Loaded {len(poses)} poses")

# Load the images
df = 8
left_imgs, right_imgs = load_stereo_images(data_path, step=df)
side_left_imgs, side_right_imgs = load_side_images(data_path, step=df)

# Load the ground truth map
map = np.load(
    "/home/shared/data_raw/LAC/heightmaps/competition/Moon_Map_01_preset_1.dat",
    allow_pickle=True,
)

In [None]:
segmentation = UnetSegmentation()
feature_tracker = FeatureTracker(cam_config)

Frontend:

- Run both segmentation and feature extraction on left and right images
- Triangulate feature matches. For keypoints labeled as rock, group them together

Rock map:

- Each rock has a set of world points with associated descriptors (these descriptors should probably also be associated with a viewing direction since the apperance of a rock can change with viewing direction)
-

Graph SLAM:

- In the graph, we have a landmark for each rock corresponding to its centroid point
- For each keyframe, we determine the observed pixel centroid of the rock based on segementation outputs, and use that to add reprojection factors


# Tracking


In [None]:
import plotly.graph_objects as go
from lac.perception.depth import stereo_depth_from_segmentation, project_pixel_to_world
from lac.utils.plotting import plot_rock_map
from lac.utils.visualization import int_to_color

In [None]:
segmentation = UnetSegmentation()
feature_tracker = FeatureTracker(cam_config)

In [None]:
from lac.perception.vision import get_camera_intrinsics
from lac.utils.frames import get_cam_pose_rover, CAMERA_TO_OPENCV_PASSIVE
from scipy.linalg import inv

# def get_depths_mono(cam_name, points0, points1, prev_pose, curr_pose):

frame = 200
cam_name = "Right"
img0 = side_left_imgs[frame - df]
img1 = side_left_imgs[frame]
prev_pose = poses[frame - df]
curr_pose = poses[frame]

# Camera intrinsics and extrinsics
K = get_camera_intrinsics(cam_name, cam_config)
rover_T_cam = get_cam_pose_rover(cam_name)
rover_T_cam_ocv = rover_T_cam.copy()
rover_T_cam_ocv[:3, :3] = rover_T_cam_ocv[:3, :3] @ CAMERA_TO_OPENCV_PASSIVE

# Projection matrices
cam_T_world_0 = inv(prev_pose @ rover_T_cam_ocv)
cam_T_world_1 = inv(curr_pose @ rover_T_cam_ocv)

P0 = K @ cam_T_world_0[:3]
P1 = K @ cam_T_world_1[:3]

points_cam = np.array(
    [
        [0.0, 0.0, 2.0],
        [-1.0, 0.0, 3.0],
        [0.0, -1.5, 4.0],
    ]
)
points_cam_h = np.hstack((points_cam, np.ones((points_cam.shape[0], 1))))
points_world = (prev_pose @ rover_T_cam_ocv @ points_cam_h.T).T
points_world_h = np.hstack((points_world[:, :3], np.ones((points_world.shape[0], 1))))
points_cam_0 = (cam_T_world_0 @ points_world_h.T).T[:, :3]
points_cam_1 = (cam_T_world_1 @ points_world_h.T).T[:, :3]

points0 = (K @ points_cam_0.T).T
points0 = (points0[:, :2].T / points0[:, 2]).T
points1 = (K @ points_cam_1.T).T
points1 = (points1[:, :2].T / points1[:, 2]).T

# Triangulate
points_4d_h = cv2.triangulatePoints(P0, P1, points0.T, points1.T)
points_3d_est = (points_4d_h[:3] / points_4d_h[3]).T
depths_est = (cam_T_world_1[:3, :3] @ points_3d_est.T + cam_T_world_1[:3, 3:4]).T[:, 2]

print(depths_est)

# return points_3d_next, depths_next

In [None]:
from lac.perception.matching import get_depths_mono, get_submask, insert_submask

frame = 200
cam_name = "Right"
img0 = side_left_imgs[frame - df]
img1 = side_left_imgs[frame]
prev_pose = poses[frame - df]
curr_pose = poses[frame]

get_depths_mono(
    cam_name,
    cam_config,
    np.array([640, 360]),
    np.array([704.33040004, 377.72892952]),
    prev_pose,
    curr_pose,
)

In [None]:
from lac.perception.segmentation_util import get_mask_centroids, centroid_matching
from lac.params import FL_X, STEREO_BASELINE

frame = 200
cam_name = "Right"
img0 = side_left_imgs[frame - df]
img1 = side_left_imgs[frame]
prev_pose = poses[frame - df]
curr_pose = poses[frame]

from multiprocessing import Pool


def process_match(args):
    centroid0, centroid1, mask0, mask1, cam_name, cam_config, curr_pose, prev_pose = args

    if mask0.sum() < 100 or mask1.sum() < 100:
        return None

    height = np.max(np.where(mask0)[1]) - np.min(np.where(mask0)[1]) + 1
    width = np.max(np.where(mask0)[0]) - np.min(np.where(mask0)[0]) + 1
    size = (width, height)

    best_submask, best_score = None, 0
    for dx in range(-2, 2):
        for dy in range(-2, 2):
            center0 = (centroid0[0], centroid0[1])
            center1 = (centroid1[0] + dx, centroid1[1] + dy)
            mask0_new = np.logical_and(
                mask0, insert_submask(np.zeros_like(mask0), np.ones(size, np.uint8), center0)
            )
            mask1_new = np.logical_and(
                mask1, insert_submask(np.zeros_like(mask1), np.ones(size, np.uint8), center1)
            )
            submask0 = get_submask(mask0_new, size, centroid0)
            submask1 = get_submask(mask1_new, size, centroid1)
            submask = np.logical_and(submask0, submask1)
            score = np.sum(submask)
            if score > best_score:
                best_score, best_submask = score, submask

    if best_submask is not None:
        mask0_common = insert_submask(np.zeros_like(mask0), best_submask, centroid0)
        mask1_common = insert_submask(np.zeros_like(mask1), best_submask, centroid1)
        centroid0_new = np.mean(np.array(np.argwhere(mask0_common)).T, axis=1)[::-1]
        centroid1_new = np.mean(np.array(np.argwhere(mask1_common)).T, axis=1)[::-1]
    else:
        centroid0_new, centroid1_new = centroid0, centroid1

    if "Front" in cam_name:
        disparity = centroid0_new[0] - centroid1_new[0] + 1e-8
        depth = (FL_X * STEREO_BASELINE) / disparity
        centroid = project_pixel_to_world(curr_pose, centroid0_new, depth, "FrontLeft", cam_config)
    else:
        _, depth = get_depths_mono(cam_name, centroid0_new, centroid1_new, prev_pose, curr_pose)
        depth = depth[0]
        centroid = project_pixel_to_world(curr_pose, centroid0_new, depth, cam_name, cam_config)

    if 0 < depth < 5:
        return centroid
    return None


def get_rock_centroids(cam_name, img0, img1, curr_pose, prev_pose=None, refine_centroids=True):
    seg_masks0, labels0 = segmentation.segment_rocks(img0)
    seg_masks1, labels1 = segmentation.segment_rocks(img1)
    if len(seg_masks0) == 0 or len(seg_masks1) == 0:
        return []
    centroids0, centroids1 = get_mask_centroids(seg_masks0), get_mask_centroids(seg_masks1)
    matches = centroid_matching(centroids0, centroids1, max_y_diff=100, max_x_diff=100)

    all_centroids = []
    if not refine_centroids:
        for match in matches:
            centroid0, centroid1 = centroids0[match[0]], centroids1[match[1]]
            if "Front" in cam_name:
                disparity = centroid0[0] - centroid1[0] + 1e-8
                depth = (FL_X * STEREO_BASELINE) / disparity
                centroid = project_pixel_to_world(
                    curr_pose, centroid0, depth, "FrontLeft", cam_config
                )
            else:
                _, depth = get_depths_mono(cam_name, centroid0, centroid1, prev_pose, curr_pose)
                depth = depth[0]
                centroid = project_pixel_to_world(curr_pose, centroid0, depth, cam_name, cam_config)
            all_centroids.append(centroid)
        return all_centroids

    # Refine
    for match in matches:
        c = process_match(
            (
                centroids0[match[0]],
                centroids1[match[1]],
                seg_masks0[match[0]],
                seg_masks1[match[1]],
                cam_name,
                cam_config,
                curr_pose,
                prev_pose,
            )
        )
        if c is not None:
            all_centroids.append(c)

    return all_centroids

In [None]:
START_FRAME = 80
END_FRAME = START_FRAME + 100
END_FRAME = max(left_imgs.keys())

centroids_no_refine = []
for frame in tqdm(range(START_FRAME, END_FRAME, 4 * df)):
    centroids_no_refine.extend(
        get_rock_centroids(
            "FrontLeft", left_imgs[frame], right_imgs[frame], poses[frame], refine_centroids=False
        )
    )
    # for df in range(10, 14, 2):
    #     centroids.extend(get_rock_centroids("Left", side_left_imgs[frame-df], side_left_imgs[frame], poses[frame], poses[frame - df]))
    #     centroids.extend(get_rock_centroids("Right", side_right_imgs[frame-df], side_right_imgs[frame], poses[frame], poses[frame - df]))

In [None]:
START_FRAME = 80
END_FRAME = START_FRAME + 100
END_FRAME = max(left_imgs.keys())

centroids = []
for frame in tqdm(range(START_FRAME, END_FRAME, 4 * df)):
    centroids.extend(
        get_rock_centroids("FrontLeft", left_imgs[frame], right_imgs[frame], poses[frame])
    )
    # for df in range(10, 14, 2):
    #     centroids.extend(get_rock_centroids("Left", side_left_imgs[frame-df], side_left_imgs[frame], poses[frame], poses[frame - df]))
    #     centroids.extend(get_rock_centroids("Right", side_right_imgs[frame-df], side_right_imgs[frame], poses[frame], poses[frame - df]))

In [None]:
START_FRAME = 80
END_FRAME = max(left_imgs.keys())
END_FRAME = START_FRAME + 5000

centroids_1st = []
for frame in tqdm(range(START_FRAME, END_FRAME, 4 * df)):
    centroids_1st.extend(
        get_rock_centroids(
            "FrontLeft", left_imgs[frame], right_imgs[frame], poses[frame], refine_centroids=False
        )
    )
    # for df in range(10, 14, 2):
    #     centroids.extend(get_rock_centroids("Left", side_left_imgs[frame-df], side_left_imgs[frame], poses[frame], poses[frame - df]))
    #     centroids.extend(get_rock_centroids("Right", side_right_imgs[frame-df], side_right_imgs[frame], poses[frame], poses[frame - df]))

In [None]:
xmin, xmax = np.min(map[:, :, 0]), np.max(map[:, :, 0])
ymin, ymax = np.min(map[:, :, 1]), np.max(map[:, :, 1])
nx, ny = map.shape[:2]

In [None]:
from lac.perception.matching import get_rocks_score

centroids_1st = np.array(centroids_1st)
agent_map = np.zeros_like(map)
for c in centroids_1st:
    i = int((c[0] - xmin) / (xmax - xmin) * nx)
    j = int((c[1] - ymin) / (ymax - ymin) * ny)
    if 0 <= i < nx and 0 <= j < ny:
        agent_map[i, j, 3] = 1.0

print(get_rocks_score(map, agent_map))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(map[:, :, 3], cmap="gray")
ax[1].imshow(agent_map[:, :, 3], cmap="gray")
plt.show()

In [None]:
from lac.perception.matching import get_rocks_score

centroids_no_refine = np.array(centroids_no_refine)
agent_map = np.zeros_like(map)
for c in centroids_no_refine:
    i = int((c[0] - xmin) / (xmax - xmin) * nx)
    j = int((c[1] - ymin) / (ymax - ymin) * ny)
    if 0 <= i < nx and 0 <= j < ny:
        agent_map[i, j, 3] = 1.0

print(get_rocks_score(map, agent_map))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(map[:, :, 3], cmap="gray")
ax[1].imshow(agent_map[:, :, 3], cmap="gray")
plt.show()

In [None]:
from lac.perception.matching import get_rocks_score

agent_map = np.zeros_like(map)
for c in centroids:
    i = int((c[0] - xmin) / (xmax - xmin) * nx)
    j = int((c[1] - ymin) / (ymax - ymin) * ny)
    if 0 <= i < nx and 0 <= j < ny:
        agent_map[i, j, 3] = 1.0

print(get_rocks_score(map, agent_map))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(map[:, :, 3], cmap="gray")
ax[1].imshow(agent_map[:, :, 3], cmap="gray")
plt.show()

In [None]:
fig = go.Figure()
fig = plot_rock_map(map, fig=fig)
fig = plot_poses(poses, fig=fig, no_axes=True, color="black")
fig = plot_poses([poses[END_FRAME]], fig=fig, no_axes=False, color="black")
fig.add_scatter3d(
    x=centroids[:, 0],
    y=centroids[:, 1],
    z=centroids[:, 2],
    mode="markers",
    marker=dict(size=3, color="green"),
)
fig.show()

In [None]:
# END_FRAME = img_idxs[-1]
END_FRAME = 2000
for frame in tqdm(range(2, END_FRAME, 2)):
    left_seg_masks, left_seg_labels = segmentation.segment_rocks(left_imgs[frame])
    right_seg_masks, right_seg_labels = segmentation.segment_rocks(right_imgs[frame])
    left_seg_full_mask = np.clip(left_seg_labels, 0, 1)

    stereo_depth_results = stereo_depth_from_segmentation(
        left_seg_masks, right_seg_masks, STEREO_BASELINE, FL_X
    )

    detections = []
    centroids = []
    for result in stereo_depth_results:
        centroid = result["left_centroid"]
        depth = result["depth"]
        if depth < 5.0:
            rock_point_world_frame = project_pixel_to_world(
                poses[frame], centroid, result["depth"], "FrontLeft", cam_config
            )
            centroids.append(centroid)
            detections.append(Detection(points=centroid, data=rock_point_world_frame))
    tracked_objects = tracker.update(detections)

    for rock in tracked_objects:
        centroid_pixel = rock.last_detection.points[0]
        if rock.id not in rock_detections:
            rock_detections[rock.id] = {"frame": [], "points": [], "pixels": []}
        rock_detections[rock.id]["frame"].append(frame)
        rock_detections[rock.id]["points"].append(rock.last_detection.data)
        rock_detections[rock.id]["pixels"].append(centroid_pixel)

In [None]:
fig = go.Figure()
fig = plot_rock_map(map, fig=fig)
fig = plot_poses(poses, fig=fig, no_axes=True, color="black")
for id, detections in rock_detections.items():
    points = np.array(detections["points"])
    fig = plot_3d_points(
        points, fig=fig, color=int_to_color(id, hex=True), markersize=2, name=f"rock_{id}"
    )
    avg_point = np.mean(points, axis=0)
    fig = plot_3d_points(
        avg_point[None, :],
        fig=fig,
        color=int_to_color(id, hex=True),
        markersize=5,
        name=f"rock_{id}_avg",
    )
fig.show()

# SLAM


In [None]:
import gtsam
from gtsam.symbol_shorthand import X, L

from lac.slam.visual_odometry import StereoVisualOdometry
from lac.slam.slam import ROVER_T_CAM
from lac.params import FL_X, FL_Y, IMG_HEIGHT, IMG_WIDTH

In [None]:
svo = StereoVisualOdometry(cam_config)
START_FRAME = 80
svo.initialize(initial_pose, left_imgs[START_FRAME], right_imgs[START_FRAME])

# Pre-process the VO
svo_poses = [initial_pose]
pose_deltas = []

END_FRAME = 4500

for idx in tqdm(np.arange(START_FRAME + 2, END_FRAME, 2)):
    svo.track(left_imgs[idx], right_imgs[idx])
    svo_poses.append(svo.rover_pose)
    pose_deltas.append(svo.pose_delta)

In [None]:
PIXEL_NOISE = gtsam.noiseModel.Isotropic.Sigma(2, 5.0)
K = gtsam.Cal3_S2(FL_X, FL_Y, 0.0, IMG_WIDTH / 2, IMG_HEIGHT / 2)

svo_pose_sigma = 1e-2 * np.ones(6)
svo_pose_noise = gtsam.noiseModel.Diagonal.Sigmas(svo_pose_sigma)

graph = gtsam.NonlinearFactorGraph()
values = gtsam.Values()

values.insert(X(0), gtsam.Pose3(initial_pose))
graph.add(gtsam.NonlinearEqualityPose3(X(0), gtsam.Pose3(initial_pose)))

In [None]:
frame_to_i = {0: 0}

# Add poses and VO odometry
i = 1
for frame in range(START_FRAME, END_FRAME - 2, 2):
    frame_to_i[frame] = i
    # values.insert(X(i), gtsam.Pose3(poses[frame]))
    values.insert(X(i), gtsam.Pose3(svo_poses[i]))
    graph.push_back(
        gtsam.BetweenFactorPose3(X(i - 1), X(i), gtsam.Pose3(pose_deltas[i - 1]), svo_pose_noise)
    )
    i += 1

active_rock_ids = {}
rock_id_count = 0

# Add rock landmarks and observations
for id, detections in rock_detections.items():
    points = np.array(detections["points"])
    pixels = np.array(detections["pixels"])
    frames = np.array(detections["frame"])
    avg_point = np.median(points, axis=0)

    for j in range(len(frames)):
        if frames[j] not in frame_to_i:
            continue

        if id not in active_rock_ids:
            active_rock_ids[id] = rock_id_count
            rock_id_count += 1
            values.insert(L(active_rock_ids[id]), avg_point)

        i = frame_to_i[frames[j]]
        graph.add(
            gtsam.GenericProjectionFactorCal3_S2(
                pixels[j], PIXEL_NOISE, X(i), L(active_rock_ids[id]), K, ROVER_T_CAM
            )
        )

In [None]:
params = gtsam.LevenbergMarquardtParams()
params.setVerbosity("TERMINATION")
optimizer = gtsam.LevenbergMarquardtOptimizer(graph, values, params)
result = optimizer.optimize()

In [None]:
opt_poses = [result.atPose3(X(k)).matrix() for k in range(i)]

In [None]:
fig = plot_poses(poses[80:END_FRAME], no_axes=True, color="black", name="Ground Truth")
fig = plot_poses(opt_poses, no_axes=True, fig=fig, color="green", name="Opt")
fig.show()