# Install MuJoCo, MJX, and Brax

In [1]:
!pip install -q mujoco wandb mujoco_mjx ml_collections brax
#!pip install -q --upgrade ipykernel

In [2]:
# #@title Check if MuJoCo installation was successful

# import distutils.util
# import os
# import subprocess
# if subprocess.run('nvidia-smi').returncode:
#   raise RuntimeError(
#       'Cannot communicate with GPU. '
#       'Make sure you are using a GPU Colab runtime. '
#       'Go to the Runtime menu and select Choose runtime type.')

# # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# # This is usually installed as part of an Nvidia driver package, but the Colab
# # kernel doesn't install its driver via APT, and as a result the ICD is missing.
# # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
# # NVIDIA_ICD_CONFIG_PATH = '../../usr/share/glvnd/egl_vendor.d/10_nvidia.json'
# # if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
# #   with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
# #     f.write("""{
# #     "file_format_version" : "1.0.0",
# #     "ICD" : {
# #         "library_path" : "libEGL_nvidia.so.0"
# #     }
# # }
# # """)

# # Configure MuJoCo to use the EGL rendering backend (requires GPU)
# print('Setting environment variable to use GPU rendering:')
# %env MUJOCO_GL=egl

# try:
#   print('Checking that the installation succeeded:')
#   import mujoco
#   mujoco.MjModel.from_xml_string('<mujoco/>')
# except Exception as e:
#   raise e from RuntimeError(
#       'Something went wrong during installation. Check the shell output above '
#       'for more information.\n'
#       'If using a hosted Colab runtime, make sure you enable GPU acceleration '
#       'by going to the Runtime menu and selecting "Choose runtime type".')

# print('Installation successful.')

# # Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
# xla_flags = os.environ.get('XLA_FLAGS', '')
# xla_flags += ' --xla_gpu_triton_gemm_any=True'
# os.environ['XLA_FLAGS'] = xla_flags


In [3]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

Installing mediapy:


In [4]:
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

import wandb

#login to wandb

In [5]:
api_token = '50c8814867c630e4f1649261257a9e728b7cebb6'
try:
    if wandb.login(key=api_token, relogin=False):
        print("Successfully logged in to Weights & Biases.")
    else:
        print("Login failed. Please check your API token.")
except Exception as e:
    print(f"An error occurred: {e}")

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Successfully logged in to Weights & Biases.


# Experiment Config

In [6]:
config={
    "algorithm": "PPO",
    "Policy Network Architecture": "default",
    "Value Network Architecture": "default",
    "Obs":"data.qpos[2:]+data.qvel[3:6]+data.qacc[0:2]",
    "reward":"multi-stage",
    "num_timesteps":100_000_000,
    "num_evals":10,
    "reward_scaling":1,
    "episode_length":1000,
    "normalize_observations": True,
    "action_repeat": 1,
    "unroll_length": 10,
    "num_minibatches": 32,
    "num_updates_per_batch": 8,
    "discounting": 0.97,
    "learning_rate": 3e-4,
    "entropy_cost": 1e-3,
    "num_envs": 2048,
    "batch_size": 1024,
    "seed": 0,
    "xml_path": None,
    "restore_checkpoint_path": None,
    "create_checkpoint_path": None
}

# Utilities

In [7]:
# a utility function to create a subset from a given config
sub_config = lambda config, wanted_keys: dict((k, config[k]) for k in wanted_keys if k in config)

# Train

In [14]:
#@title Humanoid Env

class Humanoid(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      phase_threshold = 40,
      **kwargs,
  ):
    path = epath.Path(epath.resource_path('mujoco')) / (
        'mjx/test_data/humanoid'
    )
    mj_model = mujoco.MjModel.from_xml_path(
        (path / 'humanoid.xml').as_posix())
#     mj_model = mujoco.MjModel.from_xml_string(xml_str)
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6
    self.mj_model = mj_model

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    name2id = lambda x: mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_GEOM, x)
    
    self.head_id = name2id('head')
    self.foot1_right_id = name2id('foot1_right')
    self.foot1_left_id = name2id('foot1_left')
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self.phase_threshold = phase_threshold
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )
    

    data = self.pipeline_init(qpos, qvel)

    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    info = {'counter':zero}

    obs = self._get_obs(data, jp.zeros(self.sys.nu), info['counter'])

    return State(data, obs, reward, done, metrics, info)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    action = jp.where(
        state.info['counter']% self.phase_threshold <= (self.phase_threshold/2),
        action,
        jp.concatenate([
          action[0:3],
          action[9:15],
          action[3:9],
          action[18:21],
          action[15:18]
        ])
    )
    #action = action.at[jp.array([15, 17, 18, 20])].set(0)
    
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    min_z, max_z = self._healthy_z_range

    y_head = data.geom_xpos[self.head_id, 2]
    y_mean_feet = (data.geom_xpos[self.foot1_right_id, 2]+ data.geom_xpos[self.foot1_left_id, 2])/2

    done = jp.where(
      (y_head-y_mean_feet) < 0.9,
      1.0,
      0.0
      )

    healthy_reward = self._healthy_reward
    
    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    state.info['counter'] = (state.info['counter']+1)%self.phase_threshold

    obs = self._get_obs(data, action, state.info['counter'])
    reward = forward_reward + healthy_reward - ctrl_cost

    
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done, info=state.info
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray, counter: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    # external_contact_forces are excluded
    return jp.where(
        counter % self.phase_threshold <= (self.phase_threshold/2),
        jp.concatenate([
          data.qpos[2:],
          data.qvel[3:],
          action,
        ]),
        jp.concatenate([
          data.qpos[2:10],
          data.qpos[16:22],
          data.qpos[10:16],
          data.qpos[25:],
          data.qpos[22:25],
          data.qvel[3:9],
          data.qvel[15:21],
          data.qvel[9:15],
          data.qvel[24:],
          data.qvel[21:24],
          action[0:3],
          action[9:15],
          action[3:9],
          action[18:],
          action[15:18]
        ])
    )


envs.register_environment('humanoid', Humanoid)

In [15]:

env_name = 'humanoid'
env_config = {'phase_threshold': 26}#sub_config(config, ['reset_noise_scale'])
env = envs.get_environment(env_name, **env_config)

In [10]:
wandb.init(
    project="tmp",
    config=config
)

[34m[1mwandb[0m: Currently logged in as: [33mtmptmp[0m ([33mtmptmp-tmp[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Train Humanoid Policy

In [11]:
import zipfile
import os
def zip_directory(folder_path, zip_path):
    # Ensure the folder exists
    if not os.path.exists(folder_path):
        print(f"The directory {folder_path} does not exist.")
        return

    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(folder_path):
            for file in files:
                zipf.write(os.path.join(root, file),
                           os.path.relpath(os.path.join(root, file),
                                           os.path.join(folder_path, '..')))
    print(f"Successfully created zip file at {zip_path}")
    
    
    

In [12]:
%notebook notebook.ipynb
wandb.save("/kaggle/working/notebook.ipynb", base_path="/kaggle/working/")

['/kaggle/working/wandb/run-20240909_124222-zrm7pyza/files/notebook.ipynb']

In [None]:
ckpt_path = epath.Path('/kaggle/working/tmp/humanoid/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  #save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)
  zip_path = ckpt_path / f'{current_step}.zip'
  zip_directory(path,zip_path)
  wandb.save(zip_path, base_path="/kaggle/working/")
  

def save_checkpoint(path, params):
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

train_config = sub_config(config, ['num_timesteps', 'num_evals', 'reward_scaling', 'episode_length', 'normalize_observations', 'action_repeat', 'unroll_length', 'num_minibatches', 'num_updates_per_batch', 'discounting', 'learning_rate', 'entropy_cost', 'num_envs', 'batch_size', 'seed'])

# make_networks_factory = functools.partial(
#     ppo_networks.make_ppo_networks,
#         policy_hidden_layer_sizes=(16, 32, 32)
# )

train_fn = functools.partial(ppo.train, **train_config)


times = [datetime.now()]
def progress(num_steps, metrics):
  times.append(datetime.now())
  _metrics = metrics.copy()
  _metrics['numsteps'] = num_steps
  wandb.log(_metrics)

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress, policy_params_fn=policy_params_fn)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
# import zipfile
# import os
# def zip_directory(folder_path, zip_path):
#     # Ensure the folder exists
#     if not os.path.exists(folder_path):
#         print(f"The directory {folder_path} does not exist.")
#         return

#     with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
#         for root, dirs, files in os.walk(folder_path):
#             for file in files:
#                 zipf.write(os.path.join(root, file),
#                            os.path.relpath(os.path.join(root, file),
#                                            os.path.join(folder_path, '..')))
#     print(f"Successfully created zip file at {zip_path}")

# # Define the folder and zip file paths
# folder_path = "/kaggle/working/tmp/humanoid/ckpts"
# zip_path = "/kaggle/working/ckpts.zip"

# # Call the function to zip the directory
# wandb.init(project='tmp', resume='psw0ok5x')
# save_checkpoint(folder_path, params)
# zip_directory(folder_path, zip_path)
# wandb.save(zip_path, base_path="/kaggle/working/")

In [None]:
wandb.finish()