diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 5f10367cdf311..8837c4cc13985 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1380,6 +1380,11 @@ def forward(self): # NB: This also addresses a triton limitation where if the kernels are # getting called indirectly, triton cannot find the kernels unless they # are at top level. + # Define constants here for the same triton limitation + CONSTANT_C = 4 + STRING_CONSTANT_C = "CONSTANT_C" + BOOL_CONSTANT_C = True + @triton.jit def add_kernel( in_ptr0, @@ -1931,6 +1936,57 @@ def call_triton( self.assertEqual(float_result, result) self.assertEqual(int_result, resulti) + @requires_cuda() + @requires_triton() + def test_triton_kernel_constants(self): + @triton.jit + def mulC_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + CONSTANT_NAME: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + if CONSTANT_NAME.value == STRING_CONSTANT_C: + output = CONSTANT_C * x + if BOOL_CONSTANT_C: + output *= CONSTANT_C + tl.store(out_ptr + offsets, output, mask=mask) + + def call_triton( + x: torch.Tensor, + ): + output = torch.zeros_like(x) + n_elements = output.numel() + + grid = (x.numel(),) + mulC_kernel[grid]( + x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C" + ) + return output + + # Triton kernels capture global constants by their parse time value + # not runtime value + global CONSTANT_C + prev_c = CONSTANT_C + # If the behavior of triton kernels change, this test will fail + CONSTANT_C = 10 + assert CONSTANT_C != prev_c + + t = torch.randn(5, device="cuda") + torch_result = call_triton(t) + compiled_result = torch.compile(call_triton)(t) + + self.assertEqual(torch_result, compiled_result) + + # reset back + CONSTANT_C = prev_c + @requires_cuda() @requires_triton() @common_utils.parametrize("grad", [False, True]) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 8cadb4f87f144..99dc56cef2179 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -857,6 +857,10 @@ def traverse(cur_kernel): compile_wrapper.splice(symbol.src, strip=True) symbols_included.add(symbol_name) traverse(symbol) + elif isinstance(symbol, (int, str, bool)): + compile_wrapper.newline() + compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") + symbols_included.add(symbol_name) traverse(kernel)