diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index c04cc12b0..4040dba08 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -743,6 +743,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor: gelu = _register_torch_operation("gelu", module=torch.nn.functional) relu = _register_torch_operation("relu", module=torch.nn.functional) relu6 = _register_torch_operation("relu6", module=torch.nn.functional) +hardswish = _register_torch_operation("hardswish", module=torch.nn.functional) selu = _register_torch_operation("selu", module=torch.nn.functional) silu = _register_torch_operation("silu", module=torch.nn.functional) @@ -754,6 +755,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F _register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker) +_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.silu, silu) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 10e0f8e68..9bcfe15ea 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1598,6 +1598,35 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs): elementwise_unary_ops.append(relu6_opinfo) +def hardswish_error_generator(op, device, dtype=torch.float32, **kwargs): + a = make_tensor((), dtype=dtype, device=device) + yield (SampleInput(a, inplace=True), NotImplementedError, "hardswish only supports inplace=False") + + +hardswish_opinfo = OpInfo( + ltorch.hardswish, + sample_input_generator=elementwise_unary_generator, + error_input_generator=hardswish_error_generator, + torch_reference=_elementwise_unary_torch(torch.nn.functional.hardswish), + dtypes=(datatypes.floating,), + test_directives=( + # PyTorch does not support CPU Half hardswish + DecorateInfo( + pytest.mark.xfail, + "test_core_vs_torch_consistency", + dtypes=(datatypes.float16,), + devicetypes=(devices.DeviceType.CPU,), + ), + # TODO: we might have a tolerance issue here with hardsiwsh, a function of relu6 + DecorateInfo( + pytest.mark.xfail(strict=False), + "test_vjp_correctness", + ), + ), +) +elementwise_unary_ops.append(hardswish_opinfo) + + def selu_error_generator(op, device, dtype=torch.float32, **kwargs): a = make_tensor((), dtype=dtype, device=device) yield (SampleInput(a, inplace=True), NotImplementedError, "selu only supports inplace=False") diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a61832343..53bdd5769 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1208,6 +1208,17 @@ def relu6(a: TensorProxy, /, inplace: bool = False) -> TensorLike: return clamp(a, 0, 6) +@torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False) +def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike: + utils.check(not inplace, lambda: f"hardswish only supports inplace=False", exception_type=NotImplementedError) + utils.check( + dtypes.is_float_dtype(a.dtype), + lambda: f"hardswish only supports floating point dtypes, got {a.dtype}", + exception_type=ValueError, + ) + return a * relu6(a + 3) / 6 + + # id=torch.selu because we ignore inplace argument in torch.nn.functional.selu @torchsymbol(torch.selu, torch.nn.functional.selu, id="torch.selu", is_method=False) def selu(a: TensorProxy, /, inplace: bool = False) -> TensorLike: