In [None]:
%env MUJOCO_GL=egl
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.5

import sys, os
import numpy as np
import torch
import torch.utils.dlpack as tpack
import jax
import jax.numpy as jp
from etils import epath
import mujoco
import mediapy

from mujoco_playground import wrapper
from mujoco_playground._src.manipulation.airbot_play.pick import AirbotPlayPickCube, default_config

from discoverse.gaussian_renderer.batch_splat import BatchSplatConfig, BatchSplatRenderer

In [None]:
# --- Load MJX env (replace with your env) ---
env = AirbotPlayPickCube(config=default_config())
num_envs = 64  # adjust
episode_length = int(4 / env._config.ctrl_dt)
rng = jax.random.PRNGKey(0)
env_wrap = wrapper.wrap_for_brax_training(env, vision=False, num_vision_envs=num_envs, episode_length=episode_length, action_repeat=1)
state = env_wrap.reset(jax.random.split(rng, num_envs))
state = env_wrap.step(state, jp.zeros((num_envs, env.mj_model.nu)))

print('Bodies:', env.mj_model.nbody, 'Cams:', env.mj_model.ncam)
for i in range(env.mj_model.nbody):
    print(env.mj_model.body(i).name, end=", ")

In [None]:
# --- Configure body->PLY mapping ---
# Update body names to match your MJCF; sample below assumes body name 'arm_base' exists.
from discoverse import DISCOVERSE_ASSETS_DIR

AIRBOT_ASSETS_PATH = os.path.join(DISCOVERSE_ASSETS_DIR, "3dgs/manipulator/airbot_play_224")
body_gaussians = {
    # "world"     : os.path.join(DISCOVERSE_ASSETS_DIR, "3dgs/scene/lab3/point_cloud.ply"),
    # Add more bodies as needed
    "arm_base"  : os.path.join(AIRBOT_ASSETS_PATH, "arm_base.ply"),
    "link1"     : os.path.join(AIRBOT_ASSETS_PATH, "link1.ply"),
    "link2"     : os.path.join(AIRBOT_ASSETS_PATH, "link2.ply"),
    "link3"     : os.path.join(AIRBOT_ASSETS_PATH, "link3.ply"),
    "link4"     : os.path.join(AIRBOT_ASSETS_PATH, "link4.ply"),
    "link5"     : os.path.join(AIRBOT_ASSETS_PATH, "link5.ply"),
    "link6"     : os.path.join(AIRBOT_ASSETS_PATH, "link6.ply"),
    "left"      : os.path.join(AIRBOT_ASSETS_PATH, "left.ply"),
    "right"     : os.path.join(AIRBOT_ASSETS_PATH, "right.ply"),
    "box"       : os.path.join(DISCOVERSE_ASSETS_DIR, "3dgs/manipulator/green_cube.ply"),
}
cfg = BatchSplatConfig(body_gaussians=body_gaussians, background_ply=None)
renderer = BatchSplatRenderer(cfg, mj_model=env.mj_model)

In [None]:
# --- Step 2: batch_update_gaussians ---
body_pos = tpack.from_dlpack(state.data.xpos)  # (Nenv, Nbody, 3)
body_quat = tpack.from_dlpack(state.data.xquat)  # (Nenv, Nbody, 4, wxyz)
gsb = renderer.batch_update_gaussians(body_pos, body_quat)
# --- Step 3: batch_env_render ---
cam_pos = tpack.from_dlpack(state.data.cam_xpos)   # (Nenv, Ncam, 3)
cam_xmat = tpack.from_dlpack(state.data.cam_xmat)   # (Nenv, Ncam, 3, 3)
H = 90; W = 120
fovy = np.array(env.mj_model.cam_fovy)[None, :]  # broadcast to (1, Ncam) or match cam count
bg_img = torch.ones((num_envs, fovy.shape[-1], H, W, 3), dtype=torch.float32, requires_grad=False)
rgb, depth = renderer.batch_env_render(gsb, cam_pos, cam_xmat, H, W, fovy, bg_img)

print('RGB:', rgb.shape, 'Depth:', depth.shape)
mediapy.show_image(rgb[0,0].cpu().numpy())

In [None]:
# --- Load Policy ---
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from mujoco_playground.config import manipulation_params
import functools
from etils import epath

# Define path and params
ckpt_path = epath.Path('/root/code/mujoco_playground/learning/logs/AirbotPlayPickCube-20251209-034846')
env_name = 'AirbotPlayPickCube'
ppo_params = manipulation_params.brax_ppo_config(env_name)
ppo_params.num_timesteps = 0

# Setup network factory
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
    network_factory = functools.partial(
        ppo_networks.make_ppo_networks,
        **ppo_params.network_factory
    )

# Create train_fn to restore checkpoint
train_fn = functools.partial(
    ppo.train,
    **dict(ppo_params),
    network_factory=network_factory,
    restore_checkpoint_path=ckpt_path,
    wrap_env_fn=wrapper.wrap_for_brax_training,
)

# Load inference function and params
# Note: We pass the existing 'env' from the notebook
make_inference_fn, params, _ = train_fn(environment=env)
inference_fn = make_inference_fn(params, deterministic=True)
jit_inference_fn = jax.jit(inference_fn)
print("Policy loaded successfully.")

In [None]:
# --- Run Batch Rollout & Render ---
import time

rng = jax.random.PRNGKey(0)
reset_rng, act_rng = jax.random.split(rng)

# Reset env
state = env_wrap.reset(jax.random.split(reset_rng, num_envs))

frames = []
start_time = time.time()

print(f"Simulating {episode_length} steps...")
for i in range(episode_length):
    # Step Policy
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = env_wrap.step(state, ctrl)
    
    # Render
    # Update Gaussians
    body_pos = tpack.from_dlpack(state.data.xpos)
    body_quat = tpack.from_dlpack(state.data.xquat)
    gsb = renderer.batch_update_gaussians(body_pos, body_quat)
    
    # Render Image
    cam_pos = tpack.from_dlpack(state.data.cam_xpos)
    cam_xmat = tpack.from_dlpack(state.data.cam_xmat)
    rgb, depth = renderer.batch_env_render(gsb, cam_pos, cam_xmat, H, W, fovy, bg_img)
    
    # Save frame (CPU)
    frames.append(rgb.cpu().numpy())

print(f"Rendered {len(frames)} frames in {time.time() - start_time:.2f}s")

# --- Display Video ---
def tile_frames(frame_batch, d):
    # frame_batch: (N, H, W, 3)
    N, H, W, C = frame_batch.shape
    grid = frame_batch.reshape(d, d, H, W, C)
    grid = grid.transpose(0, 2, 1, 3, 4).reshape(d*H, d*W, C)
    return grid

grid_size = int(np.sqrt(num_envs))
video_frames = [tile_frames(f, grid_size) for f in frames]

mediapy.show_video(video_frames, fps=1.0/env.dt)

In [None]:
# --- Timing loop example (optional) ---
num_run = 10
gsb = renderer.batch_update_gaussians(body_pos, body_quat)
rgb, depth = renderer.batch_env_render(gsb, cam_pos, cam_xmat, H, W, fovy, bg_img)

torch.cuda.synchronize() if torch.cuda.is_available() else None
start = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
end = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
if start: start.record()
for _ in range(num_run):
    gsb = renderer.batch_update_gaussians(body_pos, body_quat)
    # rgb, depth = renderer.batch_env_render(gsb, cam_pos, cam_xmat, H, W, fovy, bg_img)
    rgb, depth = renderer.batch_env_render(gsb, cam_pos, cam_xmat, H, W, fovy)
if end:
    end.record(); torch.cuda.synchronize()
    ms = start.elapsed_time(end) / num_run
    print(f'Avg time: {ms:.2f} ms, batch size={body_pos.shape[0]}')
    print(f"total fps = {1e3 * np.prod(rgb.shape[:2]) / ms}, width={W}, height={H}")
print('Done')

def tile(img, d):
    assert img.shape[0] == d*d
    img = img.reshape((d,d)+img.shape[1:])
    return np.concat(np.concat(img, axis=1), axis=1)

img_arr = rgb.detach().cpu().numpy()
mediapy.show_image(tile(img_arr[:,0,...], 8))