In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import gymnasium as gym
from gymnasium import spaces

import collections

from dm_control import mujoco, viewer, suite
from dm_control.rl import control
from dm_control.suite import base, common
from dm_control.suite.utils import randomizers
from dm_control.utils import rewards
from dm_control.utils import io as resources
from dm_env import specs
import numpy as np
import os

from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
device

'cuda'

In [3]:
TASKS = [('reach_top_left', np.array([-0.15, 0.15, 0.01])),
         ('reach_top_right', np.array([0.15, 0.15, 0.01])),
         ('reach_bottom_left', np.array([-0.15, -0.15, 0.01])),
         ('reach_bottom_right', np.array([0.15, -0.15, 0.01]))]


class MultiTaskPointMassMaze(base.Task):
    """A point_mass `Task` to reach target with smooth reward."""
    def __init__(self, target_id, random=None):
        """Initialize an instance of `PointMassMaze`.

    Args:
      randomize_gains: A `bool`, whether to randomize the actuator gains.
      random: Optional, either a `numpy.random.RandomState` instance, an
        integer seed for creating a new `RandomState`, or None to select a seed
        automatically (default).
    """
        self._target = TASKS[target_id][1]
        super().__init__(random=random)

    def initialize_episode(self, physics):
        """Sets the state of the environment at the start of each episode.

       If _randomize_gains is True, the relationship between the controls and
       the joints is randomized, so that each control actuates a random linear
       combination of joints.

    Args:
      physics: An instance of `mujoco.Physics`.
    """
        randomizers.randomize_limited_and_rotational_joints(
            physics, self.random)
        physics.data.qpos[0] = np.random.uniform(-0.29, -0.15)
        physics.data.qpos[1] = np.random.uniform(0.15, 0.29)
        #import ipdb; ipdb.set_trace()
        physics.named.data.geom_xpos['target'][:] = self._target
        

        super().initialize_episode(physics)

    def get_observation(self, physics):
        """Returns an observation of the state."""
        obs = collections.OrderedDict()
        obs['position'] = physics.position()
        obs['velocity'] = physics.velocity()
        return obs
    
    def get_reward_spec(self):
        return specs.Array(shape=(1,), dtype=np.float32, name='reward')

    def get_reward(self, physics):
        """Returns a reward to the agent."""
        target_size = .015
        control_reward = rewards.tolerance(physics.control(), margin=1,
                                       value_at_margin=0,
                                       sigmoid='quadratic').mean()
        small_control = (control_reward + 4) / 5
        near_target = rewards.tolerance(physics.mass_to_target_dist(self._target),
                                bounds=(0, target_size), margin=target_size)
        reward = near_target * small_control
        return reward



class Physics(mujoco.Physics):
    """physics for the point_mass domain."""

    def mass_to_target_dist(self, target):
        """Returns the distance from mass to the target."""
        d = target - self.named.data.geom_xpos['pointmass']
        return np.linalg.norm(d)

In [4]:

def convert_dm_control_to_gym_space(dm_control_space):
    r"""Convert dm_control space to gym space. """
    if isinstance(dm_control_space, specs.BoundedArray):
        space = spaces.Box(low=dm_control_space.minimum, 
                           high=dm_control_space.maximum, 
                           dtype=dm_control_space.dtype)
        assert space.shape == dm_control_space.shape
        return space
    elif isinstance(dm_control_space, specs.Array) and not isinstance(dm_control_space, specs.BoundedArray):
        space = spaces.Box(low=-float('inf'), 
                           high=float('inf'), 
                           shape=dm_control_space.shape, 
                           dtype=dm_control_space.dtype)
        return space
    elif isinstance(dm_control_space, dict):
        space = spaces.Dict({key: convert_dm_control_to_gym_space(value)
                             for key, value in dm_control_space.items()})
        return space


class DMSuiteEnv(gym.Env):
    def __init__(self, env):
        self.env = env
        self.metadata = {'render.modes': ['human', 'rgb_array'],
                         'video.frames_per_second': round(1.0/self.env.control_timestep())}

        self.observation_space = convert_dm_control_to_gym_space(self.env.observation_spec())
        self.action_space = convert_dm_control_to_gym_space(self.env.action_spec())
        self.viewer = None
    
    def seed(self, seed):
        return self.env.task.random.seed(seed)
    
    def step(self, action):
        timestep = self.env.step(action)
        observation = timestep.observation
        reward = timestep.reward
        done = timestep.last()
        info = {}
        truncated = False
        return observation, reward, done, truncated, info
    
    def reset(self, seed=None, options=None):
        timestep = self.env.reset()
        return timestep.observation, {}
    
    def render(self, mode='human', **kwargs):
        if 'camera_id' not in kwargs:
            kwargs['camera_id'] = 0  # Tracking camera
        use_opencv_renderer = kwargs.pop('use_opencv_renderer', False)
        
        img = self.env.physics.render(**kwargs)
        if mode == 'rgb_array':
            return img
        elif mode == 'human':
            if self.viewer is None:
                if not use_opencv_renderer:
                    from gym.envs.classic_control import rendering
                    self.viewer = rendering.SimpleImageViewer(maxwidth=1024)
                else:
                    from . import OpenCVImageViewer
                    self.viewer = OpenCVImageViewer()
            self.viewer.imshow(img)
            return self.viewer.isopen
        else:
            raise NotImplementedError

    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None
        return self.env.close()
    
    
class FlattenObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        
        # Flatten the observation space by combining the shapes of each dictionary entry
        self.observation_space = gym.spaces.Box(
            low=-np.inf, 
            high=np.inf, 
            shape=(self.flatten_observation_space_shape(),), 
            dtype=np.float32
        )

    def flatten_observation_space_shape(self):
        # Calculate the total number of elements in the observation space after flattening
        total_shape = 0
        for key in self.observation_space.spaces:
            total_shape += np.prod(self.observation_space.spaces[key].shape)
        return total_shape

    def observation(self, obs):
        # Flatten the dictionary of numpy arrays into a single vector
        return np.concatenate([obs[key].flatten() for key in obs], axis=0)
    
def flat_obs(obs):
    return np.concatenate([obs[key].flatten() for key in obs], axis=0)

In [36]:
from types import MethodType

target_id = 3

xml = resources.GetResource(f'mazes/point_mass_maze_{TASKS[target_id][0]}.xml')
physics = Physics.from_xml_string(xml, common.ASSETS)
task = MultiTaskPointMassMaze(target_id=target_id)

dm_env = control.Environment(
    physics,
    task,
    time_limit=20,
)

def get_reward(self, physics):
    """Returns a reward to the agent."""
    # print(physics.named.data.geom_xpos['pointmass'])
    target_size = physics.named.data.geom_xpos['target', 0]
    distance = np.linalg.norm(physics.named.data.geom_xpos['target'] - physics.named.data.geom_xpos['pointmass'])
    distance2 = rewards.tolerance(distance, bounds=(0, target_size), margin=target_size)
    # print(distance, distance2)
    reward = 1 - 2*distance
    return reward

dm_env.task.get_reward = MethodType(get_reward, dm_env.task.get_reward)

target_id = 0
xml = resources.GetResource(f'mazes/point_mass_maze_{TASKS[target_id][0]}.xml')
dm_env.physics.reload_from_xml_string(xml, common.ASSETS)

viewer.launch(dm_env)

In [44]:
viewer.launch(dm_env)

In [32]:
[s for s in dm_env.physics.named.data.__dir__() if 'geom' in s]

['geom_xmat', 'geom_xpos']

In [45]:
target_size = 0.015
distance = np.linalg.norm(dm_env.physics.named.data.geom_xpos['target'] - dm_env.physics.named.data.geom_xpos['pointmass'])
goal_reached = rewards.tolerance(distance, bounds=(0, target_size), margin=target_size)
if goal_reached > 0.9:
    print('done')
goal_reached


1.9223131121790506e-48