In [1]:
import numpy as np
import torch
import theseus as th
import torchlie.functional as lieF # use this instead of th.SE3
import matplotlib.pyplot as plt
from scipy.spatial import KDTree
from scipy.spatial.transform import Rotation
from typing import Union, List, Tuple, Optional, cast, Dict

def torch2np(tensor: torch.Tensor) -> np.ndarray:
    """ Converts a PyTorch tensor to a NumPy ndarray.
    Args:
        tensor: The PyTorch tensor to convert.
    Returns:
        A NumPy ndarray with the same data and dtype as the input tensor.
    """
    return tensor.detach().cpu().numpy()


def to_skew_symmetric(tensor: torch.Tensor):
    """
    Transform a (3, ) tensor to a (3, 3) tensor, or
    Transform a (num_pts, 3) tensor to a (num_pts, 3, 3) tensor, or
    Transform a (batch_size, num_pts, 3) tensor to a (batch_size, num_pts, 3, 3) tensor

    Args:
        tensor (torch.Tensor): 3d point cloud(s) that need(s) to be transformed to skew symmetric matrix

    Returns:
        skew_symmetric (torch.Tensor): transformed skew symmetric matrices
    """

    tensor_shape = tensor.shape
    if len(tensor_shape) > 3 or tensor_shape[-1] != 3:
        raise ValueError("Incorrect tensor dimension!")

    if len(tensor_shape) == 1:
        skew_symmetric = tensor.new_zeros((1, )+tensor_shape+(3, ))
    else:
        skew_symmetric = tensor.new_zeros(tensor_shape+(3, ))

    skew_symmetric[..., 0, 1] = -tensor[..., 2]
    skew_symmetric[..., 0, 2] = tensor[..., 1]
    skew_symmetric[..., 1, 0] = tensor[..., 2]
    skew_symmetric[..., 1, 2] = -tensor[..., 0]
    skew_symmetric[..., 2, 0] = -tensor[..., 1]
    skew_symmetric[..., 2, 1] = tensor[..., 0]

    return skew_symmetric




In [2]:
class GaussianSLAMEdge:
    def __init__(
        self,
        vertex_idx_i: int,
        vertex_idx_j: int,
        relative_pose: th.SE3,
        cost_weight: th.CostWeight
    ):
        self.vertex_idx_i = vertex_idx_i
        self.vertex_idx_j = vertex_idx_j
        self.relative_pose = relative_pose
        self.cost_weight = cost_weight


class GaussianSLAMPoseGraph:
    def __init__(
        self, 
        requires_auto_grad = True
    ):
        self._requires_auto_grad = requires_auto_grad
        self._objective = th.Objective()
        self._theseus_inputs = {} 

    def add_odometry_edge(
            self,
            vertex_i: th.SE3,
            vertex_j: th.SE3,
            edge: GaussianSLAMEdge,
            gaussian_means: torch.Tensor
        ):

        # (batch_size, num_pts, 3)
        gaussian_means_th = th.Variable(
            tensor=gaussian_means.unsqueeze(0), 
            name=f"gaussian_means_odometry__{edge.vertex_idx_i}_{edge.vertex_idx_j}"
        )

        if self._requires_auto_grad:
            cost_function = th.AutoDiffCostFunction(
                optim_vars=[vertex_i, vertex_j], 
                err_fn=GaussianSLAMPoseGraph.dense_surface_alignment, 
                dim=1, 
                cost_weight=edge.cost_weight, 
                aux_vars=[edge.relative_pose, gaussian_means_th]
            )
            self._objective.add(cost_function)
            self._theseus_inputs.update({
                vertex_i.name: vertex_i.tensor, 
                vertex_j.name: vertex_j.tensor
            })
        else:
            raise NotImplementedError()

    def add_loop_closure_edge(
            self,
            vertex_i: th.SE3,
            vertex_j: th.SE3,
            edge: GaussianSLAMEdge,
            gaussian_means: torch.tensor,
            match_num : int, # kapa
            tau: float=0.2, # fairly liberal distance threshold
        ):
        
        cost_weight_registration = edge.cost_weight # for dense surface alignment
        cost_weight_mu = cost_weight_registration.scale.tensor.squeeze() * np.sqrt(match_num) * tau
        print(f"cost_weight_mu = {cost_weight_mu}")
        cost_weight_line_process = th.ScaleCostWeight(cost_weight_mu) # for line process

        l_ij = th.Vector(tensor=torch.ones(1, 1), name=f"line_process_{edge.vertex_idx_i}_{edge.vertex_idx_j}")

        gaussian_means_th = th.Variable(tensor=gaussian_means.unsqueeze(0), name=f"gaussian_means_odometry__{edge.vertex_idx_i}_{edge.vertex_idx_j}")

        if self._requires_auto_grad:
            cost_function_registration = th.AutoDiffCostFunction(
                optim_vars=[vertex_i, vertex_j, l_ij], 
                err_fn=GaussianSLAMPoseGraph.dense_surface_alignment, 
                dim=1, 
                cost_weight=cost_weight_registration, 
                aux_vars=[edge.relative_pose, gaussian_means_th]
            )
            self._objective.add(cost_function_registration)

            cost_function_line_process = th.AutoDiffCostFunction(
                optim_vars=[l_ij,], 
                err_fn=GaussianSLAMPoseGraph.line_process, 
                dim=1, 
                cost_weight=cost_weight_line_process
            )
            self._objective.add(cost_function_line_process)
        
            self._theseus_inputs.update({
                vertex_i.name: vertex_i.tensor, 
                vertex_j.name: vertex_j.tensor,
                l_ij.name: l_ij.tensor
            })
        else:
            raise NotImplementedError()

    def _remove_loop_outlier(self, threshold: float, substring="line_process"):
        """ 
        find all l_ij inside the objective, remove those whose value are smaller than threshold 
        and all cost functions that are connected to them
        (currently implemented with the help of a dictionary called self._theseus_inputs, can it be
        directly done with self._objective?)
        """
        for key in self._theseus_inputs.keys():
            if substring in key and self._objective.get_optim_var(key).tensor < threshold:
                del self._theseus_inputs[key]
                for cost_func in self._objective.get_functions_connected_to_optim_var(key):
                    self._objective.erase(cost_func.name)

    def _optimize(self, 
            max_iterations=1e3, 
            step_size=0.01, 
            damping=0.1,
            track_best_solution=True, 
            verbose=False
        ):
        optimizer = th.LevenbergMarquardt(
            objective=self._objective,
            max_iterations=max_iterations,
            step_size=step_size
        )
        layer = th.TheseusLayer(optimizer)

        with torch.no_grad():
            _, info = layer.forward(
                self._theseus_inputs, 
                optimizer_kwargs={"damping":damping, "track_best_solution":track_best_solution, "verbose":verbose}
                )
        return info

    def optimize_two_steps(
            self, 
            max_iterations=1e3, 
            step_size=0.01, 
            l_ij_threshold=0.25,
            damping=0.1,
            track_best_solution=True, 
            verbose=False
        ):
        """
        optimization in two steps: 
        1. optimize with initial guess of all optim variables (T_i, l_ij)
        2. remove all l_ij < threshold, and all cost functions that are connected to them,
           optimize again with optimized variables
        """

        print(f"First step optimization, dealing with {self._objective.size_cost_functions()} cost functions")
        info = self._optimize(max_iterations, step_size, damping, track_best_solution, verbose)
        self._remove_loop_outlier(threshold=l_ij_threshold)
        print(f"Second step optimization, dealing with {self._objective.size_cost_functions()} cost functions")
        # TODO: if all loops are true, no need to do second step optimization
        info = self._optimize(max_iterations, step_size, damping, track_best_solution,verbose)
        return info

    @ staticmethod
    def match_gaussian_means(
        pts_1: torch.tensor,
        pts_2: torch.tensor,
        transformation: torch.tensor,
        epsilon:float=5e-2
    ) -> List[Tuple[int, int]]:
        """
        Select inlier correspondences from two Gaussian clouds, use kd-tree to speed up
    
        Args:
            pts_1, pts_2: mean positions of 3D Gaussians
            transformation: prior transformation matrix from one Gaussian cloud to the other
            epsilon: threshold for finding inlier correspondence
    
        Returns:
            a list contains tuples of matching indices
        """
        if transformation.size() != torch.Size([4, 4]):
            raise ValueError(f"The size of input transformation matrix must be (4, 4), but get {transformation.size()}")
    
        if pts_1.size(-1) != 1:
            pts_1 = pts_1.unsqueeze(-1)
    
        if isinstance(pts_1, th.Point3) or isinstance(pts_2, th.Point3):
                raise TypeError("To be matched points must be torch.Tensor")
    
        rotation = transformation[:3, :3]
        translation = transformation[:3, 3]
        pts_1_new = (rotation @ pts_1).squeeze() + translation
    
        pts_1_numpy = torch2np(pts_1_new)
        pts_2_numpy = torch2np(pts_2)
        pts2_kdtree = KDTree(pts_2_numpy)
    
        _, query_idx = pts2_kdtree.query(pts_1_numpy, distance_upper_bound=epsilon, workers=-1)
    
        data_size = pts_1.size()[0]
        res_list = []
        for i in range(data_size):
            if query_idx[i] != data_size:
                res_list.append((i, query_idx[i]))
    
        return res_list, len(res_list)

    @ staticmethod
    def dense_surface_alignment(
        optim_vars: Union[Tuple[th.SE3, th.SE3], Tuple[th.SE3, th.SE3, th.Vector]],
        aux_vars: Tuple[th.SE3, th.Variable]
    ) -> torch.Tensor:
        """
        Compute the dense surface alignment error between two vertices, can be used as the error
        function input to instantiate a th.CostFunction variable

        Args:
            optim_vars: optimizaiton variables registered in cost function, should contain
                pose_i, pose_j: correction matrix for pose i, j
                l_ij (optional): line process coefficient

            aux_vars: auxiliary variables registered in cost function, should contain
                relative_pose: constraint between vertex_i and vertex_j
                gaussian_means_i: mean positions of the 3D Gaussians inside camera frustum, 
                    represented in coordinate i and coordinate (those in coordinate j are not needed),
                    shape = (batch_size, num_pts, dim)

        Returns:
            square root of global place recognition error
        """
        # determine whether the edge is odometry edge or loop closure edge
        tuple_size = len(optim_vars)
        if tuple_size == 2:
            pose_i, pose_j = optim_vars
        elif tuple_size == 3:
            pose_i, pose_j, l_ij = optim_vars
        else:
            raise ValueError(f"optim_vars tuple size is {tuple_size}, which can only be 2 or 3.")
        pose_ij_measurement, gaussian_means = aux_vars

        pose_ij_odometry : th.SE3 = pose_j.inverse().compose(pose_i) # (batch_size, 3, 4)
        pose_residual : th.SE3 = pose_ij_measurement.inverse().compose(pose_ij_odometry) # (batch_size, 3, 4)

        rot_residual = pose_residual.rotation().log_map().unsqueeze(1) # (batch_size, 1, 3)
        trans_residual = pose_residual.translation().tensor.unsqueeze(1) # (batch_size, 1, 3)
        xi = torch.cat((rot_residual, trans_residual), dim=-1) # (batch_size, 1, 6)
        
        p_skew_symmetric = to_skew_symmetric(gaussian_means.tensor) # (batch_size, num_pts, 3, 3)
        # tensor.expand() will not allocate new memory, modification on one sample will change values for all,
        # use tensor.repeat() instead
        G_p = torch.cat(( # (batch_size, num, 3, 6)
            -p_skew_symmetric, 
            #torch.eye(3).repeat(gaussian_means.shape[0], gaussian_means.shape[-2], 1, 1)
            torch.tile(torch.eye(3), (gaussian_means.shape[0], gaussian_means.shape[-2], 1, 1)),
            ), dim=-1)
        Lambda = torch.sum(G_p.transpose(-2, -1) @ G_p, axis=1) # (batch_size, 6, 6)
        res = (xi @ Lambda @ xi.transpose(-2, -1)).squeeze(1) # (batch_size, 1)
        
        if tuple_size == 3:
            return l_ij.tensor.sqrt() * res.sqrt()
        else:
            return res.sqrt()
        
    @ staticmethod
    def line_process(optim_vars: th.Vector, aux_vars=None) -> torch.Tensor:
        """
        Computes the line process error of a loop closrue edge, can be used as the error
        input to instantiate a th.CostFunction variable

        Args:
            optim_vars:
                l_ij: jointly optimized weight (l_ij ∈ [0, 1]) over the loop edges
                (note that the scaling factor mu is considered as cost_weight)

        Returns:
            square root of line process error
        """
        l_ij, = optim_vars
        return l_ij.tensor.sqrt() - 1

In [3]:
def create_data(
        num_pts: int = 1000, 
        num_poses: int = 10, 
        translation_noise: float = 0.05, 
        rotation_noise: float = 0.1, 
        weight = 1.0,
        batch_size: int = 1,
        #dtype = torch.float32 # will get error if changed to torch.float64, don't know why
        ) -> Tuple[List[th.Point3], List[th.SE3], List[th.SE3], List[GaussianSLAMEdge]]:
    """
    create point clouds represented in different coordinates, record their ground truth 
    absolute pose, noisy absolute pose, also return an empty list to put loop edges

    Returns:
        point_list: a list stores points clouds, represented in different coordinates
        abs_pose_list_gt: a list stores ground truth absolute poses
        abs_pose_list: a list stores noisy (odometry) absolute poses
        edge_list: a list stores custum GaussianSLAMEdge
        TODO: Do I need to put the first edge that connets vertex_0 and vertex_1 into the list?
    """

    points_0 = th.Point3(2*torch.rand(num_pts, 3)-1, name="POINT_CLOUD__0") # initial points in world frame
    point_list = [points_0] # represented in different frames
    abs_pose_list_gt = [] # frame i to world frame
    abs_pose_list = [] # frame i to world frame (noisy)
    edge_list = []

    abs_pose_list_gt.append(th.SE3(
        tensor=torch.tile(torch.eye(3, 4), [1, 1, 1]),
        name="VERTEX_SE3_GT__0"
        ))
    
    abs_pose_list.append(th.SE3(
        tensor=torch.tile(torch.eye(3, 4), [1, 1, 1]),
        name="VERTEX_SE3__0"
        ))

    for idx in range(1, num_poses):

        # ground truth relative pose from frame_{idx-1} to frame_{idx}
        relative_pose_gt = th.SE3.exp_map(
            torch.cat([torch.rand(batch_size, 3)-0.5, 2.0 * torch.rand(batch_size, 3)-1], dim=1),
        )

        # generate points represented in frame_{idx}
        points = relative_pose_gt.transform_from(point_list[-1])
        points.name = f"POINT_CLOUD__{idx}"
        point_list.append(points)

        # add noise to get odometry relative pose from frame_{idx-1} to frame_{idx}
        relative_pose_noise = th.SE3.exp_map(
            torch.cat([
                translation_noise * (2.0 * torch.rand(batch_size, 3) - 1),
                rotation_noise * (2.0 * torch.rand(batch_size, 3) - 1),
            ] ,dim=1),
        )

        relative_pose = cast(th.SE3, relative_pose_noise.compose(relative_pose_gt))
        relative_pose.name = f"EDGE_SE3__{idx-1}_{idx}"
        cost_weight = th.ScaleCostWeight(weight, name=f"EDGE_WEIGHT__{idx-1}_{idx}")

        # absolute pose of frame_{idx}
        absolute_pose_gt = cast(th.SE3, abs_pose_list_gt[-1].compose(relative_pose_gt.inverse()))
        absolute_pose_gt.name = f"VERTEX_SE3_GT__{idx}"

        absolute_pose = cast(th.SE3, abs_pose_list[-1].compose(relative_pose.inverse()))
        absolute_pose.name = f"VERTEX_SE3__{idx}"

        abs_pose_list_gt.append(absolute_pose_gt)
        abs_pose_list.append(absolute_pose)

        # construct odometry edge between vertex_{idx-1} and vertex_{idx}
        edge_list.append(GaussianSLAMEdge(idx-1, idx, relative_pose, cost_weight))

    return point_list, abs_pose_list_gt, abs_pose_list, edge_list


def add_loop_data(
        i: int, 
        j: int, 
        abs_pose_list_gt: List[th.SE3], 
        edge_list: List[GaussianSLAMEdge],
        weight: float = 2.0,
        measurement_noise:float = 0.001,
        batch_size: int = 1,
        ) -> None:
    """
    Add loop closure between two arbitray coordinates i and j (i < j), and stores generated edge
    """

    if i >= j:
        raise ValueError(f"The first frame index {i} is greater than the second frame index {j}!")

    abs_pose_i_gt = abs_pose_list_gt[i]
    abs_pose_j_gt = abs_pose_list_gt[j]
    rel_pose_ij_gt = th.SE3.compose(abs_pose_j_gt.inverse(), abs_pose_i_gt)
    rel_pose_ij_gt.name = f"EDGE_SE3_GT__{i}_{j}"

    relative_pose_noise = th.SE3.exp_map(
            torch.cat([
                measurement_noise * (2.0 * torch.rand(batch_size, 3) - 1),
                measurement_noise * (2.0 * torch.rand(batch_size, 3) - 1),
            ] ,dim=1),
            )
    rel_pose_ij = cast(th.SE3, relative_pose_noise.compose(rel_pose_ij_gt))
    rel_pose_ij.name = f"EDGE_SE3__{i}_{j}"

    cost_weight = th.ScaleCostWeight(weight, name=f"EDGE_WEIGHT__{i}_{j}")
    edge = GaussianSLAMEdge(i, j, rel_pose_ij, cost_weight)
    edge_list.append(edge)

In [4]:
point_list, abs_pose_gt_list, abs_pose_list, edge_list = create_data()
add_loop_data(0, 7, abs_pose_gt_list, edge_list)
add_loop_data(1, 8, abs_pose_gt_list, edge_list)
add_loop_data(2, 9, abs_pose_gt_list, edge_list)

rot_error_before = []
trans_error_before = []
for idx in range(len(abs_pose_list)):
    abs_pose = abs_pose_list[idx]
    rot = abs_pose.rotation().tensor.squeeze()
    trans = abs_pose.translation().tensor
    
    abs_pose_gt = abs_pose_gt_list[idx]
    rot_gt = abs_pose_gt.rotation().tensor.squeeze()
    trans_gt = abs_pose.translation().tensor
    
    rot_error_before.append(torch.acos((torch.trace(torch.matmul(rot.t(), rot_gt)) - 1) / 2))
    trans_error_before.append(torch.abs(trans @ trans_gt.t().squeeze()))


print("Constructing a pose graph for Gaussian Splatting SLAM.")
pose_graph = GaussianSLAMPoseGraph(requires_auto_grad=True)

for idx in range(len(edge_list)):
    edge = edge_list[idx]
    vertex_idx_i = edge.vertex_idx_i
    vertex_idx_j = edge.vertex_idx_j
    
    vertex_i = abs_pose_list[vertex_idx_i]
    vertex_j = abs_pose_list[vertex_idx_j]

    if vertex_idx_j - vertex_idx_i == 1:
        print(f"adding edge {idx} to pose graph, current edge is an odometry edge.")
        pose_graph.add_odometry_edge(vertex_i, vertex_j, edge, point_list[idx].tensor)
    else:
        print(f"adding edge {idx} to pose graph, current edge is an loop edge.")
        inlier_idx, num_matches = GaussianSLAMPoseGraph.match_gaussian_means(
            point_list[vertex_idx_i].tensor, point_list[vertex_idx_j].tensor, edge.relative_pose.to_matrix().squeeze(), epsilon=5e-2)
        inlier_idx_i = [idx_inlier[0] for idx_inlier in inlier_idx]
        pose_graph.add_loop_closure_edge(vertex_i, vertex_j, edge, point_list[vertex_idx_i].tensor[inlier_idx_i, :], num_matches, tau=0.2)


Constructing a pose graph for Gaussian Splatting SLAM.
adding edge 0 to pose graph, current edge is an odometry edge.
adding edge 1 to pose graph, current edge is an odometry edge.
adding edge 2 to pose graph, current edge is an odometry edge.
adding edge 3 to pose graph, current edge is an odometry edge.
adding edge 4 to pose graph, current edge is an odometry edge.
adding edge 5 to pose graph, current edge is an odometry edge.
adding edge 6 to pose graph, current edge is an odometry edge.
adding edge 7 to pose graph, current edge is an odometry edge.
adding edge 8 to pose graph, current edge is an odometry edge.
adding edge 9 to pose graph, current edge is an loop edge.
cost_weight_mu = 12.649110794067383
adding edge 10 to pose graph, current edge is an loop edge.
cost_weight_mu = 12.649110794067383
adding edge 11 to pose graph, current edge is an loop edge.
cost_weight_mu = 12.649110794067383


In [5]:
info = pose_graph.optimize_two_steps(max_iterations=1e3, step_size=0.02, l_ij_threshold=0.25, damping=0.1, verbose=False)
print(info)

First step optimization, dealing with 15 cost functions
Second step optimization, dealing with 15 cost functions
NonlinearOptimizerInfo(best_solution={'VERTEX_SE3__0': tensor([[[ 0.9967,  0.0627, -0.0524, -0.0174],
         [-0.0641,  0.9976, -0.0255, -0.0143],
         [ 0.0507,  0.0288,  0.9983, -0.0517]]]), 'VERTEX_SE3__1': tensor([[[ 0.7357, -0.5913,  0.3304,  0.0890],
         [ 0.1662,  0.6305,  0.7582,  0.2937],
         [-0.6566, -0.5029,  0.5621, -0.4849]]]), 'VERTEX_SE3__2': tensor([[[ 0.7741, -0.3741,  0.5106,  0.1566],
         [-0.0653,  0.7552,  0.6523,  0.6237],
         [-0.6296, -0.5383,  0.5602, -0.6817]]]), 'VERTEX_SE3__3': tensor([[[-0.1135, -0.4142,  0.9031,  0.0714],
         [ 0.4462,  0.7909,  0.4188,  0.6758],
         [-0.8877,  0.4505,  0.0951, -0.5051]]]), 'VERTEX_SE3__4': tensor([[[ 0.5421, -0.1870,  0.8192, -0.0935],
         [ 0.8339,  0.2397, -0.4971,  0.8443],
         [-0.1034,  0.9527,  0.2859, -0.0371]]]), 'VERTEX_SE3__5': tensor([[[ 0.1213, -0.9631,

In [6]:
rot_error_after = []
trans_error_after = []
for idx in range(len(abs_pose_list)):  
    abs_pose_optimized = pose_graph._objective.get_optim_var(f"VERTEX_SE3__{idx}")
    rot_opt = abs_pose_optimized.rotation().tensor.squeeze()
    trans_opt = abs_pose_optimized.translation().tensor
    
    abs_pose_gt = abs_pose_gt_list[idx]
    rot_gt = abs_pose_gt.rotation().tensor.squeeze()
    trans_gt = abs_pose.translation().tensor

    rot_error_after.append(torch.acos((torch.trace(torch.matmul(rot_opt.t(), rot_gt)) - 1) / 2))
    trans_error_after.append(torch.abs(trans_opt @ trans_gt.t().squeeze()))

for idx in range(len(abs_pose_list)):  
    print(f"Rotation error of vertex_{idx} before optimization {rot_error_before[idx]}, after optimization: {rot_error_after[idx]}")
    print(f"Translation error of vertex_{idx} before optimization {trans_error_before[idx]}, after optimization: {trans_error_after[idx]}")
    print()

Rotation error of vertex_0 before optimization 0.0, after optimization: 0.08538861572742462
Translation error of vertex_0 before optimization tensor([0.]), after optimization: tensor([0.0559])

Rotation error of vertex_1 before optimization 0.10006792098283768, after optimization: 0.13244080543518066
Translation error of vertex_1 before optimization tensor([0.3377]), after optimization: tensor([0.0175])

Rotation error of vertex_2 before optimization 0.07194383442401886, after optimization: 0.10689190775156021
Translation error of vertex_2 before optimization tensor([0.9590]), after optimization: tensor([0.2433])

Rotation error of vertex_3 before optimization 0.13692331314086914, after optimization: 0.1700340062379837
Translation error of vertex_3 before optimization tensor([0.7582]), after optimization: tensor([0.4397])

Rotation error of vertex_4 before optimization 0.16074852645397186, after optimization: 0.18198218941688538
Translation error of vertex_4 before optimization tensor(

In [7]:
torch.eye(3).repeat(1, 2, 1, 1).shape

torch.Size([1, 2, 3, 3])

In [8]:
torch.tile(torch.eye(3), (1, 2, 1, 1)).shape

torch.Size([1, 2, 3, 3])

In [9]:
print(pose_graph._objective.size_aux_vars())
print(pose_graph._objective.size_variables())

39
13
