Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading local sharded optim state #372

Merged
merged 4 commits into from
Nov 8, 2023
Merged

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Nov 8, 2023

Adaptive clipping caused an issue with loading local-type sharded checkpoints as explained in the comment I added.

# So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
# Not the end of the world but there's probably a better way around this without resetting
# the metric.
for param_id in list(optim_state["state"].keys()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if instead you could go into _do_adaptive_clipping and add something like the following early in the for loop.

if p.grad is None or p.grad.numel() == 0:
    continue

Then only parameters with non-trivial grad will have grad_norm_exp_avg state, and this problem will hopefully go away.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one thing stopping us from doing it that way is that we want to log the grad_norm_exp_avg of every param (as well as the clipping rate, which depends on the grad_norm_exp_avg) to W&B. The logging only happens from rank 0, so one way or another rank 0 needs the grad_norm_exp_avg of every param. The way it is now every rank tracks every grad_norm_exp_avg, so rank 0 of course has all it needs for logging without any additional distributed communication.

We could still make your suggestion work but there would need to be an extra dist.gather or dist.reduce on every step that we log. I think it's worth revisiting this in the future but that's a bigger change than I want to make right now.

I added a TODO comment for this: 5459c19

@2015aroras 2015aroras self-requested a review November 8, 2023 21:08
@AkshitaB AkshitaB merged commit 38be6a7 into main Nov 8, 2023
10 checks passed
@AkshitaB AkshitaB deleted the epwalsh/optim-state-fix branch November 8, 2023 21:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants