-
Notifications
You must be signed in to change notification settings - Fork 390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bring back global gradient clipping and improve speed of collecting metrics #326
Conversation
@@ -185,69 +194,128 @@ def clip_grads_and_collect_metrics( | |||
all_metrics[metric_name] = metric.squeeze(0) | |||
for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics): | |||
all_metrics[metric_name] = metric.squeeze(0) | |||
all_metrics["total_grad_norm"] = total_grad_norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love that we're getting this back.
olmo/optim.py
Outdated
for param_was_clipped in clipping_iter: | ||
if param_was_clipped is not None: | ||
num_eligible_grads += 1 | ||
if param_was_clipped: | ||
num_grads_clipped += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be easier if the clipping functions just returned num_eligible_grads
and num_grads_clipped
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work with FSDP when a parameter exists on multiple devices? Does't that double-count some parameters? But I also don't see any place where these clipping counts get aggregated across ranks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be easier if the clipping functions just returned num_eligible_grads and num_grads_clipped?
That's what I had initially but it got a little weird because when we're not collecting metrics we're not counting these things. Ultimately I thought it was cleaner to have it this way.
How does this work with FSDP when a parameter exists on multiple devices? Does't that double-count some parameters? But I also don't see any place where these clipping counts get aggregated across ranks.
These don't need to be aggregated since the grad norm metrics are synced across all ranks, and so every rank will end up with the same counts for these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the two do_*_clipping()
functions about whether they do anything different when a parameter is present vs. absent. And they do, but I think they do it in a way that doesn't matter for these counts to be correct.
This is some intricate code and it'll be hard to keep correct as it changes. Not that I have a better idea for how to structure it either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to what Dirk said, returning an optional tuple of num_eligible_grads
and num_grads_clipped
was also an option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I found another way to clean this up: 1b1ee91
clip_coef_clamped = torch.clamp(clip_coef, max=1.0) | ||
if p.grad is not None: | ||
# p.grad could be none for some ranks when using FSDP. | ||
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is detach()
the right thing here? Doesn't that create a detached copy? Shouldn't this be with no_grad():
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
detach()
creates a new tensor with the same data, so in-place modifications are seen with the OG tensor. PyTorch does it the same way:
https://github.com/pytorch/pytorch/blob/7827ae2864afa1955bc9ce04d168b274700d24e5/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1184
Actually our whole clip_grads_and_collect_metrics
is wrapped with no_grad()
so this probably isn't necessary, but I think it's good to keep just in case someone (me?) refactors this later and forgets a no_grad()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ec8db60 adds no_grad()
on these methods too just to be safe.
# equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`. | ||
if p.grad is not None: | ||
# p.grad could be none for some ranks when using FSDP. | ||
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question about detach()
.
# TODO (dirkgr): If we did this much earlier, like, right after the forwards step, but then didn't | ||
# call `.item()` for a long time, would it use laziness to interleave this reduce call with the backward step? | ||
# Collect metrics and check for NaN loss. | ||
# NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does checking for NaN loss cause the host-device syncs too? If not, maybe it would be better to check for NaN loss before taking the optimizer step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea checking for NaN requires a sync
olmo/optim.py
Outdated
for param_was_clipped in clipping_iter: | ||
if param_was_clipped is not None: | ||
num_eligible_grads += 1 | ||
if param_was_clipped: | ||
num_grads_clipped += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to what Dirk said, returning an optional tuple of num_eligible_grads
and num_grads_clipped
was also an option.
Trainer.train_step()
to delay host-device syncs as long as possible.Speed up after the metrics change:
![image](https://private-user-images.githubusercontent.com/8812459/274394524-5db734bd-cb1c-4b0e-b002-07477d0a6cae.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjAxODA2NzgsIm5iZiI6MTcyMDE4MDM3OCwicGF0aCI6Ii84ODEyNDU5LzI3NDM5NDUyNC01ZGI3MzRiZC1jYjFjLTRiMGUtYjAwMi0wNzQ3N2QwYTZjYWUucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDcwNSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA3MDVUMTE1MjU4WiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9YTMyNGYxMTlkMzMwYTIyN2M2ZGVhNDQwZTRiMTVkNjgyMWE5YjhmODI4MzAzMTBkOGYxNzVkNzQ5NTNhNWFiMiZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.x3tQQLR85hJfKKov6zJz1NSBT8j2nBN7_e_6pUclIP8)
Speed up after the
![image](https://private-user-images.githubusercontent.com/8812459/274407871-3e92b5b8-448e-4b8d-9b82-cc3e61241c7d.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjAxODA2NzgsIm5iZiI6MTcyMDE4MDM3OCwicGF0aCI6Ii84ODEyNDU5LzI3NDQwNzg3MS0zZTkyYjViOC00NDhlLTRiOGQtOWI4Mi1jYzNlNjEyNDFjN2QucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDcwNSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA3MDVUMTE1MjU4WiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9Zjg0MzUwNWM5YjdjYzY5OWViMDEyZDRlYzRmYzE1YjkzNzg2MjM4ZTcyMWZiNjRmODgxZTljMjQ5NzhhODU4NiZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.6oW67hjFuEQgEuSITffTYifXnIvJyFSYPuhGaRsO1mU)
Trainer.train_step()
change: