In [29]:
import collections

from dm_control import mujoco
from dm_control import viewer, suite
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.suite.utils import randomizers
from dm_control.utils import rewards
from dm_control.utils import io as resources
from dm_env import specs
import numpy as np
import os

from tqdm import tqdm

In [30]:
# Load the environment
env = suite.load(domain_name="point_mass", task_name="easy")

def random_policy(time_step):
    return env.action_spec().minimum + (env.action_spec().maximum - env.action_spec().minimum) * np.random.rand(*env.action_spec().shape)

# Launch the viewer
# viewer.launch(env, policy=random_policy)

In [37]:
TASKS = [('reach_top_left', np.array([-0.15, 0.15, 0.01])),
         ('reach_top_right', np.array([0.15, 0.15, 0.01])),
         ('reach_bottom_left', np.array([-0.15, -0.15, 0.01])),
         ('reach_bottom_right', np.array([0.15, -0.15, 0.01]))]


class MultiTaskPointMassMaze(base.Task):
    """A point_mass `Task` to reach target with smooth reward."""
    def __init__(self, target_id, random=None):
        """Initialize an instance of `PointMassMaze`.

    Args:
      randomize_gains: A `bool`, whether to randomize the actuator gains.
      random: Optional, either a `numpy.random.RandomState` instance, an
        integer seed for creating a new `RandomState`, or None to select a seed
        automatically (default).
    """
        self._target = TASKS[target_id][1]
        super().__init__(random=random)

    def initialize_episode(self, physics):
        """Sets the state of the environment at the start of each episode.

       If _randomize_gains is True, the relationship between the controls and
       the joints is randomized, so that each control actuates a random linear
       combination of joints.

    Args:
      physics: An instance of `mujoco.Physics`.
    """
        randomizers.randomize_limited_and_rotational_joints(
            physics, self.random)
        physics.data.qpos[0] = np.random.uniform(-0.29, -0.15)
        physics.data.qpos[1] = np.random.uniform(0.15, 0.29)
        #import ipdb; ipdb.set_trace()
        physics.named.data.geom_xpos['target'][:] = self._target
        

        super().initialize_episode(physics)

    def get_observation(self, physics):
        """Returns an observation of the state."""
        obs = collections.OrderedDict()
        obs['position'] = physics.position()
        obs['velocity'] = physics.velocity()
        return obs
    
    def get_reward_spec(self):
        return specs.Array(shape=(1,), dtype=np.float32, name='reward')

    def get_reward(self, physics):
        """Returns a reward to the agent."""
        target_size = .015
        control_reward = rewards.tolerance(physics.control(), margin=1,
                                       value_at_margin=0,
                                       sigmoid='quadratic').mean()
        small_control = (control_reward + 4) / 5
        near_target = rewards.tolerance(physics.mass_to_target_dist(self._target),
                                bounds=(0, target_size), margin=target_size)
        reward = near_target * small_control
        return reward


In [38]:
class Physics(mujoco.Physics):
    """physics for the point_mass domain."""

    def mass_to_target_dist(self, target):
        """Returns the distance from mass to the target."""
        d = target - self.named.data.geom_xpos['pointmass']
        return np.linalg.norm(d)

In [39]:
import gymnasium as gym
from gymnasium import spaces

from dm_control import suite
from dm_env import specs


def convert_dm_control_to_gym_space(dm_control_space):
    r"""Convert dm_control space to gym space. """
    if isinstance(dm_control_space, specs.BoundedArray):
        space = spaces.Box(low=dm_control_space.minimum, 
                           high=dm_control_space.maximum, 
                           dtype=dm_control_space.dtype)
        assert space.shape == dm_control_space.shape
        return space
    elif isinstance(dm_control_space, specs.Array) and not isinstance(dm_control_space, specs.BoundedArray):
        space = spaces.Box(low=-float('inf'), 
                           high=float('inf'), 
                           shape=dm_control_space.shape, 
                           dtype=dm_control_space.dtype)
        return space
    elif isinstance(dm_control_space, dict):
        space = spaces.Dict({key: convert_dm_control_to_gym_space(value)
                             for key, value in dm_control_space.items()})
        return space


class DMSuiteEnv(gym.Env):
    def __init__(self, env):
        self.env = env
        self.metadata = {'render.modes': ['human', 'rgb_array'],
                         'video.frames_per_second': round(1.0/self.env.control_timestep())}

        self.observation_space = convert_dm_control_to_gym_space(self.env.observation_spec())
        self.action_space = convert_dm_control_to_gym_space(self.env.action_spec())
        self.viewer = None
    
    def seed(self, seed):
        return self.env.task.random.seed(seed)
    
    def step(self, action):
        timestep = self.env.step(action)
        observation = timestep.observation
        reward = timestep.reward
        done = timestep.last()
        info = {}
        truncated = False
        return observation, reward, done, truncated, info
    
    def reset(self, seed=None, options=None):
        timestep = self.env.reset()
        return timestep.observation, {}
    
    def render(self, mode='human', **kwargs):
        if 'camera_id' not in kwargs:
            kwargs['camera_id'] = 0  # Tracking camera
        use_opencv_renderer = kwargs.pop('use_opencv_renderer', False)
        
        img = self.env.physics.render(**kwargs)
        if mode == 'rgb_array':
            return img
        elif mode == 'human':
            if self.viewer is None:
                if not use_opencv_renderer:
                    from gym.envs.classic_control import rendering
                    self.viewer = rendering.SimpleImageViewer(maxwidth=1024)
                else:
                    from . import OpenCVImageViewer
                    self.viewer = OpenCVImageViewer()
            self.viewer.imshow(img)
            return self.viewer.isopen
        else:
            raise NotImplementedError

    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None
        return self.env.close()
    
    
class FlattenObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        
        # Flatten the observation space by combining the shapes of each dictionary entry
        self.observation_space = gym.spaces.Box(
            low=-np.inf, 
            high=np.inf, 
            shape=(self.flatten_observation_space_shape(),), 
            dtype=np.float32
        )

    def flatten_observation_space_shape(self):
        # Calculate the total number of elements in the observation space after flattening
        total_shape = 0
        for key in self.observation_space.spaces:
            total_shape += np.prod(self.observation_space.spaces[key].shape)
        return total_shape

    def observation(self, obs):
        # Flatten the dictionary of numpy arrays into a single vector
        return np.concatenate([obs[key].flatten() for key in obs], axis=0)

In [50]:
target_id = 3

xml = resources.GetResource(f'mazes/point_mass_maze_{TASKS[target_id][0]}.xml')
physics = Physics.from_xml_string(xml, common.ASSETS)
task = MultiTaskPointMassMaze(target_id=target_id)

dm_env = control.Environment(
    physics,
    task,
    time_limit=20,
)

viewer.launch(dm_env, policy=random_policy)

In [7]:
def make_env():
    # dm_env = control.Environment(
    #     physics,
    #     task,
    #     time_limit=20,
    # )
    dm_env = suite.load(domain_name="point_mass", task_name="easy")
    env = DMSuiteEnv(dm_env)
    env = FlattenObservation(env)
    env = gym.wrappers.TimeLimit(env, max_episode_steps=20)
    return env

env = gym.vector.SyncVectorEnv([lambda: make_env() for _ in tqdm(range(16))])

100%|██████████| 16/16 [00:00<00:00, 383479.22it/s]


In [136]:
env.reset()

(array([[-0.26691705, -0.18439485,  0.        ,  0.        ],
        [ 0.26258624, -0.01790569,  0.        ,  0.        ],
        [-0.00136795, -0.18059953,  0.        ,  0.        ],
        [ 0.04473701,  0.14295596,  0.        ,  0.        ],
        [-0.16896723, -0.2574089 ,  0.        ,  0.        ],
        [ 0.23174173, -0.16264512,  0.        ,  0.        ],
        [-0.10622237,  0.08820276,  0.        ,  0.        ],
        [ 0.23072541,  0.13541777,  0.        ,  0.        ],
        [ 0.26880175,  0.01268177,  0.        ,  0.        ],
        [ 0.01134793,  0.1385897 ,  0.        ,  0.        ],
        [ 0.14140862, -0.19728428,  0.        ,  0.        ],
        [-0.09924069,  0.11472593,  0.        ,  0.        ],
        [-0.06487781, -0.02408356,  0.        ,  0.        ],
        [-0.02990355,  0.14837378,  0.        ,  0.        ],
        [-0.16277723, -0.12779978,  0.        ,  0.        ],
        [ 0.15549557,  0.23205476,  0.        ,  0.        ]],
       

In [140]:
for _ in tqdm(range(1000)):
    action = env.action_space.sample()
    next_state, reward, done, truncated, info = env.step(action)
    # print(next_state.mean(axis=0), next_state.std(axis=0))

100%|██████████| 1000/1000 [00:00<00:00, 1186.40it/s]


In [132]:
next_state.min(axis=0), next_state.max(axis=0)

(array([-0.287916  , -0.2885625 , -0.02364   , -0.02546857], dtype=float32),
 array([0.2897761 , 0.2862624 , 0.02604527, 0.02253521], dtype=float32))

In [53]:
env.flatten_observation_space_shape

AttributeError: 'SyncVectorEnv' object has no attribute 'flatten_observation_space_shape'