From ac16de1dcd3c8a41478189bf461fa1240d109a2b Mon Sep 17 00:00:00 2001 From: Gruge <202421090105@std.uestc.edu.cn> Date: Thu, 6 Nov 2025 14:13:40 +0800 Subject: [PATCH 1/3] replace torch._C._fft.fft_fftn with torch.fft.fftn --- .../backend/unstable_to_stable_backend.py | 22 +++++++++++++++++++ graph_net/torch/fx_graph_serialize_util.py | 5 +++++ 2 files changed, 27 insertions(+) diff --git a/graph_net/torch/backend/unstable_to_stable_backend.py b/graph_net/torch/backend/unstable_to_stable_backend.py index 3cb3b81c4..fd4c3bcc2 100644 --- a/graph_net/torch/backend/unstable_to_stable_backend.py +++ b/graph_net/torch/backend/unstable_to_stable_backend.py @@ -52,10 +52,32 @@ def avg_pool2d_to_avg_pool2d(self, gm): return gm + def fft_fftn_to_fftn(self, gm): + """ + Convert torch._C._fft.fft_fftn to torch.fft.fftn + """ + # Update graph nodes: replace torch._C._fft.fft_fftn with torch.fft.fftn + for node in gm.graph.nodes: + if node.op == "call_function": + if ( + hasattr(node.target, "__module__") + and hasattr(node.target, "__name__") + and node.target.__module__ == "torch._C._fft" + and node.target.__name__ == "fft_fftn" + ): + node.target = torch.fft.fftn + + # Recompile the graph + gm.recompile() + + return gm + def unstable_to_stable(self, gm): # Convert based on unstable_api environment variable if self.unstable_api == "torch._C._nn.avg_pool2d": gm = self.avg_pool2d_to_avg_pool2d(gm) + elif self.unstable_api == "torch._C._fft.fft_fftn": + gm = self.fft_fftn_to_fftn(gm) return gm def check_unstable_api(self, gm): diff --git a/graph_net/torch/fx_graph_serialize_util.py b/graph_net/torch/fx_graph_serialize_util.py index 6d6117ee0..1e8be2e5d 100644 --- a/graph_net/torch/fx_graph_serialize_util.py +++ b/graph_net/torch/fx_graph_serialize_util.py @@ -24,4 +24,9 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str: "torch.nn.functional.avg_pool2d(", code, ) + code = re.sub( + r"torch\._C\._fft\.fft_fftn\(", + "torch.fft.fftn(", + code, + ) return code From 1e4381a0084ed5a8705e8eeecda915361e311485 Mon Sep 17 00:00:00 2001 From: Gruge <202421090105@std.uestc.edu.cn> Date: Thu, 6 Nov 2025 14:59:11 +0800 Subject: [PATCH 2/3] update code --- .../backend/unstable_to_stable_backend.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/graph_net/torch/backend/unstable_to_stable_backend.py b/graph_net/torch/backend/unstable_to_stable_backend.py index fd4c3bcc2..f69b3f506 100644 --- a/graph_net/torch/backend/unstable_to_stable_backend.py +++ b/graph_net/torch/backend/unstable_to_stable_backend.py @@ -57,15 +57,17 @@ def fft_fftn_to_fftn(self, gm): Convert torch._C._fft.fft_fftn to torch.fft.fftn """ # Update graph nodes: replace torch._C._fft.fft_fftn with torch.fft.fftn - for node in gm.graph.nodes: - if node.op == "call_function": - if ( - hasattr(node.target, "__module__") - and hasattr(node.target, "__name__") - and node.target.__module__ == "torch._C._fft" - and node.target.__name__ == "fft_fftn" - ): - node.target = torch.fft.fftn + issue_nodes = ( + node + for node in gm.graph.nodes + if node.op == "call_function" + if hasattr(node.target, "__module__") + if node.target.__module__ == "torch._C._fft" + if hasattr(node.target, "__name__") + if node.target.__name__ == "fft_fftn" + ) + for node in issue_nodes: + node.target = torch.fft.fftn # Recompile the graph gm.recompile() From 12724e54bf2861a56ff33d1d9b696312d17331d3 Mon Sep 17 00:00:00 2001 From: Gruge <202421090105@std.uestc.edu.cn> Date: Fri, 7 Nov 2025 10:12:49 +0800 Subject: [PATCH 3/3] resolve merge conflicts --- .../backend/unstable_to_stable_backend.py | 68 +++++++++++++++++-- graph_net/torch/fx_graph_serialize_util.py | 19 +++--- 2 files changed, 70 insertions(+), 17 deletions(-) diff --git a/graph_net/torch/backend/unstable_to_stable_backend.py b/graph_net/torch/backend/unstable_to_stable_backend.py index f69b3f506..49cadfb3d 100644 --- a/graph_net/torch/backend/unstable_to_stable_backend.py +++ b/graph_net/torch/backend/unstable_to_stable_backend.py @@ -30,7 +30,37 @@ def my_backend(gm, sample_inputs): **Stable API reference link:** """ - def avg_pool2d_to_avg_pool2d(self, gm): + def _impl_unstable_to_stable_irfft(self, gm): + def replace_in_graph(graph_mod): + # Register stable implementation on GraphModule, codegen can use self.irfft + try: + setattr(graph_mod, "irfft", torch.fft.irfft) + except Exception: + pass + + for node in graph_mod.graph.nodes: + if node.op == "call_function": + # Match for all forms of target names + if "fft_irfft" in str(node.target): + # Directly point target to Python layer function + node.target = torch.fft.irfft + # Validate and recompile the graph + graph_mod.graph.lint() + graph_mod.recompile() + + # Process main gm and all nested GraphModules + modules = [gm] + modules += [ + m + for _, m in gm.named_modules() + if isinstance(m, torch.fx.GraphModule) and m is not gm + ] + for m in modules: + replace_in_graph(m) + + return gm + + def _impl_unstable_to_stable_avg_pool2d(self, gm): """ Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d """ @@ -52,7 +82,29 @@ def avg_pool2d_to_avg_pool2d(self, gm): return gm - def fft_fftn_to_fftn(self, gm): + def _impl_unstable_to_stable_rfft(self, gm): + """ + Convert torch._C._fft.fft_rfft to torch.fft.rfft + """ + # Update graph nodes: replace torch._C._fft.fft_rfft with torch.fft.rfft + issue_nodes = ( + node + for node in gm.graph.nodes + if node.op == "call_function" + if hasattr(node.target, "__module__") + if node.target.__module__ == "torch._C._fft" + if hasattr(node.target, "__name__") + if node.target.__name__ == "fft_rfft" + ) + for node in issue_nodes: + node.target = torch.fft.rfft + + # Recompile the graph + gm.recompile() + + return gm + + def _impl_unstable_to_stable_fftn(self, gm): """ Convert torch._C._fft.fft_fftn to torch.fft.fftn """ @@ -75,11 +127,13 @@ def fft_fftn_to_fftn(self, gm): return gm def unstable_to_stable(self, gm): - # Convert based on unstable_api environment variable - if self.unstable_api == "torch._C._nn.avg_pool2d": - gm = self.avg_pool2d_to_avg_pool2d(gm) - elif self.unstable_api == "torch._C._fft.fft_fftn": - gm = self.fft_fftn_to_fftn(gm) + methods = ( + name + for name in vars(type(self)).keys() + if name.startswith("_impl_unstable_to_stable") + ) + for method in methods: + gm = getattr(self, method)(gm) return gm def check_unstable_api(self, gm): diff --git a/graph_net/torch/fx_graph_serialize_util.py b/graph_net/torch/fx_graph_serialize_util.py index 1e8be2e5d..f89716d53 100644 --- a/graph_net/torch/fx_graph_serialize_util.py +++ b/graph_net/torch/fx_graph_serialize_util.py @@ -19,14 +19,13 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str: """ code = gm.code # Replace torch._C._nn.avg_pool2d with torch.nn.functional.avg_pool2d - code = re.sub( - r"torch\._C\._nn\.avg_pool2d\(", - "torch.nn.functional.avg_pool2d(", - code, - ) - code = re.sub( - r"torch\._C\._fft\.fft_fftn\(", - "torch.fft.fftn(", - code, - ) + replacements = [ + (r"torch\._C\._nn\.avg_pool2d\(", "torch.nn.functional.avg_pool2d("), + (r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("), + (r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("), + (r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("), + # Add new rules to this list as needed + ] + for pattern, repl in replacements: + code = re.sub(pattern, repl, code) return code