In [2]:
import numpy as np
from pydrake.all import (
    AddMultibodyPlantSceneGraph,
    DiagramBuilder,
    LeafSystem,
    MeshcatVisualizerCpp,
    MultilayerPerceptron,
    Parser,
    PerceptronActivationType,
    RandomGenerator,
    RigidTransform,
    RotationMatrix,
    SceneGraph,
    Simulator,
    StartMeshcat,
    ZeroOrderHold,
)
from pydrake.examples.pendulum import PendulumGeometry, PendulumPlant
from underactuated import FindResource, running_as_notebook
from value_iteration import *
from functools import partial

In [6]:
meshcat = StartMeshcat()

Meshcat is now available at http://localhost:7002


In [5]:
# Set up the cart-pole system
builder = DiagramBuilder()
cart_plant, cart_scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
file_name = FindResource("models/cartpole.urdf")
Parser(cart_plant).AddModelFromFile(file_name)
cart_plant.Finalize()
cart_plant_context = cart_plant.CreateDefaultContext()

cart_diagram = builder.Build()

num_states = cart_plant.num_continuous_states()


cart_actuation_port_index = 3
num_inputs = cart_plant.get_input_port(cart_actuation_port_index)

# Set up training data. states are (x, theta, x_dot, theta_dot)
time_step = 0.01
num_samples = 50
x_states_cart = np.linspace(-2, 2, num_samples)
theta_states_cart = np.linspace(0, 2 * np.pi, 50)
x_dot_states_cart = np.linspace(-10, 10, num_samples)
theta_dot_states_cart = np.linspace(-10, 10, num_samples)
state_grid_cart = np.meshgrid(
    x_states_cart,
    theta_states_cart,
    x_dot_states_cart,
    theta_dot_states_cart,
    indexing="ij",
)
state_data_cart = np.vstack([s.flatten() for s in state_grid_cart])

# zero cost state
cart_target_state = np.array([0, np.pi, 0, 0]).reshape(-1, 1)

Q_cart = np.diag([0.1, 20, 1, 1])
R_cart = np.array([2])

# A neural network for the cartpole
cart_value_mlp = MultilayerPerceptron(
    [False, True, False, False],
    [128, 128, 1],
    [
        PerceptronActivationType.kReLU,
        PerceptronActivationType.kReLU,
        PerceptronActivationType.kIdentity,
    ],
)

state_cost_function_cart = partial(compute_state_cost, Q_cart, cart_target_state)
# train the neural network
cart_value_mlp_context = ContinuousFittedValueIteration(
    cart_plant,
    cart_plant_context,
    cart_value_mlp,
    state_cost_function_cart,
    compute_u_star,
    R_cart,
    state_data_cart,
    time_step=0.01,
    discount_factor=0.9999,
    input_port_index=cart_actuation_port_index,
    lr=1e-4,
    minibatch=64,
    epochs=300,
    optimization_steps_per_epoch=100,
    input_limits=None,
    target_state=cart_target_state,
)

epoch 0: loss = 0.11710697837294064
epoch 1: loss = 0.08321165021248118
epoch 2: loss = 0.0780985613548365
epoch 3: loss = 0.07546664387904568
epoch 4: loss = 0.08555367595802893
epoch 5: loss = 0.08016439424912769
epoch 6: loss = 0.04663409108416899
epoch 7: loss = 0.04833643305826999
epoch 8: loss = 0.039867798298860814
epoch 9: loss = 0.043322141659040685
epoch 10: loss = 0.030690624309467895
epoch 11: loss = 0.04400796822983323
epoch 12: loss = 0.02890341539244498
epoch 13: loss = 0.04243039331076075
epoch 14: loss = 0.0442159820901976
epoch 15: loss = 0.039991611392951584
epoch 16: loss = 0.04396619853290439
epoch 17: loss = 0.04773254122689792
epoch 18: loss = 0.039098930868610375
epoch 19: loss = 0.048763557362438155
epoch 20: loss = 0.056561183736750335
epoch 21: loss = 0.04208077004204335
epoch 22: loss = 0.05539134821768695
epoch 23: loss = 0.051207420599503686
epoch 24: loss = 0.03930488998575729
epoch 25: loss = 0.07236711646731789
epoch 26: loss = 0.0565571839721975
epoch 

In [8]:
# initialize controller and plant
closed_loop_builder_cart = DiagramBuilder()

cart_plant_cl, cart_scene_graph_cl = AddMultibodyPlantSceneGraph(
    closed_loop_builder_cart, time_step=0.0
)

file_name = FindResource("models/cartpole.urdf")
Parser(cart_plant_cl).AddModelFromFile(file_name)
cart_plant_cl.Finalize()
cart_plant_context_cl = cart_plant_cl.CreateDefaultContext()
cart_controller_sys = ContinuousFittedValueIterationPolicyComputeUStar(
    cart_plant_cl,
    cart_value_mlp,
    cart_value_mlp_context,
    R_cart,
    compute_u_star,
    input_port_index=cart_actuation_port_index,
)


cart_controller = closed_loop_builder_cart.AddSystem(cart_controller_sys)
# we assume a zero-order hold between time steps
zoh_cart = closed_loop_builder_cart.AddSystem(ZeroOrderHold(time_step, 1))

# wire all the systems together
closed_loop_builder_cart.Connect(
    cart_plant_cl.get_state_output_port(), cart_controller.get_input_port()
)
closed_loop_builder_cart.Connect(
    cart_controller.get_output_port(), zoh_cart.get_input_port()
)
closed_loop_builder_cart.Connect(
    zoh_cart.get_output_port(), cart_plant_cl.get_input_port(cart_actuation_port_index)
)

meshcat.Delete()
meshcat.Set2dRenderMode(xmin=-2.5, xmax=2.5, ymin=-1.0, ymax=2.5)
vis = MeshcatVisualizerCpp.AddToBuilder(
    closed_loop_builder_cart, cart_scene_graph_cl, meshcat
)

cart_diagram_closed_loop = closed_loop_builder_cart.Build()

cart_simulator = Simulator(cart_diagram_closed_loop)
cart_simulator_context = cart_simulator.get_mutable_context()

cart_simulator.set_target_realtime_rate(1.0 if running_as_notebook else 0.0)
duration = 10.0 if running_as_notebook else 0.1
for i in range(1):
    cart_simulator_context.SetTime(0.)
    cart_simulator_context.SetContinuousState([0, 0, 0, 0])
    cart_simulator.Initialize()
    cart_simulator.AdvanceTo(duration)
