In [1]:
import numpy as np
import torch
import theseus as th
import torchlie.functional as lieF # use this instead of th.SE3
import matplotlib.pyplot as plt
from scipy.spatial import KDTree
from scipy.spatial.transform import Rotation
from typing import Union, List, Tuple, Optional, cast

In [2]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def torch2np(tensor: torch.Tensor) -> np.ndarray:
    """ Converts a PyTorch tensor to a NumPy ndarray.
    Args:
        tensor: The PyTorch tensor to convert.
    Returns:
        A NumPy ndarray with the same data and dtype as the input tensor.
    """
    return tensor.detach().cpu().numpy()

In [3]:
class GaussianSLAMEdge:
    def __init__(
        self,
        vertex_idx_i: int,
        vertex_idx_j: int,
        relative_pose: th.SE3,
        cost_weight: th.CostWeight
    ):
        self.vertex_idx_i = vertex_idx_i
        self.vertex_idx_j = vertex_idx_j
        self.relative_pose = relative_pose
        self.cost_weight = cost_weight
        

class GaussianSLAMPoseGraph:
    def __init__(
        self, 
        requires_auto_grad = True
    ):
        self._requires_auto_grad = requires_auto_grad
        self._objective = th.Objective()
        self._theseus_inputs = {} 
    
    def add_odometry_edge(
            self,
            vertex_i: th.SE3,
            vertex_j: th.SE3,
            edge: GaussianSLAMEdge,
            gaussian_means: torch.Tensor
        ):

        if self._requires_auto_grad:
            for point in gaussian_means:
                cost_function = th.AutoDiffCostFunction(
                    optim_vars=[vertex_i, vertex_j], 
                    err_fn=GaussianSLAMPoseGraph.dense_surface_alignment, 
                    dim=3, 
                    cost_weight=edge.cost_weight, 
                    aux_vars=[edge.relative_pose, th.Point3(tensor=point)]
                )
                self._objective.add(cost_function)
            self._theseus_inputs.update({
                vertex_i.name: vertex_i.tensor, 
                vertex_j.name: vertex_j.tensor
            })
        else:
            raise NotImplementedError()
    
    def add_loop_closure_edge(
            self,
            vertex_i: th.SE3,
            vertex_j: th.SE3,
            edge: GaussianSLAMEdge,
            gaussian_means: torch.tensor,
            match_num : int,
            tau: float=0.2, # fairly liberal distance threshold
        ):

        cost_weight_registration = edge.cost_weight # for dense surface alignment
        cost_weight_mu = cost_weight_registration.scale.tensor.squeeze() * np.sqrt(match_num) * tau
        print(f"cost_weight_mu = {cost_weight_mu}")
        cost_weight_line_process = th.ScaleCostWeight(cost_weight_mu) # for line process
        
        l_ij = th.Vector(tensor=torch.ones(1, 1), name=f"line_process_{edge.vertex_idx_i}_{edge.vertex_idx_j}")    

        if self._requires_auto_grad:
            for point in gaussian_means:
                cost_function_registration = th.AutoDiffCostFunction(
                    optim_vars=[vertex_i, vertex_j, l_ij], 
                    err_fn=GaussianSLAMPoseGraph.dense_surface_alignment, 
                    dim=3, 
                    cost_weight=cost_weight_registration, 
                    aux_vars=[edge.relative_pose, th.Point3(tensor=point)]
                )
                self._objective.add(cost_function_registration)

            cost_function_line_process = th.AutoDiffCostFunction(
                optim_vars=[l_ij,], 
                err_fn=GaussianSLAMPoseGraph.line_process, 
                dim=1, 
                cost_weight=cost_weight_line_process
            ) # auxiliary variables can be not declared
            self._objective.add(cost_function_line_process)
        
            self._theseus_inputs.update({
                vertex_i.name: vertex_i.tensor, 
                vertex_j.name: vertex_j.tensor,
                l_ij.name: l_ij.tensor
            })
        else:
            raise NotImplementedError()

    
    def _remove_loop_outlier(self, threshold: float, substring="line_process"):
        """ 
        find all l_ij inside the objective, remove those whose value are smaller than threshold 
        and all cost functions that are connected to them
        (currently implemented with the help of a dictionary called self._theseus_inputs, can it be
        directly done with self._objective?)
        """
        for key in self._theseus_inputs.keys():
            if substring in key and self._objective.get_optim_var(key).tensor < threshold:
                del self._theseus_inputs[key]
                for cost_func in self._objective.get_functions_connected_to_optim_var(key):
                    self._objective.erase(cost_func.name)

    def _optimize(self, 
            max_iterations=1e3, 
            step_size=0.01, 
            damping=0.1,
            track_best_solution=True, 
            verbose=False
        ):
        optimizer = th.LevenbergMarquardt(
            objective=self._objective,
            max_iterations=max_iterations,
            step_size=step_size
        )
        layer = th.TheseusLayer(optimizer)

        with torch.no_grad():
            _, info = layer.forward(
                #self._theseus_inputs, # 这个要输入吗？
                optimizer_kwargs={"damping":damping, "track_best_solution":track_best_solution, "verbose":verbose}
                )
        return info

    def optimize_two_steps(
            self, 
            max_iterations=1e3, 
            step_size=0.01, 
            l_ij_threshold=0.25,
            damping=0.1,
            track_best_solution=True, 
            verbose=False
        ):
        """
        optimization in two steps: 
        1. optimize with initial guess of all optim variables (T_i, l_ij)
        2. remove all l_ij < threshold, and all cost functions that are connected to them,
           optimize again with optimized variables
        """

        print(f"First step optimization, dealing with {self._objective.size_cost_functions()} cost functions")
        _ = self._optimize(max_iterations, step_size, damping, track_best_solution, verbose)
        self._remove_loop_outlier(threshold=l_ij_threshold)
        print(f"Second step optimization, dealing with {self._objective.size_cost_functions()} cost functions")
        info = self._optimize(max_iterations, step_size, damping, track_best_solution,verbose)
        return info

    @ staticmethod
    def dense_surface_alignment(
        optim_vars: Union[Tuple[th.SE3, th.SE3], Tuple[th.SE3, th.SE3, th.Vector]],
        aux_vars: Tuple[th.SE3, th.Point3]
    ) -> torch.Tensor:
        """
        Compute the dense surface alignment error between two vertices, can be used as the error
        function input to instantiate a th.CostFunction variable
    
        Args:
            optim_vars: optimizaiton variables registered in cost function, should contain
                pose_i, pose_j: correction matrix for pose i, j
                l_ij (optional): line process coefficient
    
            aux_vars: auxiliary variables registered in cost function, should contain
                relative_pose: constraint between vertex_i and vertex_j
                gaussian_means_i: mean positions of the 3D Gaussians inside camera frustum, 
                    represented in coordinate i and coordinate (those in coordinate j are not needed).
                    should have shape (batch_size, num_points, dim)
    
        Returns:
            square root of global place recognition error
        """
        # determine whether the edge is odometry edge or loop closure edge
        tuple_size = len(optim_vars)
        if tuple_size == 2:
            pose_i, pose_j = optim_vars
        elif tuple_size == 3:
            pose_i, pose_j, l_ij = optim_vars
        else:
            raise ValueError(f"optim_vars tuple size is {tuple_size}, which can only be 2 or 3.")
        relative_pose, gaussian_means_i = aux_vars
    
        #print(f"gaussian_means shape = {gaussian_means_i.shape}")
        
        # transform all points in coordinate i and j to world coordinate
        gaussian_means_i_transformed: th.Point3 = pose_i.transform_from(gaussian_means_i)
        gaussian_means_j_transformed: th.Point3 = pose_j.transform_from(
            relative_pose.transform_from(gaussian_means_i))
    
        residual = (gaussian_means_i_transformed - gaussian_means_j_transformed).tensor
    
        #gaussian_means_w1 = lieF.SE3.transform(pose_i.tensor, gaussian_means_i.tensor)
        #gaussian_means_j = lieF.SE3.transform(relative_pose.tensor, gaussian_means_i.tensor)
        #gaussian_means_w2 = lieF.SE3.transform(pose_j.tensor, gaussian_means_j)
        #residual = gaussian_means_w1 - gaussian_means_w2
    
        # check if this error function is used for odometry edge or loop edge
        if tuple_size == 2:
            return residual
        else:
            return torch.sqrt(l_ij.tensor) * residual

    
    @ staticmethod
    def line_process(optim_vars: th.Vector, aux_vars=None) -> torch.Tensor:
        """
        Computes the line process error of a loop closrue edge, can be used as the error
        input to instantiate a th.CostFunction variable
    
        Args:
            optim_vars:
                l_ij: jointly optimized weight (l_ij ∈ [0, 1]) over the loop edges
                (note that the scaling factor mu is considered as cost_weight)
    
        Returns:
            square root of line process error
        """
        l_ij, = optim_vars
        return l_ij.tensor.sqrt() - 1

    
    @ staticmethod
    def match_gaussian_means(
        pts_1: torch.tensor,
        pts_2: torch.tensor,
        transformation: torch.tensor,
        epsilon:float=10e-2
    ) -> List[Tuple[int, int]]:
        """
        Select inlier correspondences from two Gaussian clouds, use kd-tree to speed up
    
        Args:
            pts_1, pts_2: mean positions of 3D Gaussians
            transformation: prior transformation matrix from one Gaussian cloud to the other
            epsilon: threshold for finding inlier correspondence
    
        Returns:
            a list contains tuples of matching indices
        """
        if transformation.size() != torch.Size([4, 4]):
            raise ValueError(f"The size of input transformation matrix must be (4, 4), but get {transformation.size()}")
    
        if pts_1.size(-1) != 1:
            pts_1 = pts_1.unsqueeze(-1)
    
        if isinstance(pts_1, th.Point3) or isinstance(pts_2, th.Point3):
                raise TypeError("To be matched points must be torch.Tensor")
    
        rotation = transformation[:3, :3]
        translation = transformation[:3, 3]
        pts_1_new = (rotation @ pts_1).squeeze() + translation
    
        pts_1_numpy = torch2np(pts_1_new)
        pts_2_numpy = torch2np(pts_2)
        pts2_kdtree = KDTree(pts_2_numpy)
    
        _, query_idx = pts2_kdtree.query(pts_1_numpy, distance_upper_bound=epsilon, workers=-1)
    
        data_size = pts_1.size()[0]
        res_list = []
        for i in range(data_size):
            if query_idx[i] != data_size:
                res_list.append((i, query_idx[i]))
    
        return res_list, len(res_list)

In [4]:
def create_data(
        num_pts: int = 500, 
        num_poses: int = 10, 
        translation_noise: float = 0.05, 
        rotation_noise: float = 0.1, 
        weight: float = 1.0,
        batch_size: int = 1,
        #dtype = torch.float32 # will get error if changed to torch.float64, don't know why
        ) -> Tuple[List[th.Point3], List[th.SE3], List[th.SE3], List[GaussianSLAMEdge]]:
    """
    create point clouds represented in different coordinates, record their ground truth 
    absolute pose, noisy absolute pose, also return an empty list to put loop edges

    Returns:
        point_list: a list stores points clouds, represented in different coordinates
        abs_pose_list_gt: a list stores ground truth absolute poses
        abs_pose_list: a list stores noisy (odometry) absolute poses
        edge_list: a list stores custum GaussianSLAMEdge
        TODO: Do I need to put the first edge that connets vertex_0 and vertex_1 into the list?
    """

    points_0 = th.Point3(2*torch.rand(num_pts, 3)-1, name="POINT_CLOUD__0") # initial points in world frame
    point_list = [points_0] # represented in different frames
    abs_pose_list_gt = [] # frame i to world frame
    abs_pose_list = [] # frame i to world frame (noisy)
    edge_list = []

    abs_pose_list_gt.append(th.SE3(
        #tensor=torch.tile(torch.eye(3, 4), [1, 1, 1]),
        tensor=torch.eye(3, 4).unsqueeze(0),
        name="VERTEX_SE3_GT__0"
        ))
    
    abs_pose_list.append(th.SE3(
        #tensor=torch.tile(torch.eye(3, 4), [1, 1, 1]),
        tensor=torch.eye(3, 4).unsqueeze(0),
        name="VERTEX_SE3__0"
        ))

    for idx in range(1, num_poses):

        # ground truth relative pose from frame_{idx-1} to frame_{idx}
        relative_pose_gt = th.SE3.exp_map(
            torch.cat([
                torch.rand(batch_size, 3) - 0.5, 
                 2.0 * torch.rand(batch_size, 3 ) -1
            ], dim=1),
        )

        # generate points represented in frame_{idx}
        points = relative_pose_gt.transform_from(point_list[-1])
        points.name = f"POINT_CLOUD__{idx}"
        point_list.append(points)

        # add noise to get odometry relative pose from frame_{idx-1} to frame_{idx}
        
        relative_pose_noise = th.SE3.exp_map(
            torch.cat([
                translation_noise * (2.0 * torch.rand(batch_size, 3) - 1),
                rotation_noise * (2.0 * torch.rand(batch_size, 3) - 1),
            ] ,dim=1),
        )

        relative_pose = cast(th.SE3, relative_pose_noise.compose(relative_pose_gt))
        #relative_pose = cast(th.SE3, relative_pose_gt.compose(relative_pose_noise))
        relative_pose.name = f"EDGE_SE3__{idx-1}_{idx}"
        cost_weight = th.ScaleCostWeight(weight, name=f"EDGE_WEIGHT__{idx-1}_{idx}")

        # absolute pose of frame_{idx}
        absolute_pose_gt = cast(th.SE3, abs_pose_list_gt[-1].compose(relative_pose_gt.inverse()))
        #absolute_pose_gt = cast(th.SE3, abs_pose_list_gt[-1].compose(relative_pose_gt))
        absolute_pose_gt.name = f"VERTEX_SE3_GT__{idx}"

        absolute_pose = cast(th.SE3, abs_pose_list[-1].compose(relative_pose.inverse()))
        #absolute_pose = cast(th.SE3, abs_pose_list[-1].compose(relative_pose))
        absolute_pose.name = f"VERTEX_SE3__{idx}"

        abs_pose_list_gt.append(absolute_pose_gt)
        abs_pose_list.append(absolute_pose)

        # construct odometry edge between vertex_{idx-1} and vertex_{idx}
        edge_list.append(GaussianSLAMEdge(idx-1, idx, relative_pose, cost_weight))

    return point_list, abs_pose_list_gt, abs_pose_list, edge_list


def add_loop_data(
        i: int, 
        j: int, 
        abs_pose_list_gt: List[th.SE3], 
        edge_list: List[GaussianSLAMEdge],
        weight: float = 1.0,
        measurement_noise:float = 0.01,
        batch_size: int = 1,
        ) -> None:
    """
    Add loop closure between two arbitray coordinates i and j (i < j), and stores generated edge
    """

    if i >= j:
        raise ValueError(f"The first frame index {i} is greater than the second frame index {j}!")

    abs_pose_i_gt = abs_pose_list_gt[i]
    abs_pose_j_gt = abs_pose_list_gt[j]
    rel_pose_ij_gt = th.SE3.compose(abs_pose_j_gt.inverse(), abs_pose_i_gt)
    rel_pose_ij_gt.name = f"EDGE_SE3_GT__{i}_{j}"

    relative_pose_noise = th.SE3.exp_map(
            torch.cat([
                measurement_noise * (2.0 * torch.rand(batch_size, 3) - 1),
                measurement_noise * (2.0 * torch.rand(batch_size, 3) - 1),
            ] ,dim=1),
            )
    rel_pose_ij = cast(th.SE3, relative_pose_noise.compose(rel_pose_ij_gt))
    rel_pose_ij.name = f"EDGE_SE3__{i}_{j}"

    cost_weight = th.ScaleCostWeight(weight, name=f"EDGE_WEIGHT__{i}_{j}")
    edge = GaussianSLAMEdge(i, j, rel_pose_ij, cost_weight)
    edge_list.append(edge)

In [5]:
point_list, abs_pose_gt_list, abs_pose_list, edge_list = create_data(
    num_poses=10,
    num_pts=100,
    rotation_noise=1e-3,
    translation_noise=1e-3
)
add_loop_data(0, 7, abs_pose_gt_list, edge_list)
add_loop_data(1, 8, abs_pose_gt_list, edge_list)
add_loop_data(2, 9, abs_pose_gt_list, edge_list)
#add_loop_data(0, 9, abs_pose_gt_list, edge_list)


rot_error_before = []
trans_error_before = []
for idx in range(len(abs_pose_list)):
    abs_pose = abs_pose_list[idx]
    rot = abs_pose.rotation().tensor.squeeze()
    trans = abs_pose.translation().tensor
    #trans = abs_pose.to_x_y_z_quaternion()[..., :3]
    
    abs_pose_gt = abs_pose_gt_list[idx]
    rot_gt = abs_pose_gt.rotation().tensor.squeeze()
    #trans_gt = abs_pose_gt.translation().tensor
    trans_gt = abs_pose_gt.to_x_y_z_quaternion()[..., :3]
    
    rot_error_before.append(torch.acos((torch.trace(torch.matmul(rot.t(), rot_gt)) - 1) / 2))
    #trans_error_before.append(torch.abs(trans @ trans_gt.t().squeeze()))
    trans_error_before.append(torch.norm(trans - trans_gt))


print("Constructing a pose graph for Gaussian Splatting SLAM.")
pose_graph = GaussianSLAMPoseGraph(requires_auto_grad=True)

for edge in edge_list:
    vertex_idx_i = edge.vertex_idx_i
    vertex_idx_j = edge.vertex_idx_j
    
    vertex_i = abs_pose_list[vertex_idx_i]
    vertex_j = abs_pose_list[vertex_idx_j]

    if vertex_idx_j - vertex_idx_i == 1:
        print("adding an odometry edge to pose graph.")
        pose_graph.add_odometry_edge(vertex_i, vertex_j, edge, point_list[vertex_idx_i].tensor)
    else:
        print("adding an loop edge to pose graph.")
        inlier_idx, num_matches = GaussianSLAMPoseGraph.match_gaussian_means(
            point_list[vertex_idx_i].tensor, point_list[vertex_idx_j].tensor, edge.relative_pose.to_matrix().squeeze(), epsilon=5e-2)
        inlier_idx_i = [idx[0] for idx in inlier_idx]
        pose_graph.add_loop_closure_edge(vertex_i, vertex_j, edge, point_list[vertex_idx_i].tensor[inlier_idx_i, :], num_matches)


Constructing a pose graph for Gaussian Splatting SLAM.
adding an odometry edge to pose graph.
adding an odometry edge to pose graph.
adding an odometry edge to pose graph.
adding an odometry edge to pose graph.
adding an odometry edge to pose graph.
adding an odometry edge to pose graph.
adding an odometry edge to pose graph.
adding an odometry edge to pose graph.
adding an odometry edge to pose graph.
adding an loop edge to pose graph.
cost_weight_mu = 2.0
adding an loop edge to pose graph.
cost_weight_mu = 2.0
adding an loop edge to pose graph.
cost_weight_mu = 2.0


In [6]:
#print(pose_graph._objective.error().shape)
pose_graph._optimize(max_iterations=1e3, step_size=0.01, damping=0.01, verbose=False)
#info = pose_graph.optimize_two_steps(max_iterations=1e3, step_size=0.01, l_ij_threshold=0.25, damping=0.2, verbose=False)
#print(info)

NonlinearOptimizerInfo(best_solution={'VERTEX_SE3__0': tensor([[[ 1.0000e+00,  1.7481e-03, -3.3121e-03,  2.1088e-03],
         [-1.7480e-03,  1.0000e+00,  4.9391e-05, -4.0411e-03],
         [ 3.3122e-03, -4.3601e-05,  1.0000e+00, -1.8621e-03]]]), 'VERTEX_SE3__1': tensor([[[ 0.5511, -0.2059, -0.8086,  0.3082],
         [-0.0522,  0.9587, -0.2797, -0.1478],
         [ 0.8328,  0.1964,  0.5175,  0.1600]]]), 'VERTEX_SE3__2': tensor([[[ 0.0159,  0.2135, -0.9768,  0.4148],
         [ 0.5701,  0.8006,  0.1843, -0.1026],
         [ 0.8214, -0.5598, -0.1089,  0.2680]]]), 'VERTEX_SE3__3': tensor([[[-0.3713,  0.7687, -0.5209,  0.5513],
         [ 0.4739,  0.6393,  0.6056, -0.1522],
         [ 0.7985, -0.0220, -0.6016,  0.7392]]]), 'VERTEX_SE3__4': tensor([[[-0.3223, -0.0770, -0.9435,  0.3214],
         [ 0.2711,  0.9474, -0.1699, -0.1065],
         [ 0.9070, -0.3105, -0.2845,  0.4638]]]), 'VERTEX_SE3__5': tensor([[[-0.9503, -0.1541, -0.2705,  0.2047],
         [-0.3088,  0.5778,  0.7555,  0.0421]

In [7]:
rot_error_after = []
trans_error_after = []
for idx in range(len(abs_pose_list)):  
    abs_pose_optimized = pose_graph._objective.get_optim_var(f"VERTEX_SE3__{idx}")
    rot_opt = abs_pose_optimized.rotation().tensor.squeeze()
    #trans_opt = abs_pose_optimized.translation().tensor
    trans_opt = abs_pose_optimized.to_x_y_z_quaternion()[..., :3]
    
    abs_pose_gt = abs_pose_gt_list[idx]
    rot_gt = abs_pose_gt.rotation().tensor.squeeze()
    #trans_gt = abs_pose_gt.translation().tensor
    trans_gt = abs_pose_gt.to_x_y_z_quaternion()[..., :3]

    rot_error_after.append(torch.acos((torch.trace(torch.matmul(rot_opt.t(), rot_gt)) - 1) / 2))
    #trans_error_after.append(torch.abs(trans_opt @ trans_gt.t().squeeze()))
    trans_error_after.append(torch.norm(trans_opt - trans_gt))

for idx in range(len(abs_pose_list)):  
    print(f"Rotation error of vertex_{idx} before optimization {rot_error_before[idx]}, after optimization: {rot_error_after[idx]}")
    print(f"Translation error of vertex_{idx} before optimization {trans_error_before[idx]}, after optimization: {trans_error_after[idx]}")
    print()

Rotation error of vertex_0 before optimization 0.0, after optimization: 0.0
Translation error of vertex_0 before optimization 0.0, after optimization: 0.0049241213127970695

Rotation error of vertex_1 before optimization nan, after optimization: nan
Translation error of vertex_1 before optimization 0.0007512992597185075, after optimization: 0.002209974220022559

Rotation error of vertex_2 before optimization nan, after optimization: 0.003201875602826476
Translation error of vertex_2 before optimization 0.0008897980442270637, after optimization: 0.0023524491116404533

Rotation error of vertex_3 before optimization 0.0009765625, after optimization: 0.0037822125013917685
Translation error of vertex_3 before optimization 0.0012935486156493425, after optimization: 0.0027890850324183702

Rotation error of vertex_4 before optimization 0.001691456069238484, after optimization: 0.004085256718099117
Translation error of vertex_4 before optimization 0.0021228499244898558, after optimization: 0.00

In [8]:
# test match_gaussian_means function
pose_0, pose_1 = abs_pose_list[0], abs_pose_list[1]
pose_01 = th.SE3.compose(pose_1.inverse(), pose_0).to_matrix().squeeze()
idx, num = GaussianSLAMPoseGraph.match_gaussian_means(point_list[0].tensor, point_list[1].tensor, pose_01)
num

100

In [9]:
print(pose_graph._objective.size_aux_vars())
print(pose_graph._objective.size_variables())

1227
13
