Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TimeLimitCallback to mx/trainer callbacks. #1631

Merged
merged 16 commits into from Jun 15, 2022

Conversation

yx1215
Copy link
Contributor

@yx1215 yx1215 commented Jul 19, 2021

Issue #, if available:

Description of changes:
Add TimelimitCallback so that user can set a time limit to the training process.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@@ -338,3 +339,31 @@ def on_network_initializing_end(
self, training_network: nn.HybridBlock
) -> None:
copy_parameters(self.predictor.prediction_net, training_network)


class TimeLimitCallback(Callback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a short doc-string describing the class and it's parameters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give it a more descriptive name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about TrainingTimeLimitCallback?


class TimeLimitCallback(Callback):
@validated()
def __init__(self, time_limit=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validated only makes sense if you have type annotations really.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, time_limit=None):
def __init__(self, time_limit: Optional[int] = None) -> None:

Comment on lines 347 to 378
self.start_time = None
self.time_limit = time_limit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reverse these? Parameters which are passed should be handled first.

if self.time_limit is not None:
cur_time = time.time()
if cur_time - self.start_time > self.time_limit:
logging.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to use logging directly`, but use a logger instance instead.

Copy link
Contributor

@jaheba jaheba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before I forget: Thanks for the PR!

Comment on lines 373 to 374
cur_time = time.time()
if cur_time - self.start_time > self.time_limit:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cur_time = time.time()
if cur_time - self.start_time > self.time_limit:
elapsed = time.time() - self.start_time
if elapsed > self.time_limit:

@@ -338,3 +341,39 @@ def on_network_initializing_end(
self, training_network: nn.HybridBlock
) -> None:
copy_parameters(self.predictor.prediction_net, training_network)


class TrainingTimeLimitCallback(Callback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we omitted the callback name from other callbacks.

Suggested change
class TrainingTimeLimitCallback(Callback):
class TrainingTimeLimit(Callback):


class TimeLimitCallback(Callback):
@validated()
def __init__(self, time_limit=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, time_limit=None):
def __init__(self, time_limit: Optional[int] = None) -> None:

Comment on lines 354 to 355
time_limit: int
time in seconds, after which your training process will end.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
time_limit: int
time in seconds, after which your training process will end.
time_limit: int
time in seconds, after which training ends

src/gluonts/mx/trainer/callback.py Outdated Show resolved Hide resolved
cur_time = time.time()
if cur_time - self.start_time > self.time_limit:
logger.warning(
"Time limit exceed during training, stop training."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Time limit exceed during training, stop training."
"Time limit exceeded during training, stopping training."

@jaheba jaheba requested a review from borchero July 19, 2021 16:00
@jaheba
Copy link
Contributor

jaheba commented Jul 19, 2021

@borchero Can you also take a look?

Copy link
Contributor

@borchero borchero left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice in general, I was implementing something similar very recently :D however, I would consider a couple more points:

  • Should we time validation? We could "stop" timing on_validation_epoch_start and resume on on_train_epoch_start.
  • Should we extend the Callback base class to return a value from on_train_batch_end? This way, we could stop training after the first batch exceeding the time limit instead of the first epoch (which is more useful in my opinion).

src/gluonts/mx/trainer/callback.py Outdated Show resolved Hide resolved
src/gluonts/mx/trainer/callback.py Outdated Show resolved Hide resolved
if self.time_limit is not None:
elapsed = time.time() - self.start_time
if elapsed > self.time_limit:
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loger.info?

@jaheba
Copy link
Contributor

jaheba commented Jul 19, 2021

  • Should we time validation? We could "stop" timing on_validation_epoch_start and resume on on_train_epoch_start.

What is the intuition here? Why would I want to stop training after a certain amount of time?

  • Should we extend the Callback base class to return a value from on_train_batch_end? This way, we could stop training after the first batch exceeding the time limit instead of the first epoch (which is more useful in my opinion).

Would there be other places where we also would like to be able to stop training?

@yx1215
Copy link
Contributor Author

yx1215 commented Jul 20, 2021

  • Should we time validation? We could "stop" timing on_validation_epoch_start and resume on on_train_epoch_start.
  • What is the intuition here? Why would I want to stop training after a certain amount of time?

Sometimes user might only have limit time resource. They won't know with that limit time, how much epoch they can run. so they can set the epoch to 9999 and give a time limit to make it stop once the time is used up.

  • Should we extend the Callback base class to return a value from on_train_batch_end? This way, we could stop training after the first batch exceeding the time limit instead of the first epoch (which is more useful in my opinion).
  • Would there be other places where we also would like to be able to stop training?

Maybe we can do what @borchero says, recording only the training process, not the validation process. And if we want to be more precise with the time limit, we can check after every batch

@borchero
Copy link
Contributor

borchero commented Jul 20, 2021

What is the intuition here? Why would I want to stop training after a certain amount of time?

One very common use case is hyperparameter optimization. In the successive halving algorithm, you train many configurations for a "budget" N, then take the best 50% of configurations, train for a total budget of 2N and so on ... While you can look at the budget as the number of epochs, it is often useful to actually use time as the budget (as it also often determines the money spent) -- especially if you want your budget to be independent of model size and/or dataset size.

Would there be other places where we also would like to be able to stop training?

I think, we could consider allowing to do that after every hook which is called at the end of some iteration (i.e. end of training/validation batch, end of training/validation epoch, end of epoch).

@borchero
Copy link
Contributor

Maybe we can do what @borchero says, recording only the training process, not the validation process. And if we want to be more precise with the time limit, we can check after every batch.

I would say that we should have a flag in the __init__ of the callback to determine whether validation should be recorded.

@jaheba
Copy link
Contributor

jaheba commented Jul 20, 2021

I think I'm against treating treating validation differently and to have a net-train mode. It makes the code more complicated and is less intuitive (if I have a budged, I don't care too much how it is spend). If we realise there is still a need for this, we can add it later.

@jaheba
Copy link
Contributor

jaheba commented Jul 20, 2021

We can have a flag, which controls whether we check after each batch or after each full epoch.

What happens if we stop after a batch? Is that entire epoch invalidated?

@borchero
Copy link
Contributor

I think I'm against treating treating validation differently and to have a net-train mode. It makes the code more complicated and is less intuitive (if I have a budged, I don't care too much how it is spend). If we realise there is still a need for this, we can add it later.

I actually needed this exactly (i.e. only track training time and not callbacks/validation since callbacks were potentially very time-consuming) few weeks ago and ended up rewriting the Trainer class (for some other reasons as well). Would be nice to have this option out-of-the-box and it's not too much work imo.

@borchero
Copy link
Contributor

What happens if we stop after a batch? Is that entire epoch invalidated?

I would just stop the epoch prematurely and treat it like the epoch has been completed.

@jaheba
Copy link
Contributor

jaheba commented Jul 20, 2021

I think I'm against treating treating validation differently and to have a net-train mode. It makes the code more complicated and is less intuitive (if I have a budged, I don't care too much how it is spend). If we realise there is still a need for this, we can add it later.

I actually needed this exactly (i.e. only track training time and not callbacks/validation since callbacks were potentially very time-consuming) few weeks ago and ended up rewriting the Trainer class (for some other reasons as well). Would be nice to have this option out-of-the-box and it's not too much work imo.

Fair enough.

@jaheba
Copy link
Contributor

jaheba commented Jul 20, 2021

What happens if we stop after a batch? Is that entire epoch invalidated?

I would just stop the epoch prematurely and treat it like the epoch has been completed.

Then let's do this.

@yx1215
Copy link
Contributor Author

yx1215 commented Jul 20, 2021

I'm writing the conclusion here to make sure we are on the same page.
We will have to add the following:

  • a flag that controls whether validation epoch should be recorded in the time limit
  • a flag that controls whether we stop at the end of each epoch or each batch, and if we stop after a batch, we will treat it as the whole epoch ends.

Did I miss anything?

Copy link
Contributor

@borchero borchero left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!!

src/gluonts/mx/trainer/callback.py Show resolved Hide resolved
src/gluonts/mx/trainer/callback.py Show resolved Hide resolved
def __init__(
self,
time_limit: int,
include_validation_in_time_limit: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about track_validation_duration?

@@ -355,25 +362,96 @@ def __init__(self, time_limit: Optional[int] = None) -> None:
time in seconds, after which training ends
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should properly document the parameters.

def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool:
print(
"on_train_batch_end", self.time_spent
) # for debugging purpose, will be deleted before merging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to keep track of the comment^^

self.checkpoint = time.time()
print(
"on_train_epoch_end", self.time_spent
) # for debugging purpose, will be deleted before merging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only for tracking.

self.time_spent += time.time() - self.checkpoint
self.checkpoint = time.time()
if self.stop_during_epoch:
if self.time_spent > self.time_limit:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log for consistency?

Copy link
Contributor Author

@yx1215 yx1215 Jul 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need log here, otherwise, the log will be printed twice when we stop after one batch(one after the batch, and one after the epoch). Because regardless of whether we are to stop after one batch, time limit will always be check after one epoch.


if self.stop_during_epoch:
if self.time_spent > self.time_limit:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging for consistency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@jaheba
Copy link
Contributor

jaheba commented Jul 24, 2021

Looks like there is some code duplication. Can we make the checking more reusable?

@yx1215
Copy link
Contributor Author

yx1215 commented Jul 25, 2021

Looks like there is some code duplication. Can we make the checking more reusable?

I've just reduced the redundancy of the code.

@@ -404,6 +408,8 @@ def loop( # todo call run epoch
logger.info(
f"Number of parameters in {net_name}: {num_model_param}"
)
if not should_continue:
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't set the outer should_continue and thus we will call loop again and again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed this by setting self.halt=True before we break from the loop.

def __init__(
self,
time_limit: int,
track_validation_duration: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something like use_net_training is better, since time can be spent in a lot of places, including validation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what this actually means.... What would we use use_net_training to record?

def on_train_start(self, max_epochs: int) -> None:
self.checkpoint = time.time()

def should_continue_by_timelimit(self, record_time=True, should_stop=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be two separate methods, which do not have default arguments. It's not really clear to me at a quick glance what this is supposed to be doing.

@yx1215 yx1215 requested a review from borchero August 4, 2021 18:47
@lostella lostella added new feature (one of pr required labels) BREAKING This is a breaking change (one of pr required labels) labels Aug 9, 2021
@lostella
Copy link
Contributor

Marked as breaking because of the change in the output signature for the on_train_batch_end and on_validation_batch_end methods

@lostella lostella added this to the v0.9 milestone Aug 24, 2021
@lostella lostella modified the milestones: v0.9, v0.10 Feb 17, 2022
@jaheba jaheba changed the title add TimeLimitCallback in callback.py Add TimeLimitCallback to mx/trainer callbacks. Jun 14, 2022
@jaheba
Copy link
Contributor

jaheba commented Jun 14, 2022

@lostella Should we have tests, at least for serde?

@lostella
Copy link
Contributor

@lostella Should we have tests, at least for serde?

Yes that would be good. For example the TerminateOnNan callback doesn't seem to use any of the serialization mechanisms other classes rely on.

@jaheba
Copy link
Contributor

jaheba commented Jun 15, 2022

However not sure we can really test it functionally. Would also take some time, which we don't have.

@jaheba jaheba merged commit 7032ada into awslabs:dev Jun 15, 2022
@jaheba jaheba mentioned this pull request Jun 17, 2022
kashif pushed a commit to kashif/gluon-ts that referenced this pull request Jun 24, 2022
Co-authored-by: Jasper <schjaspe@amazon.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
BREAKING This is a breaking change (one of pr required labels) new feature (one of pr required labels)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants