Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

value_and_grad returns None gradients with thunder.jit #211

Open
kshitij12345 opened this issue Apr 17, 2024 · 3 comments
Open

value_and_grad returns None gradients with thunder.jit #211

kshitij12345 opened this issue Apr 17, 2024 · 3 comments
Assignees
Labels
bug Something isn't working jit transforms

Comments

@kshitij12345
Copy link
Collaborator

Note that it works with the deprecated thunder.compile.

import torch
from thunder.core.transforms import value_and_grad
import thunder

def model(x, w1):
    return x + w1

inp = torch.randn(1, 1)  # doesn't matter if requires_grad is True/False
w1 = torch.randn(1, 1)  # doesn't matter if requires_grad is True/False

print(thunder.compile(value_and_grad(model), disable_preprocessing=True)(inp, w1))
print(thunder.jit(value_and_grad(model))(inp, w1))

Output:

(tensor([[0.4349]]), (tensor([[1.]]), tensor([[1.]])))
(tensor([[0.4349]]), (None, None))
@kshitij12345 kshitij12345 added bug Something isn't working transforms labels Apr 17, 2024
@kshitij12345
Copy link
Collaborator Author

cc: @IvanYashchuk

@kshitij12345 kshitij12345 self-assigned this Apr 24, 2024
@kshitij12345
Copy link
Collaborator Author

Grabbing for investigation

@kshitij12345
Copy link
Collaborator Author

In function is_constant_for_vjp, following line incorrectly returns True.

are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args)

This happens because the isinstance lookaside in general_jit changes the check for TensorProxy to res = issubclass(torch.Tensor, ucls). Hence the above line incorrectly returns True.

def _general_jit_isinstance_lookaside(obj: Any, cls: type | UnionType | tuple[type | UnionType]):
uobj = unwrap(obj)
ucls = unwrap(cls)
if isinstance(uobj, TensorProxy):
res = issubclass(torch.Tensor, ucls)
else:
res = isinstance(uobj, ucls)
pr = ProvenanceRecord(
PseudoInst.LOOKASIDE, inputs=[wrap_const(isinstance).provenance, obj.provenance, cls.provenance]
)
return wrap(res, provenance=pr)

Putting a potential fix

diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index 9ccce83..21f77b5 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -709,7 +709,7 @@ def _general_jit_isinstance_lookaside(obj: Any, cls: type | UnionType | tuple[ty
     uobj = unwrap(obj)
     ucls = unwrap(cls)
     if isinstance(uobj, TensorProxy):
-        res = issubclass(torch.Tensor, ucls)
+        res = issubclass(torch.Tensor, ucls) or isinstance(uobj, ucls)
     else:
         res = isinstance(uobj, ucls)

Leads to a different error:

File "/home/kkalambarkar/lightning-thunder/thunder/core/transforms.py", line 3662, in _vjp
    result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/interpreter.py", line 6179, in partial_call_impl
    return partial_function.func(*(partial_function.args + args), **(partial_function.keywords | kwargs))
  File "/home/kkalambarkar/lightning-thunder/thunder/core/transforms.py", line 3636, in vjp_call_metafunc
    result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)
TypeError: expected 1 argument, got 2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working jit transforms
Projects
None yet
Development

No branches or pull requests

1 participant