/
cartpole.py
executable file
Β·183 lines (155 loc) Β· 6 KB
/
cartpole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""JAX compatible version of CartPole-v1 OpenAI gym environment."""
from typing import Any, Dict, Optional, Tuple, Union
import chex
from flax import struct
import jax
from jax import lax
import jax.numpy as jnp
from gymnax.environments import environment
from gymnax.environments import spaces
@struct.dataclass
class EnvState(environment.EnvState):
x: jnp.ndarray
x_dot: jnp.ndarray
theta: jnp.ndarray
theta_dot: jnp.ndarray
time: int
@struct.dataclass
class EnvParams(environment.EnvParams):
gravity: float = 9.8
masscart: float = 1.0
masspole: float = 0.1
total_mass: float = 1.0 + 0.1 # (masscart + masspole)
length: float = 0.5
polemass_length: float = 0.05 # (masspole * length)
force_mag: float = 10.0
tau: float = 0.02
theta_threshold_radians: float = 12 * 2 * jnp.pi / 360
x_threshold: float = 2.4
max_steps_in_episode: int = 500 # v0 had only 200 steps!
class CartPole(environment.Environment[EnvState, EnvParams]):
"""JAX Compatible version of CartPole-v1 OpenAI gym environment.
Source: github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
"""
def __init__(self):
super().__init__()
self.obs_shape = (4,)
@property
def default_params(self) -> EnvParams:
# Default environment parameters for CartPole-v1
return EnvParams()
def step_env(
self,
key: chex.PRNGKey,
state: EnvState,
action: Union[int, float, chex.Array],
params: EnvParams,
) -> Tuple[chex.Array, EnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
"""Performs step transitions in the environment."""
prev_terminal = self.is_terminal(state, params)
force = params.force_mag * action - params.force_mag * (1 - action)
costheta = jnp.cos(state.theta)
sintheta = jnp.sin(state.theta)
temp = (
force + params.polemass_length * state.theta_dot**2 * sintheta
) / params.total_mass
thetaacc = (params.gravity * sintheta - costheta * temp) / (
params.length
* (4.0 / 3.0 - params.masspole * costheta**2 / params.total_mass)
)
xacc = temp - params.polemass_length * thetaacc * costheta / params.total_mass
# Only default Euler integration option available here!
x = state.x + params.tau * state.x_dot
x_dot = state.x_dot + params.tau * xacc
theta = state.theta + params.tau * state.theta_dot
theta_dot = state.theta_dot + params.tau * thetaacc
# Important: Reward is based on termination is previous step transition
reward = 1.0 - prev_terminal
# Update state dict and evaluate termination conditions
state = EnvState(
x=x,
x_dot=x_dot,
theta=theta,
theta_dot=theta_dot,
time=state.time + 1,
)
done = self.is_terminal(state, params)
return (
lax.stop_gradient(self.get_obs(state)),
lax.stop_gradient(state),
jnp.array(reward),
done,
{"discount": self.discount(state, params)},
)
def reset_env(
self, key: chex.PRNGKey, params: EnvParams
) -> Tuple[chex.Array, EnvState]:
"""Performs resetting of environment."""
init_state = jax.random.uniform(key, minval=-0.05, maxval=0.05, shape=(4,))
state = EnvState(
x=init_state[0],
x_dot=init_state[1],
theta=init_state[2],
theta_dot=init_state[3],
time=0,
)
return self.get_obs(state), state
def get_obs(self, state: EnvState, params=None, key=None) -> chex.Array:
"""Applies observation function to state."""
return jnp.array([state.x, state.x_dot, state.theta, state.theta_dot])
def is_terminal(self, state: EnvState, params: EnvParams) -> jnp.ndarray:
"""Check whether state is terminal."""
# Check termination criteria
done1 = jnp.logical_or(
state.x < -params.x_threshold,
state.x > params.x_threshold,
)
done2 = jnp.logical_or(
state.theta < -params.theta_threshold_radians,
state.theta > params.theta_threshold_radians,
)
# Check number of steps in episode termination condition
done_steps = state.time >= params.max_steps_in_episode
done = jnp.logical_or(jnp.logical_or(done1, done2), done_steps)
return done
@property
def name(self) -> str:
"""Environment name."""
return "CartPole-v1"
@property
def num_actions(self) -> int:
"""Number of actions possible in environment."""
return 2
def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
"""Action space of the environment."""
return spaces.Discrete(2)
def observation_space(self, params: EnvParams) -> spaces.Box:
"""Observation space of the environment."""
high = jnp.array(
[
params.x_threshold * 2,
jnp.finfo(jnp.float32).max,
params.theta_threshold_radians * 2,
jnp.finfo(jnp.float32).max,
]
)
return spaces.Box(-high, high, (4,), dtype=jnp.float32)
def state_space(self, params: EnvParams) -> spaces.Dict:
"""State space of the environment."""
high = jnp.array(
[
params.x_threshold * 2,
jnp.finfo(jnp.float32).max,
params.theta_threshold_radians * 2,
jnp.finfo(jnp.float32).max,
]
)
return spaces.Dict(
{
"x": spaces.Box(-high[0], high[0], (), jnp.float32),
"x_dot": spaces.Box(-high[1], high[1], (), jnp.float32),
"theta": spaces.Box(-high[2], high[2], (), jnp.float32),
"theta_dot": spaces.Box(-high[3], high[3], (), jnp.float32),
"time": spaces.Discrete(params.max_steps_in_episode),
}
)