-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
value_and_grad returns None gradients with thunder.jit #211
Comments
cc: @IvanYashchuk |
Grabbing for investigation |
In function lightning-thunder/thunder/core/transforms.py Line 3338 in 2578766
This happens because the isinstance lookaside in general_jit changes the check for lightning-thunder/thunder/core/jit_ext.py Lines 708 to 720 in 2578766
Putting a potential fix diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index 9ccce83..21f77b5 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -709,7 +709,7 @@ def _general_jit_isinstance_lookaside(obj: Any, cls: type | UnionType | tuple[ty
uobj = unwrap(obj)
ucls = unwrap(cls)
if isinstance(uobj, TensorProxy):
- res = issubclass(torch.Tensor, ucls)
+ res = issubclass(torch.Tensor, ucls) or isinstance(uobj, ucls)
else:
res = isinstance(uobj, ucls) Leads to a different error:
|
Note that it works with the deprecated
thunder.compile
.Output:
The text was updated successfully, but these errors were encountered: