Skip to content

Commit

Permalink
[Inductor] Support top level constants in user defined triton kernels (
Browse files Browse the repository at this point in the history
…pytorch#111970)

Pull Request resolved: pytorch#111970
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#111956
  • Loading branch information
oulgen authored and andreigh committed Oct 26, 2023
1 parent c521771 commit a86ef76
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
56 changes: 56 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,11 @@ def forward(self):
# NB: This also addresses a triton limitation where if the kernels are
# getting called indirectly, triton cannot find the kernels unless they
# are at top level.
# Define constants here for the same triton limitation
CONSTANT_C = 4
STRING_CONSTANT_C = "CONSTANT_C"
BOOL_CONSTANT_C = True

@triton.jit
def add_kernel(
in_ptr0,
Expand Down Expand Up @@ -1931,6 +1936,57 @@ def call_triton(
self.assertEqual(float_result, result)
self.assertEqual(int_result, resulti)

@requires_cuda()
@requires_triton()
def test_triton_kernel_constants(self):
@triton.jit
def mulC_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
CONSTANT_NAME: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
if CONSTANT_NAME.value == STRING_CONSTANT_C:
output = CONSTANT_C * x
if BOOL_CONSTANT_C:
output *= CONSTANT_C
tl.store(out_ptr + offsets, output, mask=mask)

def call_triton(
x: torch.Tensor,
):
output = torch.zeros_like(x)
n_elements = output.numel()

grid = (x.numel(),)
mulC_kernel[grid](
x, output, n_elements, BLOCK_SIZE=16, CONSTANT_NAME="CONSTANT_C"
)
return output

# Triton kernels capture global constants by their parse time value
# not runtime value
global CONSTANT_C
prev_c = CONSTANT_C
# If the behavior of triton kernels change, this test will fail
CONSTANT_C = 10
assert CONSTANT_C != prev_c

t = torch.randn(5, device="cuda")
torch_result = call_triton(t)
compiled_result = torch.compile(call_triton)(t)

self.assertEqual(torch_result, compiled_result)

# reset back
CONSTANT_C = prev_c

@requires_cuda()
@requires_triton()
@common_utils.parametrize("grad", [False, True])
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,10 @@ def traverse(cur_kernel):
compile_wrapper.splice(symbol.src, strip=True)
symbols_included.add(symbol_name)
traverse(symbol)
elif isinstance(symbol, (int, str, bool)):
compile_wrapper.newline()
compile_wrapper.writeline(f"{symbol_name} = {symbol!r}")
symbols_included.add(symbol_name)

traverse(kernel)

Expand Down

0 comments on commit a86ef76

Please sign in to comment.