Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Owner(s): ["module: inductor"]
import contextlib
import functools
import inspect
import json
import logging
import math
import os
import tempfile
Expand Down Expand Up @@ -34,6 +37,7 @@
parametrize,
TEST_WITH_ROCM,
)
from torch.testing._internal.logging_utils import multiple_logs_to_string
from torch.utils._triton import has_triton_tma_device


Expand Down Expand Up @@ -928,6 +932,51 @@ def f(x, y):
f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f)
self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25)

@config.patch("trace.enabled", True)
@config.patch({"test_configs.force_extern_kernel_in_multi_template": True})
def test_mutation_rename(self):
torch._logging.set_logs(ir_post_fusion=True)

def f(x, y, z, other):
mul = x * y
diag = torch.diagonal(mul)
diag.copy_(other)
x = torch.mm(mul, z)
y = torch.diagonal(x).add_(torch.tensor(1, device="cuda"))
return y

t = functools.partial(torch.randn, device="cuda")
inps = (t(3, 3), t(3, 3), t(3, 3), t(3))
fn = torch.compile(f, mode="max-autotune-no-cudagraphs")
(
pre_fusion_tream,
post_fusion_stream,
), ctx = multiple_logs_to_string(
"torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion"
)

with config.patch({"trace.debug_dir": tempfile.mkdtemp()}):
with self.assertLogs(
logging.getLogger("torch._inductor.debug"), level=logging.INFO
) as cm, ctx():
out = fn(*inps)

self.assertEqual(f(*inps), out)

pre_fusion_stream = cm.output[0]
post_fusion_stream = cm.output[1]

# before and after finalizing multi template buffer, deps should have the same normalization
# wrt writes
FileCheck().check("MultiTemplateBuffer").check("unmet").check_same("buf1").run(
pre_fusion_stream
)
FileCheck().check("ExternKernelSchedulerNode").check("unmet").check_same(
"buf1"
).run(post_fusion_stream)

torch._logging.set_logs()

@config.patch({"test_configs.force_extern_kernel_in_multi_template": True})
def test_cat_max_autotune_extern(self):
self._test_cat_max_autotune_impl(using_triton_mm=False)
Expand Down
27 changes: 27 additions & 0 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2635,6 +2635,15 @@ def benchmark_codegened_module(
return backend.benchmark_codegened_module(module)

def finalize_multi_template_buffers(self) -> None:
"""
Finalize a backing choice for MultiTemplateBuffers which did not already have a
choice finalized through fusion. In the case of an extern choice, this will result
in replacing the SchedulerNode.

If a MultiTemplateBuffer did not have any fusion opportunities, finalizing a choie
will force completion of compilation and benchmarking.
"""

def replace_operation_buffer(
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
) -> None:
Expand Down Expand Up @@ -2702,6 +2711,24 @@ def replace_operation_buffer(
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node

# We need to reflect the mutation renames that were recorded in the original node
mutation_renames = {}
for dep in itertools.chain(
node.read_writes.reads, node.unmet_dependencies
):
if real_name := self.mutation_real_name.get(dep.name, None):
mutation_renames[real_name] = dep.name

def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
return OrderedSet(dep.rename(mutation_renames) for dep in deps)

new_scheduler_node.unmet_dependencies = rename_deps(
new_scheduler_node.unmet_dependencies
)
new_scheduler_node.read_writes.reads = rename_deps(
new_scheduler_node.read_writes.reads
)

for new_out, old_out in zip(
new_scheduler_node.get_outputs(), node.get_outputs()
):
Expand Down