In [None]:
import matplotlib.pyplot as plt

import seaborn as sns

import numpy as np
from IPython.display import HTML, display, SVG

import pydot
from pydrake.all import (AddMultibodyPlantSceneGraph, DiagramBuilder, BasicVector,
                        Parser, Saturation, Simulator, PlanarSceneGraphVisualizer, 
                        LinearQuadraticRegulator, AbstractValue, MeshcatVisualizerCpp, StartMeshcat)    

from pydrake.systems.primitives import LogVectorOutput, ConstantVectorSource

from pydrake.systems.framework import LeafSystem
from pydrake.multibody import inverse_kinematics
from pydrake.multibody import plant as plnt
from pydrake.multibody import math as m

from pydrake.multibody.plant import ContactResults
from underactuated.meshcat_cpp_utils import MeshcatSliders
from Custom_LeafSystems import *

from tqdm import tqdm as tqdm

from Linearize import getGradients


In [None]:
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat_instance = StartMeshcat()

In [None]:
"""Get LQR gains through equilibrium state and torques obtained from PD controller"""

def getTorquesHoppingLeg():
    path = "HoppingLeg/leg_v2/urdf/LEG_002_No_Coll.urdf"
    eq_state = np.array([3.1998e-1, 2.37277e0, -1.047198e0, 1.1931e-5, -4.51202e-4, -4.54405e-4])
    eq_input = np.array([0.41216, -0.76074367])
    sys_autodiff = getGradients(eq_state, eq_input, path)
    sys_autodiff.makePlantFromURDF()
    sys_autodiff.getGradients()

    Q = np.eye(6)
    Q[1,1], Q[2, 2] = 10., 10.
    
    R = np.array([[10., 0],
                  [0., 1.]])

    K, S = LinearQuadraticRegulator(sys_autodiff.A, sys_autodiff.B, Q, R)
    return K, sys_autodiff.A, sys_autodiff.B

lqr_gain, A, B = getTorquesHoppingLeg()
print(np.linalg.eigvals(A-B@lqr_gain))

In [None]:
builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)

# TODO make simple urdf file to do inverse dynamics and foot hinge
object_instance = Parser(plant).AddModelFromFile("HoppingLeg/leg_v2/urdf/LEG_002_Minimal_Coll.urdf")
plant.Finalize()

meshcat_instance.Delete()
vis = MeshcatVisualizerCpp.AddToBuilder(builder, scene_graph, meshcat_instance)

context = plant.CreateDefaultContext()
plant.get_actuation_input_port().FixValue(context, [0., 0.])

plant.SetPositions(context, np.array([0.32, 3*np.pi/4, -np.pi/3]))
plant.SetVelocities(context, np.zeros(plant.num_velocities()))

tau_g = plant.CalcGravityGeneralizedForces(context)
print(tau_g)

print(dir(vis))

plnt.ContactModel(1)

# Inverse dynamics calculations
# TODO apply proper force at correct location

# mass = plant.CalcTotalMass(context)
# reaction_force = mass * plant.mutable_gravity_field().gravity_vector()

# force_object = plnt.ExternallyAppliedSpatialForce()
# force_object.body_index = plant.GetBodyByName("Link_foot").index()
# force_object.F_Bq_W = m.SpatialForce(tau=np.zeros(3), f=-plant.CalcGravityGeneralizedForces(context))
# force_object.p_BoBq_B = np.zeros(3)

# sp_force = AbstractValue.Make([force_object])
# test_sp = builder.AddSystem(ConstantValueSource(sp_force))
# builder.Connect(test_sp.get_output_port(), plant.get_applied_spatial_force_input_port())

# multi_force = MultibodyForces(plant)
# set_multi = multi_force.mutable_generalized_forces()
# set_multi = plant.CalcGravityGeneralizedForces(context)

# print(plant.CalcGravityGeneralizedForces(context))

# grav_force = m.SpatialForce(tau=np.zeros(3), f=plant.CalcGravityGeneralizedForces(context))

# plant.GetBodyByName("Link_foot").AddInForce(context, np.zeros(3), grav_force, plant.GetFrameByName("Link_foot"),
#                     multi_force)

# print(multi_force.generalized_forces())

# tau = plant.CalcInverseDynamics(context, np.zeros(plant.num_velocities()), multi_force)
# print(tau)

In [None]:
# TODO Visualize disturbance in meshcat

converter = builder.AddSystem(ReadContactResults(plant))
controller = builder.AddSystem(Balancing(lqr_gain))
saturation = builder.AddSystem(Saturation(min_value=[-6, -6], max_value=[6, 6]))
switcher = builder.AddSystem(SwitchController())

state_ref = np.array([3.1998e-1, 2.37277e0, -1.047198e0, 1.1931e-5, -4.51202e-4, -4.54405e-4, 0, 0])
ref_vector = builder.AddSystem(ConstantVectorSource(state_ref))

# Slider to toggle between balancing and hopping
# Currently slider does nothing
meshcat_instance.AddSlider('H', min=0, max=1, step=1, value=0.0)
selection = builder.AddSystem(MeshcatSliders(meshcat_instance,['H']))

disturbance_force = builder.AddSystem(Disturbance(plant))
# vis_force = builder.AddSystem(VisualizeForce(plant, meshcat_instance))

# Connecting diagram
builder.Connect(plant.get_contact_results_output_port(), converter.get_input_port())
builder.Connect(plant.get_state_output_port(), controller.get_input_port(0))
builder.Connect(ref_vector.get_output_port(), controller.get_input_port(1))
builder.Connect(converter.get_output_port(1), controller.get_input_port(2))
builder.Connect(controller.get_output_port(0), switcher.get_input_port(0))
builder.Connect(controller.get_output_port(1), switcher.get_input_port(1))
builder.Connect(selection.get_output_port(), switcher.get_input_port(2))
builder.Connect(switcher.get_output_port(), saturation.get_input_port(0))
builder.Connect(saturation.get_output_port(), plant.get_actuation_input_port())
builder.Connect(converter.get_output_port(1), disturbance_force.get_input_port())
builder.Connect(disturbance_force.get_output_port(0), plant.get_applied_spatial_force_input_port())
# builder.Connect(disturbance_force.get_output_port(0), vis_force.get_input_port())

# Loggers
logger_contact = LogVectorOutput(converter.get_output_port(0), builder, publish_period=1/1000)
logger_iscontact = LogVectorOutput(converter.get_output_port(1), builder, publish_period=1/1000)
logger_torque = LogVectorOutput(saturation.get_output_port(), builder, publish_period=1/1000)

diagram = builder.Build()
display(SVG(pydot.graph_from_dot_data(diagram.GetGraphvizString(max_depth=2))[0].create_svg()))

# context_diagram = diagram.CreateDefaultContext()
# plant_context_from_diagram = diagram.GetSubsystemContext(plant, context_diagram)

In [6]:

# Set up a simulator to run this diagram

simulator = Simulator(diagram)
sim_context = simulator.get_mutable_context()

sim_context.SetTime(0.)

state_det = np.array([0.65, 3*np.pi/4, -np.pi/3, 0, 0, 0])
state_rand = np.hstack((0.1*np.random.randn(1, 3), np.zeros((1, 3))))
x0 = (state_det + state_rand).flatten()

sim_context.SetContinuousState(x0)

vis.StartRecording()
simulator.AdvanceTo(2)
vis.StopRecording()
vis.PublishRecording()

meshcat_instance.DeleteAddedControls()

contact = logger_contact.FindLog(sim_context).data()
isTouching = logger_iscontact.FindLog(sim_context).data()
torques = logger_torque.FindLog(sim_context).data()

plt.style.use("ggplot")
plt.rcParams["font.family"] = "sans-serif"
sns.set_palette("dark")
sns.set_context("talk")

show = True

t = np.arange(0, len(contact[1]))
plt.plot(t, contact[0, :], label="Axis 1")
plt.plot(t, contact[1, :], label="Axis 2")
plt.plot(t, contact[2, :], label="Axis 3")
plt.legend()
plt.title("Hopping Leg Contact Forces")
plt.xlabel("Time step")
plt.ylabel("Contact Force [N]")
plt.tight_layout()

if show:
    plt.show()
else:
    plt.savefig("Figures/Hopping_ContactF.pdf")
    plt.close()

plt.plot(t, contact[0, :], label="Hip")
plt.plot(t, contact[1, :], label="Knee")
plt.legend()
plt.title("Hopping Leg Motor Torques")
plt.xlabel("Time step")
plt.ylabel("Torque [Nm]")
plt.tight_layout()

if show:
    plt.show()
else:
    plt.savefig("Figures/Hopping_torque.pdf")
    plt.close()
