-
Notifications
You must be signed in to change notification settings - Fork 18
/
dm_control_compatibility.py
231 lines (190 loc) · 8.95 KB
/
dm_control_compatibility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""Wrapper to convert a dm_env environment into a gymnasium compatible environment.
Taken from
https://github.com/ikostrikov/dmcgym/blob/main/dmcgym/env.py
and modified to modern gymnasium API
"""
from __future__ import annotations
import math
from enum import Enum
from typing import Any, Callable, Optional
import dm_env
import gymnasium
import numpy as np
from dm_control import composer
from dm_control.mujoco.engine import Physics as MujocoEnginePhysics
from dm_control.rl import control
from gymnasium.core import ObsType
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from gymnasium.utils import EzPickle
from mujoco._structs import MjvScene
from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space
class EnvType(Enum):
"""The environment type."""
COMPOSER = 0
RL_CONTROL = 1
class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray], EzPickle):
"""This compatibility wrapper converts a dm-control environment into a gymnasium environment.
Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo physics.
Dm-control actually has two Environments classes, `dm_control.composer.Environment` and
`dm_control.rl.control.Environment` that while both inherit from `dm_env.Environment`, they differ
in implementation.
For environment in `dm_control.suite` are `dm-control.rl.control.Environment` while
dm-control locomotion and manipulation environments use `dm-control.composer.Environment`.
This wrapper supports both Environment class through determining the base environment type.
Note:
dm-control uses `np.random.RandomState`, a legacy random number generator while gymnasium
uses `np.random.Generator`, therefore the return type of `np_random` is different from expected.
"""
metadata = {
"render_modes": ["human", "rgb_array", "depth_array", "multi_camera"],
"render_fps": 10, # this value is updated to use the `env.control_timesteps() * 1000`
}
def __init__(
self,
env: composer.Environment | control.Environment | dm_env.Environment,
render_mode: str | None = None,
render_kwargs: dict[str, Any] | None = None,
):
"""Initialises the environment with a render mode along with render information.
Note: this wrapper supports multi-camera rendering via the `render_mode` argument (render_mode = "multi_camera")
For more information on DM Control rendering, see https://github.com/deepmind/dm_control/blob/main/dm_control/mujoco/engine.py#L178
Args:
env (Optional[composer.Environment | control.Environment | dm_env.Environment]): DM Control env to wrap
render_mode (Optional[str]): rendering mode (options: "human", "rgb_array", "depth_array", "multi_camera")
render_kwargs (Optional[dict[str, Any]]): Additional keyword arguments for rendering.
For the width, height and camera id use "width", "height" and "camera_id" respectively.
See the dm_control implementation for the list of possible kwargs, https://github.com/deepmind/dm_control/blob/330c91f41a21eacadcf8316f0a071327e3f5c017/dm_control/mujoco/engine.py#L178
Note: kwargs are not used for human rendering, which uses simpler Gymnasium MuJoCo rendering.
"""
EzPickle.__init__(self, env, render_mode, render_kwargs)
self._env: Any = env
self.env_type = self._find_env_type(env)
self.metadata["render_fps"] = self._env.control_timestep() * 1000
self.observation_space = dm_spec2gym_space(env.observation_spec())
self.action_space = dm_spec2gym_space(env.action_spec())
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
if render_kwargs is None:
render_kwargs = {}
self.render_kwargs = render_kwargs
if self.render_mode == "human":
# We use the gymnasium mujoco rendering, dm-control provides more complex rendering options.
self.viewer = MujocoRenderer(
self._env.physics.model.ptr, self._env.physics.data.ptr
)
@property
def dt(self):
"""Returns the environment control timestep which is equivalent to the number of actions per second."""
return self._env.control_timestep()
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the dm-control environment."""
super().reset(seed=seed)
if seed is not None:
self.np_random = np.random.RandomState(seed=seed)
timestep = self._env.reset()
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)
if self.render_mode == "human":
self.viewer.close()
self.viewer = MujocoRenderer(
self._env.physics.model.ptr, self._env.physics.data.ptr
)
return obs, info
def step(
self, action: np.ndarray
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-control environment."""
timestep = self._env.step(action)
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)
if self.render_mode == "human":
self.viewer.render(self.render_mode)
return (
obs,
reward,
terminated,
truncated,
info,
)
def render(self) -> np.ndarray | None:
"""Renders the dm-control env."""
if self.render_mode == "rgb_array":
return self._env.physics.render(
**self.render_kwargs,
)
elif self.render_mode == "depth_array":
return self._env.physics.render(
depth=True,
**self.render_kwargs,
)
elif self.render_mode == "multi_camera":
physics = self._env.physics
num_cameras = physics.model.ncam
num_columns = int(math.ceil(math.sqrt(num_cameras)))
num_rows = int(math.ceil(float(num_cameras) / num_columns))
# 240 and 320 are the default values in dm-control
height = self.render_kwargs.get("height", 240)
width = self.render_kwargs.get("width", 320)
frame = np.zeros(
(num_rows * height, num_columns * width, 3),
dtype=np.uint8,
)
assert (
"camera_id" not in self.render_kwargs
), "The camera_id is specified in `multi_camera` render so don't include it in the render_kwargs"
for col in range(num_columns):
for row in range(num_rows):
camera_id = row * num_columns + col
if camera_id >= num_cameras:
break
subframe = physics.render(
camera_id=camera_id,
**self.render_kwargs,
)
frame[
row * height : (row + 1) * height,
col * width : (col + 1) * width,
] = subframe
return frame
def close(self):
"""Closes the environment."""
self._env.physics.free()
self._env.close()
if hasattr(self, "viewer"):
self.viewer.close()
@property
def np_random(self) -> np.random.RandomState:
"""This should be np.random.Generator but dm-control uses np.random.RandomState."""
if self.env_type is EnvType.RL_CONTROL:
return self._env.task._random
else:
return self._env._random_state
@np_random.setter
def np_random(self, value: np.random.RandomState):
if self.env_type is EnvType.RL_CONTROL:
self._env.task._random = value
else:
self._env._random_state = value
def __getattr__(self, item: str):
"""If the attribute is missing, try getting the attribute from dm_control env."""
return getattr(self._env, item)
def _find_env_type(self, env) -> EnvType:
"""Tries to discover env types, in particular for environments with wrappers."""
if isinstance(env, composer.Environment):
return EnvType.COMPOSER
elif isinstance(env, control.Environment):
return EnvType.RL_CONTROL
else:
assert isinstance(env, dm_env.Environment)
if hasattr(env, "_env"):
return self._find_env_type(
env._env # pyright: ignore[reportGeneralTypeIssues]
)
elif hasattr(env, "env"):
return self._find_env_type(
env.env # pyright: ignore[reportGeneralTypeIssues]
)
else:
raise AttributeError(
f"Can't know the dm-control environment type, actual type: {type(env)}"
)