-
Notifications
You must be signed in to change notification settings - Fork 750
/
_base.py
437 lines (374 loc) · 17.3 KB
/
_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# Standard library imports
import logging
import os
import tempfile
import time
import uuid
from typing import Any, List, Optional, Union, Callable
# Third-party imports
import mxnet as mx
import mxnet.autograd as autograd
import mxnet.gluon.nn as nn
import numpy as np
# First-party imports
from gluonts.core.component import validated
from gluonts.core.exception import GluonTSDataError, GluonTSUserError
from gluonts.dataset.loader import TrainDataLoader, ValidationDataLoader
from gluonts.gluonts_tqdm import tqdm
from gluonts.mx.context import get_mxnet_context
from gluonts.support.util import HybridContext
from mxnet.metric import ndarray
# Relative imports
from . import learning_rate_scheduler as lrs
from .model_averaging import (
AveragingStrategy,
SelectNBestMean,
save_epoch_info,
)
from .model_iteration_averaging import IterationAveragingStrategy
logger = logging.getLogger("gluonts").getChild("trainer")
MODEL_ARTIFACT_FILE_NAME = "model"
STATE_ARTIFACT_FILE_NAME = "state"
# make the IDE happy: mx.py does not explicitly import autograd
mx.autograd = autograd
def check_loss_finite(val: float) -> None:
if not np.isfinite(val):
raise GluonTSDataError(
"Encountered invalid loss value! Try reducing the learning rate "
"or try a different likelihood."
)
def loss_value(loss: mx.metric.Loss) -> float:
return loss.get_name_value()[0][1]
class Trainer:
r"""
A trainer specifies how a network is going to be trained.
A trainer is mainly defined by two sets of parameters. The first one determines the number of examples
that the network will be trained on (`epochs`, `num_batches_per_epoch` and `batch_size`), while the
second one specifies how the gradient updates are performed (`learning_rate`, `learning_rate_decay_factor`,
`patience`, `minimum_learning_rate`, `clip_gradient` and `weight_decay`).
Parameters
----------
ctx
epochs
Number of epochs that the network will train (default: 100).
batch_size
Number of examples in each batch (default: 32).
num_batches_per_epoch
Number of batches at each epoch (default: 50).
learning_rate
Initial learning rate (default: :math:`10^{-3}`).
learning_rate_decay_factor
Factor (between 0 and 1) by which to decrease the learning rate (default: 0.5).
patience
The patience to observe before reducing the learning rate, nonnegative integer (default: 10).
minimum_learning_rate
Lower bound for the learning rate (default: :math:`5\cdot 10^{-5}`).
clip_gradient
Maximum value of gradient. The gradient is clipped if it is too large (default: 10).
weight_decay
The weight decay (or L2 regularization) coefficient. Modifies objective by adding a penalty for having
large weights (default :math:`10^{-8}`).
init
Initializer of the weights of the network (default: "xavier").
hybridize
If set to true the network will be hybridized before training
post_initialize_cb
An optional callback function. If provided the function will be called with the
initialized network `post_initialize_cb(net)` before the training starts.
This callback can be used to e.g. overwrite parameters for warm starting, to freeze some
of the network parameters etc.
"""
@validated()
def __init__(
self,
ctx: Optional[mx.Context] = None,
epochs: int = 100,
batch_size: int = 32,
num_batches_per_epoch: int = 50,
learning_rate: float = 1e-3,
learning_rate_decay_factor: float = 0.5,
patience: int = 10,
minimum_learning_rate: float = 5e-5,
clip_gradient: float = 10.0,
weight_decay: float = 1e-8,
init: Union[str, mx.initializer.Initializer] = "xavier",
hybridize: bool = True,
avg_strategy: Union[
AveragingStrategy, IterationAveragingStrategy
] = SelectNBestMean(num_models=1),
post_initialize_cb: Optional[Callable[[mx.gluon.Block], None]] = None,
) -> None:
assert (
0 <= epochs < float("inf")
), "The value of `epochs` should be >= 0"
assert 0 < batch_size, "The value of `batch_size` should be > 0"
assert (
0 < num_batches_per_epoch
), "The value of `num_batches_per_epoch` should be > 0"
assert (
0 < learning_rate < float("inf")
), "The value of `learning_rate` should be > 0"
assert (
0 <= learning_rate_decay_factor < 1
), "The value of `learning_rate_decay_factor` should be in the [0, 1) range"
assert 0 <= patience, "The value of `patience` should be >= 0"
assert (
0 <= minimum_learning_rate
), "The value of `minimum_learning_rate` should be >= 0"
assert 0 < clip_gradient, "The value of `clip_gradient` should be > 0"
assert 0 <= weight_decay, "The value of `weight_decay` should be => 0"
self.epochs = epochs
self.batch_size = batch_size
self.num_batches_per_epoch = num_batches_per_epoch
self.learning_rate = learning_rate
self.learning_rate_decay_factor = learning_rate_decay_factor
self.patience = patience
self.minimum_learning_rate = minimum_learning_rate
self.clip_gradient = clip_gradient
self.weight_decay = weight_decay
self.init = init
self.hybridize = hybridize
self.avg_strategy = avg_strategy
self.ctx = ctx if ctx is not None else get_mxnet_context()
self.halt = False
self.post_initialize_cb = post_initialize_cb
def set_halt(self, signum: int, stack_frame: Any) -> None:
logger.info("Received signal: {}".format(signum))
self.halt = True
def count_model_params(self, net: nn.HybridBlock) -> int:
params = net.collect_params()
num_params = 0
for p in params:
v = params[p]
num_params += np.prod(v.shape)
return num_params
def __call__(
self,
net: nn.HybridBlock,
train_iter: TrainDataLoader,
validation_iter: Optional[ValidationDataLoader] = None,
) -> None: # TODO: we may want to return some training information here
"""
Train a network, given an iterable over training (and optionally validation) batches.
Parameters
----------
net
Network to be trained. This a Gluon HybridBlock, assumed to produce a tensor
of loss values as output.
train_iter
An iterable over batches to be used for training. Batches are assumed to be
dictionaries, whose values are MXNet arrays that correspond to the network
inputs.
validation_iter
Similar to `train_iter` but the batches produced here are used to compute
validation metrics.
"""
is_validation_available = validation_iter is not None
self.halt = False
with tempfile.TemporaryDirectory(
prefix="gluonts-trainer-temp-"
) as gluonts_temp:
def base_path() -> str:
return os.path.join(
gluonts_temp,
"{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()),
)
logger.info("Start model training")
net.initialize(ctx=self.ctx, init=self.init)
with HybridContext(
net=net,
hybridize=self.hybridize,
static_alloc=True,
static_shape=True,
):
batch_size = train_iter.batch_size
best_epoch_info = {
"params_path": "%s-%s.params" % (base_path(), "init"),
"epoch_no": -1,
"score": np.Inf,
}
lr_scheduler = lrs.MetricAttentiveScheduler(
objective="min",
patience=self.patience,
decay_factor=self.learning_rate_decay_factor,
min_lr=self.minimum_learning_rate,
)
optimizer = mx.optimizer.Adam(
learning_rate=self.learning_rate,
lr_scheduler=lr_scheduler,
wd=self.weight_decay,
clip_gradient=self.clip_gradient,
)
trainer = mx.gluon.Trainer(
net.collect_params(),
optimizer=optimizer,
kvstore="device", # FIXME: initialize properly
)
first_forward = True
def loop(
epoch_no, batch_iter, is_training: bool = True
) -> mx.metric.Loss:
nonlocal first_forward
tic = time.time()
epoch_loss = mx.metric.Loss()
# use averaged model for validation
if not is_training and isinstance(
self.avg_strategy, IterationAveragingStrategy
):
self.avg_strategy.load_averaged_model(net)
with tqdm(batch_iter) as it:
for batch_no, batch in enumerate(it, start=1):
# `batch` here is expected to be a dictionary whose fields
# should correspond 1-to-1 with the network inputs
# see below how `batch.values()` is fed into the network
if self.halt:
break
if first_forward:
first_forward = False
_ = net(*batch.values())
if self.post_initialize_cb:
self.post_initialize_cb(net)
with mx.autograd.record():
output = net(*batch.values())
# network can returns several outputs, the first being always the loss
# when having multiple outputs, the forward returns a list in the case of hybrid and a
# tuple otherwise
# we may wrap network outputs in the future to avoid this type check
if isinstance(output, (list, tuple)):
loss = output[0]
else:
loss = output
if not np.isfinite(ndarray.sum(loss).asscalar()):
logger.warning(
"Batch [%d] of Epoch[%d] gave NaN loss and it will be ignored",
batch_no,
epoch_no,
)
else:
if is_training:
loss.backward()
trainer.step(batch_size)
# iteration averaging in training
if isinstance(
self.avg_strategy,
IterationAveragingStrategy,
):
self.avg_strategy.apply(net)
epoch_loss.update(None, preds=loss)
lv = loss_value(epoch_loss)
it.set_postfix(
ordered_dict={
"epoch": f"{epoch_no + 1}/{self.epochs}",
("" if is_training else "validation_")
+ "avg_epoch_loss": lv,
},
refresh=False,
)
# print out parameters of the network at the first pass
if batch_no == 1 and epoch_no == 0:
net_name = type(net).__name__
num_model_param = self.count_model_params(net)
logger.info(
f"Number of parameters in {net_name}: {num_model_param}"
)
# mark epoch end time and log time cost of current epoch
toc = time.time()
logger.info(
"Epoch[%d] Elapsed time %.3f seconds",
epoch_no,
(toc - tic),
)
logger.info(
"Epoch[%d] Evaluation metric '%s'=%f",
epoch_no,
("" if is_training else "validation_") + "epoch_loss",
lv,
)
if not is_training and isinstance(
self.avg_strategy, IterationAveragingStrategy
):
# bring back the cached model
self.avg_strategy.load_cached_model(net)
return epoch_loss
for epoch_no in range(self.epochs):
if self.halt:
logger.info(f"Epoch[{epoch_no}] Interrupting training")
break
curr_lr = trainer.learning_rate
logger.info(
f"Epoch[{epoch_no}] Learning rate is {curr_lr}"
)
epoch_loss = loop(epoch_no, train_iter)
if is_validation_available:
epoch_loss = loop(
epoch_no, validation_iter, is_training=False
)
# update average trigger
if isinstance(
self.avg_strategy, IterationAveragingStrategy
):
self.avg_strategy.update_average_trigger(
metric=loss_value(epoch_loss), epoch=epoch_no + 1
)
# once triggered, update the average immediately
self.avg_strategy.apply(net)
should_continue = lr_scheduler.step(loss_value(epoch_loss))
if isinstance(
self.avg_strategy, IterationAveragingStrategy
):
logging.info(
"Overriding early stopping for iteration-based averaging strategies."
)
should_continue = True
if not should_continue:
logger.info("Stopping training")
break
# save model and epoch info
bp = base_path()
epoch_info = {
"params_path": f"{bp}-0000.params",
"epoch_no": epoch_no,
"score": loss_value(epoch_loss),
}
net.save_parameters(
epoch_info["params_path"]
) # TODO: handle possible exception
save_epoch_info(bp, epoch_info)
# update best epoch info - needed for the learning rate scheduler
if loss_value(epoch_loss) < best_epoch_info["score"]:
best_epoch_info = epoch_info.copy()
if not trainer.learning_rate == curr_lr:
if best_epoch_info["epoch_no"] == -1:
raise GluonTSUserError(
"Got NaN in first epoch. Try reducing initial learning rate."
)
logger.info(
f"Loading parameters from best epoch "
f"({best_epoch_info['epoch_no']})"
)
net.load_parameters(
best_epoch_info["params_path"], self.ctx
)
if isinstance(self.avg_strategy, AveragingStrategy):
logging.info("Computing averaged parameters.")
averaged_params_path = self.avg_strategy.apply(
gluonts_temp
)
logging.info("Loading averaged parameters.")
net.load_parameters(averaged_params_path, self.ctx)
if isinstance(self.avg_strategy, IterationAveragingStrategy):
logging.info("Loading averaged parameters.")
self.avg_strategy.load_averaged_model(net)
logger.info("End model training")