From a7bf482ed2a348be28b1f27b6b993cf5f7b3d6d3 Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 11 Jul 2025 09:06:25 +0000 Subject: [PATCH] [Bugfix][Inductor] Fix dependency list merged incorrectly for a custom op with multiple mutated inputs and None return type. (#157133) This is an attempt to fix a memory allocation issue when using `torch.compile` with a custom layernorm kernel in vllm: ```C++ // In-place fused Add and RMS Normalization. ops.def( "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ``` We observed abnormal extra memory allocations with this op enabled using `torch.compile`: {374E9FCF-FB46-4750-8B60-D31E3ADCE00A} and without this op: {9BB08EFE-FFE3-4D06-82C0-C70BBE6ADD56} After investigation, we found that this is because the compiler considers the two buffers for the two mutated inputs `Tensor input` and `Tensor residual` should share a same dependency list, which makes it can not reuse the buffer of `Tensor input`. ``` buf1.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf16.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] ``` ``` op13: ExternKernelSchedulerNode(FallbackKernel) op13.writes = [ StarDep(name='buf17', mode=None), StarDep(name='buf18', mode=None), StarDep(name='buf19', mode=None)] op13.unmet_dependencies = [ StarDep(name='buf13', mode=None), StarDep(name='buf16', mode=None), WeakDep(name='buf11', mutating_buf='buf18'), WeakDep(name='buf12', mutating_buf='buf18'), WeakDep(name='buf13', mutating_buf='buf18'), WeakDep(name='buf2', mutating_buf='buf18'), WeakDep(name='buf3', mutating_buf='buf18')] op13.met_dependencies = [StarDep(name='arg11_1', mode=None)] op13.outputs = [ buf17: FallbackKernel buf17.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf17.aliases = ['buf16', 'buf1'] buf17.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf18: MutationOutput buf18.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf18.mutations = ['buf16'] buf18.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op14'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=True), ] buf19: MutationOutput buf19.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf19.mutations = ['buf1'] buf19.users = [NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False)] ] op13.node.kernel = torch.ops._C.fused_add_rms_norm.default ``` Here we can see `buf16` shares the same dependency list with `buf1` because `buf16` and `buf1` are in the aliases list of `buf17`. This is incorrect since those two are two separate tensors. And this makes the compiler could not reuse `buf16` for subsequent ops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157133 Approved by: https://github.com/jansel (cherry picked from commit 02724b5f649b93ef7960962bdde7a667c0893d21) --- test/dynamo/test_logging.py | 1 + test/inductor/test_auto_functionalize.py | 48 ++++++++++++++++++++++-- torch/_inductor/scheduler.py | 24 ++++++++++++ torch/_logging/_internal.py | 2 + torch/_logging/_registrations.py | 1 + 5 files changed, 72 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 2b120349ea01a..0ff58e49008cd 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -959,6 +959,7 @@ def bar(): "autotuning", "graph_region_expansion", "hierarchical_compile", + "compute_dependencies", } for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index 6f15b493ec1bd..0cc2c9e3a7836 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -445,12 +445,17 @@ def run_aot_eager(self, f, orig_args, _dynamic=False): graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() return [aot_eager_args, result, graph] - def run_inductor(self, f, orig_args, _dynamic=False): + def run_inductor( + self, + f, + orig_args, + _dynamic=False, + log_module="torch._inductor.compile_fx", + log_function="post_grad_graphs", + ): compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) + log_stream, ctx = logs_to_string(log_module, log_function) result = None with ctx(): result = torch.compile( @@ -1733,6 +1738,41 @@ def f(x, w): y = f(x, w) self.assertEqual(y, x.sin()) + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_scheduling_with_multiple_mutates(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor! x, Tensor! y, Tensor z) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo(x, y, z): + pass + + def func(x, w): + a = torch.empty_like(x) # buf0 + b = torch.empty_like(x) # buf1 + torch.ops.mylib.foo(a, b, x) # buf2, buf3, buf4 + c = torch.mm(a, w) # buf5 + torch.ops.mylib.foo(c, b, x) # buf6, buf7, buf8 + return c + + input = torch.rand(2, 2) + weight = torch.rand(2, 2) + [inductor_args, output, graph_inductor] = self.run_inductor( + func, + [input, weight], + False, + "torch._inductor.scheduler", + "compute_dependencies", + ) + name_to_users = eval(graph_inductor) + self.assertNotEqual(name_to_users["buf1"], name_to_users["buf5"]) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 687ba95e1dd1d..f855cc1de922d 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -74,6 +74,9 @@ log = logging.getLogger(__name__) fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") +compute_dependencies_log = torch._logging.getArtifactLogger( + __name__, "compute_dependencies" +) PartitionType = list["BaseSchedulerNode"] @@ -2278,6 +2281,15 @@ def __add__(self, other: DedupList[T]) -> DedupList[T]: for node in self.nodes: for buf1 in node.get_outputs(): buf1_name = buf1.get_name() + # This is for handling auto functionized ops which return None + # and mutate more than 1 inputs, we shouldn't let them all + # point to the same user list since buffers in the aliases + # list might not be alias to each other. + if ( + isinstance(buf1.node.layout, ir.NoneLayout) + and len(buf1.get_aliases()) > 1 + ): + continue for buf2_name in buf1.get_aliases(): if buf1_name in name_to_users and buf2_name in name_to_users: # merge the two @@ -2445,6 +2457,18 @@ def add_user( for name in self.name_to_donated_buffer: self.name_to_donated_buffer[name].set_users(name_to_users[name].items) + # For debug logging + logbuf = IndentedBuffer() + logbuf.splice("{") + for key, value in name_to_users.items(): + with logbuf.indent(): + users = [v.get_name() for v in value.items] + logbuf.splice(f"'{key}': {users},") + logbuf.splice("}") + str = logbuf.getrawvalue().rstrip() + compute_dependencies_log.debug("BUFFER USER LIST\n") + compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str) + def dead_node_elimination(self) -> None: """ Remove any nodes without users diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 3821218cefec9..f56f0165b206f 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -252,6 +252,7 @@ def set_logs( graph_region_expansion: bool = False, inductor_metrics: bool = False, hierarchical_compile: bool = False, + compute_dependencies: bool = False, ) -> None: """ Sets the log level for individual components and toggles individual log @@ -565,6 +566,7 @@ def _set_logs(**kwargs) -> None: graph_region_expansion=graph_region_expansion, inductor_metrics=inductor_metrics, hierarchical_compile=hierarchical_compile, + compute_dependencies=compute_dependencies, ) diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 62e5d9b7064ca..3c6f092ed4d24 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -183,6 +183,7 @@ ) register_artifact("perf_hints", "", off_by_default=True) register_artifact("onnx_diagnostics", "", off_by_default=True) +register_artifact("compute_dependencies", "", off_by_default=True) register_artifact( "fusion", "Detailed Inductor fusion decisions. More detailed than 'schedule'",