This notebook provides examples to go along with the [textbook](https://underactuated.csail.mit.edu/acrobot.html).  I recommend having both windows open, side-by-side!


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pydot
from IPython.display import SVG, display
from pydrake.all import (
    DiagramBuilder,
    Linearize,
    LinearQuadraticRegulator,
    LogVectorOutput,
    MeshcatVisualizer,
    SceneGraph,
    Simulator,
    StartMeshcat,
    SteadyStateKalmanFilter,
)
from pydrake.examples import AcrobotGeometry, AcrobotWEncoder

In [None]:
meshcat = StartMeshcat()

# Acrobot w/ Encoder


In [None]:
def BalancingLQRController(acrobot):
    context = acrobot.CreateDefaultContext()

    # Set nominal torque to zero.
    acrobot.GetInputPort("elbow_torque").FixValue(context, 0.0)

    # Set nominal state to the upright fixed point.
    context.SetContinuousState([np.pi, 0, 0, 0])

    # Setup LQR Cost matrices (penalize position error 10x more than velocity
    # to roughly address difference in units, using sqrt(g/l) as the time
    # constant.
    Q = np.diag([10, 10, 1, 1])
    R = [1]

    return LinearQuadraticRegulator(acrobot, context, Q, R)


def demo():
    builder = DiagramBuilder()
    acrobot_w_encoder = builder.AddSystem(AcrobotWEncoder(True))
    acrobot_w_encoder.set_name("acrobot_w_encoder")

    acrobot = acrobot_w_encoder.acrobot_plant()
    scene_graph = builder.AddSystem(SceneGraph())
    AcrobotGeometry.AddToBuilder(
        builder, acrobot_w_encoder.get_output_port(1), scene_graph
    )
    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

    # Make a Kalman filter observer.
    observer_acrobot = AcrobotWEncoder()
    observer_context = observer_acrobot.CreateDefaultContext()
    # Set context to upright fixed point.
    observer_context.SetContinuousState([np.pi, 0, 0, 0])
    observer_acrobot.GetInputPort("elbow_torque").FixValue(observer_context, 0.0)

    W = np.eye(4)
    V = 0.1 * np.eye(2)  # position measurements are relatively clean.
    observer = builder.AddSystem(
        SteadyStateKalmanFilter(observer_acrobot, observer_context, W, V)
    )
    observer.set_name("observer")
    builder.Connect(acrobot_w_encoder.get_output_port(0), observer.get_input_port(0))

    # Make the LQR Controller.
    controller = builder.AddSystem(BalancingLQRController(acrobot))
    controller.set_name("controller")
    builder.Connect(observer.get_output_port(0), controller.get_input_port())
    builder.Connect(controller.get_output_port(), acrobot_w_encoder.get_input_port(0))
    builder.Connect(controller.get_output_port(), observer.get_input_port(1))

    # Log the true state and the estimated state.
    x_logger = LogVectorOutput(acrobot_w_encoder.get_output_port(1), builder)
    x_logger.set_name("x_logger")
    xhat_logger = LogVectorOutput(observer.get_output_port(0), builder)
    xhat_logger.set_name("xhat_logger")

    diagram = builder.Build()

    display(SVG(pydot.graph_from_dot_data(diagram.GetGraphvizString())[0].create_svg()))

    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()

    # Set an initial condition near the upright fixed point.
    acrobot_w_encoder.GetMyMutableContextFromRoot(context).SetContinuousState(
        [np.pi + 0.1, -0.1, 0, 0]
    )

    # Set the initial conditions of the observer.
    observer.GetMyMutableContextFromRoot(context).SetContinuousState([np.pi, 0, 0, 0])

    # Simulate.
    # simulator.set_target_realtime_rate(1.0)
    simulator.get_mutable_integrator().set_maximum_step_size(0.01)
    simulator.get_mutable_integrator().set_fixed_step_mode(True)
    simulator.Initialize()
    simulator.AdvanceTo(5.0)

    x_log = x_logger.FindLog(context)
    xhat_log = xhat_logger.FindLog(context)

    plt.figure()
    plt.plot(x_log.sample_times(), (x_log.data()[0, :] - np.pi).T, label="theta1 - pi")
    plt.plot(x_log.sample_times(), x_log.data()[1, :].T, label="theta2")
    plt.legend()

    plt.figure(2)
    plt.plot(x_log.sample_times(), (x_log.data() - xhat_log.data()).T)
    plt.xlabel("seconds")
    plt.ylabel("error")
    plt.legend(["theta1", "theta2", "theta1_dot", "theta2_dot"])


demo()