Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the trainer more configurable #3913

schmmd opened this issue Mar 6, 2020 · 0 comments · Fixed by #3970

Make the trainer more configurable #3913

schmmd opened this issue Mar 6, 2020 · 0 comments · Fixed by #3970


Copy link

@schmmd schmmd commented Mar 6, 2020

From #3519 (comment):

Ok, I'm scoping this out right now, taking the approach of adding a few targeted callbacks to the original trainer, and simplifying what's there / making it more modular. Here's a list of specific items that I'm thinking of doing; I'm putting this here to get feedback before starting on this. My process for deciding on these things to change was (1) see what parameters to __init__ I can remove / push to other, configurable dependencies, (2) see what other places might want some customization, and propose specific callbacks for them.

  1. Change the name of TrainerBase to Trainer, and Trainer to GradientDescentTrainer. This is just a cosmetic change that makes the trainer code consistent with other naming conventions in the library, and makes it so we can use Trainer annotations in places where it's appropriate, instead of the more awkward-looking TrainerBase. I could maybe be persuaded to just make it one class, even, that's still Registrable and can be subclassed (like Embedding is).

  2. Make TensorboardWriter an argument to __init__, and to from_partial_objects (with a Lazy annotation), so I can remove several of the __init__ parameters that are specific to the tensorboard writer. Specifically, that's these parameters:

    summary_interval: int = 100,
    histogram_interval: int = None,
    should_log_parameter_statistics: bool = True,
    should_log_learning_rate: bool = False,

  3. Move the tensorboard logging logic into the TensorboardWriter class, so it's more easily configurable with your own TensorboardWriter:

    if self._tensorboard.should_log_this_batch() and self._master:
    self._tensorboard.log_parameter_and_gradient_statistics(self.model, batch_grad_norm)
    self._tensorboard.log_learning_rates(self.model, self.optimizer)
    self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"])
    self._tensorboard.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()})
    if self._tensorboard.should_log_histograms_this_batch() and self._master:
    self._tensorboard.log_histograms(self.model, histogram_parameters)
    if self._log_batch_size_period:
    batch_group_size = sum(training_util.get_batch_size(batch) for batch in batch_group)
    cumulative_batch_group_size += batch_group_size
    if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
    average = cumulative_batch_group_size / batches_this_epoch
    f"current batch size: {batch_group_size} mean batch size: {average}"
    self._tensorboard.add_train_scalar("current_batch_size", batch_group_size)
    self._tensorboard.add_train_scalar("mean_batch_size", average)
    This also could let us remove this parameter:
    log_batch_size_period: Optional[int] = None,
    Question here: I could generalize TensorboardWriter to handle other kinds of logging, and just add a bunch of methods / calls to it, which might simplify some things. If the only reason we want callbacks is to provide better logging, we could accomplish that with a specific, overridable TrainLogger class, or something, which is a generalization of the TensorboardWriter. Does this would probably also let me completely remove the serialization_dir parameter to the Trainer. It's tempting to also rip out tqdm and roll our own thing somehow (or figure out a way to push it into the TrainLogger or whatever it gets called), as that would simplify the trainer and fix some bugs. Decision on the question: add a couple of simple callbacks, one for end of each batch, and one for end of each epoch. See item 7.

  4. Remove these two parameters, as they are redundant with the Checkpointer (and were already removed from from_partial_objects):

    num_serialized_models_to_keep: int = 20,
    keep_serialized_model_every_num_seconds: int = None,
    I'll make sure that the default Checkpointer behavior is reasonable if none is passed. Also, move this parameter into the Checkpointer and remove it from the Trainer:
    model_save_interval: float = None,

  5. On grad_norm and grad_clipping: I could try to remove these, but I'm not sure it's worth it. This would mean adding some kind of callback, except it's a weird callback, because clipping requires setting hooks on the model before starting training. I'm leaning towards just leaving these two alone; they don't add much complexity to the Trainer, and we don't get issues asking to configure this more than what we already allow. (Nothing to do, crossing it off the list)

  6. I don't think num_gradient_accumulation_steps can be simplified or put into a callback. There are three parameters for distributed training, which I could maybe simplify to one object, but that doesn't really seem worth it. All of the rest of the parameters are either core training parameters or configurable objects that don't seem like they need to be further separated out. (Nothing to do, crossing it off the list)

  7. Add two very simple callbacks, one that gets called at the end of each batch (getting passed the batch inputs and outputs, and the trainer object), and one that gets called at the end of each epoch (getting metrics and the trainer object). This allows for an incredible amount of flexibility, for logging, saving predictions, or whatever you want.

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

Successfully merging a pull request may close this issue.

2 participants