diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 73a1bb636..126c208fc 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2209,6 +2209,12 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b if bias is not None and not is_supported_tensor(bias): return False + # nvFuser supports only fp16 and bf16 inputs. Only checking the first tensor + # dtype, as all tensors should have the same dtype which is checked by + # linear_meta. + if a.dtype not in (dtypes.float16, dtypes.bfloat16): + return False + # nvFuser only supports 2D inputs in v0.2.3. if not a.ndim == 2: return False diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 276ae8314..10e518054 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -34,7 +34,7 @@ TorchExecutor, ) from thunder.tests.make_tensor import make_tensor, make_tensor_like -from thunder.tests.opinfos import opinfos, push_away_from_singularities, tensor_creation_ops, get_opinfo +from thunder.tests.opinfos import linear_opinfo from looseversion import LooseVersion @@ -854,32 +854,35 @@ def get_num_fusions(cfn): @instantiate( - dtypes=(thunder.float16, thunder.bfloat16), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,) + dtypes=(thunder.float16, thunder.bfloat16), + devicetypes=(devices.DeviceType.CUDA,), + executors=(nvFuserExecutor,), + decorators=( + pytest.mark.skipif(nvfuser_version() < LooseVersion("0.2.3"), reason="Requires nvFuser version 0.2.3 or later"), + ), ) def test_linear(executor, device: str, dtype: dtypes.dtype): def fn(a, b, bias=None): return torch.nn.functional.linear(a, b, bias) - m, n, k = 128, 64, 32 - torch_dtype = ltorch.to_torch_dtype(dtype) - a = torch.randn((m, k), dtype=torch_dtype, device=device) - b = torch.randn((n, k), dtype=torch_dtype, device=device) + at_least_one_tested = False - for has_bias in [True, False]: - bias = None - - if has_bias: - bias = torch.randn(n, dtype=torch_dtype, device=device) + for sample in linear_opinfo.sample_inputs(device, dtype): + # nvFuser doesn't support batched input yet + if sample.args[0].ndim != 2: + continue compiled_func = thunder.jit(fn, executors_list=executor.executors_list(), nv_enable_linear=True) - out = compiled_func(a, b, bias) + out = compiled_func(*sample.args) traces = thunder.last_traces(compiled_func) fusions = examine.get_fusions(traces[-1]) - nv_version = nvfuser_version() - expected_fusions = 1 if nv_version >= LooseVersion("0.2.3") else 0 + expected_fusions = 1 assert len(fusions) == expected_fusions - torch.testing.assert_close(out, torch.nn.functional.linear(a, b, bias)) + torch.testing.assert_close(out, torch.nn.functional.linear(*sample.args)) + at_least_one_tested = True + + assert at_least_one_tested