In [None]:
import time

import mujoco.viewer
import mujoco
import numpy as np
import torch
import yaml


def get_gravity_orientation(quaternion):
    qw = quaternion[0]
    qx = quaternion[1]
    qy = quaternion[2]
    qz = quaternion[3]

    gravity_orientation = np.zeros(3)

    gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
    gravity_orientation[1] = -2 * (qz * qy + qw * qx)
    gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)

    return gravity_orientation


def pd_control(target_q, q, kp, target_dq, dq, kd):
    """Calculates torques from position commands"""
    return (target_q - q) * kp + (target_dq - dq) * kd


In [None]:
# Load configuration
import os

# For Jupyter notebook, use current working directory
mujoco_dir = os.getcwd()
h12_root = os.path.dirname(mujoco_dir)

config_file = "h12.yaml"
config_path = os.path.join(mujoco_dir, config_file)

print(f"Loading config from: {config_path}")
print(f"H12 root: {h12_root}")

with open(config_path, "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    policy_path = config["policy_path"].replace("{H12_ROOT}", h12_root)
    xml_path = config["xml_path"].replace("{H12_ROOT}", h12_root)

    simulation_duration = config["simulation_duration"]
    simulation_dt = config["simulation_dt"]
    control_decimation = config["control_decimation"]

    kps = np.array(config["kps"], dtype=np.float32) * 1.1
    kds = np.array(config["kds"], dtype=np.float32) 

    legs_motor_pos_lower_limit_list = np.array(config["legs_motor_pos_lower_limit_list"], dtype=np.float32)
    legs_motor_pos_upper_limit_list = np.array(config["legs_motor_pos_upper_limit_list"], dtype=np.float32)

    default_angles = np.array(config["default_angles"], dtype=np.float32)

    ang_vel_scale = config["ang_vel_scale"]
    dof_pos_scale = config["dof_pos_scale"]
    dof_vel_scale = config["dof_vel_scale"]
    action_scale = config["action_scale"]
    cmd_scale = np.array(config["cmd_scale"], dtype=np.float32)

    cmd = np.array(config["cmd_init"], dtype=np.float32)

print(f"\nLoaded configuration:")
print(f"  Policy: {policy_path}")
print(f"  XML: {xml_path}")

# Define for 12 DOF lower body only
num_actions = 12  # Only leg DOF
num_obs = 3 + 3 + 3 + 12 + 12 + 12  # base_ang_vel + projected_gravity + velocity_commands + joint_pos_rel + joint_vel_rel + last_action

print(f"  num_actions: {num_actions}, num_obs: {num_obs}")

# define context variables
action = np.zeros(num_actions, dtype=np.float32)
target_dof_pos = default_angles.copy()
obs = np.zeros(num_obs, dtype=np.float32)

counter = 0

# Observation history buffer - 12 DOF only
history_length = 5
obs_dim = 3 + 3 + 3 + 12 + 12 + 12  # base_ang_vel + projected_gravity + velocity_commands + joint_pos_rel(12) + joint_vel_rel(12) + last_action(12)
obs_single = np.zeros(obs_dim, dtype=np.float32)
obs_history = np.zeros((history_length, obs_dim), dtype=np.float32)
last_action = np.zeros(num_actions, dtype=np.float32)

# Load robot model
print(f"\nLoading MuJoCo model from: {xml_path}")
m = mujoco.MjModel.from_xml_path(xml_path)
d = mujoco.MjData(m)
m.opt.timestep = simulation_dt

# load policy
print(f"Loading policy from: {policy_path}")
policy = torch.jit.load(policy_path)

print("\nSetup complete!")


In [None]:
# Inspect MuJoCo model structure
print("=" * 80)
print("MUJOCO MODEL STRUCTURE")
print("=" * 80)

print(f"\nTotal bodies: {m.nbody}")
print(f"Total joints: {m.njnt}")
print(f"Total actuators: {m.nu}")
print(f"Total DOFs: {m.nq}")
print(f"Total generalized velocities: {m.nv}")

print("\n" + "=" * 80)
print("JOINTS:")
print("=" * 80)
for i in range(m.njnt):
    joint_name = m.names[m.name_jntadr[i]:m.name_jntadr[i]+20].decode().split('\0')[0]
    joint_type = m.jnt_type[i]
    type_names = {0: "free", 1: "ball", 2: "slide", 3: "hinge"}
    print(f"  {i:2d}: {joint_name:30s} | type: {type_names.get(joint_type, joint_type)}")

print("\n" + "=" * 80)
print("ACTUATORS (MOTORS):")
print("=" * 80)
for i in range(m.nu):
    actuator_name = m.names[m.name_actuatoradr[i]:m.name_actuatoradr[i]+20].decode().split('\0')[0]
    joint_id = m.actuator_trnid[i, 0]
    joint_name = m.names[m.name_jntadr[joint_id]:m.name_jntadr[joint_id]+20].decode().split('\0')[0]
    print(f"  {i:2d}: {actuator_name:30s} → {joint_name}")
    

print("\n" + "=" * 80)
print("POSITION/VELOCITY INDICES:")
print("=" * 80)
print(f"d.qpos shape: {d.qpos.shape} (positions)")
print(f"d.qvel shape: {d.qvel.shape} (velocities)")
print(f"d.ctrl shape: {d.ctrl.shape} (controls/torques)")

print("\nFloating base (indices 0-7):")
print(f"  Position [0:7]: {d.qpos[0:7]}")  # x,y,z,qw,qx,qy,qz
print(f"  Velocity [0:6]: {d.qvel[0:6]}")   # vx,vy,vz,wx,wy,wz

print(f"\nJoint positions [7:]: {d.qpos[7:]}")
print(f"Joint velocities [6:]: {d.qvel[6:]}")

In [None]:
# # Print joint sequence with qpos/qvel indices
# print("=" * 120)
# print("JOINT SEQUENCE WITH ARRAY INDICES")
# print("=" * 120)

# print("\n⚠️  d.qpos and d.qvel are indexed by JOINT order (XML order), not actuator order!\n")
# print(f"{'Joint':<8} {'Type':<10} {'Name':<35} {'qpos index':<25} {'qvel index':<25}")
# print("-" * 120)

# qpos_idx = 0
# qvel_idx = 0

# for i in range(m.njnt):
#     joint_name = m.names[m.name_jntadr[i]:m.name_jntadr[i]+20].decode().split('\0')[0]
#     joint_type = m.jnt_type[i]
#     type_names = {0: "free", 1: "ball", 2: "slide", 3: "hinge"}
#     type_str = type_names.get(joint_type, str(joint_type))
    
#     if joint_type == 0:  # free joint (7 position dims + 6 velocity dims)
#         qpos_range = f"qpos[{qpos_idx}:{qpos_idx+7}]"
#         qvel_range = f"qvel[{qvel_idx}:{qvel_idx+6}]"
#         print(f"{i:<8} {type_str:<10} {joint_name:<35} {qpos_range:<25} {qvel_range:<25}")
#         qpos_idx += 7
#         qvel_idx += 6
#     elif joint_type == 3:  # hinge joint (1 position dim + 1 velocity dim)
#         qpos_range = f"qpos[{qpos_idx}]"
#         qvel_range = f"qvel[{qvel_idx}]"
#         print(f"{i:<8} {type_str:<10} {joint_name:<35} {qpos_range:<25} {qvel_range:<25}")
#         qpos_idx += 1
#         qvel_idx += 1
#     else:  # other types
#         print(f"{i:<8} {type_str:<10} {joint_name:<35} (multiple dims)          (multiple dims)")

# print("\n" + "=" * 120)
# print("INDEX RANGES SUMMARY:")
# print("=" * 120)
# print(f"Floating base (joint 0):  qpos[0:7]     qvel[0:6]")
# print(f"Leg joints (12 DOF):      qpos[7:19]    qvel[6:18]")
# print(f"Arm joints (15 DOF):      qpos[19:34]   qvel[18:33]")
# print(f"\nTotal qpos: {m.nq} dims")
# print(f"Total qvel: {m.nv} dims")

# print("\n" + "=" * 120)
# print("CURRENT VALUES:")
# print("=" * 120)
# print(f"\nFloating base position (qpos[0:3]):  {d.qpos[0:3]}")
# print(f"Floating base quaternion (qpos[3:7]):  {d.qpos[3:7]}")
# print(f"\nLeg joint positions (qpos[7:19]):")
# for j in range(12):
#     print(f"  [{7+j:2d}] = {d.qpos[7+j]:8.4f}")
    
# print(f"\nArm joint positions (qpos[19:34]):")
# for j in range(15):
#     print(f"  [{19+j:2d}] = {d.qpos[19+j]:8.4f}")


In [None]:
with mujoco.viewer.launch_passive(m, d) as viewer:
    # Close the viewer automatically after simulation_duration wall-seconds.
    start = time.time()
    step_count = 0
    while viewer.is_running() and time.time() - start < simulation_duration:
        step_start = time.time()
        
        # PD control for legs (first 12 DOF)
        tau_legs = pd_control(target_dof_pos, d.qpos[7:19], kps, np.zeros_like(kds), d.qvel[6:18], kds)
        
        d.ctrl[:12] = tau_legs  # Apply leg control (only 12 actuators in XML)
        
        mujoco.mj_step(m, d)

        counter += 1
        step_count += 1
        
        # Print qpos every 100 steps
        if step_count % 10 == 0:
            print(f"\n--- Step {step_count} ---")
            print(f"Target pos:  {target_dof_pos}")
            print(f"Actual pos:  {d.qpos[7:19]}")
            print(f"Error:       {d.qpos[7:19] - target_dof_pos}")
            print(f"Velocities:  {d.qvel[6:18]}")
            print(f"Torques:     {tau_legs}")
        
        if counter % control_decimation == 0:
            # Extract state from MuJoCo - 12 DOF legs only
            qj = d.qpos[7:19]  # 12 leg joint positions
            dqj = d.qvel[6:18]  # 12 leg joint velocities
            
            quat = d.qpos[3:7]  # Floating base quaternion
            omega = d.qvel[3:6]  # Floating base angular velocity
            
            # Scale observations to match IsaacLab training
            omega_scaled = omega * 0.2  # base_ang_vel scale=0.2
            gravity_orientation = get_gravity_orientation(quat)  # projected_gravity (no scale)
            
            # For joint_pos_rel: compare all joints against default_angles
            qj_rel = (qj - default_angles)  # No scale on joint_pos_rel in URDF
            dqj_scaled = dqj * 0.05  # joint_vel_rel scale=0.05
            cmd_scaled = cmd  # velocity_commands (uses cmd_scale from config)
            
            # Build observation: [base_ang_vel(3), projected_gravity(3), velocity_commands(3),
            #                     joint_pos_rel(12), joint_vel_rel(12), last_action(12)]
            obs_single[0:3] = omega_scaled
            obs_single[3:6] = gravity_orientation
            obs_single[6:9] = cmd_scaled
            obs_single[9:21] = qj_rel  # 12 joint positions
            obs_single[21:33] = dqj_scaled  # 12 joint velocities
            obs_single[33:45] = last_action  # 12 actions
            
            # Shift history and add new observation
            obs_history = np.roll(obs_history, shift=1, axis=0)
            obs_history[0] = obs_single
            
            # Flatten history for policy input
            obs_tensor = torch.from_numpy(obs_history.flatten()).unsqueeze(0).float()
            
            # Policy inference
            action = policy(obs_tensor).detach().numpy().squeeze()
            last_action = action.copy()

            action_scaled = action_scale * action
            
            # Transform action to target positions
            target_dof_pos = action_scaled + default_angles
            # target_dof_pos = np.clip(target_dof_pos, legs_motor_pos_lower_limit_list, legs_motor_pos_upper_limit_list)

            # Constrain target positions within joint limits
            # target_dof_pos = np.clip(target_dof_pos, legs_motor_pos_lower_limit_list, legs_motor_pos_upper_limit_list)
            
        # Pick up changes to the physics state, apply perturbations, update options from GUI.
        viewer.sync()

        # Rudimentary time keeping, will drift relative to wall clock.
        time_until_next_step = m.opt.timestep - (time.time() - step_start)
        if time_until_next_step > 0:
            time.sleep(time_until_next_step)
