-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use torch.get_default_dtype
and torch.get_default_device
for factory method in thunder/torch/__init__.py
#621
Comments
triage review:
|
added high-priority since it's needed to support transformers |
I digged a bit to understand where this change might be. It looks to me that we could actually put the default device/dtype information in the cache info, however then we would still have the problem of what happens when the user changes the default device/dtype inside the compiled function like: import thunder
import torch
# default dtype here is torch.float32
def foo():
a = torch.ones((1, ), device="cuda")
torch.set_default_dtype(torch.float64)
b = torch.ones((1, ), device="cuda")
return a, b
jfoo = thunder.jit(foo)
a, b = jfoo()
assert a.dtype != b.dtype This work might also require some better standardization of this tensor creating operations where for each one of them we have slightly different behaviors. For example the snipped above in torch should pass or at least produce |
Is the problem here that we are worried about reordering the In terms of i.e. the cache information only guards for things at the call site of the compiled function. User should be able to update the default device/dtype in the program. That's the same behavior as with eager. If the end status is the same, follow up iteration would be a cache hit as well. |
So the dtype part of this works as @kshitij12345 fixed the #750, we should likely also do the device? |
🐛 Bug
thunder's produces output with different dtype in compiled function
To Reproduce
Pitches
thunder/torch/__init__.py
isn't properly pulling torch's default dtype/devices in ops liketorch.full
,torch.empty
.Resulting in wrong behavior.
The trace from the above function is:
the lowered
prims.full
is executed by nvfuser, which receives an explicit dtype which is different from the vanilla function.cc @apaz-cli
The text was updated successfully, but these errors were encountered: