Skip to content
26 changes: 26 additions & 0 deletions graph_net/bash_templates/model_path_handler_sh.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
model_path_handler_config_json_str=$(cat <<EOF
{
"handler_path": "$GRAPH_NET_ROOT/graph_net/customize_your_sample_pass.py",
"handler_class_name": "customize_your_class_name",
"handler_config": {
"resume": true,
"model_path_prefix": "/customize_your_model_path_prefix",
"output_dir": "/customize_your_output_file"
}
}
EOF
)

model_path_handler_model_path_list="customize_your_model_path_list"
MODEL_PATH_HANDLER_CONFIG=$(echo $model_path_handler_config_json_str | base64 -w 0)

python3 -m graph_net.model_path_handler \
--model-path-list $model_path_handler_model_path_list \
--handler-config $MODEL_PATH_HANDLER_CONFIG \

unset model_path_handler_model_path_list
unset MODEL_PATH_HANDLER_CONFIG

28 changes: 28 additions & 0 deletions graph_net/sample_pass/only_model_file_rewrite_sample_pass_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import abc
import shutil
from graph_net.sample_pass.sample_pass_mixin import SamplePassMixin
from pathlib import Path


class OnlyModelFileRewriteSamplePassMixin(SamplePassMixin):
def declare_config(
self,
model_path_prefix: str,
output_dir: str,
):
pass

@abc.abstractmethod
def handle_model_py_file(self, rel_model_path: str) -> str:
"""
return rewrited model.py file contents
"""
raise NotImplementedError()

def copy_sample_and_handle_model_py_file(self, rel_model_path: str):
src_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
dst_model_path.mkdir(parents=True, exist_ok=True)
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
model_py_code = self.handle_model_py_file(rel_model_path)
(dst_model_path / "model.py").write_text(model_py_code)
48 changes: 48 additions & 0 deletions graph_net/sample_pass/resumable_sample_pass_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import abc
import sys
from graph_net.sample_pass.sample_pass_mixin import SamplePassMixin
from pathlib import Path
import os


class ResumableSamplePassMixin(SamplePassMixin):
def __init__(self, *args, **kwargs):
self.num_handled_models = 0

def declare_config(
self,
model_path_prefix: str,
output_dir: str,
resume: bool = False,
limits_handled_models: int = None,
):
pass

def sample_handled(self, rel_model_path: str) -> bool:
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
if not dst_model_path.exists():
return False
num_model_py_files = len(list(dst_model_path.rglob("model.py")))
assert num_model_py_files <= 1
return num_model_py_files == 1

@abc.abstractmethod
def resume(self, rel_model_path: str):
raise NotImplementedError()

def resumable_handle_sample(self, rel_model_path: str):
assert os.path.realpath(self.config["model_path_prefix"]) != os.path.realpath(
self.config["output_dir"]
)
if self.config["resume"] and self.sample_handled(rel_model_path):
return
self.resume(rel_model_path)
self._inc_num_handled_models_or_exit()

def _inc_num_handled_models_or_exit(self):
if self.config["limits_handled_models"] is None:
return
self.num_handled_models += 1
if self.num_handled_models >= self.config["limits_handled_models"]:
print("limits_handled_models expired.", flush=True)
sys.exit(0)
92 changes: 92 additions & 0 deletions graph_net/sample_pass/sample_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import abc
import copy
import inspect


class SamplePass(abc.ABC):
def __init__(self, config=None):
if config is None:
config = {}

self._check_config_declaration_valid()
self.config = self._make_config_by_config_declare(config)

@abc.abstractmethod
def declare_config(self):
raise NotImplementedError()

@abc.abstractmethod
def __call__(self, rel_model_path: str):
raise NotImplementedError()

def _recursively_check_mixin_declare_config(self, base_class):
from graph_net.sample_pass.sample_pass_mixin import SamplePassMixin

if issubclass(base_class, (SamplePass, SamplePassMixin)):
check_is_base_signature(
base_class=base_class,
derived_class=type(self),
method_name="declare_config",
)
for sub_class in base_class.__bases__:
self._recursively_check_mixin_declare_config(sub_class)

def _check_config_declaration_parameters(self):
sig = inspect.signature(self.declare_config)
for name, param in sig.parameters.items():
assert param.annotation in {
int,
bool,
float,
str,
list,
dict,
}, f"{name=} {param.annotation}"
assert param.kind in {
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.VAR_KEYWORD,
}, f"{name=} {param.kind=}"

def _check_config_declaration_valid(self):
self._recursively_check_mixin_declare_config(type(self))
self._check_config_declaration_parameters()

def _make_config_by_config_declare(self, config):
sig = inspect.signature(self.declare_config)
mut_config = copy.deepcopy(config)
for name, param in sig.parameters.items():
self._complete_default(name, param, mut_config)
class_name = type(self).__name__
assert name in mut_config, f"{name=} {class_name=}"

def get_extra_config_fields():
return set(name for name, _ in mut_config.items()) - set(
name for name, _ in sig.parameters.items()
)

no_varadic_keyword = all(
param.kind != inspect.Parameter.VAR_KEYWORD
for _, param in sig.parameters.items()
)
if no_varadic_keyword:
no_extra_config_fields = all(
name in sig.parameters for name, _ in mut_config.items()
)
assert no_extra_config_fields, f"{get_extra_config_fields()=}"
return mut_config

def _complete_default(self, name, param, mut_config):
if param.default is inspect.Parameter.empty:
return
mut_config[name] = copy.deepcopy(param.default)


def check_is_base_signature(base_class, derived_class, method_name):
base = getattr(base_class, method_name)
derived = getattr(derived_class, method_name)
base_parameters = inspect.signature(base).parameters
derived_parameters = inspect.signature(derived).parameters
assert len(derived_parameters) >= len(base_parameters)
for name, param in base_parameters.items():
assert name in base_parameters, f"{name=}"
assert param == base_parameters[name]
7 changes: 7 additions & 0 deletions graph_net/sample_pass/sample_pass_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import abc


class SamplePassMixin(abc.ABC):
@abc.abstractmethod
def declare_config(self):
raise NotImplementedError()
27 changes: 27 additions & 0 deletions graph_net/test/device_rewrite_sample_pass_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
model_path_handler_config_json_str=$(cat <<EOF
{
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/device_rewrite_sample_pass.py",
"handler_class_name": "DeviceRewriteSamplePass",
"handler_config": {
"device": "cuda",
"resume": false,
"model_path_prefix": "$GRAPH_NET_ROOT",
"output_dir": "/tmp/device_rewrited"
}
}
EOF
)

model_path_handler_model_path_list="$GRAPH_NET_ROOT/graph_net/test/dev_model_list/validation_error_model_list.txt"
MODEL_PATH_HANDLER_CONFIG=$(echo $model_path_handler_config_json_str | base64 -w 0)

python3 -m graph_net.model_path_handler \
--model-path-list $model_path_handler_model_path_list \
--handler-config $MODEL_PATH_HANDLER_CONFIG \

unset model_path_handler_model_path_list
unset MODEL_PATH_HANDLER_CONFIG

2 changes: 0 additions & 2 deletions graph_net/torch/fx_graph_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ def parse_sole_graph_module_without_varify(module, inputs):
def my_backend(gm, sample_inputs):
nonlocal traced_module
nonlocal traced_sample_inputs
assert traced_module is None
assert traced_sample_inputs is None
traced_module = gm
traced_sample_inputs = sample_inputs
return gm.forward
Expand Down
11 changes: 7 additions & 4 deletions graph_net/torch/graph_decomposer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
from pathlib import Path
import torch
import json
Expand Down Expand Up @@ -197,13 +198,15 @@ def _is_model_handled(self, rel_model_path, split_positions):
num_subgraphs = len(split_positions) + 1
decomposed_model_path = Path(self.config["output_dir"]) / rel_model_path
num_decomposed = len(list(decomposed_model_path.rglob("model.py")))
if num_decomposed > 0:
assert (
num_subgraphs <= num_decomposed
), f"{num_subgraphs=} {num_decomposed=} {str(decomposed_model_path)=}"
if num_decomposed > 0 and num_subgraphs != num_decomposed:
shutil.rmtree(decomposed_model_path / "_decomposed")
return False
return num_subgraphs == num_decomposed

def __call__(self, rel_model_path):
assert os.path.realpath(self.config["model_path_prefix"]) != os.path.realpath(
self.config["output_dir"]
)
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
split_results = load_json(self.config["split_results_path"])
split_positions = split_results[rel_model_path]["split_positions"]
Expand Down
36 changes: 36 additions & 0 deletions graph_net/torch/sample_passes/device_rewrite_sample_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from graph_net.sample_pass.sample_pass import SamplePass
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
from graph_net.sample_pass.only_model_file_rewrite_sample_pass_mixin import (
OnlyModelFileRewriteSamplePassMixin,
)
from graph_net.torch import utils
from pathlib import Path


class DeviceRewriteSamplePass(
SamplePass, ResumableSamplePassMixin, OnlyModelFileRewriteSamplePassMixin
):
def __init__(self, config):
super().__init__(config)

def declare_config(
self,
model_path_prefix: str,
output_dir: str,
device: str,
resume: bool = False,
limits_handled_models: int = None,
):
pass

def __call__(self, rel_model_path: str):
self.resumable_handle_sample(rel_model_path)

def resume(self, rel_model_path: str):
return self.copy_sample_and_handle_model_py_file(rel_model_path)

def handle_model_py_file(self, rel_model_path: str) -> str:
src_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
model_py_code = (src_model_path / "model.py").read_text()
device = self.config["device"]
return utils.modify_code_by_device(model_py_code, device)
Loading