You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
why did you reduce the gradients in float32 as reported in section 3.1 in the OLMo paper?
We have made some experiments on this and observed that when setting the reduce_dtype=bfloat16 for training setups with more than 4 nodes causes the output logit norm to grow.
I am curious, did you make a similar observation? Did you track the output logit norm during training?
More concretely, during training our models we also observed a growth of the output logit norm which lead to Infs in our PPL metrics (NOT in the loss, loss was still fine) at some point later in training.
Even though we observed that we could mitigate this adding by a regularizing loss that pushes down the output logits, we tried do avoid using such a loss similar to the z-loss as suggested by the PaLM paper.
Instead we investigated PyTorch FSDP Mixed Precision settings, as we suspected bfloat16 to cause issues here.
We trained two Transformer like models of size 125M and 1.3B on next-token-prediction on 4, 16 and 32 nodes (see below).
We trained for approx. 10k steps with the hyperparameters specified below.
A model with 1_3B parameters trained on 32 Nodes with DDP, FSDP NO_SHARD reduce_dtype=float32 and FSDP NO_SHARD reduce_dtype=bfloat16.
We compare FSDP with sharding strategy NO_SHARD to DDP.
We observe that setting reduce_dtype=bfloat16 for training setups with more than 4 nodes causes the output logit norm to grow.
When training with FSDP, setting reduce_dtype=float32 or training with DDP (we think that DDP also reduces gradients in float32) the output logit norm did not grow.
In other experiments we even observed that the growth of the output logit norm scales roughly linear with the number of nodes (when using reduce_dtype=float32)
The tricky thing is that this behavior is not visible in the loss (see screenshots), so it is hard to track down this issue to FSDP Mixed Precision.
We think this is a severe issue that needs more investigation, since the reduce dtype has a major impact on training speed and one would actually prefer bfloat16 for higher training throughput.
The text was updated successfully, but these errors were encountered:
❓ The question
Hi all,
why did you reduce the gradients in float32 as reported in section 3.1 in the OLMo paper?
We have made some experiments on this and observed that when setting the
reduce_dtype=bfloat16
for training setups with more than 4 nodes causes the output logit norm to grow.I am curious, did you make a similar observation? Did you track the output logit norm during training?
More concretely, during training our models we also observed a growth of the output logit norm which lead to Infs in our PPL metrics (NOT in the loss, loss was still fine) at some point later in training.
Even though we observed that we could mitigate this adding by a regularizing loss that pushes down the output logits, we tried do avoid using such a loss similar to the z-loss as suggested by the PaLM paper.
Instead we investigated PyTorch FSDP Mixed Precision settings, as we suspected bfloat16 to cause issues here.
We trained two Transformer like models of size 125M and 1.3B on next-token-prediction on 4, 16 and 32 nodes (see below).
We trained for approx. 10k steps with the hyperparameters specified below.
Note:
drd
corresponds to thereduce_dtype
setting of FSDPMixedPrecisionExperiment 1:
A model with 125M parameters trained on 4 Nodes and 16 Nodes and both with reduce_dtype=bfloat16.
As sharding strategy we use NO_SHARD.
Brown: B24E768gbs256--s-NO_SHARD-nn-16-drd-bfloat16-sn-125M-utc-1-l-0.0003-wd-0.1-nb-24-ed-768-seed-42
Blue: B24E768gbs256--s-NO_SHARD-nn-4 -drd-bfloat16-sn-125M-utc-1-l-0.0003-wd-0.1-nb-24-ed-768-seed-0
Experiment 2:
A model with 1_3B parameters trained on 32 Nodes with DDP, FSDP NO_SHARD reduce_dtype=float32 and FSDP NO_SHARD reduce_dtype=bfloat16.
We compare FSDP with sharding strategy NO_SHARD to DDP.
Grey: B48E2048gbs512--s-NO_SHARD -nn-32-drd-float32 -sn-1_3B-utc-1-l-0.0002-wd-0.1-nb-48-ed-2048-seed-42
Red: B48E2048gbs512--s-DDP -nn-32-drd-bfloat16-sn-1_3B-utc-1-l-0.0002-wd-0.1-nb-48-ed-2048-seed-42
Green: B48E2048gbs512--s-NO_SHARD -nn-32-drd-bfloat16-sn-1_3B-utc-1-l-0.0002-wd-0.1-nb-48-ed-2048-seed-42
We observe that setting
reduce_dtype=bfloat16
for training setups with more than 4 nodes causes the output logit norm to grow.When training with FSDP, setting
reduce_dtype=float32
or training with DDP (we think that DDP also reduces gradients in float32) the output logit norm did not grow.In other experiments we even observed that the growth of the output logit norm scales roughly linear with the number of nodes (when using
reduce_dtype=float32
)The tricky thing is that this behavior is not visible in the loss (see screenshots), so it is hard to track down this issue to FSDP Mixed Precision.
We think this is a severe issue that needs more investigation, since the reduce dtype has a major impact on training speed and one would actually prefer bfloat16 for higher training throughput.
The text was updated successfully, but these errors were encountered: