Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2247,21 +2247,43 @@ def topological_sort_schedule(
name_to_node: Dict[str, BaseSchedulerNode] = dict()
result: List[BaseSchedulerNode] = []

def has_mutations(node: BaseSchedulerNode) -> bool:
return any(buf.get_mutations() for buf in node.get_outputs())

def visit(n: BaseSchedulerNode) -> None:
if n not in seen:
seen.add(n)

# Visit regular dependencies
for dep in sorted(n.unmet_dependencies, key=lambda d: d.name):
# We only care about doing toposort within `nodes`
if dep.name not in name_to_node:
continue
visit(name_to_node[dep.name])

# Visit mutation dependencies
for buf in n.get_outputs():
for mutation in buf.get_mutations():
if mutation in name_to_node and name_to_node[mutation] != n:
visit(name_to_node[mutation])

result.append(n)

# Build name to node mapping
for node in nodes:
for name in node.get_buffer_names():
name_to_node[name] = node

# Visit non-mutation nodes first
for node in nodes:
if not has_mutations(node):
visit(node)

# Then visit mutation nodes
for node in nodes:
visit(node)
if has_mutations(node):
visit(node)

return result

def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]:
Expand Down