In [1]:
#!pip install -q git+https://github.com/A-H-Mansoury/brax.git@02d3bb6b758fd2f50f97c98d5386d115e168ec9f

In [2]:
!pip install -q mujoco wandb mujoco_mjx ml_collections brax
#!pip install -q --upgrade ipykernel

In [3]:
# #@title Check if MuJoCo installation was successful

# import distutils.util
# import os
# import subprocess
# if subprocess.run('nvidia-smi').returncode:
#   raise RuntimeError(
#       'Cannot communicate with GPU. '
#       'Make sure you are using a GPU Colab runtime. '
#       'Go to the Runtime menu and select Choose runtime type.')

# # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# # This is usually installed as part of an Nvidia driver package, but the Colab
# # kernel doesn't install its driver via APT, and as a result the ICD is missing.
# # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
# # NVIDIA_ICD_CONFIG_PATH = '../../usr/share/glvnd/egl_vendor.d/10_nvidia.json'
# # if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
# #   with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
# #     f.write("""{
# #     "file_format_version" : "1.0.0",
# #     "ICD" : {
# #         "library_path" : "libEGL_nvidia.so.0"
# #     }
# # }
# # """)

# # Configure MuJoCo to use the EGL rendering backend (requires GPU)
# print('Setting environment variable to use GPU rendering:')
# %env MUJOCO_GL=egl

# try:
#   print('Checking that the installation succeeded:')
#   import mujoco
#   mujoco.MjModel.from_xml_string('<mujoco/>')
# except Exception as e:
#   raise e from RuntimeError(mj_model = mujoco.MjModel.from_xml_string(model_string)
#       'Something went wrong during installation. Check the shell output above '
#       'for more information.\n'
#       'If using a hosted Colab runtime, make sure you enable GPU acceleration '
#       'by going to the Runtime menu and selecting "Choose runtime type".')

# print('Installation successful.')

# # Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
# xla_flags = os.environ.get('XLA_FLAGS', '')
# xla_flags += ' --xla_gpu_triton_gemm_any=True'
# os.environ['XLA_FLAGS'] = xla_flags

In [4]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [5]:
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

import wandb

In [6]:
api_token = '50c8814867c630e4f1649261257a9e728b7cebb6'
try:
    if wandb.login(key=api_token, relogin=False):
        print("Successfully logged in to Weights & Biases.")
    else:
        print("Login failed. Please check your API token.")
except Exception as e:
    print(f"An error occurred: {e}")

In [7]:
config={
    "algorithm": "PPO",
    "Policy Network Architecture": "default",
    "Value Network Architecture": "default",
    "Obs":"data.qpos[2:]+data.qvel[3:6]+data.qacc[0:2]",
    "reward":"multi-stage",
    "num_timesteps":1000_000_000,
    "num_evals":500,
    "reward_scaling":1,
    "episode_length":1000,
    "normalize_observations": True,
    "action_repeat": 1,
    "unroll_length": 10,
    "num_minibatches": 32,
    "num_updates_per_batch": 8,
    "discounting": 0.97,
    "learning_rate": 3e-4,
    "entropy_cost": 1e-3,
    "num_envs": 2048,
    "batch_size": 1024,
    "seed": 0,
    "xml_path": None,
    "restore_checkpoint_path": None,
    "create_checkpoint_path": None
}

In [8]:
# a utility function to create a subset from a given config
sub_config = lambda config, wanted_keys: dict((k, config[k]) for k in wanted_keys if k in config)

In [9]:
model_string ="""<!-- Copyright 2021 DeepMind Technologies Limited

     Licensed under the Apache License, Version 2.0 (the "License");
     you may not use this file except in compliance with the License.
     You may obtain a copy of the License at

         http://www.apache.org/licenses/LICENSE-2.0

     Unless required by applicable law or agreed to in writing, software
     distributed under the License is distributed on an "AS IS" BASIS,
     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     See the License for the specific language governing permissions and
     limitations under the License.
-->

<mujoco model="Humanoid">
  <option timestep="0.005" iterations="1" ls_iterations="4">
    <flag eulerdamp="disable"/>
  </option>

  <visual>
    <map force="0.1" zfar="30"/>
    <rgba haze="0.15 0.25 0.35 1"/>
    <global offwidth="2560" offheight="1440" elevation="-20" azimuth="120"/>
  </visual>

  <statistic center="0 0 0.7"/>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1=".3 .5 .7" rgb2="0 0 0" width="32" height="512"/>
    <texture name="body" type="cube" builtin="flat" mark="cross" width="128" height="128" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01"/>
    <material name="body" texture="body" texuniform="true" rgba="0.8 0.6 .4 1"/>
    <texture name="grid" type="2d" builtin="checker" width="512" height="512" rgb1=".1 .2 .3" rgb2=".2 .3 .4"/>
    <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
  </asset>

  <default>
    <motor ctrlrange="-1 1" ctrllimited="true"/>
    <default class="body">

      <!-- geoms -->
      <!-- TODO(robotics-simulation): support condim=1 for humanoid capsules. -->
      <geom type="capsule" condim="3" friction=".7" solimp=".9 .99 .003" solref=".015 1" material="body" contype="0" conaffinity="0"/>
      <default class="thigh">
        <geom size=".06"/>
      </default>
      <default class="shin">
        <geom fromto="0 0 0 0 0 -.3"  size=".049"/>
      </default>
      <default class="foot">
        <geom size=".027"/>
        <default class="foot1">
          <geom fromto="-.07 -.01 0 .14 -.03 0"/>
        </default>
        <default class="foot2">
          <geom fromto="-.07 .01 0 .14  .03 0"/>
        </default>
      </default>
      <default class="arm_upper">
        <geom size=".04"/>
      </default>
      <default class="arm_lower">
        <geom size=".031"/>
      </default>
      <default class="hand">
        <geom type="sphere" size=".04"/>
      </default>

      <!-- joints -->
      <joint type="hinge" damping=".2" stiffness="1" armature=".01" limited="true" solimplimit="0 .99 .01"/>
      <default class="joint_big">
        <joint damping="5" stiffness="10"/>
        <!--default class="hip_x">
          <joint range="-30 10"/>
        </default>
        <default class="hip_z">
          <joint range="-60 35"/>
        </default-->
        <default class="hip_y">
          <joint axis="0 1 0" range="-50 20"/>
        </default>
        <default class="joint_big_stiff">
          <joint stiffness="20"/>
        </default>
      </default>
      <default class="knee">
        <joint pos="0 0 .02" axis="0 -1 0" range="-90 2"/>
      </default>
      <default class="ankle">
        <joint range="-50 50"/>
        <default class="ankle_y">
          <joint pos="0 0 .08" axis="0 1 0" stiffness="6"/>
        </default>
        <!--default class="ankle_x">
          <joint pos="0 0 .04" stiffness="3"/>
        </default-->
      </default>
      <default class="shoulder">
        <joint range="-85 60"/>
      </default>
      <!--default class="elbow">
        <joint range="-100 50" stiffness="0"/>
      </default-->
    </default>
  </default>

  <worldbody>
    <geom name="floor" size="0 0 .05" type="plane" material="grid" condim="3"/>
    <light name="spotlight" mode="targetbodycom" target="torso" diffuse=".8 .8 .8" specular="0.3 0.3 0.3" pos="0 -6 4" cutoff="30"/>
    <body name="torso" pos="0 0 1.282" childclass="body">
      <light name="top" pos="0 0 2" mode="trackcom"/>
      <camera name="back" pos="-3 0 1" xyaxes="0 -1 0 1 0 2" mode="trackcom"/>
      <camera name="side" pos="0 -3 1" xyaxes="1 0 0 0 1 2" mode="trackcom"/>
      <freejoint name="root"/>
      <geom name="torso" fromto="0 -.07 0 0 .07 0" size=".07"/>
      <geom name="waist_upper" fromto="-.01 -.06 -.12 -.01 .06 -.12" size=".06"/>
      <body name="head" pos="0 0 .19">
        <geom name="head" type="sphere" size=".09"/>
        <camera name="egocentric" pos=".09 0 0" xyaxes="0 -1 0 .1 0 1" fovy="80"/>
      </body>
      <body name="waist_lower" pos="-.01 0 -.26">
        <geom name="waist_lower" fromto="0 -.06 0 0 .06 0" size=".06"/>
        <joint name="abdomen_z" pos="0 0 .065" axis="0 0 1" range="-45 45" class="joint_big_stiff"/>
        <joint name="abdomen_y" pos="0 0 .065" axis="0 1 0" range="-75 30" class="joint_big"/>
        <body name="pelvis" pos="0 0 -.165">
          <joint name="abdomen_x" pos="0 0 .1" axis="1 0 0" range="-35 35" class="joint_big"/>
          <geom name="butt" fromto="-.02 -.07 0 -.02 .07 0" size=".09"/>
          <body name="thigh_right" pos="0 -.1 -.04">
            <!--joint name="hip_x_right" axis="1 0 0" class="hip_x"/>
            <joint name="hip_z_right" axis="0 0 1" class="hip_z"/-->
            <joint name="hip_y_right" class="hip_y"/>
            <geom name="thigh_right" fromto="0 0 0 0 .01 -.34" class="thigh"/>
            <body name="shin_right" pos="0 .01 -.4">
              <joint name="knee_right" class="knee"/>
              <geom name="shin_right" class="shin"/>
              <body name="foot_right" pos="0 0 -.39">
                <joint name="ankle_y_right" class="ankle_y"/>
                <!--joint name="ankle_x_right" class="ankle_x" axis="1 0 .5"/-->
                <geom name="foot1_right" class="foot1"/>
                <geom name="foot2_right" class="foot2"/>
              </body>
            </body>
          </body>
          <body name="thigh_left" pos="0 .1 -.04">
            <!--joint name="hip_x_left" axis="-1 0 0" class="hip_x"/>
            <joint name="hip_z_left" axis="0 0 -1" class="hip_z"/-->
            <joint name="hip_y_left" class="hip_y"/>
            <geom name="thigh_left" fromto="0 0 0 0 -.01 -.34" class="thigh"/>
            <body name="shin_left" pos="0 -.01 -.4">
              <joint name="knee_left" class="knee"/>
              <geom name="shin_left" fromto="0 0 0 0 0 -.3" class="shin"/>
              <body name="foot_left" pos="0 0 -.39">
                <joint name="ankle_y_left" class="ankle_y"/>
                <!--joint name="ankle_x_left" class="ankle_x" axis="-1 0 -.5"/-->
                <geom name="foot1_left" class="foot1"/>
                <geom name="foot2_left" class="foot2"/>
              </body>
            </body>
          </body>
        </body>
      </body>
      <body name="upper_arm_right" pos="0 -.17 .06">
        <joint name="shoulder1_right" axis="2 1 1"  class="shoulder"/>
        <joint name="shoulder2_right" axis="0 -1 1" class="shoulder"/>
        <geom name="upper_arm_right" fromto="0 0 0 .16 -.16 -.16" class="arm_upper"/>
        <body name="lower_arm_right" pos=".18 -.18 -.18">
          <!--joint name="elbow_right" axis="0 -1 1" class="elbow"/-->
          <geom name="lower_arm_right" fromto=".01 -.01 -.01 .17 -.17 -.17" class="arm_lower"/>
          <body name="hand_right" pos=".18 -.18 -.18">
            <geom name="hand_right" zaxis="1 1 1" class="hand"/>
          </body>
        </body>
      </body>
      <body name="upper_arm_left" pos="0 .17 .06">
        <joint name="shoulder1_left" axis="-2 1 -1" class="shoulder"/>
        <joint name="shoulder2_left" axis="0 -1 -1"  class="shoulder"/>
        <geom name="upper_arm_left" fromto="0 0 0 .16 .16 -.16" class="arm_upper"/>
        <body name="lower_arm_left" pos=".18 .18 -.18">
          <!--joint name="elbow_left" axis="0 -1 -1" class="elbow"/-->
          <geom name="lower_arm_left" fromto=".01 .01 -.01 .17 .17 -.17" class="arm_lower"/>
          <body name="hand_left" pos=".18 .18 -.18">
            <geom name="hand_left" zaxis="1 -1 1" class="hand"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>

  <contact>
    <exclude body1="waist_lower" body2="thigh_right"/>
    <exclude body1="waist_lower" body2="thigh_left"/>
    <pair geom1="foot1_left" geom2="floor"/>
    <pair geom1="foot1_right" geom2="floor"/>
    <pair geom1="foot2_left" geom2="floor"/>
    <pair geom1="foot2_right" geom2="floor"/>
  </contact>

  <!-- <tendon>
    <fixed name="hamstring_right" limited="true" range="-0.3 2">
      <joint joint="hip_y_right" coef=".5"/>
      <joint joint="knee_right" coef="-.5"/>
    </fixed>
    <fixed name="hamstring_left" limited="true" range="-0.3 2">
      <joint joint="hip_y_left" coef=".5"/>
      <joint joint="knee_left" coef="-.5"/>
    </fixed>
  </tendon> -->

  <actuator>
    <motor name="abdomen_y"       gear="40"  joint="abdomen_y"/>
    <motor name="abdomen_z"       gear="40"  joint="abdomen_z"/>
    <motor name="abdomen_x"       gear="40"  joint="abdomen_x"/>
    <!--motor name="hip_x_right"     gear="40"  joint="hip_x_right"/>
    <motor name="hip_z_right"     gear="40"  joint="hip_z_right"/-->
    <motor name="hip_y_right"     gear="50" joint="hip_y_right"/>
    <motor name="knee_right"      gear="80"  joint="knee_right"/>
    <!--motor name="ankle_x_right"   gear="20"  joint="ankle_x_right"/-->
    <motor name="ankle_y_right"   gear="20"  joint="ankle_y_right"/>
    <!--motor name="hip_x_left"      gear="40"  joint="hip_x_left"/>
    <motor name="hip_z_left"      gear="40"  joint="hip_z_left"/-->
    <motor name="hip_y_left"      gear="50" joint="hip_y_left"/>
    <motor name="knee_left"       gear="80"  joint="knee_left"/>
    <!--motor name="ankle_x_left"    gear="20"  joint="ankle_x_left"/-->
    <motor name="ankle_y_left"    gear="20"  joint="ankle_y_left"/>
    <!--motor name="shoulder1_right" gear="20"  joint="shoulder1_right"/> -->
    <motor name="shoulder2_right" gear="20"  joint="shoulder2_right"/> 
    <!--<motor name="elbow_right"     gear="40"  joint="elbow_right"/>
    <motor name="shoulder1_left"  gear="20"  joint="shoulder1_left"/> -->
    <motor name="shoulder2_left"  gear="20"  joint="shoulder2_left"/>
    <!--<motor name="elbow_left"      gear="40"  joint="elbow_left"/> -->
  </actuator>

  <!--keyframe>
    <!--
    The values below are split into rows for readibility:
      torso position
      torso orientation
      spinal
      right leg
      left leg
      arms
    								->
    <key name="squat" qpos="0 0 0.596
                            0.988015 0 0.154359 0
                            0 0 0"/>
    <key name="stand_on_left_leg" qpos="0 0 1.21948
                                        0.971588 -0.179973 0.135318 -0.0729076
                                        -0.0516 -0.202 0.23
                                        -0.24 -0.007 -0.34 -1.76 -0.466 -0.0415
                                        -0.08 -0.01 -0.37 -0.685 -0.35 -0.09
                                        0.109 -0.067 -0.7 -0.05 0.12 0.16"/>
    <key name="home" qpos='0 0 1.282 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0'/>
  </keyframe-->
</mujoco>"""

In [10]:
# The value returned by tolerance() at `margin` distance from `bounds` interval.
_DEFAULT_VALUE_AT_MARGIN = 0.1
import warnings

def _sigmoids(x, value_at_1, sigmoid):
  """Returns 1 when `x` == 0, between 0 and 1 otherwise.

  Args:
    x: A scalar or numpy array.
    value_at_1: A float between 0 and 1 specifying the output when `x` == 1.
    sigmoid: String, choice of sigmoid type.

  Returns:
    A numpy array with values between 0.0 and 1.0.

  Raises:
    ValueError: If not 0 < `value_at_1` < 1, except for `linear`, `cosine` and
      `quadratic` sigmoids which allow `value_at_1` == 0.
    ValueError: If `sigmoid` is of an unknown type.
  """
  if sigmoid in ('cosine', 'linear', 'quadratic'):
    if not 0 <= value_at_1 < 1:
      raise ValueError('`value_at_1` must be nonnegative and smaller than 1, '
                       'got {}.'.format(value_at_1))
  else:
    if not 0 < value_at_1 < 1:
      raise ValueError('`value_at_1` must be strictly between 0 and 1, '
                       'got {}.'.format(value_at_1))

  if sigmoid == 'gaussian':
    scale = jp.sqrt(-2 * jp.log(value_at_1))
    return jp.exp(-0.5 * (x*scale)**2)

  elif sigmoid == 'hyperbolic':
    scale = jp.arccosh(1/value_at_1)
    return 1 / jp.cosh(x*scale)

  elif sigmoid == 'long_tail':
    scale = jp.sqrt(1/value_at_1 - 1)
    return 1 / ((x*scale)**2 + 1)

  elif sigmoid == 'reciprocal':
    scale = 1/value_at_1 - 1
    return 1 / (abs(x)*scale + 1)

  elif sigmoid == 'cosine':
    scale = jp.arccos(2*value_at_1 - 1) / jp.pi
    scaled_x = x*scale
    with warnings.catch_warnings():
      warnings.filterwarnings(
          action='ignore', message='invalid value encountered in cos')
      cos_pi_scaled_x = jp.cos(jp.pi*scaled_x)
    return jp.where(abs(scaled_x) < 1, (1 + cos_pi_scaled_x)/2, 0.0)

  elif sigmoid == 'linear':
    scale = 1-value_at_1
    scaled_x = x*scale
    return jp.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0)

  elif sigmoid == 'quadratic':
    scale = jp.sqrt(1-value_at_1)
    scaled_x = x*scale
    return jp.where(abs(scaled_x) < 1, 1 - scaled_x**2, 0.0)

  elif sigmoid == 'tanh_squared':
    scale = jp.arctanh(jp.sqrt(1-value_at_1))
    return 1 - jp.tanh(x*scale)**2

  else:
    raise ValueError('Unknown sigmoid type {!r}.'.format(sigmoid))


def tolerance(x, bounds=(0.0, 0.0), margin=0.0, sigmoid='gaussian',
              value_at_margin=_DEFAULT_VALUE_AT_MARGIN):
  """Returns 1 when `x` falls inside the bounds, between 0 and 1 otherwise.

  Args:
    x: A scalar or numpy array.
    bounds: A tuple of floats specifying inclusive `(lower, upper)` bounds for
      the target interval. These can be infinite if the interval is unbounded
      at one or both ends, or they can be equal to one another if the target
      value is exact.
    margin: Float. Parameter that controls how steeply the output decreases as
      `x` moves out-of-bounds.
      * If `margin == 0` then the output will be 0 for all values of `x`
        outside of `bounds`.
      * If `margin > 0` then the output will decrease sigmoidally with
        increasing distance from the nearest bound.
    sigmoid: String, choice of sigmoid type. Valid values are: 'gaussian',
       'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'.
    value_at_margin: A float between 0 and 1 specifying the output value when
      the distance from `x` to the nearest bound is equal to `margin`. Ignored
      if `margin == 0`.

  Returns:
    A float or numpy array with values between 0.0 and 1.0.

  Raises:
    ValueError: If `bounds[0] > bounds[1]`.
    ValueError: If `margin` is negative.
  """
  lower, upper = bounds
  if lower > upper:
    raise ValueError('Lower bound must be <= upper bound.')
  if margin < 0:
    raise ValueError('`margin` must be non-negative.')

  in_bounds = jp.logical_and(lower <= x, x <= upper)
  if margin == 0:
    value = jp.where(in_bounds, 1.0, 0.0)
  else:
    d = jp.where(x < lower, lower - x, x - upper) / margin
    value = jp.where(in_bounds, 1.0, _sigmoids(d, value_at_margin, sigmoid))

  return float(value) if jp.isscalar(x) else value

In [11]:
#@title Humanoid Env
class Humanoid(PipelineEnv):

  def __init__(
      self,
      reset_noise_scale=1e-2,
      **kwargs,
  ):
    path = epath.Path(epath.resource_path('mujoco')) / (
        'mjx/test_data/humanoid'
    )
    mj_model = mujoco.MjModel.from_xml_path(
        (path / 'humanoid.xml').as_posix())
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6
    
    self.mj_model = mj_model

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)
    self._reset_noise_scale = reset_noise_scale
  


  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )
    
    data = self.pipeline_init(qpos, qvel)

    reward, done, zero = jp.zeros(3)
      
    metrics = {
      'com_position_x':zero,
      'com_position_y':zero,
      'com_position_z':zero,
      'com_velocity_x':zero,
      'com_velocity_y':zero,
      'com_velocity_z':zero,
    }
    
  

    obs = self._get_obs(data, jp.zeros(3))

    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    com_velocity = (com_after - com_before) / self.dt
    
    done = jp.where(data.q[2] < 1.0, 1.0, 0.0)
    
    state.metrics.update(
      com_position_x=com_after[0],
      com_position_y=com_after[1],
      com_position_z=com_after[2],
      com_velocity_x=com_velocity[0],
      com_velocity_y=com_velocity[1],
      com_velocity_z=com_velocity[2],
    )
    

    return state.replace(
        pipeline_state=data, obs=self._get_obs(data,com_velocity), reward=self._get_reward(data,com_velocity), done=done
    )
  def _get_reward(
        self, data: mjx.Data, com_velocity: jp.ndarray
    ):
    standing = tolerance(
      data.geom_xpos[3, 2],
      bounds = (1.4, float('inf')),
      margin = 1.4/4
      )
    upright = tolerance(
      data.geom_xmat[1, 2, 2],
      bounds=(0.9, float('inf')),
      sigmoid='linear',
      margin = 1.9, 
      value_at_margin = 0
    )

    stand_reward = standing*upright

    small_control = tolerance(
      data.ctrl,
      margin=1,
      value_at_margin=0,
      sigmoid='quadratic'
    ).mean()

    small_control = (4 + small_control) / 5
    move = tolerance(
      com_velocity[0],
      bounds = (1, float('inf')),
      margin=1,
      value_at_margin=0,
      sigmoid='linear'
    )
    move = (5*move + 1) / 6

    return small_control * stand_reward * move
  
  def _get_obs(
        self, data: mjx.Data, com_velocity: jp.ndarray
    ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    positions = []
    for i in [9, 16, 13, 19]:
        torso_to_limb = data.geom_xpos[i]-data.geom_xpos[1]
        positions.append(torso_to_limb.dot(data.geom_xmat[1]))
    positions = jp.hstack(positions)
    return jp.concatenate([
            data.qpos[7:],
            jp.expand_dims(data.geom_xpos[3, 2], axis=-1),
            positions,
            data.geom_xmat[1, :, 2],
            com_velocity,
            data.qvel[7:]
        ])

  
envs.register_environment('humanoid', Humanoid)

In [12]:

env_name = 'humanoid'
env_config = {}#sub_config(config, ['reset_noise_scale'])
env = envs.get_environment(env_name, **env_config)

In [13]:
wandb.init(
    project="tmp",
    config=config
)

In [14]:
import zipfile
import os
def zip_directory(folder_path, zip_path):
    # Ensure the folder exists
    if not os.path.exists(folder_path):
        print(f"The directory {folder_path} does not exist.")
        return

    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(folder_path):
            for file in files:
                zipf.write(os.path.join(root, file),
                           os.path.relpath(os.path.join(root, file),
                                           os.path.join(folder_path, '..')))
    print(f"Successfully created zip file at {zip_path}")
    
    
    