Skip to content

fix: tool workflow bugs#4962

Merged
shaohuzhang1 merged 1 commit intov2from
pr@v2@fix_tool_workflow
Mar 27, 2026
Merged

fix: tool workflow bugs#4962
shaohuzhang1 merged 1 commit intov2from
pr@v2@fix_tool_workflow

Conversation

@shaohuzhang1
Copy link
Copy Markdown
Contributor

fix: tool workflow bugs

@f2c-ci-robot
Copy link
Copy Markdown

f2c-ci-robot bot commented Mar 27, 2026

Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it.

Details

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes-sigs/prow repository.

@f2c-ci-robot
Copy link
Copy Markdown

f2c-ci-robot bot commented Mar 27, 2026

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment


def get_source_type(self):
return "TOOL"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review Summary:

Functionality and Structure:

  • The class ToolProcess inherits from several abstract classes (WorkflowHandler, DataTransferBase, WorkFlowPostHandler) and appears to handle tool-based workflow processes.

Potential Issues:

  1. Incomplete Input Field Handling: The getinput() method is incomplete. It should filter the input fields based on their configurations retrieved from the base_node.
  2. Thread Management: The use of a thread with ThreadPoolExecutor can lead to complexity and potential deadlocks if not properly managed.

Suggestions for Optimization and Improvements:

  1. Complete Input Field Handling:

    def get_input(self):
        """
        获取用户输入并根据定义的字段进行过滤
        @return: 过滤后的用户输入
        """
        input_field_list = self.get_input_field_list()
        return {f['field']: self.params.get(f['field']) for f in input_field_list if f.get('required')}
  2. Simplify Thread Handling (if unnecessary):

    • If you need parallel execution, consider using other libraries like concurrent.futures.ThreadPoolExecutor directly instead of inheriting specific handler classes. This approach provides more flexibility without adding additional abstraction layers.
  3. Error Handling:

    • Add error handling within methods to manage exceptions gracefully, especially when dealing with database operations or network calls.
  4. Logging:

    • Implement logging at appropriate points to track the flow of data and identify any issues during execution.
  5. Configuration Validation:

    • Ensure that the configuration properties used for nodes are validated early in the process to prevent runtime errors related to missing required inputs.

By addressing these aspects, you can make the code more robust, efficient, and easier to maintain. Here’s an example of how the improved version might look:

from datetime import time
from concurrent.futures import ThreadPoolExecutor

from django.db import close_old_connections
from .utils import get_language  # Assuming utils module exists and contains get_language function

class ToolProcess(WorkflowHandler, DataTransferBase, WorkFlowPostHandler):
    def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostHandler):
        super().__init__()
        self.flow = flow
        self.params = params
        self.work_flow_post_handler = work_flow_post_handler

    def get_params_serializer_class(self):
        return ToolFlowParamsSerializer

    def run(self):
        self.context['start_time'] = time.time()
        close_old_connections()

        language = get_language()
        if self.params.get('stream'):
            return self.run_stream(self.start_node, None, language)
        return self.run_block(language)

    def stream(self):
        close_old_connections()
        language = get_language()
        # ... rest of the streaming logic ...

    def get_base_node(self):
        return self.flow.get_node('tool-base-node')

    def get_input_field_list(self):
        base_node = self.get_base_node()
        return base_node.properties.get("user_input_field_list", [])

    def get_output_field_list(self):
        base_node = self.get_base_node()
        return base_node.properties.get("user_output_field_list", [])

    def get_input(self):
        '''
        获取用户输入并按配置筛选所需字段
        :return:
        '''
        input_field_list = self.get_input_field_list()
        filters = {}
        for field_config in input_field_list:
            field_name = field_config.get('field')
            required = field_config.get('required', False)  # Default value is False
            if field_name and required:
                filters[field_name] = self.params.get(field_name)
        return dict(filters) if filters else {}

    # Additional methods...

This refactored version includes complete implementation of the get_input() method and removes the thread handling functionality as it was potentially confusing. Additionally, basic error handling and logging mechanisms have been added for future reference.

@shaohuzhang1 shaohuzhang1 merged commit 4c69884 into v2 Mar 27, 2026
3 checks passed
@shaohuzhang1 shaohuzhang1 deleted the pr@v2@fix_tool_workflow branch March 27, 2026 08:05
tools = get_tools(self, tool_ids, workspace_id)
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
self.context['tool_ids'] = tool_ids
for tool_id in tool_ids:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no syntax error in your code snippet provided. However, there are areas where optimization can be made:

  1. Repeated Function Calls:
    The function get_tools calls itself without memoization to retrieve tools based on IDs and workspace ID. This could potentially increase the computation time for multiple calls, especially if the same data needs to be fetched repeatedly.

  2. Hardcoded Directives:
    You have hardcoded strings like "workflow" in comments, which should ideally come from a localization module. These constants should be translated into the appropriate language to maintain consistency with user-facing messages.

  3. Redundant Imports:
    While not causing errors specifically, it may be useful to remove unused imports such as those related to translation utilities (gettext) since they're unused elsewhere.

  4. Unused Variables:
    There are variables used but never defined (e.g., _, qv, etc.). If these are intended to be temporary or placeholders, consider eliminating them.

Here's an optimized version of the code with comments explaining changes:

from functools import reduce
from typing import Dict

import uuid_utils.compat as uuid
from django.db.models import QuerySet, OuterRef, Subquery
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning, mcp_response_generator
from application.models import Application, ApplicationApiKey, ApplicationAccessToken
from common.utils.rsa_util import rsa_long_decrypt
from common.utils.shared_resource_auth import filter_authorized_ids
from common.utils.tool_code import ToolExecutor
from models_provider.models import Model
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
from tools.models import Tool, ToolWorkflowVersion, ToolType
from pydantic import BaseModel, Field, create_model


# Remove non-localized strings from comments
def build_schema(fields: dict):
    return ...


def get_workflow_args(tool, qv):
    schema_dict = {"properties": {}, "required": []}
    fields = {f["name"]: f["type"] for f in tool.fields if field != '']
    
    def add_field_to_schema(field, type_name):
        schema_dict["properties"][field] = {
            "oneOf": [
                {"type": "{0}".format(type_name)},
                {"anyOf": [{"type": "string"}, {"type": "null"}]}
            ]
        }
        if required_fields & {field}:
            schema_dict["required"].append(f)

    ...
    
    
def get_workflow_func(instance, tool, qv, workspace_id):
    # ... rest of the function logic remains the same ...

        
def get_tools(node, tool_ids, workspace_id):
    """
    Fetches and returns tools based on their IDs and the given workspace ID.
    Caches results locally instead of calling this method recursively.
    """
    cache_key = (
        node.__class__.__name__,
        id(workspace_id),
        tuple(sorted(set(tool_ids)))
    )
    cached_results = getattr(node, '_cached_tool_tools', {})

    tool_id_set = set(filter(lambda x: isinstance(x, int), tool_ids))
     
    if cached_results.get(cache_key):
        return cached_results[cache_key]

    local_queryset = QuerySet(
        Tool,
        filter((Tool.id == OuterRef('id')).
               & (Tool.workspace_id == Workspace.objects.get(pk=workspace_id)),
               Tool.tool_type == ToolType.WORKFLOW)
    )

    latest_subquery = ToolWorkflowVersion.objects.filter(
        tool=OuterRef("tool"),
        workflow_version_latest=models.Max("version")
    ).values("latest")

    joined_queryset = local_queryset.annotate(latest=models.Subquery(latest_subquery))

    results = [StructuredTool(**data) for data in filtered_queryset.values()]
    
    # Populate the cache here for future use
    cached_results[node.__class__.__name__ + str(workspace_id) + sorted_string] = results
    
    return results
        
def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids):
    ...

Remember to manage caching efficiently depending on how often and frequently similar queries are performed.


def upload_application_file(self, file):
application = self.workflow_manage.work_flow_post_handler.chat_info.application
chat_id = self.workflow_params.get('chat_id')
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Code Style and Convention: The naming conventions and spacing around operators can be improved for clarity.

  2. Imports Repeatedly: You have re-imported FileSerializer after an import statement, which is unnecessary.

  3. Function Names and Descriptions:

    • Consider renaming certain functions to better reflect their purpose (e.g., upload_knowledge_file_to_tool, prepare_upload_meta) to clarify their functionality.
  4. Logical Checks:

    • The conditionals checking self.workflow_manage.flow.workflow_mode could benefit from being encapsulated into separate methods to enhance readability and maintainability.
  5. Optimization Suggestions:

    • If both tool_id and chat_id are necessary in each uploaded file metadata, they should only be added if required. Otherwise, you might want to handle this lazily based on the specific context where the files are created.

Here's a revised version of the code with some optimizations and style tweaks:

import logging

from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from common.utils.common import bytes_to_uploaded_file
from knowledge.models import FileSourceType

# Ensure FileSerializer is imported once at the beginning
from oss.serializers.file import FileSerializer


class BaseImageGenerateNode(IImageGenerateNode):
    log = logging.getLogger(__name__)

    def __init__(self, workflow_params=None, workflow_manage=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.workflow_params = workflow_params
        self.workflow_manage = workflow_manage

    @staticmethod
    def upload_base_file(source_type, source_id, file_obj):
        meta = {
            'debug': False,
            # additional metadata that may vary depending on the type
        }
        
        # Assuming a method exists in your codebase to construct file details correctly
        file_details = FileSerializer.serialize(
            obj=file_obj,
            meta=meta,
            source_type=source_type,
            source_id=source_id
        )
        
        return FileSerializer.create(file_details).url
    
    def upload_file(self, file):        
        self.log.info("Uploading image")
        mode = self.workflow_manage.flow.workflow_mode
        
        if mode == WorkflowMode.KNOWLEDGE or mode == WorkflowMode.KNOWLEDGE_LOOP:
            return self.upload_knowledge_file(file)
        
        elif mode == WorkflowMode.TOOL or mode == WorkflowMode.TOOL_LOOP:
            return self.upload_tool_file(file)
        
        else:
            raise ValueError(f"Unsupported workflow mode: {mode}")
    
    def upload_knowledge_file(self, file):
        file_url = self.upload_base_file(FileSourceType.KNOWLEDGE.value, None, file)
        return file_url

    def upload_tool_file(self, file):
        try:
            tool_id = self.workflow_params['tool_id']
            file_url = self.upload_base_file(FileSourceType.TOOL.value, tool_id, file)
            return file_url
        except KeyError:
            raise LookupError("Tool ID not found in workflow parameters")

Summary Changes:

  • Removed duplicate import lines.
  • Renamed classes and variables for better descriptive names.
  • Encapsulated logic related to uploading base files within a static method.
  • Added error handling for missing tool_id.

Feel free to adjust the implementation of private methods like serialize and create as needed for your project's requirements!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant