In [None]:
from airo_drake.prebuilt_diagrams import make_ur3e
from IPython.display import SVG, clear_output, display
import pydot

# Import some basic libraries and functions for this tutorial.
import numpy as np
import os
from pydrake.geometry import (
    MeshcatVisualizer,
    MeshcatVisualizerParams,
    Role,
    StartMeshcat,
)
from pydrake.math import RigidTransform, RollPitchYaw
from pydrake.multibody.meshcat import JointSliders
from pydrake.multibody.parsing import Parser
from pydrake.systems.analysis import Simulator
from pydrake.all import (
    InverseDynamicsController,
    PiecewisePose,
    Rgba,
    LeafSystem,
    JacobianWrtVariable,
    TrajectorySource,
    PassThrough,
    Demultiplexer,
    RevoluteJoint,
    StateInterpolatorWithDiscreteDerivative,
    Integrator,
    Meshcat,
    SpatialInertia
)
from pydrake.systems.framework import Diagram, DiagramBuilder, System
from pydrake.common import FindResourceOrThrow, temp_directory
from pydrake.multibody.plant import AddMultibodyPlantSceneGraph, MultibodyPlant
import time

In [None]:
meshcat = StartMeshcat()
temp_dir = temp_directory()

In [None]:
# Create a table top SDFormat model.
table_top_sdf_file = os.path.join(temp_dir, "table_top.sdf")
table_top_sdf = """<?xml version="1.0"?>
<sdf version="1.7">
  <model name="table_top">
    <link name="table_top_link">
      <visual name="visual">
        <pose>0 0 0 0 0 0</pose>
        <geometry>
          <box>
            <size>1.2 1.0 0.02</size>
          </box>
        </geometry>
        <material>
         <diffuse>0.9 0.8 0.7 1.0</diffuse>
        </material>
      </visual>
      <collision name="collision">
        <pose>0 0 0 0 0 0</pose>
        <geometry>
          <box>
            <size>1.2 1.0 0.02</size>
          </box>
        </geometry>
      </collision>
    </link>
    <frame name="table_top_center">
      <pose relative_to="table_top_link">0 0 0.01 0 0 0</pose>
    </frame>
    <frame name="table_top_left">
      <pose relative_to="table_top_link">-0.40 0 0.01 0 0 0</pose>
    </frame>
  </model>
</sdf>

"""

with open(table_top_sdf_file, "w") as f:
    f.write(table_top_sdf)

In [None]:
class PseudoInverseController(LeafSystem):
    def __init__(self, plant):
        LeafSystem.__init__(self)
        self._plant = plant
        self._plant_context = plant.CreateDefaultContext() # this controller has its own copy of the context
        self._robot = plant.GetModelInstanceByName("ur3e")
        # self._G = plant.GetBodyByName("body").body_frame()
        self._G = plant.GetFrameByName("ur_ee_link")
        self._W = plant.world_frame()

        self.V_G_port = self.DeclareVectorInputPort("V_WG", 6)
        self.q_port = self.DeclareVectorInputPort("robot_position", 6)
        self.DeclareVectorOutputPort("robot_velocity", 6, self.CalcOutput)
        self.start_index = plant.GetJointByName("ur_shoulder_pan_joint").velocity_start()
        self.end_index = plant.GetJointByName("ur_wrist_3_joint").velocity_start()
        print(self.start_index, self.end_index)

    def CalcOutput(self, context, output):
        V_G = self.V_G_port.Eval(context)
        q = self.q_port.Eval(context)
        self._plant.SetPositions(self._plant_context, self._robot, q)
        J_G = self._plant.CalcJacobianSpatialVelocity(
            self._plant_context, JacobianWrtVariable.kV,
            self._G, [0,0,0], self._W, self._W)

        J_G = J_G[:,self.start_index:self.end_index+1] # Only robot terms.
        v = np.linalg.pinv(J_G).dot(V_G)
        output.SetFromVector(v)

In [None]:
ur3e_diagram = make_ur3e(additional_model_files=[table_top_sdf_file])

builder = DiagramBuilder()
station = builder.AddSystem(ur3e_diagram)
plant = station.GetSubsystemByName("plant")

p_Ginitial = np.array([0.131, -0.06307968978321705, 0.38314779900225227])
X_Ginitial = RigidTransform(p_Ginitial)

# p_G = p_Ginitial + [0.1, -0.2, -0.2]
p_G = p_Ginitial + [0.2, -0.1, 0.0]
X_G = RigidTransform(p_G)

poses = [X_Ginitial, X_G]
sample_times = [0.0, 2.0]

traj_X_G = PiecewisePose.MakeLinear(sample_times, poses)
traj_p_G = traj_X_G.get_position_trajectory()
traj_V_G = traj_X_G.MakeDerivative()

V_G_source = builder.AddSystem(TrajectorySource(traj_V_G))
# V_G_source.set_name("v_WG")
controller = builder.AddSystem(PseudoInverseController(plant))
controller.set_name("PseudoInverseController")
builder.Connect(V_G_source.get_output_port(), controller.GetInputPort("V_WG"))

n_dof = 6
integrator = builder.AddSystem(Integrator(n_dof))
integrator.set_name("integrator")
builder.Connect(controller.get_output_port(), integrator.get_input_port())
builder.Connect(integrator.get_output_port(), station.GetInputPort("robot_position"))
builder.Connect(station.GetOutputPort("robot_position_measured"), controller.GetInputPort("robot_position"))

meshcat.Delete()
visualizer = MeshcatVisualizer.AddToBuilder(builder, station.GetOutputPort("query_object"), meshcat)

p_G = traj_p_G.vector_values(traj_p_G.get_segment_times())
meshcat.SetLine('p_G', p_G, 2.0, rgba=Rgba(1, 0.65, 0))
print(p_G.shape)

def draw_frame(meshcat: Meshcat, frame: RigidTransform, size: float=0.05
    
):
    o = frame.translation()
    x, y, z = size * frame.rotation().matrix().T
    t = time.time()
    line_thickness = 100 * size
    meshcat.SetLine(f'X_{t}', np.array([o, o + x]).T, line_thickness, rgba=Rgba(1, 0, 0))
    meshcat.SetLine(f'Y_{t}', np.array([o, o + y]).T, line_thickness, rgba=Rgba(0, 1, 0))
    meshcat.SetLine(f'Z_{t}', np.array([o, o + z]).T, line_thickness, rgba=Rgba(0, 0, 1))


diagram = builder.Build()
diagram.set_name("pick_and_place")

simulator = Simulator(diagram)
context = simulator.get_mutable_context()
station_context = station.GetMyContextFromRoot(context)

plant_context = plant.GetMyMutableContextFromRoot(context)
q0 = np.array([-np.pi / 2, -3.0 / 4.0 * np.pi, np.pi / 2, -np.pi / 4, -np.pi / 2, 0])
plant.SetPositions(plant_context, plant.GetModelInstanceByName("ur3e"), q0)

eef_frame = plant.GetFrameByName("ur_ee_link").CalcPoseInWorld(plant_context)
print(eef_frame)
draw_frame(meshcat, eef_frame)

integrator.set_integral_value(
    integrator.GetMyContextFromRoot(context),
    plant.GetPositions(plant.GetMyContextFromRoot(context), plant.GetModelInstanceByName("ur3e")),
)


visualizer.StartRecording(False)
simulator.AdvanceTo(traj_p_G.end_time())
visualizer.PublishRecording()

context = simulator.get_context()
plant_context = plant.GetMyMutableContextFromRoot(context)
eef_frame = plant.GetFrameByName("ur_ee_link").CalcPoseInWorld(plant_context)
draw_frame(meshcat, eef_frame)

In [None]:
# TODO 

In [None]:
SVG(pydot.graph_from_dot_data(diagram.GetGraphvizString(max_depth=1))[0].create_svg())


In [None]:
SVG(pydot.graph_from_dot_data(ur3e_diagram.GetGraphvizString(max_depth=1))[0].create_svg())
