Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 128 additions & 13 deletions graph_net/paddle/run_model.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个入口过期了。请你用sample_pass机制。参考https://github.com/PaddlePaddle/GraphNet/pull/442/files 这个pr

Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import argparse
import base64
import importlib.util
import json
import os
import sys
import json
import base64
import argparse
from typing import Type

os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"

import paddle
from graph_net import imp_util
from graph_net.paddle import utils
from jinja2 import Template


def load_class_from_file(file_path: str, class_name: str):
Expand Down Expand Up @@ -41,16 +42,125 @@ def _convert_to_dict(config_str):
return config


def _get_decorator(args):
if args.decorator_config is None:
return lambda model: model
decorator_config = _convert_to_dict(args.decorator_config)
if "decorator_path" not in decorator_config:
def _get_decorator(arg):
"""兼容旧接口:既接受 argparse.Namespace,也接受已解析的 dict。"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

英文注释

if arg is None:
return lambda model: model
decorator_class = load_class_from_file(
decorator_config["decorator_path"], class_name="RunModelDecorator"

decorator_config = (
_convert_to_dict(arg.decorator_config)
if hasattr(arg, "decorator_config")
else arg
)
return decorator_class(decorator_config.get("decorator_config", {}))
if not decorator_config:
return lambda model: model

class_name = decorator_config.get("decorator_class_name", "RunModelDecorator")
decorator_kwargs = decorator_config.get("decorator_config", {})

if "decorator_path" in decorator_config:
decorator_class = load_class_from_file(
decorator_config["decorator_path"], class_name=class_name
)
return decorator_class(decorator_kwargs)

if hasattr(sys.modules[__name__], class_name):
decorator_class = getattr(sys.modules[__name__], class_name)
return decorator_class(decorator_kwargs)

return lambda model: model


class AgentUnittestGenerator:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类应该单独放到新创建的graph_net/torch/sample_passes/agent_unittest_generator.py下。

"""生成 Paddle 子图的独立 unittest 脚本,验证前向可运行。"""

def __init__(self, config):
defaults = {
"model_path": None,
"output_path": None,
"force_device": "auto", # auto / cpu / gpu
"use_numpy": True,
}
merged = {**defaults, **(config or {})}
if merged["model_path"] is None:
raise ValueError("AgentUnittestGenerator requires 'model_path' in config")
self.model_path = merged["model_path"]
self.output_path = merged["output_path"] or self._default_output_path()
self.force_device = merged["force_device"]
self.use_numpy = merged["use_numpy"]

def __call__(self, model):
self._generate_unittest_file()
return model

def _default_output_path(self):
base = os.path.basename(os.path.normpath(self.model_path))
return os.path.join(self.model_path, f"{base}_test.py")

def _choose_device(self):
if self.force_device == "cpu":
return "cpu"
if self.force_device == "gpu":
return "gpu"
return "gpu" if paddle.device.is_compiled_with_cuda() else "cpu"

def _generate_unittest_file(self):
target_device = self._choose_device()
template_str = """
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好使用jinja2那样的专门负责渲染的引擎。

import importlib.util
import os
import unittest

import paddle
from graph_net.paddle import utils
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是不可以的。单测脚本应该能脱离graph_net单独工作。



def _load_graph_module(model_path: str):
source_path = os.path.join(model_path, "model.py")
spec = importlib.util.spec_from_file_location("agent_graph_module", source_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.GraphModule


class AgentGraphTest(unittest.TestCase):
def setUp(self):
self.model_path = os.path.dirname(__file__)
self.target_device = "{{ target_device }}"
paddle.set_device(self.target_device)
self.GraphModule = _load_graph_module(self.model_path)
self.meta = utils.load_converted_from_text(self.model_path)
self.use_numpy = {{ use_numpy_flag }}

def _with_device(self, info):
cloned = {"info": dict(info["info"]), "data": info.get("data")}
cloned["info"]["device"] = self.target_device
return cloned

def _build_tensor(self, meta):
return utils.replay_tensor(self._with_device(meta), use_numpy=self.use_numpy)

def test_forward_runs(self):
model = self.GraphModule()
inputs = {k: self._build_tensor(v) for k, v in self.meta["input_info"].items()}
params = {k: self._build_tensor(v) for k, v in self.meta["weight_info"].items()}
model.__graph_net_file_path__ = self.model_path
output = model(**params, **inputs)
self.assertIsNotNone(output)


if __name__ == "__main__":
unittest.main()
"""

rendered = Template(template_str).render(
target_device=target_device, use_numpy_flag=self.use_numpy
)

os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
with open(self.output_path, "w", encoding="utf-8") as f:
f.write(rendered)
print(f"[Agent] unittest 已生成: {self.output_path} (device={target_device})")


def main(args):
Expand All @@ -61,9 +171,14 @@ def main(args):
assert model_class is not None
model = model_class()
print(f"{model_path=}")
decorator_config = _convert_to_dict(args.decorator_config)
if decorator_config:
decorator_config.setdefault("decorator_config", {})
decorator_config["decorator_config"].setdefault("model_path", model_path)
decorator_config["decorator_config"].setdefault("use_numpy", True)

model = _get_decorator(decorator_config)(model)
input_dict = get_input_dict(args.model_path)
model = _get_decorator(args)(model)
model(**input_dict)


Expand Down
40 changes: 40 additions & 0 deletions graph_net/test/test_agent_unittest_generator.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env bash
set -euo pipefail

# Smoke tests for AgentUnittestGenerator on one CV and one NLP sample (Torch side).
# It runs run_model with the decorator, which will drop a *_test.py under each sample directory.

ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
TORCH_RUN="python -m graph_net.torch.run_model"

CV_SAMPLE="$ROOT_DIR/samples/torchvision/resnet18"
NLP_SAMPLE="$ROOT_DIR/samples/transformers-auto-model/albert-base-v2"

encode_cfg() {
MODEL_PATH="$1" python - <<'PY'
import base64, json, os
cfg = {
"decorator_class_name": "AgentUnittestGenerator",
"decorator_config": {
"model_path": os.environ["MODEL_PATH"],
"force_device": "auto",
"output_path": None,
"use_dummy_inputs": False,
},
}
print(base64.b64encode(json.dumps(cfg).encode()).decode())
PY
}
Comment on lines +13 to +27
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


run_case() {
local sample_path="$1"
local name="$2"
echo "[AgentTest] running $name sample at $sample_path"
cfg_b64="$(encode_cfg "$sample_path")"
$TORCH_RUN --model-path "$sample_path" --decorator-config "$cfg_b64"
}

run_case "$CV_SAMPLE" "CV (torchvision/resnet18)"
run_case "$NLP_SAMPLE" "NLP (transformers-auto-model/albert-base-v2)"

echo "[AgentTest] done. Generated *_test.py files should now exist beside the samples."
Loading