# Cave Exploration RL Training

This notebook contains the cave exploration reinforcement learning training pipeline, organized into separate cells for better modularity and experimentation.

In [1]:
# Necessary imports and setup
import sys
import os

# Add execution tracking to debug duplicate output
print("=== STARTING IMPORTS CELL ===")

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../..')))

import multiprocessing as mp
try:
    mp.set_start_method('spawn', force=True)
except RuntimeError:
    pass

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

import jax
from jax import numpy as jp
from jax.lib import xla_bridge

print("JAX backend info:")
print(f"Platform: {xla_bridge.get_backend().platform}")
print(f"Device count: {xla_bridge.get_backend().device_count()}")
print(f"Devices: {xla_bridge.get_backend().devices()}")

# JAX configuration for numerical stability
jax.config.update('jax_default_matmul_precision', 'high')
jax.config.update('jax_traceback_filtering', 'off')

# Check GPU availability
gpu_available = jax.devices()[0].platform == 'gpu'
print(f"GPU available: {gpu_available}")

if gpu_available:
    gpu_device = jax.devices('gpu')[0]
    print(f"GPU device: {gpu_device}")
else:
    print("No GPU device found.")

import signal
import json
import functools
import mujoco
from datetime import datetime
from pathlib import Path
import imageio
import gc

print("Basic imports completed...")

# Brax and training imports
from brax.io import model
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from flax.training import orbax_utils
from orbax import checkpoint as ocp
from mujoco_playground.config import locomotion_params
from mujoco_playground import wrapper
from tensorboardX import SummaryWriter

print("Brax imports completed...")

# Task-specific imports
from tasks.cave_exploration.cave_exploration import CaveExplore
from tasks.common.randomize import domain_randomize as reachbot_randomize
from utils.telegram_messenger import send_message_sync

print("Task-specific imports completed...")

# Global variables
ENV_STR = 'Go1JoystickFlatTerrain'
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]

# Signal handler for graceful interruption
def signal_handler(sig, frame):
    print('Program exited via keyboard interrupt')
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

# JSON encoder for JAX arrays
class JaxArrayEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, jp.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

print("=== ALL IMPORTS LOADED SUCCESSFULLY! ===")

=== STARTING IMPORTS CELL ===
JAX backend info:
JAX backend info:


  print(f"Platform: {xla_bridge.get_backend().platform}")


Platform: gpu
Device count: 2
Devices: [CudaDevice(id=0), CudaDevice(id=1)]
GPU available: True
GPU device: cuda:0
Basic imports completed...
Brax imports completed...
Task-specific imports completed...
=== ALL IMPORTS LOADED SUCCESSFULLY! ===
Brax imports completed...
Task-specific imports completed...
=== ALL IMPORTS LOADED SUCCESSFULLY! ===


In [None]:
# Environment configuration
from tasks.cave_exploration.cave_exploration import default_config as reachbot_config

env_cfg = reachbot_config()

# Basic simulation parameters
env_cfg.sim_dt = 0.004
env_cfg.action_scale = 1

# PID control parameters
env_cfg.Kp_pri = 60.0
env_cfg.Kd_pri = 20.0
env_cfg.Kp_rot = 25.0
env_cfg.Kd_rot = 2.0

# Reward scaling configuration
env_cfg.reward_config.scales.orientation = -0.0
env_cfg.reward_config.scales.lin_vel_z = -0.0
env_cfg.reward_config.scales.ang_vel_xy = -0.00
env_cfg.reward_config.scales.torques = -0.00005
env_cfg.reward_config.scales.action_rate = -0.0001
env_cfg.reward_config.scales.dof_pos_limits = -0.5
env_cfg.reward_config.scales.energy = -0.00001
env_cfg.reward_config.scales.feet_slip = -0.0

# Target-based rewards
env_cfg.reward_config.scales.distance_to_target = 10.0
env_cfg.reward_config.scales.vel_to_target = 3.0
env_cfg.reward_config.scales.exploration_rate = 0.0

print("Environment configuration completed!")
print(f"Simulation dt: {env_cfg.sim_dt}")
print(f"Distance to target reward scale: {env_cfg.reward_config.scales.distance_to_target}")

Environment configuration completed!
Simulation dt: 0.004
Distance to target reward scale: 70.0


In [None]:
# PPO training parameters configuration
ppo_params = locomotion_params.brax_ppo_config(ENV_STR)
ppo_training_params = dict(ppo_params)

# Modify params for training
ppo_training_params["num_timesteps"] = 30_000_000
ppo_training_params["episode_length"] = 10000
ppo_training_params["num_envs"] = 4096
ppo_training_params["batch_size"] = 1024
ppo_training_params["num_minibatches"] = 32
ppo_training_params["num_updates_per_batch"] = 16
ppo_training_params["unroll_length"] = 200
ppo_training_params["entropy_cost"] = 0.02
ppo_training_params["learning_rate"] = 0.0003

# Set episode length in environment config
env_cfg.episode_length = ppo_training_params["episode_length"]

print("PPO training parameters:")
for key, value in ppo_training_params.items():
    print(f"  {key}: {value}")

print("\nPPO parameters configuration completed!")

PPO training parameters:
  action_repeat: 1
  batch_size: 512
  discounting: 0.97
  entropy_cost: 0.02
  episode_length: 10000
  learning_rate: 0.0003
  max_grad_norm: 1.0
  network_factory: policy_hidden_layer_sizes: &id001 !!python/tuple
- 512
- 256
- 128
policy_obs_key: state
value_hidden_layer_sizes: *id001
value_obs_key: privileged_state

  normalize_observations: True
  num_envs: 4096
  num_evals: 10
  num_minibatches: 16
  num_resets_per_eval: 1
  num_timesteps: 60000000
  num_updates_per_batch: 8
  reward_scaling: 1.0
  unroll_length: 100

PPO parameters configuration completed!


In [4]:
import os
import threading
print(f"🔥 PID: {os.getpid()} | Thread: {threading.current_thread().ident} | Time: {datetime.now()}")

# Training execution
# Create log directory for training run
datetime_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
logdir = os.path.join(os.getcwd(), "logs/cave_exploration-"+datetime_str)
os.makedirs(logdir, exist_ok=True)

# Create environment
env = CaveExplore(config=env_cfg)

# Initialize tracking variables
timesteps = []
rewards = []
total_rewards = []
total_rewards_std = []
times = [datetime.now()]

writer = SummaryWriter(logdir=logdir)

# Progress tracking function
def progress(num_steps, metrics):
    episode_reward = metrics["eval/episode_reward"]
    if jp.isnan(episode_reward) or jp.isinf(episode_reward):
        print("Warning: NaN/Inf reward encountered, aborting.")
        run_duration = str(datetime.now() - times[0])
        send_message_sync(
            task="Cave Exploration RL Training",
            duration=run_duration,
            result="Failed: NaN/Inf reward encountered"
        )
        raise ValueError(f"NaN/Inf reward encountered at step {num_steps}: {episode_reward}")
    
    times.append(datetime.now())
    timesteps.append(num_steps)
    total_rewards.append(episode_reward)
    total_rewards_std.append(metrics["eval/episode_reward_std"])
    
    # Log to TensorBoard
    for key, value in metrics.items():
        if not (jp.isnan(value) or jp.isinf(value)):
            writer.add_scalar(key, value, num_steps)
        else:
            print(f"Warning: Skipping NaN/Inf value for metric '{key}' at step {num_steps}")
    
    writer.flush()
    metrics["timesteps"] = num_steps
    metrics["time"] = (times[-1] - times[0]).total_seconds()
    rewards.append(metrics)
    
    percent_complete = (num_steps / ppo_training_params["num_timesteps"]) * 100
    if num_steps == 0:
        remaining_time_str = "unknown"
    else:
        remaining_time = (ppo_training_params["num_timesteps"] - num_steps) * (times[-1] - times[0]).total_seconds() / num_steps / 60
        remaining_time_str = f"{remaining_time:.2f}"
    
    print(f"step: {num_steps}/{ppo_training_params['num_timesteps']} ({percent_complete:.1f}%), reward: {total_rewards[-1]:.3f} +/- {total_rewards_std[-1]:.3f}, time passed (min): {(times[-1] - times[0]).total_seconds() / 60:.2f} min, calculated time left (min): {remaining_time_str} min")

# Network factory setup
network_factory = ppo_networks.make_ppo_networks(observation_size=env.observation_size, action_size=env.action_size)
if "network_factory" in ppo_params:
    if "network_factory" in ppo_training_params:
        del ppo_training_params["network_factory"]
    network_factory = functools.partial(
        ppo_networks.make_ppo_networks,
        **ppo_params.network_factory
    )

# Checkpoint saving function
def policy_params_fn(current_step, make_policy, params):
    del make_policy  # Unused.
    orbax_checkpointer = ocp.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(params)
    checkpoint_path = os.path.join(logdir, 'checkpoints')
    path = os.path.join(checkpoint_path, f"{current_step}")
    abs_path = os.path.abspath(path)
    orbax_checkpointer.save(abs_path, params, force=True, save_args=save_args)

# Save configurations
print("Saving configs")
configs = {
    "env_cfg": env_cfg.to_dict(),
    "ppo_params": ppo_training_params
}

def replace_infinity(obj):
    if isinstance(obj, dict):
        return {k: replace_infinity(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [replace_infinity(v) for v in obj]
    elif isinstance(obj, float) and obj == float('inf'):
        return 1e308
    return obj

configs = replace_infinity(configs)
config_path = os.path.join(logdir, 'config.json')
with open(config_path, "w", encoding="utf-8") as fp:
    json.dump(configs, fp, indent=4)
print(f"Configuration saved to {config_path}")
writer.add_text('config', json.dumps(configs, indent=4))

# Setup training function
randomizer = reachbot_randomize
train_fn = functools.partial(
    ppo.train, 
    **dict(ppo_training_params),
    network_factory=network_factory,
    progress_fn=progress,
    policy_params_fn=policy_params_fn,
)

# Run training
print("Training the model...")
try:
    make_inference_fn, params, metrics = train_fn(
        environment=env,
        wrap_env_fn=wrapper.wrap_for_brax_training,
    )
    print("Training completed successfully!")
except Exception as e:
    import traceback
    run_duration = str(datetime.now() - times[0])
    send_message_sync(
        task="Cave Exploration RL Training",
        duration=run_duration,
        result=f"Failed: {e}"
    )
    traceback.print_exc()
    raise

print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")

# Save results
results_path = os.path.join(logdir, 'results.txt')
with open(results_path, 'w') as f:
    for i in range(len(total_rewards)):
        f.write(f"step: {timesteps[i]}, reward: {total_rewards[i]}, reward_std: {total_rewards_std[i]}\n")
    f.write(f"Time to jit: {times[1] - times[0]}\n")
    f.write(f"Time to train: {times[-1] - times[1]}\n")

# Save rewards as JSON
def nest_flat_dict(flat_dict):
    nested_dict = {}
    for key, value in flat_dict.items():
        parts = key.split('/')
        d = nested_dict
        for i, part in enumerate(parts):
            is_last_part = (i == len(parts) - 1)
            if is_last_part:
                if isinstance(d.get(part), dict):
                    d[part]['value'] = value
                else:
                    d[part] = value
            else:
                if not isinstance(d.get(part), dict):
                    d[part] = {'value': d[part]} if part in d else {}
                d = d[part]
    return nested_dict

nested_rewards = [nest_flat_dict(r) for r in rewards]
rewards_path = os.path.join(logdir, 'rewards.json')
with open(rewards_path, 'w') as fp:
    json.dump(nested_rewards, fp, indent=4, cls=JaxArrayEncoder)

# Save trained parameters
params_path = os.path.join(logdir, 'params')
model.save_params(params_path, params)

print(f"Training completed! Results saved to: {logdir}")
print(f"Final reward: {total_rewards[-1]:.3f} ± {total_rewards_std[-1]:.3f}")

# Store these variables for the video generation cell
trained_params = params
trained_make_inference_fn = make_inference_fn
trained_env = env
trained_logdir = logdir

🔥 PID: 2198632 | Thread: 139864404193920 | Time: 2025-07-05 18:20:14.578427
Found 3 cave folders in /home/ga53voq/master_thesis/tasks/cave_exploration/environment/caves.
Loading 1 cave environments.
initial_qpos (cave_batch_loader): [-0.234  0.156  0.009  0.339 -0.683 -0.587 -0.283  0.003 -0.    -0.
  0.    -0.    -0.    -0.003 -0.    -0.     0.004 -0.001  0.021]
Floor boxes detected: 4404
initial_qpos (cave_batch_loader): [-0.234  0.156  0.009  0.339 -0.683 -0.587 -0.283  0.003 -0.    -0.
  0.    -0.    -0.    -0.003 -0.    -0.     0.004 -0.001  0.021]
Floor boxes detected: 4404
Found 4 boom end geoms
Found 4405 floor/wall geoms
CaveExplore task initialized with model: ReachbotModelType.BASIC
CaveExplore task action space: 12
Found 4 boom end geoms
Found 4405 floor/wall geoms
CaveExplore task initialized with model: ReachbotModelType.BASIC
CaveExplore task action space: 12
CaveExplore task observation space: {'privileged_state': (134,), 'state': (75,)}
CaveExplore task observation spa

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


Max contacts updated to 500 based on current step.
step: 0/60000000 (0.0%), reward: 0.174 +/- 0.232, time passed (min): 7.91 min, calculated time left (min): unknown min
step: 0/60000000 (0.0%), reward: 0.174 +/- 0.232, time passed (min): 7.91 min, calculated time left (min): unknown min
step: 7372800/60000000 (12.3%), reward: 1.319 +/- 1.154, time passed (min): 22.93 min, calculated time left (min): 163.68 min
step: 7372800/60000000 (12.3%), reward: 1.319 +/- 1.154, time passed (min): 22.93 min, calculated time left (min): 163.68 min
step: 14745600/60000000 (24.6%), reward: 2.416 +/- 2.321, time passed (min): 36.60 min, calculated time left (min): 112.33 min
step: 14745600/60000000 (24.6%), reward: 2.416 +/- 2.321, time passed (min): 36.60 min, calculated time left (min): 112.33 min
step: 22118400/60000000 (36.9%), reward: 4.609 +/- 3.530, time passed (min): 50.27 min, calculated time left (min): 86.09 min
step: 22118400/60000000 (36.9%), reward: 4.609 +/- 3.530, time passed (min): 50

In [None]:
# Video creation from trained model
# Free up training memory before rendering
del train_fn, network_factory, writer
gc.collect()

# Use the already loaded model and environment from the training cell
env = trained_env
params = trained_params
make_inference_fn = trained_make_inference_fn
logdir = trained_logdir

# Setup JIT compiled functions for inference
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
inference_fn = make_inference_fn(params, deterministic=True)
jit_inference_fn = jax.jit(inference_fn)

print("Setting up rollout for video creation...")

# Rollout parameters
rng = jax.random.PRNGKey(0)
rollout = []
n_episodes = 1
rollout_steps = 2000

# Set command (if needed for environment)
x_vel = 0.2
y_vel = 0.2
yaw_vel = 0.0
command = jp.array([x_vel, y_vel, yaw_vel])

# Rollout policy and record simulation
print(f"Running rollout for {n_episodes} episode(s) with {rollout_steps} steps each...")
for episode in range(n_episodes):
    print(f"Episode {episode + 1}/{n_episodes}")
    state = jit_reset(rng)
    rollout.append(state)
    
    for i in range(rollout_steps):
        if i % 500 == 0:
            print(f"  Step {i}/{rollout_steps}")
            
        act_rng, rng = jax.random.split(rng)
        ctrl, _ = jit_inference_fn(state.obs, act_rng)
        
        # Check for numerical issues
        if jp.any(jp.isinf(ctrl)) or jp.any(jp.isnan(ctrl)):
            print(f"Numerical issue detected in control at step {i}. Stopping rollout.")
            break
            
        state = jit_step(state, ctrl)
        
        # Set command if the environment supports it
        if hasattr(state, 'info') and 'command' in state.info:
            state.info["command"] = command
            
        rollout.append(state)

print(f"Rollout completed with {len(rollout)} states")

# Render video
print("Rendering video...")
render_every = 1  # Render every frame
width = 1920      # Full HD width
height = 1080     # Full HD height

frames = env.render(rollout[::render_every], camera='track_global', width=width, height=height)
print(f"Rendered {len(frames)} frames")

# Save video
video_path = os.path.join(logdir, 'posttraining.mp4')
fps = 1.0 / env.dt

print(f"Saving video to {video_path} at {fps} FPS...")
imageio.mimsave(video_path, frames, fps=fps)
print(f"Video saved successfully to {video_path}")

# Send completion notification
run_duration = str(times[-1] - times[0])
if total_rewards:
    result = f"Final reward: {total_rewards[-1]:.3f} ± {total_rewards_std[-1]:.3f}"
else:
    result = "No rewards recorded."

send_message_sync(
    task="Cave Exploration RL Training",
    duration=run_duration,
    result=result
)

print("\n=== TRAINING AND VIDEO CREATION COMPLETE ===")
print(f"Log directory: {logdir}")
print(f"Video file: {video_path}")
print(f"Training duration: {run_duration}")
print(f"Final result: {result}")

Setting up rollout for video creation...
Running rollout for 1 episode(s) with 2000 steps each...
Episode 1/1
  Step 0/2000
  Step 500/2000
  Step 1000/2000
  Step 1500/2000
Rollout completed with 2001 states
Rendering video...


100%|██████████| 2001/2001 [00:39<00:00, 51.26it/s]
  self.pid = _fork_exec(


Rendered 2001 frames
Saving video to /home/ga53voq/master_thesis/tasks/cave_exploration/logs/cave_exploration-2025-07-05_18-20-14/posttraining.mp4 at 50.0 FPS...
Video saved successfully to /home/ga53voq/master_thesis/tasks/cave_exploration/logs/cave_exploration-2025-07-05_18-20-14/posttraining.mp4

=== TRAINING AND VIDEO CREATION COMPLETE ===
Log directory: /home/ga53voq/master_thesis/tasks/cave_exploration/logs/cave_exploration-2025-07-05_18-20-14
Video file: /home/ga53voq/master_thesis/tasks/cave_exploration/logs/cave_exploration-2025-07-05_18-20-14/posttraining.mp4
Training duration: 2:12:23.700416
Final result: Final reward: 12.860 ± 9.270
