Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 86 additions & 7 deletions test/inductor/test_cooperative_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing import assert_close
from torch.testing._internal.common_cuda import IS_SM89
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
Expand Down Expand Up @@ -57,12 +58,90 @@ def setUp(self):
torch._inductor.metrics.generated_kernel_count = 0
torch._dynamo.reset()

def run_and_check(self, fn, args, *, expect_kernel_count=1):
args_cpu = [tensor.cpu().to(torch.float32) for tensor in args]
expected = fn(*args_cpu).to(torch.float16)
fn = torch.compile(fn, fullgraph=True)
result, (source_code,) = run_and_get_code(fn, *args)
self.assertEqual(result, expected)
def run_and_check(self, fn, args, dtype=None, *, expect_kernel_count=1):
# Define fixed tolerances
RTOL = 1e-5
ATOL = 1e-6

# calculate reference value in higher precision when input dtype is float16
ref_dtype = dtype
if dtype == torch.float16:
ref_dtype = torch.float64

# Cast to the determined reference dtype
args_ref = [tensor.to(ref_dtype) for tensor in args]

# Calculate expected output
raw_expected = fn(*args_ref)

if isinstance(raw_expected, (tuple, list)):
# If it's a tuple or list, apply .to(dtype) to each tensor within it
# Also, handle cases where dtype might not be provided (e.g., for bool reductions)
if dtype is not None:
expected = type(raw_expected)(
[
t.to(dtype) if isinstance(t, torch.Tensor) else t
for t in raw_expected
]
)
else:
expected = type(raw_expected)(
[
t.to(torch.float64) if isinstance(t, torch.Tensor) else t
for t in raw_expected
]
)
else:
# If it's a single tensor
if dtype is not None:
expected = raw_expected.to(dtype)
else:
expected = raw_expected.to(torch.float64)

fn_compiled = torch.compile(fn, fullgraph=True)
result, (source_code,) = run_and_get_code(fn_compiled, *args)

# For comparison, ensure result is also a tuple/list if expected is
if isinstance(expected, (tuple, list)):
if isinstance(result, torch.Tensor):
result = (result,)
elif not isinstance(result, type(expected)):
result = type(expected)(result)

if dtype is not None:
result = type(result)(
[t.to(dtype) if isinstance(t, torch.Tensor) else t for t in result]
)
else:
result = type(result)(
[
t.to(torch.float64) if isinstance(t, torch.Tensor) else t
for t in result
]
)
else:
if dtype is not None and isinstance(result, torch.Tensor):
result = result.to(dtype)
elif isinstance(result, torch.Tensor):
result = result.to(torch.float64)

# Apply assert_close with fixed tolerances for tensor comparisons
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
assert_close(result, expected, rtol=RTOL, atol=ATOL)
elif isinstance(result, (tuple, list)) and isinstance(expected, (tuple, list)):
# Iterate through elements for comparison
for r_item, e_item in zip(result, expected):
if isinstance(r_item, torch.Tensor) and isinstance(
e_item, torch.Tensor
):
assert_close(r_item, e_item, rtol=RTOL, atol=ATOL)
else:
# Fallback to assertEqual for non-tensor elements (e.g., bool, int)
self.assertEqual(r_item, e_item)
else:
# Fallback to assertEqual for other types not handled by assert_close
self.assertEqual(result, expected)

if "@triton_heuristics.fixed_config" in source_code:
self.assertIn("cooperative_reduction_grid", source_code)
else:
Expand Down Expand Up @@ -98,7 +177,7 @@ def fn(x, y):

reduction_fn = getattr(torch, name)
args = [torch.randn(1, 1024**2, device="cuda", dtype=dtype) for _ in range(2)]
self.run_and_check(fn, args)
self.run_and_check(fn, args, dtype)

def test_bool_reduction_fns(self):
def fn(x, y):
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class PointwisePostOp(NamedTuple):
def avoid_vpmaddubsw_overflow_linear(
batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max
):
if sys.version_info >= (3, 13):
if np.lib.NumpyVersion(np.__version__) >= '2.1.0':
raise unittest.SkipTest("numpy 2.1 overflow error")
for i, j in np.ndindex((batch_size, output_channels)):
for k in range(0, input_channels // 2 * 2, 2):
Expand Down