Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data parallel error with O2 and not O1 #227

Closed
aclyde11 opened this issue Mar 28, 2019 · 32 comments
Closed

Data parallel error with O2 and not O1 #227

aclyde11 opened this issue Mar 28, 2019 · 32 comments

Comments

@aclyde11
Copy link

When using O2, data parallel does not work:
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

however with O1, everything works just fine.

model = GeneralVae(encoder, decoder, rep_size=500).cuda()
optimizer = optim.Adam(model.parameters(), lr=LR)
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
if data_para and torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
    model = model.cuda()

loss_picture = customLoss()

val_losses = []
train_losses = []

def train(epoch):
    train_loader_food = generate_data_loader(train_root, get_batch_size(epoch), int(rampDataSize * data_size))
    print("Epoch {}: batch_size {}".format(epoch, get_batch_size(epoch)))
    model.train()
    train_loss = 0
    loss = None
    for batch_idx, (data, _, aff) in enumerate(train_loader_food):
        data = data[0].cuda(0)
@tullie
Copy link

tullie commented Mar 29, 2019

I'm running into the same error for O0, O2, O3:
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

O1 is working as expected.

@jshanna100
Copy link

Same.

@zplizzi
Copy link

zplizzi commented Mar 31, 2019

Same here, at least for O2. O1 does work.

@mcarilli
Copy link
Contributor

In general we strongly recommend DistributedDataParallel (either torch.nn.parallel.DistributedDataParallel or apex.parallel.DistributedDataParallel) over DataParallel, because global interpreter lock sharing within a single process is not great for performance. Currently, I don't test with or claim to support DataParallel. If you are open to trying DistributedDataParallel, I have a simple example showing proper DDP initialization and launch. The Imagenet example also shows DDP use along with distributed data sampling.

That being said, I don't think DataParallel is fundamentally incompatible with Amp control flow. I see one potential problem with your code above: you are calling .cuda on the model after it's been returned from amp.initialize. You should be doing things in the following order:

model.cuda() # Cuda-ing your model should occur before the call to amp.initialize
model, optimizer = amp.initialize(model, optimizer)
model = nn.DataParallel(model)

Try this and let me know if it works.

The fact that cuda-ing your model should occur before amp.initialize is a general truth, independent of DataParallel or DistributedDataParallel. However, I can't really set up hard checks for that, because people may legitimately want part of their model to remain on the CPU.

@aclyde11
Copy link
Author

aclyde11 commented Apr 1, 2019

@mcarilli I tried the fix above but it still produces the same error:
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

I will transition to distributed data parallel

@mcarilli
Copy link
Contributor

mcarilli commented Apr 24, 2019

Historically we only test with DistributedDataParallel because performance tends to be better, but the dataset sharing issue raised by @seongwook-ham in #269 is a compelling use case. @ptrblck and I will look into it. Current to-do list is better fused optimizers, checkpointing, sparse gradients, and then DataParallel, so it may be a couple weeks before I can give it undivided attention.

@seongwook-ham
Copy link

i find that old api(FP16_Optimizer) works well with nn.dataparalllel
if you need dataparallel, you could use it with FP16_Optimizer and model.half()

@mcarilli
Copy link
Contributor

mcarilli commented May 13, 2019

Please don't use the old FP16_Optimizer API, that may break at any time. It might already be broken.

In general, O1 is preferable over O2 anyway, so if O1 works with DataParallel currently, using DataParallel + O1 with the current API is a much better workaround than DataParallel + reverting to the old FP16_Optimizer.

We do want to support O2 + DataParallel but haven't gotten a chance to look at it yet. @seongwook-ham I am interested to hear if people have compelling reasons for requiring O2 (or FP16_Optimizer) over O1, because O1 is safer in general. Is O1 significantly slower?

@seongwook-ham
Copy link

seongwook-ham commented May 14, 2019

yes new api o1 is significantly slower than old api with FP16_Optimizer and half in nn.dataparallel case.
in this case 1.41it/s vs 2.4it/s. in same setting fp32 is 1.2it/s
also in apex distributed dataparallel case, o1 with adam is significantly slower than o2 with fusedadam
2.9 its/s vs 4.58 its/s
this test is based on modified version of https://github.com/huggingface/pytorch-pretrained-BERT
in environment with 6950x 3x2080ti 1xtitan rtx
when tested on dgx-1(8xv100(32G NVlink)), result is similar

@greathope
Copy link

Same

1 similar comment
@mwyborski
Copy link

Same

@askerlee
Copy link

@mcarilli Unfortunately I'm using FusedAdam, which requires O2. Seems a deadlock so I have to revert to FP16_Optimizer...

@iariav
Copy link

iariav commented Aug 11, 2019

@mcarilli
still seeing this issue.
any idea when the support for O2 + DataParallel will kick in?

thanks

@Hesene
Copy link

Hesene commented Aug 13, 2019

I meet the same problem, sad

@xiongzhangdavid
Copy link

same issue

@AtsunoriFujita
Copy link

AtsunoriFujita commented Aug 20, 2019

In my case, torch.nn.parallel.DistributedDataParallel doesn't work except O1

@visionscaper
Copy link

Same issue with DataParallel and O2, for O1 I get a CUDA out of memory, while the float32 version (without amp) fits.

@chenyilun95
Copy link

Same

@williamFalcon
Copy link

williamFalcon commented Oct 6, 2019

There's a version of Distributed Data Parallel that acts like DP on a node and DDP across nodes. (https://pytorch-lightning.readthedocs.io/en/latest/Trainer/Distributed%20training/#distributeddataparallel-2-ddp2).

However, this is incompatible with apex because of the issue above. What happens is that the casting done here has a bug:

  File "/private/home/user/.conda/envs/myenv/lib/python3.7/site-packages/apex/amp/_initialize.py", line 194, in new_fwd
    **applier(kwargs, input_caster))

After digging into the code it looks like the forward call is being patch as a "convenience" to cast inputs to .half() and back to .float32 for the outputs.

A good alternative might be to remove this patching and detect 16bit in PyTorch to do the casting there. This would avoid any patching they do on forward as well.

187  	        for model in models:
188  	            import pdb
189  	            pdb.set_trace()
190  	            # Patch the forward method to cast incoming data to the correct type, and
191  	            # outgoing data to float32, so "the user never needs to call .half()."
192  	            # I like writing things explicitly more than decorators.
193  ->	            def patch_forward(old_fwd):
194  	                def new_fwd(*args, **kwargs):
195  	                    output = old_fwd(*applier(args, input_caster),
196  	                                     **applier(kwargs, input_caster))
197  	                    return applier(output, output_caster)
198  	                return new_fwd
199
200  	            model.forward = patch_forward(model.forward)

The workaround I'm using in Lightning right now is to do this:

def training_step(self, batch, batch_nb):
   x, y = batch
   if self.trainer.use_amp:
        x = x.half()
        y = y.half()

   # process the reset without using forward()
   out = self.model(x)
   ...

Whereas normally I'd do this:

def training_step(self, batch, batch_nb):
   x, y = batch

   # process the input without using forward()
   out = self.forward(x)    # < ------------ ONLY CHANGE + NO CASTING
   ...

@ptrblck

@ewrfcas
Copy link

ewrfcas commented Oct 11, 2019

the same problem

@John1231983
Copy link

Same problem with level O0 and nn.DataParallel . I tried all the above suggestions and they did not work

@lonelylingoes
Copy link

I have the same problem. O1 works well, but the ohters not.
And I change the nn.DataParallel with parallel.DistributedDataParallel, it still be the same.

@wasiahmad
Copy link

Any solution to the problem?

@ewrfcas
Copy link

ewrfcas commented Nov 3, 2019

Now I use the old method called "FP16_optimizer" to solve this problem temporarily.

@wasiahmad
Copy link

@ewrfcas I saw that solution but it was not working for me. Is there any specific version of Apex I should install to use FP16_optimizer. I recently installed apex from its master branch and it says FP16_optimizer not found.

@ewrfcas
Copy link

ewrfcas commented Nov 3, 2019

@ewrfcas I saw that solution but it was not working for me. Is there any specific version of Apex I should install to use FP16_optimizer. I recently installed apex from its master branch and it says FP16_optimizer not found.

FP16_optimizer is in the apex.contrib now, and you can simply copy the code from the git and use it directly.
https://github.com/NVIDIA/apex/blob/master/apex/contrib/optimizers/fp16_optimizer.py

@vadimkantorov
Copy link

vadimkantorov commented Dec 16, 2019

The problem seems to be that DataParallel replication mechanism doesn't seem to work well with forward method patching. The patched method seems to still refer to the singular model copy (via referring the old forward method which references old "self"), not to the replica, to which it should apply, hence type mismatch.

It seems everything would work if forward patching to do the tensor casting is done after DP initialization.

@vadimkantorov
Copy link

Workaround:

model = apex.amp.initialize(torch.nn.Sequential(model), opt_level = 'O2')[0]
model = torch.nn.DataParallel(model, device_ids = args.devices)
model.forward = lambda *args, old_fwd = model.forward, input_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_type']), output_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_outputs'] if apex.amp._amp_state.opt_properties.options.get('cast_model_outputs') is not None else torch.float32), **kwargs: apex.amp._initialize.applier(old_fwd(*apex.amp._initialize.applier(args, input_caster), **apex.amp._initialize.applier(kwargs, input_caster)), output_caster)

In case of DataParallel, forward must be patched after DataParallel(...) call.

@mcarilli
Copy link
Contributor

mcarilli commented Dec 18, 2019

RIght now I'm working hard on native Pytorch support for mixed precision which will accommodate DistributedDataParallel, DataParallel, and model parallel training, targeting the 1.5 release. Apex as a source for mixed precision is not a future-proof path, it's annoying for people to install something separate. If Apex helps, that's great, but the sooner we get something that's packaged and tested as a native component of Pytorch, the better. If Apex does not work for you currently, my best advice is to wait for the upstream support. See #269 (comment).

@phosseini
Copy link

In general we strongly recommend DistributedDataParallel (either torch.nn.parallel.DistributedDataParallel or apex.parallel.DistributedDataParallel) over DataParallel, because global interpreter lock sharing within a single process is not great for performance. Currently, I don't test with or claim to support DataParallel. If you are open to trying DistributedDataParallel, I have a simple example showing proper DDP initialization and launch. The Imagenet example also shows DDP use along with distributed data sampling.

That being said, I don't think DataParallel is fundamentally incompatible with Amp control flow. I see one potential problem with your code above: you are calling .cuda on the model after it's been returned from amp.initialize. You should be doing things in the following order:

model.cuda() # Cuda-ing your model should occur before the call to amp.initialize
model, optimizer = amp.initialize(model, optimizer)
model = nn.DataParallel(model)

Try this and let me know if it works.

The fact that cuda-ing your model should occur before amp.initialize is a general truth, independent of DataParallel or DistributedDataParallel. However, I can't really set up hard checks for that, because people may legitimately want part of their model to remain on the CPU.

Problem solved. In my case, problem was I was passing the model through nn.DataParallel before I do the amp initialization amp.initialize(). Thought it might be the case for some.

@RSKothari
Copy link

@phosseini What are your APEX and Torch versions?

@ChaoFan96
Copy link

ChaoFan96 commented Sep 29, 2020

@vadimkantorov Thank you so much! The workaround you proposed works well in my issue! This ideed help me a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests