In [None]:
import batch_urdf
import trimesh
import torch
import numpy as np
import trimesh.transformations as tra
from typing import Dict, List
import pprint
import os
import shutil

In [None]:
large_urdf = batch_urdf.URDF(
    batch_size=1,
    urdf_path="./galbot_zero_description/galbot_zero_two_grippers.urdf"
)

In [None]:
print(large_urdf.joint_map.__len__())
print(large_urdf.actuated_joints_map.__len__())
print(large_urdf.link_map.__len__())

In [None]:
print(large_urdf.get_link_scene("left_gripper_inspire_tcp_frame"))
print(len(large_urdf.get_link_scene("left_gripper_inspire_tcp_frame").geometry))

In [None]:
class Link:
    def __init__(self, name: str, scene: trimesh.Scene, transform: np.ndarray) -> None:
        self.name = name
        
        self.parent_joint = None
        self.child_joint: List["Joint"] = []

        self.scene = scene
        self.transform = transform
    
    def __hash__(self):
        return self.name.__hash__()
    
    def set_parent_joint(self, joint: "Joint"):
        self.parent_joint = joint
    
    def add_child_joint(self, joint: "Joint"):
        self.child_joint.append(joint)

class Joint:
    def __init__(self, name: str, joint_type: str, axis: np.ndarray, origin: np.ndarray, limit: batch_urdf.Limit) -> None:
        self.name = name
        self.joint_type = joint_type

        self.parent = None
        self.child = None

        self.axis = axis
        self.origin = origin
        self.limit = limit
    
    def __hash__(self):
        return self.name.__hash__()

    def set_parent_link(self, link: "Link"):
        self.parent = link

    def set_child_link(self, link: "Link"):
        self.child = link

def connect(parent: Link, child: Link, joint: Joint):
    joint.set_parent_link(parent)
    joint.set_child_link(child)
    child.set_parent_joint(joint)
    parent.add_child_joint(joint)

def merge_scene(s1: trimesh.Scene, s2: trimesh.Scene, s2_origin: np.ndarray):
    return (
        s1 if len(s1.geometry) > 0 else trimesh.Scene()
    ) + (
        s2.apply_transform(s2_origin) if len(s2.geometry) > 0 else trimesh.Scene()
    )

def torch_to_numpy(t: torch.Tensor) -> np.ndarray:
    return t.detach().cpu().numpy().copy()

def collapse_edge(joint_name: str, merged_joints: Dict[str, Joint], merged_links: Dict[str, Link]):
    joint = merged_joints[joint_name]
    parent = joint.parent
    child = joint.child

    # remove joint and link
    parent.child_joint.remove(joint)
    merged_joints.pop(joint.name)
    merged_links.pop(child.name)

    # update parent link's scene
    parent.scene = merge_scene(parent.scene, child.scene, np.linalg.inv(parent.transform) @ child.transform)

    for child_joint in child.child_joint:
        # reconnect joint and link
        connect(
            parent,
            child_joint.child,
            child_joint
        )

        # udpate joint axis and origin
        child_joint.origin = joint.origin @ child_joint.origin
        child_joint.axis = child_joint.origin[:3, :3] @ child_joint.axis
        

In [None]:
merged_joints: Dict[str, Joint] = {}
merged_links: Dict[str, Link] = {}

for link_name in large_urdf.link_map.keys():
    merged_links[link_name] = Link(
        link_name, 
        large_urdf.get_link_scene(link_name), 
        torch_to_numpy(large_urdf.link_transform_map[link_name])[0, ...]
    )

for joint_name, joint in large_urdf.joint_map.items():
    merged_joints[joint_name] = Joint(
        joint_name,
        joint.type,
        torch_to_numpy(joint.axis)[0, ...],
        torch_to_numpy(joint.origin)[0, ...],
        joint.limit,
    )
    connect(
        merged_links[joint.parent],
        merged_links[joint.child],
        merged_joints[joint_name],
    )

for joint_name, joint in large_urdf.joint_map.items():
    if joint.type in ["fixed", "prismatic"]:
        if large_urdf.link_map[joint.child].name not in [
            "left_gripper_inspire_tcp_frame",
            "right_gripper_inspire_tcp_frame",
            "left_arm_camera_link",
            "right_arm_camera_link",
            "head_camera_normal_frame",
        ]:
            print(joint_name)
            collapse_edge(joint_name, merged_joints, merged_links)

In [None]:
pprint.pprint(merged_joints)

In [None]:
pprint.pprint(merged_links)

In [None]:
output_urdf_path = "./galbot_zero_description_simplified/urdf.urdf"

In [None]:
def export_urdf(output_urdf_path: str, merged_joints: Dict[str, Joint], merged_links: Dict[str, Link]):
    space_str = ' '
    def export_str_link(link: Link):
        if len(link.scene.geometry) > 0:
            rho = 1e3
            m: trimesh.Trimesh = link.scene.dump(True)
            i = trimesh.inertia.scene_inertia(link.scene, np.eye(4)) * rho
            return (
                f'\t<link name="{link.name}">\n' + 

                f'\t\t<inertial>\n' + 
                f'\t\t\t<origin rpy="0 0 0" xyz="{space_str.join(str(x) for x in m.center_mass)}"/>\n' + 
                f'\t\t\t<mass value="{m.volume * rho}"/>\n' + 
                f'\t\t\t<inertia ixx="{i[0, 0]}" ixy="{i[0, 1]}" ixz="{i[0, 2]}" iyy="{i[1, 1]}" iyz="{i[1, 2]}" izz="{i[2, 2]}"/>\n'
                f'\t\t</inertial>\n' + 

                f'\t\t<visual>\n' + 
                f'\t\t\t<origin rpy="0 0 0" xyz="0 0 0"/>\n' + 
                f'\t\t\t<geometry>\n' + 
                f'\t\t\t\t<mesh filename="meshes/{link.name}.stl"/>\n' + 
                f'\t\t\t</geometry>\n' + 
                f'\t\t</visual>\n' + 

                f'\t\t<collision>\n' + 
                f'\t\t\t<origin rpy="0 0 0" xyz="0 0 0"/>\n' + 
                f'\t\t\t<geometry>\n' + 
                f'\t\t\t\t<mesh filename="meshes/{link.name}.stl"/>\n' + 
                f'\t\t\t</geometry>\n' + 
                f'\t\t</collision>\n' + 
                f'\t</link>\n'
            )
        else:
            return f'\t<link name="{link.name}"/>\n'
    def export_str_joint(joint: Joint):
        if joint.limit is not None:
            limit_str = f'\t\t<limit effort="{joint.limit.effort}" lower="{joint.limit.lower}" upper="{joint.limit.upper}" velocity="{joint.limit.velocity}"/>\n'
        else:
            limit_str = ''
        return (
            f'\t<joint name="{joint.name}" type="{joint.joint_type}">\n' + 
            f'\t\t<origin rpy="{space_str.join(str(x) for x in tra.euler_from_matrix(joint.origin))}" xyz="{space_str.join(str(x) for x in joint.origin[:3, 3])}"/>\n' + 
            f'\t\t<parent link="{joint.parent.name}"/>\n' + 
            f'\t\t<child link="{joint.child.name}"/>\n' + 
            f'\t\t<axis xyz="{space_str.join(str(x) for x in joint.axis)}"/>\n' + 
            limit_str + 
            f'\t</joint>\n'
        )
    full_str = (
        '<?xml version="1.0" ?>\n' + 
        '<robot name="galbot_zero">\n' + 
        ''.join([export_str_link(l) for l in merged_links.values()]) + 
        ''.join([export_str_joint(j) for j in merged_joints.values()]) + 
        '</robot>\n'
    )

    print(output_urdf_path)
    os.makedirs(os.path.dirname(output_urdf_path))
    os.makedirs(os.path.join(os.path.dirname(output_urdf_path), "meshes"))

    with open(output_urdf_path, "w") as f_obj:
        f_obj.write(full_str.replace("\t", "  "))
    for link in merged_links.values():
        if len(link.scene.geometry) > 0:
            link.scene.export(
                os.path.join(os.path.dirname(output_urdf_path), "meshes", link.name + ".stl")
            )

In [None]:
export_urdf(output_urdf_path, merged_joints, merged_links)

In [None]:
simplified_urdf = batch_urdf.URDF(
    1,
    output_urdf_path,
)

In [None]:
simplified_urdf.get_scene(0).show()

In [None]:
cfg = {k: v.clone() for k, v in simplified_urdf.cfg.items()}

In [None]:
cfg["left_arm_joint1"][...] = 1.0
simplified_urdf.update_cfg(cfg)
simplified_urdf.get_scene(0).show()