Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions graph_net/log2json.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def parse_logs_to_json(log_file: str, output_dir: str):
"datatype": {},
"speedup": {},
},
"result": {
"status": "unknown",
},
}
continue

Expand Down Expand Up @@ -102,16 +105,20 @@ def parse_logs_to_json(log_file: str, output_dir: str):
result_status_match = patterns["result_status"].search(line)
if result_status_match:
status = result_status_match.group(1).strip()
data["result"]["status"] = status
if status == "failed" and (i + 1) < len(lines):
error_reason_match = patterns["failure"].search(lines[i + 1])
if error_reason_match:
reason = error_reason_match.group(1).lower()
if "eager" in reason:
data["performance"]["failure"] = "eager"
data["result"]["status"] = "eager_fail"
elif "compiled" in reason:
data["performance"]["failure"] = "compiled"
data["result"]["status"] = "compile_fail"
else:
data["performance"]["failure"] = "other"
data["result"]["status"] = "runtime_fail"
continue

speedup_match = patterns["speedup"].search(line)
Expand Down Expand Up @@ -141,6 +148,20 @@ def parse_logs_to_json(log_file: str, output_dir: str):
# filename = f"{model_name}_{subgraph_name}_{compiler_name}.json"
filepath = os.path.join(output_dir, filename)

# Build result field with status and speedup
if data["result"]["status"] == "success":
speedup_data = {}
if "e2e" in data["performance"]["speedup"]:
speedup_data["e2e"] = {
"mean": data["performance"]["speedup"]["e2e"]
}
if "gpu" in data["performance"]["speedup"]:
speedup_data["gpu"] = {
"mean": data["performance"]["speedup"]["gpu"]
}
if speedup_data:
data["result"]["speedup"] = speedup_data

with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)

Expand Down
66 changes: 64 additions & 2 deletions graph_net/torch/backend/unstable_to_stable_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import sys
import inspect
import ast
from .graph_compiler_backend import GraphCompilerBackend
from ..fx_graph_serialize_util import serialize_graph_module_to_str

Expand All @@ -12,12 +13,21 @@ def __call__(self, model):
unstable_api = os.getenv("DISALLOWED_UNSTABLE_API", "").strip()
self.unstable_api = unstable_api

# Use torch.compile's backend method to get graph module uniformly
# This ensures all models use the same conversion method, avoiding performance differences
def my_backend(gm, sample_inputs):
# Convert unstable API
gm = self.unstable_to_stable(gm)
self.check_unstable_api(gm)
# Return forward function without additional optimization
return gm.forward

return torch.compile(backend=my_backend)(model)
# Use torch.compile to get graph module and perform conversion
# Although compile is used, the backend only does API conversion, no optimization
# Performance should be close to eager mode (since only API replacement is done)
# Note: Do not use mode parameter to avoid version compatibility issues
compiled_model = torch.compile(model, backend=my_backend)
return compiled_model

"""
TODO: Implement logic to convert unstable APIs in `self.model` into their stable counterparts.
Expand Down Expand Up @@ -165,7 +175,59 @@ def _impl_unstable_to_stable_special_logit(self, gm):

# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)

# replace this line with modification code for task 127 (torch._C._nn.linear)
def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
"""
Convert torch._C._nn.linear to torch.nn.functional.linear

Args:
gm: torch.fx.GraphModule object

Returns:
Modified GraphModule object
"""
import torch.nn.functional as F

# Get reference to torch._C._nn.linear for comparison
try:
unstable_linear = torch._C._nn.linear
except AttributeError:
unstable_linear = None

# Traverse all nodes to find nodes that need to be replaced
for node in gm.graph.nodes:
if node.op == "call_function":
target = node.target
should_replace = False

# Method 1: Direct target comparison (most reliable)
if unstable_linear is not None and target is unstable_linear:
should_replace = True
# Method 2: Check if it's the same function object (using id comparison)
elif unstable_linear is not None and id(target) == id(unstable_linear):
should_replace = True
# Method 3: Check module and name attributes (most reliable method, as torch.fx preserves these attributes)
elif hasattr(target, "__module__") and hasattr(target, "__name__"):
if (
target.__module__ == "torch._C._nn"
and target.__name__ == "linear"
):
should_replace = True
# Method 4: Check via string representation (fallback method)
elif "torch._C._nn.linear" in str(target) or (
hasattr(target, "__qualname__")
and "linear" in target.__qualname__
and hasattr(target, "__module__")
and "torch._C._nn" in str(target.__module__)
):
should_replace = True

if should_replace:
node.target = F.linear

# Recompile the graph
gm.recompile()

return gm

def unstable_to_stable(self, gm):
methods = (
Expand Down
116 changes: 115 additions & 1 deletion graph_net/torch/fx_graph_serialize_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,120 @@
import torch.fx


# def apply_ast_based_linear_replacement(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
# """
# Apply AST-based replacement of torch._C._nn.linear to torch.nn.functional.linear.
#
# This function uses AST parsing and transformation to replace torch._C._nn.linear
# calls with torch.nn.functional.linear in the GraphModule's code.
#
# Note: This function is currently commented out as the replacement is now handled
# by simple string replacement in serialize_graph_module_to_str.
#
# Args:
# gm: The GraphModule to modify.
#
# Returns:
# Modified GraphModule with torch._C._nn.linear replaced by torch.nn.functional.linear.
# """
# import ast
# import torch
# import types
#
# # First recompile to generate code
# gm.recompile()
#
# # Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
# code_str = gm.code
#
# # Parse AST
# tree = ast.parse(code_str)
#
# class LinearReplacer(ast.NodeTransformer):
# def visit_Call(self, node):
# # Check if it's a torch._C._nn.linear call
# # Structure: torch._C._nn.linear(...)
# filtered_nodes = [
# node
# for node in [node]
# if isinstance(node.func, ast.Attribute)
# if node.func.attr == "linear"
# if isinstance(node.func.value, ast.Attribute)
# if node.func.value.attr == "_nn"
# if isinstance(node.func.value.value, ast.Attribute)
# if node.func.value.value.attr == "_C"
# if isinstance(node.func.value.value.value, ast.Name)
# if node.func.value.value.value.id == "torch"
# ]
# if filtered_nodes:
# # Found torch._C._nn.linear, replace with torch.nn.functional.linear
# new_func = ast.Attribute(
# value=ast.Attribute(
# value=ast.Attribute(
# value=ast.Name(
# id="torch",
# ctx=ast.Load(),
# ),
# attr="nn",
# ctx=ast.Load(),
# ),
# attr="functional",
# ctx=ast.Load(),
# ),
# attr="linear",
# ctx=ast.Load(),
# )
# node.func = new_func
# return self.generic_visit(node)
#
# transformer = LinearReplacer()
# modified_tree = transformer.visit(tree)
# ast.fix_missing_locations(modified_tree)
#
# # Convert the modified AST back to code string
# new_code = ast.unparse(modified_tree)
#
# # Recompile the modified code
# # Need to import device, inf and other modules that may be used
# namespace = {
# "torch": torch,
# }
# # Try to import device (if used in code)
# try:
# from torch import device
#
# namespace["device"] = device
# except ImportError:
# pass
# # Try to import inf (if used in code)
# try:
# from torch import inf
#
# namespace["inf"] = inf
# except ImportError:
# # If torch doesn't have inf, use math.inf
# try:
# import math
#
# namespace["inf"] = math.inf
# except:
# pass
#
# exec(compile(modified_tree, filename="<ast>", mode="exec"), namespace)
#
# # Update GraphModule's forward method
# forward_func = namespace.get("forward")
# if forward_func:
# gm.forward = types.MethodType(forward_func, gm)
#
# # Use serialize_graph_module_to_str to get the serialized code
# # This ensures the code is properly serialized with unstable API replacements
# serialized_code = serialize_graph_module_to_str(gm)
# gm._code = serialized_code
#
# return gm


def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
"""
Serialize a GraphModule to a string representation, replacing unstable APIs
Expand Down Expand Up @@ -34,7 +148,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
# replace this line with modification code for task 123 (torch._C._nn.pad)
# replace this line with modification code for task 125 (torch._C._nn.gelu)
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
# replace this line with modification code for task 127 (torch._C._nn.linear)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把上面的改动放到这里,减少其他 PR 的合入冲突

(r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("),
]
for pattern, repl in replacements:
code = re.sub(pattern, repl, code)
Expand Down