diff --git a/graph_net/paddle/run_model.py b/graph_net/paddle/run_model.py index 7c873a9bb..68144875f 100644 --- a/graph_net/paddle/run_model.py +++ b/graph_net/paddle/run_model.py @@ -1,15 +1,16 @@ +import argparse +import base64 +import importlib.util +import json import os import sys -import json -import base64 -import argparse -from typing import Type os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump" import paddle from graph_net import imp_util from graph_net.paddle import utils +from jinja2 import Template def load_class_from_file(file_path: str, class_name: str): @@ -41,16 +42,125 @@ def _convert_to_dict(config_str): return config -def _get_decorator(args): - if args.decorator_config is None: - return lambda model: model - decorator_config = _convert_to_dict(args.decorator_config) - if "decorator_path" not in decorator_config: +def _get_decorator(arg): + """兼容旧接口:既接受 argparse.Namespace,也接受已解析的 dict。""" + if arg is None: return lambda model: model - decorator_class = load_class_from_file( - decorator_config["decorator_path"], class_name="RunModelDecorator" + + decorator_config = ( + _convert_to_dict(arg.decorator_config) + if hasattr(arg, "decorator_config") + else arg ) - return decorator_class(decorator_config.get("decorator_config", {})) + if not decorator_config: + return lambda model: model + + class_name = decorator_config.get("decorator_class_name", "RunModelDecorator") + decorator_kwargs = decorator_config.get("decorator_config", {}) + + if "decorator_path" in decorator_config: + decorator_class = load_class_from_file( + decorator_config["decorator_path"], class_name=class_name + ) + return decorator_class(decorator_kwargs) + + if hasattr(sys.modules[__name__], class_name): + decorator_class = getattr(sys.modules[__name__], class_name) + return decorator_class(decorator_kwargs) + + return lambda model: model + + +class AgentUnittestGenerator: + """生成 Paddle 子图的独立 unittest 脚本,验证前向可运行。""" + + def __init__(self, config): + defaults = { + "model_path": None, + "output_path": None, + "force_device": "auto", # auto / cpu / gpu + "use_numpy": True, + } + merged = {**defaults, **(config or {})} + if merged["model_path"] is None: + raise ValueError("AgentUnittestGenerator requires 'model_path' in config") + self.model_path = merged["model_path"] + self.output_path = merged["output_path"] or self._default_output_path() + self.force_device = merged["force_device"] + self.use_numpy = merged["use_numpy"] + + def __call__(self, model): + self._generate_unittest_file() + return model + + def _default_output_path(self): + base = os.path.basename(os.path.normpath(self.model_path)) + return os.path.join(self.model_path, f"{base}_test.py") + + def _choose_device(self): + if self.force_device == "cpu": + return "cpu" + if self.force_device == "gpu": + return "gpu" + return "gpu" if paddle.device.is_compiled_with_cuda() else "cpu" + + def _generate_unittest_file(self): + target_device = self._choose_device() + template_str = """ +import importlib.util +import os +import unittest + +import paddle +from graph_net.paddle import utils + + +def _load_graph_module(model_path: str): + source_path = os.path.join(model_path, "model.py") + spec = importlib.util.spec_from_file_location("agent_graph_module", source_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.GraphModule + + +class AgentGraphTest(unittest.TestCase): + def setUp(self): + self.model_path = os.path.dirname(__file__) + self.target_device = "{{ target_device }}" + paddle.set_device(self.target_device) + self.GraphModule = _load_graph_module(self.model_path) + self.meta = utils.load_converted_from_text(self.model_path) + self.use_numpy = {{ use_numpy_flag }} + + def _with_device(self, info): + cloned = {"info": dict(info["info"]), "data": info.get("data")} + cloned["info"]["device"] = self.target_device + return cloned + + def _build_tensor(self, meta): + return utils.replay_tensor(self._with_device(meta), use_numpy=self.use_numpy) + + def test_forward_runs(self): + model = self.GraphModule() + inputs = {k: self._build_tensor(v) for k, v in self.meta["input_info"].items()} + params = {k: self._build_tensor(v) for k, v in self.meta["weight_info"].items()} + model.__graph_net_file_path__ = self.model_path + output = model(**params, **inputs) + self.assertIsNotNone(output) + + +if __name__ == "__main__": + unittest.main() +""" + + rendered = Template(template_str).render( + target_device=target_device, use_numpy_flag=self.use_numpy + ) + + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + with open(self.output_path, "w", encoding="utf-8") as f: + f.write(rendered) + print(f"[Agent] unittest 已生成: {self.output_path} (device={target_device})") def main(args): @@ -61,9 +171,14 @@ def main(args): assert model_class is not None model = model_class() print(f"{model_path=}") + decorator_config = _convert_to_dict(args.decorator_config) + if decorator_config: + decorator_config.setdefault("decorator_config", {}) + decorator_config["decorator_config"].setdefault("model_path", model_path) + decorator_config["decorator_config"].setdefault("use_numpy", True) + model = _get_decorator(decorator_config)(model) input_dict = get_input_dict(args.model_path) - model = _get_decorator(args)(model) model(**input_dict) diff --git a/graph_net/test/test_agent_unittest_generator.sh b/graph_net/test/test_agent_unittest_generator.sh new file mode 100644 index 000000000..e1cea64b2 --- /dev/null +++ b/graph_net/test/test_agent_unittest_generator.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Smoke tests for AgentUnittestGenerator on one CV and one NLP sample (Torch side). +# It runs run_model with the decorator, which will drop a *_test.py under each sample directory. + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +TORCH_RUN="python -m graph_net.torch.run_model" + +CV_SAMPLE="$ROOT_DIR/samples/torchvision/resnet18" +NLP_SAMPLE="$ROOT_DIR/samples/transformers-auto-model/albert-base-v2" + +encode_cfg() { + MODEL_PATH="$1" python - <<'PY' +import base64, json, os +cfg = { + "decorator_class_name": "AgentUnittestGenerator", + "decorator_config": { + "model_path": os.environ["MODEL_PATH"], + "force_device": "auto", + "output_path": None, + "use_dummy_inputs": False, + }, +} +print(base64.b64encode(json.dumps(cfg).encode()).decode()) +PY +} + +run_case() { + local sample_path="$1" + local name="$2" + echo "[AgentTest] running $name sample at $sample_path" + cfg_b64="$(encode_cfg "$sample_path")" + $TORCH_RUN --model-path "$sample_path" --decorator-config "$cfg_b64" +} + +run_case "$CV_SAMPLE" "CV (torchvision/resnet18)" +run_case "$NLP_SAMPLE" "NLP (transformers-auto-model/albert-base-v2)" + +echo "[AgentTest] done. Generated *_test.py files should now exist beside the samples." diff --git a/graph_net/torch/run_model.py b/graph_net/torch/run_model.py index 58549fca8..7e8094318 100644 --- a/graph_net/torch/run_model.py +++ b/graph_net/torch/run_model.py @@ -1,10 +1,14 @@ from . import utils import argparse +import base64 import importlib.util -import torch -from typing import Type import json -import base64 +import os +import sys +from typing import Type + +import torch +from jinja2 import Template def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]: @@ -25,26 +29,152 @@ def _convert_to_dict(config_str): return config -def _get_decorator(decorator_config): - if "decorator_path" not in decorator_config: +def _get_decorator(arg): + """兼容旧接口:既接受 argparse.Namespace,也接受已解析的 dict。""" + if arg is None: return lambda model: model - class_name = decorator_config.get("decorator_class_name", "RunModelDecorator") - decorator_class = load_class_from_file( - decorator_config["decorator_path"], - class_name=class_name, + + decorator_config = ( + _convert_to_dict(arg.decorator_config) + if hasattr(arg, "decorator_config") + else arg ) - return decorator_class(decorator_config.get("decorator_config", {})) + if not decorator_config: + return lambda model: model + + class_name = decorator_config.get("decorator_class_name", "RunModelDecorator") + decorator_kwargs = decorator_config.get("decorator_config", {}) + + if "decorator_path" in decorator_config: + decorator_class = load_class_from_file( + decorator_config["decorator_path"], class_name=class_name + ) + return decorator_class(decorator_kwargs) + + if hasattr(sys.modules[__name__], class_name): + decorator_class = getattr(sys.modules[__name__], class_name) + return decorator_class(decorator_kwargs) + + return lambda model: model def get_flag_use_dummy_inputs(decorator_config): - return "use_dummy_inputs" in decorator_config + return "use_dummy_inputs" in decorator_config if decorator_config else False def replay_tensor(info, use_dummy_inputs): if use_dummy_inputs: return utils.get_dummy_tensor(info) - else: - return utils.replay_tensor(info) + return utils.replay_tensor(info) + + +class AgentUnittestGenerator: + """生成可独立运行的unittest脚本,用于子图前向可执行性验证。""" + + def __init__(self, config): + defaults = { + "model_path": None, + "output_path": None, + "force_device": "auto", # auto / cpu / cuda + "use_dummy_inputs": False, + } + merged = {**defaults, **(config or {})} + if merged["model_path"] is None: + raise ValueError("AgentUnittestGenerator requires 'model_path' in config") + self.model_path = merged["model_path"] + self.output_path = merged["output_path"] or self._default_output_path() + self.force_device = merged["force_device"] + self.use_dummy_inputs = merged["use_dummy_inputs"] + + def __call__(self, model): + self._generate_unittest_file() + return model + + def _default_output_path(self): + base = os.path.basename(os.path.normpath(self.model_path)) + return os.path.join(self.model_path, f"{base}_test.py") + + def _choose_device(self): + if self.force_device == "cpu": + return "cpu" + if self.force_device == "cuda": + return "cuda" + return "cuda" if torch.cuda.is_available() else "cpu" + + def _generate_unittest_file(self): + target_device = self._choose_device() + template_str = """ +import importlib.util +import os +import tempfile +import unittest + +import torch +from graph_net.torch import utils + + +def _load_graph_module(model_path: str, target_device: str): + source_path = os.path.join(model_path, "model.py") + with open(source_path, "r", encoding="utf-8") as f: + code = f.read() + + if target_device != "cuda": + code = utils.modify_code_by_device(code, target_device) + + tmp_dir = tempfile.mkdtemp(prefix="agent_unittest_") + tmp_file = os.path.join(tmp_dir, "model_tmp.py") + with open(tmp_file, "w", encoding="utf-8") as f: + f.write(code) + + spec = importlib.util.spec_from_file_location("agent_graph_module", tmp_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.GraphModule + + +class AgentGraphTest(unittest.TestCase): + def setUp(self): + self.model_path = os.path.dirname(__file__) + self.target_device = "{{ target_device }}" + self.GraphModule = _load_graph_module(self.model_path, self.target_device) + self.meta = utils.load_converted_from_text(self.model_path) + self.use_dummy_inputs = {{ use_dummy_inputs }} + + def _with_device(self, info): + cloned = {"info": dict(info["info"]), "data": info.get("data")} + cloned["info"]["device"] = self.target_device + return cloned + + def test_forward_runs(self): + model = self.GraphModule() + weight_info = self.meta["weight_info"] + + def _build_tensor(val): + wrapped = self._with_device(val) + return ( + utils.get_dummy_tensor(wrapped) + if self.use_dummy_inputs + else utils.replay_tensor(wrapped) + ) + + state_dict = {k: _build_tensor(v) for k, v in weight_info.items()} + model.__graph_net_file_path__ = self.model_path + output = model(**state_dict) + self.assertIsNotNone(output) + + +if __name__ == "__main__": + unittest.main() +""" + + rendered = Template(template_str).render( + target_device=target_device, use_dummy_inputs=self.use_dummy_inputs + ) + + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + with open(self.output_path, "w", encoding="utf-8") as f: + f.write(rendered) + print(f"[Agent] unittest 已生成: {self.output_path} (device={target_device})") def main(args): @@ -56,8 +186,12 @@ def main(args): model = model_class() print(f"{model_path=}") decorator_config = _convert_to_dict(args.decorator_config) - if "decorator_path" in decorator_config: - model = _get_decorator(decorator_config)(model) + if decorator_config: + decorator_config.setdefault("decorator_config", {}) + decorator_config["decorator_config"].setdefault("model_path", model_path) + decorator_config["decorator_config"].setdefault("use_dummy_inputs", False) + + model = _get_decorator(decorator_config)(model) inputs_params = utils.load_converted_from_text(f"{model_path}") params = inputs_params["weight_info"]