-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reducing the weight of NumberProxy used in mincut in rematerialization #425
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change makes sense to me.
The PR description has only the "before" trace. Could you please update it adding how the trace looks like with this PR?
Not blocking the merge due to an absence of tests, but maybe you could come up with a short test for this change? The rematerialization tests currently live in tests/test_nvfuser_remat.py
.
Trace after this PR(based on branch uniform_rng):
sure, let me try if I can construct a graph corresponding to the case By the way, the final trace is the same on Llama-2-7b-hf fsdp_zero2_none_bucket before and after this pr @IvanYashchuk |
Thank you for checking! This is very good and what I wanted to see. |
Hi @mruberry , could you take a look for review |
Hi @nikitaved @t-vi , could you help to take a look for merging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, stamped!
Before submitting
What does this PR do?
Fixes part of #114 .
Background:
When decomposing
dropout
touniform_philox
, the rematerialization doesn't pass the expected seed/offset to the backwardTrace before this PR(based on branch uniform_rng):
The mincut is
(a_in, a_out), (t1_in, t1_out)
withweight=2+1=3
, the corresponding weight is:we expect to have the mincut
((a_in, a_out), (i7_in, i7_out), (i8_in, i8_out))
which actually has the sameweight=0.5+0.5+2=3
Here's the trace with this PR: #425 (comment)
The proposal of fixing it would be add a factor(e.g.
0.1
) to reduce the weight for NumberProxycc: @IvanYashchuk