## Load Spot quadruped with arm

In [16]:
# load model
import gymnasium
import numpy as np
from IPython import display
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo

env = gymnasium.make(
    "Ant-v5",
    xml_file="./robots/boston_dynamics_spot/scene_arm.xml",
    forward_reward_weight=1,  # kept the same as the 'Ant' environment
    ctrl_cost_weight=0.05,  # changed because of the stronger motors of `Go1`
    contact_cost_weight=5e-4,  # kept the same as the 'Ant' environment
    healthy_reward=1,  # kept the same as the 'Ant' environment
    main_body=1,  # represents the "trunk" of the `Go1` robot
    healthy_z_range=(0.195, 0.95),
    include_cfrc_ext_in_observation=True,
    exclude_current_positions_from_observation=False,
    reset_noise_scale=0.1,
    frame_skip=25,
    max_episode_steps=1000,
    render_mode="rgb_array"
)

env = RecordVideo(env, video_folder="./videos", episode_trigger=lambda episode_id: True)
env = RecordEpisodeStatistics(env)

  logger.warn("Unable to save last video! Did you call close()?")
  logger.warn(


In [6]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo
import mujoco

class LocoManipulationEnv(gym.Env):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 30
    }

    def __init__(self, render_mode=None):
        super().__init__()
        self.render_mode = render_mode

        # Load MuJoCo model
        self.model = mujoco.MjModel.from_xml_path("./robots/boston_dynamics_spot/scene_arm.xml")
        self.data = mujoco.MjData(self.model)

        # Setup render context for offscreen rendering
        if render_mode == "rgb_array":
            self.frame_width = 640
            self.frame_height = 480
            self._renderer = mujoco.Renderer(self.model, width=self.frame_width, height=self.frame_height)

        self.nq = self.model.nq
        self.nv = self.model.nv
        self.nu = self.model.nu

        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(self.nq + self.nv,),
            dtype=np.float32
        )

        self.action_space = spaces.Box(
            low=self.model.actuator_ctrlrange[:, 0],
            high=self.model.actuator_ctrlrange[:, 1],
            dtype=np.float32
        )

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        mujoco.mj_resetData(self.model, self.data)

        obs = np.concatenate([self.data.qpos, self.data.qvel]).astype(np.float32)
        return obs, {}

    def step(self, action):
        self.data.ctrl[:] = action
        mujoco.mj_step(self.model, self.data)

        obs = np.concatenate([self.data.qpos, self.data.qvel]).astype(np.float32)
        reward = -np.linalg.norm(obs)
        terminated = False
        truncated = False
        return obs, reward, terminated, truncated, {}

    def render(self):
        if self.render_mode == "rgb_array":
            self._renderer.update_scene(self.data)
            return self._renderer.render()
        elif self.render_mode == "human":
            if not hasattr(self, "viewer"):
                import mujoco.viewer
                self.viewer = mujoco.viewer.launch_passive(self.model, self.data)
            # Human rendering happens in the background viewer
            return None

    def close(self):
        if hasattr(self, "_renderer"):
            self._renderer.close()
        if hasattr(self, "viewer"):
            self.viewer.close()
    
gym.register(
    id='LocoManipulation-v0',
    entry_point=LocoManipulationEnv,
    kwargs={"render_mode": "rgb_array"}
)

env = gymnasium.make(
    "LocoManipulation-v0",
    render_mode="rgb_array"
)
env = RecordVideo(env, video_folder="./videos", episode_trigger=lambda episode_id: True)
env = RecordEpisodeStatistics(env)

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  gym.logger.warn(
  gym.logger.warn(
  logger.warn(


In [13]:
env.action_space.low, env.action_space.high

(array([-0.785398, -0.898845, -2.7929  , -0.785398, -0.898845, -2.7929  ,
        -0.785398, -0.898845, -2.7929  , -0.785398, -0.898845, -2.7929  ,
        -2.61799 , -3.14159 ,  0.      , -2.79253 , -1.8326  , -2.87979 ,
        -1.57    ], dtype=float32),
 array([ 0.785398,  2.29511 , -0.254402,  0.785398,  2.24363 , -0.255648,
         0.785398,  2.29511 , -0.247067,  0.785398,  2.29511 , -0.248282,
         3.14159 ,  0.523599,  3.14159 ,  2.79253 ,  1.8326  ,  2.87979 ,
         0.      ], dtype=float32))

## Episode loop

In [17]:
for ep in range(4):
    print(f"Starting episode {ep + 1}")
    obs, _ = env.reset()
    done = False
    reward_total = 0
    step_count = 0

    while not done and step_count < 200:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, _ = env.step(action)
        env.render()
        reward_total += reward
        done = terminated or truncated
        step_count += 1

    print(f"Episode {ep + 1} ended with reward {reward_total:.2f}")

Starting episode 1
Episode 1 ended with reward -0.94
Starting episode 2
Episode 2 ended with reward -182.30
Starting episode 3
Episode 3 ended with reward -3.76
Starting episode 4
Episode 4 ended with reward -184.43


In [None]:
import torch as t
import torch.nn as nn

class Actor(nn.Module):
    def __init__(self, obs_shape, action_shape, hidden_dim, std=0.1):
        super().__init__()

        self.std = std
        self.policy = nn.Sequential(nn.Linear(obs_shape[0], hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, action_shape[0]))

        self.apply(utils.weight_init)

    def forward(self, obs):
        mu = self.policy(obs)
        mu = torch.tanh(mu)
        std = torch.ones_like(mu) * self.std

        dist = utils.TruncatedNormal(mu, std)
        return dist


In [10]:
np.concatenate([[1,2], [3,4]])

array([1, 2, 3, 4])