diff --git a/graph_net/bash_templates/model_path_handler_sh.txt b/graph_net/bash_templates/model_path_handler_sh.txt new file mode 100644 index 000000000..ee599105e --- /dev/null +++ b/graph_net/bash_templates/model_path_handler_sh.txt @@ -0,0 +1,26 @@ +#!/bin/bash + +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") +model_path_handler_config_json_str=$(cat < str: + """ + return rewrited model.py file contents + """ + raise NotImplementedError() + + def copy_sample_and_handle_model_py_file(self, rel_model_path: str): + src_model_path = Path(self.config["model_path_prefix"]) / rel_model_path + dst_model_path = Path(self.config["output_dir"]) / rel_model_path + dst_model_path.mkdir(parents=True, exist_ok=True) + shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True) + model_py_code = self.handle_model_py_file(rel_model_path) + (dst_model_path / "model.py").write_text(model_py_code) diff --git a/graph_net/sample_pass/resumable_sample_pass_mixin.py b/graph_net/sample_pass/resumable_sample_pass_mixin.py new file mode 100644 index 000000000..73e9fc7b3 --- /dev/null +++ b/graph_net/sample_pass/resumable_sample_pass_mixin.py @@ -0,0 +1,48 @@ +import abc +import sys +from graph_net.sample_pass.sample_pass_mixin import SamplePassMixin +from pathlib import Path +import os + + +class ResumableSamplePassMixin(SamplePassMixin): + def __init__(self, *args, **kwargs): + self.num_handled_models = 0 + + def declare_config( + self, + model_path_prefix: str, + output_dir: str, + resume: bool = False, + limits_handled_models: int = None, + ): + pass + + def sample_handled(self, rel_model_path: str) -> bool: + dst_model_path = Path(self.config["output_dir"]) / rel_model_path + if not dst_model_path.exists(): + return False + num_model_py_files = len(list(dst_model_path.rglob("model.py"))) + assert num_model_py_files <= 1 + return num_model_py_files == 1 + + @abc.abstractmethod + def resume(self, rel_model_path: str): + raise NotImplementedError() + + def resumable_handle_sample(self, rel_model_path: str): + assert os.path.realpath(self.config["model_path_prefix"]) != os.path.realpath( + self.config["output_dir"] + ) + if self.config["resume"] and self.sample_handled(rel_model_path): + return + self.resume(rel_model_path) + self._inc_num_handled_models_or_exit() + + def _inc_num_handled_models_or_exit(self): + if self.config["limits_handled_models"] is None: + return + self.num_handled_models += 1 + if self.num_handled_models >= self.config["limits_handled_models"]: + print("limits_handled_models expired.", flush=True) + sys.exit(0) diff --git a/graph_net/sample_pass/sample_pass.py b/graph_net/sample_pass/sample_pass.py new file mode 100644 index 000000000..7703214b5 --- /dev/null +++ b/graph_net/sample_pass/sample_pass.py @@ -0,0 +1,92 @@ +import abc +import copy +import inspect + + +class SamplePass(abc.ABC): + def __init__(self, config=None): + if config is None: + config = {} + + self._check_config_declaration_valid() + self.config = self._make_config_by_config_declare(config) + + @abc.abstractmethod + def declare_config(self): + raise NotImplementedError() + + @abc.abstractmethod + def __call__(self, rel_model_path: str): + raise NotImplementedError() + + def _recursively_check_mixin_declare_config(self, base_class): + from graph_net.sample_pass.sample_pass_mixin import SamplePassMixin + + if issubclass(base_class, (SamplePass, SamplePassMixin)): + check_is_base_signature( + base_class=base_class, + derived_class=type(self), + method_name="declare_config", + ) + for sub_class in base_class.__bases__: + self._recursively_check_mixin_declare_config(sub_class) + + def _check_config_declaration_parameters(self): + sig = inspect.signature(self.declare_config) + for name, param in sig.parameters.items(): + assert param.annotation in { + int, + bool, + float, + str, + list, + dict, + }, f"{name=} {param.annotation}" + assert param.kind in { + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.VAR_KEYWORD, + }, f"{name=} {param.kind=}" + + def _check_config_declaration_valid(self): + self._recursively_check_mixin_declare_config(type(self)) + self._check_config_declaration_parameters() + + def _make_config_by_config_declare(self, config): + sig = inspect.signature(self.declare_config) + mut_config = copy.deepcopy(config) + for name, param in sig.parameters.items(): + self._complete_default(name, param, mut_config) + class_name = type(self).__name__ + assert name in mut_config, f"{name=} {class_name=}" + + def get_extra_config_fields(): + return set(name for name, _ in mut_config.items()) - set( + name for name, _ in sig.parameters.items() + ) + + no_varadic_keyword = all( + param.kind != inspect.Parameter.VAR_KEYWORD + for _, param in sig.parameters.items() + ) + if no_varadic_keyword: + no_extra_config_fields = all( + name in sig.parameters for name, _ in mut_config.items() + ) + assert no_extra_config_fields, f"{get_extra_config_fields()=}" + return mut_config + + def _complete_default(self, name, param, mut_config): + if param.default is inspect.Parameter.empty: + return + mut_config[name] = copy.deepcopy(param.default) + + +def check_is_base_signature(base_class, derived_class, method_name): + base = getattr(base_class, method_name) + derived = getattr(derived_class, method_name) + base_parameters = inspect.signature(base).parameters + derived_parameters = inspect.signature(derived).parameters + assert len(derived_parameters) >= len(base_parameters) + for name, param in base_parameters.items(): + assert name in base_parameters, f"{name=}" + assert param == base_parameters[name] diff --git a/graph_net/sample_pass/sample_pass_mixin.py b/graph_net/sample_pass/sample_pass_mixin.py new file mode 100644 index 000000000..7f13171fd --- /dev/null +++ b/graph_net/sample_pass/sample_pass_mixin.py @@ -0,0 +1,7 @@ +import abc + + +class SamplePassMixin(abc.ABC): + @abc.abstractmethod + def declare_config(self): + raise NotImplementedError() diff --git a/graph_net/test/device_rewrite_sample_pass_test.sh b/graph_net/test/device_rewrite_sample_pass_test.sh new file mode 100755 index 000000000..90779cc70 --- /dev/null +++ b/graph_net/test/device_rewrite_sample_pass_test.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") +model_path_handler_config_json_str=$(cat < 0: - assert ( - num_subgraphs <= num_decomposed - ), f"{num_subgraphs=} {num_decomposed=} {str(decomposed_model_path)=}" + if num_decomposed > 0 and num_subgraphs != num_decomposed: + shutil.rmtree(decomposed_model_path / "_decomposed") + return False return num_subgraphs == num_decomposed def __call__(self, rel_model_path): + assert os.path.realpath(self.config["model_path_prefix"]) != os.path.realpath( + self.config["output_dir"] + ) model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) split_results = load_json(self.config["split_results_path"]) split_positions = split_results[rel_model_path]["split_positions"] diff --git a/graph_net/torch/sample_passes/device_rewrite_sample_pass.py b/graph_net/torch/sample_passes/device_rewrite_sample_pass.py new file mode 100644 index 000000000..f7a8df2d3 --- /dev/null +++ b/graph_net/torch/sample_passes/device_rewrite_sample_pass.py @@ -0,0 +1,36 @@ +from graph_net.sample_pass.sample_pass import SamplePass +from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin +from graph_net.sample_pass.only_model_file_rewrite_sample_pass_mixin import ( + OnlyModelFileRewriteSamplePassMixin, +) +from graph_net.torch import utils +from pathlib import Path + + +class DeviceRewriteSamplePass( + SamplePass, ResumableSamplePassMixin, OnlyModelFileRewriteSamplePassMixin +): + def __init__(self, config): + super().__init__(config) + + def declare_config( + self, + model_path_prefix: str, + output_dir: str, + device: str, + resume: bool = False, + limits_handled_models: int = None, + ): + pass + + def __call__(self, rel_model_path: str): + self.resumable_handle_sample(rel_model_path) + + def resume(self, rel_model_path: str): + return self.copy_sample_and_handle_model_py_file(rel_model_path) + + def handle_model_py_file(self, rel_model_path: str) -> str: + src_model_path = Path(self.config["model_path_prefix"]) / rel_model_path + model_py_code = (src_model_path / "model.py").read_text() + device = self.config["device"] + return utils.modify_code_by_device(model_py_code, device)