In [1]:

from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap

In [2]:
from utils.util_data import *


In [3]:
with open('models/final_humanoid_no_gravity.xml', 'r') as file:
    xml = file.read()


# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=480, width=640)

In [4]:
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())

[0.  0.  0.9 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. ] <class 'numpy.ndarray'>
[0.  0.  0.9 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. ] <class 'jaxlib.xla_extension.ArrayImpl'> {CpuDevice(id=0)}


In [5]:
#get th kp and kd for the agent
kp,kd = generate_kp_kd_gains()
print(kp)
print(kd)

[1000 1000 1000  100  100  100  400  400  400  300  400  400  400  300
  500  500  500  500  400  400  400  500  500  500  500  400  400  400]
[100 100 100  10  10  10  40  40  40  30  40  40  40  30  50  50  50  50
  40  40  40  50  50  50  50  40  40  40]


In [6]:
t_init = 1
t_end = 3

In [7]:
from some_math.math_utils import *

In [8]:
#generate trajectories
trajec_dict = dict()


a_jnt_right =get_actuator_indx(mj_model,'right_shoulder','Y')
a_jnt_left =get_actuator_indx(mj_model,'left_shoulder','X')
a_jnt_left_elbow = get_actuator_indx(mj_model,'left_elbow','X')
a_jnt_right_elbow = get_actuator_indx(mj_model,'right_elbow','X')

right_knee = get_actuator_indx(mj_model,'right_knee','X')
left_knee = get_actuator_indx(mj_model,'left_knee','X')

trajec_dict[a_jnt_right] = generate_trajectory(t_init,t_end, 0, -1.5)
#trajectory left
trajec_dict[a_jnt_left] = generate_trajectory(t_init, t_end, 0, 1.5)
# #left elbow   
trajec_dict[a_jnt_left_elbow]= generate_trajectory(t_init, t_end, 0, 1.5)
# #right elbow
trajec_dict[a_jnt_right_elbow] = generate_trajectory(t_init, t_end, 0, 1.5)

trajec_dict[right_knee] = generate_trajectory(t_init, t_end, 0, 0)

trajec_dict[left_knee] = generate_trajectory(t_init, t_end, 0, 0)


start_trajec = start_trajectories(trajec_dict)

In [9]:
def visualize_mjx(model,mjData, pdControl, stable=False):
    scene_option = mujoco.MjvOption()
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True
    duration = 7.0  # (seconds)
    framerate = 60  # (Hz)
    delta_time = 1.0 / framerate
    jit_step = jax.jit(mjx.step)

    camera = mujoco.MjvCamera()
    mujoco.mjv_defaultFreeCamera(model, camera)
    camera.distance = 5
    camera.elevation = -15
    camera.azimuth = 0 
    camera.lookat = mjData.body('root').subtree_com

    frames = []
    mujoco.mj_resetData(model, mjData)
    mjx_data = mjx.put_data(model, mjData)
    while mjx_data.time < duration:
     
      
      time = mjx_data.time
      time = jp.clip(time, t_init, t_end)
      
      qpos = mjx_data.qpos
      qvel = mjx_data.qvel
      
      if not stable:
        
        target = compute_cubic_trajectory(time,start_trajec)
        
        torque= pdControl(target,model,mjData,mjx_data,
                           qpos,qvel,kp,kd,time,delta_time)
        
      else:
        print('stable')
        torque= pdControl(start_trajec,model,mjData,mjx_data,
                           qpos,qvel,kp,kd,time,delta_time)
        
        
  
      
      updated_ctrl = mjx_data.ctrl.at[:].set(torque)
      
      new_mjxData = mjx_data.replace(ctrl = updated_ctrl)
      
      mjx_data = jit_step(mjx_model, new_mjxData)
      #print(mjx_data.actuator)
      if len(frames) < mjx_data.time * framerate:
        mjData = mjx.get_data(model, mjx_data)
        renderer.update_scene(mjData, camera,scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)

    media.show_video(frames, fps=framerate)

Start with the standard pd

In [10]:
from agents_env.pds_controllers_mjx import standard_pd_controller_mjx

In [11]:
visualize_mjx(mj_model,mj_data,standard_pd_controller_mjx)

0
This browser does not support the video tag.


Now with feedback

In [12]:
from agents_env.pds_controllers_mjx import feedback_pd_controller_mjx

In [13]:
visualize_mjx(mj_model,mj_data,feedback_pd_controller_mjx)

0
This browser does not support the video tag.


Now with Stable

In [14]:
from agents_env.pds_controllers_mjx import stable_pd_controller_custom_trajectory_mjx

In [15]:
visualize_mjx(mj_model,mj_data,stable_pd_controller_custom_trajectory_mjx,True)

stable
M: [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
stable
M: [[ 4.5000000e+01  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
  -2.2499979e-02  0.0000000e+00]
 [ 0.0000000e+00  4.5000000e+01  0.0000000e+00 ...  2.2499979e-02
   0.0000000e+00  4.5000002e-02]
 [ 0.0000000e+00  0.0000000e+00  4.5000000e+01 ...  0.0000000e+00
  -4.5000002e-02  0.0000000e+00]
 ...
 [ 0.0000000e+00  2.2499979e-02  0.0000000e+00 ...  1.0014334e+00
   2.9148817e-10  1.0124983e-03]
 [-2.2499979e-02  0.0000000e+00 -4.5000002e-02 ...  2.9148817e-10
   1.0053940e+00  0.0000000e+00]
 [ 0.0000000e+00  4.5000002e-02  0.0000000e+00 ...  1.0124983e-03
   0.0000000e+00  1.0053108e+00]]
stable
M: [[ 4.5000000e+01  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
  -2.2499979e-02  0.0000000e+00]
 [ 0.0000000e+00  4.5000000e+01  0.0000000e+00 ...  2.2499979e-02
   0.0000000e+00  4.5000002e-02]
 [ 0.0000000e+00

0
This browser does not support the video tag.
