From 3c9c02054cfc77f34525ce9bed524eb8492ad98f Mon Sep 17 00:00:00 2001 From: xilzy <22451302@zju.edu.cn> Date: Sat, 8 Nov 2025 13:52:19 +0800 Subject: [PATCH] Convert torch._C._linalg.linalg_norm to torch.linalg.norm --- .../backend/unstable_to_stable_backend.py | 21 ++++++++++++++++++- graph_net/torch/fx_graph_serialize_util.py | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/graph_net/torch/backend/unstable_to_stable_backend.py b/graph_net/torch/backend/unstable_to_stable_backend.py index 100f77a28..d7f36379e 100644 --- a/graph_net/torch/backend/unstable_to_stable_backend.py +++ b/graph_net/torch/backend/unstable_to_stable_backend.py @@ -178,7 +178,26 @@ def _impl_unstable_to_stable_linalg_vector_norm(self, gm): return gm - # replace this line with modification code for task 117 (torch._C._linalg.linalg_norm) + def _impl_unstable_to_stable_linalg_norm(self, gm): + """ + Convert torch._C._linalg.linalg_norm to torch.linalg.norm + """ + # Update graph nodes: replace torch._C._linalg.linalg_norm with torch.linalg.norm + issue_nodes = ( + node + for node in gm.graph.nodes + if node.op == "call_function" + if hasattr(node.target, "__module__") + if node.target.__module__ == "torch._C._linalg" + if hasattr(node.target, "__name__") + if node.target.__name__ == "linalg_norm" + ) + for node in issue_nodes: + node.target = torch.linalg.norm + + # Recompile the graph + gm.recompile() + return gm def _impl_unstable_to_stable_softplus(self, gm): """ diff --git a/graph_net/torch/fx_graph_serialize_util.py b/graph_net/torch/fx_graph_serialize_util.py index 2ad78a4c0..bc3e37df4 100644 --- a/graph_net/torch/fx_graph_serialize_util.py +++ b/graph_net/torch/fx_graph_serialize_util.py @@ -140,7 +140,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str: (r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("), (r"torch\._C\._special\.special_logit\(", "torch.special.logit("), (r"torch\._C\._linalg\.linalg_vector_norm\(", "torch.linalg.vector_norm("), - # replace this line with modification code for task 117 (torch._C._linalg.linalg_norm) + (r"torch\._C\._linalg\.linalg_norm\(", "torch.linalg.norm("), (r"torch\._C\._nn\.softplus\(", "torch.nn.functional.softplus("), (r"torch\._C\._nn\.one_hot\(", "torch.nn.functional.one_hot("), (r"torch\._C\._set_grad_enabled\(", "torch.set_grad_enabled("),