Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions backend/app/api/v1/endpoints/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ def list_pipelines(request: Request):
def create_pipeline(request: Request, payload: PipelineIn):
try:
logger.info(f"Request: {request.method} {request.url.path}, Pipeline name: {payload.name}")
pipeline = _PIPELINE_REGISTRY.create_pipeline(payload)
pipeline_in_data = payload.model_dump()

operators = pipeline_in_data.get("config", {}).get("operators", [])
for op in operators:
op["params"] = _PIPELINE_REGISTRY.parse_frontend_params(op.get("params", []))

pipeline = _PIPELINE_REGISTRY.create_pipeline(pipeline_in_data)
return created(pipeline)
except ValueError as e:
logger.error(f"Invalid pipeline configuration: {str(e)}", exc_info=True)
Expand All @@ -52,7 +58,13 @@ def get_pipeline(pipeline_id: str):
@router.put("/{pipeline_id}", response_model=ApiResponse[PipelineOut], operation_id="update_pipeline", summary="更新指定的Pipeline")
def update_pipeline(pipeline_id: str, payload: PipelineIn):
try:
updated_pipeline = _PIPELINE_REGISTRY.update_pipeline(pipeline_id, payload)
pipeline_in_data = payload.model_dump()

operators = pipeline_in_data.get("config", {}).get("operators", [])
for op in operators:
op["params"] = _PIPELINE_REGISTRY.parse_frontend_params(op.get("params", []))

updated_pipeline = _PIPELINE_REGISTRY.update_pipeline(pipeline_id, pipeline_in_data)
return ok(updated_pipeline)
except ValueError as e:
logger.error(f"Failed to update pipeline: {str(e)}")
Expand Down Expand Up @@ -80,6 +92,10 @@ async def execute_pipeline(request: Request, payload: PipelineExecutionRequest,
try:
logger.info(f"Request: {request.method} {request.url.path}")

if payload.config:
for op in payload.config.operators:
op.params = _PIPELINE_REGISTRY.parse_frontend_params(op.params)

# 调用服务层开始执行
execution_id, pipeline_config, initial_result = _PIPELINE_REGISTRY.start_execution(
pipeline_id=payload.pipeline_id,
Expand Down
2 changes: 1 addition & 1 deletion backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class Settings(BaseSettings):
ENV: str = "dev"
DATA_REGISTRY: str = "data/registry.yaml"
TASK_REGISTRY: str = "data/task_registry.yaml"
PIPELINE_REGISTRY: str = "data/pipeline_registry.json"
PIPELINE_REGISTRY: str = "data/pipeline_registry.yaml"
SERVING_REGISTRY: str = "data/serving_registry.yaml"
DataFlow_CORE_DIR: str = "data/dataflow_core"
OPS_JSON_PATH: str = "resources/ops.json"
Expand Down
48 changes: 24 additions & 24 deletions backend/app/schemas/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import Enum
from typing import List, Dict, Any, Optional, Union
from pydantic import BaseModel, Field, field_validator

from app.schemas.operator import OperatorDetailSchema
from dataflow.utils.storage import FileStorage

class Pipeline(str, Enum):
"""Pipeline类型枚举"""
Expand Down Expand Up @@ -32,33 +33,32 @@ class ExecutionStatus(str, Enum):
failed = "failed"


class PipelineOperator(BaseModel):
class PipelineOperator(BaseModel): # 画布上的pipeline类
"""Pipeline算子模型"""
name: str = Field(..., description="算子名称")
params: Dict[str, Any] = Field(default_factory=dict, description="算子参数配置")

@field_validator('name')
def validate_operator_name(cls, v: str) -> str:
"""验证算子名称格式"""
if not v.replace('_', '').isalnum():
raise ValueError('Operator name can only contain letters, numbers and underscores')
# 后续可以补充从可用算子集中验证算子名称是否存在
return v

location: tuple[int, int] = Field(default=(0, 0), description="算子在画布上的位置, 包含x和y两个坐标值")
# @field_validator('name')
# def validate_operator_name(cls, v: str) -> str:
# """验证算子名称格式"""
# if not v.replace('_', '').isalnum():
# raise ValueError('Operator name can only contain letters, numbers and underscores')
# # 后续可以补充从可用算子集中验证算子名称是否存在
# return v

class PipelineConfig(BaseModel):
"""Pipeline配置模型"""
file_path: str = Field(..., description="Pipeline文件路径")
input_dataset: str = Field(..., description="输入数据集ID")
# 用 list 的顺序代表算子执行顺序
operators: List[PipelineOperator] = Field(default_factory=list, description="算子执行序列")
run_config: Dict[str, Any] = Field(default_factory=dict, description="运行时配置参数")

@field_validator('operators')
def validate_operators(cls, v: List[PipelineOperator]) -> List[PipelineOperator]:
"""确保至少有一个算子"""
if not v:
raise ValueError('Pipeline must have at least one operator')
return v
# @field_validator('operators')
# def validate_operators(cls, v: List[PipelineOperator]) -> List[PipelineOperator]:
# """确保至少有一个算子"""
# if not v:
# raise ValueError('Pipeline must have at least one operator')
# return v


class PipelineIn(BaseModel):
Expand Down Expand Up @@ -87,12 +87,12 @@ class PipelineExecutionRequest(BaseModel):
pipeline_id: Optional[str] = Field(None, description="预定义Pipeline ID")
config: Optional[PipelineConfig] = Field(None, description="自定义Pipeline配置")

@field_validator('pipeline_id', 'config')
def validate_at_least_one(cls, v, info):
"""确保至少提供pipeline_id或config之一"""
if info.data.get('pipeline_id') is None and info.data.get('config') is None:
raise ValueError('Either pipeline_id or config must be provided')
return v
# @field_validator('pipeline_id', 'config')
# def validate_at_least_one(cls, v, info):
# """确保至少提供pipeline_id或config之一"""
# if info.data.get('pipeline_id') is None and info.data.get('config') is None:
# raise ValueError('Either pipeline_id or config must be provided')
# return v


class PipelineExecutionResult(BaseModel):
Expand Down
Loading