Skip to content

Commit 7886070

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
Use stable topological sort in fuse_by_partitions (pytorch#167397)
legalize_graph() performs a topo sort that shuffles the nodes is a global way, making the result unpredictable. We should avoid this in graph pass in general. This problem is discovered when testing regional_inductor, a single fuse region trigger the global reordering. Before https://www.internalfb.com/intern/diffing/?before_paste_number=2029217728&after_paste_number=2029218006&regex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff After https://www.internalfb.com/intern/diffing/?paste_number=2029162294&regex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff Left is gm before regional_inductor, right is after. Pull Request resolved: pytorch#167397 Approved by: https://github.com/ezyang
1 parent 87d17e9 commit 7886070

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

torch/fx/passes/tools_common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,9 @@ def __call__(self) -> dict[torch.fx.Node, NodeSet]:
245245

246246

247247
@compatibility(is_backward_compatible=False)
248-
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
248+
def legalize_graph(
249+
gm: torch.fx.GraphModule, stable_topo_sort: bool = False
250+
) -> torch.fx.GraphModule:
249251
"""
250252
Replace the graph of the given GraphModule with one that contains the same nodes as the
251253
original, but in topologically sorted order.
@@ -255,6 +257,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
255257
256258
Arguments:
257259
gm: The graph module to topologically sort. It is modified in-place.
260+
stable_topo_sort: when True, PRIORITIZED_OPS would be ignored.
258261
259262
Returns:
260263
The graph module in-place sorted
@@ -304,7 +307,11 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
304307
for user in cur.users:
305308
indeg[user] -= 1
306309
if indeg[user] == 0:
307-
if user.op == "call_function" and user.target in PRIORITIZED_OPS:
310+
if (
311+
not stable_topo_sort
312+
and user.op == "call_function"
313+
and user.target in PRIORITIZED_OPS
314+
):
308315
queue.appendleft(user)
309316
else:
310317
queue.append(user)

torch/fx/passes/utils/fuser_utils.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -220,22 +220,36 @@ def insert_subgm(
220220
submodule_name = sub_gm.__class__.__name__
221221
gm.add_submodule(submodule_name, sub_gm)
222222

223-
# Create a call_module node in main graph.
224-
module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None)
223+
def last_node(target_nodes: tuple[Node, ...]) -> Node | None:
224+
for node in reversed(gm.graph.nodes):
225+
if node in target_nodes:
226+
return node
227+
return None
225228

226-
output_node = sub_gm.graph.output_node()
227-
if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple):
228-
# main_remapping[comp.orig_outputs[0]] = module_node
229-
orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
230-
else:
231-
for i, orig_output in enumerate(orig_outputs):
232-
# Use Proxy to record getitem access.
233-
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
234-
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
229+
last_input_node: Node | None = last_node(orig_inputs)
230+
assert last_input_node is not None
235231

236-
module_node.meta["val"] = tuple(
237-
orig_output.meta.get("val", None) for orig_output in orig_outputs
232+
# Create a call_module node in main graph.
233+
with gm.graph.inserting_after(last_input_node):
234+
module_node = gm.graph.call_module(
235+
submodule_name, args=orig_inputs, kwargs=None
238236
)
237+
output_node = sub_gm.graph.output_node()
238+
239+
next_node = module_node.next
240+
with gm.graph.inserting_before(next_node):
241+
if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple):
242+
# main_remapping[comp.orig_outputs[0]] = module_node
243+
orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
244+
else:
245+
for i, orig_output in enumerate(orig_outputs):
246+
# Use Proxy to record getitem access.
247+
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
248+
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
249+
250+
module_node.meta["val"] = tuple(
251+
orig_output.meta.get("val", None) for orig_output in orig_outputs
252+
)
239253
return gm
240254

241255

@@ -269,7 +283,7 @@ def fuse_by_partitions(
269283

270284
erase_nodes(gm, sorted_nodes)
271285

272-
# topological sort original gm with newly created sub_gm
273-
legalize_graph(gm)
286+
legalize_graph(gm, stable_topo_sort=True)
287+
gm.graph.lint()
274288

275289
return gm

0 commit comments

Comments
 (0)