In [None]:
%env MUJOCO_GL=egl
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.5

import sys, os
import numpy as np
import torch
import torch.utils.dlpack as tpack
import jax
import jax.numpy as jp
from etils import epath
import mujoco
import mediapy

from mujoco_playground import wrapper
from mujoco_playground._src.manipulation.airbot_play.pick import AirbotPlayPick, default_config

from examples.gsplat.batch_splat import BatchSplatConfig, BatchSplatRenderer

In [None]:
# --- Load MJX env (replace with your env) ---
xml_path = "/root/code/DISCOVERSE/models/mjcf/manipulator/robot_airbot_play.xml"
env = AirbotPlayPick(xml_path=epath.Path(xml_path), config=default_config())
num_envs = 8  # adjust
episode_length = int(4 / env._config.ctrl_dt)
rng = jax.random.PRNGKey(0)
env_wrap = wrapper.wrap_for_brax_training(env, vision=False, num_vision_envs=num_envs, episode_length=episode_length, action_repeat=1)
state = env_wrap.reset(jax.random.split(rng, num_envs))
state = env_wrap.step(state, jp.zeros((num_envs, env.mj_model.nu)))

print('Bodies:', env.mj_model.nbody, 'Cams:', env.mj_model.ncam)

In [None]:
# --- Configure body->PLY mapping ---
# Update body names to match your MJCF; sample below assumes body name 'arm_base' exists.
body_gaussians = {
    'arm_base': '/root/code/DISCOVERSE/models/3dgs/manipulator/7k/arm_base_7000.ply',
    # Add more bodies as needed
}
cfg = BatchSplatConfig(body_gaussians=body_gaussians, background_ply=None)
renderer = BatchSplatRenderer(cfg, mj_model=env.mj_model)

In [None]:
# --- Step 2: batch_update_gaussians ---
body_pos = tpack.from_dlpack(state.data.xpos)  # (Nenv, Nbody, 3)
body_quat = tpack.from_dlpack(state.data.xquat)  # (Nenv, Nbody, 4, wxyz)
gsb = renderer.batch_update_gaussians(body_pos, body_quat)
print('Gaussian batch shape:', gsb.xyz.shape)

In [None]:
# --- Step 3: batch_env_render ---
cam_pos = tpack.from_dlpack(state.data.cam_xpos)   # (Nenv, Ncam, 3)
cam_xmat = tpack.from_dlpack(state.data.cam_xmat)   # (Nenv, Ncam, 3, 3)
H = 128; W = 128
fovy = np.array(env.mj_model.cam_fovy)[None, :]  # broadcast to (1, Ncam) or match cam count
rgb, depth = renderer.batch_env_render(gsb, cam_pos, cam_xmat, H, W, fovy)
print('RGB:', rgb.shape, 'Depth:', depth.shape)
mediapy.show_image(rgb[0,0].cpu().numpy())

In [None]:
# --- Timing loop example (optional) ---
num_run = 20
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
end = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
if start: start.record()
for _ in range(num_run):
    gsb = renderer.batch_update_gaussians(body_pos, body_quat)
    rgb, depth = renderer.batch_env_render(gsb, cam_pos, cam_xmat, H, W, fovy)
if end:
    end.record(); torch.cuda.synchronize()
    ms = start.elapsed_time(end) / num_run
    print(f'Avg time: {ms:.2f} ms')
print('Done')