diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 7615f1ffc..e9017b8b8 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -5354,23 +5354,8 @@ def topk_error_generator(op, device, **kwargs): yield (SampleInput(make(3, 3), 1, -3), IndexError, err_msg) -# Phantom grad tests do not handle tensor outputs -# that do not require grad and/or do not have grad_fn. -# Therefore we explicitly filter outputs. -# See https://github.com/Lightning-AI/lightning-thunder/issues/119 { -def topk_thunder_ref(*args, **kwargs): - return clang.topk(*args, **kwargs)[0] - - -def topk_torch_ref(*args, **kwargs): - return torch.topk(*args, **kwargs)[0] - - -# } - - topk_opinfo = OpInfo( - topk_thunder_ref, + clang.topk, name="topk", supports_grad=True, # Without the fixed seed this generator does not guarantee @@ -5380,7 +5365,7 @@ def topk_torch_ref(*args, **kwargs): # fix the issue. sample_input_generator=topk_sample_generator, error_input_generator=topk_error_generator, - torch_reference=topk_torch_ref, + torch_reference=torch.topk, dtypes=(datatypes.signedinteger, datatypes.unsignedinteger, datatypes.floating), test_directives=( DecorateInfo( diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 0e40da319..c27d57381 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1184,9 +1184,43 @@ def snippet_phantom_grad_vs_torch_consistency(op, torch_op, sample, comp, singul args, kwargs = sample.args, sample.kwargs + def is_output_differentiable(x): + # grad_fn is set only if one of the input `requires_grad=True` + # and the op is differentiable. + # Example: + # >>> x = torch.ones(3, requires_grad=True) + # >>> y = torch.ones(3, requires_grad=False) + # >>> (x + x).grad_fn # + # >>> (y + y).grad_fn # None + # >>> (y + x).grad_fn # + # >>> (x < 1).grad_fn # None (non-differentiable op) + # Op with differentiable and non-differentiable outputs. + # >>> torch.topk(x, k=2) + # torch.return_types.topk( + # values=tensor([1., 1.], grad_fn=), + # indices=tensor([0, 1])) + # >>> torch.topk(torch.ones(3, requires_grad=False), k=2) + # torch.return_types.topk( + # values=tensor([1., 1.]), + # indices=tensor([0, 1])) + return x.grad_fn is not None + + def filter_differentiable_outputs(outputs): + if isinstance(outputs, torch.Tensor): + # Otherwise `filter` below will + # iterate over the Tensor data. + outputs = [outputs] + + return list(filter(is_output_differentiable, outputs)) + # Computes PyTorch (competition) result torch_flats, spec = tree_flatten((args, kwargs)) torch_result = torch_op(*args, **kwargs) + torch_result = filter_differentiable_outputs(torch_result) + if torch_result == []: + raise RuntimeError( + f"phantom_grad: Expected atleast 1 differentiable output. If {op.name} is non-differentiable, set op.supports_grad=False." + ) grads = [] assert isinstance(torch_result, torch.Tensor) or isinstance( @@ -1197,9 +1231,11 @@ def snippet_phantom_grad_vs_torch_consistency(op, torch_op, sample, comp, singul assert isinstance( x, torch.Tensor ), "Expected a single torch tensor or a sequence of torch tensors when testing phantom grad torch consistency" - grads.append(torch.ones_like(x)) + if is_output_differentiable(x): + grads.append(torch.ones_like(x)) else: - grads = [torch.ones_like(torch_result)] + if is_output_differentiable(torch_result): + grads = [torch.ones_like(torch_result)] torch_tensors_requiring_grad = tuple(f for f in torch_flats if isinstance(f, torch.Tensor) and f.requires_grad) torch_grad_result = torch.autograd.grad(torch_result, torch_tensors_requiring_grad, grads) @@ -1219,6 +1255,7 @@ def upcast_tensors(x: Any) -> Any: f for f in reference_flats if isinstance(f, torch.Tensor) and f.requires_grad ) reference_result = torch_op(*reference_args, **reference_kwargs) + reference_result = filter_differentiable_outputs(reference_result) reference_grad_result = torch.autograd.grad(reference_result, reference_tensors_requiring_grad, grads) # Computes thunder result