In [84]:
import pydot
import numpy as np
from IPython.display import display, Javascript, SVG
from pydrake.examples.manipulation_station import ManipulationStation
from manipulation.scenarios import (
    AddIiwa, AddShape
)
from manipulation.meshcat_cpp_utils import (
    StartMeshcat, AddMeshcatTriad
)
from pydrake.all import (
    AddMultibodyPlantSceneGraph, DiagramBuilder, MeshcatVisualizerCpp, MeshcatVisualizerParams, Parser, 
    RollPitchYaw, RigidTransform, RevoluteJoint, Sphere, Simulator, InverseDynamicsController, MultibodyPlant,
    RotationMatrix, Rgba
)
import pydrake.all

In [2]:
meshcat = StartMeshcat()

In [85]:
def AddFloatingIiwa(plant, collision_model="no_collision"):
    sdf_path = "iiwa_rock_climbing/models/iiwa_description/iiwa7/" + \
        f"iiwa7_{collision_model}.sdf"

    parser = Parser(plant)
    iiwa = parser.AddModelFromFile(sdf_path)
    
    # Set default positions:
    q0 = [0.0, 0.1, 0, -1.2, 0, 1.6, 0]
    index = 0
    for joint_index in plant.GetJointIndices(iiwa):       
        joint = plant.get_mutable_joint(joint_index)
        if isinstance(joint, RevoluteJoint):
            joint.set_default_angle(q0[index])
            index += 1
        
    return iiwa

def AddFloatingBase(plant, iiwa):
    base = AddShape(plant, pydrake.geometry.Box(0.2,0.2,0.2), "mobile_base", 
                    mass=0.1, mu=1, color=[.5, .5, .9, 1.0])
    plant.WeldFrames(plant.GetFrameByName("mobile_base"), 
                     plant.GetFrameByName("iiwa_link_0"),
                     RigidTransform(RollPitchYaw(0, 0, 0), [0, 0, 0.1]))
    return base

In [103]:
def AddMeshcatSphere(meshcat,
                     path,
                     radius=0.01,
                     opacity=1.,
                     p_WP=np.array([0,0,0])):
    
    X_WP = RigidTransform(RotationMatrix(),
                          p_WP)
    meshcat.SetTransform(path, X_WP)
    meshcat.SetObject(path, Sphere(radius),
                      Rgba(1, 0, 0, opacity))

In [104]:
class IIWA_Climber():
    def __init__(self):
        builder = DiagramBuilder()
        
        self.plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)

        self.iiwa = AddFloatingIiwa(self.plant)
        base = AddFloatingBase(self.plant, self.iiwa)
        parser = Parser(self.plant)

        # gripper = parser.AddModelFromFile("models/gripper.sdf")
        # rock = parser.AddModelFromFile("models/rock.sdf")
        # plant.WeldFrames(
        #     plant.GetFrameByName("iiwa_link_7"), 
        #     plant.GetFrameByName("base_link"), 
        #     RigidTransform(RollPitchYaw(0, -np.pi/2, 0), [0, 0, 0.25]))

        self.plant.gravity_field().set_gravity_vector([0., 0., -0.1])    

        self.plant.Finalize()

        visualizer = MeshcatVisualizerCpp.AddToBuilder(
            builder, 
            scene_graph, 
            meshcat,
            MeshcatVisualizerParams(delete_prefix_initialization_event=False))
        
            
        # Set up controller ----------------------------------------
        controller_plant = MultibodyPlant(time_step = 1e-2)
        controller_iiwa = AddIiwa(controller_plant)
        # contror_lleiiwa = AddFloatingIiwa(controller_plant)
        # controller_base = AddFloatingBase(controller_plant, controller_iiwa)
        controller_parser = Parser(controller_plant)
        # controller_parser.AddModelFromFile("models/gripper.sdf")
        # controller_plant.WeldFrames(
        #     controller_plant.GetFrameByName("iiwa_link_7"),
        #     controller_plant.GetFrameByName("base_link"),
        #     RigidTransform(RollPitchYaw(0, -np.pi/2, 0), [0, 0, 0.25]))
        controller_plant.Finalize()

        Kp = np.full(7, 10)
        Ki = 2 * np.sqrt(Kp)
        Kd = np.full(7, 1)
        self.iiwa_controller = builder.AddSystem(InverseDynamicsController(controller_plant, Kp, Ki, Kd, False))
        self.iiwa_controller.set_name("iiwa_controller");
        builder.Connect(self.plant.get_state_output_port(self.iiwa),
                        self.iiwa_controller.get_input_port_estimated_state())
        builder.Connect(self.iiwa_controller.get_output_port_control(),
                        self.plant.get_actuation_input_port())
        
        # Build
        self.diagram = builder.Build()
        # context = diagram.CreateDefaultContext()
        # diagram.Publish(context)
        self.gripper_frame = self.plant.GetFrameByName("iiwa_link_7")
        self.world_frame = self.plant.world_frame()
        
        context = self.CreateDefaultContext()
        self.diagram.Publish(context)
                
        
    def CreateDefaultContext(self):
        context = self.diagram.CreateDefaultContext()
        plant_context = self.diagram.GetMutableSubsystemContext(self.plant, context)

        # Set initial positions 
        q0 = np.array([-1.57, 0.1, 0, -1.2, 0, 1.6, 0])
        x0 = np.hstack((q0, 0*q0))

        plant_context = self.plant.GetMyMutableContextFromRoot(context)

        self.plant.SetPositions(plant_context, self.iiwa, q0)
        # plant.SetPositions(plant_context, ball, np.array([1, 1, 1, 1, 1, 0, 5]))
        self.iiwa_controller.GetInputPort('desired_state').FixValue(self.iiwa_controller.GetMyMutableContextFromRoot(context), x0)

        
        
        
        
#         # provide initial states
#         q0 = np.array([ 1.40666193e-05,  1.56461165e-01, -3.82761069e-05, 
#                        -1.32296976e+00, -6.29097287e-06,  1.61181157e+00, -2.66900985e-05])
#         # set the joint positions of the kuka arm
#         self.station.SetIiwaPosition(station_context, q0)
#         self.station.SetIiwaVelocity(station_context, np.zeros(7))
#         self.station.SetWsgPosition(station_context, 0.1)
#         self.station.SetWsgVelocity(station_context, 0)

#         if hasattr(self, 'integrator'):
#             self.integrator.set_integral_value(
#                 self.integrator.GetMyMutableContextFromRoot(context), 
#                 self.station.GetIiwaPosition(station_context))

        return context
    
    
    def visualize_frame(self, name, X_WF, length=0.15, radius=0.006):
        """
        visualize imaginary frame that are not attached to existing bodies
        
        Input: 
            name: the name of the frame (str)
            X_WF: a RigidTransform to from frame F to world.
        
        Frames whose names already exist will be overwritten by the new frame
        """
        AddMeshcatTriad(meshcat, "climber/" + name,
                        length=length, radius=radius, X_PT=X_WF)
        
    def visualize_com(self, name, p_WP, radius=0.06):
        """
        visualize center of mass
        
        Input: 
            name: the name of the frame (str)
            p_WP: a position from a point P to world.
        
        Frames whose names already exist will be overwritten by the new frame
        """
        AddMeshcatSphere(meshcat, "climber/" + name,
                         radius=radius, p_WP=p_WP)
        
        
    def get_X_WG(self, context=None):
        if not context:
            context = self.CreateDefaultContext()
        plant_context = self.plant.GetMyMutableContextFromRoot(context)
        X_WG = self.plant.CalcRelativeTransform(
                    plant_context,
                    frame_A=self.world_frame,
                    frame_B=self.gripper_frame)
        return X_WG
    
    def run(self):
        context = self.CreateDefaultContext()
        simulator = Simulator(self.diagram, context)
        simulator.set_target_realtime_rate(0.25)
        simulator.AdvanceTo(2.0)

In [105]:
p0 = [0, 0, 0]
R0 = RotationMatrix(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=float).T)
X_WorldCenter = RigidTransform(R0, p0)

climber = IIWA_Climber()

In [106]:
X_WG = climber.get_X_WG()
climber.visualize_frame('gripper_current', X_WG)
climber.visualize_frame('world_center', X_WorldCenter)
climber.visualize_com('com_test', np.array([1,1,1]))

In [98]:
climber.run()

In [41]:
# AddMeshcatTriad(meshcat, "climber/" + "iiwa_link_0",
#                 length=2, radius=0.1, X_PT=X_WF)

In [10]:
# X_WF


# AddMeshcatTriad(meshcat, "painter/" + name,
#                 length=length, radius=radius, X_PT=X_WF)

In [42]:
# simulator = Simulator(diagram, context)
# simulator.set_target_realtime_rate(0.25)
# simulator.AdvanceTo(2.0)