# Notebook Setup

The following cell will:
- on Colab (only), install Drake to `/opt/drake`, install Drake's prerequisites via `apt`, and add pydrake to `sys.path`.  This will take approximately two minutes on the first time it runs (to provision the machine), but should only need to reinstall once every 12 hours.  If you navigate between notebooks using Colab's "File->Open" menu, then you can avoid provisioning a separate machine for each notebook.
- import packages used throughout the notebook.

You will need to rerun this cell if you restart the kernel, but it should be fast (even on Colab) because the machine will already have drake installed.

In [None]:
import importlib
import sys
from urllib.request import urlretrieve

# Install drake (and underactuated).
if 'google.colab' in sys.modules and importlib.util.find_spec('underactuated') is None:
    urlretrieve(f"http://underactuated.csail.mit.edu/scripts/setup/setup_underactuated_colab.py",
                "setup_underactuated_colab.py")
    from setup_underactuated_colab import setup_underactuated
    setup_underactuated(underactuated_sha='15cfd96b0bdfd1b0c67597c24f91907776c02a6d', drake_version='0.27.0', drake_build='release')

server_args = []
if 'google.colab' in sys.modules:
  server_args = ['--ngrok_http_tunnel']

# Start a single meshcat server instance to use for the remainder of this notebook.
from meshcat.servers.zmqserver import start_zmq_server_as_subprocess
proc, zmq_url, web_url = start_zmq_server_as_subprocess(server_args=server_args)

    
import numpy as np
import meshcat

from pydrake.common.containers import namedview
from pydrake.common.value import AbstractValue
from pydrake.math import RigidTransform, RotationMatrix
from pydrake.geometry import FramePoseVector
from pydrake.multibody.plant import MultibodyPlant
from pydrake.multibody.parsing import Parser
from pydrake.systems.framework import (BasicVector, BasicVector_, LeafSystem_,
                                       LeafSystem)
from pydrake.systems.scalar_conversion import TemplateSystem
from pydrake.all import (
    SceneGraph, DiagramBuilder, Parser, ConnectPlanarSceneGraphVisualizer, DrakeVisualizer, 
    ConnectMeshcatVisualizer, DirectCollocation, Solve, PiecewisePolynomial
)

from underactuated import FindResource

# Note: In order to use the Python system with drake's autodiff features, we
# have to add a little "TemplateSystem" boilerplate (for now).  For details,
# see https://drake.mit.edu/pydrake/pydrake.systems.scalar_conversion.html

GliderState = namedview(
    "GliderState", ["x", "z", "pitch", "elevator", "xdot", "zdot", "pitchdot"])

@TemplateSystem.define("GliderPlant_")
def GliderPlant_(T):

    class Impl(LeafSystem_[T]):

        def _construct(self, converter=None):
            LeafSystem_[T].__init__(self, converter)
            # one inputs (elevator_velocity)
            self.DeclareVectorInputPort("elevatordot", BasicVector_[T](1))
            # four positions, three velocities
            self.DeclareContinuousState(4, 3, 0)
            # seven outputs (full state)
            self.DeclareVectorOutputPort("state", BasicVector_[T](7),
                                         self.CopyStateOut)

            # TODO(russt): Declare elevator constraints:
            self.elevator_lower_limit = -0.9473
            self.elevator_upper_limit = 0.4463

        def _construct_copy(self, other, converter=None):
            Impl._construct(self, converter=converter)

        def DoCalcTimeDerivatives(self, context, derivatives):
            # parameters based on Rick Cory's "R1 = no dihedral" model.
            Sw = 0.0885  # surface area of wing + fuselage + tail.
            Se = 0.0147  # surface area of elevator.
            lw = 0  # horizontal offset of wing center.
            le = 0.022  # elevator aerodynamic center from hinge.
            lh = 0.317  # elevator hinge.
            inertia = 0.0015  # body inertia.
            m = 0.08  # body mass.
            rho = 1.204  # air density (kg/m^3).
            gravity = 9.81  # gravity

            s = GliderState(
                context.get_mutable_continuous_state_vector().CopyToVector())
            elevatordot = self.EvalVectorInput(context, 0)[0]

            xwdot = s.xdot + lw * s.pitchdot * np.sin(s.pitch)
            zwdot = s.zdot + lw * s.pitchdot * np.cos(s.pitch)
            alpha_w = -np.arctan2(zwdot, xwdot) - s.pitch
            fw = rho * Sw * np.sin(alpha_w) * (zwdot**2 + xwdot**2)
            
            e = s.pitch + s.elevator
            edot = s.pitchdot + elevatordot
            xedot = s.xdot + lh * s.pitchdot * np.sin(s.pitch) \
                + le * edot * np.sin(e)
            zedot = s.zdot + lh * s.pitchdot * np.cos(s.pitch) \
                + le * edot * np.cos(e)
            alpha_e = -np.arctan2(zedot, xedot) - e
            fe = rho * Se * np.sin(alpha_e) * (zedot**2 + xedot**2)

            sdot = GliderState(s[:])
            sdot[0:3] = s[4:7]
            sdot.elevator = elevatordot
            sdot.xdot = (fw * np.sin(s.pitch) + fe * np.sin(e)) / m
            sdot.zdot = (fw * np.cos(s.pitch) + fe * np.cos(e)) / m - gravity
            sdot.pitchdot = (fw * lw + fe * (lh * np.cos(s.elevator) + le)) / inertia
            derivatives.get_mutable_vector().SetFromVector(sdot[:])

        def CopyStateOut(self, context, output):
            x = context.get_continuous_state_vector().CopyToVector()
            y = output.SetFromVector(x)

    return Impl


# To use glider.urdf for visualization, follow the pattern from e.g.
# drake::examples::quadrotor::QuadrotorGeometry.
class GliderGeometry(LeafSystem):

    def __init__(self, scene_graph):
        LeafSystem.__init__(self)
        assert scene_graph

        mbp = MultibodyPlant(1.0)  # Timestep doesn't matter, and this avoids a warning
        parser = Parser(mbp, scene_graph)
        model_id = parser.AddModelFromFile(
            FindResource("models/glider/glider.urdf"))
        mbp.Finalize()
        self.source_id = mbp.get_source_id()
        self.body_frame_id = mbp.GetBodyFrameIdOrThrow(
            mbp.GetBodyIndices(model_id)[0])
        self.elevator_frame_id = mbp.GetBodyFrameIdOrThrow(
            mbp.GetBodyIndices(model_id)[1])

        self.DeclareVectorInputPort("state", BasicVector(7))
        self.DeclareAbstractOutputPort(
            "geometry_pose", lambda: AbstractValue.Make(FramePoseVector()),
            self.OutputGeometryPose)

    def OutputGeometryPose(self, context, poses):
        assert self.body_frame_id.is_valid()
        assert self.elevator_frame_id.is_valid()
        lh = 0.317  # elevator hinge.
        state = GliderState(self.get_input_port(0).Eval(context))
        body_pose = RigidTransform(RotationMatrix.MakeYRotation(state.pitch),
                                   [state.x, 0, state.z])
        elevator_pose = RigidTransform(
            RotationMatrix.MakeYRotation(state.pitch + state.elevator), [
                state.x - lh * np.cos(state.pitch), 0,
                state.z + lh * np.sin(state.pitch)
            ])
        poses.get_mutable_value().set_value(self.body_frame_id, body_pose)
        poses.get_mutable_value().set_value(self.elevator_frame_id,
                                            elevator_pose)

    @staticmethod
    def AddToBuilder(builder, glider_state_port, scene_graph):
        assert builder
        assert scene_graph

        geom = builder.AddSystem(GliderGeometry(scene_graph))
        builder.Connect(glider_state_port, geom.get_input_port(0))
        builder.Connect(geom.get_output_port(0),
                        scene_graph.get_source_pose_port(geom.source_id))

        return geom


GliderPlant = GliderPlant_[None]  # Default instantiation

def draw_glider(x):
    builder = DiagramBuilder()
    glider = builder.AddSystem(GliderPlant())
    scene_graph = builder.AddSystem(SceneGraph())
    GliderGeometry.AddToBuilder(builder, glider.GetOutputPort("state"), scene_graph)
    meshcat_vis = ConnectMeshcatVisualizer(builder, 
        scene_graph=scene_graph, 
        zmq_url=zmq_url)
    meshcat_vis.set_planar_viewpoint(xmin=-4, xmax=1, ymin=-1, ymax=1)

    diagram = builder.Build()
    context = diagram.CreateDefaultContext()
    context.SetContinuousState(x)
    meshcat_vis.load()
    diagram.Publish(context)

    return meshcat_vis

In [None]:
def dircol_perching():
    glider = GliderPlant()
    
    N = 41
    dircol = DirectCollocation(glider, glider.CreateDefaultContext(), N, 0.01, 0.05)
    dircol.AddEqualTimeIntervalsConstraints()

    # Input limits
    u = dircol.input()
    elevator_velocity_limit = 13  # max servo velocity (rad/sec)
    dircol.AddConstraintToAllKnotPoints(-elevator_velocity_limit <= u[0])
    dircol.AddConstraintToAllKnotPoints(u[0] <= elevator_velocity_limit)
    
    # State constraints
    s = GliderState(dircol.state())
    dircol.AddConstraintToAllKnotPoints(-4 <= s.x)
    dircol.AddConstraintToAllKnotPoints(s.x <= 1)
    dircol.AddConstraintToAllKnotPoints(-1 <= s.z)
    dircol.AddConstraintToAllKnotPoints(s.z <= 1)
    dircol.AddConstraintToAllKnotPoints(-np.pi/2.0 <= s.pitch)
    dircol.AddConstraintToAllKnotPoints(s.pitch <= np.pi/2.0)
    dircol.AddConstraintToAllKnotPoints(glider.elevator_lower_limit <= s.elevator)
    dircol.AddConstraintToAllKnotPoints(s.elevator <= glider.elevator_upper_limit)
#    dircol.AddConstraintToAllKnotPoints(0.1 <= s.xdot)

    # Initial conditions
    s0 = GliderState(np.zeros(7))
    s0.x = -3.5
    s0.z = 0.1
    s0.xdot = 7.0
    dircol.AddBoundingBoxConstraint(s0[:], s0[:], dircol.initial_state())
    draw_glider(s0[:])
    
    # Final conditions
    s = GliderState(dircol.final_state())
    dircol.AddBoundingBoxConstraint(0, 0, s.x)
    dircol.AddBoundingBoxConstraint(0, 0, s.z)
    dircol.AddBoundingBoxConstraint(np.pi/6.0, 1.0, s.pitch)
    dircol.AddBoundingBoxConstraint(-2.0, 2.0, s.xdot)
    dircol.AddBoundingBoxConstraint(-2.0, 2.0, s.zdot)
    
    # Cost
    dircol.AddRunningCost(100*u*u)
    sf_d = GliderState(np.zeros(7))
    sf_d.pitch = np.pi/4.0
    sf_d.xdot = 0.01
#    sf_d.zdot = -.01
    dircol.AddQuadraticErrorCost(np.diag([10, 10, 1, 10, 1, 1, 1]), sf_d[:], dircol.final_state())
    
    vis = meshcat.Visualizer(zmq_url=zmq_url, server_args=server_args)
    def plot_trajectory(times, states):
        vertices = np.vstack([states[0,:], 0*times, states[1,:]])
        vis["dircol"].set_object(meshcat.geometry.Line(meshcat.geometry.PointsGeometry(vertices),meshcat.geometry.LineBasicMaterial(color=0x0000dd)))

    dircol.AddStateTrajectoryCallback(plot_trajectory)

    initial_x_trajectory = PiecewisePolynomial.FirstOrderHold(
        [0., 1.], np.column_stack((s0[:], sf_d[:])))
    dircol.SetInitialTrajectory(PiecewisePolynomial(), initial_x_trajectory)
    
    result = Solve(dircol)
    if not result.is_success():
        infeasible = result.GetInfeasibleConstraints(dircol)
        print("Infeasible constraints:")
        for i in range(len(infeasible)):
            print(infeasible[i])
    
dircol_perching()