<a href="https://colab.research.google.com/github/anirudhgudi/quad-sdk-Unitree_Go2/blob/main/quadruped_go2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
!pip install mujoco




In [13]:
!pip install torch numpy gymnasium mujoco scipy



In [14]:
!

In [15]:
""" # @title
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Go2 Quadruped RL Trainer (MuJoCo + PPO)\n",
    "\n",
    "This notebook will set up a complete environment for training the Unitree Go2 robot to walk using Reinforcement Learning (PPO). \n",
    "\n",
    "**Instructions:**\n",
    "1.  Make sure your runtime is set to use a GPU (Runtime > Change runtime type > T4 GPU) for faster training.\n",
    "2.  Run the cells in order from top to bottom."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Install Dependencies\n",
    "\n",
    "This cell installs all necessary Python libraries for the simulation and RL algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q mujoco gymnasium torch scipy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Get Robot Model and Assets\n",
    "\n",
    "We clone the official `unitree_mujoco` repository. This gives us the `go2.xml` file and, most importantly, the `assets/` folder containing all the `.obj` mesh files the robot model needs to load."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/unitreerobotics/unitree_mujoco"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Create Python Training Files\n",
    "\n",
    "We use the `%%writefile` magic command to create our four Python scripts in the Colab environment's main directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile config.py\n",
    "import numpy as np\n",
    "\n",
    "# --- Simulation ---\n",
    "# This path now points inside the cloned repository\n",
    "XML_PATH = 'unitree_mujoco/unitree_robots/go2/go2.xml'\n",
    "\n",
    "SIM_HZ = 500  # (Hz) Frequency of MuJoCo physics steps\n",
    "CONTROL_HZ = 50  # (Hz) Frequency of RL agent policy decisions\n",
    "FRAME_SKIP = SIM_HZ // CONTROL_HZ # Number of physics steps per agent step\n",
    "MAX_EPISODE_STEPS = 1000  # Max steps before env reset\n",
    "\n",
    "# --- Robot Model ---\n",
    "# These names are taken from the official 'go2.xml'\n",
    "TRUNK_BODY_NAME = \"trunk\"\n",
    "JOINT_NAMES = [\n",
    "    \"FR_hip_joint\", \"FR_thigh_joint\", \"FR_calf_joint\",\n",
    "    \"FL_hip_joint\", \"FL_thigh_joint\", \"FL_calf_joint\",\n",
    "    \"RR_hip_joint\", \"RR_thigh_joint\", \"RR_calf_joint\",\n",
    "    \"RL_hip_joint\", \"RL_thigh_joint\", \"RL_calf_joint\",\n",
    "]\n",
    "# We use the calf bodies to check for contact, as they contain the foot geoms\n",
    "FOOT_BODY_NAMES = [\"FR_calf\", \"FL_calf\", \"RR_calf\", \"RL_calf\"]\n",
    "\n",
    "# --- Locomotion ---\n",
    "TARGET_VELOCITY = 0.8  # (m/s) Target forward velocity (x-axis)\n",
    "TARGET_HEIGHT = 0.3    # (m) Target CoM height\n",
    "CONTACT_FORCE_THRESHOLD = 5.0  # (N) Force threshold to register foot contact\n",
    "\n",
    "# --- RL Reward Weights ---\n",
    "W_VEL_X = 2.0         # Reward for matching target x-velocity\n",
    "W_VEL_Y = -1.0        # Penalty for y-velocity\n",
    "W_VEL_Z = -1.0        # Penalty for z-velocity\n",
    "W_ANG_VEL = -0.1      # Penalty for angular velocity\n",
    "W_COM_HEIGHT = 1.5    # Reward for maintaining target height\n",
    "W_ORIENTATION = -2.0  # Penalty for roll and pitch\n",
    "W_ACTION_RATE = -0.01 # Penalty for jerky actions\n",
    "W_TORQUE = -0.00002   # Penalty for motor effort (torques)\n",
    "W_CONTACT_FORCE = -0.0001 # Penalty for high contact forces\n",
    "W_COM_IN_SUPPORT = 3.0 # Reward for keeping CoM in support polygon\n",
    "W_FALL = -200.0       # Large penalty for falling\n",
    "\n",
    "# --- PPO Training ---\n",
    "PPO_STEPS_PER_EPOCH = 4096\n",
    "PPO_EPOCHS = 500\n",
    "PPO_LEARNING_RATE = 3e-4\n",
    "PPO_MINIBATCH_SIZE = 64\n",
    "PPO_UPDATE_EPOCHS = 10\n",
    "PPO_GAMMA = 0.99\n",
    "PPO_LAM = 0.95\n",
    "PPO_CLIP = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile utils.py\n",
    "import numpy as np\n",
    "import mujoco\n",
    "from scipy.spatial import ConvexHull, Delaunay\n",
    "\n",
    "# --- MuJoCo Model/Data Getters ---\n",
    "\n",
    "def get_body_id(model, body_name):\n",
    "    \"\"\"Returns the MuJoCo ID for a body.\"\"\"\n",
    "    return mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, body_name)\n",
    "\n",
    "def get_joint_qpos_ids(model, joint_names):\n",
    "    \"\"\"Returns qpos indices for a list of joint names.\"\"\"\n",
    "    return [model.jnt_qposadr[mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, name)] for name in joint_names]\n",
    "\n",
    "def get_joint_qvel_ids(model, joint_names):\n",
    "    \"\"\"Returns qvel indices for a list of joint names.\"\"\"\n",
    "    return [model.jnt_dofadr[mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, name)] for name in joint_names]\n",
    "\n",
    "def get_actuator_ids(model, joint_names):\n",
    "    \"\"\"Returns actuator indices for a list of joint names.\"\"\"\n",
    "    # Assumes motor name is joint name + \"_motor\"\n",
    "    return [mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, name.replace(\"_joint\", \"_motor\")) for name in joint_names]\n",
    "\n",
    "def get_foot_body_ids(model, foot_body_names):\n",
    "    \"\"\"Returns body IDs for a list of foot body names.\"\"\"\n",
    "    return [get_body_id(model, name) for name in foot_body_names]\n",
    "\n",
    "# --- Kinematics & Dynamics ---\n",
    "\n",
    "def get_com_position(data, trunk_id):\n",
    "    \"\"\"Returns the 3D position of the trunk (CoM).\"\"\"\n",
    "    return data.xpos[trunk_id]\n",
    "\n",
    "def get_com_velocity(data, trunk_id):\n",
    "    \"\"\"Returns the 3D linear velocity of the trunk (CoM).\"\"\"\n",
    "    # Use cvel (velocities in world frame)\n",
    "    return data.cvel[trunk_id, 3:6]\n",
    "\n",
    "def get_body_orientation(data, trunk_id):\n",
    "    \"\"\"Returns the quaternion orientation of the trunk.\"\"\"\n",
    "    return data.xquat[trunk_id]\n",
    "\n",
    "def get_body_angular_velocity(data, trunk_id):\n",
    "    \"\"\"Returns the 3D angular velocity of the trunk.\"\"\"\n",
    "    # Use cvel (velocities in world frame)\n",
    "    return data.cvel[trunk_id, 0:3]\n",
    "\n",
    "def quat_to_rpy(quat):\n",
    "    \"\"\"Converts a quaternion (w, x, y, z) to roll, pitch, yaw.\"\"\"\n",
    "    w, x, y, z = quat\n",
    "    \n",
    "    # Roll (x-axis rotation)\n",
    "    sinr_cosp = 2 * (w * x + y * z)\n",
    "    cosr_cosp = 1 - 2 * (x * x + y * y)\n",
    "    roll = np.arctan2(sinr_cosp, cosr_cosp)\n",
    "    \n",
    "    # Pitch (y-axis rotation)\n",
    "    sinp = 2 * (w * y - z * x)\n",
    "    if np.abs(sinp) >= 1:\n",
    "        pitch = np.copysign(np.pi / 2, sinp)  # Use 90 degrees if out of range\n",
    "    else:\n",
    "        pitch = np.arcsin(sinp)\n",
    "        \n",
    "    # Yaw (z-axis rotation)\n",
    "    siny_cosp = 2 * (w * z + x * y)\n",
    "    cosy_cosp = 1 - 2 * (y * y + z * z)\n",
    "    yaw = np.arctan2(siny_cosp, cosy_cosp)\n",
    "    \n",
    "    return roll, pitch, yaw\n",
    "\n",
    "# --- Contact & Stability ---\n",
    "\n",
    "def get_foot_contacts(model, data, foot_body_ids, contact_force_threshold):\n",
    "    \"\"\"\n",
    "    Checks for foot contact with the ground.\n",
    "    Returns a boolean array [FR, FL, RR, RL]\n",
    "    \"\"\"\n",
    "    contacts = [False] * 4\n",
    "    for i in range(data.ncon):\n",
    "        contact = data.contact[i]\n",
    "        \n",
    "        # Check if geom1 or geom2 is a foot\n",
    "        geom1_body = model.geom_bodyid[contact.geom1]\n",
    "        geom2_body = model.geom_bodyid[contact.geom2]\n",
    "        \n",
    "        is_geom1_foot = geom1_body in foot_body_ids\n",
    "        is_geom2_foot = geom2_body in foot_body_ids\n",
    "        \n",
    "        if not (is_geom1_foot or is_geom2_foot):\n",
    "            continue # Not a foot contact\n",
    "\n",
    "        # Check if the other geom is the ground (geom ID 0)\n",
    "        is_geom1_ground = contact.geom1 == 0\n",
    "        is_geom2_ground = contact.geom2 == 0\n",
    "\n",
    "        if not (is_geom1_ground or is_geom2_ground):\n",
    "            continue # Not a ground contact\n",
    "            \n",
    "        # Get contact force\n",
    "        force_normal = np.zeros(3)\n",
    "        mujoco.mj_contactForce(model, data, i, force_normal)\n",
    "        \n",
    "        if np.linalg.norm(force_normal) > contact_force_threshold:\n",
    "            if is_geom1_foot:\n",
    "                foot_idx = foot_body_ids.index(geom1_body)\n",
    "                contacts[foot_idx] = True\n",
    "            if is_geom2_foot:\n",
    "                foot_idx = foot_body_ids.index(geom2_body)\n",
    "                contacts[foot_idx] = True\n",
    "                \n",
    "    return np.array(contacts)\n",
    "\n",
    "def get_foot_positions(data, foot_body_ids):\n",
    "    \"\"\"Returns the 3D world positions of the feet (calf bodies).\"\"\"\n",
    "    return data.xpos[foot_body_ids]\n",
    "\n",
    "def get_support_polygon(foot_positions, contact_states):\n",
    "    \"\"\"\n",
    "    Returns the 2D vertices (x, y) of the support polygon.\n",
    "    Returns an empty list if fewer than 2 feet are in contact.\n",
    "    \"\"\"\n",
    "    stance_feet_pos = foot_positions[contact_states, :2] # Get (x, y) of stance feet\n",
    "    \n",
    "    if stance_feet_pos.shape[0] < 2:\n",
    "        return [] # Not enough points to form a polygon\n",
    "        \n",
    "    if stance_feet_pos.shape[0] == 2:\n",
    "        return stance_feet_pos # Support polygon is a line\n",
    "        \n",
    "    try:\n",
    "        # A Bounding Box is simpler and more stable than Convex Hull for 3 points\n",
    "        if stance_feet_pos.shape[0] == 3:\n",
    "             return stance_feet_pos\n",
    "        \n",
    "        hull = ConvexHull(stance_feet_pos)\n",
    "        return stance_feet_pos[hull.vertices]\n",
    "    except Exception:\n",
    "        return [] # Error during hull calculation (e.g., colinear points)\n",
    "\n",
    "def is_com_stable(com_pos_2d, support_polygon):\n",
    "    \"\"\"\n",
    "    Checks if the 2D CoM position is inside the 2D support polygon.\n",
    "    Uses scipy.spatial.Delaunay for robust point-in-polygon check.\n",
    "    \"\"\"\n",
    "    if len(support_polygon) < 3:\n",
    "        # If support is a line (2 feet) or point (1 foot),\n",
    "        # we can't use a polygon check.\n",
    "        # For simplicity, we'll call it \"unstable\"\n",
    "        return False\n",
    "        \n",
    "    try:\n",
    "        # Create a Delaunay triangulation of the support polygon\n",
    "        hull = Delaunay(support_polygon)\n",
    "        \n",
    "        # find_simplex returns -1 if the point is outside the hull\n",
    "        return hull.find_simplex(com_pos_2d) >= 0\n",
    "    except Exception:\n",
    "        # Error (e.g., flat polygon)\n",
    "        return False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile go2_env.py\n",
    "import mujoco\n",
    "import gymnasium as gym\n",
    "from gymnasium import spaces\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import config\n",
    "import utils\n",
    "\n",
    "class Go2Env(gym.Env):\n",
    "    \"\"\"\n",
    "    Custom Gymnasium environment for the Unitree Go2 robot using MuJoCo.\n",
    "    \"\"\"\n",
    "    metadata = {\"render_modes\": [\"human\", \"rgb_array\"], \"render_fps\": 50}\n",
    "\n",
    "    def __init__(self, render_mode=None):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.render_mode = render_mode\n",
    "        \n",
    "        # Construct the full path to the XML file\n",
    "        # This path is relative to where the script is run (e.g., /content/)\n",
    "        xml_path = config.XML_PATH\n",
    "        \n",
    "        if not os.path.exists(xml_path):\n",
    "            raise FileNotFoundError(\n",
    "                f\"Could not find XML file: {xml_path}. \"\n",
    "                f\"Make sure the 'unitree_mujoco' repo was cloned successfully.\"\n",
    "            )\n",
    "            \n",
    "        self.model = mujoco.MjModel.from_xml_path(xml_path)\n",
    "        self.data = mujoco.MjData(self.model)\n",
    "        \n",
    "        # --- Get element IDs from model ---\n",
    "        self.trunk_id = utils.get_body_id(self.model, config.TRUNK_BODY_NAME)\n",
    "        self.joint_qpos_ids = utils.get_joint_qpos_ids(self.model, config.JOINT_NAMES)\n",
    "        self.joint_qvel_ids = utils.get_joint_qvel_ids(self.model, config.JOINT_NAMES)\n",
    "        self.actuator_ids = utils.get_actuator_ids(self.model, config.JOINT_NAMES)\n",
    "        self.foot_body_ids = utils.get_foot_body_ids(self.model, config.FOOT_BODY_NAMES)\n",
    "        \n",
    "        # Store initial state for resets\n",
    "        self.init_qpos = self.data.qpos.copy()\n",
    "        self.init_qvel = self.data.qvel.copy()\n",
    "        self.action_history = np.zeros(12)\n",
    "        \n",
    "        # --- Define Action Space ---\n",
    "        act_dim = 12\n",
    "        ctrl_range = self.model.actuator_ctrlrange[self.actuator_ids]\n",
    "        self.action_low = ctrl_range[:, 0]\n",
    "        self.action_high = ctrl_range[:, 1]\n",
    "        \n",
    "        self.action_space = spaces.Box(\n",
    "            low=self.action_low,\n",
    "            high=self.action_high,\n",
    "            dtype=np.float32\n",
    "        )\n",
    "        \n",
    "        # --- Define Observation Space ---\n",
    "        obs_dim = 3 + 3 + 4 + 12 + 12 + 4 # com_vel, ang_vel, quat, qpos, qvel, contacts\n",
    "        self.observation_space = spaces.Box(\n",
    "            low=-np.inf,\n",
    "            high=np.inf,\n",
    "            shape=(obs_dim,),\n",
    "            dtype=np.float64\n",
    "        )\n",
    "        \n",
    "        # --- Rendering (Colab) ---\n",
    "        # We'll use 'rgb_array' for Colab. 'human' mode won't work.\n",
    "        if self.render_mode == \"human\":\n",
    "            print(\"Warning: 'human' render_mode not supported in Colab. Use 'rgb_array' instead.\")\n",
    "            self.render_mode = \"rgb_array\"\n",
    "            \n",
    "        if self.render_mode == \"rgb_array\":\n",
    "            self.renderer = mujoco.Renderer(self.model, 480, 640)\n",
    "        else:\n",
    "            self.renderer = None\n",
    "\n",
    "    def _get_obs(self):\n",
    "        \"\"\"Constructs the observation vector from simulation data.\"\"\"\n",
    "        com_vel = utils.get_com_velocity(self.data, self.trunk_id)\n",
    "        ang_vel = utils.get_body_angular_velocity(self.data, self.trunk_id)\n",
    "        quat = utils.get_body_orientation(self.data, self.trunk_id)\n",
    "        qpos = self.data.qpos[self.joint_qpos_ids]\n",
    "        qvel = self.data.qvel[self.joint_qvel_ids]\n",
    "        contacts = utils.get_foot_contacts(\n",
    "            self.model, self.data, self.foot_body_ids, config.CONTACT_FORCE_THRESHOLD\n",
    "        )\n",
    "        \n",
    "        return np.concatenate([\n",
    "            com_vel, ang_vel, quat, qpos, qvel, contacts.astype(float)\n",
    "        ])\n",
    "\n",
    "    def _compute_reward(self, action):\n",
    "        \"\"\"Calculates the reward based on the current state and action.\"\"\"\n",
    "        \n",
    "        # --- Get current state data ---\n",
    "        com_pos = utils.get_com_position(self.data, self.trunk_id)\n",
    "        com_vel = utils.get_com_velocity(self.data, self.trunk_id)\n",
    "        ang_vel = utils.get_body_angular_velocity(self.data, self.trunk_id)\n",
    "        quat = utils.get_body_orientation(self.data, self.trunk_id)\n",
    "        torques = self.data.actuator_force[self.actuator_ids]\n",
    "        \n",
    "        # 1. Velocity Tracking (X, Y, Z)\n",
    "        vel_error_x = (com_vel[0] - config.TARGET_VELOCITY)**2\n",
    "        vel_error_y = com_vel[1]**2\n",
    "        vel_error_z = com_vel[2]**2\n",
    "        \n",
    "        reward_vel_x = config.W_VEL_X * np.exp(-vel_error_x * 5.0)\n",
    "        penalty_vel_y = config.W_VEL_Y * vel_error_y\n",
    "        penalty_vel_z = config.W_VEL_Z * vel_error_z\n",
    "        \n",
    "        # 2. CoM Height\n",
    "        height_error = (com_pos[2] - config.TARGET_HEIGHT)**2\n",
    "        reward_height = config.W_COM_HEIGHT * np.exp(-height_error * 20.0)\n",
    "        \n",
    "        # 3. Orientation\n",
    "        roll, pitch, _ = utils.quat_to_rpy(quat)\n",
    "        orientation_penalty = config.W_ORIENTATION * (roll**2 + pitch**2)\n",
    "        ang_vel_penalty = config.W_ANG_VEL * np.sum(ang_vel**2)\n",
    "        \n",
    "        # 4. Effort / Torque / Action Rate\n",
    "        torque_penalty = config.W_TORQUE * np.sum(torques**2)\n",
    "        action_rate_penalty = config.W_ACTION_RATE * np.sum((action - self.action_history)**2)\n",
    "        self.action_history = action # store for next step\n",
    "\n",
    "        # 5. Stability: CoM within Support Polygon\n",
    "        contacts = utils.get_foot_contacts(\n",
    "            self.model, self.data, self.foot_body_ids, config.CONTACT_FORCE_THRESHOLD\n",
    "        )\n",
    "        foot_positions = utils.get_foot_positions(self.data, self.foot_body_ids)\n",
    "        support_polygon = utils.get_support_polygon(foot_positions, contacts)\n",
    "        \n",
    "        is_stable = utils.is_com_stable(com_pos[:2], support_polygon)\n",
    "        reward_com_stable = config.W_COM_IN_SUPPORT if is_stable else 0.0\n",
    "\n",
    "        # --- Sum Rewards ---\n",
    "        total_reward = (\n",
    "            reward_vel_x + penalty_vel_y + penalty_vel_z +\n",
    "            reward_height + \n",
    "            orientation_penalty + ang_vel_penalty +\n",
    "            torque_penalty + action_rate_penalty +\n",
    "            reward_com_stable\n",
    "        )\n",
    "        \n",
    "        reward_info = {\n",
    "            \"r_vel_x\": reward_vel_x,\n",
    "            \"p_vel_y\": penalty_vel_y,\n",
    "            \"p_vel_z\": penalty_vel_z,\n",
    "            \"r_height\": reward_height,\n",
    "            \"p_orientation\": orientation_penalty,\n",
    "            \"p_ang_vel\": ang_vel_penalty,\n",
    "            \"p_torque\": torque_penalty,\n",
    "            \"p_action_rate\": action_rate_penalty,\n",
    "            \"r_com_stable\": reward_com_stable,\n",
    "        }\n",
    "        \n",
    "        return total_reward, reward_info\n",
    "\n",
    "    def _check_termination(self):\n",
    "        \"\"\"Checks if the episode should terminate (e.g., robot fell).\"\"\"\n",
    "        com_pos = utils.get_com_position(self.data, self.trunk_id)\n",
    "        quat = utils.get_body_orientation(self.data, self.trunk_id)\n",
    "        roll, pitch, _ = utils.quat_to_rpy(quat)\n",
    "        \n",
    "        # Fell if CoM is too low or if roll/pitch is too high\n",
    "        is_fallen = (com_pos[2] < 0.15) or (abs(roll) > 1.0) or (abs(pitch) > 1.0)\n",
    "        \n",
    "        return is_fallen\n",
    "\n",
    "    def step(self, action):\n",
    "        \"\"\"Run one timestep of the environment's dynamics.\"\"\"\n",
    "        \n",
    "        # Clip action to be safe\n",
    "        action = np.clip(action, self.action_low, self.action_high)\n",
    "        \n",
    "        # Set the target joint positions as the control signal\n",
    "        self.data.ctrl[self.actuator_ids] = action\n",
    "        \n",
    "        # Step the simulation forward\n",
    "        mujoco.mj_step(self.model, self.data, nstep=config.FRAME_SKIP)\n",
    "        \n",
    "        # Get new state, reward, and termination status\n",
    "        observation = self._get_obs()\n",
    "        terminated = self._check_termination()\n",
    "        reward, reward_info = self._compute_reward(action)\n",
    "        \n",
    "        if terminated:\n",
    "            reward += config.W_FALL # Add large fall penalty\n",
    "\n",
    "        self.step_count += 1\n",
    "        truncated = self.step_count >= config.MAX_EPISODE_STEPS\n",
    "\n",
    "        return observation, reward, terminated, truncated, reward_info\n",
    "\n",
    "    def reset(self, seed=None):\n",
    "        \"\"\"Reset the environment to an initial state.\"\"\"\n",
    "        super().reset(seed=seed)\n",
    "        \n",
    "        self.step_count = 0\n",
    "        self.action_history = np.zeros(12)\n",
    "\n",
    "        mujoco.mj_resetData(self.model, self.data)\n",
    "        self.data.qpos[:] = self.init_qpos\n",
    "        self.data.qvel[:] = self.init_qvel\n",
    "        \n",
    "        # Add small random noise to initial joint positions\n",
    "        self.data.qpos[self.joint_qpos_ids] += self.np_random.uniform(\n",
    "            -0.1, 0.1, size=len(self.joint_qpos_ids)\n",
    "        )\n",
    "        \n",
    "        mujoco.mj_forward(self.model, self.data)\n",
    "        \n",
    "        observation = self._get_obs()\n",
    "        return observation, {}\n",
    "\n",
    "    def render(self):\n",
    "        \"\"\"Render the environment (if in 'rgb_array' mode).\"\"\"\n",
    "        if self.renderer:\n",
    "            self.renderer.update_scene(self.data)\n",
    "            return self.renderer.render()\n",
    "        return None\n",
    "\n",
    "    def close(self):\n",
    "        \"\"\"Close the viewer.\"\"\"\n",
    "        if self.renderer:\n",
    "            self.renderer = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile train_ppo.py\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.distributions.normal import Normal\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "\n",
    "from go2_env import Go2Env  # Import the custom environment\n",
    "import config  # Import config file\n",
    "\n",
    "class ActorCritic(nn.Module):\n",
    "    \"\"\"PPO Actor-Critic network.\"\"\"\n",
    "    def __init__(self, obs_dim, action_dim, action_low, action_high):\n",
    "        super().__init__()\n",
    "        \n",
    "        # Store action scaling parameters\n",
    "        self.action_low = torch.tensor(action_low, dtype=torch.float32)\n",
    "        self.action_high = torch.tensor(action_high, dtype=torch.float32)\n",
    "        self.action_scale = (self.action_high - self.action_low) / 2.0\n",
    "        self.action_bias = (self.action_high + self.action_low) / 2.0\n",
    "\n",
    "        hidden_dim = 256\n",
    "        \n",
    "        # Critic network\n",
    "        self.critic = nn.Sequential(\n",
    "            nn.Linear(obs_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, 1)\n",
    "        )\n",
    "        \n",
    "        # Actor network\n",
    "        self.actor = nn.Sequential(\n",
    "            nn.Linear(obs_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, action_dim)\n",
    "        )\n",
    "        \n",
    "        # Standard deviation for the action distribution\n",
    "        self.actor_logstd = nn.Parameter(torch.zeros(1, action_dim))\n",
    "\n",
    "    def get_value(self, x):\n",
    "        return self.critic(x)\n",
    "\n",
    "    def get_action_and_value(self, x, action=None):\n",
    "        \"\"\"\n",
    "        Gets an action (and its log_prob) and the state value.\n",
    "        If action is provided, it evaluates that action.\n",
    "        If action is None, it samples a new action.\n",
    "        \"\"\"\n",
    "        # Actor output is the mean of a distribution in unbounded space\n",
    "        action_mean = self.actor(x)\n",
    "        \n",
    "        action_logstd = self.actor_logstd.expand_as(action_mean)\n",
    "        action_std = torch.exp(action_logstd)\n",
    "        \n",
    "        probs = Normal(action_mean, action_std)\n",
    "        \n",
    "        if action is None:\n",
    "            # Sample new action from the unbounded distribution\n",
    "            action_unbounded = probs.sample()\n",
    "            # Squash to [-1, 1] using Tanh\n",
    "            action_tanh = torch.tanh(action_unbounded)\n",
    "            # Scale and shift to the correct action range\n",
    "            action = self.action_bias + self.action_scale * action_tanh\n",
    "        else:\n",
    "            # Evaluate given action\n",
    "            # We need to reverse the scaling to get the \"tanh\" value\n",
    "            action_tanh = (action - self.action_bias) / self.action_scale\n",
    "            # Clip to avoid numerical issues at the bounds\n",
    "            action_tanh = torch.clamp(action_tanh, -0.9999, 0.9999)\n",
    "            # Reverse the Tanh to get the unbounded action\n",
    "            action_unbounded = torch.atanh(action_tanh)\n",
    "            \n",
    "        # Log-prob of the scaled action\n",
    "        log_prob = probs.log_prob(action_unbounded)\n",
    "        log_prob -= torch.log(self.action_scale * (1 - action_tanh.pow(2)) + 1e-6)\n",
    "        log_prob = log_prob.sum(1, keepdim=True)\n",
    "        \n",
    "        entropy = probs.entropy().sum(1)\n",
    "        value = self.critic(x)\n",
    "        \n",
    "        return action, log_prob, entropy, value\n",
    "\n",
    "def main():\n",
    "    # Set device\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(f\"Using device: {device}\")\n",
    "    \n",
    "    print(\"Initializing Go2 RL Environment...\")\n",
    "    env = Go2Env()\n",
    "    \n",
    "    obs_dim = env.observation_space.shape[0]\n",
    "    action_dim = env.action_space.shape[0]\n",
    "    \n",
    "    print(f\"Observation space dim: {obs_dim}\")\n",
    "    print(f\"Action space dim: {action_dim}\")\n",
    "\n",
    "    # --- PPO Agent ---\n",
    "    agent = ActorCritic(obs_dim, action_dim, env.action_low, env.action_high).to(device)\n",
    "    optimizer = optim.Adam(agent.parameters(), lr=config.PPO_LEARNING_RATE, eps=1e-5)\n",
    "    \n",
    "    # --- Storage ---\n",
    "    num_steps = config.PPO_STEPS_PER_EPOCH\n",
    "    \n",
    "    obs = torch.zeros((num_steps, obs_dim)).to(device)\n",
    "    actions = torch.zeros((num_steps, action_dim)).to(device)\n",
    "    logprobs = torch.zeros(num_steps).to(device)\n",
    "    rewards = torch.zeros(num_steps).to(device)\n",
    "    dones = torch.zeros(num_steps).to(device)\n",
    "    values = torch.zeros(num_steps).to(device)\n",
    "\n",
    "    print(\"Starting PPO Training...\")\n",
    "    start_time = time.time()\n",
    "    \n",
    "    next_obs, _ = env.reset()\n",
    "    next_obs = torch.Tensor(next_obs).to(device)\n",
    "    next_done = torch.zeros(1).to(device)\n",
    "    \n",
    "    for epoch in range(config.PPO_EPOCHS):\n",
    "        epoch_rewards = []\n",
    "        epoch_reward_components = {}\n",
    "\n",
    "        for step in range(num_steps):\n",
    "            obs[step] = next_obs\n",
    "            dones[step] = next_done\n",
    "\n",
    "            # Get action and value from agent\n",
    "            with torch.no_grad():\n",
    "                action, logprob, _, value = agent.get_action_and_value(next_obs.unsqueeze(0))\n",
    "                values[step] = value.flatten()\n",
    "            \n",
    "            actions[step] = action.squeeze(0)\n",
    "            logprobs[step] = logprob.squeeze()\n",
    "\n",
    "            # Step the environment\n",
    "            next_obs_np, reward, terminated, truncated, info = env.step(action.cpu().numpy().squeeze(0))\n",
    "            epoch_rewards.append(reward)\n",
    "            \n",
    "            # Log reward components\n",
    "            for key, val in info.items():\n",
    "                if key not in epoch_reward_components:\n",
    "                    epoch_reward_components[key] = []\n",
    "                epoch_reward_components[key].append(val)\n",
    "            \n",
    "            rewards[step] = torch.tensor(reward, device=device).view(-1)\n",
    "            next_obs = torch.Tensor(next_obs_np).to(device)\n",
    "            next_done = torch.tensor(float(terminated or truncated), device=device)\n",
    "            \n",
    "            if next_done:\n",
    "                epoch_rewards = []\n",
    "                epoch_reward_components = {}\n",
    "                next_obs, _ = env.reset()\n",
    "                next_obs = torch.Tensor(next_obs).to(device)\n",
    "\n",
    "        # --- Calculate Advantages (GAE) ---\n",
    "        with torch.no_grad():\n",
    "            next_value = agent.get_value(next_obs.unsqueeze(0)).reshape(1, -1)\n",
    "            advantages = torch.zeros_like(rewards).to(device)\n",
    "            lastgaelam = 0\n",
    "            for t in reversed(range(num_steps)):\n",
    "                if t == num_steps - 1:\n",
    "                    nextnonterminal = 1.0 - next_done\n",
    "                    nextvalues = next_value\n",
    "                else:\n",
    "                    nextnonterminal = 1.0 - dones[t + 1]\n",
    "                    nextvalues = values[t + 1]\n",
    "                \n",
    "                delta = rewards[t] + config.PPO_GAMMA * nextvalues * nextnonterminal - values[t]\n",
    "                advantages[t] = lastgaelam = delta + config.PPO_GAMMA * config.PPO_LAM * nextnonterminal * lastgaelam\n",
    "            returns = advantages + values\n",
    "\n",
    "        # --- Update Policy ---\n",
    "        b_obs = obs.reshape((-1,) + env.observation_space.shape)\n",
    "        b_actions = actions.reshape((-1,) + env.action_space.shape)\n",
    "        b_logprobs = logprobs.reshape(-1)\n",
    "        b_advantages = advantages.reshape(-1)\n",
    "        b_returns = returns.reshape(-1)\n",
    "\n",
    "        # Normalize advantages\n",
    "        b_advantages = (b_advantages - b_advantages.mean()) / (b_advantages.std() + 1e-8)\n",
    "        \n",
    "        b_inds = np.arange(num_steps)\n",
    "        for _ in range(config.PPO_UPDATE_EPOCHS):\n",
    "            np.random.shuffle(b_inds)\n",
    "            for start in range(0, num_steps, config.PPO_MINIBATCH_SIZE):\n",
    "                end = start + config.PPO_MINIBATCH_SIZE\n",
    "                mb_inds = b_inds[start:end]\n",
    "\n",
    "                _, newlogprob, entropy, newvalue = agent.get_action_and_value(\n",
    "                    b_obs[mb_inds], b_actions[mb_inds]\n",
    "                )\n",
    "                logratio = newlogprob.squeeze() - b_logprobs[mb_inds]\n",
    "                ratio = logratio.exp()\n",
    "\n",
    "                # Policy loss\n",
    "                pg_loss1 = -b_advantages[mb_inds] * ratio\n",
    "                pg_loss2 = -b_advantages[mb_inds] * torch.clamp(\n",
    "                    ratio, 1 - config.PPO_CLIP, 1 + config.PPO_CLIP\n",
    "                )\n",
    "                pg_loss = torch.max(pg_loss1, pg_loss2).mean()\n",
    "\n",
    "                # Value loss\n",
    "                v_loss = 0.5 * ((newvalue.squeeze() - b_returns[mb_inds]) ** 2).mean()\n",
    "\n",
    "                # Entropy loss\n",
    "                entropy_loss = entropy.mean()\n",
    "\n",
    "                # Total loss\n",
    "                loss = pg_loss - 0.01 * entropy_loss + v_loss * 0.5\n",
    "\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                nn.utils.clip_grad_norm_(agent.parameters(), 0.5)\n",
    "                optimizer.step()\n",
    "\n",
    "        # --- Logging ---\n",
    "        num_episodes = dones.sum().item()\n",
    "        if num_episodes == 0:\n",
    "            avg_reward = np.nan # Avoid division by zero if no episodes finished\n",
    "        else:\n",
    "            avg_reward = rewards.sum().item() / num_episodes\n",
    "            \n",
    "        print(f\"Epoch {epoch+1}/{config.PPO_EPOCHS} | Avg. Ep Reward: {avg_reward:.2f} | Time: {time.time()-start_time:.2f}s\")\n",
    "        \n",
    "        # Log mean of reward components\n",
    "        for key, val_list in epoch_reward_components.items():\n",
    "            if val_list:\n",
    "                print(f\"  ... avg {key}: {np.mean(val_list):.3f}\")\n",
    "        \n",
    "    env.close()\n",
    "    print(\"Training finished.\")\n",
    "    \n",
    "    # Save the trained policy\n",
    "    model_path = \"ppo_go2_policy.pth\"\n",
    "    torch.save(agent.state_dict(), model_path)\n",
    "    print(f\"Trained policy saved to {model_path}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Run Training\n",
    "\n",
    "This final cell executes the training script. It will print the average reward for each epoch. Training will take a while!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python train_ppo.py"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}

"""



# Go2 Quadruped RL Trainer (MuJoCo + PPO)

This notebook will set up a complete environment for training the Unitree Go2 robot to walk using Reinforcement Learning (PPO).

**Instructions:**
1.  Make sure your runtime is set to use a GPU (Runtime > Change runtime type > T4 GPU) for faster training.
2.  Run the cells in order from top to bottom.

## Step 1: Install Dependencies

This cell installs all necessary Python libraries for the simulation and RL algorithm.

In [16]:
!pip install -q mujoco gymnasium torch scipy

## Step 2: Get Robot Model and Assets

We clone the official `unitree_mujoco` repository. This gives us the `go2.xml` file and, most importantly, the `assets/` folder containing all the `.obj` mesh files the robot model needs to load.

In [17]:
!git clone https://github.com/unitreerobotics/unitree_mujoco

Cloning into 'unitree_mujoco'...
remote: Enumerating objects: 729, done.[K
remote: Counting objects: 100% (235/235), done.[K
remote: Compressing objects: 100% (111/111), done.[K
remote: Total 729 (delta 163), reused 131 (delta 124), pack-reused 494 (from 1)[K
Receiving objects: 100% (729/729), 62.56 MiB | 17.47 MiB/s, done.
Resolving deltas: 100% (264/264), done.
Updating files: 100% (390/390), done.


## Step 3: Create Python Training Files

We use the `%%writefile` magic command to create our four Python scripts in the Colab environment's main directory.

In [18]:
%%writefile config.py
import numpy as np

# --- Simulation ---
# This path now points inside the cloned repository
XML_PATH = 'unitree_mujoco/unitree_robots/go2/go2.xml'

SIM_HZ = 500  # (Hz) Frequency of MuJoCo physics steps
CONTROL_HZ = 50  # (Hz) Frequency of RL agent policy decisions
FRAME_SKIP = SIM_HZ // CONTROL_HZ # Number of physics steps per agent step
MAX_EPISODE_STEPS = 1000  # Max steps before env reset

# --- Robot Model ---
# These names are taken from the official 'go2.xml'
TRUNK_BODY_NAME = "trunk"
JOINT_NAMES = [
    "FR_hip_joint", "FR_thigh_joint", "FR_calf_joint",
    "FL_hip_joint", "FL_thigh_joint", "FL_calf_joint",
    "RR_hip_joint", "RR_thigh_joint", "RR_calf_joint",
    "RL_hip_joint", "RL_thigh_joint", "RL_calf_joint",
]
# We use the calf bodies to check for contact, as they contain the foot geoms
FOOT_BODY_NAMES = ["FR_calf", "FL_calf", "RR_calf", "RL_calf"]

# --- Locomotion ---
TARGET_VELOCITY = 0.8  # (m/s) Target forward velocity (x-axis)
TARGET_HEIGHT = 0.3    # (m) Target CoM height
CONTACT_FORCE_THRESHOLD = 5.0  # (N) Force threshold to register foot contact

# --- RL Reward Weights ---
W_VEL_X = 2.0         # Reward for matching target x-velocity
W_VEL_Y = -1.0        # Penalty for y-velocity
W_VEL_Z = -1.0        # Penalty for z-velocity
W_ANG_VEL = -0.1      # Penalty for angular velocity
W_COM_HEIGHT = 1.5    # Reward for maintaining target height
W_ORIENTATION = -2.0  # Penalty for roll and pitch
W_ACTION_RATE = -0.01 # Penalty for jerky actions
W_TORQUE = -0.00002   # Penalty for motor effort (torques)
W_CONTACT_FORCE = -0.0001 # Penalty for high contact forces
W_COM_IN_SUPPORT = 3.0 # Reward for keeping CoM in support polygon
W_FALL = -200.0       # Large penalty for falling

# --- PPO Training ---
PPO_STEPS_PER_EPOCH = 4096
PPO_EPOCHS = 500
PPO_LEARNING_RATE = 3e-4
PPO_MINIBATCH_SIZE = 64
PPO_UPDATE_EPOCHS = 10
PPO_GAMMA = 0.99
PPO_LAM = 0.95
PPO_CLIP = 0.2

Writing config.py


In [19]:
%%writefile utils.py
import numpy as np
import mujoco
from scipy.spatial import ConvexHull, Delaunay

# --- MuJoCo Model/Data Getters ---

def get_body_id(model, body_name):
    """Returns the MuJoCo ID for a body."""
    return mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, body_name)

def get_joint_qpos_ids(model, joint_names):
    """Returns qpos indices for a list of joint names."""
    return [model.jnt_qposadr[mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, name)] for name in joint_names]

def get_joint_qvel_ids(model, joint_names):
    """Returns qvel indices for a list of joint names."""
    return [model.jnt_dofadr[mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, name)] for name in joint_names]

def get_actuator_ids(model, joint_names):
    """Returns actuator indices for a list of joint names."""
    # Assumes motor name is joint name + "_motor"
    return [mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, name.replace("_joint", "_motor")) for name in joint_names]

def get_foot_body_ids(model, foot_body_names):
    """Returns body IDs for a list of foot body names."""
    return [get_body_id(model, name) for name in foot_body_names]

# --- Kinematics & Dynamics ---

def get_com_position(data, trunk_id):
    """Returns the 3D position of the trunk (CoM)."""
    return data.xpos[trunk_id]

def get_com_velocity(data, trunk_id):
    """Returns the 3D linear velocity of the trunk (CoM)."""
    # Use cvel (velocities in world frame)
    return data.cvel[trunk_id, 3:6]

def get_body_orientation(data, trunk_id):
    """Returns the quaternion orientation of the trunk."""
    return data.xquat[trunk_id]

def get_body_angular_velocity(data, trunk_id):
    """Returns the 3D angular velocity of the trunk."""
    # Use cvel (velocities in world frame)
    return data.cvel[trunk_id, 0:3]

def quat_to_rpy(quat):
    """Converts a quaternion (w, x, y, z) to roll, pitch, yaw."""
    w, x, y, z = quat

    # Roll (x-axis rotation)
    sinr_cosp = 2 * (w * x + y * z)
    cosr_cosp = 1 - 2 * (x * x + y * y)
    roll = np.arctan2(sinr_cosp, cosr_cosp)

    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    if np.abs(sinp) >= 1:
        pitch = np.copysign(np.pi / 2, sinp)  # Use 90 degrees if out of range
    else:
        pitch = np.arcsin(sinp)

    # Yaw (z-axis rotation)
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y * y + z * z)
    yaw = np.arctan2(siny_cosp, cosy_cosp)

    return roll, pitch, yaw

# --- Contact & Stability ---

def get_foot_contacts(model, data, foot_body_ids, contact_force_threshold):
    """
    Checks for foot contact with the ground.
    Returns a boolean array [FR, FL, RR, RL]
    """
    contacts = [False] * 4
    for i in range(data.ncon):
        contact = data.contact[i]

        # Check if geom1 or geom2 is a foot
        geom1_body = model.geom_bodyid[contact.geom1]
        geom2_body = model.geom_bodyid[contact.geom2]

        is_geom1_foot = geom1_body in foot_body_ids
        is_geom2_foot = geom2_body in foot_body_ids

        if not (is_geom1_foot or is_geom2_foot):
            continue # Not a foot contact

        # Check if the other geom is the ground (geom ID 0)
        is_geom1_ground = contact.geom1 == 0
        is_geom2_ground = contact.geom2 == 0

        if not (is_geom1_ground or is_geom2_ground):
            continue # Not a ground contact

        # Get contact force
        force_normal = np.zeros(3)
        mujoco.mj_contactForce(model, data, i, force_normal)

        if np.linalg.norm(force_normal) > contact_force_threshold:
            if is_geom1_foot:
                foot_idx = foot_body_ids.index(geom1_body)
                contacts[foot_idx] = True
            if is_geom2_foot:
                foot_idx = foot_body_ids.index(geom2_body)
                contacts[foot_idx] = True

    return np.array(contacts)

def get_foot_positions(data, foot_body_ids):
    """Returns the 3D world positions of the feet (calf bodies)."""
    return data.xpos[foot_body_ids]

def get_support_polygon(foot_positions, contact_states):
    """
    Returns the 2D vertices (x, y) of the support polygon.
    Returns an empty list if fewer than 2 feet are in contact.
    """
    stance_feet_pos = foot_positions[contact_states, :2] # Get (x, y) of stance feet

    if stance_feet_pos.shape[0] < 2:
        return [] # Not enough points to form a polygon

    if stance_feet_pos.shape[0] == 2:
        return stance_feet_pos # Support polygon is a line

    try:
        # A Bounding Box is simpler and more stable than Convex Hull for 3 points
        if stance_feet_pos.shape[0] == 3:
             return stance_feet_pos

        hull = ConvexHull(stance_feet_pos)
        return stance_feet_pos[hull.vertices]
    except Exception:
        return [] # Error during hull calculation (e.g., colinear points)

def is_com_stable(com_pos_2d, support_polygon):
    """
    Checks if the 2D CoM position is inside the 2D support polygon.
    Uses scipy.spatial.Delaunay for robust point-in-polygon check.
    """
    if len(support_polygon) < 3:
        # If support is a line (2 feet) or point (1 foot),
        # we can't use a polygon check.
        # For simplicity, we'll call it "unstable"
        return False

    try:
        # Create a Delaunay triangulation of the support polygon
        hull = Delaunay(support_polygon)

        # find_simplex returns -1 if the point is outside the hull
        return hull.find_simplex(com_pos_2d) >= 0
    except Exception:
        # Error (e.g., flat polygon)
        return False

Writing utils.py


In [20]:
%%writefile go2_env.py
import mujoco
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import os

import config
import utils

class Go2Env(gym.Env):
    """
    Custom Gymnasium environment for the Unitree Go2 robot using MuJoCo.
    """
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50}

    def __init__(self, render_mode=None):
        super().__init__()

        self.render_mode = render_mode

        # Construct the full path to the XML file
        # This path is relative to where the script is run (e.g., /content/)
        xml_path = config.XML_PATH

        if not os.path.exists(xml_path):
            raise FileNotFoundError(
                f"Could not find XML file: {xml_path}. "
                f"Make sure the 'unitree_mujoco' repo was cloned successfully."
            )

        self.model = mujoco.MjModel.from_xml_path(xml_path)
        self.data = mujoco.MjData(self.model)

        # --- Get element IDs from model ---
        self.trunk_id = utils.get_body_id(self.model, config.TRUNK_BODY_NAME)
        self.joint_qpos_ids = utils.get_joint_qpos_ids(self.model, config.JOINT_NAMES)
        self.joint_qvel_ids = utils.get_joint_qvel_ids(self.model, config.JOINT_NAMES)
        self.actuator_ids = utils.get_actuator_ids(self.model, config.JOINT_NAMES)
        self.foot_body_ids = utils.get_foot_body_ids(self.model, config.FOOT_BODY_NAMES)

        # Store initial state for resets
        self.init_qpos = self.data.qpos.copy()
        self.init_qvel = self.data.qvel.copy()
        self.action_history = np.zeros(12)

        # --- Define Action Space ---
        act_dim = 12
        ctrl_range = self.model.actuator_ctrlrange[self.actuator_ids]
        self.action_low = ctrl_range[:, 0]
        self.action_high = ctrl_range[:, 1]

        self.action_space = spaces.Box(
            low=self.action_low,
            high=self.action_high,
            dtype=np.float32
        )

        # --- Define Observation Space ---
        obs_dim = 3 + 3 + 4 + 12 + 12 + 4 # com_vel, ang_vel, quat, qpos, qvel, contacts
        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(obs_dim,),
            dtype=np.float64
        )

        # --- Rendering (Colab) ---
        # We'll use 'rgb_array' for Colab. 'human' mode won't work.
        if self.render_mode == "human":
            print("Warning: 'human' render_mode not supported in Colab. Use 'rgb_array' instead.")
            self.render_mode = "rgb_array"

        if self.render_mode == "rgb_array":
            self.renderer = mujoco.Renderer(self.model, 480, 640)
        else:
            self.renderer = None

    def _get_obs(self):
        """Constructs the observation vector from simulation data."""
        com_vel = utils.get_com_velocity(self.data, self.trunk_id)
        ang_vel = utils.get_body_angular_velocity(self.data, self.trunk_id)
        quat = utils.get_body_orientation(self.data, self.trunk_id)
        qpos = self.data.qpos[self.joint_qpos_ids]
        qvel = self.data.qvel[self.joint_qvel_ids]
        contacts = utils.get_foot_contacts(
            self.model, self.data, self.foot_body_ids, config.CONTACT_FORCE_THRESHOLD
        )

        return np.concatenate([
            com_vel, ang_vel, quat, qpos, qvel, contacts.astype(float)
        ])

    def _compute_reward(self, action):
        """Calculates the reward based on the current state and action."""

        # --- Get current state data ---
        com_pos = utils.get_com_position(self.data, self.trunk_id)
        com_vel = utils.get_com_velocity(self.data, self.trunk_id)
        ang_vel = utils.get_body_angular_velocity(self.data, self.trunk_id)
        quat = utils.get_body_orientation(self.data, self.trunk_id)
        torques = self.data.actuator_force[self.actuator_ids]

        # 1. Velocity Tracking (X, Y, Z)
        vel_error_x = (com_vel[0] - config.TARGET_VELOCITY)**2
        vel_error_y = com_vel[1]**2
        vel_error_z = com_vel[2]**2

        reward_vel_x = config.W_VEL_X * np.exp(-vel_error_x * 5.0)
        penalty_vel_y = config.W_VEL_Y * vel_error_y
        penalty_vel_z = config.W_VEL_Z * vel_error_z

        # 2. CoM Height
        height_error = (com_pos[2] - config.TARGET_HEIGHT)**2
        reward_height = config.W_COM_HEIGHT * np.exp(-height_error * 20.0)

        # 3. Orientation
        roll, pitch, _ = utils.quat_to_rpy(quat)
        orientation_penalty = config.W_ORIENTATION * (roll**2 + pitch**2)
        ang_vel_penalty = config.W_ANG_VEL * np.sum(ang_vel**2)

        # 4. Effort / Torque / Action Rate
        torque_penalty = config.W_TORQUE * np.sum(torques**2)
        action_rate_penalty = config.W_ACTION_RATE * np.sum((action - self.action_history)**2)
        self.action_history = action # store for next step

        # 5. Stability: CoM within Support Polygon
        contacts = utils.get_foot_contacts(
            self.model, self.data, self.foot_body_ids, config.CONTACT_FORCE_THRESHOLD
        )
        foot_positions = utils.get_foot_positions(self.data, self.foot_body_ids)
        support_polygon = utils.get_support_polygon(foot_positions, contacts)

        is_stable = utils.is_com_stable(com_pos[:2], support_polygon)
        reward_com_stable = config.W_COM_IN_SUPPORT if is_stable else 0.0

        # --- Sum Rewards ---
        total_reward = (
            reward_vel_x + penalty_vel_y + penalty_vel_z +
            reward_height +
            orientation_penalty + ang_vel_penalty +
            torque_penalty + action_rate_penalty +
            reward_com_stable
        )

        reward_info = {
            "r_vel_x": reward_vel_x,
            "p_vel_y": penalty_vel_y,
            "p_vel_z": penalty_vel_z,
            "r_height": reward_height,
            "p_orientation": orientation_penalty,
            "p_ang_vel": ang_vel_penalty,
            "p_torque": torque_penalty,
            "p_action_rate": action_rate_penalty,
            "r_com_stable": reward_com_stable,
        }

        return total_reward, reward_info

    def _check_termination(self):
        """Checks if the episode should terminate (e.g., robot fell)."""
        com_pos = utils.get_com_position(self.data, self.trunk_id)
        quat = utils.get_body_orientation(self.data, self.trunk_id)
        roll, pitch, _ = utils.quat_to_rpy(quat)

        # Fell if CoM is too low or if roll/pitch is too high
        is_fallen = (com_pos[2] < 0.15) or (abs(roll) > 1.0) or (abs(pitch) > 1.0)

        return is_fallen

    def step(self, action):
        """Run one timestep of the environment's dynamics."""

        # Clip action to be safe
        action = np.clip(action, self.action_low, self.action_high)

        # Set the target joint positions as the control signal
        self.data.ctrl[self.actuator_ids] = action

        # Step the simulation forward
        mujoco.mj_step(self.model, self.data, nstep=config.FRAME_SKIP)

        # Get new state, reward, and termination status
        observation = self._get_obs()
        terminated = self._check_termination()
        reward, reward_info = self._compute_reward(action)

        if terminated:
            reward += config.W_FALL # Add large fall penalty

        self.step_count += 1
        truncated = self.step_count >= config.MAX_EPISODE_STEPS

        return observation, reward, terminated, truncated, reward_info

    def reset(self, seed=None):
        """Reset the environment to an initial state."""
        super().reset(seed=seed)

        self.step_count = 0
        self.action_history = np.zeros(12)

        mujoco.mj_resetData(self.model, self.data)
        self.data.qpos[:] = self.init_qpos
        self.data.qvel[:] = self.init_qvel

        # Add small random noise to initial joint positions
        self.data.qpos[self.joint_qpos_ids] += self.np_random.uniform(
            -0.1, 0.1, size=len(self.joint_qpos_ids)
        )

        mujoco.mj_forward(self.model, self.data)

        observation = self._get_obs()
        return observation, {}

    def render(self):
        """Render the environment (if in 'rgb_array' mode)."""
        if self.renderer:
            self.renderer.update_scene(self.data)
            return self.renderer.render()
        return None

    def close(self):
        """Close the viewer."""
        if self.renderer:
            self.renderer = None

Writing go2_env.py


In [21]:
%%writefile train_ppo.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
import numpy as np
import time
import os

from go2_env import Go2Env  # Import the custom environment
import config  # Import config file

class ActorCritic(nn.Module):
    """PPO Actor-Critic network."""
    def __init__(self, obs_dim, action_dim, action_low, action_high):
        super().__init__()

        # Store action scaling parameters
        self.action_low = torch.tensor(action_low, dtype=torch.float32)
        self.action_high = torch.tensor(action_high, dtype=torch.float32)
        self.action_scale = (self.action_high - self.action_low) / 2.0
        self.action_bias = (self.action_high + self.action_low) / 2.0

        hidden_dim = 256

        # Critic network
        self.critic = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        # Actor network
        self.actor = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

        # Standard deviation for the action distribution
        self.actor_logstd = nn.Parameter(torch.zeros(1, action_dim))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        """
        Gets an action (and its log_prob) and the state value.
        If action is provided, it evaluates that action.
        If action is None, it samples a new action.
        """
        # Actor output is the mean of a distribution in unbounded space
        action_mean = self.actor(x)

        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)

        probs = Normal(action_mean, action_std)

        if action is None:
            # Sample new action from the unbounded distribution
            action_unbounded = probs.sample()
            # Squash to [-1, 1] using Tanh
            action_tanh = torch.tanh(action_unbounded)
            # Scale and shift to the correct action range
            action = self.action_bias + self.action_scale * action_tanh
        else:
            # Evaluate given action
            # We need to reverse the scaling to get the "tanh" value
            action_tanh = (action - self.action_bias) / self.action_scale
            # Clip to avoid numerical issues at the bounds
            action_tanh = torch.clamp(action_tanh, -0.9999, 0.9999)
            # Reverse the Tanh to get the unbounded action
            action_unbounded = torch.atanh(action_tanh)

        # Log-prob of the scaled action
        log_prob = probs.log_prob(action_unbounded)
        log_prob -= torch.log(self.action_scale * (1 - action_tanh.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        entropy = probs.entropy().sum(1)
        value = self.critic(x)

        return action, log_prob, entropy, value

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print("Initializing Go2 RL Environment...")
    env = Go2Env()

    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    print(f"Observation space dim: {obs_dim}")
    print(f"Action space dim: {action_dim}")

    # --- PPO Agent ---
    agent = ActorCritic(obs_dim, action_dim, env.action_low, env.action_high).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=config.PPO_LEARNING_RATE, eps=1e-5)

    # --- Storage ---
    num_steps = config.PPO_STEPS_PER_EPOCH

    obs = torch.zeros((num_steps, obs_dim)).to(device)
    actions = torch.zeros((num_steps, action_dim)).to(device)
    logprobs = torch.zeros(num_steps).to(device)
    rewards = torch.zeros(num_steps).to(device)
    dones = torch.zeros(num_steps).to(device)
    values = torch.zeros(num_steps).to(device)

    print("Starting PPO Training...")
    start_time = time.time()

    next_obs, _ = env.reset()
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(1).to(device)

    for epoch in range(config.PPO_EPOCHS):
        epoch_rewards = []
        epoch_reward_components = {}

        for step in range(num_steps):
            obs[step] = next_obs
            dones[step] = next_done

            # Get action and value from agent
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs.unsqueeze(0))
                values[step] = value.flatten()

            actions[step] = action.squeeze(0)
            logprobs[step] = logprob.squeeze()

            # Step the environment
            next_obs_np, reward, terminated, truncated, info = env.step(action.cpu().numpy().squeeze(0))
            epoch_rewards.append(reward)

            # Log reward components
            for key, val in info.items():
                if key not in epoch_reward_components:
                    epoch_reward_components[key] = []
                epoch_reward_components[key].append(val)

            rewards[step] = torch.tensor(reward, device=device).view(-1)
            next_obs = torch.Tensor(next_obs_np).to(device)
            next_done = torch.tensor(float(terminated or truncated), device=device)

            if next_done:
                epoch_rewards = []
                epoch_reward_components = {}
                next_obs, _ = env.reset()
                next_obs = torch.Tensor(next_obs).to(device)

        # --- Calculate Advantages (GAE) ---
        with torch.no_grad():
            next_value = agent.get_value(next_obs.unsqueeze(0)).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(num_steps)):
                if t == num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]

                delta = rewards[t] + config.PPO_GAMMA * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + config.PPO_GAMMA * config.PPO_LAM * nextnonterminal * lastgaelam
            returns = advantages + values

        # --- Update Policy ---
        b_obs = obs.reshape((-1,) + env.observation_space.shape)
        b_actions = actions.reshape((-1,) + env.action_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)

        # Normalize advantages
        b_advantages = (b_advantages - b_advantages.mean()) / (b_advantages.std() + 1e-8)

        b_inds = np.arange(num_steps)
        for _ in range(config.PPO_UPDATE_EPOCHS):
            np.random.shuffle(b_inds)
            for start in range(0, num_steps, config.PPO_MINIBATCH_SIZE):
                end = start + config.PPO_MINIBATCH_SIZE
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(
                    b_obs[mb_inds], b_actions[mb_inds]
                )
                logratio = newlogprob.squeeze() - b_logprobs[mb_inds]
                ratio = logratio.exp()

                # Policy loss
                pg_loss1 = -b_advantages[mb_inds] * ratio
                pg_loss2 = -b_advantages[mb_inds] * torch.clamp(
                    ratio, 1 - config.PPO_CLIP, 1 + config.PPO_CLIP
                )
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                v_loss = 0.5 * ((newvalue.squeeze() - b_returns[mb_inds]) ** 2).mean()

                # Entropy loss
                entropy_loss = entropy.mean()

                # Total loss
                loss = pg_loss - 0.01 * entropy_loss + v_loss * 0.5

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
                optimizer.step()

        # --- Logging ---
        num_episodes = dones.sum().item()
        if num_episodes == 0:
            avg_reward = np.nan # Avoid division by zero if no episodes finished
        else:
            avg_reward = rewards.sum().item() / num_episodes

        print(f"Epoch {epoch+1}/{config.PPO_EPOCHS} | Avg. Ep Reward: {avg_reward:.2f} | Time: {time.time()-start_time:.2f}s")

        # Log mean of reward components
        for key, val_list in epoch_reward_components.items():
            if val_list:
                print(f"  ... avg {key}: {np.mean(val_list):.3f}")

    env.close()
    print("Training finished.")

    # Save the trained policy
    model_path = "ppo_go2_policy.pth"
    torch.save(agent.state_dict(), model_path)
    print(f"Trained policy saved to {model_path}")

if __name__ == "__main__":
    main()


Writing train_ppo.py


## Step 4: Run Training

This final cell executes the training script. It will print the average reward for each epoch. Training will take a while!

In [22]:
!python train_ppo.py

Using device: cuda
Initializing Go2 RL Environment...
  gym.logger.warn(
  gym.logger.warn(
Observation space dim: 38
Action space dim: 12
Starting PPO Training...
Traceback (most recent call last):
  File "/content/train_ppo.py", line 244, in <module>
    main()
  File "/content/train_ppo.py", line 134, in main
    action, logprob, _, value = agent.get_action_and_value(next_obs.unsqueeze(0))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/train_ppo.py", line 69, in get_action_and_value
    action = self.action_bias + self.action_scale * action_tanh
                                ~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
