In [1]:
import numpy as np
from pydrake.all import (
    AddMultibodyPlantSceneGraph,
    DiagramBuilder,
    Parser,
    Simulator,
    PlanarSceneGraphVisualizer,
    LinearQuadraticRegulator,
)
from value_iteration import *
import matplotlib

matplotlib.rcParams["figure.figsize"] = [10, 10]

# Cartpole LQR

In [3]:
def cartpole_balancing_example(
    target_state={"x": 0, "theta": np.pi, "x_dot": 0, "theta_dot": 0}
):
    def BalancingLQR(plant):
        context = plant.CreateDefaultContext()
        plant.get_actuation_input_port().FixValue(context, [0])
        context.get_mutable_continuous_state_vector().SetFromVector(
            list(target_state.values())
        )

        Q = np.diag((10.0, 10.0, 1.0, 1.0))
        R = [1]

        return LinearQuadraticRegulator(
            plant,
            context,
            Q,
            R,
            input_port_index=plant.get_actuation_input_port().get_index(),
        )

    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
    Parser(plant).AddModelFromFile("cartpole.urdf")
    plant.set_name("cartpole plant")
    plant.Finalize()

    controller = builder.AddSystem(BalancingLQR(plant))
    builder.Connect(plant.get_state_output_port(), controller.get_input_port(0))
    builder.Connect(controller.get_output_port(0), plant.get_actuation_input_port())

    # Add visualizer
    visualizer = builder.AddSystem(
        PlanarSceneGraphVisualizer(
            scene_graph, xlim=[-3.0, 3.0], ylim=[-0.5, 1.2], show=False
        )
    )
    visualizer.set_name("visualizer")
    builder.Connect(scene_graph.get_query_output_port(), visualizer.get_input_port(0))

    diagram = builder.Build()

    # Set up a simulator to run this diagram
    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()

    # Simulate
    for _ in range(5):
        df = simulate_and_animate(
            {k: v + 0.1 * np.random.randn() for k, v in target_state.items()},
            visualizer,
            simulator,
            sim_time=10
        )


cartpole_balancing_example()