Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion graph_net/agent/code_generator/llm_code_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@
attention_mask = torch.ones((1, 64), dtype=torch.long).to(device)
decoder_input_ids = torch.randint(0, min(vocab_size-1, 1000), (1, 32), dtype=torch.long).to(device)

**MoE 类**(mixtral/qwen2_moe/deepseek_v2/dbrx/olmoe 等):
- 架构上仍是文本模型,输入与文本类完全相同(input_ids + attention_mask)
- 加载同样用 AutoModel.from_config(config),无需任何特殊处理
⚠️ vocab_size 通常很大(32000+),严格用 min(vocab_size-1, 30000) 作为 randint 上界
关键 config 字段:num_local_experts(Mixtral)/ num_experts(Qwen2-MoE)/ n_routed_experts(DeepSeek)

**扩散模型类**(UNet2DConditionModel / DiT / stable-diffusion / SDXL 等):
from diffusers import UNet2DConditionModel
_config = UNet2DConditionModel.load_config(model_dir)
model = UNet2DConditionModel.from_config(_config)
# 从 config 读取关键维度
in_channels = _config.get("in_channels", 4)
sample_size = _config.get("sample_size", 64)
cross_attention_dim = _config.get("cross_attention_dim", 768)
sample = torch.randn(1, in_channels, sample_size, sample_size).to(device)
timestep = torch.tensor([1]).to(device)
encoder_hidden_states = torch.randn(1, 77, cross_attention_dim).to(device)
# 调用必须用位置参数,不能 **inputs
wrapped(sample, timestep, encoder_hidden_states)
⚠️ dynamic 必须为 False;调用格式固定为位置参数,禁止用 **inputs dict 展开

## 【常见报错 → 修复方法】
| 报错关键词 | 修复方法 |
|---|---|
Expand All @@ -77,6 +98,9 @@
| "sentencepiece" / "tiktoken" ImportError | 不使用 tokenizer,用 torch.randint 直接构造 input_ids |
| "PendingUnbackedSymbolNotFound" | 确认 dynamic=False(不要改为 True) |
| decoder_input_ids missing | Seq2Seq 模型需要同时传 input_ids 和 decoder_input_ids |
| "encoder_hidden_states" required(UNet) | 扩散模型必须以位置参数传入 encoder_hidden_states,不能省略 |
| UNet sample/timestep 形状错误 | 检查 in_channels/sample_size/cross_attention_dim 是否从 config 正确读取 |
| MoE expert 路由 RuntimeError | 输入格式与普通文本模型相同,通常是 vocab 越界,检查 randint 上界是否 < vocab_size |
"""


Expand Down Expand Up @@ -312,9 +336,29 @@ def _extract_key_fields(model_dir: Path) -> str:
"patch_size",
"num_mel_bins",
"chunk_length",
# MoE routing (field names vary across models)
"num_local_experts",
"num_experts_per_tok",
"num_experts",
"n_routed_experts",
"moe_intermediate_size",
"num_shared_experts",
# Diffusion / UNet
"in_channels",
"sample_size",
"cross_attention_dim",
"layers_per_block",
# Seq2Seq
"is_encoder_decoder",
"decoder_start_token_id",
# GQA (Llama/Mistral family)
"num_key_value_heads",
# Audio
"feature_size",
"sample_rate",
]
result = {k: cfg[k] for k in keys if k in cfg}
# 对嵌套 config 只取 model_type
# 对嵌套 config 只取关键字段
for nested in ("audio_config", "vision_config", "text_config"):
if isinstance(result.get(nested), dict):
result[nested] = {
Expand All @@ -326,6 +370,10 @@ def _extract_key_fields(model_dir: Path) -> str:
"num_channels",
"num_mel_bins",
"hidden_size",
"num_local_experts",
"num_experts",
"n_routed_experts",
"sample_rate",
)
if k in result[nested]
}
Expand Down
106 changes: 92 additions & 14 deletions graph_net/agent/code_generator/template_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,19 @@ def _model_short_name(model_id: str) -> str:
return model_id.replace("/", "_")

def _generate_code(self, model_dir: Path, model_metadata: ModelMetadata) -> str:
"""Generate complete extraction script code string"""
# Generate model loading code
load_code = self._generate_model_loader(model_dir, model_metadata)
"""Generate complete extraction script code string."""
if model_metadata.architecture_type == "diffusion":
return self._generate_diffusion_code(model_dir, model_metadata)
return self._generate_standard_code(model_dir, model_metadata)

# Generate input construction code
def _generate_standard_code(
self, model_dir: Path, model_metadata: ModelMetadata
) -> str:
"""Generate standard (transformers-based) extraction script."""
load_code = self._generate_model_loader(model_dir, model_metadata)
input_code = self._generate_input_code(model_metadata)

short_name = self._model_short_name(model_metadata.model_id)

# Generate main code
code = f"""import torch
try:
from transformers import AutoModel
Expand Down Expand Up @@ -102,6 +105,48 @@ def main():
with torch.no_grad():
wrapped(**inputs)

if __name__ == "__main__":
main()
"""
return code

def _generate_diffusion_code(
self, model_dir: Path, model_metadata: ModelMetadata
) -> str:
"""Generate extraction script for diffusion models (diffusers UNet)."""
load_code = self._generate_model_loader(model_dir, model_metadata)
input_code = self._generate_input_code(model_metadata)
short_name = self._model_short_name(model_metadata.model_id)

# Diffusion model forward takes positional args, not **inputs dict
code = f"""import torch
try:
from diffusers import UNet2DConditionModel
except ImportError:
raise ImportError("diffusers is required. Install with: pip install diffusers")

import graph_net

def main():
# Load model
{self._indent(load_code, 4)}

# Prepare inputs
{self._indent(input_code, 4)}

# Extract graph
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()

sample = inputs["sample"].to(device)
timestep = inputs["timestep"].to(device)
encoder_hidden_states = inputs["encoder_hidden_states"].to(device)

wrapped = graph_net.torch.extract(name="{short_name}", dynamic=False)(model).eval()

with torch.no_grad():
wrapped(sample, timestep, encoder_hidden_states)

if __name__ == "__main__":
main()
"""
Expand All @@ -110,14 +155,47 @@ def main():
def _generate_model_loader(
self, model_dir: Path, model_metadata: ModelMetadata
) -> str:
"""Generate model loading code — config only, random weights"""
"""Generate model loading code based on architecture type."""
model_path = str(model_dir).replace("\\", "/")

return (
f"from transformers import AutoConfig\n"
f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n'
f"model = AutoModel.from_config(_config)"
)
arch = model_metadata.architecture_type

if arch == "seq2seq":
return (
f"from transformers import AutoConfig, AutoModelForSeq2SeqLM\n"
f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n'
f"model = AutoModelForSeq2SeqLM.from_config(_config)"
)
elif arch == "diffusion":
return (
f"from diffusers import UNet2DConditionModel\n"
f'_config = UNet2DConditionModel.load_config("{model_path}")\n'
f"model = UNet2DConditionModel.from_config(_config)"
)
else:
# text, moe, vision, multimodal, audio, None → AutoModel
# If model_type is not present in config.json (e.g. prajjwal1/bert-tiny),
# inject the inferred model_type so AutoConfig can resolve the class.
model_type = model_metadata.model_type
if model_type:
return (
f"import json as _json, os as _os, tempfile as _tmp\n"
f"from transformers import AutoConfig, AutoModel\n"
f'_raw = _json.load(open(_os.path.join("{model_path}", "config.json")))\n'
f'if "model_type" not in _raw:\n'
f' _raw["model_type"] = "{model_type}"\n'
f" _td = _tmp.mkdtemp()\n"
f' _json.dump(_raw, open(_os.path.join(_td, "config.json"), "w"))\n'
f" _config = AutoConfig.from_pretrained(_td, trust_remote_code=True)\n"
f"else:\n"
f' _config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n'
f"model = AutoModel.from_config(_config)"
)
else:
return (
f"from transformers import AutoConfig, AutoModel\n"
f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n'
f"model = AutoModel.from_config(_config)"
)

def _generate_input_code(self, model_metadata: ModelMetadata) -> str:
"""Generate input tensor construction code based on model metadata"""
Expand All @@ -129,7 +207,7 @@ def _generate_input_code(self, model_metadata: ModelMetadata) -> str:
shape_tuple = f"({', '.join(map(str, shape))})"

if dtype == "int64":
if "input_ids" in name.lower():
if "input_ids" in name.lower() or "decoder_input_ids" in name.lower():
safe_vocab_size = self._calculate_safe_vocab_size(model_metadata)
lines.append(
f'inputs["{name}"] = torch.randint(0, {safe_vocab_size}, {shape_tuple}, dtype={torch_dtype})'
Expand Down
26 changes: 26 additions & 0 deletions graph_net/agent/graph_net_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def extract_sample(self, model_id: str) -> ExtractionStatus:
self.logger.info(f"Starting extraction for model: {model_id}")

model_dir = self._fetch_model(model_id)
model_dir = self._resolve_model_dir(model_dir)
model_metadata = self._analyze_model(model_dir)
script_path = self._generate_script(model_dir, model_metadata, model_id)

Expand Down Expand Up @@ -199,6 +200,31 @@ def _fetch_model(self, model_id: str) -> Path:
self.logger.info(f"Model downloaded to: {model_dir}")
return model_dir

def _resolve_model_dir(self, model_dir: Path) -> Path:
"""
For diffusers pipeline repos (identified by model_index.json at root),
resolve to the UNet subdirectory which contains the actual UNet config.
Returns model_dir unchanged for non-pipeline repos.
"""
model_index = model_dir / "model_index.json"
if not model_index.exists():
return model_dir

# It's a diffusers pipeline — find the unet subdirectory
unet_dir = model_dir / "unet"
if unet_dir.is_dir() and (unet_dir / "config.json").exists():
self.logger.info(
f"Detected diffusers pipeline; using UNet subdir: {unet_dir}"
)
return unet_dir

# Pipeline without unet/ (e.g., image-to-image or non-SD pipeline)
self.logger.warning(
f"Diffusers pipeline detected but no unet/ subdir found in {model_dir}; "
"proceeding with root dir."
)
return model_dir

def _analyze_model(self, model_dir: Path):
"""Analyze model configuration to extract metadata"""
self.logger.info("Analyzing model configuration")
Expand Down
Loading
Loading