How to get ddp working for custom collate functions #3429
Unanswered
edraizen
asked this question in
DDP / multi-GPU / multi-node
Replies: 1 comment
-
I was able to get it to work by correcting the batch size and using the normal scatter function. It turns out I had manually multiplied the batch size by the number of GPUs, which works for dp but not for ddp. Would it be possible for Pytorch Lightning to scale the batch size for number of GPUs if using dp or warn the user that they should increase it? Also, it would also be useful if there was a Thanks, |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
❓ Questions and Help
Hello,
I am confused about ddp training, batch sizes, and custom collate functions. I am using MinkowskiEngine for sparse 3D image segmentation, which requires batches to consist of lists of minibatches (list of samples) because the data cannot be split randomly or by a set length. However, Pytorch Lighting
scatter
s the data prior to sending it totraining_step
, incorrectly scattering individual samples across different GPUs. Fordp
training, this can be corrected with a custom collate funcation that makes a minibatch for each GPU, where each minibatch has the length of the batch size.ddp
is slightly more confusing. I am first trying to get it work using one node with 8 GPUs, with a batch size of 4.What is your question?
When and where does the DataLoader (and LightningDataModule) initialize and create the data? Does the DataLoader get initialized for each
ddp
process? If so, does this initialize it with same num_workers each time?When I printed it out, DataLoader was initialized 16 (or 32; pl seems to print everything twice) times, all with the same number of GPUs, but distributed_backend became None for half of them. If distributed_backend is None, should I set
num_workers
to 0?If I split data the same as I do for
dp
, each ddp process (8 for 8 GPUs) calls thescatter
function with a list of 8 minibatches, where each minibatch has the batch size number of samples (=> 32 samples over 8 minibatches), but only one device indevice_ids
. Since each ddp process is different and only has access to one GPU it makes sense that there is only one device, but is there a way to make sure thescatter
function only gets one minibatch? Or if there is another way to think about this, I'd really appreciate it!Thanks for your help!
Code
What have you tried?
What's your environment?
Beta Was this translation helpful? Give feedback.
All reactions