In [None]:
# External libraries
import numpy as np
from matplotlib import pyplot as plt

# Drake dependencies
from pydrake.all import (
    DiagramBuilder,
    Simulator,
    StartMeshcat,
    MultibodyPlant,
    Demultiplexer,
    DiscreteContactApproximation,
    ConstantVectorSource,
    Parser,
    AddMultibodyPlantSceneGraph,
    ConstantVectorSource,
    DiagramBuilder,
    JointSliders,
    MeshcatVisualizer,
    MeshcatVisualizerParams,
    MultibodyPlant,
    MultibodyPositionToGeometryPose,
    Multiplexer,
    Parser,
    PrismaticJoint,
    SceneGraph,
    SpatialInertia,
    Sphere,
    UnitInertia,
    MeshcatVisualizerParams,
    LoadModelDirectivesFromString,
    ProcessModelDirectives,
    AddDefaultVisualization,
    SimIiwaDriver,
    IiwaControlMode
)

# Custom classes and functions
from ShishKebot.CartesianStiffnessController import CartesianStiffnessController

# Helper functions
import manipulation
from manipulation.meshcat_utils import MeshcatSliders, StopButton
from manipulation.scenarios import AddShape
from manipulation.utils import RenderDiagram
from manipulation.scenarios import AddRgbdSensors
import os
import sys 

In [None]:
meshcat = StartMeshcat()
meshcat.SetProperty("/Background", "visible", False)
meshcat.SetProperty("/Cameras/default/rotated/<object>", "zoom", 10.5)

In [None]:
iiwa1_directive = f"""
directives:
- add_model:
    name: iiwa1
    file: package://drake_models/iiwa_description/sdf/iiwa7_with_box_collision.sdf
    default_joint_positions:
        iiwa_joint_1: [0.0]
        iiwa_joint_2: [0.0]
        iiwa_joint_3: [0.0]
        iiwa_joint_4: [0.0]
        iiwa_joint_5: [0.0]
        iiwa_joint_6: [0.0]
        iiwa_joint_7: [0.0]
- add_weld:
    parent: world
    child: iiwa1::iiwa_link_0
    X_PC:
        translation: [0, 0.5, 0]
- add_model:
    name: wsg1
    file: package://manipulation/schunk_wsg_50_welded_fingers.sdf
- add_weld:
    parent: iiwa1::iiwa_link_7
    child: wsg1::body
    X_PC:
        translation: [0, 0, 0.09]
        rotation: !Rpy {{ deg: [90, 0, 90]}}

- add_model:
    name: skewer
    file: file://{os.getcwd()}/Models/skewer_5mm.sdf
    # default_free_body_pose:
    #     skewer_5mm:
    #         translation: [0.05, 0.2, 2]         
- add_weld:
    parent: wsg1::body
    child: skewer::skewer_5mm
    X_PC:
        translation: [0, 0, 0]
        rotation: !Rpy {{ deg: [270, 0, 0]}}
"""

iiwa2_directive = f"""
directives:
- add_model:
    name: iiwa2
    file: package://drake_models/iiwa_description/sdf/iiwa7_with_box_collision.sdf
    default_joint_positions:
        iiwa_joint_1: [0.0]
        iiwa_joint_2: [0.6]
        iiwa_joint_3: [0.0]
        iiwa_joint_4: [-1.75]
        iiwa_joint_5: [0.0]
        iiwa_joint_6: [1.0]
        iiwa_joint_7: [0.0]
- add_weld:
    parent: world
    child: iiwa2::iiwa_link_0
    X_PC:
        translation: [0, -0.5, 0]
- add_model:
    name: wsg2
    file: package://manipulation/schunk_wsg_50_welded_fingers.sdf
- add_weld:
    parent: iiwa2::iiwa_link_7
    child: wsg2::body
    X_PC:
        translation: [0, 0, 0.09]
        rotation: !Rpy {{ deg: [90, 0, 90]}}
"""

world_directive = f"""
directives:
- add_model:
    name: table
    file: file://{os.getcwd()}/Models/ground.sdf
- add_weld:
    parent: world
    child: table::base

- add_model:
    name: cube
    file: file://{os.getcwd()}/Models/cube_food.sdf
    default_free_body_pose:
        cube_food:
            rotation: !Rpy {{ deg: [{np.random.rand()*180}, {np.random.rand()*180}, {np.random.rand()*180}]}}
            translation: {[-0.4-np.random.rand()*0.6, np.random.rand() - 1, 0]}
- add_model:
    name: cube2
    file: file://{os.getcwd()}/Models/cube_food.sdf
    default_free_body_pose:
        cube_food:
            rotation: !Rpy {{ deg: [{np.random.rand()*180}, {np.random.rand()*180}, {np.random.rand()*180}]}}
            translation: {[-0.4-np.random.rand()*0.6, np.random.rand() - 1, .1]}
- add_model:
    name: cube3
    file: file://{os.getcwd()}/Models/cube_food.sdf
    default_free_body_pose:
        cube_food:
            rotation: !Rpy {{ deg: [{np.random.rand()*180}, {np.random.rand()*180}, {np.random.rand()*180}]}}
            translation: {[-0.4-np.random.rand()*0.6, np.random.rand() - 1, .2]}

- add_frame:
    name: camera0_origin
    X_PF:
        base_frame: world
        # rotation: !Rpy {{ deg: [180, 0, 0]}}
        # translation: [0, 0, 4]
        rotation: !Rpy {{ deg: [225, 0, 0]}}
        translation: [-0.7, -1.5, 0.5]
- add_model:
    name: camera0
    file: package://manipulation/camera_box.sdf
- add_weld:
    parent: camera0_origin
    child: camera0::base

- add_frame:
    name: camera1_origin
    X_PF:
        base_frame: world
        rotation: !Rpy {{ deg: [135, 0, 0]}}
        translation: [-0.7, 0.5, 0.5]
- add_model:
    name: camera1
    file: package://manipulation/camera_box.sdf
- add_weld:
    parent: camera1_origin
    child: camera1::base
"""

In [None]:
builder = DiagramBuilder()

In [None]:
# Add iiwa1 to the scene
iiwa1_plant, iiwa1_scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=1e-3)
iiwa1_plant.set_name("iiwa1_plant")
iiwa1_scene_graph.set_name("iiwa1_scene_graph")
iiwa1_plant.set_discrete_contact_approximation(DiscreteContactApproximation.kSap)
directives = LoadModelDirectivesFromString(iiwa1_directive)
parser = Parser(iiwa1_plant)
parser.package_map().Add("manipulation", manipulation.__path__[0] + "/models/")
models = ProcessModelDirectives(directives, iiwa1_plant, parser)
iiwa1 = iiwa1_plant.GetModelInstanceByName("iiwa1")
wsg1 = iiwa1_plant.GetModelInstanceByName("wsg1")
iiwa1_plant.Finalize()

# Meshcat
params = MeshcatVisualizerParams()
params.prefix = "iiwa1"
MeshcatVisualizer.AddToBuilder(builder, iiwa1_scene_graph, meshcat, params)

In [None]:
# Add iiwa2 to the scene
iiwa2_plant, iiwa2_scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=1e-3)
iiwa2_plant.set_name("iiwa2_plant")
iiwa2_scene_graph.set_name("iiwa2_scene_graph")
iiwa2_plant.set_discrete_contact_approximation(DiscreteContactApproximation.kSap)
directives = LoadModelDirectivesFromString(iiwa2_directive)
parser = Parser(iiwa2_plant)
parser.package_map().Add("manipulation", manipulation.__path__[0] + "/models/")
models = ProcessModelDirectives(directives, iiwa2_plant, parser)
iiwa2 = iiwa2_plant.GetModelInstanceByName("iiwa2")
wsg2 = iiwa2_plant.GetModelInstanceByName("wsg2")
iiwa2_plant.Finalize()

# Meshcat
params = MeshcatVisualizerParams()
params.prefix = "iiwa2"
MeshcatVisualizer.AddToBuilder(builder, iiwa2_scene_graph, meshcat, params)

In [None]:
# Add world models to the scene
world_plant, world_scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=1e-3)
world_plant.set_name("world_plant")
world_scene_graph.set_name("world_scene_graph")
world_plant.set_discrete_contact_approximation(DiscreteContactApproximation.kSap)
directives = LoadModelDirectivesFromString(world_directive)
parser = Parser(world_plant)
parser.package_map().Add("manipulation", manipulation.__path__[0] + "/models/")
models = ProcessModelDirectives(directives, world_plant, parser)
world_plant.Finalize()

# Meshcat
params = MeshcatVisualizerParams()
params.prefix = "world"
MeshcatVisualizer.AddToBuilder(builder, world_scene_graph, meshcat, params)

# Add the cameras to the diagram
AddRgbdSensors(builder, world_plant, world_scene_graph)

In [None]:
# Add our torque controller
controller: CartesianStiffnessController = builder.AddSystem(CartesianStiffnessController(plant, "iiwa1", "wsg1"))
controller.SetGains(
    position=(10.0, 5.0),
    orientation=(10.0, 5.0),
    null_space=5.0
)

# Wire up controller torque to iiwa torque
builder.Connect(controller.GetOutputPort("iiwa_torque_cmd"), plant.get_actuation_input_port(iiwa))

# Extract state from plant and feed to controller
state_demultiplexer = builder.AddSystem(Demultiplexer([7,7]))
builder.Connect(plant.get_state_output_port(iiwa), state_demultiplexer.get_input_port(0))
builder.Connect(state_demultiplexer.get_output_port(0), controller.GetInputPort("iiwa_position_measured"))
builder.Connect(state_demultiplexer.get_output_port(1), controller.GetInputPort("iiwa_velocity_measured"))


In [None]:
# Add a sim iiwa controller for each iiwa. iiwa1's controller is wired into the port switches

sim_iiwa_controller_1 = builder.AddSystem(SimIiwaDriver(
  control_mode=IiwaControlMode.kPositionAndTorque, 
  controller_plant=iiwa1_plant, 
  ext_joint_filter_tau=0.01,
  kp_gains=np.full(iiwa1_plant.num_positions(iiwa1), 100))
)

zeros = builder.AddSystem(ConstantVectorSource(np.zeros(7)))

builder.Connect(iiwa1_plant.get_state_output_port(iiwa1), sim_iiwa_controller_1.GetInputPort("state"))
builder.Connect(iiwa1_plant.get_generalized_contact_forces_output_port(iiwa1), sim_iiwa_controller_1.GetInputPort("generalized_contact_forces"))

# builder.Connect(sim_iiwa_controller.GetOutputPort("actuation"), plant.get_actuation_input_port(iiwa))

sim_iiwa_controller_2 = builder.AddSystem(SimIiwaDriver(
  control_mode=IiwaControlMode.kPositionAndTorque, 
  controller_plant=iiwa2_plant, 
  ext_joint_filter_tau=0.01,
  kp_gains=np.full(iiwa2_plant.num_positions(iiwa2), 100))
)

zeros = builder.AddSystem(ConstantVectorSource(np.zeros(7)))

builder.Connect(iiwa2_plant.get_state_output_port(iiwa2), sim_iiwa_controller_2.GetInputPort("state"))
builder.Connect(iiwa2_plant.get_generalized_contact_forces_output_port(iiwa2), sim_iiwa_controller_2.GetInputPort("generalized_contact_forces"))
builder.Connect(zeros.get_output_port(0), sim_iiwa_controller_2.GetInputPort("torque"))
# builder.Connect(trajectory.GetOutputPort("iiwa_position_cmd"), sim_iiwa_controller_2.GetInputPort("position"))

# builder.Connect(sim_iiwa_controller.GetOutputPort("actuation"), plant.get_actuation_input_port(iiwa))

In [None]:
from enum import Enum
from pydrake.all import (
    LeafSystem,
    PortSwitch,
    InputPortIndex,
    AbstractValue,
)
class State(Enum):
    START = 1
    PICK_UP_OBJECTS = 2 # diff ik controller to pick up objects
    GET_TO_SKEWER_POSITION = 3 #moves robots to a set positions aligned to skewer
    SKEWER = 4 #skewering action with stiffness controller

class StateMachine(LeafSystem):
    """
    State machine for switching between states for skewering
    """
    def __init__(
        self,
        iiwa1_plant: MultibodyPlant,
        iiwa2_plant: MultibodyPlant,
        ) -> None:
        LeafSystem.__init__(self)
        self._plant1 = iiwa1_plant
        self._plant2 = iiwa2_plant

        self.state = 0 #?
        self.desired_pose_iiwa2 = np.zeros(6)
        self.gripper_state = False

        # Controller inputs
        self._q_in = self.DeclareVectorInputPort("iiwa_position_measured", 7)
        self._torque_in = self.DeclareVectorInputPort("iiwa_torque_measured", 7)
        self._x_d_in = self.DeclareVectorInputPort("pose_desired", 6)

        # Controller outputs
        self.DeclareVectorOutputPort("pose_desired_1", 6, self.CalcDesiredPose1)
        self.DeclareVectorOutputPort("pose_desired_2", 6, self.CalcDesiredPose2)
        self.DeclareVectorOutputPort("close_gripper", 1, self.CloseGripper)
        # for selecting between stiffness and diff ik controllers
        self.DeclareAbstractOutputPort(
                    "control_mode",
                    lambda: AbstractValue.Make(InputPortIndex(0)),
                    self.CalcControlMode)

    def CalcDesiredPose1(self, context, output):
        # where the state machine lives

        # case statement for state machine control
        if self.state == State.START:
            # take in point clouds
            # grasp selection of objects
            # plan trajectory to objects
            pass
        elif self.state == State.PICK_UP_OBJECTS:
            # move to objects
            # grasp objects (close gripper output)
            pass
        elif self.state == State.GET_TO_SKEWER_POSITION:
            # move to skewer position
            pass
        else: #SKEWER
            # switch to stiffness controller for iiwa1
            # move to skewer objects
            pass

        # x_d = self._x_d_in.Eval(context)
        output.SetFromVector(x_d)

    def CalcDesiredPose2(self, context, output):
        # just set output to be the stored value for pose_desired_2
        # desired_pose_iiwa2 is updated in CalcDesiredPose1 and stored
        output.SetFromVector(self.desired_pose_iiwa2)

    def CalcControlMode(self, context, output):
        # mode = context.get_abstract_state(int(self._mode_index)).get_value()

        if self.state == State.SKEWER:
            output.set_value(InputPortIndex(2))  # Use stiffness control
        else:
            output.set_value(InputPortIndex(1))  # Use diff IK
    
    def CloseGripper(self, context, output):
        output.set_value(self.gripper_state)

state_machine = builder.AddSystem(StateMachine(iiwa1_plant, iiwa2_plant))

zeros = builder.AddSystem(ConstantVectorSource(np.zeros(7)))

# PortSwitch1: switch that outputs position to the skewer iiwa (iiwa1)
switch_position = builder.AddSystem(PortSwitch(7))
#inputs
#diff_ik position output to switch position input
builder.Connect(sim_iiwa_controller_1.GetOutputPort("position_measured"), switch_position.DeclareInputPort("diff_ik"))
#stiffness output to switch_position input
builder.Connect(controller.GetOutputPort("iiwa_position_cmd"), switch_position.DeclareInputPort("current_position"))
#output
builder.Connect(switch_position.get_output_port(), sim_iiwa_controller_1.GetInputPort("position"))
#selector: state machine state -> selector input
builder.Connect(state_machine.GetOutputPort("control_mode"), switch_position.get_port_selector_input_port())

# PortSwitch2: switch that outputs torque to the skewer iiwa (iiwa1)
switch_torque = builder.AddSystem(PortSwitch(7))
#inputs:
#zero torques to input
builder.Connect(zeros.get_output_port(0), switch_torque.DeclareInputPort("diff_ik"))
#stiffness controller to input
builder.Connect(controller.GetOutputPort("iiwa_torque_cmd"), switch_torque.DeclareInputPort("position"))
#output
builder.Connect(switch_torque.get_output_port(), sim_iiwa_controller_1.GetInputPort("torque"))
#selector
builder.Connect(state_machine.GetOutputPort("control_mode"), switch_torque.get_port_selector_input_port())

# RenderDiagram(diagram, max_depth=1)

In [None]:
diagram = builder.Build()
simulator = Simulator(diagram)

In [None]:
simulator.set_target_realtime_rate(1.0)
simulator.AdvanceTo(np.inf)