Skip to content

Commit 4627d6f

Browse files
author
Shunichi09
committed
Add: catpole env
1 parent 4e01264 commit 4627d6f

File tree

17 files changed

+723
-26
lines changed

17 files changed

+723
-26
lines changed

Environments.md

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,48 @@
99

1010
## FistOrderLagEnv
1111

12-
System equations.
12+
### System equation.
1313

1414
<img src="assets/firstorderlag.png" width="550">
1515

1616
You can set arbinatry time constant, tau. The default is 0.63 s
1717

18+
### Cost.
19+
20+
<img src="assets/quadratic_score.png" width="200">
21+
22+
Q = diag[1., 1., 1., 1.],
23+
R = diag[1., 1.]
24+
25+
X_g denote the goal states.
26+
1827
## TwoWheeledEnv
1928

20-
System equations.
29+
### System equation.
2130

2231
<img src="assets/twowheeled.png" width="300">
2332

33+
### Cost.
34+
35+
<img src="assets/quadratic_score.png" width="200">
36+
37+
Q = diag[5., 5., 1.],
38+
R = diag[0.1, 0.1]
39+
40+
X_g denote the goal states.
41+
2442
## CatpoleEnv (Swing up)
2543

26-
System equations.
44+
System equation.
2745

2846
<img src="assets/cartpole.png" width="600">
2947

3048
You can set arbinatry parameters, mc, mp, l and g.
3149

3250
Default settings are as follows:
3351

34-
mc = 1, mp = 0.2, l = 0.5, g = 9.8
52+
mc = 1, mp = 0.2, l = 0.5, g = 9.81
53+
54+
### Cost.
55+
56+
<img src="assets/cartpole_score.png" width="300">
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import numpy as np
2+
3+
class CartPoleConfigModule():
4+
# parameters
5+
ENV_NAME = "CartPole-v0"
6+
TYPE = "Nonlinear"
7+
TASK_HORIZON = 500
8+
PRED_LEN = 50
9+
STATE_SIZE = 4
10+
INPUT_SIZE = 1
11+
DT = 0.02
12+
# cost parameters
13+
R = np.diag([0.01])
14+
# bounds
15+
INPUT_LOWER_BOUND = np.array([-3.])
16+
INPUT_UPPER_BOUND = np.array([3.])
17+
# parameters
18+
MP = 0.2
19+
MC = 1.
20+
L = 0.5
21+
G = 9.81
22+
23+
def __init__(self):
24+
"""
25+
"""
26+
# opt configs
27+
self.opt_config = {
28+
"Random": {
29+
"popsize": 5000
30+
},
31+
"CEM": {
32+
"popsize": 500,
33+
"num_elites": 50,
34+
"max_iters": 15,
35+
"alpha": 0.3,
36+
"init_var":9.,
37+
"threshold":0.001
38+
},
39+
"MPPI":{
40+
"beta" : 0.6,
41+
"popsize": 5000,
42+
"kappa": 0.9,
43+
"noise_sigma": 0.5,
44+
},
45+
"MPPIWilliams":{
46+
"popsize": 5000,
47+
"lambda": 1.,
48+
"noise_sigma": 0.9,
49+
},
50+
"iLQR":{
51+
"max_iter": 500,
52+
"init_mu": 1.,
53+
"mu_min": 1e-6,
54+
"mu_max": 1e10,
55+
"init_delta": 2.,
56+
"threshold": 1e-6,
57+
},
58+
"DDP":{
59+
"max_iter": 500,
60+
"init_mu": 1.,
61+
"mu_min": 1e-6,
62+
"mu_max": 1e10,
63+
"init_delta": 2.,
64+
"threshold": 1e-6,
65+
},
66+
"NMPC-CGMRES":{
67+
},
68+
"NMPC-Newton":{
69+
},
70+
}
71+
72+
@staticmethod
73+
def input_cost_fn(u):
74+
""" input cost functions
75+
Args:
76+
u (numpy.ndarray): input, shape(pred_len, input_size)
77+
or shape(pop_size, pred_len, input_size)
78+
Returns:
79+
cost (numpy.ndarray): cost of input, shape(pred_len, input_size) or
80+
shape(pop_size, pred_len, input_size)
81+
"""
82+
return (u**2) * np.diag(CartPoleConfigModule.R)
83+
84+
@staticmethod
85+
def state_cost_fn(x, g_x):
86+
""" state cost function
87+
Args:
88+
x (numpy.ndarray): state, shape(pred_len, state_size)
89+
or shape(pop_size, pred_len, state_size)
90+
g_x (numpy.ndarray): goal state, shape(pred_len, state_size)
91+
or shape(pop_size, pred_len, state_size)
92+
Returns:
93+
cost (numpy.ndarray): cost of state, shape(pred_len, 1) or
94+
shape(pop_size, pred_len, 1)
95+
"""
96+
97+
if len(x.shape) > 2:
98+
return (6. * (x[:, :, 0]**2) \
99+
+ 12. * ((np.cos(x[:, :, 2]) + 1.)**2) \
100+
+ 0.1 * (x[:, :, 1]**2) \
101+
+ 0.1 * (x[:, :, 3]**2))[:, :, np.newaxis]
102+
103+
elif len(x.shape) > 1:
104+
return (6. * (x[:, 0]**2) \
105+
+ 12. * ((np.cos(x[:, 2]) + 1.)**2) \
106+
+ 0.1 * (x[:, 1]**2) \
107+
+ 0.1 * (x[:, 3]**2))[:, np.newaxis]
108+
109+
return 6. * (x[0]**2) \
110+
+ 12. * ((np.cos(x[2]) + 1.)**2) \
111+
+ 0.1 * (x[1]**2) \
112+
+ 0.1 * (x[3]**2)
113+
114+
@staticmethod
115+
def terminal_state_cost_fn(terminal_x, terminal_g_x):
116+
"""
117+
Args:
118+
terminal_x (numpy.ndarray): terminal state,
119+
shape(state_size, ) or shape(pop_size, state_size)
120+
terminal_g_x (numpy.ndarray): terminal goal state,
121+
shape(state_size, ) or shape(pop_size, state_size)
122+
Returns:
123+
cost (numpy.ndarray): cost of state, shape(pred_len, ) or
124+
shape(pop_size, pred_len)
125+
"""
126+
127+
if len(terminal_x.shape) > 1:
128+
return (6. * (terminal_x[:, 0]**2) \
129+
+ 12. * ((np.cos(terminal_x[:, 2]) + 1.)**2) \
130+
+ 0.1 * (terminal_x[:, 1]**2) \
131+
+ 0.1 * (terminal_x[:, 3]**2))[:, np.newaxis]
132+
133+
return 6. * (terminal_x[0]**2) \
134+
+ 12. * ((np.cos(terminal_x[2]) + 1.)**2) \
135+
+ 0.1 * (terminal_x[1]**2) \
136+
+ 0.1 * (terminal_x[3]**2)
137+
138+
@staticmethod
139+
def gradient_cost_fn_with_state(x, g_x, terminal=False):
140+
""" gradient of costs with respect to the state
141+
142+
Args:
143+
x (numpy.ndarray): state, shape(pred_len, state_size)
144+
g_x (numpy.ndarray): goal state, shape(pred_len, state_size)
145+
146+
Returns:
147+
l_x (numpy.ndarray): gradient of cost, shape(pred_len, state_size)
148+
or shape(1, state_size)
149+
"""
150+
if not terminal:
151+
return None
152+
153+
return None
154+
155+
@staticmethod
156+
def gradient_cost_fn_with_input(x, u):
157+
""" gradient of costs with respect to the input
158+
159+
Args:
160+
x (numpy.ndarray): state, shape(pred_len, state_size)
161+
u (numpy.ndarray): goal state, shape(pred_len, input_size)
162+
163+
Returns:
164+
l_u (numpy.ndarray): gradient of cost, shape(pred_len, input_size)
165+
"""
166+
return None
167+
168+
@staticmethod
169+
def hessian_cost_fn_with_state(x, g_x, terminal=False):
170+
""" hessian costs with respect to the state
171+
172+
Args:
173+
x (numpy.ndarray): state, shape(pred_len, state_size)
174+
g_x (numpy.ndarray): goal state, shape(pred_len, state_size)
175+
176+
Returns:
177+
l_xx (numpy.ndarray): gradient of cost,
178+
shape(pred_len, state_size, state_size) or
179+
shape(1, state_size, state_size) or
180+
"""
181+
if not terminal:
182+
(pred_len, _) = x.shape
183+
return None
184+
185+
return None
186+
187+
@staticmethod
188+
def hessian_cost_fn_with_input(x, u):
189+
""" hessian costs with respect to the input
190+
191+
Args:
192+
x (numpy.ndarray): state, shape(pred_len, state_size)
193+
u (numpy.ndarray): goal state, shape(pred_len, input_size)
194+
195+
Returns:
196+
l_uu (numpy.ndarray): gradient of cost,
197+
shape(pred_len, input_size, input_size)
198+
"""
199+
(pred_len, _) = u.shape
200+
201+
return None
202+
203+
@staticmethod
204+
def hessian_cost_fn_with_input_state(x, u):
205+
""" hessian costs with respect to the state and input
206+
207+
Args:
208+
x (numpy.ndarray): state, shape(pred_len, state_size)
209+
u (numpy.ndarray): goal state, shape(pred_len, input_size)
210+
211+
Returns:
212+
l_ux (numpy.ndarray): gradient of cost ,
213+
shape(pred_len, input_size, state_size)
214+
"""
215+
(_, state_size) = x.shape
216+
(pred_len, input_size) = u.shape
217+
218+
return np.zeros((pred_len, input_size, state_size))

PythonLinearNonlinearControl/configs/make_configs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .first_order_lag import FirstOrderLagConfigModule
22
from .two_wheeled import TwoWheeledConfigModule
3+
from .cartpole import CartPoleConfigModule
34

45
def make_config(args):
56
"""
@@ -9,4 +10,6 @@ def make_config(args):
910
if args.env == "FirstOrderLag":
1011
return FirstOrderLagConfigModule()
1112
elif args.env == "TwoWheeledConst" or args.env == "TwoWheeled":
12-
return TwoWheeledConfigModule()
13+
return TwoWheeledConfigModule()
14+
elif args.env == "CartPole":
15+
return CartPoleConfigModule()

PythonLinearNonlinearControl/envs/cartpole.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@ class CartPoleEnv(Env):
1414
def __init__(self):
1515
"""
1616
"""
17-
self.config = {"state_size" : 4,\
18-
"input_size" : 1,\
19-
"dt" : 0.02,\
20-
"max_step" : 1000,\
21-
"input_lower_bound": None,\
22-
"input_upper_bound": None,
17+
self.config = {"state_size" : 4,
18+
"input_size" : 1,
19+
"dt" : 0.02,
20+
"max_step" : 500,
21+
"input_lower_bound": [-3.],
22+
"input_upper_bound": [3.],
23+
"mp": 0.2,
24+
"mc": 1.,
25+
"l": 0.5,
26+
"g": 9.81,
2327
}
2428

2529
super(CartPoleEnv, self).__init__(self.config)
@@ -33,13 +37,13 @@ def reset(self, init_x=None):
3337
"""
3438
self.step_count = 0
3539

36-
self.curr_x = np.zeros(self.config["state_size"])
40+
self.curr_x = np.array([0., 0., 0., 0.])
3741

3842
if init_x is not None:
3943
self.curr_x = init_x
4044

4145
# goal
42-
self.g_x = np.array([0., 0., np.pi, 0.])
46+
self.g_x = np.array([0., 0., -np.pi, 0.])
4347

4448
# clear memory
4549
self.history_x = []
@@ -65,20 +69,43 @@ def step(self, u):
6569
self.config["input_upper_bound"])
6670

6771
# step
68-
next_x = np.zeros(self.config["state_size"])
72+
# x
73+
d_x0 = self.curr_x[1]
74+
# v_x
75+
d_x1 = (u[0] + self.config["mp"] * np.sin(self.curr_x[2]) \
76+
* (self.config["l"] * (self.curr_x[3]**2) \
77+
+ self.config["g"] * np.cos(self.curr_x[2]))) \
78+
/ (self.config["mc"] + self.config["mp"] \
79+
* (np.sin(self.curr_x[2])**2))
80+
# theta
81+
d_x2 = self.curr_x[3]
82+
83+
# v_theta
84+
d_x3 = (-u[0] * np.cos(self.curr_x[2]) \
85+
- self.config["mp"] * self.config["l"] * (self.curr_x[3]**2) \
86+
* np.cos(self.curr_x[2]) * np.sin(self.curr_x[2]) \
87+
- (self.config["mc"] + self.config["mp"]) * self.config["g"] \
88+
* np.sin(self.curr_x[2])) \
89+
/ (self.config["l"] * (self.config["mc"] + self.config["mp"] \
90+
* (np.sin(self.curr_x[2])**2)))
91+
92+
next_x = self.curr_x +\
93+
np.array([d_x0, d_x1, d_x2, d_x3]) * self.config["dt"]
6994

7095
# TODO: costs
7196
costs = 0.
7297
costs += 0.1 * np.sum(u**2)
73-
costs += np.sum((self.curr_x - self.g_x)**2)
74-
98+
costs += 6. * self.curr_x[0]**2 \
99+
+ 12. * (np.cos(self.curr_x[2]) + 1.)**2 \
100+
+ 0.1 * self.curr_x[1]**2 \
101+
+ 0.1 * self.curr_x[3]**2
75102

76103
# save history
77104
self.history_x.append(next_x.flatten())
78105
self.history_g_x.append(self.g_x.flatten())
79106

80107
# update
81-
self.curr_x = next_x.flatten()
108+
self.curr_x = next_x.flatten().copy()
82109
# update costs
83110
self.step_count += 1
84111

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .first_order_lag import FirstOrderLagEnv
22
from .two_wheeled import TwoWheeledConstEnv
3-
from .cartpole import CartpoleEnv
3+
from .cartpole import CartPoleEnv
44

55
def make_env(args):
66

@@ -9,6 +9,6 @@ def make_env(args):
99
elif args.env == "TwoWheeledConst":
1010
return TwoWheeledConstEnv()
1111
elif args.env == "CartPole":
12-
return CartpoleEnv()
12+
return CartPoleEnv()
1313

1414
raise NotImplementedError("There is not {} Env".format(args.env))

0 commit comments

Comments
 (0)