-
Notifications
You must be signed in to change notification settings - Fork 52
/
purerl.py
110 lines (93 loc) 路 3.74 KB
/
purerl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import jax
import jax.numpy as jnp
import chex
import numpy as np
from flax import struct
from functools import partial
from gymnax.environments import environment, spaces
from typing import Optional, Tuple, Union
class GymnaxWrapper(object):
"""Base class for Gymnax wrappers."""
def __init__(self, env):
self._env = env
# provide proxy access to regular attributes of wrapped object
def __getattr__(self, name):
return getattr(self._env, name)
class FlattenObservationWrapper(GymnaxWrapper):
"""Flatten the observations of the environment."""
def __init__(self, env: environment.Environment):
super().__init__(env)
def observation_space(self, params) -> spaces.Box:
assert isinstance(
self._env.observation_space(params), spaces.Box
), "Only Box spaces are supported for now."
return spaces.Box(
low=self._env.observation_space(params).low,
high=self._env.observation_space(params).high,
shape=(np.prod(self._env.observation_space(params).shape),),
dtype=self._env.observation_space(params).dtype,
)
@partial(jax.jit, static_argnums=(0,))
def reset(
self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
) -> Tuple[chex.Array, environment.EnvState]:
obs, state = self._env.reset(key, params)
obs = jnp.reshape(obs, (-1,))
return obs, state
@partial(jax.jit, static_argnums=(0,))
def step(
self,
key: chex.PRNGKey,
state: environment.EnvState,
action: Union[int, float],
params: Optional[environment.EnvParams] = None,
) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
obs, state, reward, done, info = self._env.step(
key, state, action, params
)
obs = jnp.reshape(obs, (-1,))
return obs, state, reward, done, info
@struct.dataclass
class LogEnvState:
env_state: environment.EnvState
episode_returns: float
episode_lengths: int
returned_episode_returns: float
returned_episode_lengths: int
class LogWrapper(GymnaxWrapper):
"""Log the episode returns and lengths."""
def __init__(self, env: environment.Environment):
super().__init__(env)
@partial(jax.jit, static_argnums=(0,))
def reset(
self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
) -> Tuple[chex.Array, environment.EnvState]:
obs, env_state = self._env.reset(key, params)
state = LogEnvState(env_state, 0, 0, 0, 0)
return obs, state
@partial(jax.jit, static_argnums=(0,))
def step(
self,
key: chex.PRNGKey,
state: environment.EnvState,
action: Union[int, float],
params: Optional[environment.EnvParams] = None,
) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
obs, env_state, reward, done, info = self._env.step(
key, state.env_state, action, params
)
new_episode_return = state.episode_returns + reward
new_episode_length = state.episode_lengths + 1
state = LogEnvState(
env_state=env_state,
episode_returns=new_episode_return * (1 - done),
episode_lengths=new_episode_length * (1 - done),
returned_episode_returns=state.returned_episode_returns * (1 - done)
+ new_episode_return * done,
returned_episode_lengths=state.returned_episode_lengths * (1 - done)
+ new_episode_length * done,
)
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["returned_episode"] = done
return obs, state, reward, done, info