diff --git a/graph_net/agent/README.md b/graph_net/agent/README.md index 0670d7ccfe..2fcf1aa7c9 100644 --- a/graph_net/agent/README.md +++ b/graph_net/agent/README.md @@ -6,16 +6,7 @@ ### 基础依赖 ```bash -# 已包含在 GraphNet 主依赖中 -pip install torch torchvision -``` - -### Agent 可选依赖 -```bash -# 安装 Agent 相关依赖(包括 huggingface_hub) -pip install -e ".[agent]" - -# 或单独安装 +pip install torch transformers accelerate pip install huggingface_hub>=0.20.0 ``` @@ -26,48 +17,103 @@ pip install huggingface_hub>=0.20.0 export GRAPH_NET_EXTRACT_WORKSPACE=/path/to/your/workspace ``` -或在代码中指定: +未设置时默认使用 `~/graphnet_workspace`。也可在代码中显式指定: ```python from graph_net.agent import GraphNetAgent agent = GraphNetAgent(workspace="/path/to/workspace") ``` +### HuggingFace Token(访问私有或受限模型) +```bash +export HF_TOKEN=hf_xxx +``` + ## 使用示例 +### 单模型抽取 ```python from graph_net.agent import GraphNetAgent -# 初始化 Agent agent = GraphNetAgent( - workspace="./agent_workspace", - hf_token=None # 可选,用于访问私有模型 + hf_token=None, # 可选,私有模型需要 + llm_retry=True, # 失败时自动调用 ducc/claude 修复脚本并重试 ) -# 运行提取 success = agent.extract_sample("bert-base-uncased") +``` + +### 并行批量抽取 +```bash +# 从文件读取模型列表(每行一个 model_id,# 开头为注释) +python graph_net/agent/parallel_extract.py --model-list models.txt + +# 从 HuggingFace Hub 按下载量抓取模型 +python graph_net/agent/parallel_extract.py --count 200 --task text-classification -if success: - print("✅ Sample extracted successfully") -else: - print("❌ Extraction failed") +# 指定 GPU 和 workspace +python graph_net/agent/parallel_extract.py \ + --model-list models.txt \ + --gpus 0,1,2,3 \ + --workspace /data/graphnet_workspace \ + --hf-token YOUR_TOKEN + +# 结果保存为 JSON(默认自动生成带时间戳的文件名) +python graph_net/agent/parallel_extract.py --model-list models.txt --output result.json +``` + +`--gpus` 默认自动检测全部可用 GPU(读取 `CUDA_VISIBLE_DEVICES` 或 `nvidia-smi`)。 + +## parallel_extract.py 详解 + +`parallel_extract.py` 是面向批量场景的并行抽取脚本,适合一次性处理数百到数千个模型。 + +### 工作原理 + +- 所有待抽取的模型 ID 放入一个共享任务队列 +- 每张 GPU 启动一个独立的 worker 子进程(`multiprocessing spawn` 模式,CUDA 安全) +- worker 空闲时主动从队列取任务,天然实现动态负载均衡 +- 每个 worker 内部使用独立的 `GraphNetAgent`,彼此隔离,互不影响 + +### 命令行参数 + +| 参数 | 默认值 | 说明 | +|---|---|---| +| `--model-list` | — | 模型列表文件路径,每行一个 model_id,`#` 开头为注释 | +| `--count` | 100 | 未指定 `--model-list` 时,从 HuggingFace Hub 按下载量抓取的模型数量(需安装 `huggingface_hub`) | +| `--task` | — | HuggingFace 任务类型过滤,如 `text-classification`、`image-classification`(与 `--count` 配合使用) | +| `--gpus` | 自动检测 | 使用的 GPU 编号,逗号分隔,如 `0,1,2,3` | +| `--workspace` | `$GRAPH_NET_EXTRACT_WORKSPACE` 或 `~/graphnet_workspace` | 工作目录根路径 | +| `--hf-token` | — | HuggingFace API Token,私有或受限模型需要 | +| `--output` | 自动生成 | 结果 JSON 文件路径,默认为 `parallel_extract_<时间戳>.json` | + +### 模型列表文件格式 + +``` +# 文本模型 +bert-base-uncased +google/flan-t5-base + +# 视觉模型 +openai/clip-vit-base-patch32 ``` ## 工作流程 -1. **Fetch**: 从 HuggingFace 下载模型 -2. **Analyze**: 解析 config.json 提取元数据 -3. **CodeGen**: 生成 run_model.py 脚本 -4. **Extract**: 执行脚本提取计算图 -5. **Deduplicate**: 检查是否与已有样本重复 -6. **Verify**: 验证样本完整性 -7. **Archive**: 保存 run_model.py 到样本目录 +1. **Fetch**: 从 HuggingFace 下载模型到本地缓存 +2. **Analyze**: 解析 `config.json` 提取输入形状、dtype、模型类型等元数据 +3. **CodeGen**: 根据元数据生成 `run_model.py` 抽取脚本 +4. **Extract**: 在子进程中执行脚本抽取计算图 +5. **LLM Retry**(可选):若抽取失败,调用 `ducc`/`claude -p` 修复脚本并最多重试 2 次 +6. **Deduplicate**: 基于 SHA-256 图哈希检查是否与已有样本重复 +7. **Verify**: 使用 ForwardVerifier 验证样本可 forward -## 测试 +## LLM Retry -```bash -# 运行所有测试 -pytest graph_net/agent/tests/ -v +当模板生成的脚本执行失败时,若系统中存在 `ducc` 或 `claude` CLI,Agent 会自动将失败脚本、报错信息和模型 config 发给 LLM 进行修复,最多重试 2 次。 -# 运行实际模型测试(需要设置环境变量) -TEST_REAL_RUN=1 pytest graph_net/agent/tests/test_real_run.py -v +```python +# 禁用 LLM retry +agent = GraphNetAgent(llm_retry=False) ``` + +LLM retry 需要 `ducc` 或 `claude` 在 `PATH` 中可用。 diff --git a/graph_net/agent/agent_usage.md b/graph_net/agent/agent_usage.md new file mode 100644 index 0000000000..07f02ec6a2 --- /dev/null +++ b/graph_net/agent/agent_usage.md @@ -0,0 +1,184 @@ +# GraphNet Agent 使用指南 + +自动从 HuggingFace 模型抽取计算图的 Agent 工具。 + +--- + +## 环境准备 + +# 目录 +在GraphNet目录下运行即可,不需要安装 + +``` + +--- + +## 快速开始 + +```python +from graph_net.agent import GraphNetAgent + +agent = GraphNetAgent() +ok = agent.extract_sample("prajjwal1/bert-tiny") +print("成功" if ok else "失败") +``` + +默认 workspace 为 `/work/graphnet_workspace`,输出目录按 `组织_模型名` 命名,例如: +`/work/graphnet_workspace/prajjwal1_bert-tiny/` + +--- + +## 初始化参数 + +```python +GraphNetAgent( + workspace = "/work/graphnet_workspace", # 工作目录根路径 + hf_token = None, # HuggingFace Token(私有模型需要) + llm_retry = True, # 失败时调用 LLM 兜底修复 +) +``` + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `workspace` | `/work/graphnet_workspace` | 工作目录,自动创建子目录结构 | +| `hf_token` | `None` | HF access token,公开模型无需填写 | +| `llm_retry` | `True` | 模板脚本失败后,调用 `ducc -p` 让 LLM 修复并重试一次 | + +--- + +## 批量抽取 + +```python +from graph_net.agent import GraphNetAgent + +agent = GraphNetAgent() + +models = [ + "prajjwal1/bert-tiny", + "distilbert/distilgpt2", + "hf-internal-testing/tiny-random-ViTModel", + "hf-internal-testing/tiny-random-T5Model", + "openai/clip-vit-base-patch32", +] + +results = {} +for model_id in models: + results[model_id] = agent.extract_sample(model_id) + +# 打印汇总 +for mid, ok in results.items(): + print(f"{'OK ' if ok else 'FAIL'} {mid}") +``` + +--- + +## 工作目录结构 + +``` +/work/graphnet_workspace/ +├── models/ # HuggingFace 下载缓存(仅 config,不含权重) +│ └── models--org--model-name/ +├── generated/ # 自动生成的抽取脚本 +│ └── org_model-name/ +│ ├── run_model.py # 模板生成脚本 +│ └── run_model_llm.py # LLM 修复脚本(首次失败时生成) +├── org_model-name/ # 计算图输出(以 组织_模型名 命名) +│ ├── model.py # 计算图结构 +│ ├── graph_net.json # 图结构 JSON +│ ├── input_meta.py # 输入元信息 +│ ├── input_tensor_constraints.py # 输入约束 +│ ├── weight_meta.py # 权重元信息 +│ ├── graph_hash.txt # 图结构哈希(用于去重) +│ └── run_model.py # 归档的抽取脚本 +├── samples/ # 去重比对参考库 +└── logs/ # 运行日志 + └── agent_YYYYMMDD_HHMMSS.log +``` + +--- + +## 抽取流程 + +``` +HuggingFace model_id + │ + ▼ +① 下载配置文件 仅下载 config.json 等配置,跳过权重文件(*.bin / *.safetensors + │ / *.tflite / *.mlmodel / *.onnx 等),模型参数随机初始化 + ▼ +② 解析配置元数据 读取 config.json,推断 model_type / vocab_size / input_shapes + │ + ▼ +③ 生成抽取脚本 模板生成 run_model.py,含随机输入构造 + graph_net.torch.extract 调用 + │ + ▼ +④ 子进程执行脚本 独立 Python 进程运行,注入 GRAPH_NET_EXTRACT_WORKSPACE 环境变量 + │ + ├─ 成功 ──────────────────────────────────────────────────┐ + │ │ + └─ 失败 → ⑤ LLM 兜底(llm_retry=True 且 ducc 可用) │ + │ │ + ▼ │ + ducc -p "" │ + 生成 run_model_llm.py │ + │ │ + ▼ │ + 子进程重试执行 │ + │ │ + ────────────┘ │ + ▼ +⑥ 生成 graph_hash.txt + 去重检查 + 验证输出文件完整性 + 归档脚本 +``` + +--- + +## LLM 兜底机制 + +当模板脚本执行失败时,若满足以下条件则触发 LLM 兜底: + +- `llm_retry=True`(默认开启) +- `ducc` 命令可用(在 PATH 中) + +LLM 收到的信息包括:失败脚本原文、报错信息、`config.json` 内容。 +LLM 必须遵守以下约束(写在 system prompt 里): + +1. 必须调用 `graph_net.torch.extract(name="...")(model).eval()` +2. 只能使用 `torch`、`transformers`、`graph_net` 三个包 +3. 输入张量随机构造,无需真实数据 +4. 设备必须用 `torch.device("cuda" if torch.cuda.is_available() else "cpu")` +5. 不下载模型权重 + +每个模型**最多触发两次** LLM 兜底。第二次会把第一次修复后的脚本及其新的报错一并送给 LLM,方便它在上一轮基础上进一步修正。 + +--- + +## 常见问题 + +**Q:为什么某些模型下载很慢?** + +HuggingFace 上部分模型除 PyTorch 权重外还包含 CoreML (`.mlmodel`)、TFLite (`.tflite`)、ONNX (`.onnx`) 等格式的文件,这些文件体积大(数百 MB)。Agent 已在 `ignore_patterns` 中跳过这些格式,若遇到新格式可在 [huggingface_fetcher.py](../graph_net/agent/model_fetcher/huggingface_fetcher.py) 里补充。 + +**Q:抽取结果目录名规则是什么?** + +以 `组织_模型名` 命名(`/` 替换为 `_`),例如: +- `prajjwal1/bert-tiny` → `prajjwal1_bert-tiny/` +- `hf-internal-testing/tiny-random-ViTModel` → `hf-internal-testing_tiny-random-ViTModel/` + +**Q:关闭 LLM 兜底怎么做?** + +```python +agent = GraphNetAgent(llm_retry=False) +``` + +**Q:如何使用私有模型?** + +```python +import os +agent = GraphNetAgent(hf_token=os.environ["HF_TOKEN"]) +``` + +**Q:如何检查某次抽取是否成功?** + +`extract_sample()` 返回 `True` 表示成功,同时可以检查输出目录是否存在 7 个文件: +`model.py`、`graph_net.json`、`input_meta.py`、`input_tensor_constraints.py`、 +`weight_meta.py`、`graph_hash.txt`、`run_model.py`。 diff --git a/graph_net/agent/code_generator/__init__.py b/graph_net/agent/code_generator/__init__.py index 6971c23097..f590a8b28c 100644 --- a/graph_net/agent/code_generator/__init__.py +++ b/graph_net/agent/code_generator/__init__.py @@ -2,5 +2,6 @@ from graph_net.agent.code_generator.base import BaseCodeGenerator from graph_net.agent.code_generator.template_generator import TemplateCodeGenerator +from graph_net.agent.code_generator.llm_code_fixer import LLMCodeFixer -__all__ = ["BaseCodeGenerator", "TemplateCodeGenerator"] +__all__ = ["BaseCodeGenerator", "TemplateCodeGenerator", "LLMCodeFixer"] diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py new file mode 100644 index 0000000000..2a56a6f9ab --- /dev/null +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -0,0 +1,332 @@ +"""LLM-based script fixer using ducc -p (Claude Code non-interactive mode)""" + +import json +import logging +import os +import re +import shutil +import subprocess +from pathlib import Path +from typing import Optional + +from graph_net.agent.utils.exceptions import CodeGenError + +# Candidate binary names / paths to search for ducc CLI +_DUCC_CANDIDATES = [ + "ducc", + "claude", + "/usr/local/bin/ducc", + os.path.expanduser("~/.local/bin/ducc"), +] + +_SYSTEM_PROMPT = """\ +你是 PyTorch / HuggingFace 模型计算图抽取专家。 +任务:修复一段失败的图抽取脚本,输出完整、可直接运行的 Python 脚本。 + +## 【硬性约束 - 违反即输出无效】 +1. 抽取调用格式固定为: + graph_net.torch.extract(name="{name}", dynamic=False)(model).eval()(**inputs) + - name 值已指定,禁止修改;dynamic 必须为 False(Swin/F.pad 等动态模式会崩溃) +2. 模型加载必须用随机权重(禁止下载权重文件): + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + model = AutoModel.from_config(config) # 或对应任务类 +3. 设备选择固定写法:device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +4. 只允许使用 torch、transformers、graph_net 及 Python 标准库(os/pathlib/json 等) +5. 只输出代码块,格式:```python\\n...代码...\\n```,禁止输出任何说明文字 + +## 【输入构造规范 - 按 model_type 选择对应方案】 + +**文本类**(bert/roberta/electra/albert/distilbert/mpnet/xlm/deberta/camembert 等): + seq_len = 128 + vocab_upper = min(vocab_size - 1, 30000) # 严格小于 vocab_size,防止 embedding 越界 + input_ids = torch.randint(0, vocab_upper, (1, seq_len), dtype=torch.long).to(device) + attention_mask = torch.ones((1, seq_len), dtype=torch.long).to(device) + token_type_ids = torch.zeros((1, seq_len), dtype=torch.long).to(device) # 仅 BERT 系需要 + ⚠️ 绝对禁止用 vocab_size 本身作为 randint 上界 + +**视觉类**(vit/swin/convnext/deit/resnet/efficientnet/mobilenet 等): + num_channels = config 中读取,默认 3 + image_size = config 中读取(vision_config.image_size 或 image_size),默认 224 + pixel_values = torch.randn(1, num_channels, image_size, image_size).to(device) + ⚠️ 纯视觉模型只传 pixel_values,禁止传 input_ids + +**多模态类**(clip/blip/flava/align 等): + 同时传文本分支(input_ids + attention_mask)和视觉分支(pixel_values) + +**音频类**(wav2vec2/hubert/whisper/clap/unispeech 等): + - wav2vec2/hubert/unispeech:input_values = torch.randn(1, 16000).to(device) + - whisper:input_features = torch.randn(1, 80, 3000).to(device) + - clap:input_features=torch.randn(1, 1, 1001, 64).to(device), is_longer=torch.tensor([False]).to(device) + - 通用回退:input_features = torch.randn(1, 128, 64).to(device) + +**序列到序列类**(t5/bart/marian/pegasus/longt5 等): + input_ids = torch.randint(0, min(vocab_size-1, 1000), (1, 64), dtype=torch.long).to(device) + attention_mask = torch.ones((1, 64), dtype=torch.long).to(device) + decoder_input_ids = torch.randint(0, min(vocab_size-1, 1000), (1, 32), dtype=torch.long).to(device) + +## 【常见报错 → 修复方法】 +| 报错关键词 | 修复方法 | +|---|---| +| "You have to specify pixel_values" | 补充 pixel_values 输入 | +| "index out of range in self" / embedding 越界 | input_ids 上界 > vocab_size,改为 min(vocab_size-1, 30000) | +| "NoneType has no attribute" | 对应输入字段为 None,补充正确 tensor | +| "running_mean should contain X elements" | BatchNorm channel 维度不对,检查 input_features shape 的 channel 轴 | +| "size of tensor a must match tensor b" | 序列长度或 channel 不一致,统一固定值 | +| "Sizes of tensors must match except in dimension" | 注意 encoder/decoder 序列长度可以不同,不要强制相等 | +| "Expected input batch_size to match target batch_size" | batch size 统一为 1 | +| "sentencepiece" / "tiktoken" ImportError | 不使用 tokenizer,用 torch.randint 直接构造 input_ids | +| "PendingUnbackedSymbolNotFound" | 确认 dynamic=False(不要改为 True) | +| decoder_input_ids missing | Seq2Seq 模型需要同时传 input_ids 和 decoder_input_ids | +""" + + +_CONFIG_JSON_MAX_CHARS = 4096 + + +def _find_ducc() -> Optional[str]: + """Find ducc/claude binary, return full path or None.""" + for candidate in _DUCC_CANDIDATES: + found = shutil.which(candidate) + if found: + return found + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return candidate + return None + + +def _extract_code_block(text: str) -> Optional[str]: + """Extract first ```python ... ``` code block from LLM output.""" + pattern = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL) + match = pattern.search(text) + if match: + return match.group(1).strip() + # Fallback: if entire output looks like python code + stripped = text.strip() + if stripped.startswith("import ") or stripped.startswith("from "): + return stripped + return None + + +class LLMCodeFixer: + """Fix a failed extraction script by calling ducc/claude -p (non-interactive).""" + + def __init__( + self, + timeout: int = 360, + model: Optional[str] = None, + ): + """ + Args: + timeout: Max seconds to wait for ducc response. + model: Override the LLM model (e.g. 'sonnet', 'haiku'). + If None, uses whatever ducc default is configured. + """ + self.timeout = timeout + self.model = model + self.logger = logging.getLogger(self.__class__.__name__) + self._ducc_bin = _find_ducc() + if self._ducc_bin: + self.logger.info(f"LLMCodeFixer: using CLI at {self._ducc_bin}") + else: + self.logger.warning( + "LLMCodeFixer: ducc/claude binary not found; " + "LLM retry will be skipped. " + "Add ducc or claude to PATH." + ) + + @property + def available(self) -> bool: + return self._ducc_bin is not None + + def fix( + self, + script_path: Path, + error_msg: str, + model_dir: Path, + model_id: str, + output_dir: Path, + attempt: int = 1, + ) -> Path: + """ + Ask the LLM to fix a failed extraction script. + + Args: + script_path: Path to the (failed) script to fix + error_msg: Captured stderr / ExtractionError message + model_dir: Local model directory (contains config.json) + model_id: HuggingFace model ID (e.g. 'prajjwal1/bert-tiny') + output_dir: Directory where the fixed script should be written + attempt: Retry index (1 or 2), affects output filename + + Returns: + Path to the fixed script (run_model_llm_1.py / run_model_llm_2.py) + + Raises: + CodeGenError: If LLM call fails or returns no valid code + """ + if not self.available: + raise CodeGenError( + "ducc/claude binary not available; cannot perform LLM fix." + ) + + original_script = script_path.read_text(encoding="utf-8") + config_json = self._read_config(model_dir) + safe_name = model_id.replace("/", "_") + + prompt = self._build_prompt( + original_script=original_script, + error_msg=error_msg, + config_json=config_json, + model_id=model_id, + safe_name=safe_name, + model_dir=model_dir, + ) + + self.logger.info( + f"Calling LLM to fix script for {model_id} (attempt {attempt}) ..." + ) + llm_output = self._call_ducc(prompt) + + code = _extract_code_block(llm_output) + if not code: + raise CodeGenError( + f"LLM response contained no Python code block.\n" + f"Response (first 500 chars):\n{llm_output[:500]}" + ) + + output_dir.mkdir(parents=True, exist_ok=True) + fixed_path = output_dir / f"run_model_llm_{attempt}.py" + fixed_path.write_text(code, encoding="utf-8") + self.logger.info(f"LLM-fixed script written to: {fixed_path}") + return fixed_path + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _build_prompt( + self, + original_script: str, + error_msg: str, + config_json: str, + model_id: str, + safe_name: str, + model_dir: Path, + ) -> str: + model_dir_str = str(model_dir).replace("\\", "/") + system = _SYSTEM_PROMPT.format(name=safe_name) + key_fields = self._extract_key_fields(model_dir) + return ( + f"{system}\n\n" + f"---\n\n" + f"## 当前任务\n\n" + f"### 模型信息\n" + f"- model_id: `{model_id}`\n" + f"- config_dir: `{model_dir_str}`\n" + f"- 关键配置字段(优先以此为准):\n```json\n{key_fields}\n```\n\n" + f"### config.json(完整参考)\n```json\n{config_json}\n```\n\n" + f"### 失败脚本\n```python\n{original_script}\n```\n\n" + f"### 错误信息\n```\n{error_msg}\n```\n\n" + f"### 输出要求\n" + f"直接输出修复后的完整脚本,用 ```python\\n...\\n``` 包裹,不附加任何说明:" + ) + + def _call_ducc(self, prompt: str) -> str: + """Invoke ducc -p '' and return stdout.""" + cmd = [ + self._ducc_bin, + "-p", + prompt, + ] + if self.model: + cmd += ["--model", self.model] + + # Inherit current env so ANTHROPIC_* vars are passed through + env = os.environ.copy() + # Ensure the binary's directory is in PATH (handles non-PATH installs) + bin_dir = str(Path(self._ducc_bin).parent) + if bin_dir not in env.get("PATH", ""): + env["PATH"] = f"{bin_dir}:{env.get('PATH', '')}" + + try: + result = subprocess.run( + cmd, + env=env, + capture_output=True, + text=True, + timeout=self.timeout, + ) + except subprocess.TimeoutExpired: + raise CodeGenError(f"ducc -p timed out after {self.timeout}s") + + if result.returncode != 0: + raise CodeGenError( + f"ducc -p exited with code {result.returncode}.\n" + f"stderr: {result.stderr[:500]}" + ) + + output = result.stdout.strip() + if not output: + raise CodeGenError("ducc -p returned empty output.") + + return output + + @staticmethod + def _read_config(model_dir: Path) -> str: + """Read config.json as a compact JSON string (max 4 KB).""" + config_path = model_dir / "config.json" + if not config_path.exists(): + return "{}" + try: + raw = json.loads(config_path.read_text(encoding="utf-8")) + text = json.dumps(raw, ensure_ascii=False, indent=2) + if len(text) > _CONFIG_JSON_MAX_CHARS: + text = text[:_CONFIG_JSON_MAX_CHARS] + "\n... (truncated)" + return text + except Exception: + return "{}" + + @staticmethod + def _extract_key_fields(model_dir: Path) -> str: + """从 config.json 提取对输入构造最关键的字段,方便 LLM 直接读取。""" + config_path = model_dir / "config.json" + if not config_path.exists(): + return "{}" + try: + cfg = json.loads(config_path.read_text(encoding="utf-8")) + except Exception: + return "{}" + keys = [ + "model_type", + "vocab_size", + "max_position_embeddings", + "image_size", + "num_channels", + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + # audio/multimodal + "audio_config", + "vision_config", + "text_config", + "patch_size", + "num_mel_bins", + "chunk_length", + ] + result = {k: cfg[k] for k in keys if k in cfg} + # 对嵌套 config 只取 model_type + for nested in ("audio_config", "vision_config", "text_config"): + if isinstance(result.get(nested), dict): + result[nested] = { + k: result[nested][k] + for k in ( + "model_type", + "vocab_size", + "image_size", + "num_channels", + "num_mel_bins", + "hidden_size", + ) + if k in result[nested] + } + return json.dumps(result, ensure_ascii=False) diff --git a/graph_net/agent/code_generator/template_generator.py b/graph_net/agent/code_generator/template_generator.py index 297a9a4e28..b2ec415f3f 100644 --- a/graph_net/agent/code_generator/template_generator.py +++ b/graph_net/agent/code_generator/template_generator.py @@ -59,6 +59,11 @@ def generate( except Exception as e: raise CodeGenError(f"Failed to generate code: {e}") from e + @staticmethod + def _model_short_name(model_id: str) -> str: + """Return 'org_model' name (replace '/' with '_')""" + return model_id.replace("/", "_") + def _generate_code(self, model_dir: Path, model_metadata: ModelMetadata) -> str: """Generate complete extraction script code string""" # Generate model loading code @@ -67,6 +72,8 @@ def _generate_code(self, model_dir: Path, model_metadata: ModelMetadata) -> str: # Generate input construction code input_code = self._generate_input_code(model_metadata) + short_name = self._model_short_name(model_metadata.model_id) + # Generate main code code = f"""import torch try: @@ -79,19 +86,19 @@ def _generate_code(self, model_dir: Path, model_metadata: ModelMetadata) -> str: def main(): # Load model {self._indent(load_code, 4)} - + # Prepare inputs {self._indent(input_code, 4)} - + # Extract graph device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device).eval() - + # Move inputs to same device as model inputs = {{k: v.to(device) for k, v in inputs.items()}} - - wrapped = graph_net.torch.extract(name="{model_metadata.model_id}", dynamic=True)(model).eval() - + + wrapped = graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval() + with torch.no_grad(): wrapped(**inputs) @@ -103,16 +110,14 @@ def main(): def _generate_model_loader( self, model_dir: Path, model_metadata: ModelMetadata ) -> str: - """Generate model loading code based on model type""" + """Generate model loading code — config only, random weights""" model_path = str(model_dir).replace("\\", "/") - if model_metadata.model_type in ["bert", "gpt", "t5", "roberta"]: - return f'model = AutoModel.from_pretrained("{model_path}")' - elif model_metadata.model_type in ["resnet", "vgg", "densenet"]: - return f"model = torchvision.models.{model_metadata.model_type}(pretrained=True)" - else: - # Generic loading - return f'model = AutoModel.from_pretrained("{model_path}")' + return ( + f"from transformers import AutoConfig\n" + f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' + f"model = AutoModel.from_config(_config)" + ) def _generate_input_code(self, model_metadata: ModelMetadata) -> str: """Generate input tensor construction code based on model metadata""" diff --git a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py index 9312654f99..3f0a13f48e 100644 --- a/graph_net/agent/graph_extractor/subprocess_graph_extractor.py +++ b/graph_net/agent/graph_extractor/subprocess_graph_extractor.py @@ -2,6 +2,7 @@ import logging import os +import signal import subprocess import sys import time @@ -12,7 +13,7 @@ from graph_net.agent.utils.exceptions import ExtractionError # Constants -DEFAULT_TIMEOUT = 600 # 10 minutes for large models +DEFAULT_TIMEOUT = 1000 # ~17 minutes for large models OUTPUT_SEARCH_WINDOW = 600 # 10 minutes for finding recently created directories HASH_DIR_LENGTH = 40 # SHA1 hash length ERROR_MSG_MAX_LINES = 20 # Keep first and last N lines of error messages @@ -59,20 +60,38 @@ def extract(self, code_path: Path, model_id: str) -> Path: else: env["PYTHONPATH"] = str(graphnet_root) - # Run script in subprocess - result = subprocess.run( + # Ensure GRAPH_NET_EXTRACT_WORKSPACE points to our workspace + if "GRAPH_NET_EXTRACT_WORKSPACE" not in env: + env["GRAPH_NET_EXTRACT_WORKSPACE"] = str(self.workspace) + + # Run script in subprocess via Popen so we can kill on timeout + proc = subprocess.Popen( [sys.executable, str(code_path)], cwd=str(code_path.parent), env=env, - capture_output=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, - timeout=self.timeout, + # 用新进程组,方便整组 kill(避免遗留孙进程占显存) + start_new_session=True, ) + try: + stdout, stderr = proc.communicate(timeout=self.timeout) + except subprocess.TimeoutExpired: + # 先 kill 整个进程组,确保 GPU 显存释放 + try: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + except ProcessLookupError: + proc.kill() + proc.communicate() # 回收僵尸进程 + raise ExtractionError( + f"Script execution timed out after {self.timeout} seconds" + ) - if result.returncode != 0: - error_msg = self._format_error_message(result.stderr or result.stdout) + if proc.returncode != 0: + error_msg = self._format_error_message(stderr or stdout) raise ExtractionError( - f"Script execution failed with return code {result.returncode}.\n" + f"Script execution failed with return code {proc.returncode}.\n" f"Command: {sys.executable} {code_path}\n" f"Error output:\n{error_msg}" ) @@ -88,10 +107,8 @@ def extract(self, code_path: Path, model_id: str) -> Path: ) return output_dir - except subprocess.TimeoutExpired: - raise ExtractionError( - f"Script execution timed out after {self.timeout} seconds" - ) + except ExtractionError: + raise except Exception as e: raise ExtractionError(f"Failed to extract graph: {e}") from e @@ -118,6 +135,7 @@ def _find_output_dir_robust(self, model_id: str) -> Optional[Path]: self.logger.warning(f"Workspace path does not exist: {workspace_path}") return None + # Use 'org_model' naming to match the extract(name=...) convention safe_model_id = model_id.replace("/", "_") expected_dir = workspace_path / safe_model_id @@ -132,17 +150,12 @@ def _find_output_dir_robust(self, model_id: str) -> Optional[Path]: self.logger.info(f"Found output directory (retry): {expected_dir}") return expected_dir - # Strategy 3: Search for recently modified directories - recent_dir = self._find_recent_sample_dir(workspace_path) - if recent_dir: - return recent_dir - - # Strategy 4: Search by model_id pattern + # Strategy 3: Search by model_id pattern (org_model first) pattern_dir = self._find_dir_by_pattern(workspace_path, model_id, safe_model_id) if pattern_dir: return pattern_dir - # Strategy 5: Search for hash-named directories + # Strategy 4: Search for hash-named directories hash_dir = self._find_hash_named_dir(workspace_path) if hash_dir: return hash_dir @@ -155,41 +168,14 @@ def _get_workspace_path(self) -> Optional[Path]: workspace_env = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE") return Path(workspace_env) if workspace_env else self.workspace - def _find_recent_sample_dir(self, workspace_path: Path) -> Optional[Path]: - """Find most recently modified sample directory""" - current_time = time.time() - candidate_dirs = [] - - for item in workspace_path.iterdir(): - if not item.is_dir() or not self._is_valid_sample_dir(item): - continue - try: - mtime = item.stat().st_mtime - time_diff = current_time - mtime - if time_diff < OUTPUT_SEARCH_WINDOW: - candidate_dirs.append((item, mtime, time_diff)) - except (OSError, FileNotFoundError): - continue - - if candidate_dirs: - candidate_dirs.sort(key=lambda x: x[1], reverse=True) - most_recent = candidate_dirs[0][0] - time_diff = candidate_dirs[0][2] - self.logger.info( - f"Found recent directory: {most_recent} (modified {time_diff:.1f}s ago)" - ) - return most_recent - - return None - def _find_dir_by_pattern( self, workspace_path: Path, model_id: str, safe_model_id: str ) -> Optional[Path]: """Find directory matching model_id patterns""" patterns = [ - safe_model_id, + model_id.replace("/", "_"), # org_model (primary) model_id.replace("/", "-"), - model_id.replace("/", "_"), + model_id.split("/")[-1], # short name fallback ] for item in workspace_path.iterdir(): @@ -222,4 +208,8 @@ def _find_hash_named_dir(self, workspace_path: Path) -> Optional[Path]: def _is_valid_sample_dir(self, dir_path: Path) -> bool: """Check if a directory is a valid sample directory""" required_files = ["model.py", "graph_net.json"] - return all((dir_path / f).exists() for f in required_files) + # 单图:根目录下有文件 + if all((dir_path / f).exists() for f in required_files): + return True + # 多子图:subgraph_* 子目录下有文件 + return any(dir_path.glob("subgraph_*/model.py")) diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index e9e25ea33f..e51c1a2e34 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -1,6 +1,7 @@ """GraphNet Agent core implementation""" -import shutil +import json +import os from pathlib import Path from typing import Optional @@ -8,6 +9,7 @@ from graph_net.agent.metadata_analyzer import ConfigMetadataAnalyzer from graph_net.agent.code_generator import TemplateCodeGenerator +from graph_net.agent.code_generator.llm_code_fixer import LLMCodeFixer from graph_net.agent.graph_extractor import SubprocessGraphExtractor from graph_net.agent.model_fetcher import HFFetcher from graph_net.agent.utils.exceptions import ( @@ -18,7 +20,7 @@ ) from graph_net.agent.utils.logger import setup_logger from graph_net.agent.utils.workspace_manager import WorkspaceManager -from graph_net.agent.sample_verifier import BasicSampleVerifier +from graph_net.agent.sample_verifier import ForwardVerifier class GraphNetAgent: @@ -26,16 +28,25 @@ class GraphNetAgent: def __init__( self, - workspace: str, + workspace: Optional[str] = None, hf_token: Optional[str] = None, + llm_retry: bool = True, ): """ Initialize GraphNet Agent Args: - workspace: Workspace root directory - hf_token: HuggingFace API token (optional) + workspace: Workspace root directory. Defaults to + $GRAPH_NET_EXTRACT_WORKSPACE or ~/graphnet_workspace. + hf_token: HuggingFace API token (optional) + llm_retry: If True and ducc/claude CLI is available, retry failed + extractions up to 2 times with LLM-fixed scripts. """ + if workspace is None: + workspace = os.environ.get( + "GRAPH_NET_EXTRACT_WORKSPACE", + os.path.expanduser("~/graphnet_workspace"), + ) self.workspace = WorkspaceManager(workspace) self.logger = setup_logger( "GraphNetAgent", @@ -46,17 +57,26 @@ def __init__( self.model_fetcher = HFFetcher( cache_dir=str(self.workspace.models_dir), token=hf_token, + max_retries=3, + retry_delay=5, ) self.metadata_analyzer = ConfigMetadataAnalyzer() self.code_generator = TemplateCodeGenerator() self.graph_extractor = SubprocessGraphExtractor( workspace=str(self.workspace.workspace_root) ) - self.sample_verifier = BasicSampleVerifier() + self.sample_verifier = ForwardVerifier() + + # LLM fixer — only created when llm_retry is requested + self.llm_fixer: Optional[LLMCodeFixer] = LLMCodeFixer() if llm_retry else None def extract_sample(self, model_id: str) -> bool: """ - Execute complete sample extraction pipeline from HuggingFace model ID + Execute complete sample extraction pipeline from HuggingFace model ID. + + On first failure the LLM fixer (ducc -p) is invoked up to 2 times to + produce a repaired script. Each retry feeds the previous script and its + error back to the LLM for further refinement. Args: model_id: HuggingFace model ID (e.g., "bert-base-uncased") @@ -70,9 +90,17 @@ def extract_sample(self, model_id: str) -> bool: model_dir = self._fetch_model(model_id) model_metadata = self._analyze_model(model_dir) script_path = self._generate_script(model_dir, model_metadata, model_id) - sample_dir = self._extract_graph(script_path, model_id) + + # ── First attempt (template script) ────────────────────────── + try: + sample_dir = self._extract_graph(script_path, model_id) + except ExtractionError as first_err: + sample_dir = self._llm_retry( + first_err, script_path, model_dir, model_id + ) self._generate_graph_hash(sample_dir) + self._fix_model_name(sample_dir, model_id) if self.is_duplicate_sample(sample_dir): self.logger.info("Duplicate sample detected, skipping verification") @@ -82,7 +110,6 @@ def extract_sample(self, model_id: str) -> bool: self.logger.error("Sample verification failed") return False - self._archive_script(script_path, sample_dir) self.logger.info(f"Successfully extracted sample for {model_id}") return True @@ -93,6 +120,55 @@ def extract_sample(self, model_id: str) -> bool: self.logger.error(f"Unexpected error for {model_id}: {e}", exc_info=True) return False + def _llm_retry( + self, + first_err: ExtractionError, + script_path: Path, + model_dir: Path, + model_id: str, + ) -> tuple[Path, Path]: + """ + On extraction failure: ask the LLM to fix the script and retry, up to 2 times. + Each attempt feeds the previous script + its error back to the LLM. + + Returns: + (sample_dir, successful_script_path) + + Raises ExtractionError if LLM fix is unavailable or both attempts fail. + """ + if self.llm_fixer is None or not self.llm_fixer.available: + self.logger.warning( + "LLM retry disabled or ducc not available; re-raising original error." + ) + raise first_err + + generated_dir = self.workspace.get_generated_dir(model_id) + err = first_err + current_script = script_path + + for attempt in range(1, 3): # attempt 1, 2 + self.logger.warning( + f"Extraction failed (attempt {attempt}/2): {err}\n" + f"Invoking LLM to fix the script..." + ) + fixed_path = self.llm_fixer.fix( + script_path=current_script, + error_msg=str(err), + model_dir=model_dir, + model_id=model_id, + output_dir=generated_dir, + attempt=attempt, + ) + self.logger.info(f"Retrying extraction with LLM-fixed script: {fixed_path}") + try: + sample_dir = self._extract_graph(fixed_path, model_id) + return sample_dir + except ExtractionError as retry_err: + err = retry_err + current_script = fixed_path # 第二次把上一次修复的脚本+新报错再喂给 LLM + + raise err + def _fetch_model(self, model_id: str) -> Path: """Download model from HuggingFace Hub""" self.logger.info(f"Fetching model: {model_id}") @@ -112,6 +188,9 @@ def _analyze_model(self, model_dir: Path): def _generate_script(self, model_dir: Path, model_metadata, model_id: str) -> Path: """Generate run_model.py script based on metadata""" self.logger.info("Generating extraction script") + # Override model_id in metadata with the original HF model_id so that + # extract(name=...) uses the short model name, not a snapshot hash. + model_metadata.model_id = model_id generated_dir = self.workspace.get_generated_dir(model_id) script_path = self.code_generator.generate( model_dir, model_metadata, generated_dir @@ -126,10 +205,21 @@ def _extract_graph(self, script_path: Path, model_id: str) -> Path: self.logger.info(f"Graph extracted to: {sample_dir}") return sample_dir - def _archive_script(self, script_path: Path, sample_dir: Path) -> None: - """Archive generated script to sample directory""" - self.logger.info("Archiving extraction script") - self.save_extraction_script(script_path, sample_dir) + def _fix_model_name(self, sample_dir: Path, model_id: str) -> None: + """将 graph_net.json 中的 model_name 修正为原始 HuggingFace model_id(org/model)""" + for json_path in [ + sample_dir / "graph_net.json", + *sample_dir.glob("subgraph_*/graph_net.json"), + ]: + if not json_path.exists(): + continue + try: + data = json.loads(json_path.read_text()) + if data.get("model_name") != model_id: + data["model_name"] = model_id + json_path.write_text(json.dumps(data, indent=4)) + except (OSError, json.JSONDecodeError) as e: + self.logger.warning(f"Failed to fix model_name in {json_path}: {e}") def _generate_graph_hash(self, sample_dir: Path) -> None: """Generate graph_hash.txt from model.py if it doesn't exist""" @@ -180,14 +270,3 @@ def is_duplicate_sample(self, sample_dir: Path) -> bool: except (OSError, IOError) as e: self.logger.warning(f"Failed to check duplicate: {e}") return False - - def save_extraction_script(self, script_path: Path, sample_dir: Path) -> bool: - """Save the generated extraction script to the sample directory""" - try: - target_path = sample_dir / "run_model.py" - shutil.copy(script_path, target_path) - self.logger.info(f"Script archived to: {target_path}") - return True - except (OSError, IOError, shutil.Error) as e: - self.logger.error(f"Failed to archive script: {e}") - return False diff --git a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py index bc37fb71ef..7b8501f5c0 100644 --- a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py +++ b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py @@ -9,7 +9,11 @@ from graph_net.agent.utils.exceptions import AnalysisError -# Common embedding weight keys in different model architectures +# Cap sequence length to avoid OOM: attention is O(n²), graph extraction +# only needs a short sequence to trace the computation graph. +_MAX_SEQ_LEN = 128 +# Cap image size to avoid OOM on high-resolution configs. +_MAX_IMAGE_SIZE = 512 _EMBEDDING_WEIGHT_KEYS = [ "embeddings.word_embeddings.weight", "model.embed_tokens.weight", @@ -106,7 +110,10 @@ def _extract_input_info( # Common patterns for NLP models if "max_position_embeddings" in config or "vocab_size" in config: # NLP model (BERT, GPT, etc.) - max_length = config.get("max_position_embeddings", 512) + # Cap to _MAX_SEQ_LEN: large models set max_position_embeddings to + # 131072+ which causes OOM via O(n²) attention during graph tracing. + raw_len = config.get("max_position_embeddings", 512) + max_length = min(raw_len, _MAX_SEQ_LEN) batch_size = 1 input_shapes["input_ids"] = [batch_size, max_length] input_dtypes["input_ids"] = "int64" @@ -119,7 +126,11 @@ def _extract_input_info( # Common patterns for vision models elif "image_size" in config or "num_channels" in config: # Vision model (ResNet, ViT, etc.) - image_size = config.get("image_size", 224) + # image_size may be an int or a [H, W] list + raw_size = config.get("image_size", 224) + if isinstance(raw_size, (list, tuple)): + raw_size = raw_size[0] + image_size = min(int(raw_size), _MAX_IMAGE_SIZE) num_channels = config.get("num_channels", 3) batch_size = 1 input_shapes["pixel_values"] = [ diff --git a/graph_net/agent/model_fetcher/huggingface_fetcher.py b/graph_net/agent/model_fetcher/huggingface_fetcher.py index 024c6496cd..903e4984b6 100644 --- a/graph_net/agent/model_fetcher/huggingface_fetcher.py +++ b/graph_net/agent/model_fetcher/huggingface_fetcher.py @@ -1,5 +1,7 @@ """HuggingFace model fetcher implementation""" +import os +import time from pathlib import Path from typing import Optional @@ -11,22 +13,63 @@ from graph_net.agent.model_fetcher.base import BaseModelFetcher from graph_net.agent.utils.exceptions import ModelFetchError +# Network-related exceptions that are worth retrying +_RETRYABLE_ERRORS = ( + ConnectionError, + TimeoutError, + OSError, +) + +# Try to import httpx/huggingface_hub errors for more granular retry +try: + import httpx + + _RETRYABLE_ERRORS = _RETRYABLE_ERRORS + (httpx.ConnectTimeout, httpx.ReadTimeout) +except ImportError: + pass + +try: + from huggingface_hub.errors import LocalEntryNotFoundError + + _RETRYABLE_ERRORS = _RETRYABLE_ERRORS + (LocalEntryNotFoundError,) +except ImportError: + pass + class HFFetcher(BaseModelFetcher): """HuggingFace model fetcher using huggingface_hub""" - def __init__(self, cache_dir: Optional[str] = None, token: Optional[str] = None): + DEFAULT_MAX_RETRIES = 3 + DEFAULT_RETRY_DELAY = 5 # seconds, will be exponentially backed off + + def __init__( + self, + cache_dir: Optional[str] = None, + token: Optional[str] = None, + max_retries: int = DEFAULT_MAX_RETRIES, + retry_delay: float = DEFAULT_RETRY_DELAY, + endpoint: Optional[str] = None, + ): """ Args: - cache_dir: Directory to cache downloaded models - token: HuggingFace API token (optional, for private models) + cache_dir: Directory to cache downloaded models + token: HuggingFace API token (optional, for private models) + max_retries: Max retry attempts on network errors (default 3) + retry_delay: Initial delay between retries in seconds (default 5, exponential backoff) + endpoint: HuggingFace mirror endpoint (e.g., "https://hf-mirror.com"). + If not set, falls back to HF_ENDPOINT env var. """ self.cache_dir = Path(cache_dir) if cache_dir else None self.token = token + self.max_retries = max_retries + self.retry_delay = retry_delay + + # Resolve endpoint: explicit param > env var + self.endpoint = endpoint or os.environ.get("HF_ENDPOINT") def download(self, model_id: str) -> Path: """ - Download model from HuggingFace Hub + Download model from HuggingFace Hub with retry on network errors. Args: model_id: HuggingFace model ID (e.g., "bert-base-uncased") @@ -35,7 +78,7 @@ def download(self, model_id: str) -> Path: Path to local model directory Raises: - ModelFetchError: If download fails + ModelFetchError: If download fails after all retries """ if snapshot_download is None: raise ModelFetchError( @@ -43,13 +86,67 @@ def download(self, model_id: str) -> Path: "Please install it with: pip install huggingface_hub" ) - try: - # Use snapshot_download to get all model files - local_dir = snapshot_download( - repo_id=model_id, - cache_dir=str(self.cache_dir) if self.cache_dir else None, - token=self.token, - ) - return Path(local_dir) - except Exception as e: - raise ModelFetchError(f"Failed to download model {model_id}: {e}") from e + last_err = None + for attempt in range(1, self.max_retries + 1): + try: + # Set endpoint for this call if configured + if self.endpoint: + os.environ["HF_ENDPOINT"] = self.endpoint + + local_dir = snapshot_download( + repo_id=model_id, + cache_dir=str(self.cache_dir) if self.cache_dir else None, + token=self.token, + ignore_patterns=[ + "*.bin", + "*.safetensors", + "*.pt", + "*.pth", + "*.gguf", + "*.ot", + "*.zip", + "*.tflite", + "*.mlmodel", + "*.onnx", + "*.msgpack", + "flax_model*", + "tf_model*", + "rust_model*", + ], + ) + return Path(local_dir) + + except _RETRYABLE_ERRORS as e: + last_err = e + if attempt < self.max_retries: + delay = self.retry_delay * (2 ** (attempt - 1)) + # Check if the error message indicates a timeout — these are worth retrying + err_msg = str(e).lower() + is_timeout = any( + kw in err_msg + for kw in ("timeout", "timed out", "connection", "refused") + ) + if not is_timeout: + raise ModelFetchError( + f"Failed to download model {model_id}: {e}" + ) from e + + print( + f"[HFFetcher] Network error for {model_id} " + f"(attempt {attempt}/{self.max_retries}): {e}. " + f"Retrying in {delay}s..." + ) + time.sleep(delay) + else: + raise ModelFetchError( + f"Failed to download model {model_id} after {self.max_retries} retries: {e}" + ) from e + except Exception as e: + raise ModelFetchError( + f"Failed to download model {model_id}: {e}" + ) from e + + # Should not reach here, but just in case + raise ModelFetchError( + f"Failed to download model {model_id} after {self.max_retries} retries: {last_err}" + ) diff --git a/graph_net/agent/parallel_extract.py b/graph_net/agent/parallel_extract.py new file mode 100644 index 0000000000..957c25b313 --- /dev/null +++ b/graph_net/agent/parallel_extract.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +""" +Parallel extraction script: dynamic scheduling, all models placed in a shared task queue, +each GPU worker picks up tasks when idle. + +Usage examples: + # Load model list from file (one model_id per line) + python parallel_extract.py --model-list models.txt + + # Fetch 400 models from HuggingFace Hub + python parallel_extract.py --count 400 + + # Specify workspace and HF token + python parallel_extract.py --model-list models.txt \ + --workspace /data/graphnet_workspace \ + --hf-token YOUR_TOKEN + + # Specify GPUs to use (default: auto-detect all available GPUs) + python parallel_extract.py --model-list models.txt --gpus 0,1,2,3 +""" + +import argparse +import json +import multiprocessing +import os +import queue +import subprocess +import sys +import time +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +# Ensure graph_net is importable +_SCRIPT_DIR = Path(__file__).resolve().parent +_GRAPHNET_ROOT = _SCRIPT_DIR.parent.parent # GraphNet/ +if str(_GRAPHNET_ROOT) not in sys.path: + sys.path.insert(0, str(_GRAPHNET_ROOT)) + +from graph_net.agent import GraphNetAgent # noqa: E402 + +try: + from huggingface_hub import list_models as _hf_list_models + + HUGGINGFACE_HUB_AVAILABLE = True +except ImportError: + HUGGINGFACE_HUB_AVAILABLE = False + + +def load_models_from_file(path: str) -> List[str]: + models = [] + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + models.append(line) + return models + + +def get_models_from_hf(task: Optional[str] = None, limit: int = 100) -> List[str]: + return [ + m.modelId + for m in _hf_list_models(task=task, limit=limit, sort="downloads", direction=-1) + ] + + +def _get_default_gpus() -> List[int]: + """Detect available GPU indices from environment or nvidia-smi.""" + cvd = os.getenv("CUDA_VISIBLE_DEVICES", "") + if cvd: + try: + return [int(g.strip()) for g in cvd.split(",") if g.strip()] + except ValueError: + pass + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + gpus = [ + int(x.strip()) for x in result.stdout.strip().split("\n") if x.strip() + ] + if gpus: + return gpus + except Exception: + pass + return [0] + + +DEFAULT_GPUS = _get_default_gpus() +DEFAULT_WORKSPACE = os.environ.get( + "GRAPH_NET_EXTRACT_WORKSPACE", + os.path.expanduser("~/graphnet_workspace"), +) + + +# --------------------------------------------------------------------------- +# Worker — runs in a dedicated subprocess, bound to one GPU +# --------------------------------------------------------------------------- + + +def _worker( + gpu_id: int, + task_queue: multiprocessing.Queue, + result_queue: multiprocessing.Queue, + workspace: str, + hf_token: Optional[str], + total: int, +) -> None: + """ + Worker function, runs in a dedicated subprocess bound to a single GPU. + Dynamically pulls tasks from task_queue and exits when the queue is empty. + + Args: + gpu_id: CUDA device index (e.g. 2) + task_queue: Shared task queue; each item is a model_id string + result_queue: Queue for reporting results back to the main process + workspace: Root workspace directory path + hf_token: HuggingFace token (optional) + total: Total task count (used for logging only) + """ + # Bind GPU: subprocess only sees this card, internal code can use cuda:0 + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + # Pass workspace to the environment variable used by SubprocessGraphExtractor + os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = workspace + + print(f"[GPU {gpu_id}] Worker started", flush=True) + + try: + agent = GraphNetAgent(workspace=workspace, hf_token=hf_token, llm_retry=False) + except Exception as e: + print(f"[GPU {gpu_id}] Failed to initialize agent: {e}", flush=True) + # Drain queue and mark remaining tasks as failed to avoid blocking the main process + while True: + try: + mid = task_queue.get_nowait() + result_queue.put( + { + "gpu": gpu_id, + "model_id": mid, + "success": False, + "error": str(e), + "elapsed": 0.0, + } + ) + except queue.Empty: + break + return + + while True: + try: + model_id = task_queue.get_nowait() + except queue.Empty: + break + + print(f"[GPU {gpu_id}] Extracting: {model_id}", flush=True) + t0 = time.time() + try: + success = agent.extract_sample(model_id) + elapsed = time.time() - t0 + status = "OK" if success else "FAIL" + print(f"[GPU {gpu_id}] {status} {model_id} ({elapsed:.1f}s)", flush=True) + result_queue.put( + { + "gpu": gpu_id, + "model_id": model_id, + "success": success, + "elapsed": round(elapsed, 2), + "timestamp": datetime.now().isoformat(), + } + ) + except Exception as e: + elapsed = time.time() - t0 + print(f"[GPU {gpu_id}] ERROR {model_id}: {e} ({elapsed:.1f}s)", flush=True) + result_queue.put( + { + "gpu": gpu_id, + "model_id": model_id, + "success": False, + "error": str(e), + "elapsed": round(elapsed, 2), + "timestamp": datetime.now().isoformat(), + } + ) + + print(f"[GPU {gpu_id}] Worker finished (queue empty)", flush=True) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _save_results(results: Dict, output_file: str) -> None: + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + print(f"\n[SAVE] Results saved to: {output_path}") + + +def _print_summary(results: Dict) -> None: + details = results.get("details", []) + total = len(details) + success = sum(1 for d in details if d.get("success")) + failed = total - success + rate = (success / total * 100) if total else 0.0 + print("\n" + "=" * 60) + print("[SUMMARY] Parallel Extraction Summary") + print("=" * 60) + print(f" Total : {total}") + print(f" Success: {success}") + print(f" Failed : {failed}") + print(f" Rate : {rate:.2f}%") + # Per-GPU breakdown + gpu_stats: Dict[int, Dict] = {} + for d in details: + g = d.get("gpu", -1) + if g not in gpu_stats: + gpu_stats[g] = {"total": 0, "success": 0} + gpu_stats[g]["total"] += 1 + if d.get("success"): + gpu_stats[g]["success"] += 1 + print("\n Per-GPU:") + for g in sorted(gpu_stats): + gs = gpu_stats[g] + gr = (gs["success"] / gs["total"] * 100) if gs["total"] else 0.0 + print(f" GPU {g}: {gs['success']}/{gs['total']} ({gr:.1f}%)") + print("=" * 60) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Parallel computation graph extraction from HuggingFace; one agent process per GPU" + ) + parser.add_argument( + "--model-list", + type=str, + default=None, + help="Path to model list file (one model_id per line, lines starting with # are comments)", + ) + parser.add_argument( + "--count", + type=int, + default=100, + help="Number of models to fetch from HuggingFace Hub (used when --model-list is not set, default 100)", + ) + parser.add_argument( + "--task", + type=str, + default=None, + help="HuggingFace task filter (e.g. text-classification)", + ) + parser.add_argument( + "--workspace", + type=str, + default=None, + help=f"Root workspace directory (default: {DEFAULT_WORKSPACE} or GRAPH_NET_EXTRACT_WORKSPACE env var)", + ) + parser.add_argument( + "--hf-token", + type=str, + default=None, + help="HuggingFace API Token (required for private models)", + ) + parser.add_argument( + "--gpus", + type=str, + default=",".join(str(g) for g in DEFAULT_GPUS), + help=f"Comma-separated GPU indices to use (default: {','.join(str(g) for g in DEFAULT_GPUS)})", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output JSON file path (default: auto-generated filename with timestamp)", + ) + + args = parser.parse_args() + + # --- Resolve workspace --- + workspace = ( + args.workspace or os.getenv("GRAPH_NET_EXTRACT_WORKSPACE") or DEFAULT_WORKSPACE + ) + print(f"[INFO] Workspace: {workspace}") + + # --- Parse GPU list --- + try: + gpus = [int(g.strip()) for g in args.gpus.split(",") if g.strip()] + except ValueError: + print(f"[ERROR] Invalid --gpus value: {args.gpus}") + return 1 + if not gpus: + print("[ERROR] No GPUs specified") + return 1 + print(f"[INFO] GPUs: {gpus}") + + # --- Load model list --- + if args.model_list: + model_ids = load_models_from_file(args.model_list) + elif HUGGINGFACE_HUB_AVAILABLE: + print( + f"[INFO] Fetching {args.count} models from HuggingFace Hub (task={args.task})..." + ) + model_ids = get_models_from_hf(task=args.task, limit=args.count) + else: + print("[ERROR] No model list provided and huggingface_hub not available") + return 1 + + if not model_ids: + print("[ERROR] Empty model list, nothing to do") + return 1 + + print(f"[INFO] Total models: {len(model_ids)}, workers: {len(gpus)}") + + # --- Populate shared task queue --- + task_queue: multiprocessing.Queue = multiprocessing.Queue() + for mid in model_ids: + task_queue.put(mid) + + # --- Launch workers --- + result_queue: multiprocessing.Queue = multiprocessing.Queue() + processes = [] + + start_time = datetime.now() + print( + f"\n[START] {start_time.strftime('%Y-%m-%d %H:%M:%S')} — launching {len(gpus)} workers\n" + ) + + for gpu_id in gpus: + p = multiprocessing.Process( + target=_worker, + args=( + gpu_id, + task_queue, + result_queue, + workspace, + args.hf_token, + len(model_ids), + ), + name=f"worker-gpu{gpu_id}", + daemon=True, + ) + p.start() + processes.append(p) + + # --- Collect results --- + details = [] + total_expected = len(model_ids) + + while len(details) < total_expected: + try: + entry = result_queue.get(timeout=5) + details.append(entry) + done = len(details) + success_so_far = sum(1 for d in details if d.get("success")) + print( + f"[PROGRESS] {done}/{total_expected} done, " + f"success rate so far: {success_so_far/done*100:.1f}%", + flush=True, + ) + except Exception: + # Check if all workers are done + alive = [p for p in processes if p.is_alive()] + if not alive: + break + + # Wait for all workers to exit cleanly + for p in processes: + p.join(timeout=10) + + end_time = datetime.now() + elapsed_total = (end_time - start_time).total_seconds() + + # --- Build result summary --- + results = { + "start_time": start_time.isoformat(), + "end_time": end_time.isoformat(), + "elapsed_seconds": round(elapsed_total, 1), + "gpus": gpus, + "workspace": workspace, + "total": len(details), + "success": sum(1 for d in details if d.get("success")), + "failed": sum(1 for d in details if not d.get("success")), + "success_rate": 0.0, + "details": details, + } + if results["total"] > 0: + results["success_rate"] = round(results["success"] / results["total"] * 100, 2) + + # --- Save results --- + output_file = ( + args.output or f"parallel_extract_{start_time.strftime('%Y%m%d_%H%M%S')}.json" + ) + _save_results(results, output_file) + + # --- Print summary --- + _print_summary(results) + print(f"\n[DONE] Total elapsed: {elapsed_total:.0f}s") + + return 0 if results["success_rate"] > 0 else 1 + + +if __name__ == "__main__": + # Linux defaults to fork; explicitly use spawn to avoid CUDA fork issues + multiprocessing.set_start_method("spawn", force=True) + sys.exit(main()) diff --git a/graph_net/agent/sample_verifier/__init__.py b/graph_net/agent/sample_verifier/__init__.py index 9302aea0cd..22104c7957 100644 --- a/graph_net/agent/sample_verifier/__init__.py +++ b/graph_net/agent/sample_verifier/__init__.py @@ -2,5 +2,6 @@ from graph_net.agent.sample_verifier.base import BaseSampleVerifier from graph_net.agent.sample_verifier.basic_sample_verifier import BasicSampleVerifier +from graph_net.agent.sample_verifier.forward_verifier import ForwardVerifier -__all__ = ["BaseSampleVerifier", "BasicSampleVerifier"] +__all__ = ["BaseSampleVerifier", "BasicSampleVerifier", "ForwardVerifier"] diff --git a/graph_net/agent/sample_verifier/basic_sample_verifier.py b/graph_net/agent/sample_verifier/basic_sample_verifier.py index 4b639eec3f..70e50e20d0 100644 --- a/graph_net/agent/sample_verifier/basic_sample_verifier.py +++ b/graph_net/agent/sample_verifier/basic_sample_verifier.py @@ -1,5 +1,6 @@ """Basic sample verifier implementation""" +import json from pathlib import Path from graph_net.agent.sample_verifier.base import BaseSampleVerifier @@ -7,10 +8,12 @@ class BasicSampleVerifier(BaseSampleVerifier): - """Basic verifier that checks file existence and basic structure""" + """Basic verifier that checks file existence and basic structure. + + Supports both single-graph and multi-subgraph (subgraph_0/, subgraph_1/, …) layouts. + """ def __init__(self): - """Initialize basic verifier""" self.required_files = [ "model.py", "graph_net.json", @@ -19,32 +22,20 @@ def __init__(self): ] def verify(self, sample_dir: Path) -> bool: - """ - Verify sample validity - - Args: - sample_dir: Path to sample directory - - Returns: - True if sample is valid, False otherwise - """ try: - # Check required files exist - for filename in self.required_files: - file_path = sample_dir / filename - if not file_path.exists(): + subgraph_dirs = sorted(sample_dir.glob("subgraph_*/")) + targets = subgraph_dirs if subgraph_dirs else [sample_dir] + + for target in targets: + for filename in self.required_files: + if not (target / filename).exists(): + return False + try: + with open(target / "graph_net.json", "r") as f: + json.load(f) + except (json.JSONDecodeError, IOError): return False - # Check graph_net.json is valid JSON - json_path = sample_dir / "graph_net.json" - try: - import json - - with open(json_path, "r") as f: - json.load(f) - except (json.JSONDecodeError, IOError): - return False - return True except Exception as e: raise VerificationError(f"Verification failed: {e}") from e diff --git a/graph_net/agent/sample_verifier/forward_verifier.py b/graph_net/agent/sample_verifier/forward_verifier.py new file mode 100644 index 0000000000..5f8ba03235 --- /dev/null +++ b/graph_net/agent/sample_verifier/forward_verifier.py @@ -0,0 +1,104 @@ +"""Forward-pass verifier: eager execution of extracted GraphModule.""" + +import logging +import subprocess +import sys +from pathlib import Path + +from graph_net.agent.sample_verifier.base import BaseSampleVerifier +from graph_net.agent.sample_verifier.basic_sample_verifier import BasicSampleVerifier +from graph_net.agent.utils.exceptions import VerificationError + +# Inline eager runner — executed in a subprocess to isolate CUDA state. +# Loads GraphModule from model.py, reconstructs tensors from weight_meta.py, +# and runs a plain eager forward pass (no dynamo / compile). +_EAGER_RUNNER = """ +import sys, importlib.util, torch +from graph_net.torch import utils + +model_path = sys.argv[1] + +spec = importlib.util.spec_from_file_location("m", f"{model_path}/model.py") +mod = importlib.util.module_from_spec(spec) +spec.loader.exec_module(mod) +model = mod.GraphModule().eval() + +inputs_params = utils.load_converted_from_text(model_path) +state_dict = {k: utils.replay_tensor(v) for k, v in inputs_params["weight_info"].items()} + +with torch.no_grad(): + model(**state_dict) +""" + + +class ForwardVerifier(BaseSampleVerifier): + """ + Two-stage verifier: + 1. BasicSampleVerifier — file existence & JSON validity + 2. Eager forward pass — plain model(**inputs), no dynamo/compile + + For multi-subgraph samples (subgraph_0/, subgraph_1/, …) each subgraph + is verified independently; all must pass. + """ + + def __init__(self, timeout: int = 300): + """ + Args: + timeout: seconds to wait for each forward-pass subprocess (default 5 min) + """ + self._basic = BasicSampleVerifier() + self.timeout = timeout + self.logger = logging.getLogger(self.__class__.__name__) + + def verify(self, sample_dir: Path) -> bool: + """ + Verify sample validity including a real eager forward pass. + + Args: + sample_dir: Path to extracted sample directory + + Returns: + True if all checks pass, False otherwise + """ + try: + # Stage 1: file structure check + if not self._basic.verify(sample_dir): + self.logger.warning(f"Basic verification failed: {sample_dir}") + return False + + # Stage 2: eager forward pass (per subgraph if multi-subgraph) + subgraph_dirs = sorted(sample_dir.glob("subgraph_*/")) + targets = subgraph_dirs if subgraph_dirs else [sample_dir] + + for target in targets: + if not self._run_forward(target): + return False + + return True + + except Exception as e: + raise VerificationError(f"Forward verification failed: {e}") from e + + def _run_forward(self, model_path: Path) -> bool: + """Run an eager forward pass on one model directory in a subprocess.""" + self.logger.info(f"Forward verify (eager): {model_path.name}") + try: + result = subprocess.run( + [sys.executable, "-c", _EAGER_RUNNER, str(model_path)], + capture_output=True, + text=True, + timeout=self.timeout, + ) + if result.returncode == 0: + self.logger.info(f"Forward verify OK: {model_path.name}") + return True + else: + self.logger.warning( + f"Forward verify FAIL: {model_path.name}\n{result.stderr[-2000:]}" + ) + return False + except subprocess.TimeoutExpired: + self.logger.warning( + f"Forward verify TIMEOUT ({self.timeout}s): {model_path.name}" + ) + return False diff --git a/graph_net/agent/tests/__init__.py b/graph_net/agent/tests/__init__.py deleted file mode 100644 index cc01f8e4a3..0000000000 --- a/graph_net/agent/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for Agent module""" diff --git a/graph_net/agent/tests/run_500_models_test.py b/graph_net/agent/tests/run_500_models_test.py deleted file mode 100755 index 561ae319b5..0000000000 --- a/graph_net/agent/tests/run_500_models_test.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 -"""Batch test script for model extraction success rate""" - -import argparse -import os -import sys -from datetime import datetime - -from graph_net.agent import GraphNetAgent -from graph_net.agent.tests.test_batch_success_rate import ( - get_models_from_hf, - run_batch_test, - HUGGINGFACE_HUB_AVAILABLE, -) - -# Task distribution ratios for mixed task testing -TEXT_CLASSIFICATION_RATIO = 0.4 -TEXT_GENERATION_RATIO = 0.4 -QUESTION_ANSWERING_RATIO = 0.2 - - -def main(): - """Run batch test""" - - parser = argparse.ArgumentParser( - description="Batch test model extraction success rate" - ) - parser.add_argument( - "--count", type=int, default=100, help="Number of models to test (default: 100)" - ) - parser.add_argument( - "--task", - type=str, - default=None, - help="HuggingFace task type (default: None, mixed tasks)", - ) - - args = parser.parse_args() - model_count = args.count - - print("=" * 70) - print(f"[START] GraphNet Agent Batch Test - {model_count} models") - print("=" * 70) - - workspace = os.getenv("GRAPH_NET_EXTRACT_WORKSPACE") - if not workspace: - print("\n[ERROR] GRAPH_NET_EXTRACT_WORKSPACE environment variable not set") - print("\nPlease set workspace:") - print(" export GRAPH_NET_EXTRACT_WORKSPACE=/path/to/workspace") - sys.exit(1) - - print(f"\n[INFO] Workspace: {workspace}") - - if not HUGGINGFACE_HUB_AVAILABLE: - print("\n[ERROR] huggingface_hub not installed") - print("Please install: pip install huggingface_hub") - sys.exit(1) - - print(f"\n[INFO] Fetching {model_count} models from HuggingFace Hub...") - - if args.task: - print(f" Task type: {args.task}") - else: - print(" Task type: Mixed NLP tasks") - print(" This may take some time...\n") - - try: - model_ids = [] - - if args.task: - print(f" - Fetching {args.task} models...") - models = get_models_from_hf(task=args.task, limit=model_count) - model_ids.extend(models) - print(f" Fetched {len(models)} models") - else: - tasks = [ - ("text-classification", int(model_count * TEXT_CLASSIFICATION_RATIO)), - ("text-generation", int(model_count * TEXT_GENERATION_RATIO)), - ("question-answering", int(model_count * QUESTION_ANSWERING_RATIO)), - ] - - for task, limit in tasks: - print(f" - Fetching {task} models...") - models = get_models_from_hf(task=task, limit=limit) - model_ids.extend(models) - print(f" Fetched {len(models)} models") - - model_ids = list(dict.fromkeys(model_ids)) - - if len(model_ids) < model_count: - print(f" - Current: {len(model_ids)} models, fetching more...") - additional = get_models_from_hf( - task=None, limit=model_count - len(model_ids) - ) - model_ids.extend(additional) - model_ids = list(dict.fromkeys(model_ids)) - - model_ids = model_ids[:model_count] - - print(f"\n[OK] Successfully fetched {len(model_ids)} models") - - except Exception as e: - print(f"\n[ERROR] Failed to fetch model list: {e}") - import traceback - - traceback.print_exc() - sys.exit(1) - - print("\n[INIT] Initializing Agent...") - try: - agent = GraphNetAgent(workspace=workspace, hf_token=None) - print("[OK] Agent initialized successfully\n") - except Exception as e: - print(f"[ERROR] Failed to initialize Agent: {e}") - sys.exit(1) - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_file = f"batch_test_{model_count}_models_{timestamp}.json" - - print("=" * 70) - print("[INFO] Starting batch test") - print(f" Model count: {len(model_ids)}") - print(f" Output file: {output_file}") - print("=" * 70) - print("\n[TIP] Notes:") - print(" - Test may take hours or longer") - print(" - Press Ctrl+C to interrupt anytime") - print(" - Results are saved to JSON file in real-time") - print(" - Recommended to run with screen or nohup") - print("\n" + "=" * 70 + "\n") - - try: - results = run_batch_test( - agent=agent, - model_ids=model_ids, - output_file=output_file, - ) - - print("\n" + "=" * 70) - print("[DONE] Test completed") - print("=" * 70) - print(f"Total models: {results['total']}") - print(f"Success: {results['success']}") - print(f"Failed: {results['failed']}") - print(f"Success rate: {results['success_rate']:.2f}%") - print(f"Output file: {output_file}") - print("=" * 70) - - return 0 - - except KeyboardInterrupt: - print("\n\n[WARNING] Test interrupted by user") - print(f"Partial results saved to: {output_file}") - return 1 - except Exception as e: - print(f"\n[ERROR] Error during test: {e}") - import traceback - - traceback.print_exc() - return 1 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/graph_net/agent/tests/test_batch_success_rate.py b/graph_net/agent/tests/test_batch_success_rate.py deleted file mode 100644 index 1e15a6a1ad..0000000000 --- a/graph_net/agent/tests/test_batch_success_rate.py +++ /dev/null @@ -1,281 +0,0 @@ -"""Batch testing script for GraphNet Agent success rate statistics""" - -import argparse -import json -import os -import sys -import time -from datetime import datetime -from pathlib import Path -from typing import Dict, List, Optional - -try: - from huggingface_hub import list_models - - HUGGINGFACE_HUB_AVAILABLE = True -except ImportError: - HUGGINGFACE_HUB_AVAILABLE = False - -from graph_net.agent import GraphNetAgent - -# Default test models (common and small models) -DEFAULT_TEST_MODELS = [ - "bert-base-uncased", - "distilbert-base-uncased", - "roberta-base", - "gpt2", - "t5-small", - "albert-base-v2", - "google/bert_uncased_L-2_H-128_A-2", - "google/t5-efficient-mini", -] - - -def get_models_from_hf(task: Optional[str] = None, limit: int = 100) -> List[str]: - """Get model list from HuggingFace Hub""" - if not HUGGINGFACE_HUB_AVAILABLE: - print("[WARNING] huggingface_hub not installed, cannot fetch models from Hub") - return [] - - try: - print( - f"[INFO] Fetching models from HuggingFace Hub (task={task}, limit={limit})..." - ) - - search_params = { - "sort": "downloads", - "direction": -1, # descending order - "limit": limit, - } - - if task: - search_params["task"] = task - - models = list(list_models(**search_params)) - model_ids = [model.id for model in models] - - print(f"[OK] Fetched {len(model_ids)} models") - return model_ids - - except Exception as e: - print(f"[ERROR] Failed to fetch model list: {e}") - return [] - - -def load_models_from_file(file_path: str) -> List[str]: - """Load model list from file (one model ID per line)""" - try: - with open(file_path, "r", encoding="utf-8") as f: - models = [ - line.strip() for line in f if line.strip() and not line.startswith("#") - ] - print(f"[OK] Loaded {len(models)} models from file: {file_path}") - return models - except (OSError, IOError) as e: - print(f"[ERROR] Failed to load model list: {e}") - return [] - - -def run_batch_test( - agent: GraphNetAgent, - model_ids: List[str], - output_file: Optional[str] = None, -) -> Dict: - """Run batch test and calculate success rate""" - results = { - "total": len(model_ids), - "success": 0, - "failed": 0, - "success_rate": 0.0, - "start_time": datetime.now().isoformat(), - "details": [], - } - - print(f"\n{'='*60}") - print(f"[START] Starting batch test: {len(model_ids)} models") - print(f"{'='*60}\n") - - for idx, model_id in enumerate(model_ids, 1): - print(f"\n[{idx}/{len(model_ids)}] Testing: {model_id}") - print("-" * 60) - - start_time = time.time() - try: - success = agent.extract_sample(model_id) - elapsed = time.time() - start_time - - if success: - results["success"] += 1 - status = "[OK] Success" - else: - results["failed"] += 1 - status = "[FAIL] Failed" - - result_entry = { - "model_id": model_id, - "success": success, - "elapsed_time": round(elapsed, 2), - "timestamp": datetime.now().isoformat(), - } - results["details"].append(result_entry) - - print(f"{status} (elapsed: {elapsed:.2f}s)") - - except KeyboardInterrupt: - print("\n[WARNING] Test interrupted by user") - break - except Exception as e: - elapsed = time.time() - start_time - results["failed"] += 1 - result_entry = { - "model_id": model_id, - "success": False, - "error": str(e), - "elapsed_time": round(elapsed, 2), - "timestamp": datetime.now().isoformat(), - } - results["details"].append(result_entry) - print(f"[ERROR] Exception: {e} (elapsed: {elapsed:.2f}s)") - - # Show real-time statistics - current_success_rate = (results["success"] / idx) * 100 - print( - f"[STATS] Current success rate: {current_success_rate:.2f}% ({results['success']}/{idx})" - ) - - results["end_time"] = datetime.now().isoformat() - results["success_rate"] = ( - (results["success"] / results["total"]) * 100 if results["total"] > 0 else 0.0 - ) - - # Save results - if output_file: - _save_results(results, output_file) - - # Print final statistics - _print_statistics(results) - - return results - - -def _save_results(results: Dict, output_file: str) -> None: - """Save test results to JSON file""" - output_path = Path(output_file) - output_path.parent.mkdir(parents=True, exist_ok=True) - try: - with open(output_path, "w", encoding="utf-8") as f: - json.dump(results, f, indent=2, ensure_ascii=False) - print(f"\n[SAVE] Results saved to: {output_path}") - except (OSError, IOError) as e: - print(f"[WARNING] Failed to save results: {e}") - - -def _print_statistics(results: Dict) -> None: - """Print final test statistics""" - print(f"\n{'='*60}") - print("[SUMMARY] Test Summary") - print(f"{'='*60}") - print(f"Total models: {results['total']}") - print(f"Success: {results['success']}") - print(f"Failed: {results['failed']}") - print(f"Success rate: {results['success_rate']:.2f}%") - print(f"{'='*60}\n") - - -def main(): - parser = argparse.ArgumentParser( - description="Batch test GraphNet Agent success rate" - ) - parser.add_argument( - "--model-list-file", - type=str, - help="Model list file path (one model ID per line)", - ) - parser.add_argument( - "--count", - type=int, - default=10, - help="Number of models to fetch from HuggingFace Hub (default: 10)", - ) - parser.add_argument( - "--task", - type=str, - help="HuggingFace task type (e.g., text-classification, text-generation)", - ) - parser.add_argument( - "--workspace", - type=str, - default=None, - help="Workspace path (default: use GRAPH_NET_EXTRACT_WORKSPACE env var)", - ) - parser.add_argument( - "--hf-token", type=str, default=None, help="HuggingFace API Token (optional)" - ) - parser.add_argument( - "--output", - type=str, - default="batch_test_results.json", - help="Output file path for results (default: batch_test_results.json)", - ) - parser.add_argument( - "--use-default-models", - action="store_true", - help="Use predefined default test model list", - ) - - args = parser.parse_args() - - workspace = args.workspace or os.getenv("GRAPH_NET_EXTRACT_WORKSPACE") - if not workspace: - print("[ERROR] workspace not specified") - print( - " Use --workspace or set GRAPH_NET_EXTRACT_WORKSPACE environment variable" - ) - sys.exit(1) - - model_ids = _get_model_list(args) - if not model_ids: - print("[ERROR] no models to test") - sys.exit(1) - - agent = _init_agent(workspace, args.hf_token) - - results = run_batch_test( - agent=agent, - model_ids=model_ids, - output_file=args.output, - ) - - sys.exit(0 if results["success_rate"] > 0 else 1) - - -def _get_model_list(args: argparse.Namespace) -> List[str]: - """Get model list from various sources""" - if args.use_default_models: - print(f"[INFO] Using default model list ({len(DEFAULT_TEST_MODELS)} models)") - return DEFAULT_TEST_MODELS - - if args.model_list_file: - return load_models_from_file(args.model_list_file) - - if HUGGINGFACE_HUB_AVAILABLE: - return get_models_from_hf(task=args.task, limit=args.count) - - print("[WARNING] No model source specified, using default model list") - return DEFAULT_TEST_MODELS - - -def _init_agent(workspace: str, hf_token: Optional[str]) -> GraphNetAgent: - """Initialize GraphNetAgent""" - print(f"\n[INIT] Initializing Agent (workspace: {workspace})...") - try: - agent = GraphNetAgent(workspace=workspace, hf_token=hf_token) - print("[OK] Agent initialized successfully\n") - return agent - except Exception as e: - print(f"[ERROR] Failed to initialize Agent: {e}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/graph_net/agent/tests/test_code_generator.py b/graph_net/agent/tests/test_code_generator.py deleted file mode 100644 index 0035b5bb35..0000000000 --- a/graph_net/agent/tests/test_code_generator.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Tests for code generation""" - -import tempfile -import unittest -from pathlib import Path - -from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata -from graph_net.agent.code_generator.template_generator import TemplateCodeGenerator - - -class TestTemplateCodeGenerator(unittest.TestCase): - """Test TemplateCodeGenerator""" - - def setUp(self): - """Set up test environment""" - self.generator = TemplateCodeGenerator() - self.temp_dir = tempfile.mkdtemp() - - def test_generate_code(self): - """Test code generation""" - model_dir = Path(self.temp_dir) / "model" - model_dir.mkdir() - - meta = ModelMetadata( - model_id="bert-base-uncased", - input_shapes={"input_ids": [1, 128]}, - input_dtypes={"input_ids": "int64"}, - model_type="bert", - ) - - output_dir = Path(self.temp_dir) / "output" - script_path = self.generator.generate(model_dir, meta, output_dir) - - self.assertTrue(script_path.exists()) - self.assertEqual(script_path.name, "run_model.py") - - # Check code content - code = script_path.read_text() - self.assertIn("bert-base-uncased", code) - self.assertIn("input_ids", code) - self.assertIn("graph_net.torch.extract", code) - - def test_generate_model_loader(self): - """Test model loader generation""" - model_dir = Path(self.temp_dir) / "model" - model_dir.mkdir() - - # Test BERT model - meta_bert = ModelMetadata( - model_id="bert-base-uncased", - input_shapes={"input_ids": [1, 128]}, - input_dtypes={"input_ids": "int64"}, - model_type="bert", - ) - load_code = self.generator._generate_model_loader(model_dir, meta_bert) - self.assertIn("AutoModel.from_pretrained", load_code) - - def test_generate_input_code(self): - """Test input code generation""" - meta = ModelMetadata( - model_id="test", - input_shapes={ - "input_ids": [1, 128], - "attention_mask": [1, 128], - }, - input_dtypes={ - "input_ids": "int64", - "attention_mask": "int64", - }, - ) - - input_code = self.generator._generate_input_code(meta) - self.assertIn("input_ids", input_code) - self.assertIn("attention_mask", input_code) - self.assertIn("torch.randn", input_code) - - -if __name__ == "__main__": - unittest.main() diff --git a/graph_net/agent/tests/test_integration.py b/graph_net/agent/tests/test_integration.py deleted file mode 100644 index 4ecf79b611..0000000000 --- a/graph_net/agent/tests/test_integration.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Integration tests for Agent end-to-end workflow""" - -import tempfile -import unittest -from pathlib import Path -from unittest.mock import Mock, patch - -from graph_net.agent.graph_net_agent import GraphNetAgent - - -class TestAgentIntegration(unittest.TestCase): - """Test Agent end-to-end workflow""" - - def setUp(self): - """Set up test environment""" - self.temp_dir = tempfile.mkdtemp() - self.agent = GraphNetAgent(workspace=self.temp_dir) - - @patch("graph_net.agent.model_fetcher.huggingface_fetcher.snapshot_download") - def test_agent_initialization(self, mock_download): - """Test Agent can be initialized""" - self.assertIsNotNone(self.agent) - self.assertIsNotNone(self.agent.model_fetcher) - self.assertIsNotNone(self.agent.metadata_analyzer) - self.assertIsNotNone(self.agent.code_generator) - self.assertIsNotNone(self.agent.graph_extractor) - self.assertIsNotNone(self.agent.sample_verifier) - - @patch("graph_net.agent.model_fetcher.huggingface_fetcher.snapshot_download") - @patch("graph_net.agent.graph_extractor.subprocess_graph_extractor.subprocess.run") - def test_full_workflow_mock(self, mock_subprocess, mock_download): - """Test full workflow with mocked dependencies""" - # Mock model download - mock_model_dir = Path(self.temp_dir) / "models" / "test_model" - mock_model_dir.mkdir(parents=True) - (mock_model_dir / "config.json").write_text( - '{"model_type": "bert", "max_position_embeddings": 512}' - ) - mock_download.return_value = str(mock_model_dir) - - # Mock subprocess execution (success) - mock_subprocess.return_value = Mock(returncode=0, stdout="", stderr="") - - # Mock output directory - mock_output_dir = Path(self.temp_dir) / "workspace" / "test_model" - mock_output_dir.mkdir(parents=True) - (mock_output_dir / "model.py").write_text("class GraphModule: pass") - (mock_output_dir / "graph_net.json").write_text('{"framework": "torch"}') - (mock_output_dir / "input_meta.py").write_text("") - (mock_output_dir / "weight_meta.py").write_text("") - (mock_output_dir / "graph_hash.txt").write_text("test_hash") - - # Mock extractor to return output_dir - self.agent.graph_extractor._find_output_dir_robust = Mock( - return_value=mock_output_dir - ) - - # Run agent - result = self.agent.extract_sample("test-model") - - # Should succeed (with mocked dependencies) - # Note: This will likely fail at subprocess execution in real scenario - # but tests the workflow structure - self.assertIsInstance(result, bool) - - def test_deduplicate_logic(self): - """Test deduplicate logic""" - # Create a sample directory with graph_hash - sample_dir = Path(self.temp_dir) / "sample" - sample_dir.mkdir() - (sample_dir / "graph_hash.txt").write_text("test_hash_123") - - # Test duplicate check - result = self.agent.is_duplicate_sample(sample_dir) - # Should return False if no duplicate found - self.assertIsInstance(result, bool) - - def test_archive_logic(self): - """Test archive logic""" - # Create sample directory - sample_dir = Path(self.temp_dir) / "sample" - sample_dir.mkdir() - - # Create a test script - script_path = Path(self.temp_dir) / "test_script.py" - script_path.write_text("print('test')") - - # Test script archiving - result = self.agent.save_extraction_script(script_path, sample_dir) - self.assertTrue(result) - self.assertTrue((sample_dir / "run_model.py").exists()) - - -if __name__ == "__main__": - unittest.main() diff --git a/graph_net/agent/tests/test_model_metadata.py b/graph_net/agent/tests/test_model_metadata.py deleted file mode 100644 index c9bc68cca6..0000000000 --- a/graph_net/agent/tests/test_model_metadata.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Tests for ModelMetadata""" - -import unittest - -from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata - - -class TestModelMetadata(unittest.TestCase): - """Test ModelMetadata data class""" - - def test_basic_creation(self): - """Test basic metadata creation""" - meta = ModelMetadata( - model_id="bert-base-uncased", - input_shapes={"input_ids": [1, 128]}, - input_dtypes={"input_ids": "int64"}, - ) - self.assertEqual(meta.model_id, "bert-base-uncased") - self.assertEqual(meta.input_shapes["input_ids"], [1, 128]) - self.assertEqual(meta.input_dtypes["input_ids"], "int64") - - def test_multiple_inputs(self): - """Test metadata with multiple inputs""" - meta = ModelMetadata( - model_id="test-model", - input_shapes={ - "input_ids": [1, 128], - "attention_mask": [1, 128], - }, - input_dtypes={ - "input_ids": "int64", - "attention_mask": "int64", - }, - ) - self.assertEqual(len(meta.input_shapes), 2) - self.assertEqual(len(meta.input_dtypes), 2) - - def test_model_type(self): - """Test metadata with model type""" - meta = ModelMetadata( - model_id="bert-base-uncased", - input_shapes={"input_ids": [1, 128]}, - input_dtypes={"input_ids": "int64"}, - model_type="bert", - ) - self.assertEqual(meta.model_type, "bert") - - def test_empty_input_shapes_raises_error(self): - """Test that empty input_shapes raises error""" - with self.assertRaises(ValueError): - ModelMetadata( - model_id="test", - input_shapes={}, - input_dtypes={"input_ids": "int64"}, - ) - - def test_empty_input_dtypes_raises_error(self): - """Test that empty input_dtypes raises error""" - with self.assertRaises(ValueError): - ModelMetadata( - model_id="test", - input_shapes={"input_ids": [1, 128]}, - input_dtypes={}, - ) - - def test_mismatched_keys_raises_error(self): - """Test that mismatched keys raise error""" - with self.assertRaises(ValueError): - ModelMetadata( - model_id="test", - input_shapes={"input_ids": [1, 128]}, - input_dtypes={"attention_mask": "int64"}, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/graph_net/agent/tests/test_utils.py b/graph_net/agent/tests/test_utils.py deleted file mode 100644 index ccaeaf6e3c..0000000000 --- a/graph_net/agent/tests/test_utils.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Tests for utility modules""" - -import tempfile -import unittest - -from graph_net.agent.utils.exceptions import ( - AgentError, - ModelFetchError, - AnalysisError, - CodeGenError, - ExtractionError, - VerificationError, -) -from graph_net.agent.utils.workspace_manager import WorkspaceManager - - -class TestExceptions(unittest.TestCase): - """Test exception classes""" - - def test_exception_hierarchy(self): - """Test exception inheritance""" - self.assertTrue(issubclass(ModelFetchError, AgentError)) - self.assertTrue(issubclass(AnalysisError, AgentError)) - self.assertTrue(issubclass(CodeGenError, AgentError)) - self.assertTrue(issubclass(ExtractionError, AgentError)) - self.assertTrue(issubclass(VerificationError, AgentError)) - - def test_exception_raising(self): - """Test exception can be raised""" - with self.assertRaises(ModelFetchError): - raise ModelFetchError("Test error") - - -class TestWorkspaceManager(unittest.TestCase): - """Test WorkspaceManager""" - - def setUp(self): - """Set up test workspace""" - self.temp_dir = tempfile.mkdtemp() - self.workspace = WorkspaceManager(self.temp_dir) - - def test_directory_creation(self): - """Test workspace directories are created""" - self.assertTrue(self.workspace.models_dir.exists()) - self.assertTrue(self.workspace.generated_dir.exists()) - self.assertTrue(self.workspace.samples_dir.exists()) - self.assertTrue(self.workspace.logs_dir.exists()) - - def test_get_model_dir(self): - """Test model directory path generation""" - model_id = "bert-base-uncased" - model_dir = self.workspace.get_model_dir(model_id) - self.assertEqual(model_dir.name, "bert-base-uncased") - self.assertEqual(model_dir.parent, self.workspace.models_dir) - - def test_get_generated_dir(self): - """Test generated directory path generation""" - model_id = "test/model" - gen_dir = self.workspace.get_generated_dir(model_id) - self.assertEqual(gen_dir.name, "test_model") - self.assertEqual(gen_dir.parent, self.workspace.generated_dir) - - def test_get_sample_dir(self): - """Test sample directory path generation""" - model_id = "resnet50" - sample_dir = self.workspace.get_sample_dir(model_id) - self.assertEqual(sample_dir.name, "resnet50") - self.assertEqual(sample_dir.parent, self.workspace.samples_dir) - - def test_get_log_path(self): - """Test log path generation""" - model_id = "test-model" - log_path = self.workspace.get_log_path(model_id, "20240101_120000") - self.assertIn("test-model", log_path.name) - self.assertIn("20240101_120000", log_path.name) - self.assertEqual(log_path.parent, self.workspace.logs_dir) - - -if __name__ == "__main__": - unittest.main()