# Install MuJoCo, MJX, and Brax

In [None]:
!pip install git+https://github.com/A-H-Mansoury/brax.git

In [2]:
!pip install -q mujoco wandb mujoco_mjx ml_collections
#!pip install -q --upgrade ipykernel

In [3]:
# #@title Check if MuJoCo installation was successful

# import distutils.util
# import os
# import subprocess
# if subprocess.run('nvidia-smi').returncode:
#   raise RuntimeError(
#       'Cannot communicate with GPU. '
#       'Make sure you are using a GPU Colab runtime. '
#       'Go to the Runtime menu and select Choose runtime type.')

# # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# # This is usually installed as part of an Nvidia driver package, but the Colab
# # kernel doesn't install its driver via APT, and as a result the ICD is missing.
# # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
# # NVIDIA_ICD_CONFIG_PATH = '../../usr/share/glvnd/egl_vendor.d/10_nvidia.json'
# # if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
# #   with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
# #     f.write("""{
# #     "file_format_version" : "1.0.0",
# #     "ICD" : {
# #         "library_path" : "libEGL_nvidia.so.0"
# #     }
# # }
# # """)

# # Configure MuJoCo to use the EGL rendering backend (requires GPU)
# print('Setting environment variable to use GPU rendering:')
# %env MUJOCO_GL=egl

# try:
#   print('Checking that the installation succeeded:')
#   import mujoco
#   mujoco.MjModel.from_xml_string('<mujoco/>')
# except Exception as e:
#   raise e from RuntimeError(
#       'Something went wrong during installation. Check the shell output above '
#       'for more information.\n'
#       'If using a hosted Colab runtime, make sure you enable GPU acceleration '
#       'by going to the Runtime menu and selecting "Choose runtime type".')

# print('Installation successful.')

# # Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
# xla_flags = os.environ.get('XLA_FLAGS', '')
# xla_flags += ' --xla_gpu_triton_gemm_any=True'
# os.environ['XLA_FLAGS'] = xla_flags


In [4]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

Installing mediapy:


In [5]:
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

import wandb

#login to wandb

In [6]:
api_token = '50c8814867c630e4f1649261257a9e728b7cebb6'
try:
    if wandb.login(key=api_token, relogin=False):
        print("Successfully logged in to Weights & Biases.")
    else:
        print("Login failed. Please check your API token.")
except Exception as e:
    print(f"An error occurred: {e}")

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Successfully logged in to Weights & Biases.


# Experiment Config

In [7]:
config={
    "algorithm": "PPO",
    "Policy Network Architecture": "default",
    "Value Network Architecture": "default",
    "Obs":"data.qpos[2:]+data.qvel[3:6]+data.qacc[0:2]",
    "reward":"multi-stage",
    "num_timesteps":100_000_000,
    "num_evals":10,
    "reward_scaling":1,
    "episode_length":1000,
    "normalize_observations": True,
    "action_repeat": 1,
    "unroll_length": 10,
    "num_minibatches": 32,
    "num_updates_per_batch": 8,
    "discounting": 0.97,
    "learning_rate": 3e-4,
    "entropy_cost": 1e-3,
    "num_envs": 2048,
    "batch_size": 1024,
    "seed": 0,
    "xml_path": None,
    "restore_checkpoint_path": None,
    "create_checkpoint_path": None
}

# Utilities

In [8]:
# a utility function to create a subset from a given config
sub_config = lambda config, wanted_keys: dict((k, config[k]) for k in wanted_keys if k in config)

# Train

In [None]:
#@title Humanoid Env
class Humanoid(PipelineEnv):

  def __init__(
      self,
      reset_noise_scale=1e-2,
      **kwargs,
  ):
    path = epath.Path(epath.resource_path('mujoco')) / (
        'mjx/test_data/humanoid'
    )
    mj_model = mujoco.MjModel.from_xml_path(
        (path / 'humanoid.xml').as_posix())
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6
    
    self.mj_model = mj_model

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)
    self._reset_noise_scale = reset_noise_scale


  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    #qpos = qpos.at[18].set(0.01)
    qpos = qpos.at[12].set(-0.1)
#     qvel = qvel.at[18].set(-1)
#     qvel = qvel.at[12].set(1)

    qpos = qpos.at[jp.array([22, 25])].set(0.76)
    qpos = qpos.at[jp.array([23, 26])].set(-0.46)
    qpos = qpos.at[jp.array([24, 27])].set(-1.75)
    
    
    
    data = self.pipeline_init(qpos, qvel)

    reward, done, zero = jp.zeros(3)
      
    metrics = {
        "orient_reward": zero,
        "foot_contact_reward": zero,
        "base_height_reward": zero,
        "feet_airtime_reward": zero,
        "arm_reward": zero,
        "base_acceleration_reward": zero,
        "action_difference_reward": zero,
        "torque_reward": zero,
        "left_foot_placment_reward": zero,
        'right_foot_placment_reward': zero,
        'single_support_time': zero,
        'right_foot_airtime': zero,
        'left_foot_airtime': zero,
        'previous_action':jp.zeros(self.sys.nu),
        'Crightfoot': zero,
        'Cleftfoot': zero-0.21,
        'turn': zero<12,
    }
    
    info = {
      'single_support_time': zero,
      'right_foot_airtime': zero,
      'left_foot_airtime': zero,
      'previous_action':jp.zeros(self.sys.nu),
      'Crightfoot': zero,
      'Cleftfoot': zero-0.21,
      'turn': zero<12,
    }
    

    obs = self._get_obs(data, jp.zeros(self.sys.nu))

    return State(data, obs, reward, done, metrics, info)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""

    indexes = jp.array([15, 17, 18, 20])
    action = action.at[indexes].set(0)
    
    data0 = state.pipeline_state


    # Conditional update for action[5]
    action = jax.lax.cond(jp.abs(data0.qpos[12]) < 0.35,
                      lambda _: action,
                      lambda _: action.at[5].set(0),
                      None)

    # Conditional update for action[11]
    action = jax.lax.cond(jp.abs(data0.qpos[18]) < 0.35,
                      lambda _: action,
                      lambda _: action.at[11].set(0),
                      None)


    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    
    collision = data.contact.dist < 0.01

    is_touchdown = [
      jax.lax.bitwise_or(collision[0], collision[1]),
      jax.lax.bitwise_or(collision[2], collision[3])
      ]
    
    is_single_support = jax.lax.bitwise_xor(
      is_touchdown[0],
      is_touchdown[1]
    )
    
    state.info['single_support_time'] = jp.where(
      is_single_support,
      state.info['single_support_time']+self.dt,
      0
    )

    state.info['right_foot_airtime'] = jp.where(
      is_touchdown[0],
      state.info['right_foot_airtime']+self.dt,
      0
    )

    state.info['left_foot_airtime'] = jp.where(
      is_touchdown[1],
      state.info['left_foot_airtime']+self.dt,
      0
    )

    epsilon = 0.01
    swing_length = 0.21
    
    right_foot_placment_reward = 0.15*jp.exp(-5*(jp.abs(data.xpos[:,7,0]-state.info['Crightfoot']))**2)

    left_foot_placment_reward = 0.15*jp.exp(-5*(jp.abs(data.xpos[:,10,0]-state.info['Cleftfoot']))**2)
    
    is_right_foot_inplace = jax.lax.bitwise_and(
        state.info['turn'],
        jp.abs(data.xpos[:,7,0]-state.info['Crightfoot']) < epsilon
    )
    is_left_foot_inplace = jax.lax.bitwise_and(
        jax.lax.bitwise_not(state.info['turn']),
        jp.abs(data.xpos[:,10,0]-state.info['Cleftfoot']) < epsilon
    )
    
    state.info['Crightfoot'] = jp.where(
         is_right_foot_inplace,
        state.info['Crightfoot'] + swing_length,
        jp.zeros(state.info['Crightfoot'].shape)
    )
    state.info['Cleftfoot'] = jp.where(
         is_left_foot_inplace,
        state.info['Cleftfoot'] + swing_length,
        jp.zeros(state.info['Cledtfoot'].shape)
    )
    
    state.info['turn'] = jax.lax.cond(
                      jax.lax.bitwise_xor(is_right_foot_inplace, is_left_foot_inplace),
                      lambda _: jax.lax.bitwise_not(state.info['turn']),
                      lambda _: state.info['turn'],
                      None)

    
    
    Cyrp = jp.array([1,0,0,0])
    orient_reward = 0.1*jp.exp(-300*jp.linalg.norm(data.qpos[3:7] - Cyrp))
    
    foot_contact_reward = jp.where(
      state.info['single_support_time'] >= 0.2,
      1.0,
      0.0
    )

    Ch = 1.282
    base_height_reward = 0.05*jp.exp(-20*jp.abs(data.qpos[2]-Ch))

    Cair = 0.4
    feet_airtime_reward = jp.where(
        is_touchdown[0],
        state.info['right_foot_airtime']-Cair,
        0
    )

    feet_airtime_reward += jp.where(
        is_touchdown[1],
        state.info['left_foot_airtime']-Cair,
        0
    )

    # feet_orientation_reward = 0 #dropped the is very little information to its implemetation and probably it is for standing task in walking it would harm agents adaptability to different terrains.

    # feet_position = 0 #dropped not same reason

    Carm = jp.array([0.76, -0.46, -1.75, 0.76, -0.46, -1.75])
    arm_reward = 0.03*jp.exp(-3*jp.linalg.norm(data.qpos[22:]-Carm))
    
    base_acceleration_reward = 0.1*jp.exp(-0.01*jp.sum(jp.abs(data.qacc[:3])))

    action_difference_reward = 0.002*jp.exp(-0.02*jp.sum(jp.abs(action-state.info['previous_action'])))

    torque_reward = 0.02*jp.exp(-0.02*1/21*jp.sum(jp.abs(data.qfrc_actuator[6:])))

    state.info['previous_action'] = action


    done = jp.where(data.q[2] < 1.0, 1.0, 0.0)

    obs = self._get_obs(data, action)
    
    reward = orient_reward \
            + foot_contact_reward \
            + base_height_reward \
            + feet_airtime_reward \
            + arm_reward \
            + base_acceleration_reward \
            + action_difference_reward \
            + torque_reward \
            + right_foot_placment_reward\
            + left_foot_placment_reward            


    state.metrics.update(
        orient_reward=orient_reward,
        foot_contact_reward=foot_contact_reward,
        base_height_reward=base_height_reward,
        feet_airtime_reward=feet_airtime_reward,
        arm_reward=arm_reward,
        base_acceleration_reward=base_acceleration_reward,
        action_difference_reward=action_difference_reward,
        torque_reward=torque_reward,
        right_foot_placment_reward=right_foot_placment_reward,
        left_foot_placment_reward=left_foot_placment_reward,
        single_support_time=state.info['single_support_time'],
        right_foot_airtime=state.info['right_foot_airtime'],
        left_foot_airtime=state.info['left_foot_airtime'],
        previous_action=state.info['previous_action'],
        Crightfoot=state.info['Crightfoot'],
        Cleftfoot=state.info['Cleftfoot'],
        turn=state.info['turn'],
    )
    

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done, info=state.info
    )
  
  def _get_obs(
        self, data: mjx.Data, action: jp.ndarray
    ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    return jp.concatenate([
            data.qpos[2:],
            data.qvel[2:],
            action,
            data.qfrc_actuator
        ])
  
envs.register_environment('humanoid', Humanoid)

In [None]:

env_name = 'humanoid'
env_config = {}#sub_config(config, ['reset_noise_scale'])
env = envs.get_environment(env_name, **env_config)

In [None]:
wandb.init(
    project="tmp",
    config=config
)

## Train Humanoid Policy

In [52]:
import zipfile
import os
def zip_directory(folder_path, zip_path):
    # Ensure the folder exists
    if not os.path.exists(folder_path):
        print(f"The directory {folder_path} does not exist.")
        return

    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(folder_path):
            for file in files:
                zipf.write(os.path.join(root, file),
                           os.path.relpath(os.path.join(root, file),
                                           os.path.join(folder_path, '..')))
    print(f"Successfully created zip file at {zip_path}")
    
    
    

In [53]:
%notebook notebook.ipynb
wandb.save("/kaggle/working/notebook.ipynb", base_path="/kaggle/working/")

['/kaggle/working/wandb/run-20240828_204556-tapggqag/files/notebook.ipynb']

In [54]:
ckpt_path = epath.Path('/kaggle/working/tmp/humanoid/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  #save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)
  zip_path = ckpt_path / f'{current_step}.zip'
  zip_directory(path,zip_path)
  wandb.save(zip_path, base_path="/kaggle/working/")
  

def save_checkpoint(path, params):
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

train_config = sub_config(config, ['num_timesteps', 'num_evals', 'reward_scaling', 'episode_length', 'normalize_observations', 'action_repeat', 'unroll_length', 'num_minibatches', 'num_updates_per_batch', 'discounting', 'learning_rate', 'entropy_cost', 'num_envs', 'batch_size', 'seed'])

# make_networks_factory = functools.partial(
#     ppo_networks.make_ppo_networks,
#         policy_hidden_layer_sizes=(16, 32, 32)
# )

train_fn = functools.partial(ppo.train, **train_config)


times = [datetime.now()]
def progress(num_steps, metrics):
  times.append(datetime.now())
  _metrics = metrics.copy()
  _metrics['numsteps'] = num_steps
  wandb.log(_metrics)

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress, policy_params_fn=policy_params_fn)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

ValueError: Incompatible shapes for broadcasting: shapes=[(128, 21), (128,)]

In [None]:
# import zipfile
# import os
# def zip_directory(folder_path, zip_path):
#     # Ensure the folder exists
#     if not os.path.exists(folder_path):
#         print(f"The directory {folder_path} does not exist.")
#         return

#     with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
#         for root, dirs, files in os.walk(folder_path):
#             for file in files:
#                 zipf.write(os.path.join(root, file),
#                            os.path.relpath(os.path.join(root, file),
#                                            os.path.join(folder_path, '..')))
#     print(f"Successfully created zip file at {zip_path}")

# # Define the folder and zip file paths
# folder_path = "/kaggle/working/tmp/humanoid/ckpts"
# zip_path = "/kaggle/working/ckpts.zip"

# # Call the function to zip the directory
# wandb.init(project='tmp', resume='psw0ok5x')
# save_checkpoint(folder_path, params)
# zip_directory(folder_path, zip_path)
# wandb.save(zip_path, base_path="/kaggle/working/")

In [None]:
wandb.finish()