In [1]:
from brax.io import model
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from mujoco_playground.config import locomotion_params

from ml_collections import config_dict

ENV_STR = 'Go1JoystickFlatTerrain'

ppo_params = locomotion_params.brax_ppo_config(ENV_STR)

# Modify params for faster training
ppo_params.num_timesteps = 10_000_000  # Reduce from 60000000
ppo_params.episode_length = 2000  # Max episode length
ppo_params.num_envs = 4096  # Reduce from 2048
ppo_params.batch_size = 512  # Number of samples randomly chosen from the rollout data for training
ppo_params.num_minibatches = 16  # Splits batch_size into num_minibatches for separate gradient updates
ppo_params.num_updates_per_batch = 8  # Reduce from 16
ppo_params.unroll_length = 100  # Number of steps to run in each environment before gathering rollouts
ppo_params.entropy_cost = 0.02
ppo_params.learning_rate = 0.0003  # Learning rate for the optimizerppo
ppo_params.network_factory = config_dict.create(
    policy_hidden_layer_sizes = (1024,512,256,128),
    value_hidden_layer_sizes = (1024,512,256,128),
    policy_obs_key = 'state',
    value_obs_key = 'privileged_state',
)





In [None]:
import os
import sys

# For Jupyter notebooks, use os.getcwd() instead of __file__
notebook_dir = os.getcwd()
sys.path.insert(0, os.path.abspath(os.path.join(notebook_dir, '../..')))

from tasks.cave_exploration.cave_exploration import default_config as reachbot_config

env_cfg = reachbot_config()

env_cfg.sim_dt = 0.004
env_cfg.action_scale = 0.2 # Scale the actions to make them more manageable


env_cfg.reward_config.scales.orientation = 0.0 # Disable orientation reward
env_cfg.reward_config.scales.lin_vel_z = 0.0 # Disable velocity reward
env_cfg.reward_config.scales.ang_vel_xy = 0.0 # Disable xy velocity reward
env_cfg.reward_config.scales.feet_slip = 0.0 # Disable z velocity reward
env_cfg.reward_config.scales.stand_still = -0.0001 # Disable feet height reward
env_cfg.reward_config.scales.torques = -0.0000 # Enable torques reward
env_cfg.reward_config.scales.action_rate = 0.0 # Disable LIDAR reward
env_cfg.reward_config.scales.energy = 0.0 # Enable
env_cfg.reward_config.scales.distance_to_target = 100.0 # Disable LIDAR reward

In [None]:
# @title Run RL simulation with joystick control


#os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import multiprocessing as mp
try:
    mp.set_start_method('spawn', force=True)
except RuntimeError:
    pass


# Tell XLA to use Triton GEMM, this    pip install numpy<2.0 improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

import jax
from jax import numpy as jp



# This helps with nan values being returned from the model while costing some perf. See github @brax for more info
# Other fix: 
#jax.config.update('jax_enable_x64', True) # However, this will slow down the training a lot
#jax.config.update('jax_default_matmul_precision', 'high')
#jax.config.update("jax_debug_nans", True)
# Check if GPU is available

gpu_available = jax.devices()[0].platform == 'gpu'
print(f"GPU available: {gpu_available}")

# Print GPU details
if gpu_available:
    gpu_device = jax.devices('gpu')[0]
    print(f"GPU device: {gpu_device}")
else:
    print("No GPU device found.")

import signal
import sys

# Function to handle the interrupt signal
def signal_handler(sig, frame):
    print('Program exited via keyboard interrupt')
    sys.exit(0)

# Register the signal handler for SIGINT (Ctrl+C)
signal.signal(signal.SIGINT, signal_handler)

import mujoco

import json
# Importing the necessary libraries
from datetime import datetime
import functools
# Run the code on the CPU rather than the GPU
# Normally the code runs on the GPU or any other accelerator that is available
#os.environ['JAX_PLATFORM_NAME'] = 'cpu'

import mujoco
from flax.training import orbax_utils
from orbax import checkpoint as ocp
from mujoco_playground.config import locomotion_params

from tasks.cave_exploration.cave_exploration import CaveExplore
from tasks.common.randomize import domain_randomize as reachbot_randomize
from mujoco_playground import registry

from tensorboardX import SummaryWriter

from pathlib import Path

from utils.telegram_messenger import send_message_sync




# Store data from training
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]

class JaxArrayEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, jp.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)
    
def save_video(frames, video_path, fps):
    import imageio
    imageio.mimsave(video_path, frames, fps=fps)
       

def trainModel():

  # Create log directory for training run
  datetime_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
  logdir = os.path.join(os.getcwd(), "logs/cave_exploration-"+datetime_str)
  os.makedirs(logdir, exist_ok=True) 
  

  env = CaveExplore(config=env_cfg, lidar_num_horizontal_rays=20, lidar_max_range=15.0, lidar_horizontal_angle_range=jp.pi * 2, lidar_vertical_angle_range=jp.pi / 6) # Updated LIDAR params for 3D

  env_cfg.episode_length = ppo_params["episode_length"]

  print("Training parameters:")

  for key, value in ppo_params.items():
      print(f"  {key}: {value}")
  timesteps = []
  rewards = []
  total_rewards = []
  total_rewards_std = []

  writer = SummaryWriter(logdir=logdir)


  # Function to display the training progress
  def progress(num_steps, metrics):
    times.append(datetime.now())
    timesteps.append(num_steps)
    total_rewards.append(metrics["eval/episode_reward"])
    total_rewards_std.append(metrics["eval/episode_reward_std"])
    for key, value in metrics.items():
        writer.add_scalar(key, value, num_steps)
        writer.flush()
    metrics["timesteps"] = num_steps
    metrics["time"] = (times[-1] - times[0]).total_seconds()
    rewards.append(metrics)
    percent_complete = (num_steps / ppo_params["num_timesteps"]) * 100
    print(f"step: {num_steps}/{ppo_params['num_timesteps']} ({percent_complete:.1f}%), reward: {total_rewards[-1]:.3f} +/- {total_rewards_std[-1]:.3f}")

  
# Getting the network factory
  ppo_training_params = dict(ppo_params)
  network_factory = ppo_networks.make_ppo_networks(observation_size=env.observation_size, action_size=env.action_size)
  if "network_factory" in ppo_params:
    if "network_factory" in ppo_training_params:
      del ppo_training_params["network_factory"]
    network_factory = functools.partial(
        ppo_networks.make_ppo_networks,
        **ppo_params.network_factory
    )

  def policy_params_fn(current_step, make_policy, params):
    del make_policy  # Unused.
    orbax_checkpointer = ocp.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(params)
    checkpoint_path = os.path.join(logdir, 'checkpoints')
    path = os.path.join(checkpoint_path, f"{current_step}")
    abs_path = os.path.abspath(path)
    orbax_checkpointer.save(abs_path, params, force=True, save_args=save_args)

  #randomizer = registry.get_domain_randomizer(ENV_STR)
  randomizer = reachbot_randomize

  # Training the model
  print("Training the model...")
  train_fn = functools.partial(
      ppo.train, **dict(ppo_training_params),
      network_factory=network_factory,
      progress_fn=progress,
      policy_params_fn=policy_params_fn,
      #randomization_fn=randomizer,
  )

  # Function to control the trained agents actions in the environment
  # Params: Stores the weights of the trained model
  # Metrics: Contains information about the training process such as performance over time
  from mujoco_playground import wrapper
  # Run training
  try:
    make_inference_fn, params, metrics = train_fn(
        environment=env,
        wrap_env_fn=wrapper.wrap_for_brax_training,
    )

  except Exception as e:
    import traceback
    run_duration = str(datetime.now() - times[0])
    # --- Fix for Jupyter: avoid asyncio.run in running event loop ---
    import asyncio
    from utils.telegram_messenger import send_telegram_message
    loop = asyncio.get_event_loop()
    if loop.is_running():
        task = loop.create_task(send_telegram_message(
            task="Cave Exploration RL Training",
            duration=run_duration,
            result=f"Failed: {e}"
        ))
    else:
        loop.run_until_complete(send_telegram_message(
            task="Cave Exploration RL Training",
            duration=run_duration,
            result=f"Failed: {e}"
        ))
    # --- End fix ---
    traceback.print_exc()
    raise

  print(f"time to jit: {times[1] - times[0]}")
  print(f"time to train: {times[-1] - times[1]}")

  # Save configs as json in results
  configs = {
      "env_cfg": env_cfg.to_dict(),
      "ppo_params": ppo_params
  }
  config_path = os.path.join(logdir, 'config.json')
  # Replace Infinity with a large number or null
  def replace_infinity(obj):
      if isinstance(obj, dict):
          return {k: replace_infinity(v) for k, v in obj.items()}
      elif isinstance(obj, list):
          return [replace_infinity(v) for v in obj]
      elif isinstance(obj, float) and obj == float('inf'):
          return 1e308
      return obj

  # Update the configs before saving
  configs = replace_infinity(configs)

  # Save environment configuration
  with open(config_path, "w", encoding="utf-8") as fp:
    json.dump(configs, fp, indent=4)
  print(f"Configuration saved to {config_path}")
  writer.add_text('config', json.dumps(configs, indent=4))

  # Store the results in a file
  results_path = os.path.join(logdir, 'results.txt')
  with open(results_path, 'w') as f:
    for i in range(len(total_rewards)):
      f.write(f"step: {timesteps[i]}, reward: {total_rewards[i]}, reward_std: {total_rewards_std[i]}\n")
    f.write(f"Time to jit: {times[1] - times[0]}\n")
    f.write(f"Time to train: {times[-1] - times[1]}\n")

  # Save the rewards as JSON
  rewards_path = os.path.join(logdir, 'rewards.json')
  def nest_flat_dict(flat_dict):
    """Converts a dictionary with '/' in keys to a nested dictionary."""
    nested_dict = {}
    for key, value in flat_dict.items():
        parts = key.split('/')
        d = nested_dict
        for i, part in enumerate(parts):
            is_last_part = (i == len(parts) - 1)
            if is_last_part:
                # If a dictionary already exists here, we're setting the 'value' for that group.
                if isinstance(d.get(part), dict):
                    d[part]['value'] = value
                else:
                    d[part] = value
            else:
                # If the path item is not a dict, convert it to one to allow nesting.
                if not isinstance(d.get(part), dict):
                    d[part] = {'value': d[part]} if part in d else {}
                d = d[part]
    return nested_dict
  
  nested_rewards = [nest_flat_dict(r) for r in rewards]
  with open(rewards_path, 'w') as fp:
      json.dump(nested_rewards, fp, indent=4, cls=JaxArrayEncoder)

  
  params_path = os.path.join(logdir, 'params')
  model.save_params(params_path, params)

  jit_reset = jax.jit(env.reset)
  jit_step = jax.jit(env.step)
  jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))


  rng = jax.random.PRNGKey(0)
  rollout = []
  n_episodes = 1

  # Rollout policy and record simulation
  for _ in range(n_episodes):
    state = jit_reset(rng)
    rollout.append(state)
    for i in range(1200):
      act_rng, rng = jax.random.split(rng)
      ctrl, _ = jit_inference_fn(state.obs, act_rng)
      state = jit_step(state, ctrl)
      rollout.append(state)

  render_every = 1
  frames = env.render(rollout[::render_every], camera='track_global', width=1920, height=1080)
  video_path = os.path.join(logdir, 'posttraining.mp4')
  fps = 1.0 / env.dt
  ctx = mp.get_context("spawn")
  p = ctx.Process(target=save_video, args=(frames, video_path, fps))
  p.start()
  p.join()

  # Calculate run duration
  run_duration = str(times[-1] - times[0])
  # Get final reward and std if available
  if total_rewards:
      result = f"Final reward: {total_rewards[-1]:.3f} ± {total_rewards_std[-1]:.3f}"
  else:
      result = "No rewards recorded."
  # --- Fix for Jupyter: avoid asyncio.run in running event loop ---
  import asyncio
  from utils.telegram_messenger import send_telegram_message
  loop = asyncio.get_event_loop()
  if loop.is_running():
      task = loop.create_task(send_telegram_message(
          task="Cave Exploration RL Training",
          duration=run_duration,
          result=result
      ))
  else:
      loop.run_until_complete(send_telegram_message(
          task="Cave Exploration RL Training",
          duration=run_duration,
          result=result
      ))
  # --- End fix ---

trainModel()

GPU available: True
GPU device: cuda:0
Found 3 cave folders in /home/ga53voq/master_thesis/tasks/cave_exploration/environment/caves.
Loading 1 cave environments.
initial_qpos (cave_batch_loader): [-0.006  0.057  0.124  0.987 -0.212 -0.008 -0.031  0.016  0.001  0.028
 -0.016  0.013  0.027 -0.     0.006  0.028 -0.002 -0.     0.028]
Floor boxes detected: 5909
CaveExplore task initialized with model: ReachbotModelType.BASIC
CaveExplore task action space: 12
initial_qpos (cave_batch_loader): [-0.006  0.057  0.124  0.987 -0.212 -0.008 -0.031  0.016  0.001  0.028
 -0.016  0.013  0.027 -0.     0.006  0.028 -0.002 -0.     0.028]
Floor boxes detected: 5909
CaveExplore task initialized with model: ReachbotModelType.BASIC
CaveExplore task action space: 12
CaveExplore task observation space: {'privileged_state': (204,), 'state': (145,)}
Training parameters:
  action_repeat: 1
  batch_size: 512
  discounting: 0.97
  entropy_cost: 0.02
  episode_length: 2000
  learning_rate: 0.0003
  max_grad_norm: 1

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


step: 0/10000000 (0.0%), reward: nan +/- nan
