In [1]:
import numpy as np
from gym import utils
from gym.envs.mujoco import MujocoEnv
from gym.spaces import Discrete, Box

In [2]:
class CustomDiscreteManipulatorEnv(MujocoEnv, utils.EzPickle):
    """
    Custom Gym Environment for a 3-DOF robotic arm with discrete action space
    and an additional action for pushing a button with the end-effector.
    """
    metadata = {
        "render_modes": [
            "human",
            "rgb_array",
            "depth_array",
        ],
        "render_fps": 50,
    }

    def __init__(self, **kwargs):
        utils.EzPickle.__init__(self, **kwargs)
        
        # Observation space: 19-dimensional continuous space
        observation_space = Box(low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64)

        # Discrete action space: 54 possible actions (3^3 for torques, plus 1 for the button)
        action_space = Discrete(54)
        
        # Initial torques (in the range [-1, 1]) for the three joints
        self.torques = np.zeros(3)

        # Track the button state (0: off, 1: on) and end-effector push state
        self.button_state = 0
        self.end_effector_pushed = 0  # Tracks if the end-effector has pushed the button

        MujocoEnv.__init__(
            self, "robotic_arm.xml", 2, observation_space=observation_space, action_space=action_space, **kwargs
        )

    def step(self, action):
        # Reset the end-effector push state before every step
        self.end_effector_pushed = 0
        
        # Action mapping for torque changes on the joints
        action_map = [
            (-0.1, -0.1, -0.1), (-0.1, -0.1, 0), (-0.1, -0.1, 0.1),
            (-0.1, 0, -0.1), (-0.1, 0, 0), (-0.1, 0, 0.1),
            (-0.1, 0.1, -0.1), (-0.1, 0.1, 0), (-0.1, 0.1, 0.1),
            (0, -0.1, -0.1), (0, -0.1, 0), (0, -0.1, 0.1),
            (0, 0, -0.1), (0, 0, 0), (0, 0, 0.1),
            (0, 0.1, -0.1), (0, 0.1, 0), (0, 0.1, 0.1),
            (0.1, -0.1, -0.1), (0.1, -0.1, 0), (0.1, -0.1, 0.1),
            (0.1, 0, -0.1), (0.1, 0, 0), (0.1, 0, 0.1),
            (0.1, 0.1, -0.1), (0.1, 0.1, 0), (0.1, 0.1, 0.1)
        ]
        
        # If action is < 27, it means it’s a torque action
        if action < 27:
            torque_delta = np.array(action_map[action])
            self.torques = np.clip(self.torques + torque_delta, -1.0, 1.0)
        else:
            # Action 27-53 represents the torque + button operation (action-27)
            torque_delta = np.array(action_map[action - 27])
            self.torques = np.clip(self.torques + torque_delta, -1.0, 1.0)
            self._attempt_push_button()  # Simulate the button push action

        # Perform simulation with the updated torques
        self.do_simulation(self.torques, self.frame_skip)
        
        # Observation
        ob = self._get_obs()
        
        # Calculate the reward based on the current state
        reward, done = self._calculate_reward()

        # Return observations, reward, done flag, and additional info
        info = {"button_state": self.button_state}
        return ob, reward, done, info

    def _attempt_push_button(self):
        """
        Handles the logic for the button push and updates the end-effector state.
        """
        end_effector_pos = self.get_body_com("link_3")
        target_pos = self.get_body_com("target")
        relative_distance = np.linalg.norm(end_effector_pos - target_pos)

        # If the end-effector is close enough, mark it as pushed
        if relative_distance < 0.05:
            self.end_effector_pushed = 1  # End-effector is in the "pushed" state
            self.button_state = 1  # Button is switched to "on"
        else:
            self.button_state = 0  # Invalid push, button remains "off"

    def _calculate_reward(self):
        """
        Computes the reward based on the current state of the environment.
        """
        end_effector_pos = self.get_body_com("link_3")
        target_pos = self.get_body_com("target")
        relative_distance = np.linalg.norm(end_effector_pos - target_pos)

        reward = 0
        
        # 1. Dense negative reward based on the relative distance
        reward += -0.01 * relative_distance
        
        # 2. Positive reward of 1 when the end-effector is close enough to the target
        if relative_distance < 0.05:
            reward += 1
        
        # 3. Penalty for invalid button press
        if self.end_effector_pushed == 1 and relative_distance >= 0.05:
            reward -= 1  # Penalty for invalid push action

        # 4. Large positive reward when the button is successfully pushed
        if self.button_state == 1:
            reward += 100
            done = True
        else:
            done = False

        return reward, done

    def reset_model(self):
        """
        Reset the environment to its initial state.
        """
        # Randomly initialize joint positions and velocities
        qpos = self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
        qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
        self.set_state(qpos, qvel)

        # Reset torques, button state, and end-effector push state
        self.torques = np.zeros(3)
        self.button_state = 0
        self.end_effector_pushed = 0

        return self._get_obs()

    def _get_obs(self):
        """
        Returns the current observation of the environment.
        Includes the end-effector push state.
        """
        # Cosine and sine of the angles of the three joints
        theta = self.data.qpos.flat[:3]
        
        # End-effector position (x, y, z)
        end_effector_pos = self.get_body_com("link_3")
        
        # Target position (x, y, z)
        target_pos = self.get_body_com("target")
        
        # Difference between end-effector and target positions
        position_diff = end_effector_pos - target_pos
        
        # Angular velocities of the 3 joints
        angular_velocities = self.data.qvel.flat[:3]

        # Concatenate all the observation elements
        observation = np.concatenate([
            np.cos(theta),              # Cosine of joint angles
            np.sin(theta),              # Sine of joint angles
            target_pos,                 # Target position (x, y, z)
            end_effector_pos,           # End-effector position (x, y, z)
            angular_velocities,         # Angular velocities
            position_diff,              # Difference in positions (x, y, z)
            [self.end_effector_pushed]  # End-effector push state (0 or 1)
        ])
        
        return observation

    def viewer_setup(self):
        """
        Sets up the viewer for rendering.
        """
        assert self.viewer is not None
        self.viewer.cam.trackbodyid = 0
