Skip to content

Commit

Permalink
test_phantom_grad: support ops which return non-differentiable outputs (
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed May 23, 2024
1 parent c55798f commit b7154dc
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
19 changes: 2 additions & 17 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -5354,23 +5354,8 @@ def topk_error_generator(op, device, **kwargs):
yield (SampleInput(make(3, 3), 1, -3), IndexError, err_msg)


# Phantom grad tests do not handle tensor outputs
# that do not require grad and/or do not have grad_fn.
# Therefore we explicitly filter outputs.
# See https://github.com/Lightning-AI/lightning-thunder/issues/119 {
def topk_thunder_ref(*args, **kwargs):
return clang.topk(*args, **kwargs)[0]


def topk_torch_ref(*args, **kwargs):
return torch.topk(*args, **kwargs)[0]


# }


topk_opinfo = OpInfo(
topk_thunder_ref,
clang.topk,
name="topk",
supports_grad=True,
# Without the fixed seed this generator does not guarantee
Expand All @@ -5380,7 +5365,7 @@ def topk_torch_ref(*args, **kwargs):
# fix the issue.
sample_input_generator=topk_sample_generator,
error_input_generator=topk_error_generator,
torch_reference=topk_torch_ref,
torch_reference=torch.topk,
dtypes=(datatypes.signedinteger, datatypes.unsignedinteger, datatypes.floating),
test_directives=(
DecorateInfo(
Expand Down
41 changes: 39 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,9 +1184,43 @@ def snippet_phantom_grad_vs_torch_consistency(op, torch_op, sample, comp, singul

args, kwargs = sample.args, sample.kwargs

def is_output_differentiable(x):
# grad_fn is set only if one of the input `requires_grad=True`
# and the op is differentiable.
# Example:
# >>> x = torch.ones(3, requires_grad=True)
# >>> y = torch.ones(3, requires_grad=False)
# >>> (x + x).grad_fn # <AddBackward0 object at 0x7f0502edcf40>
# >>> (y + y).grad_fn # None
# >>> (y + x).grad_fn # <AddBackward0 object at 0x7f0502e21060>
# >>> (x < 1).grad_fn # None (non-differentiable op)
# Op with differentiable and non-differentiable outputs.
# >>> torch.topk(x, k=2)
# torch.return_types.topk(
# values=tensor([1., 1.], grad_fn=<TopkBackward0>),
# indices=tensor([0, 1]))
# >>> torch.topk(torch.ones(3, requires_grad=False), k=2)
# torch.return_types.topk(
# values=tensor([1., 1.]),
# indices=tensor([0, 1]))
return x.grad_fn is not None

def filter_differentiable_outputs(outputs):
if isinstance(outputs, torch.Tensor):
# Otherwise `filter` below will
# iterate over the Tensor data.
outputs = [outputs]

return list(filter(is_output_differentiable, outputs))

# Computes PyTorch (competition) result
torch_flats, spec = tree_flatten((args, kwargs))
torch_result = torch_op(*args, **kwargs)
torch_result = filter_differentiable_outputs(torch_result)
if torch_result == []:
raise RuntimeError(
f"phantom_grad: Expected atleast 1 differentiable output. If {op.name} is non-differentiable, set op.supports_grad=False."
)

grads = []
assert isinstance(torch_result, torch.Tensor) or isinstance(
Expand All @@ -1197,9 +1231,11 @@ def snippet_phantom_grad_vs_torch_consistency(op, torch_op, sample, comp, singul
assert isinstance(
x, torch.Tensor
), "Expected a single torch tensor or a sequence of torch tensors when testing phantom grad torch consistency"
grads.append(torch.ones_like(x))
if is_output_differentiable(x):
grads.append(torch.ones_like(x))
else:
grads = [torch.ones_like(torch_result)]
if is_output_differentiable(torch_result):
grads = [torch.ones_like(torch_result)]

torch_tensors_requiring_grad = tuple(f for f in torch_flats if isinstance(f, torch.Tensor) and f.requires_grad)
torch_grad_result = torch.autograd.grad(torch_result, torch_tensors_requiring_grad, grads)
Expand All @@ -1219,6 +1255,7 @@ def upcast_tensors(x: Any) -> Any:
f for f in reference_flats if isinstance(f, torch.Tensor) and f.requires_grad
)
reference_result = torch_op(*reference_args, **reference_kwargs)
reference_result = filter_differentiable_outputs(reference_result)
reference_grad_result = torch.autograd.grad(reference_result, reference_tensors_requiring_grad, grads)

# Computes thunder result
Expand Down

0 comments on commit b7154dc

Please sign in to comment.