Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM with bfloat16, no speed-up #39

Closed
lhatsk opened this issue Dec 25, 2021 · 14 comments
Closed

OOM with bfloat16, no speed-up #39

lhatsk opened this issue Dec 25, 2021 · 14 comments

Comments

@lhatsk
Copy link
Contributor

lhatsk commented Dec 25, 2021

New issue based on: #34

Turning on bfloat16 in deepspeed doesn't seem to have the desired effect. Model params size remains unchanged. Hitting OOM in validation which works fine in FP16.

Training with bfloat16 in pytorch-lightning fails:

File "openfold/openfold/utils/loss.py", line 46, in sigmoid_cross_entropy
log_p = torch.nn.functional.logsigmoid(logits)
RuntimeError: "log_sigmoid_forward_cuda" not implemented for 'BFloat16'

Support still missing in deepspeed? microsoft/DeepSpeed#974

Tested on A100 with torch 1.10.1+cu113

@lhatsk lhatsk changed the title No OOM with bfloat16, no speed-up OOM with bfloat16, no speed-up Dec 25, 2021
@gahdritz
Copy link
Collaborator

gahdritz commented Dec 25, 2021

It's not still missing in DeepSpeed. See current official documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-options.

As a first step, are you running DeepSpeed with ZeRO stage 1 or 2? Have you updated the DeepSpeed config in addition to passing bf16 as a parameter to the training script?

Our testing was done with the following:

CUDA Driver 465.19.01
CUDA 11.3 Update 1 (11.3.1.005)
cuBLAS 11.5.1.109 (part of CUDA 11.3 U1)
CUDNN 8.2.1.32
NCCL 2.9.9

As for the log_sigmoid thing, I don't know why that call would succeed with DeepSpeed and fail without it. As a sanity check, I'd try just spelling out the sigmoid formula instead.

@lhatsk
Copy link
Contributor Author

lhatsk commented Dec 27, 2021

Stage 2. My DeepSpeed config:

{ "optimizer": { "type": "Adam", "params": { "lr": 0.0005, "eps": 1e-05 } }, "amp": { "enabled": false, "opt_level": "O2" }, "bfloat16": { "enabled": true }, "zero_optimization": { "stage": 2, "cpu_offload": true, "contiguous_gradients": true }, "activation_checkpointing": { "partition_activations": true, "cpu_checkpointing": false, "profile": false }, "gradient_clipping": 0.1 }

Driver Version: 460.27.04 CUDA Version: 11.2

Is there no automatic casting with DeepSpeed? Features are still in torch.float32...

logsigmoid error without DeepSpeed disappears if I replace it with log(sigmoid(.)). In this version all losses go to NaN immediately and never recover. Same OOM. Features still in float32.

@gahdritz
Copy link
Collaborator

Just to clarify: you're also passing --precision bf16 to the training script, right?

@lhatsk
Copy link
Contributor Author

lhatsk commented Dec 28, 2021

Yes

@lhatsk
Copy link
Contributor Author

lhatsk commented Dec 28, 2021

In pure pytorch (no deepspeed, no lightning) I get the same logsigmoid error and all losses go to NaN immediately. There is no autocasting, features stay in torch.float32. Maybe because of the dict argument? I have to cast all floating point features manually.

What's interesting is, if I run deepspeed (with bfloat16 enabled) and manually cast the model to bfloat16 it complains: ValueError: fp32 is enabled but the following parameters have dtype that is not fp32: module.model.input_embedder.linear_tf_z_i.weight [..]

If I only use pytorch-lightning it actually shows in the beginning that it is in bfloat16 mixed precision mode (no message like this with deepspeed), but still no autocasting and NaN losses.

@gahdritz
Copy link
Collaborator

Hm. Interesting. I won't be personally up on A100s for a few more days, so there's not much I can do at the moment to investigate. I'll loop back when I do.

What happens if you manually cast both the model and the input features when you're using DeepSpeed? Do you think casting the model interferes with DeepSpeed's master FP32 copy of the weights? Maybe this would make a good DeepSpeed issue too.

@lhatsk
Copy link
Contributor Author

lhatsk commented Dec 28, 2021

"ValueError: fp32 is enabled but the following parameters have dtype that is not fp32: module.model.input_embedder.linear_tf_z_i.weight [..]"

This happens when I cast the model to bfloat32. The error is triggered in the setup phase in pytorch-lightning. When I only convert the data, the error occurs in the network, asking for a float instead.

I'm not sure what's happening exactly. I have the feeling that the bfloat16 flag in the configuration file is just ignored for some reason. Or it's some miscommunication between deepspeed and pytorch-lightning?

@gahdritz
Copy link
Collaborator

gahdritz commented Dec 31, 2021

FYI: Earlier today I removed a line from the training script that silently changed the value of "CUDA_VISIBLE_DEVICES". This is a long shot, but could it be that this was moving computation onto GPUs that aren't compatible with bfloat16?

I'm not sure what's happening exactly. I have the feeling that the bfloat16 flag in the configuration file is just ignored for some reason. Or it's some miscommunication between deepspeed and pytorch-lightning?

Both of these sound plausible to me.

@lhatsk
Copy link
Contributor Author

lhatsk commented Dec 31, 2021

FYI: Earlier today I removed a line from the training script that silently changed the value of "CUDA_VISIBLE_DEVICES". This is a long shot, but could it be that this was moving computation onto GPUs that aren't compatible with bfloat16?

No, the nodes are in different queues.

@gahdritz
Copy link
Collaborator

gahdritz commented Jan 5, 2022

I'm up on an A100 now, and I can replicate your NaN issue. I'll circle back if I can figure out how to fix it.

@lhatsk
Copy link
Contributor Author

lhatsk commented Jan 6, 2022

I'm up on an A100 now, and I can replicate your NaN issue. I'll circle back if I can figure out how to fix it.

This is when only using pytorch-lightning and changing the loss calculation? Does the bfloat16 flag in DeepSpeed have any effect for you? I piggybacked on the DeepSpeed issue above in the hope we get some response from the DeepSpeed people.

Thanks for following up!

@gahdritz
Copy link
Collaborator

gahdritz commented Jan 6, 2022

It's just PyTorch Lightning, but I haven't changed the loss calculation. What do you mean by that?

@lhatsk
Copy link
Contributor Author

lhatsk commented Jan 7, 2022

It's just PyTorch Lightning, but I haven't changed the loss calculation. What do you mean by that?

Replacing logsigmoid with log(sigmoid(.)).

@gahdritz
Copy link
Collaborator

gahdritz commented Feb 6, 2022

This should have been resolved by this week's commits. Closing this for now.

@gahdritz gahdritz closed this as completed Feb 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants