Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove scattering for multi-GPU training. #2200

Merged
merged 87 commits into from Jan 18, 2019

Conversation

Projects
None yet
4 participants
@brendan-ai2
Copy link
Member

commented Dec 18, 2018

  • Instead just pull off a batch for each GPU.
  • Enables increasing the effective batch size for bidirectional_language_model.jsonnet by 2x giving a 1.5x speedup.

brendan-ai2 added some commits Nov 21, 2018

Transformer ELMo
- Configuration for training a transformer based bidirectional LM.
  - Training ongoing with sampled loss currently at 3.8411.
- Minor fixes to CnnHighwayEncoder.
  - LayerNorm was needed instead of MaskedLayerNorm.
- Log average batch size during training.
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
…an-ai2/allennlp into lm_without_dataset_modifications_3

brendan-ai2 added some commits Jan 17, 2019

Fix

@brendan-ai2 brendan-ai2 requested a review from matt-gardner Jan 17, 2019

@@ -34,7 +34,7 @@ local BASE_ITERATOR = {
// samples in every batch.
"batch_size": 512 * NUM_GPUS,
"sorting_keys": [["source", "num_tokens"]],
"maximum_samples_per_batch": ["num_tokens", NUM_GPUS * 1000]
"maximum_samples_per_batch": ["num_tokens", 2000]

This comment has been minimized.

Copy link
@brendan-ai2

brendan-ai2 Jan 17, 2019

Author Member

There's a minor backwards compatibility issue here. We're effectively multiplying the batch size (for multi-GPU users) by the number of GPUs. In practice this will result in some OOMs for users that were running close to their memory limits. Given that we had an experimental warning for that use case I think this okay, but I'm curious if you have other thoughts.

This comment has been minimized.

Copy link
@matt-gardner

matt-gardner Jan 17, 2019

Member

This seems fine to me, too.

This comment has been minimized.

Copy link
@brendan-ai2

brendan-ai2 Jan 18, 2019

Author Member

Thanks.

@brendan-ai2

This comment has been minimized.

Copy link
Member Author

commented Jan 17, 2019

fyi, @vidurj you should be able to merge this down if you need it ASAP.

@sai-prasanna

This comment has been minimized.

Copy link
Contributor

commented Jan 17, 2019

@brendan-ai2 Do you get better multi GPU utilization through this method?
For a sequence to sequence model, neither the current implementation or this implementation gets full utilization for me. I use fairseq models imported into allennlp.

But using fairseq directly makes training faster as expected when using bigger batch sizes. Fairseq uses distributed data parallel with multi processing. Idk what bottlenecks would be in current dataparallel approach we use. I suspect GIL, even though operations are done in cuda, the instructions are from python which might make GIL bottleneck as models like fairseq CNN have lots of code in python.

Even torch docs states it might be the case. https://pytorch.org/docs/stable/distributed.html

In the single-machine synchronous case, torch.distributed or the torch.nn.parallel.DistributedDataParallel() wrapper may still have advantages over other approaches to data-parallelism, including torch.nn.DataParallel():

  • Each process maintains its own optimizer and performs a complete optimization step with each iteration. While this may appear redundant, since the gradients have already been gathered together and averaged across processes and are thus the same for every process, this means that no parameter broadcast step is needed, reducing time spent transferring tensors between nodes.
  • Each process contains an independent Python interpreter, eliminating the extra interpreter overhead and “GIL-thrashing” that comes from driving several execution threads, model replicas, or GPUs from a single Python process. This is especially important for models that make heavy use of the Python runtime, including models with recurrent layers or many small components.
@matt-gardner
Copy link
Member

left a comment

I had one minor question, but other than that the code looks fine to me, and better than what was here previously. There's still the issue that @sai-prasanna brings up, but figuring out how to make this faster should be a separate issue (and one that I have no experience with).

Oh, we had a custom scatter_kwargs function, right? That should be deleted now, shouldn't it? It was broken, anyway - the only reason we made it custom instead of using pytorch's version didn't actually work.


used_device_ids = cuda_devices[:len(inputs)]
inputs = [()] * len(batch_group)

This comment has been minimized.

Copy link
@matt-gardner

matt-gardner Jan 17, 2019

Member

inputs is supposed to be a list of empty tuples? This never gets updated before getting passed to parallel_apply.

This comment has been minimized.

Copy link
@brendan-ai2

brendan-ai2 Jan 18, 2019

Author Member

Added a comment to clarify. You can see that () was passed to the old scatter_kwargs as well.

# We pass all our arguments as kwargs. Create a list of empty tuples of the
# correct shape to serve as (non-existent) positional arguments.
@@ -34,7 +34,7 @@ local BASE_ITERATOR = {
// samples in every batch.
"batch_size": 512 * NUM_GPUS,
"sorting_keys": [["source", "num_tokens"]],
"maximum_samples_per_batch": ["num_tokens", NUM_GPUS * 1000]
"maximum_samples_per_batch": ["num_tokens", 2000]

This comment has been minimized.

Copy link
@matt-gardner

matt-gardner Jan 17, 2019

Member

This seems fine to me, too.

@matt-peters

This comment has been minimized.

Copy link
Contributor

commented Jan 17, 2019

@sai-prasanna good ideas, thanks for the input. Can you provide some more details about how you integrate allennlp with fairseq? How do you train the model -- with fairseq or the allennlp Trainer?

@brendan-ai2

This comment has been minimized.

Copy link
Member Author

commented Jan 18, 2019

@sai-prasanna, the main benefit of this PR is that it allows one to have larger batches (and thus train faster). For reasons I don't entirely understand our scatter_kwargs implementation seemed to result in decidedly unbalanced GPU memory usage. Utilization seemed marginally better with this change, but I didn't look at that closely. In general I don't think we can promise full utilization, so we'll need to look at things in more detail if that's a major issue for you.

Your points about using torch.distributed are well taken! :) We'd definitely like to investigate that more, but that's out of scope for this PR. Would you be willing to open an issue with your insights and/or requests?

brendan-ai2 added some commits Jan 18, 2019

@brendan-ai2

This comment has been minimized.

Copy link
Member Author

commented Jan 18, 2019

@matt-gardner , thanks for the review! I can delete scatter_kwargs, but is it worth deprecating first? It's not clear to me if it was obviously strictly internal.

@matt-gardner

This comment has been minimized.

Copy link
Member

commented Jan 18, 2019

Re: deleting the method, it was part of experimental behavior. We added the method (instead of using pytorch's version) for one purpose (to handle complex data structures) for which it didn't actually work, as evidenced by my failing test (I guess it worked for Metadata, but not more complex stuff?). So if anyone was using it externally (which seems very unlikely), it was broken for them too. I'd say to just remove it. And with it, we can probably also remove the ScatterableList.

@sai-prasanna

This comment has been minimized.

Copy link
Contributor

commented Jan 18, 2019

@matt-gardner Yeah, it should be a separate issue. I thought this commit had an effect on performance.

@matt-peters I am using allennlp model and trainer https://gist.github.com/sai-prasanna/9b02b282894a3b01647c8704dc28b013 which has poor performance. Tested it on two different multi GPU machines with 3 1080Tis.

Our team compared with using fairseq's default trainers separately. Fairseq uses its own training flow where dataset -> token idx preprocessing happens first. Then they use the idx to form tensors directly for training. I controlled for that affecting the speed by using multiprocess dataiterator in allennlp.

The performance difference is stark. Fairseq has better performance (1.4-1.5x) on two GPUs, But allennlp trainer is slower than single gpu training.

We will be trying to make allennlp trainer to use datadistributed immediately if the changes are simple and test out the performance.

@matt-gardner

This comment has been minimized.

Copy link
Member

commented Jan 18, 2019

@sai-prasanna, if you can figure out ways to make our multi-GPU code work faster, the help would be greatly appreciated. We're a very small team, and we have a lot of other things to focus on. Any help diagnosing particular issues or giving recommendations (or PRs!) on how to make things faster would be great.

@brendan-ai2

This comment has been minimized.

Copy link
Member Author

commented Jan 18, 2019

@sai-prasanna, an important point of clarification: This PR does have an affect on runtime performance. For the language modeling task described in training_config/bidirectional_language_model.jsonnet it effectively increased the batch size by 2x. This led to epochs taking only 67% of the time they did previously, i.e. a 1.5x speedup.

Of course, I can't guarantee every model will see such improvements, but it might be worth double checking your batch size (or maximum_samples_per_batch if you're using that) to see if it can be increased.

Thanks again for the feedback!

@brendan-ai2

This comment has been minimized.

Copy link
Member Author

commented Jan 18, 2019

Thanks for the review, @matt-gardner! (ScatterableList and friends deleted as requested.)

@brendan-ai2 brendan-ai2 merged commit 7525c61 into allenai:master Jan 18, 2019

2 of 3 checks passed

codecov/patch 83% of diff hit (target 90%)
Details
Pull Requests (AllenNLP Library) TeamCity build finished
Details
codecov/project 92% (+<1%) compared to d0a5a40
Details
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.