Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

WIP: Trainer changes for distributed training #3414

Merged
merged 23 commits into from Nov 21, 2019

Conversation

scarecrow1123
Copy link
Contributor

@scarecrow1123 scarecrow1123 commented Oct 30, 2019

Followup PR to #3390 and #3372 to bring in distributed training support. Following are the major changes done:

  • Workers are spawned using mp.spawn and each worker creates its own Trainer instance
  • Trainer.__init__ wraps up self.model with DistributedDataParallel
  • Logging and metric aggregation are already done in the previous PRs
  • Vocabulary creation in case of distributed training is done before spawning the workers and creating Trainer class

To run distributed training, the trainer needs to have the following flag to be enabled:

{
    "trainer": {
        "distributed": true,
        // ...
    }
}

TODO:

  • Try to reproduce comparable results and share extensive results for existing/selected models
  • Check if other commands like evaluate, predict, fine-tune works well with the new changes
  • Should all the callbacks need to be called from every worker in case callback based training?
  • Should the current dataset readers be changed to support distributed training as well?(to selectively yield data based on their rank)
  • Write tests - would be happy to get some suggestions on how to write tests for this

@DeNeutoy

Copy link
Contributor

@brendan-ai2 brendan-ai2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @scarecrow1123, thanks for the PR! Here's some initial feedback. One quick thing, @DeNeutoy will be out through next week and I'll be out a few days in the beginning of the week as well. Sorry for the reviewing delay!

As for tests, let's start by just ensuring that existing functionality continues to work with multiple GPUs. We have a fairly simplistic test for the old method of doing multi-GPU training here:

def test_trainer_can_run_multiple_gpu(self):
Perhaps that can serve as a starting point.

Possible things to test:

  1. Are the metrics what we expect after a training run? (Do we need some kind of metric aggregation?)
  2. Do the different workers see their portion of the dataset? Of course, this isn't a test of the trainer proper, but more of an end-to-end test. Still seems worth having.
  3. Can we train, save and load? Essentially https://github.com/allenai/allennlp/blob/master/allennlp/common/testing/model_test_case.py#L45, but for the trainer rather than a model. Just pick a standard model and train it with ddp, making assertions similar to those in the linked code. Actually, you may be able to call that helper function directly... Worth verifying.
  4. Can we a overfit a standard model on a trivial dataset when training with ddp?

Also, you'll need to use the annotation @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need multiple GPUs.") around each test which relies on multiple GPUs. We only test these on master checkins given a paucity of GPUs on our integration server, so you'll need to be careful to run these tests locally.

Thanks again!


logging.info("Switching to distributed training mode since multiple GPUs are configured")
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason why this can't/shouldn't be configurable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I did not have this configurable was the assumption that this will only apply to single-node multi-GPU training. These properties will only have significance in a multi-node setup and I did not have such a setup to test till last week. Also as discussed in #2536 , @DeNeutoy suggested not to worry about multi-node training until everything works. But I was able to run distributed training a multi-node setup with only a few more changes today and hence this would be changed to be configurable.

make_vocab_from_params(params.duplicate(), serialization_dir)
params["vocabulary"] = {
"directory_path": os.path.join(serialization_dir, "vocabulary"),
"extend": False, # vocab extension would have been done above
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe s/above/in make_vocab_from_params/ for clarity?

f"Elmo object has num_output_representations={len(self._elmo._scalar_mixes)}, but this "
f"does not match the number of use_*_elmo flags set to true. use_input_elmo "
f"is {self._use_input_elmo}, and use_integrator_output_elmo "
f"is {self._use_integrator_output_elmo}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: Did black force this change? If not, while I appreciate cleanups, it's probably best to leave them for another PR just to keep the diff small.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this was enforced by black and hence the change to pass the CI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was enforced by black and hence the change to pass the CI.

@@ -146,6 +148,18 @@ def __init__(
# For capturing errors that occur during the train loop.
self.exception: Optional[Exception] = None

# Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its
# usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model`
# will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this rings some alarm bells. Have you and @DeNeutoy talked much about how metrics will work across processes? The regularization penalty also seems like it could be an issue...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you and @DeNeutoy talked much about how metrics will work across processes?

Metrics aggregation across processes is handled in training_util.get_metrics already which has been done in #3372 . So that part should just be fine for now I guess.

The regularization penalty also seems like it could be an issue...

I'm not sure if that would be an issue as the penalty is computed from the named_parameters of the model. Since the gradients are synced across workers before the optimization step, this might just work well as is. Please correct me if I'm wrong.

The reason I added the said comment is primarily to highlight the inconsistencies that may arise if we wrap allennlp.model.Model with DistributedDataParallel and use the same object throughout the trainer. Wrapping is necessary and it would break the Model specific interface methods such as get_regularization_penalty, etc. Hence the wrapped reference with a different name (self._pytorch_model) will only be used for pytorch specific ops such as model forward, etc. and the usages will only be through the actual Model instance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense. Thanks for explaining!

vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
# Save the vocab only in the master
if not dist.is_initialized() or dist.get_rank() == 0:
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even need to do this given that we've called make_vocab_from_params directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be done in case of non-distributed training as the vocab creation wouldn't be done up until this point. Whereas in the distributed case, it is done before spawning the workers since the entire view of the data is not available to the workers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry. I think I was confused as to the original purpose of "# Initializing the model can have side effect of expanding the vocabulary", but that makes sense now.

@scarecrow1123
Copy link
Contributor Author

scarecrow1123 commented Nov 12, 2019

@brendan-ai2 Thank you for the suggestions and apologies for the delay in my response. I'm to yet start working on the tests. As I've mentioned in one of the above replies, I've also done a few changes to make the distributed training work in a multi-node multi-GPU setup. Adding those commits as well. Below would be the configuration for doing multi-node training:

"trainer": {
        "distributed": true,
        "num_nodes": 2,
        "master_address": "172.21.197.156",
        "master_port": 29500,
        // ...
}

And the command would have an extra argument to it:

In node 1:

allennlp train experiment.jsonnet -s output --node-rank 0

In node 2:

allennlp train experiment.jsonnet -s output --node-rank 1

Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically LGTM!

We need to add some integration type tests at some point for this - it's going to be hard to unit test because of all the machines etc required. So I think for now it looks good!

@@ -19,4 +21,8 @@ def run():


if __name__ == "__main__":
# First, let Pytorch's multiprocessing module know how to create child processes.
# Refer https://docs.python.org/3.7/library/multiprocessing.html#multiprocessing.set_start_method
torch.multiprocessing.set_start_method("spawn")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a note to the Trainer that setting this is required if you aren't using allennlp as the entry point?

@DeNeutoy
Copy link
Contributor

Oh sorry I missed @brendan-ai2's note about adding tests to this PR. This seems like a good idea, I forgot we had tests for the multi-gpu stuff which we run occasionally already.

@brendan-ai2
Copy link
Contributor

Yeah, if we can at least test the single-node stuff in this PR, I'll be happy :)

@brendan-ai2
Copy link
Contributor

Hi @scarecrow1123, just wanted to check in on this PR. In case you're busy with other things, @DeNeutoy and I were discussing merging this into the branch as-is and then we could work on some of the tests ourselves. Definitely don't want to step on your toes, but I'd love to keep the momentum here going. Thanks again for all your work on this! :)

@scarecrow1123
Copy link
Contributor Author

Hey @brendan-ai2 I had planned to bring this PR to a closure over the last weekend and I couldn't do that. Apologies for dragging this up a bit. (Basically it is the end of fall term in grad school + full time work ;) ) I'll work on this the coming weekend, but if you guys think it would be better to merge it to the branch before that, please go ahead. Either of the options would be good as anyways I'll do a catch up with more tests myself. Thanks for the nudge!

@brendan-ai2
Copy link
Contributor

Oof, grad school + full time work is a lot! Thanks again for finding time to work on this. I'll merge now and see if I can't write a few tests this afternoon.

@brendan-ai2 brendan-ai2 merged commit 8d004f2 into allenai:torch-distributed Nov 21, 2019
@DeNeutoy DeNeutoy mentioned this pull request Dec 16, 2019
DeNeutoy added a commit that referenced this pull request Dec 17, 2019
* Logging and metrics changes for distributed training (#3372)

* Refactor logging setup to support distributed attrs

* `cleanup_logging()` is replaced with stdlib's `logging.shutdown()`
* Remove `TeeLogger` and use standard log handlers
* Remove `replace_cr_with_newline` and use the standard logging practice of using
`logging.Filter`
* Introduce `rank` and `world_size` optional attributes to support
distributed workers

* Support for distributed training in `get_metrics`

* Remove bad import

* Fix duplicate log messages in stdout

* Remove preemptive `logging.shutdown`

`logging.shutdown` is called by the logging module
by default during exit which makes it unnecessary to
be called from `train_model`

* Fix black formatting issues

* Remove `tee_logger` references in API doc

* Set log level from `ALLENNLP_DEBUG` env

* Changes to `train_model` for distributed training support (#3390)

* High level API changes to support distributed training

* Fix flake8 error

* Fix mypy error

* Add docstring and misc fixes

* Fix flake tests

* `Trainer` changes for distributed training (#3414)

Followup PR to #3390 and #3372 to bring in distributed training support. Following are the major changes done:

* Workers are spawned using `mp.spawn` and each worker creates its own `Trainer` instance
* `Trainer.__init__` wraps up `self.model` with `DistributedDataParallel`
*  Logging and metric aggregation are already done in the previous PRs
* `Vocabulary` creation in case of distributed training is done before spawning the workers and creating `Trainer` class

To run distributed training, the trainer needs to have the following flag to be enabled:

```jsonnet
{
    "trainer": {
        "distributed": true,
        // ...
    }
}
```

TODO:
* Try to reproduce comparable results and share extensive results for existing/selected models
* Check if other commands like `evaluate`, `predict`, `fine-tune` works well with the new changes
* Should all the callbacks need to be called from every worker in case callback based training?
* Should the current dataset readers be changed to support distributed training as well?(to selectively yield data based on their rank)
* Write tests - _would be happy to get some suggestions on how to write tests for this_

* Dist tests (#3515)

* add some tests

* another test, fix incorrect type annotations

* torch mp uses it's own context, no need to set default

* lint

* strip out old DP stuff, ensure multiple cuda devices raises err… (#3516)

* strip out old DP stuff, ensure multiple cuda devices raises errors

* lint

* remove unused attribute

* remove _cuda_devices everywhere

* fixes

* move distributed config up to top level

* lint

* clean up

* rename occurences of batch_group

* remove hack from find_learning_rate

* fix last tests

* black

* use a top level distributed config

* correct error for int

* change up parse_cuda_devices to raise good error and be strongly typed

* fix merge
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants