Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring back global gradient clipping and improve speed of collecting metrics #326

Merged
merged 11 commits into from
Oct 12, 2023

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Oct 11, 2023

  • Changes the behavior of non-adaptive clipping to revert to the default PyTorch behavior where gradients are clipped relative to the total gradient norm across all params (as viewed as one big parameter) instead of relative to each gradient's own norm.
  • Aggregates parameter metrics entirely on GPU to avoid unnecessary host-device syncs. This results in a big speed up with my test model (our 7B mitch-ish with only 2 layers) 馃敟 馃敟 馃敟 See below.
  • Also clean up the code for all this by pulling out the clipping logic into separate methods.
  • Other changes in Trainer.train_step() to delay host-device syncs as long as possible.

Speed up after the metrics change:
image

Speed up after the Trainer.train_step() change:
image

@epwalsh epwalsh changed the title Bring back global gradient clipping Bring back global gradient clipping and improve speed of collecting metrics Oct 11, 2023
@epwalsh epwalsh marked this pull request as ready for review October 11, 2023 19:28
@@ -185,69 +194,128 @@ def clip_grads_and_collect_metrics(
all_metrics[metric_name] = metric.squeeze(0)
for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics):
all_metrics[metric_name] = metric.squeeze(0)
all_metrics["total_grad_norm"] = total_grad_norm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love that we're getting this back.

olmo/optim.py Outdated
Comment on lines 215 to 219
for param_was_clipped in clipping_iter:
if param_was_clipped is not None:
num_eligible_grads += 1
if param_was_clipped:
num_grads_clipped += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be easier if the clipping functions just returned num_eligible_grads and num_grads_clipped?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work with FSDP when a parameter exists on multiple devices? Does't that double-count some parameters? But I also don't see any place where these clipping counts get aggregated across ranks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be easier if the clipping functions just returned num_eligible_grads and num_grads_clipped?

That's what I had initially but it got a little weird because when we're not collecting metrics we're not counting these things. Ultimately I thought it was cleaner to have it this way.

How does this work with FSDP when a parameter exists on multiple devices? Does't that double-count some parameters? But I also don't see any place where these clipping counts get aggregated across ranks.

These don't need to be aggregated since the grad norm metrics are synced across all ranks, and so every rank will end up with the same counts for these.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the two do_*_clipping() functions about whether they do anything different when a parameter is present vs. absent. And they do, but I think they do it in a way that doesn't matter for these counts to be correct.

This is some intricate code and it'll be hard to keep correct as it changes. Not that I have a better idea for how to structure it either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to what Dirk said, returning an optional tuple of num_eligible_grads and num_grads_clipped was also an option.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I found another way to clean this up: 1b1ee91

clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
if p.grad is not None:
# p.grad could be none for some ranks when using FSDP.
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is detach() the right thing here? Doesn't that create a detached copy? Shouldn't this be with no_grad():?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detach() creates a new tensor with the same data, so in-place modifications are seen with the OG tensor. PyTorch does it the same way:
https://github.com/pytorch/pytorch/blob/7827ae2864afa1955bc9ce04d168b274700d24e5/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1184

Actually our whole clip_grads_and_collect_metrics is wrapped with no_grad() so this probably isn't necessary, but I think it's good to keep just in case someone (me?) refactors this later and forgets a no_grad().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ec8db60 adds no_grad() on these methods too just to be safe.

# equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`.
if p.grad is not None:
# p.grad could be none for some ranks when using FSDP.
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question about detach().

# TODO (dirkgr): If we did this much earlier, like, right after the forwards step, but then didn't
# call `.item()` for a long time, would it use laziness to interleave this reduce call with the backward step?
# Collect metrics and check for NaN loss.
# NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does checking for NaN loss cause the host-device syncs too? If not, maybe it would be better to check for NaN loss before taking the optimizer step.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea checking for NaN requires a sync

olmo/optim.py Outdated
Comment on lines 215 to 219
for param_was_clipped in clipping_iter:
if param_was_clipped is not None:
num_eligible_grads += 1
if param_was_clipped:
num_grads_clipped += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to what Dirk said, returning an optional tuple of num_eligible_grads and num_grads_clipped was also an option.

@epwalsh epwalsh requested a review from dirkgr October 12, 2023 01:39
@epwalsh epwalsh merged commit d4744d0 into main Oct 12, 2023
10 checks passed
@epwalsh epwalsh deleted the petew/global-grad-clipping branch October 12, 2023 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants