Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Using DistributedDataParallel for multi GPU training #2536

Closed
scarecrow1123 opened this issue Feb 22, 2019 · 23 comments
Closed

Using DistributedDataParallel for multi GPU training #2536

scarecrow1123 opened this issue Feb 22, 2019 · 23 comments
Milestone

Comments

@scarecrow1123
Copy link
Contributor

I'm looking into the codebase to possibly use torch's DistributedDataParallel for multi GPU training. Based on the docs, it certainly would improve the training speed when compared with DataParallel. I'm only looking into single node - multi GPU case for now. I would like to get a heads up on possible caveats that I could face when integrating it into the current Trainer class. I am only worried about goofing up any stateful parts of the code that could result in leaks. Would be great if someone could give pointers to such instances if any.

@joelgrus
Copy link
Contributor

we recently refactored the Trainer class to make it easier to implement different trainer variants.

so probably my recommendation would be to implement the simplest possible ParallelTrainer subclass, see how well it works, and then after that we could decide if/how to incorporate it into the library.

does that make sense?

@scarecrow1123
Copy link
Contributor Author

Thanks for the heads up @joelgrus . I did a preliminary analysis on how something like a ParallelTrainer could be achieved. Here are a few observations and questions based on what I understood:

  • Essentially, a single worker function, say ParallelTrainer.train shall be spawned as separate processes
  • Optimizer will run independently and hence the speed up when compared to DataParallel
  • The main caveat here could be sampling subsets of mini-batches across the workers. In DataParallel the mini-batch gets split and sent from a single place. However, in this case the different processes themselves have to sample them and hence things like shuffling and the indices to sample must be deterministic across processes. The pytorch way is to pass DistributedSampler to the DataLoader which uses the local rank of the GPU process(simply a process ID) to choose the index of a data point. However in AllenNLP's case, I assume that there is no such provision. I'm still thinking of what could be a potential solution for that. Would it be good to introduce a Sampler notion in the current data iterators? Or have I missed anything?

@scarecrow1123
Copy link
Contributor Author

So I've been fiddling with the code and hacked around to make DistributedDataParallel(DDP) work with a standard AllenNLP setup. The interesting thing is that I'm able to see only small improvements from DDP when compared with the DataParallel(DP) one. The speedup obtained when compared with a single GPU run is good. Here are some numbers that I've got so far for a DeepSpeech model that I have implemented.

Dataset: Librispeech 100 hour
GPU: 2080Ti on Ubuntu
(Time reported is for one epoch)

Mini-batch size - 32
Single GPU - 17m 40s
DDP 4x GPUs
32 * 4 mini-batch - 06m 27s
8 * 4 mini-batch - 12m 14s
DP 4x GPUs
32 * 4 - 06m 34s
8 * 4 - 12m 32s
Mini-batch size - 16
Single GPU - 21m 40s
DDP 4x GPUs
16 * 4 - 07m 49s
4 * 4 - 17m 03s
DP 4x GPUs
16 * 4 - 08m 00s
4 * 4 - 16m 50s
Mini-batch size - 32
DDP 2x GPUs
32 * 2 -10m 15s
16 * 2 - 12m 30s
DP 2x GPUs
32 * 2 - 10m 47s
16 * 2 - 13m 30s

I'll post my WIP code over the weekend and I know only then I may be able to receive some insights. However, my main concern here is about DP working so well while I've seen people complaining about no speedups with DP. One major difference with AllenNLP's DP that I've observed is that the loss calculation happens at respective GPUs unlike the standard PyTorch way of calculating the criterion in a single GPU. This might reduce some tensor reductions and hence there is an enhanced performance even with DP?

Another advice that I would like to get is what could be a benchmark experiment that I can perform so that the numbers are right? This is my first time working with multiple GPUs and I guess I need to get the load right so as to report valid numbers. Would BiDAF be a proper load to gather the numbers? Or a language model perhaps?

Would like to hear from folks @2200 too!
@joelgrus @matt-gardner

@mihail911
Copy link

@scarecrow1123 Did you ever get this working completely? I'd be curious to hear whether you were able to get the speedups you had hoped for or any strategies you used for implementation.

@scarecrow1123
Copy link
Contributor Author

scarecrow1123 commented Sep 28, 2019

@mihail911 I have a working version in my fork here. This tries to patch the existing trainer with distributed & fp16 cases. Following are the broad changes that I've done to make it work:

Logging

Modify the current logging setup so that the workers also can log to stdout.log and stderr.log. The workers now log to a multiprocessing.Queue object. The master allennlp process listens to this queue and logs to the files and stdout.

train command & Trainer
  • Split the train_model function in the existing train.py and introduce a train_worker function which is the entry point for the workers. Each worker modifies the params to assign their corresponding GPU device and separate Trainer instances are created.
  • Trainer receives the following extra parameters for obvious reasons:
    • rank: int
    • world_size: int
    • distributed: int (may not be necessary)
    • mixed_precision: bool
  • training_util.get_metrics is modified to aggregate metrics from the distributed setup
fp16

Amp is used to initialize the model for fp16. Note: masked_log_softmax function may have to be converted to a model layer to make it work with amp.

Dataset reader

When doing distributed training, the worker processes have to selectively choose the data. i.e., If the dataset is an array of [0, 1, 2, 3] and there are 4 workers, every worker only feeds one data item to the iterator. With the current reader, I'm not entirely sure on how this can be achieved. For illustration, I've made a distributed version of the snli reader here and used the process rank/world_size to selectively read the input data.

My suggestion for solving the dataset related issue would be to start using the native PyTorch Dataset and DataLoader classes for reading data and retain the current AllenNLP Iterator class to do batching in place of PyTorch's collate function. This is primarily because Dataset enforces a proper definition of a dataset. The reading/pulling data out of it is decoupled out of the dataset definition and given to the DataLoader. This decoupling also makes it easy to sample the data selectively.

Another issue with my current setup: If you're running a training for the first time without the dataset/embeddings in your machine, in the distributed setup, each worker would download the same copy of the data. I haven't handled this case and hence if you are trying it out, download the resources separately and start the training.

P.S: To do distributed training the data has to be evenly distributed across the GPUs. In other words, each worker should only process the same number of batches for every epoch.

P.P.S: This works only for single-node multi-GPU setup currently. Making it work for multi-node multi-GPU case would involve a few minor changes I reckon, especially with respect to configuring the co-ordinator IP and port.

P.P.P.S: The current version may break at times as I'm still trying to smooth off the rough edges.

@joelgrus @matt-gardner It would be super useful if you guys could give some feedback on this to make this a PR candidate.

@matt-gardner
Copy link
Contributor

@brendan-ai2, @DeNeutoy, I think you two have more context on this one than @joelgrus or me. What do you think?

@DeNeutoy
Copy link
Contributor

@scarecrow1123 I took a look at this - looks good. Here are my top level comments:

  1. We do not need to worry about the dataset readers reading multiple copies of the data. People who care about this can write new dataset readers in exactly the way you have for the snli reader.

  2. Your idea of using multiprocessing.spawn inside train.py is smart - this is a good design.

  3. Multi-node, multi-gpu should be a non-objective for now until we have the single node version working.

  4. Again, the dataset caching is annoying but is definitely not a blocker to merging into allennlp.

  5. I'm not sure your point(first P.S) is correct - in many cases it's fine to just read streams either randomly, or as fast as each gpu can process them.

One thing that I didn't quite follow with your diff was the logging? Could you give some examples of what that looks like? It might be nice to be able to stream the logs from different workers to different files, or something like that.

I'm in two minds about whether this needs to wait until we switch to the torch Dataset and DataLoader model. Do you have some opinion about that?

The try except blocks in your PR scare me a little - were these just for debugging and have you used this to successfully train a model in DDP + half-precision?

After we've discussed this a little more, I might suggest the following PRs:

PR 1: Distributed logging and metric aggregation

PR 2: Changes to trainer and train.py to support multiple worker processes

PR 3: additional changes to support AMP

@brendan-ai2, would be good to get your review of @scarecrow1123's diff too!

Just as a note I am on holiday from 31st Sept - 6th Oct, so I might be a bit slow replying here. However, @scarecrow1123 your work is very much welcome and i'd love to find a way to merge it into allennlp.

@DeNeutoy DeNeutoy added this to the 1.0.0 milestone Sep 30, 2019
@brendan-ai2
Copy link
Contributor

@scarecrow1123, thanks for working on this! The diff will take me a bit of time to go over carefully, but I'll try to have some high level feedback tomorrow.

@scarecrow1123
Copy link
Contributor Author

Thanks for the detailed feedback @DeNeutoy .

I'm not sure your point(first P.S) is correct - in many cases it's fine to just read streams either randomly, or as fast as each gpu can process them.

I think I was not clear in when I mentioned about the number of batches. From my understanding, model forward/backward are synchronization points in DistributedDataParallel. From the docs

Constructor, forward method, and differentiation of the output (or a function of the output of this module) is a distributed synchronization point. Take that into account in case different processes might be executing different code.

So let's take this case of a 2 GPU setup and say worker 0 has 5 and worker 1 has 4 batches in total. Once 4 batches are computed worker 1 is going to exit the iteration whereas worker 0 will try to call model.forward. Since this method is a synchronization point, worker 0 is going to wait forever for worker 1 to join. This cannot happen and hence the number of batches that the workers compute needs to be the same. This is a fundamental constraint and in the example esim experiment in the fork, I've configured the number of total samples to be an even number to avoid this problem.

Another option to make this work would be to duplicate few training examples to even out the number of batches. This is what PyTorch's DistributedSampler does.

Could you give some examples of what that looks like? It might be nice to be able to stream the logs from different workers to different files, or something like that.

Workers logging to their respective log files would be pretty straightforward with AllenNLP's existing logging setup. However it could be more useful if they logged to the same stdout/stderr logs. Here's a gist with an example program to test the logging setup present in the fork. I've written a mini tutorial of sorts to explain this case here.

Basically if the workers need to log to a same file,

  • there can't be multiple Tqdm bars. Hence only the progress bar of master (worker with rank=0) process is shown and of course the metrics are aggregated before updating tqdm
  • the workers can't use Python's FileHandler directly to log to the same file as this would corrupt the log file. So the workers log their messages only to a multiprocessing.queue. Refer prepare_worker_logging method.
  • the master allennlp process sets this queue up initially
  • the worker's log messages are always pushed to the queue object. The master allennlp process keeps listening to this queue and passes the messages to respective file handlers.

Refer to the original docs for QueueHandler, QueueListener and a list of useful handlers.

I'm in two minds about whether this needs to wait until we switch to the torch Dataset and DataLoader model. Do you have some opinion about that?

My only qualm is the current DatasetReader abstraction would be completely unaware of the distributed nature of the training in general. IMO, using the rank to skip iterations inside the read method wouldn't be very clean. It'd be great if there is a graceful alternative to do this, but I'm not too sure.

The try except blocks in your PR scare me a little - were these just for debugging and have you used this to successfully train a model in DDP + half-precision?

I'll try to run existing AllenNLP experiments and share you the results in a few days.

Your PR suggestions make sense. I'll wait for your further comments and Brendan's comments too.

@scarecrow1123
Copy link
Contributor Author

@brendan-ai2 @DeNeutoy Were you able to have a closer look at the code? I'm looking forward to start making some PRs for this. I've also started doing some comparisons between the current implementation and the one in the fork. You can find the numbers and the corresponding experiments in this repo. The numbers are in no way complete and I'll be adding more of them along with accuracy stats.

@DeNeutoy
Copy link
Contributor

DeNeutoy commented Oct 15, 2019

@scarecrow1123 Looking good! I'll do another review of this tomorrow before we go ahead. In the meantime, would you mind getting the single GPU number filled out for Esim in your repo, so we have that baseline?

Sorry for the delay on this! I'm back from holiday now so I should be more responsive.

@scarecrow1123
Copy link
Contributor Author

would you mind getting the single GPU number filled out for Esim in your repo

@DeNeutoy I've added single GPU numbers and also comparison for Bidaf experiment.

@DeNeutoy
Copy link
Contributor

DeNeutoy commented Oct 15, 2019

@scarecrow1123 sweet! Looks like a good speed up. Just a sanity check - what numbers are you getting for Acc/F1 for those runs? I'm assuming they are similar to the single gpu numbers. If so, we can start the PRs I think, starting with the logging, then the changes to the trainer/dataset readers.

In particular, we should leave out the amp stuff for now, because 1) it's a small code change and should be easy to add in later, and 2) for various reasons it makes packaging a lot harder for us, as there are C extensions which it has to build which require GPUs. So for now let's focus on the DDP stuff. (just to check, your runs don't include using fp16/amp, right?)

@brendan-ai2
Copy link
Contributor

@scarecrow1123, thanks again for this! Could you upload the model output folder for those runs, so we can take a peak at the models and vocabs produced?

Also, just to confirm, the numbers from https://github.com/scarecrow1123/allennlp-distributed-training/blob/master/README.md under "4x Data Parallel" are produced using AllenNLP's current (limited) multi-GPU support, correct? (That is, the support available simply by using a list of device ids for cuda_device in the trainer config?) I've recently made some fixes to the the multiprocess dataset reader and iterator that may be relevant here. (See https://github.com/allenai/allennlp/blob/master/training_config/bidirectional_language_model.jsonnet for an example of how to use those.) I strongly suspect your code will be faster anyway, but it would be great to additionally try using that reader and iterator to see how the performance changes. Ideally for both the "4x Distributed" and "4x Data Parallel" variants.

@scarecrow1123
Copy link
Contributor Author

I'm assuming they are similar to the single gpu numbers.

@DeNeutoy Right now I'm seeing a bit of overfitting happening in the 4x experiments, presumably because of 4x larger batches. I'll have to do single GPU experiments with 4x larger batches to get a proper comparison. I'll try to get the accuracy stats for all these combinations and get back to you along with the model outputs as @brendan-ai2 has asked.

your runs don't include using fp16/amp, right?

Yes my runs do not include amp. I omitted it mainly because of masked_log_softmax which I'll happily ignore for now ;)

"4x Data Parallel" are produced using AllenNLP's current (limited) multi-GPU support, correct?

Yes that's right. Those numbers are using upstream AllenNLP HEAD. Only the distributed version is run from the fork.

but it would be great to additionally try using that reader and iterator to see how the performance changes

Fair point. Let me try running some experiments with multiprocess reader and get back.

@DeNeutoy
Copy link
Contributor

@scarecrow1123 I did some benchmarking with the new torch dataset loaders vs our current multiprocess implementation, and it doesn't look like the multiprocessing in allennlp will provide a speedup for your code. So don't worry about running those experiments which include using the MultiProcessDatasetReader/ MultiProcessIterator.

@brendan-ai2 and I will set up an upstream branch that we will collect these big changes into, before we merge them into master. This branch will also include the changes I make to support the IterableDataset and DataLoader from torch (r.e #3079)

@brendan-ai2
Copy link
Contributor

Thanks for running those benchmarks, @DeNeutoy! In case anyone is curious about the details (and for my own memory), what Mark showed with master...DeNeutoy:benchmark was that the time spent indexing and tensorizing with the SnliReader was negligible. This indicates that the speedup shown in https://github.com/scarecrow1123/allennlp-distributed-training/blob/master/README.md for esim.jsonnet between "4x Data Parallel" and "4x Distributed" should be essentially entirely attributable to the training proper.

@DeNeutoy
Copy link
Contributor

Full details of that benchmark available in #3079 💯

@brendan-ai2
Copy link
Contributor

@scarecrow1123, please open your PRs against https://github.com/allenai/allennlp/tree/torch-distributed when ready. :) It's even with master right now.

@scarecrow1123
Copy link
Contributor Author

@brendan-ai2 @DeNeutoy Thanks for the clarification(one item cleared from my todo! ;) ).
Here's the entire serialization directory of a sample esim experiment which was run for just 2 epochs. This will help you to verify the vocab and model files produced.

I'm still seeing an accuracy dip with DistributedDataParallel which to my understanding could be because of learning rate vs batch size mismatch. But I'm not making assumptions here and hence trying to figure out if there are any mistakes there in my code. Once I do a self review, I'll start with the first PR. Thanks again!

@scarecrow1123
Copy link
Contributor Author

scarecrow1123 commented Oct 17, 2019

The dip in accuracy/loss metrics has been entirely due to inconsistent vocabulary indices across workers. To explain it in detail, I had modified the DatasetReader._read method to selectively filter instances across GPUs based on rank and world_size attributes as below:

# distributed snli reader
def _read(self):
    ...
    for idx, line in enumerate(snli_file):
        if idx % world_size == rank:
            continue
        yield instance

This works well during the actual training. However, this selective filtering will also be applied during vocabulary creation before the training actually starts and hence the workers will be dealing with different vocabularies. I'm not sure how to actually get over this issue. As DatasetReader, Trainer and Vocabulary creation parts are pretty much disconnected and from what I see there is no way to communicate to the DatasetReader on when (not) to use selective filtering.

To test the trainer code, I just created the vocabulary before hand using allennlp make-vocab command with the selective sampling switched off manually. During training, sampling as shown in the above snippet is applied and everything works just fine with comparable accuracy scores. Any suggestions here @DeNeutoy that I may have completely overlooked?

@scarecrow1123
Copy link
Contributor Author

We may have to move vocabulary creation out of TrainerPieces.from_params to alleviate the above issue. Once #3372 are reviewed, I shall do a PR to the same branch to sort out vocab related changes and proceed with trainer changes perhaps.

@DeNeutoy
Copy link
Contributor

This has landed in #3529. Huge thanks to @scarecrow1123!!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants