From af92b864d4f27b722904103e090df71120618185 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Thu, 20 Nov 2025 16:05:09 +0800 Subject: [PATCH 1/6] add test --- graph_net/test/split_points.py | 250 ++++++++++++++++++ .../torch/rp_expr/longest_rp_expr_parser.py | 2 + graph_net/torch/rp_expr/rp_expr.py | 2 +- 3 files changed, 253 insertions(+), 1 deletion(-) create mode 100644 graph_net/test/split_points.py diff --git a/graph_net/test/split_points.py b/graph_net/test/split_points.py new file mode 100644 index 000000000..24236e966 --- /dev/null +++ b/graph_net/test/split_points.py @@ -0,0 +1,250 @@ +import sys +import os +import argparse +import importlib.util +import torch +import torch.nn as nn +from pathlib import Path +from typing import List, Dict, Any, Callable +from graph_net.torch import utils as graph_utils +from graph_net.torch.rp_expr.longest_rp_expr_parser import LongestRpExprParser +from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser + + +class GraphExtractor: + def __init__(self): + self.extract_node = [] + + def _extract_operators_from_graph( + self, gm: nn.Module, example_inputs: List[torch.Tensor] = None + ) -> List[Dict[str, Any]]: + operator_list = [] + named_modules = dict(gm.named_modules()) + + for node in gm.graph.nodes: + if node.op in ("call_method", "call_function", "call_module"): + target_name = str(node.target) + + if node.op == "call_module": + module_instance = named_modules.get(node.target) + if module_instance is not None: + target_name = type(module_instance).__name__ + elif node.op == "call_function": + if isinstance(node.target, Callable): + target_name = node.target.__name__ + elif node.op == "call_method": + target_name = str(node.target) + + operator_info = { + "op_type": node.op, + "target": node.target, + "name": node.name, + "target_name": target_name, + } + operator_list.append(operator_info) + + return operator_list + + def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): + operator = self._extract_operators_from_graph(gm, inputs) + self.extract_node = operator + return gm.forward + + +class ModelLoader: + def load_class_from_file(self, model_path: str, device: str) -> Any: + file_path = os.path.join(model_path, "model.py") + file = Path(file_path).resolve() + module_name = file.stem + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Model file not found: {file_path}") + + with open(file_path, "r", encoding="utf-8") as f: + model_code = f.read() + + model_code = graph_utils.modify_code_by_device(model_code, device) + + spec = importlib.util.spec_from_loader(module_name, loader=None) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + compiled_code = compile(model_code, filename=file, mode="exec") + exec(compiled_code, module.__dict__) + + model_class = getattr(module, "GraphModule", None) + if model_class is None: + raise ImportError(f"Class 'GraphModule' not found in {file_path}") + + setattr(model_class, "__graph_net_file_path__", str(file_path)) + setattr(model_class, "__graph_net_device__", device) + return model_class + + def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor]: + inputs_params = graph_utils.load_converted_from_text(f"{model_path}") + params = inputs_params["weight_info"] + for tensor_meta in params.values(): + if hasattr(tensor_meta, "device"): + tensor_meta.device = device + input_dict = { + k: graph_utils.replay_tensor(v).to(torch.device(device)) + for k, v in params.items() + } + return input_dict + + +def extract_ops_via_compile(model_path: str, device: str = "cpu") -> List[str]: + loader = ModelLoader() + print(f"Loading model from {model_path} on {device}...") + try: + model_class = loader.load_class_from_file(model_path, device) + model = model_class().to(torch.device(device)) + model.eval() + input_dict = loader.get_input_dict(model_path, device) + except Exception as e: + print(f"Error loading/preparing model {model_path}: {e}") + return [] + + extractor = GraphExtractor() + compiled_model = torch.compile(model, backend=extractor.extract_compiler) + + with torch.no_grad(): + compiled_model(**input_dict) + + ops_info = extractor.extract_node + if not ops_info: + print(f"Warning: No operators extracted from {model_path}.") + return [] + return [op["target_name"] for op in ops_info] + + +def calculate_token_lengths(rp_expr, num_primitives, symbol_map) -> Dict[int, int]: + token2len = {} + + def get_len(tid): + if tid in token2len: + return token2len[tid] + if tid < num_primitives: + token2len[tid] = 1 + return 1 + if tid in symbol_map: + sub_tokens = symbol_map[tid].tolist() + length = sum(get_len(t) for t in sub_tokens) + token2len[tid] = length + return length + token2len[tid] = 1 + return 1 + + for sym_id in rp_expr.symbol_token_ids: + get_len(sym_id) + return token2len + + +def main(): + parser = argparse.ArgumentParser( + description="Extract graph patterns and split points from multiple models." + ) + parser.add_argument( + "--models", + nargs="+", + required=True, + help="List of paths to model directories (e.g. --models path/to/m1 path/to/m2)", + ) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--window", type=int, default=10) + args = parser.parse_args() + + inputs = [] + valid_model_names = [] + + for model_path in args.models: + seq = extract_ops_via_compile(model_path, args.device) + inputs.append(seq) + valid_model_names.append(os.path.basename(model_path)) + + rp_parser = RpExprParser( + window_size=args.window, fold_policy="default", fold_times=0 + ) + rp_expr, token_id2primitive_id = rp_parser(inputs) + + rp_expr.try_unwrap_body_of_sole_symbol_token() + rp_expr.try_recursive_inline_symbol_sole_used(token_id2primitive_id) + + num_primitives = len(token_id2primitive_id) + symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors)) + token2len = calculate_token_lengths(rp_expr, num_primitives, symbol_map) + + # print ops func + def resolve_token_to_ops(tid) -> List[str]: + if tid < num_primitives: + return [token_id2primitive_id[tid]] + if tid in symbol_map: + sub_tokens = symbol_map[tid].tolist() + ops = [] + for t in sub_tokens: + ops.extend(resolve_token_to_ops(t)) + return ops + return [f"Unknown({tid})"] + + for sym_id in sorted(symbol_map.keys()): + length = token2len.get(sym_id, 0) + ops_seq = resolve_token_to_ops(sym_id) + ops_str = str(ops_seq) + if len(ops_str) > 100: + ops_str = ops_str[:100] + " ...]" + + for i, model_name in enumerate(valid_model_names): + if i >= len(rp_expr.body_rp_expr): + break + + target_body_tensor = rp_expr.body_rp_expr[i] + seq_tokens = target_body_tensor.tolist() + + current_idx = 0 + split_points = set() + total_len = sum(token2len.get(t, 1) for t in seq_tokens) + + full_model_ops = [] + for t in seq_tokens: + full_model_ops.extend(resolve_token_to_ops(t)) + + for token_id in seq_tokens: + length = token2len.get(token_id, 1) + is_pattern = token_id >= num_primitives + + if is_pattern: + if current_idx > 0: + split_points.add(current_idx) + end_idx = current_idx + length + if end_idx < total_len: + split_points.add(end_idx) + + current_idx += length + + sorted_splits = sorted(list(set(split_points))) + print("=" * 50) + print(f"model_name: {model_name}") + print(f"Split Sequence Indices: {sorted_splits}") + print("Segments info:") + last_split = 0 + for split in sorted_splits + [total_len]: + segment_len = split - last_split + if last_split < len(full_model_ops) and split <= len(full_model_ops): + segment_ops = full_model_ops[last_split:split] + if len(segment_ops) > 5: + ops_display = f"[{segment_ops[0]}, ..., {segment_ops[-1]}]" + else: + ops_display = str(segment_ops) + print( + f" Range [{last_split:3d}, {split:3d}), Length: {segment_len:3d} | Ops: {ops_display}" + ) + else: + print( + f" Range [{last_split:3d}, {split:3d}), Length: {segment_len:3d} | (Index Error Warning)" + ) + + last_split = split + + +if __name__ == "__main__": + main() diff --git a/graph_net/torch/rp_expr/longest_rp_expr_parser.py b/graph_net/torch/rp_expr/longest_rp_expr_parser.py index c6c6b17f9..42e3bf832 100644 --- a/graph_net/torch/rp_expr/longest_rp_expr_parser.py +++ b/graph_net/torch/rp_expr/longest_rp_expr_parser.py @@ -30,6 +30,8 @@ def __call__(self, primitive_id_lists: t.List[t.List[PrimitiveId]]): token_id2primitive_id ) ] + if not cur_primitive_id_lists: + continue cur_lets_list_rp_expr, cur_token_id2primitive_id = rp_expr_parser( cur_primitive_id_lists ) diff --git a/graph_net/torch/rp_expr/rp_expr.py b/graph_net/torch/rp_expr/rp_expr.py index 679254bb6..7fcd7a1f6 100644 --- a/graph_net/torch/rp_expr/rp_expr.py +++ b/graph_net/torch/rp_expr/rp_expr.py @@ -369,7 +369,7 @@ def get_range(size): segments = [ token_tensor[start:end] for consecutive_tensor in consecutive_tensors - for start, end in [get_range(int(consecutive_tensor.size(0)))] + for start, end in [get_range(len(consecutive_tensor))] ] return segments From b59dc338afde95e83d7d64c6f6b4796f1897cf29 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Mon, 24 Nov 2025 16:57:09 +0800 Subject: [PATCH 2/6] add backend --- .../torch/backend/range_decomposer_backend.py | 337 ++++++++++++++++++ graph_net/torch/test_compiler.py | 11 +- 2 files changed, 346 insertions(+), 2 deletions(-) create mode 100644 graph_net/torch/backend/range_decomposer_backend.py diff --git a/graph_net/torch/backend/range_decomposer_backend.py b/graph_net/torch/backend/range_decomposer_backend.py new file mode 100644 index 000000000..12a0f78da --- /dev/null +++ b/graph_net/torch/backend/range_decomposer_backend.py @@ -0,0 +1,337 @@ +import argparse +import base64 +import importlib.util +import inspect +import itertools +import json +import os +import subprocess +import sys +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.nn as nn + +import graph_net +from graph_net.torch import utils as graph_utils +from graph_net.torch.rp_expr.longest_rp_expr_parser import LongestRpExprParser +from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser + + +def encode_config(config: Dict[str, Any]) -> str: + json_str = json.dumps(config) + return base64.b64encode(json_str.encode("utf-8")).decode("utf-8") + + +class GraphExtractor: + def __init__(self): + self.extract_node = [] + + def _extract_operators_from_graph( + self, gm: nn.Module, example_inputs: List[torch.Tensor] = None + ) -> List[Dict[str, Any]]: + operator_list = [] + named_modules = dict(gm.named_modules()) + + for node in gm.graph.nodes: + if node.op in ("call_method", "call_function", "call_module"): + target_name = str(node.target) + + if node.op == "call_module": + module_instance = named_modules.get(node.target) + if module_instance is not None: + target_name = type(module_instance).__name__ + elif node.op == "call_function": + if isinstance(node.target, Callable): + target_name = node.target.__name__ + elif node.op == "call_method": + target_name = str(node.target) + + operator_info = { + "op_type": node.op, + "target": node.target, + "name": node.name, + "target_name": target_name, + } + operator_list.append(operator_info) + + return operator_list + + def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): + operator = self._extract_operators_from_graph(gm, inputs) + self.extract_node = operator + return gm.forward + + +class ModelLoader: + def load_class_from_file(self, model_path: str, device: str) -> Any: + file_path = os.path.join(model_path, "model.py") + file = Path(file_path).resolve() + module_name = file.stem + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Model file not found: {file_path}") + + with open(file_path, "r", encoding="utf-8") as f: + model_code = f.read() + + model_code = graph_utils.modify_code_by_device(model_code, device) + + spec = importlib.util.spec_from_loader(module_name, loader=None) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + compiled_code = compile(model_code, filename=file, mode="exec") + exec(compiled_code, module.__dict__) + + model_class = getattr(module, "GraphModule", None) + if model_class is None: + raise ImportError(f"Class 'GraphModule' not found in {file_path}") + + return model_class + + def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor]: + inputs_params = graph_utils.load_converted_from_text(f"{model_path}") + params = inputs_params["weight_info"] + for tensor_meta in params.values(): + if hasattr(tensor_meta, "device"): + tensor_meta.device = device + input_dict = { + k: graph_utils.replay_tensor(v).to(torch.device(device)) + for k, v in params.items() + } + return input_dict + + +class RangeDecomposerBackend: + def __init__(self): + self.window_size = 10 + self.graph_net_root = Path(graph_net.__file__).parent + self.workspace_root = Path.cwd() / "naive_decompose_workspace" + + def _resolve_token_to_ops( + self, tid, num_primitives, token_id2primitive_id, symbol_map + ) -> List[str]: + if tid < num_primitives: + return [token_id2primitive_id[tid]] + if tid in symbol_map: + sub_tokens = symbol_map[tid].tolist() + ops = [] + for t in sub_tokens: + ops.extend( + self._resolve_token_to_ops( + t, num_primitives, token_id2primitive_id, symbol_map + ) + ) + return ops + return [f"Unknown({tid})"] + + def _extract_ops_via_compile( + self, model_path: str, device: str = "cpu" + ) -> List[str]: + loader = ModelLoader() + print(f"Loading model from {model_path} on {device}...") + try: + model_class = loader.load_class_from_file(model_path, device) + model = model_class().to(torch.device(device)) + model.eval() + input_dict = loader.get_input_dict(model_path, device) + except Exception as e: + print(f"Error loading/preparing model {model_path}: {e}") + return [] + + extractor = GraphExtractor() + compiled_model = torch.compile(model, backend=extractor.extract_compiler) + + with torch.no_grad(): + compiled_model(**input_dict) + + ops_info = extractor.extract_node + if not ops_info: + print(f"Warning: No operators extracted from {model_path}.") + return [] + return [op["target_name"] for op in ops_info] + + def _calculate_token_lengths( + self, rp_expr, num_primitives, symbol_map + ) -> Dict[int, int]: + token2len = {} + + def get_len(tid): + if tid in token2len: + return token2len[tid] + if tid < num_primitives: + token2len[tid] = 1 + return 1 + if tid in symbol_map: + sub_tokens = symbol_map[tid].tolist() + length = sum(get_len(t) for t in sub_tokens) + token2len[tid] = length + return length + token2len[tid] = 1 + return 1 + + for sym_id in rp_expr.symbol_token_ids: + get_len(sym_id) + return token2len + + def _analyze_and_get_splits(self, args) -> Dict[str, Dict]: + input_file = Path(args.model_path) + if not input_file.exists(): + print(f"Error: Input file {input_file} does not exist.") + return {} + + with open(input_file, "r") as f: + model_paths = [ + Path(line.strip()) + for line in f + if line.strip() and not line.startswith("#") + ] + + if not model_paths: + print("No valid model paths found.") + return {} + + inputs_seqs = [] + valid_models = [] + + for p in model_paths: + seq = self._extract_ops_via_compile(p, args.device) + if seq: + inputs_seqs.append(seq) + valid_models.append((p.name, p)) + + if not inputs_seqs: + return {} + + rp_parser = RpExprParser( + window_size=self.window_size, fold_policy="default", fold_times=0 + ) + rp_expr, token_id2primitive_id = rp_parser(inputs_seqs) + rp_expr.try_unwrap_body_of_sole_symbol_token() + rp_expr.try_recursive_inline_symbol_sole_used(token_id2primitive_id) + num_primitives = len(token_id2primitive_id) + symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors)) + token2len = self._calculate_token_lengths(rp_expr, num_primitives, symbol_map) + + results = {} + + for i, (model_name, original_path) in enumerate(valid_models): + if i >= len(rp_expr.body_rp_expr): + break + + target_body_tensor = rp_expr.body_rp_expr[i] + seq_tokens = target_body_tensor.tolist() + + full_model_ops = [] + for t in seq_tokens: + full_model_ops.extend( + self._resolve_token_to_ops( + t, num_primitives, token_id2primitive_id, symbol_map + ) + ) + + current_idx = 0 + split_points_set = set() + total_len = sum(token2len.get(t, 1) for t in seq_tokens) + + for token_id in seq_tokens: + length = token2len.get(token_id, 1) + is_pattern = token_id >= num_primitives + + if is_pattern: + if current_idx > 0: + split_points_set.add(current_idx) + end_idx = current_idx + length + if end_idx < total_len: + split_points_set.add(end_idx) + + current_idx += length + + sorted_splits = sorted(list(split_points_set)) + + self._print_analysis( + model_name, original_path, sorted_splits, total_len, full_model_ops + ) + + results[model_name] = { + "path": str(original_path), + "split_points": sorted_splits, + } + + return results + + def _print_analysis(self, name, path, splits, total_len, full_ops): + print("=" * 60) + print(f"Model: {name}") + print(f"Path: {path}") + print(f"Splits: {splits}") + print("-" * 60) + + last_split = 0 + for split in splits + [total_len]: + segment_len = split - last_split + + start_safe = min(last_split, len(full_ops)) + end_safe = min(split, len(full_ops)) + segment_ops = full_ops[start_safe:end_safe] + + ops_display = str(segment_ops) + if len(segment_ops) > 5: + ops_display = f"[{segment_ops[0]}, ..., {segment_ops[-1]}]" + + print( + f" Range [{last_split:3d}, {split:3d}), Len: {segment_len:3d} | Ops: {ops_display}" + ) + last_split = split + print("\n") + + def __call__(self, args): + model_data_map = self._analyze_and_get_splits(args) + + for model_name, info in model_data_map.items(): + model_path = info["path"] + split_points = info["split_points"] + + model_output_dir = self.workspace_root / f"{model_name}_decomposed" + model_output_dir.mkdir(parents=True, exist_ok=True) + + config_dict = { + "decorator_path": str(self.graph_net_root / "torch/extractor.py"), + "decorator_config": { + "name": model_name, + "custom_extractor_path": str( + self.graph_net_root / "torch/naive_graph_decomposer.py" + ), + "custom_extractor_config": { + "output_dir": str(model_output_dir), + "split_positions": split_points, + "group_head_and_tail": True, + "filter_path": str( + self.graph_net_root / "torch/naive_subgraph_filter.py" + ), + "filter_config": {}, + }, + }, + } + + encoded_config = encode_config(config_dict) + + cmd = [ + sys.executable, + "-m", + "graph_net.torch.run_model", + "--model-path", + model_path, + "--decorator-config", + encoded_config, + ] + + try: + subprocess.run(cmd, check=True) + print(f" [Success] Saved to {model_output_dir}") + except subprocess.CalledProcessError as e: + print(f" [Error] Process failed: {e}") + except Exception as e: + print(f" [Error] Unexpected: {e}") diff --git a/graph_net/torch/test_compiler.py b/graph_net/torch/test_compiler.py index ff2eea94f..86bda9b02 100644 --- a/graph_net/torch/test_compiler.py +++ b/graph_net/torch/test_compiler.py @@ -23,6 +23,7 @@ from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend from graph_net.torch.backend.nope_backend import NopeBackend from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend +from graph_net.torch.backend.range_decomposer_backend import RangeDecomposerBackend from graph_net.torch.backend.range_decomposer_validator_backend import ( RangeDecomposerValidatorBackend, ) @@ -39,6 +40,7 @@ "bladedisc": BladeDISCBackend(), "nope": NopeBackend(), "unstable_to_stable": UnstableToStableBackend(), + "range_decomposer": RangeDecomposerBackend(), "range_decomposer_validator": RangeDecomposerValidatorBackend(), } @@ -385,11 +387,16 @@ def test_multi_models(args): def main(args): - assert os.path.isdir(args.model_path) - initalize_seed = 123 set_seed(random_seed=initalize_seed) + if args.compiler == "range_decomposer": + compiler = get_compiler_backend(args) + compiler(args) + return + + assert os.path.isdir(args.model_path) + if path_utils.is_single_model_dir(args.model_path): test_single_model(args) else: From a47d29af5a0cc3c1a87b81f513818b65741e6584 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Fri, 28 Nov 2025 14:08:03 +0800 Subject: [PATCH 3/6] split --- .../torch/backend/range_decomposer_backend.py | 258 --------------- .../torch/typical_sequence_split_points.py | 295 ++++++++++++++++++ 2 files changed, 295 insertions(+), 258 deletions(-) create mode 100644 graph_net/torch/typical_sequence_split_points.py diff --git a/graph_net/torch/backend/range_decomposer_backend.py b/graph_net/torch/backend/range_decomposer_backend.py index 12a0f78da..3229c637f 100644 --- a/graph_net/torch/backend/range_decomposer_backend.py +++ b/graph_net/torch/backend/range_decomposer_backend.py @@ -24,269 +24,11 @@ def encode_config(config: Dict[str, Any]) -> str: return base64.b64encode(json_str.encode("utf-8")).decode("utf-8") -class GraphExtractor: - def __init__(self): - self.extract_node = [] - - def _extract_operators_from_graph( - self, gm: nn.Module, example_inputs: List[torch.Tensor] = None - ) -> List[Dict[str, Any]]: - operator_list = [] - named_modules = dict(gm.named_modules()) - - for node in gm.graph.nodes: - if node.op in ("call_method", "call_function", "call_module"): - target_name = str(node.target) - - if node.op == "call_module": - module_instance = named_modules.get(node.target) - if module_instance is not None: - target_name = type(module_instance).__name__ - elif node.op == "call_function": - if isinstance(node.target, Callable): - target_name = node.target.__name__ - elif node.op == "call_method": - target_name = str(node.target) - - operator_info = { - "op_type": node.op, - "target": node.target, - "name": node.name, - "target_name": target_name, - } - operator_list.append(operator_info) - - return operator_list - - def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): - operator = self._extract_operators_from_graph(gm, inputs) - self.extract_node = operator - return gm.forward - - -class ModelLoader: - def load_class_from_file(self, model_path: str, device: str) -> Any: - file_path = os.path.join(model_path, "model.py") - file = Path(file_path).resolve() - module_name = file.stem - - if not os.path.exists(file_path): - raise FileNotFoundError(f"Model file not found: {file_path}") - - with open(file_path, "r", encoding="utf-8") as f: - model_code = f.read() - - model_code = graph_utils.modify_code_by_device(model_code, device) - - spec = importlib.util.spec_from_loader(module_name, loader=None) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - - compiled_code = compile(model_code, filename=file, mode="exec") - exec(compiled_code, module.__dict__) - - model_class = getattr(module, "GraphModule", None) - if model_class is None: - raise ImportError(f"Class 'GraphModule' not found in {file_path}") - - return model_class - - def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor]: - inputs_params = graph_utils.load_converted_from_text(f"{model_path}") - params = inputs_params["weight_info"] - for tensor_meta in params.values(): - if hasattr(tensor_meta, "device"): - tensor_meta.device = device - input_dict = { - k: graph_utils.replay_tensor(v).to(torch.device(device)) - for k, v in params.items() - } - return input_dict - - class RangeDecomposerBackend: def __init__(self): - self.window_size = 10 self.graph_net_root = Path(graph_net.__file__).parent self.workspace_root = Path.cwd() / "naive_decompose_workspace" - def _resolve_token_to_ops( - self, tid, num_primitives, token_id2primitive_id, symbol_map - ) -> List[str]: - if tid < num_primitives: - return [token_id2primitive_id[tid]] - if tid in symbol_map: - sub_tokens = symbol_map[tid].tolist() - ops = [] - for t in sub_tokens: - ops.extend( - self._resolve_token_to_ops( - t, num_primitives, token_id2primitive_id, symbol_map - ) - ) - return ops - return [f"Unknown({tid})"] - - def _extract_ops_via_compile( - self, model_path: str, device: str = "cpu" - ) -> List[str]: - loader = ModelLoader() - print(f"Loading model from {model_path} on {device}...") - try: - model_class = loader.load_class_from_file(model_path, device) - model = model_class().to(torch.device(device)) - model.eval() - input_dict = loader.get_input_dict(model_path, device) - except Exception as e: - print(f"Error loading/preparing model {model_path}: {e}") - return [] - - extractor = GraphExtractor() - compiled_model = torch.compile(model, backend=extractor.extract_compiler) - - with torch.no_grad(): - compiled_model(**input_dict) - - ops_info = extractor.extract_node - if not ops_info: - print(f"Warning: No operators extracted from {model_path}.") - return [] - return [op["target_name"] for op in ops_info] - - def _calculate_token_lengths( - self, rp_expr, num_primitives, symbol_map - ) -> Dict[int, int]: - token2len = {} - - def get_len(tid): - if tid in token2len: - return token2len[tid] - if tid < num_primitives: - token2len[tid] = 1 - return 1 - if tid in symbol_map: - sub_tokens = symbol_map[tid].tolist() - length = sum(get_len(t) for t in sub_tokens) - token2len[tid] = length - return length - token2len[tid] = 1 - return 1 - - for sym_id in rp_expr.symbol_token_ids: - get_len(sym_id) - return token2len - - def _analyze_and_get_splits(self, args) -> Dict[str, Dict]: - input_file = Path(args.model_path) - if not input_file.exists(): - print(f"Error: Input file {input_file} does not exist.") - return {} - - with open(input_file, "r") as f: - model_paths = [ - Path(line.strip()) - for line in f - if line.strip() and not line.startswith("#") - ] - - if not model_paths: - print("No valid model paths found.") - return {} - - inputs_seqs = [] - valid_models = [] - - for p in model_paths: - seq = self._extract_ops_via_compile(p, args.device) - if seq: - inputs_seqs.append(seq) - valid_models.append((p.name, p)) - - if not inputs_seqs: - return {} - - rp_parser = RpExprParser( - window_size=self.window_size, fold_policy="default", fold_times=0 - ) - rp_expr, token_id2primitive_id = rp_parser(inputs_seqs) - rp_expr.try_unwrap_body_of_sole_symbol_token() - rp_expr.try_recursive_inline_symbol_sole_used(token_id2primitive_id) - num_primitives = len(token_id2primitive_id) - symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors)) - token2len = self._calculate_token_lengths(rp_expr, num_primitives, symbol_map) - - results = {} - - for i, (model_name, original_path) in enumerate(valid_models): - if i >= len(rp_expr.body_rp_expr): - break - - target_body_tensor = rp_expr.body_rp_expr[i] - seq_tokens = target_body_tensor.tolist() - - full_model_ops = [] - for t in seq_tokens: - full_model_ops.extend( - self._resolve_token_to_ops( - t, num_primitives, token_id2primitive_id, symbol_map - ) - ) - - current_idx = 0 - split_points_set = set() - total_len = sum(token2len.get(t, 1) for t in seq_tokens) - - for token_id in seq_tokens: - length = token2len.get(token_id, 1) - is_pattern = token_id >= num_primitives - - if is_pattern: - if current_idx > 0: - split_points_set.add(current_idx) - end_idx = current_idx + length - if end_idx < total_len: - split_points_set.add(end_idx) - - current_idx += length - - sorted_splits = sorted(list(split_points_set)) - - self._print_analysis( - model_name, original_path, sorted_splits, total_len, full_model_ops - ) - - results[model_name] = { - "path": str(original_path), - "split_points": sorted_splits, - } - - return results - - def _print_analysis(self, name, path, splits, total_len, full_ops): - print("=" * 60) - print(f"Model: {name}") - print(f"Path: {path}") - print(f"Splits: {splits}") - print("-" * 60) - - last_split = 0 - for split in splits + [total_len]: - segment_len = split - last_split - - start_safe = min(last_split, len(full_ops)) - end_safe = min(split, len(full_ops)) - segment_ops = full_ops[start_safe:end_safe] - - ops_display = str(segment_ops) - if len(segment_ops) > 5: - ops_display = f"[{segment_ops[0]}, ..., {segment_ops[-1]}]" - - print( - f" Range [{last_split:3d}, {split:3d}), Len: {segment_len:3d} | Ops: {ops_display}" - ) - last_split = split - print("\n") - def __call__(self, args): model_data_map = self._analyze_and_get_splits(args) diff --git a/graph_net/torch/typical_sequence_split_points.py b/graph_net/torch/typical_sequence_split_points.py new file mode 100644 index 000000000..17f97bb9b --- /dev/null +++ b/graph_net/torch/typical_sequence_split_points.py @@ -0,0 +1,295 @@ +import argparse +import importlib.util +import json +import os +import sys +from pathlib import Path +from typing import Any, Callable, Dict, List + +import torch +import torch.nn as nn +import tempfile +import graph_net.imp_util +from graph_net.torch import utils as graph_utils +from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser + + +class TypicalSequenceExtractor: + def __init__(self): + self.extract_node = [] + + def _extract_operators_from_graph( + self, gm: nn.Module, example_inputs: List[torch.Tensor] = None + ) -> List[Dict[str, Any]]: + operator_list = [] + named_modules = dict(gm.named_modules()) + + for node in gm.graph.nodes: + if node.op not in ("call_method", "call_function", "call_module"): + continue + + if node.op == "call_module": + target_name = type(named_modules[node.target]).__name__ + else: + target_name = getattr(node.target, "__name__", str(node.target)) + + operator_list.append( + { + "op_type": node.op, + "target": node.target, + "name": node.name, + "target_name": target_name, + } + ) + + return operator_list + + def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): + operator = self._extract_operators_from_graph(gm, inputs) + self.extract_node = operator + return gm.forward + + +class TypicalSequenceModelLoader: + def load_class_from_file(self, model_path: str, device: str) -> Any: + file_path = os.path.join(model_path, "model.py") + file = Path(file_path).resolve() + module_name = file.stem + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Model file not found: {file_path}") + + with open(file_path, "r", encoding="utf-8") as f: + model_code = f.read() + model_code = graph_utils.modify_code_by_device(model_code, device) + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8" + ) as temp_file: + temp_file.write(model_code) + module = graph_net.imp_util.load_module(temp_file.name) + model_class = getattr(module, "GraphModule", None) + + return model_class + + def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor]: + inputs_params = graph_utils.load_converted_from_text(f"{model_path}") + params = inputs_params["weight_info"] + for tensor_meta in params.values(): + if hasattr(tensor_meta, "device"): + tensor_meta.device = device + input_dict = { + k: graph_utils.replay_tensor(v).to(torch.device(device)) + for k, v in params.items() + } + return input_dict + + +class SplitAnalyzer: + def __init__(self, window_size: int = 10): + self.window_size = window_size + + def _resolve_token_to_ops( + self, tid, num_primitives, token_id2primitive_id, symbol_map + ) -> List[str]: + if tid < num_primitives: + return [token_id2primitive_id[tid]] + if tid in symbol_map: + sub_tokens = symbol_map[tid].tolist() + ops = [] + for t in sub_tokens: + ops.extend( + self._resolve_token_to_ops( + t, num_primitives, token_id2primitive_id, symbol_map + ) + ) + return ops + return [f"Unknown({tid})"] + + def _extract_ops_via_compile( + self, model_path: str, device: str = "cpu" + ) -> List[str]: + loader = TypicalSequenceModelLoader() + print(f"Loading model from {model_path} on {device}...") + try: + model_class = loader.load_class_from_file(model_path, device) + model = model_class().to(torch.device(device)) + model.eval() + input_dict = loader.get_input_dict(model_path, device) + except Exception as e: + print(f"Error loading/preparing model {model_path}: {e}") + return [] + + extractor = TypicalSequenceExtractor() + compiled_model = torch.compile(model, backend=extractor.extract_compiler) + compiled_model(**input_dict) + ops_info = extractor.extract_node + + return [op["target_name"] for op in ops_info] + + def _calculate_token_lengths( + self, rp_expr, num_primitives, symbol_map + ) -> Dict[int, int]: + token2len = {} + + def get_len(tid): + if tid in token2len: + return token2len[tid] + if tid < num_primitives: + token2len[tid] = 1 + return 1 + if tid in symbol_map: + sub_tokens = symbol_map[tid].tolist() + length = sum(get_len(t) for t in sub_tokens) + token2len[tid] = length + return length + token2len[tid] = 1 + return 1 + + for sym_id in rp_expr.symbol_token_ids: + get_len(sym_id) + return token2len + + def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]: + input_file = Path(model_paths_file) + + with open(input_file, "r") as f: + model_paths = [ + Path(line.strip()) + for line in f + if line.strip() and not line.startswith("#") + ] + + inputs_seqs = [] + valid_models = [] + + for p in model_paths: + seq = self._extract_ops_via_compile(str(p), device) + if seq: + inputs_seqs.append(seq) + valid_models.append((p.name, p)) + + if not inputs_seqs: + return {} + + rp_parser = RpExprParser( + window_size=self.window_size, fold_policy="default", fold_times=0 + ) + rp_expr, token_id2primitive_id = rp_parser(inputs_seqs) + rp_expr.try_unwrap_body_of_sole_symbol_token() + rp_expr.try_recursive_inline_symbol_sole_used(token_id2primitive_id) + + num_primitives = len(token_id2primitive_id) + symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors)) + token2len = self._calculate_token_lengths(rp_expr, num_primitives, symbol_map) + + results = {} + + for i, (model_name, original_path) in enumerate(valid_models): + if i >= len(rp_expr.body_rp_expr): + break + + target_body_tensor = rp_expr.body_rp_expr[i] + seq_tokens = target_body_tensor.tolist() + + full_model_ops = [] + for t in seq_tokens: + full_model_ops.extend( + self._resolve_token_to_ops( + t, num_primitives, token_id2primitive_id, symbol_map + ) + ) + + current_idx = 0 + split_points_set = set() + total_len = sum(token2len.get(t, 1) for t in seq_tokens) + + for token_id in seq_tokens: + length = token2len.get(token_id, 1) + is_pattern = token_id >= num_primitives + + if is_pattern: + if current_idx > 0: + split_points_set.add(current_idx) + end_idx = current_idx + length + if end_idx < total_len: + split_points_set.add(end_idx) + + current_idx += length + + sorted_splits = sorted(list(split_points_set)) + + self._print_analysis( + model_name, str(original_path), sorted_splits, total_len, full_model_ops + ) + + results[model_name] = { + "path": str(original_path), + "split_points": sorted_splits, + "total_length": total_len, + } + + return results + + def _print_analysis(self, name, path, splits, total_len, full_ops): + print("=" * 60) + print(f"Model: {name}") + print(f"Path: {path}") + print(f"Splits: {splits}") + print("-" * 60) + + last_split = 0 + for split in splits + [total_len]: + segment_len = split - last_split + start_safe = min(last_split, len(full_ops)) + end_safe = min(split, len(full_ops)) + segment_ops = full_ops[start_safe:end_safe] + + ops_display = str(segment_ops) + if len(segment_ops) > 5: + ops_display = f"[{segment_ops[0]}, ..., {segment_ops[-1]}]" + + print( + f"Range [{last_split:3d}, {split:3d}), Len: {segment_len:3d} | Ops: {ops_display}" + ) + last_split = split + print("\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze graph and calculate split points." + ) + parser.add_argument( + "--model-list", + type=str, + required=True, + help="Path to a text file containing paths to models (one per line).", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="Device to load models on (cpu, cuda).", + ) + parser.add_argument( + "--window-size", type=int, default=10, help="Window size for RP Parser." + ) + parser.add_argument( + "--output-json", + type=str, + default="split_results.json", + help="Path to save the analysis results in JSON format.", + ) + args = parser.parse_args() + + analyzer = SplitAnalyzer(window_size=args.window_size) + results = analyzer.analyze(args.model_list, args.device) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + main() From e5abd5f02fcae37056c01253b4f4d039335dc0c6 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Fri, 28 Nov 2025 19:00:01 +0800 Subject: [PATCH 4/6] fix --- .../test/typical_sequence_decomposer_test.sh | 46 +++++++ .../torch/backend/range_decomposer_backend.py | 125 ++++++++++-------- graph_net/torch/test_compiler.py | 21 +-- .../torch/typical_sequence_split_points.py | 4 - 4 files changed, 129 insertions(+), 67 deletions(-) create mode 100644 graph_net/test/typical_sequence_decomposer_test.sh diff --git a/graph_net/test/typical_sequence_decomposer_test.sh b/graph_net/test/typical_sequence_decomposer_test.sh new file mode 100644 index 000000000..c32507f4f --- /dev/null +++ b/graph_net/test/typical_sequence_decomposer_test.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") + +MODEL1="$GRAPH_NET_ROOT/samples/torchvision/resnet18" +MODEL2="$GRAPH_NET_ROOT/samples/torchvision/resnet34" +MODEL_LIST_FILE=$(mktemp) +echo "$MODEL1" > "$MODEL_LIST_FILE" +echo "$MODEL2" >> "$MODEL_LIST_FILE" + +python3 -m graph_net.torch.typical_sequence_split_points \ + --model-list "$MODEL_LIST_FILE" \ + --device "cuda" \ + --window-size 10 \ + --output-json "$GRAPH_NET_ROOT/split_results.json" + +rm -f "$MODEL_LIST_FILE" + + +MODEL_PATH_IN_SAMPLES=/torchvision/resnet18 +MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES") + +decomposer_config_json_str=$(cat < "$DECOMPOSE_PATH/log.log" 2>&1 + +python3 -m graph_net.plot_ESt \ + --benchmark-path $DECOMPOSE_PATH/log.log \ + --output-dir $DECOMPOSE_PATH \ \ No newline at end of file diff --git a/graph_net/torch/backend/range_decomposer_backend.py b/graph_net/torch/backend/range_decomposer_backend.py index 3229c637f..5a410c7b2 100644 --- a/graph_net/torch/backend/range_decomposer_backend.py +++ b/graph_net/torch/backend/range_decomposer_backend.py @@ -1,22 +1,21 @@ -import argparse import base64 -import importlib.util -import inspect -import itertools import json -import os import subprocess import sys from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Dict import torch -import torch.nn as nn - import graph_net -from graph_net.torch import utils as graph_utils -from graph_net.torch.rp_expr.longest_rp_expr_parser import LongestRpExprParser -from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser + + +def convert_to_dict(config_str): + if config_str is None: + return {} + config_str = base64.b64decode(config_str).decode("utf-8") + config = json.loads(config_str) + assert isinstance(config, dict), f"config should be a dict. {config_str=}" + return config def encode_config(config: Dict[str, Any]) -> str: @@ -24,56 +23,72 @@ def encode_config(config: Dict[str, Any]) -> str: return base64.b64encode(json_str.encode("utf-8")).decode("utf-8") +def load_json(file_path): + with open(file_path, "r", encoding="utf-8") as file: + data_dict = json.load(file) + return data_dict + + class RangeDecomposerBackend: def __init__(self): self.graph_net_root = Path(graph_net.__file__).parent - self.workspace_root = Path.cwd() / "naive_decompose_workspace" - def __call__(self, args): - model_data_map = self._analyze_and_get_splits(args) + def __call__(self, model: torch.nn.Module) -> torch.nn.Module: + config = convert_to_dict(self.config) + workspace_path = Path(config["workspace_path"]) + chain_style = config["chain_style"] - for model_name, info in model_data_map.items(): - model_path = info["path"] - split_points = info["split_points"] + model_file_path = Path(model.__class__.__graph_net_file_path__) + model_name = model_file_path.parent.name - model_output_dir = self.workspace_root / f"{model_name}_decomposed" - model_output_dir.mkdir(parents=True, exist_ok=True) + model_info = load_json(config["split_results_path"])[model_name] + model_path = model_info["path"] + split_points = model_info["split_points"] - config_dict = { - "decorator_path": str(self.graph_net_root / "torch/extractor.py"), - "decorator_config": { - "name": model_name, - "custom_extractor_path": str( - self.graph_net_root / "torch/naive_graph_decomposer.py" + model_output_dir = workspace_path / f"{model_name}_decomposed" + model_output_dir.mkdir(parents=True, exist_ok=True) + + config_dict = { + "decorator_path": str(self.graph_net_root / "torch/extractor.py"), + "decorator_config": { + "name": model_name, + "custom_extractor_path": str( + self.graph_net_root / "torch/naive_graph_decomposer.py" + ), + "custom_extractor_config": { + "output_dir": str(model_output_dir), + "split_positions": split_points, + "group_head_and_tail": True, + "filter_path": str( + self.graph_net_root / "torch/naive_subgraph_filter.py" ), - "custom_extractor_config": { - "output_dir": str(model_output_dir), - "split_positions": split_points, - "group_head_and_tail": True, - "filter_path": str( - self.graph_net_root / "torch/naive_subgraph_filter.py" - ), - "filter_config": {}, - }, + "filter_config": {}, + "chain_style": chain_style, }, - } - - encoded_config = encode_config(config_dict) - - cmd = [ - sys.executable, - "-m", - "graph_net.torch.run_model", - "--model-path", - model_path, - "--decorator-config", - encoded_config, - ] - - try: - subprocess.run(cmd, check=True) - print(f" [Success] Saved to {model_output_dir}") - except subprocess.CalledProcessError as e: - print(f" [Error] Process failed: {e}") - except Exception as e: - print(f" [Error] Unexpected: {e}") + }, + } + + encoded_config = encode_config(config_dict) + + cmd = [ + sys.executable, + "-m", + "graph_net.torch.run_model", + "--model-path", + model_path, + "--decorator-config", + encoded_config, + ] + + try: + subprocess.run(cmd, check=True) + print(f"[Success] Saved to {model_output_dir}") + except subprocess.CalledProcessError as e: + print(f"[Error] Process failed: {e}") + except Exception as e: + print(f"[Error] Unexpected: {e}") + return model + + def synchronize(self): + if torch.cuda.is_available(): + torch.cuda.synchronize() diff --git a/graph_net/torch/test_compiler.py b/graph_net/torch/test_compiler.py index 5561c1ba9..f06a13bdf 100644 --- a/graph_net/torch/test_compiler.py +++ b/graph_net/torch/test_compiler.py @@ -96,7 +96,10 @@ def load_class_from_file( def get_compiler_backend(args) -> GraphCompilerBackend: assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}" - return registry_backend[args.compiler] + backend = registry_backend[args.compiler] + if args.config is not None: + backend.config = args.config + return backend def get_model(args): @@ -396,16 +399,11 @@ def test_multi_models(args): def main(args): + assert os.path.isdir(args.model_path) + initalize_seed = 123 set_seed(random_seed=initalize_seed) - if args.compiler == "range_decomposer": - compiler = get_compiler_backend(args) - compiler(args) - return - - assert os.path.isdir(args.model_path) - if path_utils.is_single_model_dir(args.model_path): test_single_model(args) else: @@ -454,5 +452,12 @@ def main(args): default=None, help="Path to samples list, each line contains a sample path", ) + parser.add_argument( + "--config", + type=str, + required=False, + default=None, + help="Path to configuration file.", + ) args = parser.parse_args() main(args=args) diff --git a/graph_net/torch/typical_sequence_split_points.py b/graph_net/torch/typical_sequence_split_points.py index 17f97bb9b..74dae9e28 100644 --- a/graph_net/torch/typical_sequence_split_points.py +++ b/graph_net/torch/typical_sequence_split_points.py @@ -1,8 +1,6 @@ import argparse -import importlib.util import json import os -import sys from pathlib import Path from typing import Any, Callable, Dict, List @@ -53,8 +51,6 @@ def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]) class TypicalSequenceModelLoader: def load_class_from_file(self, model_path: str, device: str) -> Any: file_path = os.path.join(model_path, "model.py") - file = Path(file_path).resolve() - module_name = file.stem if not os.path.exists(file_path): raise FileNotFoundError(f"Model file not found: {file_path}") From fbdf5724e08c5571a808b1de1ab8ae2ad176799e Mon Sep 17 00:00:00 2001 From: fangfangssj <99968055+fangfangssj@users.noreply.github.com> Date: Fri, 28 Nov 2025 19:10:25 +0800 Subject: [PATCH 5/6] Delete graph_net/test/split_points.py --- graph_net/test/split_points.py | 250 --------------------------------- 1 file changed, 250 deletions(-) delete mode 100644 graph_net/test/split_points.py diff --git a/graph_net/test/split_points.py b/graph_net/test/split_points.py deleted file mode 100644 index 24236e966..000000000 --- a/graph_net/test/split_points.py +++ /dev/null @@ -1,250 +0,0 @@ -import sys -import os -import argparse -import importlib.util -import torch -import torch.nn as nn -from pathlib import Path -from typing import List, Dict, Any, Callable -from graph_net.torch import utils as graph_utils -from graph_net.torch.rp_expr.longest_rp_expr_parser import LongestRpExprParser -from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser - - -class GraphExtractor: - def __init__(self): - self.extract_node = [] - - def _extract_operators_from_graph( - self, gm: nn.Module, example_inputs: List[torch.Tensor] = None - ) -> List[Dict[str, Any]]: - operator_list = [] - named_modules = dict(gm.named_modules()) - - for node in gm.graph.nodes: - if node.op in ("call_method", "call_function", "call_module"): - target_name = str(node.target) - - if node.op == "call_module": - module_instance = named_modules.get(node.target) - if module_instance is not None: - target_name = type(module_instance).__name__ - elif node.op == "call_function": - if isinstance(node.target, Callable): - target_name = node.target.__name__ - elif node.op == "call_method": - target_name = str(node.target) - - operator_info = { - "op_type": node.op, - "target": node.target, - "name": node.name, - "target_name": target_name, - } - operator_list.append(operator_info) - - return operator_list - - def extract_compiler(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): - operator = self._extract_operators_from_graph(gm, inputs) - self.extract_node = operator - return gm.forward - - -class ModelLoader: - def load_class_from_file(self, model_path: str, device: str) -> Any: - file_path = os.path.join(model_path, "model.py") - file = Path(file_path).resolve() - module_name = file.stem - - if not os.path.exists(file_path): - raise FileNotFoundError(f"Model file not found: {file_path}") - - with open(file_path, "r", encoding="utf-8") as f: - model_code = f.read() - - model_code = graph_utils.modify_code_by_device(model_code, device) - - spec = importlib.util.spec_from_loader(module_name, loader=None) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - - compiled_code = compile(model_code, filename=file, mode="exec") - exec(compiled_code, module.__dict__) - - model_class = getattr(module, "GraphModule", None) - if model_class is None: - raise ImportError(f"Class 'GraphModule' not found in {file_path}") - - setattr(model_class, "__graph_net_file_path__", str(file_path)) - setattr(model_class, "__graph_net_device__", device) - return model_class - - def get_input_dict(self, model_path: str, device: str) -> Dict[str, torch.Tensor]: - inputs_params = graph_utils.load_converted_from_text(f"{model_path}") - params = inputs_params["weight_info"] - for tensor_meta in params.values(): - if hasattr(tensor_meta, "device"): - tensor_meta.device = device - input_dict = { - k: graph_utils.replay_tensor(v).to(torch.device(device)) - for k, v in params.items() - } - return input_dict - - -def extract_ops_via_compile(model_path: str, device: str = "cpu") -> List[str]: - loader = ModelLoader() - print(f"Loading model from {model_path} on {device}...") - try: - model_class = loader.load_class_from_file(model_path, device) - model = model_class().to(torch.device(device)) - model.eval() - input_dict = loader.get_input_dict(model_path, device) - except Exception as e: - print(f"Error loading/preparing model {model_path}: {e}") - return [] - - extractor = GraphExtractor() - compiled_model = torch.compile(model, backend=extractor.extract_compiler) - - with torch.no_grad(): - compiled_model(**input_dict) - - ops_info = extractor.extract_node - if not ops_info: - print(f"Warning: No operators extracted from {model_path}.") - return [] - return [op["target_name"] for op in ops_info] - - -def calculate_token_lengths(rp_expr, num_primitives, symbol_map) -> Dict[int, int]: - token2len = {} - - def get_len(tid): - if tid in token2len: - return token2len[tid] - if tid < num_primitives: - token2len[tid] = 1 - return 1 - if tid in symbol_map: - sub_tokens = symbol_map[tid].tolist() - length = sum(get_len(t) for t in sub_tokens) - token2len[tid] = length - return length - token2len[tid] = 1 - return 1 - - for sym_id in rp_expr.symbol_token_ids: - get_len(sym_id) - return token2len - - -def main(): - parser = argparse.ArgumentParser( - description="Extract graph patterns and split points from multiple models." - ) - parser.add_argument( - "--models", - nargs="+", - required=True, - help="List of paths to model directories (e.g. --models path/to/m1 path/to/m2)", - ) - parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--window", type=int, default=10) - args = parser.parse_args() - - inputs = [] - valid_model_names = [] - - for model_path in args.models: - seq = extract_ops_via_compile(model_path, args.device) - inputs.append(seq) - valid_model_names.append(os.path.basename(model_path)) - - rp_parser = RpExprParser( - window_size=args.window, fold_policy="default", fold_times=0 - ) - rp_expr, token_id2primitive_id = rp_parser(inputs) - - rp_expr.try_unwrap_body_of_sole_symbol_token() - rp_expr.try_recursive_inline_symbol_sole_used(token_id2primitive_id) - - num_primitives = len(token_id2primitive_id) - symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors)) - token2len = calculate_token_lengths(rp_expr, num_primitives, symbol_map) - - # print ops func - def resolve_token_to_ops(tid) -> List[str]: - if tid < num_primitives: - return [token_id2primitive_id[tid]] - if tid in symbol_map: - sub_tokens = symbol_map[tid].tolist() - ops = [] - for t in sub_tokens: - ops.extend(resolve_token_to_ops(t)) - return ops - return [f"Unknown({tid})"] - - for sym_id in sorted(symbol_map.keys()): - length = token2len.get(sym_id, 0) - ops_seq = resolve_token_to_ops(sym_id) - ops_str = str(ops_seq) - if len(ops_str) > 100: - ops_str = ops_str[:100] + " ...]" - - for i, model_name in enumerate(valid_model_names): - if i >= len(rp_expr.body_rp_expr): - break - - target_body_tensor = rp_expr.body_rp_expr[i] - seq_tokens = target_body_tensor.tolist() - - current_idx = 0 - split_points = set() - total_len = sum(token2len.get(t, 1) for t in seq_tokens) - - full_model_ops = [] - for t in seq_tokens: - full_model_ops.extend(resolve_token_to_ops(t)) - - for token_id in seq_tokens: - length = token2len.get(token_id, 1) - is_pattern = token_id >= num_primitives - - if is_pattern: - if current_idx > 0: - split_points.add(current_idx) - end_idx = current_idx + length - if end_idx < total_len: - split_points.add(end_idx) - - current_idx += length - - sorted_splits = sorted(list(set(split_points))) - print("=" * 50) - print(f"model_name: {model_name}") - print(f"Split Sequence Indices: {sorted_splits}") - print("Segments info:") - last_split = 0 - for split in sorted_splits + [total_len]: - segment_len = split - last_split - if last_split < len(full_model_ops) and split <= len(full_model_ops): - segment_ops = full_model_ops[last_split:split] - if len(segment_ops) > 5: - ops_display = f"[{segment_ops[0]}, ..., {segment_ops[-1]}]" - else: - ops_display = str(segment_ops) - print( - f" Range [{last_split:3d}, {split:3d}), Length: {segment_len:3d} | Ops: {ops_display}" - ) - else: - print( - f" Range [{last_split:3d}, {split:3d}), Length: {segment_len:3d} | (Index Error Warning)" - ) - - last_split = split - - -if __name__ == "__main__": - main() From dab37597488068ea14514abffabb15e50498057b Mon Sep 17 00:00:00 2001 From: fangfangssj <99968055+fangfangssj@users.noreply.github.com> Date: Fri, 28 Nov 2025 19:24:06 +0800 Subject: [PATCH 6/6] Update help text for configuration argumentf --- graph_net/torch/test_compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_net/torch/test_compiler.py b/graph_net/torch/test_compiler.py index f06a13bdf..ae91c633c 100644 --- a/graph_net/torch/test_compiler.py +++ b/graph_net/torch/test_compiler.py @@ -457,7 +457,7 @@ def main(args): type=str, required=False, default=None, - help="Path to configuration file.", + help="base64 encode configuration json.", ) args = parser.parse_args() main(args=args)