New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TimeLimitCallback
to mx/trainer
callbacks.
#1631
Conversation
src/gluonts/mx/trainer/callback.py
Outdated
@@ -338,3 +339,31 @@ def on_network_initializing_end( | |||
self, training_network: nn.HybridBlock | |||
) -> None: | |||
copy_parameters(self.predictor.prediction_net, training_network) | |||
|
|||
|
|||
class TimeLimitCallback(Callback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a short doc-string describing the class and it's parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's give it a more descriptive name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about TrainingTimeLimitCallback
?
src/gluonts/mx/trainer/callback.py
Outdated
|
||
class TimeLimitCallback(Callback): | ||
@validated() | ||
def __init__(self, time_limit=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validated
only makes sense if you have type annotations really.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, time_limit=None): | |
def __init__(self, time_limit: Optional[int] = None) -> None: |
src/gluonts/mx/trainer/callback.py
Outdated
self.start_time = None | ||
self.time_limit = time_limit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reverse these? Parameters which are passed should be handled first.
src/gluonts/mx/trainer/callback.py
Outdated
if self.time_limit is not None: | ||
cur_time = time.time() | ||
if cur_time - self.start_time > self.time_limit: | ||
logging.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to use logging
directly`, but use a logger instance instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before I forget: Thanks for the PR!
src/gluonts/mx/trainer/callback.py
Outdated
cur_time = time.time() | ||
if cur_time - self.start_time > self.time_limit: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cur_time = time.time() | |
if cur_time - self.start_time > self.time_limit: | |
elapsed = time.time() - self.start_time | |
if elapsed > self.time_limit: |
src/gluonts/mx/trainer/callback.py
Outdated
@@ -338,3 +341,39 @@ def on_network_initializing_end( | |||
self, training_network: nn.HybridBlock | |||
) -> None: | |||
copy_parameters(self.predictor.prediction_net, training_network) | |||
|
|||
|
|||
class TrainingTimeLimitCallback(Callback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we omitted the callback
name from other callbacks.
class TrainingTimeLimitCallback(Callback): | |
class TrainingTimeLimit(Callback): |
src/gluonts/mx/trainer/callback.py
Outdated
|
||
class TimeLimitCallback(Callback): | ||
@validated() | ||
def __init__(self, time_limit=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, time_limit=None): | |
def __init__(self, time_limit: Optional[int] = None) -> None: |
src/gluonts/mx/trainer/callback.py
Outdated
time_limit: int | ||
time in seconds, after which your training process will end. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
time_limit: int | |
time in seconds, after which your training process will end. | |
time_limit: int | |
time in seconds, after which training ends |
src/gluonts/mx/trainer/callback.py
Outdated
cur_time = time.time() | ||
if cur_time - self.start_time > self.time_limit: | ||
logger.warning( | ||
"Time limit exceed during training, stop training." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Time limit exceed during training, stop training." | |
"Time limit exceeded during training, stopping training." |
@borchero Can you also take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice in general, I was implementing something similar very recently :D however, I would consider a couple more points:
- Should we time validation? We could "stop" timing
on_validation_epoch_start
and resume onon_train_epoch_start
. - Should we extend the
Callback
base class to return a value fromon_train_batch_end
? This way, we could stop training after the first batch exceeding the time limit instead of the first epoch (which is more useful in my opinion).
src/gluonts/mx/trainer/callback.py
Outdated
if self.time_limit is not None: | ||
elapsed = time.time() - self.start_time | ||
if elapsed > self.time_limit: | ||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loger.info
?
What is the intuition here? Why would I want to stop training after a certain amount of time?
Would there be other places where we also would like to be able to stop training? |
|
One very common use case is hyperparameter optimization. In the successive halving algorithm, you train many configurations for a "budget" N, then take the best 50% of configurations, train for a total budget of 2N and so on ... While you can look at the budget as the number of epochs, it is often useful to actually use time as the budget (as it also often determines the money spent) -- especially if you want your budget to be independent of model size and/or dataset size.
I think, we could consider allowing to do that after every hook which is called at the end of some iteration (i.e. end of training/validation batch, end of training/validation epoch, end of epoch). |
I would say that we should have a flag in the |
I think I'm against treating treating validation differently and to have a net-train mode. It makes the code more complicated and is less intuitive (if I have a budged, I don't care too much how it is spend). If we realise there is still a need for this, we can add it later. |
We can have a flag, which controls whether we check after each batch or after each full epoch. What happens if we stop after a batch? Is that entire epoch invalidated? |
I actually needed this exactly (i.e. only track training time and not callbacks/validation since callbacks were potentially very time-consuming) few weeks ago and ended up rewriting the |
I would just stop the epoch prematurely and treat it like the epoch has been completed. |
Fair enough. |
Then let's do this. |
I'm writing the conclusion here to make sure we are on the same page.
Did I miss anything? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks!!
src/gluonts/mx/trainer/callback.py
Outdated
def __init__( | ||
self, | ||
time_limit: int, | ||
include_validation_in_time_limit: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about track_validation_duration
?
src/gluonts/mx/trainer/callback.py
Outdated
@@ -355,25 +362,96 @@ def __init__(self, time_limit: Optional[int] = None) -> None: | |||
time in seconds, after which training ends |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should properly document the parameters.
src/gluonts/mx/trainer/callback.py
Outdated
def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool: | ||
print( | ||
"on_train_batch_end", self.time_spent | ||
) # for debugging purpose, will be deleted before merging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to keep track of the comment^^
src/gluonts/mx/trainer/callback.py
Outdated
self.checkpoint = time.time() | ||
print( | ||
"on_train_epoch_end", self.time_spent | ||
) # for debugging purpose, will be deleted before merging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only for tracking.
src/gluonts/mx/trainer/callback.py
Outdated
self.time_spent += time.time() - self.checkpoint | ||
self.checkpoint = time.time() | ||
if self.stop_during_epoch: | ||
if self.time_spent > self.time_limit: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Log for consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need log here, otherwise, the log will be printed twice when we stop after one batch(one after the batch, and one after the epoch). Because regardless of whether we are to stop after one batch, time limit will always be check after one epoch.
src/gluonts/mx/trainer/callback.py
Outdated
|
||
if self.stop_during_epoch: | ||
if self.time_spent > self.time_limit: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logging for consistency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
Looks like there is some code duplication. Can we make the checking more reusable? |
I've just reduced the redundancy of the code. |
@@ -404,6 +408,8 @@ def loop( # todo call run epoch | |||
logger.info( | |||
f"Number of parameters in {net_name}: {num_model_param}" | |||
) | |||
if not should_continue: | |||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't set the outer should_continue
and thus we will call loop
again and again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed this by setting self.halt=True
before we break from the loop.
src/gluonts/mx/trainer/callback.py
Outdated
def __init__( | ||
self, | ||
time_limit: int, | ||
track_validation_duration: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think something like use_net_training
is better, since time can be spent in a lot of places, including validation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what this actually means.... What would we use use_net_training
to record?
src/gluonts/mx/trainer/callback.py
Outdated
def on_train_start(self, max_epochs: int) -> None: | ||
self.checkpoint = time.time() | ||
|
||
def should_continue_by_timelimit(self, record_time=True, should_stop=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be two separate methods, which do not have default arguments. It's not really clear to me at a quick glance what this is supposed to be doing.
Marked as breaking because of the change in the output signature for the |
938a2ec
to
cd19b02
Compare
TimeLimitCallback
to mx/trainer
callbacks.
@lostella Should we have tests, at least for |
Yes that would be good. For example the |
However not sure we can really test it functionally. Would also take some time, which we don't have. |
Co-authored-by: Jasper <schjaspe@amazon.de>
Issue #, if available:
Description of changes:
Add TimelimitCallback so that user can set a time limit to the training process.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup