From bd71e1b7c91eb6660c539f3ec97e2a27e9bc0098 Mon Sep 17 00:00:00 2001 From: eellison Date: Wed, 21 May 2025 15:35:00 -0700 Subject: [PATCH] update mutation renames (#153895) Thanks to @PaulZhang12 for original find. When we finalize a multi template buffer, we need to reflect mutation renaming in dependencies. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153895 Approved by: https://github.com/PaulZhang12 (cherry picked from commit 35ddad284d350681073aa06a33904d1f31e6fff0) --- test/inductor/test_max_autotune.py | 49 ++++++++++++++++++++++++++++++ torch/_inductor/scheduler.py | 27 ++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 3be2e6983ba3f..bb38620cb1f4e 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1,6 +1,9 @@ # Owner(s): ["module: inductor"] import contextlib +import functools +import inspect import json +import logging import math import os import tempfile @@ -34,6 +37,7 @@ parametrize, TEST_WITH_ROCM, ) +from torch.testing._internal.logging_utils import multiple_logs_to_string from torch.utils._triton import has_triton_tma_device @@ -928,6 +932,51 @@ def f(x, y): f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f) self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25) + @config.patch("trace.enabled", True) + @config.patch({"test_configs.force_extern_kernel_in_multi_template": True}) + def test_mutation_rename(self): + torch._logging.set_logs(ir_post_fusion=True) + + def f(x, y, z, other): + mul = x * y + diag = torch.diagonal(mul) + diag.copy_(other) + x = torch.mm(mul, z) + y = torch.diagonal(x).add_(torch.tensor(1, device="cuda")) + return y + + t = functools.partial(torch.randn, device="cuda") + inps = (t(3, 3), t(3, 3), t(3, 3), t(3)) + fn = torch.compile(f, mode="max-autotune-no-cudagraphs") + ( + pre_fusion_tream, + post_fusion_stream, + ), ctx = multiple_logs_to_string( + "torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion" + ) + + with config.patch({"trace.debug_dir": tempfile.mkdtemp()}): + with self.assertLogs( + logging.getLogger("torch._inductor.debug"), level=logging.INFO + ) as cm, ctx(): + out = fn(*inps) + + self.assertEqual(f(*inps), out) + + pre_fusion_stream = cm.output[0] + post_fusion_stream = cm.output[1] + + # before and after finalizing multi template buffer, deps should have the same normalization + # wrt writes + FileCheck().check("MultiTemplateBuffer").check("unmet").check_same("buf1").run( + pre_fusion_stream + ) + FileCheck().check("ExternKernelSchedulerNode").check("unmet").check_same( + "buf1" + ).run(post_fusion_stream) + + torch._logging.set_logs() + @config.patch({"test_configs.force_extern_kernel_in_multi_template": True}) def test_cat_max_autotune_extern(self): self._test_cat_max_autotune_impl(using_triton_mm=False) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 349e400b77493..27858e076a896 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2635,6 +2635,15 @@ def benchmark_codegened_module( return backend.benchmark_codegened_module(module) def finalize_multi_template_buffers(self) -> None: + """ + Finalize a backing choice for MultiTemplateBuffers which did not already have a + choice finalized through fusion. In the case of an extern choice, this will result + in replacing the SchedulerNode. + + If a MultiTemplateBuffer did not have any fusion opportunities, finalizing a choie + will force completion of compilation and benchmarking. + """ + def replace_operation_buffer( orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer ) -> None: @@ -2702,6 +2711,24 @@ def replace_operation_buffer( self.name_to_node[node.get_name()] = new_scheduler_node self.name_to_fused_node[node.get_name()] = new_scheduler_node + # We need to reflect the mutation renames that were recorded in the original node + mutation_renames = {} + for dep in itertools.chain( + node.read_writes.reads, node.unmet_dependencies + ): + if real_name := self.mutation_real_name.get(dep.name, None): + mutation_renames[real_name] = dep.name + + def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]: + return OrderedSet(dep.rename(mutation_renames) for dep in deps) + + new_scheduler_node.unmet_dependencies = rename_deps( + new_scheduler_node.unmet_dependencies + ) + new_scheduler_node.read_writes.reads = rename_deps( + new_scheduler_node.read_writes.reads + ) + for new_out, old_out in zip( new_scheduler_node.get_outputs(), node.get_outputs() ):