diff --git a/graph_net/torch/backend/unstable_to_stable_backend.py b/graph_net/torch/backend/unstable_to_stable_backend.py index 472e43432..49cadfb3d 100644 --- a/graph_net/torch/backend/unstable_to_stable_backend.py +++ b/graph_net/torch/backend/unstable_to_stable_backend.py @@ -104,6 +104,28 @@ def _impl_unstable_to_stable_rfft(self, gm): return gm + def _impl_unstable_to_stable_fftn(self, gm): + """ + Convert torch._C._fft.fft_fftn to torch.fft.fftn + """ + # Update graph nodes: replace torch._C._fft.fft_fftn with torch.fft.fftn + issue_nodes = ( + node + for node in gm.graph.nodes + if node.op == "call_function" + if hasattr(node.target, "__module__") + if node.target.__module__ == "torch._C._fft" + if hasattr(node.target, "__name__") + if node.target.__name__ == "fft_fftn" + ) + for node in issue_nodes: + node.target = torch.fft.fftn + + # Recompile the graph + gm.recompile() + + return gm + def unstable_to_stable(self, gm): methods = ( name diff --git a/graph_net/torch/fx_graph_serialize_util.py b/graph_net/torch/fx_graph_serialize_util.py index 51b0b6616..f89716d53 100644 --- a/graph_net/torch/fx_graph_serialize_util.py +++ b/graph_net/torch/fx_graph_serialize_util.py @@ -23,6 +23,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str: (r"torch\._C\._nn\.avg_pool2d\(", "torch.nn.functional.avg_pool2d("), (r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("), (r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("), + (r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("), # Add new rules to this list as needed ] for pattern, repl in replacements: