-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix loading local sharded optim state #372
Conversation
# So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric. | ||
# Not the end of the world but there's probably a better way around this without resetting | ||
# the metric. | ||
for param_id in list(optim_state["state"].keys()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if instead you could go into _do_adaptive_clipping
and add something like the following early in the for loop.
if p.grad is None or p.grad.numel() == 0:
continue
Then only parameters with non-trivial grad will have grad_norm_exp_avg
state, and this problem will hopefully go away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The one thing stopping us from doing it that way is that we want to log the grad_norm_exp_avg
of every param (as well as the clipping rate, which depends on the grad_norm_exp_avg
) to W&B. The logging only happens from rank 0, so one way or another rank 0 needs the grad_norm_exp_avg
of every param. The way it is now every rank tracks every grad_norm_exp_avg
, so rank 0 of course has all it needs for logging without any additional distributed communication.
We could still make your suggestion work but there would need to be an extra dist.gather
or dist.reduce
on every step that we log. I think it's worth revisiting this in the future but that's a bigger change than I want to make right now.
I added a TODO comment for this: 5459c19
Adaptive clipping caused an issue with loading local-type sharded checkpoints as explained in the comment I added.