-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
If saved_for_backward returns NumberProxy, the value is taken from compile time, not runtime #231
Comments
It happens due to this line. Not sure if there is any impact of not baking in the value from compile time. The tests in lightning-thunder/thunder/core/transforms.py Line 3705 in 6cd19c4
I think other option could be to instead return cc: @IvanYashchuk |
Thank you @kshitij12345 ! removing this line works for me |
sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace |
That's probably the reason why these numbers were made concrete. We can't solve this particular problem easily, there are a lot of parts of Thunder that rely on concrete numbers. |
Does this cause any problems in your work? |
If we have operator that produces NumberProxy results( |
triage review:
|
Does this look about right? root@c574d9980ec8:/volume# python thunder_issue_231.py
# Constructed by Delete Last Used (took 0 milliseconds)
import operator
import thunder.core.devices as devices
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(a):
# a: "cuda:0 f32[2, 2]"
t2 = get_rng_state_prim_impl(None, devices.Device("cuda:0")) # t2: "cpu ui8[16]"
(i3, i4) = unpack_rng_state_prim_impl(t2)
del t2
[t1] = nvFusion0(a, i3, i4)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4) # t0: "cuda:0 f32[2, 2]"
# t1 = prims.mul(t0, a) # t1: "cuda:0 f32[2, 2]"
i6 = operator.add(i4, 4) # i6: "int 4"
# i6 = prims.add(i4, 4) # i6: "int 4"
t7 = pack_rng_state_prim_impl(i3, i6) # t7: "cpu ui8[16]"
del i6
set_rng_state_prim_impl(t7, devices.Device("cuda:0"))
del t7
return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((), (i3, i4))
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
_, C1, = saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t2, = cotangents
clear_collection(cotangents)
del cotangents
i3, i4, = C1
clear_collection(C1)
del C1
[t12] = nvFusion0(i3, i4, t2)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4) # t0: "cuda:0 f32[2, 2]"
# t12 = prims.mul(t0, t2) # t12: "cuda:0 f32[2, 2]"
del i3, i4, t2
return (t12,) I see This is with I happen to be playing with this recently and my hack seems to plumbed it through for you. |
I don't think my PR helps at all.... since it apparently runs fine without it. 😆 Wondering what are the other cases that you are looking at? |
when I remove the treemap line and run the dropout case, it has
In dropout case the NumberProxy is just constant number( |
Linking issue #403 |
@kiya00, I think this issue was closed automatically. Is it really fixed or should be reopened? |
ah... forgot that we had this one open already. I'm linking it to #541 and I'll verify this when I close the other. |
🐛 Bug
When I add an operator that returns numbers, the values in the saved_for_backward are the compile time value defined in
Symbol.meta
, not the real value computed at runtime.An example trace is like:
To Reproduce
Reproduction is based on branch
uniform_rng
The text was updated successfully, but these errors were encountered: