diff --git a/graph_net/log2json.py b/graph_net/log2json.py index 47edde716..98b2f0dcd 100644 --- a/graph_net/log2json.py +++ b/graph_net/log2json.py @@ -53,6 +53,9 @@ def parse_logs_to_json(log_file: str, output_dir: str): "datatype": {}, "speedup": {}, }, + "result": { + "status": "unknown", + }, } continue @@ -102,16 +105,20 @@ def parse_logs_to_json(log_file: str, output_dir: str): result_status_match = patterns["result_status"].search(line) if result_status_match: status = result_status_match.group(1).strip() + data["result"]["status"] = status if status == "failed" and (i + 1) < len(lines): error_reason_match = patterns["failure"].search(lines[i + 1]) if error_reason_match: reason = error_reason_match.group(1).lower() if "eager" in reason: data["performance"]["failure"] = "eager" + data["result"]["status"] = "eager_fail" elif "compiled" in reason: data["performance"]["failure"] = "compiled" + data["result"]["status"] = "compile_fail" else: data["performance"]["failure"] = "other" + data["result"]["status"] = "runtime_fail" continue speedup_match = patterns["speedup"].search(line) @@ -141,6 +148,20 @@ def parse_logs_to_json(log_file: str, output_dir: str): # filename = f"{model_name}_{subgraph_name}_{compiler_name}.json" filepath = os.path.join(output_dir, filename) + # Build result field with status and speedup + if data["result"]["status"] == "success": + speedup_data = {} + if "e2e" in data["performance"]["speedup"]: + speedup_data["e2e"] = { + "mean": data["performance"]["speedup"]["e2e"] + } + if "gpu" in data["performance"]["speedup"]: + speedup_data["gpu"] = { + "mean": data["performance"]["speedup"]["gpu"] + } + if speedup_data: + data["result"]["speedup"] = speedup_data + with open(filepath, "w", encoding="utf-8") as f: json.dump(data, f, indent=4) diff --git a/graph_net/torch/backend/unstable_to_stable_backend.py b/graph_net/torch/backend/unstable_to_stable_backend.py index f8cb3969c..4302eeaf2 100644 --- a/graph_net/torch/backend/unstable_to_stable_backend.py +++ b/graph_net/torch/backend/unstable_to_stable_backend.py @@ -2,6 +2,7 @@ import torch import sys import inspect +import ast from .graph_compiler_backend import GraphCompilerBackend from ..fx_graph_serialize_util import serialize_graph_module_to_str @@ -12,12 +13,21 @@ def __call__(self, model): unstable_api = os.getenv("DISALLOWED_UNSTABLE_API", "").strip() self.unstable_api = unstable_api + # Use torch.compile's backend method to get graph module uniformly + # This ensures all models use the same conversion method, avoiding performance differences def my_backend(gm, sample_inputs): + # Convert unstable API gm = self.unstable_to_stable(gm) self.check_unstable_api(gm) + # Return forward function without additional optimization return gm.forward - return torch.compile(backend=my_backend)(model) + # Use torch.compile to get graph module and perform conversion + # Although compile is used, the backend only does API conversion, no optimization + # Performance should be close to eager mode (since only API replacement is done) + # Note: Do not use mode parameter to avoid version compatibility issues + compiled_model = torch.compile(model, backend=my_backend) + return compiled_model """ TODO: Implement logic to convert unstable APIs in `self.model` into their stable counterparts. @@ -165,7 +175,59 @@ def _impl_unstable_to_stable_special_logit(self, gm): # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention) - # replace this line with modification code for task 127 (torch._C._nn.linear) + def _impl_unstable_to_stable_linear_to_functional_linear(self, gm): + """ + Convert torch._C._nn.linear to torch.nn.functional.linear + + Args: + gm: torch.fx.GraphModule object + + Returns: + Modified GraphModule object + """ + import torch.nn.functional as F + + # Get reference to torch._C._nn.linear for comparison + try: + unstable_linear = torch._C._nn.linear + except AttributeError: + unstable_linear = None + + # Traverse all nodes to find nodes that need to be replaced + for node in gm.graph.nodes: + if node.op == "call_function": + target = node.target + should_replace = False + + # Method 1: Direct target comparison (most reliable) + if unstable_linear is not None and target is unstable_linear: + should_replace = True + # Method 2: Check if it's the same function object (using id comparison) + elif unstable_linear is not None and id(target) == id(unstable_linear): + should_replace = True + # Method 3: Check module and name attributes (most reliable method, as torch.fx preserves these attributes) + elif hasattr(target, "__module__") and hasattr(target, "__name__"): + if ( + target.__module__ == "torch._C._nn" + and target.__name__ == "linear" + ): + should_replace = True + # Method 4: Check via string representation (fallback method) + elif "torch._C._nn.linear" in str(target) or ( + hasattr(target, "__qualname__") + and "linear" in target.__qualname__ + and hasattr(target, "__module__") + and "torch._C._nn" in str(target.__module__) + ): + should_replace = True + + if should_replace: + node.target = F.linear + + # Recompile the graph + gm.recompile() + + return gm def unstable_to_stable(self, gm): methods = ( diff --git a/graph_net/torch/fx_graph_serialize_util.py b/graph_net/torch/fx_graph_serialize_util.py index a67f83882..eb50d7d4e 100644 --- a/graph_net/torch/fx_graph_serialize_util.py +++ b/graph_net/torch/fx_graph_serialize_util.py @@ -2,6 +2,120 @@ import torch.fx +# def apply_ast_based_linear_replacement(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: +# """ +# Apply AST-based replacement of torch._C._nn.linear to torch.nn.functional.linear. +# +# This function uses AST parsing and transformation to replace torch._C._nn.linear +# calls with torch.nn.functional.linear in the GraphModule's code. +# +# Note: This function is currently commented out as the replacement is now handled +# by simple string replacement in serialize_graph_module_to_str. +# +# Args: +# gm: The GraphModule to modify. +# +# Returns: +# Modified GraphModule with torch._C._nn.linear replaced by torch.nn.functional.linear. +# """ +# import ast +# import torch +# import types +# +# # First recompile to generate code +# gm.recompile() +# +# # Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear +# code_str = gm.code +# +# # Parse AST +# tree = ast.parse(code_str) +# +# class LinearReplacer(ast.NodeTransformer): +# def visit_Call(self, node): +# # Check if it's a torch._C._nn.linear call +# # Structure: torch._C._nn.linear(...) +# filtered_nodes = [ +# node +# for node in [node] +# if isinstance(node.func, ast.Attribute) +# if node.func.attr == "linear" +# if isinstance(node.func.value, ast.Attribute) +# if node.func.value.attr == "_nn" +# if isinstance(node.func.value.value, ast.Attribute) +# if node.func.value.value.attr == "_C" +# if isinstance(node.func.value.value.value, ast.Name) +# if node.func.value.value.value.id == "torch" +# ] +# if filtered_nodes: +# # Found torch._C._nn.linear, replace with torch.nn.functional.linear +# new_func = ast.Attribute( +# value=ast.Attribute( +# value=ast.Attribute( +# value=ast.Name( +# id="torch", +# ctx=ast.Load(), +# ), +# attr="nn", +# ctx=ast.Load(), +# ), +# attr="functional", +# ctx=ast.Load(), +# ), +# attr="linear", +# ctx=ast.Load(), +# ) +# node.func = new_func +# return self.generic_visit(node) +# +# transformer = LinearReplacer() +# modified_tree = transformer.visit(tree) +# ast.fix_missing_locations(modified_tree) +# +# # Convert the modified AST back to code string +# new_code = ast.unparse(modified_tree) +# +# # Recompile the modified code +# # Need to import device, inf and other modules that may be used +# namespace = { +# "torch": torch, +# } +# # Try to import device (if used in code) +# try: +# from torch import device +# +# namespace["device"] = device +# except ImportError: +# pass +# # Try to import inf (if used in code) +# try: +# from torch import inf +# +# namespace["inf"] = inf +# except ImportError: +# # If torch doesn't have inf, use math.inf +# try: +# import math +# +# namespace["inf"] = math.inf +# except: +# pass +# +# exec(compile(modified_tree, filename="", mode="exec"), namespace) +# +# # Update GraphModule's forward method +# forward_func = namespace.get("forward") +# if forward_func: +# gm.forward = types.MethodType(forward_func, gm) +# +# # Use serialize_graph_module_to_str to get the serialized code +# # This ensures the code is properly serialized with unstable API replacements +# serialized_code = serialize_graph_module_to_str(gm) +# gm._code = serialized_code +# +# return gm + + def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str: """ Serialize a GraphModule to a string representation, replacing unstable APIs @@ -34,7 +148,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str: # replace this line with modification code for task 123 (torch._C._nn.pad) # replace this line with modification code for task 125 (torch._C._nn.gelu) # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention) - # replace this line with modification code for task 127 (torch._C._nn.linear) + (r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("), ] for pattern, repl in replacements: code = re.sub(pattern, repl, code)