diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 279f65d9db8ff..59f88b33d1315 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2247,21 +2247,43 @@ def topological_sort_schedule( name_to_node: Dict[str, BaseSchedulerNode] = dict() result: List[BaseSchedulerNode] = [] + def has_mutations(node: BaseSchedulerNode) -> bool: + return any(buf.get_mutations() for buf in node.get_outputs()) + def visit(n: BaseSchedulerNode) -> None: if n not in seen: seen.add(n) + + # Visit regular dependencies for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): # We only care about doing toposort within `nodes` if dep.name not in name_to_node: continue visit(name_to_node[dep.name]) + + # Visit mutation dependencies + for buf in n.get_outputs(): + for mutation in buf.get_mutations(): + if mutation in name_to_node and name_to_node[mutation] != n: + visit(name_to_node[mutation]) + result.append(n) + # Build name to node mapping for node in nodes: for name in node.get_buffer_names(): name_to_node[name] = node + + # Visit non-mutation nodes first + for node in nodes: + if not has_mutations(node): + visit(node) + + # Then visit mutation nodes for node in nodes: - visit(node) + if has_mutations(node): + visit(node) + return result def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]: