Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions backend/app/api/v1/endpoints/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Dict, Any
from loguru import logger as log

# --- 1. 导入所有需要的 Schema 和响应封装 ---
# --- 1. Import required Schemas and response wrappers ---
from app.schemas.operator import (
OperatorSchema,
OperatorDetailSchema,
Expand All @@ -14,7 +14,7 @@
from app.api.v1.resp import ok
from app.api.v1.envelope import ApiResponse

# --- 2. 导入服务层 ---
# --- 2. Import service layer ---
from app.services.operator_registry import OPS_JSON_PATH
from app.core.container import container

Expand Down Expand Up @@ -79,7 +79,7 @@ def get_operator_detail_by_name(op_name: str, lang: str = "zh"):
- Then match name in all buckets and return.
"""
try:
# 确保缓存存在
# Ensure cache exists
ops_json_path = OPS_JSON_PATH.with_suffix(f'.{lang}.json')
if not ops_json_path.exists():
log.info("ops.json cache file not found, triggering automatic operator scan and generation...")
Expand All @@ -88,7 +88,7 @@ def get_operator_detail_by_name(op_name: str, lang: str = "zh"):
with open(ops_json_path, "r", encoding="utf-8") as f:
ops_data = json.load(f)

# 在所有 bucket 中查找指定算子
# Look up the operator in all buckets
for bucket_name, items in ops_data.items():
if not isinstance(items, list):
continue
Expand All @@ -98,7 +98,7 @@ def get_operator_detail_by_name(op_name: str, lang: str = "zh"):
if op.get("name") == op_name:
return ok(op)

# 未找到
# Not found
raise HTTPException(status_code=404, detail=f"Operator '{op_name}' not found")

except json.JSONDecodeError as e:
Expand Down
108 changes: 52 additions & 56 deletions backend/app/services/operator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@
from dataflow.utils.registry import OPERATOR_REGISTRY, PROMPT_REGISTRY
from app.core.config import settings

# --- 1. 路径定义 ---
# __file__ : .../backend/app/services/operator_registry.py
# .parent.parent.parent : .../backend
# --- 1. Path definitions ---
# __file__ is: .../backend/app/services/operator_registry.py
# .parent.parent.parent is: .../backend
BACKEND_DIR = Path(__file__).parent.parent.parent
OPS_JSON_PATH = BACKEND_DIR / settings.OPS_JSON_PATH
RESOURCE_DIR = OPS_JSON_PATH.parent


# --- 2. 私有辅助函数 (模块内部实现) ---
# --- 2. Private helper functions (module-internal) ---

def _safe_json_val(val: Any) -> Any:
"""
inspect.Parameter.empty 和其他非序列化值转为 JSON 安全的值
Convert inspect.Parameter.empty and other non-serializable values to JSON-safe values.
"""
if val is inspect.Parameter.empty:
return None # JSON 中,"无默认值" 用 null 表示
return None # In JSON, "no default" is represented as null

if isinstance(val, type) or callable(val):
return str(val)
Expand All @@ -33,19 +33,19 @@ def _safe_json_val(val: Any) -> Any:
return str(val)

def _param_to_dict(p: inspect.Parameter) -> Dict[str, Any]:
""" inspect.Parameter 转成 JSON 可序列化的字典"""
"""Convert inspect.Parameter to a JSON-serializable dict."""
return {
"name": p.name,
"default_value": _safe_json_val(p.default),
"kind": p.kind.name, # POSITIONAL_OR_KEYWORD / VAR_POSITIONAL / ...
"kind": p.kind.name,
}

def _get_method_params(
method: Any, skip_first_self: bool = False
) -> List[Dict[str, Any]]:
"""
提取方法形参,转换为列表。
skip_first_self=True 时会丢掉第一个 'self' 参数。
Extract method parameters and convert to a list.
When skip_first_self=True, the first 'self' parameter is dropped.
"""
try:
sig = inspect.signature(method)
Expand All @@ -56,68 +56,66 @@ def _get_method_params(
except Exception as e:
log.warning(f"Error getting method {method} parameters: {e}")
return []



def _call_get_desc_static(cls: type, lang: str = "zh") -> str | None:
"""
仅当类的 get_desc 被显式声明为 @staticmethod 时才调用。
如果 get_desc 返回一个列表,则自动用换行符拼接成字符串。
Only call when the class's get_desc is explicitly declared as @staticmethod.
If get_desc returns a list, join it with newlines into a string.
"""
func_obj = cls.__dict__.get("get_desc")
if not isinstance(func_obj, staticmethod):
return "N/A ( staticmethod)"
return "N/A (not staticmethod)"

fn = func_obj.__func__
params = list(inspect.signature(fn).parameters)
try:
# --- 变更开始 ---
result: Any = None
if params == ["lang"]:
result = fn(lang)
elif params == ["self", "lang"]:
result = fn(None, lang)
else:
# 签名不匹配
return "N/A (签名不匹配)"
return "N/A (signature mismatch)"

# 核心修复:检查返回类型
if isinstance(result, list):
return "\n".join(str(item) for item in result) # 将列表拼接为字符串
return "\n".join(str(item) for item in result)
elif result:
return str(result) # 确保返回的是字符串
# --- 变更结束 ---
return str(result)

except Exception as e:
log.warning(f"Failed to call {cls.__name__}.get_desc: {e}")

return "N/A (Call failed)"


def _gather_single_operator(
op_name: str, cls: type, node_index: int, lang: str = "zh"
) -> Tuple[str, Dict[str, Any]]:
"""
收集单个算子的全部详细信息,用于生成缓存。
Gather full details for a single operator, used for cache generation.
"""
# 1) 分类(大类 category,用于 ops.json 的顶层 key)
# 1) Category (top-level key for ops.json)
category = "unknown"
if hasattr(cls, "__module__"):
parts = cls.__module__.split(".")
if len(parts) >= 3 and parts[0] == "dataflow" and parts[1] == "operators":
category = parts[2]

# 2) 描述 (使用 staticmethod 逻辑)
# 2) Description (using staticmethod logic)
description = _call_get_desc_static(cls, lang=lang) or ""

# 3) 简化信息里也有的 type(三级分类)和 allowed_prompts
# 3) Type (three-level classification) and allowed_prompts
op_type_category = OPERATOR_REGISTRY.get_type_of_objects().get(op_name, "Unknown/Unknown")
# 新格式: ['dataflow', 'operators', 'core_text', 'generate', 'text2qa_generator']
# [0]=dataflow前缀, [1]=operators, [2]=level_1大类, [3]=level_2类型, [4]=具体名称
type1 = op_type_category[2] if len(op_type_category) > 2 else "Unknown" # level_1: 算子大类(如 core_text)
type2 = op_type_category[3] if len(op_type_category) > 3 else "Unknown" # level_2: 算子类型(如 generate/filter/eval)
# Format: ['dataflow', 'operators', 'core_text', 'generate', 'text2qa_generator']
# [0]=dataflow prefix, [1]=operators, [2]=level_1 category, [3]=level_2 type, [4]=concrete name
type1 = op_type_category[2] if len(op_type_category) > 2 else "Unknown"
type2 = op_type_category[3] if len(op_type_category) > 3 else "Unknown"

allowed_prompt_templates = getattr(cls, "ALLOWED_PROMPTS", [])
allowed_prompt_templates = [prompt_name.__name__ for prompt_name in allowed_prompt_templates]

# 4) command 形参
# 4) Command parameters
init_params = _get_method_params(getattr(cls, "__init__", None), skip_first_self=True)
run_params = _get_method_params(getattr(cls, "run", None), skip_first_self=True)

Expand All @@ -141,12 +139,12 @@ def _gather_single_operator(
return category, info


# --- 3. 公共服务类 ---
# --- 3. Public service class ---

class OperatorRegistry:
"""
封装所有算子(Operator)相关的业务逻辑,
包括加载、实时查询和生成缓存。
Encapsulates all operator-related business logic,
including loading, live queries, and cache generation.
"""
def __init__(self):
self._op_registry = OPERATOR_REGISTRY
Expand All @@ -164,19 +162,18 @@ def __init__(self):


def get_op_list(self, lang: str = "zh") -> list[dict]:
"""获取简化的算子列表 (实时计算),用于前端列表展示。"""
"""Get simplified operator list (computed on demand) for frontend listing."""

op_list: list[dict] = []
for op_name, op_cls in self.op_obj_map.items():
# 类型信息,三级分类
# Type info, three-level classification
op_type_category = self.op_to_type.get(op_name, "Unknown/Unknown")

# 新格式: ['dataflow', 'operators', 'core_text', 'generate', 'text2qa_generator']
# [0]=dataflow前缀, [1]=operators, [2]=level_1大类, [3]=level_2类型, [4]=具体名称
type1 = op_type_category[2] if len(op_type_category) > 2 else "Unknown" # level_1: 算子大类(如 core_text)
type2 = op_type_category[3] if len(op_type_category) > 3 else "Unknown" # level_2: 算子类型(如 generate/filter/eval)
# Format: ['dataflow', 'operators', 'core_text', 'generate', 'text2qa_generator']
type1 = op_type_category[2] if len(op_type_category) > 2 else "Unknown"
type2 = op_type_category[3] if len(op_type_category) > 3 else "Unknown"

# 描述
# Description
if hasattr(op_cls, "get_desc") and callable(op_cls.get_desc):
desc = op_cls.get_desc(lang=lang)
else:
Expand All @@ -187,7 +184,7 @@ def get_op_list(self, lang: str = "zh") -> list[dict]:
allowed_prompt_templates = getattr(op_cls, "ALLOWED_PROMPTS", [])
allowed_prompt_templates = [prompt_name.__name__ for prompt_name in allowed_prompt_templates]

# get parameter info in .run()(这里只保留简要信息,不展开参数细节)
# Parameter info from .run() (brief only, no full param details)
op_info = {
"name": op_name,
"type": {
Expand All @@ -204,18 +201,18 @@ def get_op_list(self, lang: str = "zh") -> list[dict]:

def dump_ops_to_json(self, lang: str = "zh") -> Dict[str, List[Dict[str, Any]]]:
"""
执行一次完整的算子扫描,包含详细参数,并写入 ops_{lang}.json 缓存文件。
这是一个耗时操作。
Run a full operator scan with detailed params and write to ops_{lang}.json cache.
This is a heavy operation.
"""
log.info(f"开始扫描算子 (dump_ops_to_json),生成 {OPS_JSON_PATH.with_suffix(f'.{lang}.json')} ...")
log.info(f"Scanning operators (dump_ops_to_json), generating {OPS_JSON_PATH.with_suffix(f'.{lang}.json')} ...")

all_ops: Dict[str, List[Dict[str, Any]]] = {}
default_bucket: List[Dict[str, Any]] = []

idx = 1
# 使用 __init__ 中加载好的 op_obj_map
# Use op_obj_map loaded in __init__
for op_name, cls in self.op_obj_map.items():
# 调用私有辅助函数 _gather_single_operator
# Call private helper _gather_single_operator
category, info = _gather_single_operator(op_name, cls, idx, lang=lang)

all_ops.setdefault(category, []).append(info)
Expand All @@ -224,25 +221,24 @@ def dump_ops_to_json(self, lang: str = "zh") -> Dict[str, List[Dict[str, Any]]]:

all_ops["Default"] = default_bucket

# 确保目录存在
# Ensure directory exists
RESOURCE_DIR.mkdir(parents=True, exist_ok=True)
try:
with open(OPS_JSON_PATH.with_suffix(f'.{lang}.json'), "w", encoding="utf-8") as f:
json.dump(all_ops, f, ensure_ascii=False, indent=2)
log.info(f"算子信息已成功写入 {OPS_JSON_PATH.with_suffix(f'.{lang}.json')} ({len(default_bucket)} )")
log.info(f"Operator info written to {OPS_JSON_PATH.with_suffix(f'.{lang}.json')} ({len(default_bucket)} operators)")
except Exception as e:
log.error(f"写入 {OPS_JSON_PATH.with_suffix(f'.{lang}.json')} 失败: {e}")
raise # 抛出异常,让 API 层捕获
log.error(f"Failed to write {OPS_JSON_PATH.with_suffix(f'.{lang}.json')}: {e}")
raise

return all_ops

def get_op_details(self, op_name: str, lang: str = "zh") -> Optional[Dict[str, Any]]:
"""获取单个算子的详细信息 (包含参数默认值)"""
"""Get detailed info for a single operator (including param defaults)."""
cls = self.op_obj_map.get(op_name)
if not cls:
return None

# 使用模块内的 _gather_single_operator 函数
# 注意:_gather_single_operator 需要 node_index,这里可以传 0 或 -1

# Use module-level _gather_single_operator; node_index can be 0 or -1 here
category, info = _gather_single_operator(op_name, cls, -1, lang=lang)
return info