In [None]:
import numpy as np
from tqdm import tqdm
from pathlib import Path
import gtsam
from gtsam.symbol_shorthand import X
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from lac.slam.gtsam_factor_graph import GtsamFactorGraph
from lac.slam.slam import SLAM
from lac.slam.gtsam_util import remove_outliers, plot_reprojection_residuals
from lac.slam.visual_odometry import StereoVisualOdometry
from lac.slam.feature_tracker import FeatureTracker
from lac.utils.plotting import plot_poses, plot_surface, plot_3d_points
from lac.utils.visualization import image_grid
from lac.util import load_data, load_stereo_images, load_images, positions_rmse_from_poses

%load_ext autoreload
%autoreload 2

In [None]:
# Load the data logs
data_path = "/home/shared/data_raw/LAC/runs/full_spiral_map1_preset0"
initial_pose, lander_pose, poses, imu_data, cam_config = load_data(data_path)
print(f"Loaded {len(poses)} poses")

In [None]:
# images = load_images(data_path, cameras=["FrontLeft", "FrontRight"], start_frame=0, end_frame=10000)
left_imgs, right_imgs = load_stereo_images(data_path, start_frame=0, end_frame=10000)
images = {"FrontLeft": left_imgs, "FrontRight": right_imgs}

In [None]:
# Load the ground truth map
map = np.load(
    "/home/shared/data_raw/LAC/heightmaps/competition/Moon_Map_01_preset_0.dat",
    allow_pickle=True,
)

In [None]:
# Initialize modules
START_FRAME = 80

svo = StereoVisualOdometry(cam_config)
svo.initialize(poses[START_FRAME], images["FrontLeft"][START_FRAME], images["FrontRight"][START_FRAME])

tracker = FeatureTracker(cam_config)
tracker.initialize(poses[START_FRAME], images["FrontLeft"][START_FRAME], images["FrontRight"][START_FRAME])

graph = SLAM()
graph.add_pose(0, poses[START_FRAME])
graph.add_vision_factors(0, tracker)

In [None]:
IMG_RATE = 2
GRAPH_UPDATE_RATE = 2
END_FRAME = 1000

curr_pose = initial_pose
svo_poses = [initial_pose]
eval_poses = [initial_pose]
pose_key = 1

# Main loop over image frames
for step in tqdm(range(START_FRAME + 1, END_FRAME + 1)):
    # graph.accumulate_imu_measurement(imu_data[step])

    if step % IMG_RATE == 0:
        # Run VO for real-time pose tracking
        svo.track(images["FrontLeft"][step], images["FrontRight"][step])
        curr_pose = svo.get_pose()
        svo_poses.append(curr_pose)

        # Track features
        tracker.track_keyframe(curr_pose, images["FrontLeft"][step], images["FrontRight"][step])

        # Add to the graph
        if step % GRAPH_UPDATE_RATE == 0:
            graph.add_pose(pose_key, curr_pose)
            graph.add_vision_factors(pose_key, tracker)
            # graph.add_imu_factor(pose_key)
            graph.add_odometry_factor(pose_key, svo.pose_delta)

            eval_poses.append(poses[step])

            pose_key += 1

In [None]:
g, v, landmarks = graph.build_graph(list(range(0, pose_key)))

In [None]:
window = list(range(0, pose_key))
result, g, v = graph.optimize(window, verbose=True, remove_outliers=True)

In [None]:
fig = graph.plot(show_landmarks=False)
fig = plot_poses(poses[:END_FRAME], fig=fig, no_axes=True, color="black", name="Ground truth")
fig = plot_poses(svo_poses, fig=fig, no_axes=True, color="orange", name="VO poses")
# fig = plot_3d_points(landmark_points_cropped, fig=fig, color="red", markersize=2, name="Landmarks")
fig.update_layout(height=900, width=1600, scene_aspectmode="data")
fig.show()

In [None]:
plot_reprojection_residuals(g, result)

In [None]:
opt_poses = [pose for pose in graph.poses.values()]
positions_rmse_from_poses(opt_poses, eval_poses)

In [None]:
fig = graph.plot(start=0, end=1000, step=50)
# fig = plot_surface(map, fig=fig, showscale=False)
fig = plot_poses(poses[:END_FRAME], fig=fig, no_axes=True, color="black", name="Ground truth")
fig = plot_poses(svo_poses, fig=fig, no_axes=True, color="orange", name="VO poses")
# fig = plot_3d_points(landmark_points_cropped, fig=fig, color="red", markersize=2, name="Landmarks")
fig.update_layout(height=900, width=1600, scene_aspectmode="data")
fig.show()

In [None]:
graph.projection_factors[200][0].error(result)

In [None]:
graph_poses = []
for pose in graph.poses.values():
    graph_poses.append(pose)

In [None]:
fig = go.Figure()
# fig = plot_surface(map, fig=fig, showscale=False)
fig = plot_poses(poses[:END_FRAME], fig=fig, no_axes=True, color="black", name="Ground truth")
fig = plot_poses(svo_poses, fig=fig, no_axes=True, color="orange", name="VO poses")
fig = plot_poses(graph_poses, fig=fig, no_axes=True, color="green", name="Graph poses")
# fig = plot_3d_points(landmark_points_cropped, fig=fig, color="red", markersize=2, name="Landmarks")
fig.update_layout(height=900, width=1600, scene_aspectmode="data")
fig.show()

In [None]:
graph.plot(start=0, end=100, step=10)

In [None]:
fig.write_html("gtsam_slam.html")