In [None]:
import numpy as np
from tqdm import tqdm
from pathlib import Path
from gtsam.symbol_shorthand import X

from lac.slam.gtsam_factor_graph import GtsamFactorGraph
from lac.slam.slam import SLAM
from lac.slam.visual_odometry import StereoVisualOdometry
from lac.slam.feature_tracker import FeatureTracker
from lac.utils.plotting import plot_poses, plot_surface, plot_3d_points
from lac.util import load_data, load_stereo_images
from lac.params import LAC_BASE_PATH

%load_ext autoreload
%autoreload 2

In [None]:
# Load the data logs
data_path = Path(LAC_BASE_PATH) / "output/DataCollectionAgent/map1_preset0_stereo_lights1.0"
initial_pose, lander_pose, poses, imu_data, cam_config = load_data(data_path)
print(f"Loaded {len(poses)} poses")

# Load the images
left_imgs, right_imgs = load_stereo_images(data_path)

# Load the ground truth map
map = np.load(
    Path(LAC_BASE_PATH) / "data/heightmaps/competition/Moon_Map_01_preset_0.dat",
    allow_pickle=True,
)

In [None]:
# Initialize modules
START_FRAME = 80

svo = StereoVisualOdometry(cam_config)
svo.initialize(initial_pose, left_imgs[START_FRAME], right_imgs[START_FRAME])

tracker = FeatureTracker(cam_config)
tracker.initialize(initial_pose, left_imgs[START_FRAME], right_imgs[START_FRAME])

graph = SLAM()
pose_key = 0
graph.add_pose(pose_key, initial_pose)
graph.add_vision_factors(pose_key, tracker.world_points, tracker.prev_pts, tracker.track_ids)

In [None]:
IMG_RATE = 2
KEYFRAME_RATE = 10
GRAPH_UPDATE_RATE = 10
GRAPH_OPTIMIZE_RATE = 1000
END_FRAME = 4000

curr_pose = initial_pose
svo_poses = [initial_pose]
opt_first_key = pose_key

# Main loop over image frames
for step in tqdm(range(START_FRAME + IMG_RATE, END_FRAME + 1, IMG_RATE)):
    # Run VO for real-time pose tracking
    svo.track(left_imgs[step], right_imgs[step])
    curr_pose = svo.get_pose()
    svo_poses.append(curr_pose)

    # Track features
    tracker.track_keyframe(curr_pose, left_imgs[step], right_imgs[step])
    # if step % KEYFRAME_RATE == 0:
    #     tracker.track_keyframe(curr_pose, left_imgs[step], right_imgs[step])
    # else:
    #     tracker.track(left_imgs[step])

    # Add to the graph
    if step % GRAPH_UPDATE_RATE == 0:
        pose_key += 1
        graph.add_pose(pose_key, curr_pose)
        graph.add_vision_factors(
            pose_key, tracker.world_points, tracker.prev_pts, tracker.track_ids
        )

    # Optimize the graph
    if step % GRAPH_OPTIMIZE_RATE == 0:
        print(f"Optimizing window {opt_first_key} to {pose_key}")
        window = list(range(opt_first_key, pose_key + 1))
        graph.optimize(window)
        opt_first_key = pose_key

In [None]:
graph_poses = []
for pose in graph.poses.values():
    graph_poses.append(pose)

In [None]:
landmark_points = np.vstack(list(graph.landmarks.values()))

In [None]:
from lac.utils.geometry import crop_points

MAX_XY = 20.0
MIN_Z = 0.0
MAX_Z = 10.0
scene_bbox = np.array([[-MAX_XY, -MAX_XY, MIN_Z], [MAX_XY, MAX_XY, MAX_Z]])

landmark_points_cropped = crop_points(landmark_points, scene_bbox)

In [None]:
landmark_points.shape, landmark_points_cropped.shape

In [None]:
fig = plot_surface(map, showscale=False)
fig = plot_poses(poses[:END_FRAME], fig=fig, no_axes=True, color="black", name="Ground truth")
fig = plot_poses(svo_poses, fig=fig, no_axes=True, color="orange", name="VO poses")
fig = plot_poses(graph_poses, fig=fig, no_axes=True, color="green", name="Graph poses")
fig = plot_3d_points(landmark_points_cropped, fig=fig, color="lightblue", name="Landmarks")
fig.update_layout(height=900, width=1600, scene_aspectmode="data")
fig.show()