Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 3, 2024
1 parent b3a52f6 commit d39a969
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 3 additions & 3 deletions thunder/core/transforms.py
Expand Up @@ -1108,8 +1108,8 @@ def _sum_prim_grad(a: TensorProxy, /, dims: Sequence[int]) -> TensorProxy:
register_grad(pids.SUM, _sum_prim_grad)


#@torchctx
#def _topk_prim_grad(a: TensorProxy, /, k: int, dim: None | int = None, largest: bool = True, sorted: bool = True, *, out=None):
# @torchctx
# def _topk_prim_grad(a: TensorProxy, /, k: int, dim: None | int = None, largest: bool = True, sorted: bool = True, *, out=None):
# fwd = prims.topk(a, k, dim, largest, sorted, out=out)
# val, idx = fwd
#
Expand All @@ -1124,7 +1124,7 @@ def _sum_prim_grad(a: TensorProxy, /, dims: Sequence[int]) -> TensorProxy:
# return fwd
#
#
#register_grad(pids.TOPK, _topk_prim_grad)
# register_grad(pids.TOPK, _topk_prim_grad)


# TODO Fix division by zero when n_elem_reduced == 0 or when mean.numel == 0
Expand Down
4 changes: 3 additions & 1 deletion thunder/tests/opinfos.py
Expand Up @@ -4771,13 +4771,15 @@ def topk_thunder_ref(*args, **kwargs):

def topk_torch_ref(*args, **kwargs):
return torch.topk(*args, **kwargs)[0]


# }


topk_opinfo = OpInfo(
topk_thunder_ref,
name="topk",
#supports_grad=True,
# supports_grad=True,
# Without the fixed seed this generator does not guarantee
# to produce inputs at which topk is differentiable
# (i.e. when topk(x, ...).indices == topk(x + dx, ...).indices).
Expand Down

0 comments on commit d39a969

Please sign in to comment.