-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update test_linear_nvfuser #357
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ | |
TorchExecutor, | ||
) | ||
from thunder.tests.make_tensor import make_tensor, make_tensor_like | ||
from thunder.tests.opinfos import opinfos, push_away_from_singularities, tensor_creation_ops, get_opinfo | ||
from thunder.tests.opinfos import linear_opinfo | ||
from looseversion import LooseVersion | ||
|
||
|
||
|
@@ -854,32 +854,35 @@ def get_num_fusions(cfn): | |
|
||
|
||
@instantiate( | ||
dtypes=(thunder.float16, thunder.bfloat16), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,) | ||
dtypes=(thunder.float16, thunder.bfloat16), | ||
devicetypes=(devices.DeviceType.CUDA,), | ||
executors=(nvFuserExecutor,), | ||
decorators=( | ||
pytest.mark.skipif(nvfuser_version() < LooseVersion("0.2.3"), reason="Requires nvFuser version 0.2.3 or later"), | ||
), | ||
) | ||
def test_linear(executor, device: str, dtype: dtypes.dtype): | ||
|
||
def fn(a, b, bias=None): | ||
return torch.nn.functional.linear(a, b, bias) | ||
|
||
m, n, k = 128, 64, 32 | ||
torch_dtype = ltorch.to_torch_dtype(dtype) | ||
a = torch.randn((m, k), dtype=torch_dtype, device=device) | ||
b = torch.randn((n, k), dtype=torch_dtype, device=device) | ||
at_least_one_tested = False | ||
|
||
Comment on lines
+869
to
870
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does that mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a for loop over sample inputs, inside the for loop there's a continue statement to skip inputs that are not 2D. If someone goes ahead and modifies the sample input generator for linear to not generate a 2D sample this test won't test anything silently. With assert on "at least one sample input tested," there would be a failure. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or we could just create a single sample and append to the cases? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's a viable option. |
||
for has_bias in [True, False]: | ||
bias = None | ||
|
||
if has_bias: | ||
bias = torch.randn(n, dtype=torch_dtype, device=device) | ||
for sample in linear_opinfo.sample_inputs(device, dtype): | ||
# nvFuser doesn't support batched input yet | ||
if sample.args[0].ndim != 2: | ||
continue | ||
|
||
compiled_func = thunder.jit(fn, executors_list=executor.executors_list(), nv_enable_linear=True) | ||
|
||
out = compiled_func(a, b, bias) | ||
out = compiled_func(*sample.args) | ||
traces = thunder.last_traces(compiled_func) | ||
fusions = examine.get_fusions(traces[-1]) | ||
nv_version = nvfuser_version() | ||
|
||
expected_fusions = 1 if nv_version >= LooseVersion("0.2.3") else 0 | ||
expected_fusions = 1 | ||
|
||
assert len(fusions) == expected_fusions | ||
torch.testing.assert_close(out, torch.nn.functional.linear(a, b, bias)) | ||
torch.testing.assert_close(out, torch.nn.functional.linear(*sample.args)) | ||
at_least_one_tested = True | ||
|
||
assert at_least_one_tested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wasn't aware there's such restriction.
cc'ing @Priya2698 is this an intended behavior? I thought we are just dispatching to aten now and don't see why it's rejecting full precision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment, we do have that restriction since we use
fusedMultiplySum
andMmaOp
that only accept bf16/fp16.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This restriction will although be lifted with the new IR nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got'ya.
Thanks @IvanYashchuk for catching this. @Priya2698 remember to do a version bump when that restriction is lifted, so we can have a proper behavior guarded with nvfuser_version here.