-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
395 lines (328 loc) · 14.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
import glob
import os
import random
from collections import deque
from itertools import zip_longest
from typing import Callable, Iterable, Optional, Union
import gym
import numpy as np
import torch as th
# Check if tensorboard is available for pytorch
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
from stable_baselines3.common import logger
from stable_baselines3.common.type_aliases import GymEnv
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
"""
Seed the different random generators
:param seed:
:param using_cuda:
"""
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# seed the RNG for all devices (both CPU and CUDA)
th.manual_seed(seed)
if using_cuda:
# Deterministic operations for CuDNN, it may impact performances
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
# From stable baselines
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
:param y_pred: the prediction
:param y_true: the expected value
:return: explained variance of ypred and y
"""
assert y_true.ndim == 1 and y_pred.ndim == 1
var_y = np.var(y_true)
return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
"""
Update the learning rate for a given optimizer.
Useful when doing linear schedule.
:param optimizer:
:param learning_rate:
"""
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
def get_schedule_fn(value_schedule: Union[Callable, float]) -> Callable:
"""
Transform (if needed) learning rate and clip range (for PPO)
to callable.
:param value_schedule:
:return:
"""
# If the passed schedule is a float
# create a constant function
if isinstance(value_schedule, (float, int)):
# Cast to float to avoid errors
value_schedule = constant_fn(float(value_schedule))
else:
assert callable(value_schedule)
return value_schedule
def get_linear_fn(start: float, end: float, end_fraction: float) -> Callable:
"""
Create a function that interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).
:params start: value to start with if ``progress_remaining`` = 1
:params end: value to end with if ``progress_remaining`` = 0
:params end_fraction: fraction of ``progress_remaining``
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
:return:
"""
def func(progress_remaining: float) -> float:
if (1 - progress_remaining) > end_fraction:
return end
else:
return start + (1 - progress_remaining) * (end - start) / end_fraction
return func
def constant_fn(val: float) -> Callable:
"""
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)
:param val:
:return:
"""
def func(_):
return val
return func
def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
:return:
"""
# Cuda by default
if device == "auto":
device = "cuda"
# Force conversion to th.device
device = th.device(device)
# Cuda not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
return device
def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
:return: latest run number
"""
max_run_id = 0
for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
max_run_id = int(ext)
return max_run_id
def configure_logger(
verbose: int = 0, tensorboard_log: Optional[str] = None, tb_log_name: str = "", reset_num_timesteps: bool = True
) -> None:
"""
Configure the logger's outputs.
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param tb_log_name: tensorboard log
"""
if tensorboard_log is not None and SummaryWriter is not None:
latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name)
if not reset_num_timesteps:
# Continue training in the same directory
latest_run_id -= 1
save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}")
if verbose >= 1:
logger.configure(save_path, ["stdout", "tensorboard"])
else:
logger.configure(save_path, ["tensorboard"])
elif verbose == 0:
logger.configure(format_strings=[""])
def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> None:
"""
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
spaces match after loading the model with given env.
Checked parameters:
- observation_space
- action_space
:param env: Environment to check for valid spaces
:param observation_space: Observation space to check against
:param action_space: Action space to check against
"""
if observation_space != env.observation_space:
raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
if action_space != env.action_space:
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
def is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if isinstance(observation_space, gym.spaces.Box):
if observation.shape == observation_space.shape:
return False
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape)))
)
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use (1,) or (n_env, 1) for the observation shape."
)
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape."
)
elif isinstance(observation_space, gym.spaces.MultiBinary):
if observation.shape == (observation_space.n,):
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use ({observation_space.n},) or "
+ f"(n_env, {observation_space.n}) for the observation shape."
)
else:
raise ValueError(
"Error: Cannot determine if the observation is vectorized " + f" with the space type {observation_space}."
)
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr:
:return:
"""
return np.nan if len(arr) == 0 else np.mean(arr)
def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
Raises ``ValueError`` if iterables not of equal length.
Code inspired by Stackoverflow answer for question #32954486.
:param \*iterables: iterables to ``zip()``
"""
# As in Stackoverflow #32954486, use
# new object for "empty" in case we have
# Nones in iterable.
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError("Iterables have different lengths")
yield combo
def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th.nn.Parameter], tau: float) -> None:
"""
Perform a Polyak average update on ``target_params`` using ``params``:
target parameters are slowly updated towards the main parameters.
``tau``, the soft update coefficient controls the interpolation:
``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
or a computation graph, reducing memory cost and improving performance. We scale the target params
by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
params (in place).
See https://github.com/DLR-RM/stable-baselines3/issues/93
:param params: parameters to use to update the target params
:param target_params: parameters to update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
"""
with th.no_grad():
# zip does not raise an exception if length of parameters does not match.
for param, target_param in zip_strict(params, target_params):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
class DQNClippedLoss(th.nn.Module):
"""
Custom loss function per paper: https://arxiv.org/abs/2101.03958
In DQN, replaces Huber/MSE loss between train and target network
:param current_q: Q(st, at) of training network
:param target_q: Max Q value from the target network, including the reward and gamma. r + gamma * Q_target(st+1,a)
:param target_q_a_max: Q_target(st+1, a) where a is the action of the max Q value
:param gamma: the discount factor
"""
def __init__(self):
super(DQNClippedLoss, self).__init__()
def forward(self, current_q, target_q, target_q_a_max, gamma):
# max[current_q, delta^2 + target_q] + max[delta,gamma*(target_q_a_max)^2]
delta = current_q - target_q
left_max = th.max(current_q, th.pow(delta, 2) + target_q)
right_max = th.max(delta, gamma * th.pow(target_q_a_max, 2))
loss = th.mean(left_max + right_max)
return loss
class DQNRegLoss(th.nn.Module):
"""
Custom loss function per paper: https://arxiv.org/abs/2101.03958
In DQN, replaces Huber/MSE loss between train and target network
:param current_q: Q(st, at) of training network
:param target_q: Max Q value from the target network, including the reward and gamma. r + gamma * Q_target(st+1,a)
:param weight: scalar. weighted term that regularizes Q value. Paper defaults to 0.1 but theorizes that tuning this
per env to some positive value may be beneficial.
"""
def __init__(self):
super(DQNRegLoss, self).__init__()
def forward(self, current_q, target_q, weight=0.1):
# weight * Q(st, at) + delta^2
delta = current_q - target_q
loss = th.mean(weight * current_q + th.pow(delta, 2))
return loss
def dqn_clipped_loss(current_q, target_q, target_q_a_max, gamma):
"""
Custom loss function per paper: https://arxiv.org/abs/2101.03958
In DQN, replaces Huber/MSE loss between train and target network
:param current_q: Q(st, at) of training network
:param target_q: Max Q value from the target network, including the reward and gamma. r + gamma * Q_target(st+1,a)
:param target_q_a_max: Q_target(st+1, a) where a is the action of the max Q value
:param gamma: the discount factor
"""
# max[current_q, delta^2 + target_q] + max[delta,gamma*(target_q_a_max)^2]
delta = current_q - target_q
left_max = th.max(current_q, th.pow(delta, 2) + target_q)
right_max = th.max(delta, gamma * th.pow(target_q_a_max, 2))
loss = th.mean(left_max + right_max)
return loss
def dqn_reg_loss(current_q, target_q, weight=0.1):
"""
Custom loss function per paper: https://arxiv.org/abs/2101.03958
In DQN, replaces Huber/MSE loss between train and target network
:param current_q: Q(st, at) of training network
:param target_q: Max Q value from the target network, including the reward and gamma. r + gamma * Q_target(st+1,a)
:param weight: scalar. weighted term that regularizes Q value. Paper defaults to 0.1 but theorizes that tuning this
per env to some positive value may be beneficial.
"""
# weight * Q(st, at) + delta^2
delta = current_q - target_q
loss = th.mean(weight * current_q + th.pow(delta, 2))
return loss