diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 3dcd6ee0..faf380eb 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -1,18 +1,37 @@ -name: Run Tests on Production Environment +name: Run Tests on: push: branches: - main - pull_request: - branches: - - main + # any branch other than main, will use the test key + - test workflow_dispatch: jobs: - test: + setup-and-test: runs-on: ubuntu-latest - continue-on-error: true + strategy: + fail-fast: false + matrix: + test-suite: [ + 'tests/unit', + 'tests/functional/file_asset', + 'tests/functional/data_asset', + 'tests/functional/benchmark', + 'tests/functional/model', + 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v1 --sdk_version_param PipelineFactory', + 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v2 --sdk_version_param PipelineFactory', + 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v1 --sdk_version_param PipelineFactory', + 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v2 --sdk_version_param PipelineFactory', + 'tests/functional/pipelines/designer_test.py', + 'tests/functional/pipelines/create_test.py', + 'tests/functional/finetune --sdk_version v1 --sdk_version_param FinetuneFactory', + 'tests/functional/finetune --sdk_version v2 --sdk_version_param FinetuneFactory', + 'tests/functional/general_assets', + 'tests/functional/apikey', + 'tests/functional/agent tests/functional/team_agent', + ] steps: - name: Checkout repository uses: actions/checkout@v4 @@ -21,54 +40,26 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.8" - + cache: 'pip' + - name: Install dependencies run: | python -m pip install --upgrade pip pip install ".[test]" + - name: Set environment variables run: | - echo "TEAM_API_KEY=${{ secrets.TEAM_API_KEY_PROD }}" >> $GITHUB_ENV - echo "BACKEND_URL=https://platform-api.aixplain.com" >> $GITHUB_ENV - echo "MODELS_RUN_URL=https://models.aixplain.com/api/v1/execute" >> $GITHUB_ENV - echo "PIPELINES_RUN_URL=https://platform-api.aixplain.com/assets/pipeline/execution/run" >> $GITHUB_ENV - - - name: Run Unit Tests - continue-on-error: true - run: python -m pytest tests/unit - - - name: Run General Assets - continue-on-error: true - run: python -m pytest tests/functional/general_assets - - - name: Run File Asset - continue-on-error: true - run: python -m pytest tests/functional/file_asset - - - name: Run Agent - continue-on-error: true - run: python -m pytest tests/functional/agent - - - name: Run Team Agent - continue-on-error: true - run: python -m pytest tests/functional/team_agent - - - name: Run Data - continue-on-error: true - run: python -m pytest tests/functional/data_asset - - - name: Run Benchmark - continue-on-error: true - run: python -m pytest tests/functional/benchmark - - - name: Run Pipelines - continue-on-error: true - run: python -m pytest tests/functional/pipelines - - - name: Run Api Key - continue-on-error: true - run: python -m pytest tests/functional/apikey - - - name: Run Finetuner - continue-on-error: true - run: python -m pytest tests/functional/finetune + if [ "${{ github.ref_name }}" = "refs/heads/main" ]; then + echo "TEAM_API_KEY=${{ secrets.TEAM_API_KEY_PROD }}" >> $GITHUB_ENV + echo "BACKEND_URL=https://platform-api.aixplain.com" >> $GITHUB_ENV + echo "MODELS_RUN_URL=https://models.aixplain.com/api/v1/execute" >> $GITHUB_ENV + echo "PIPELINES_RUN_URL=https://platform-api.aixplain.com/assets/pipeline/execution/run" >> $GITHUB_ENV + else + echo "TEAM_API_KEY=${{ secrets.TEAM_API_KEY }}" >> $GITHUB_ENV + echo "BACKEND_URL=https://test-platform-api.aixplain.com" >> $GITHUB_ENV + echo "MODELS_RUN_URL=https://test-models.aixplain.com/api/v1/execute" >> $GITHUB_ENV + echo "PIPELINES_RUN_URL=https://test-platform-api.aixplain.com/assets/pipeline/execution/run" >> $GITHUB_ENV + fi + + - name: Run Tests + run: python -m pytest ${{ matrix.test-suite}} \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml deleted file mode 100644 index 2e5b8614..00000000 --- a/.github/workflows/test.yaml +++ /dev/null @@ -1,75 +0,0 @@ -name: Run Tests on Test Environment - -on: - push: - branches: - - test - pull_request: - branches: - - test - workflow_dispatch: - -jobs: - test: - runs-on: ubuntu-latest - continue-on-error: true - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install ".[test]" - - - name: Set environment variables - run: | - echo "TEAM_API_KEY=${{ secrets.TEAM_API_KEY }}" >> $GITHUB_ENV - echo "BACKEND_URL=https://test-platform-api.aixplain.com" >> $GITHUB_ENV - echo "MODELS_RUN_URL=https://test-models.aixplain.com/api/v1/execute" >> $GITHUB_ENV - echo "PIPELINES_RUN_URL=https://test-platform-api.aixplain.com/assets/pipeline/execution/run" >> $GITHUB_ENV - - - name: Run Unit Tests - continue-on-error: true - run: python -m pytest tests/unit - - - name: Run General Assets - continue-on-error: true - run: python -m pytest tests/functional/general_assets - - - name: Run File Asset - continue-on-error: true - run: python -m pytest tests/functional/file_asset - - - name: Run Agent - continue-on-error: true - run: python -m pytest tests/functional/agent - - - name: Run Team Agent - continue-on-error: true - run: python -m pytest tests/functional/team_agent - - - name: Run Data - continue-on-error: true - run: python -m pytest tests/functional/data_asset - - - name: Run Benchmark - continue-on-error: true - run: python -m pytest tests/functional/benchmark - - - name: Run Pipelines - continue-on-error: true - run: python -m pytest tests/functional/pipelines - - - name: Run Api Key - continue-on-error: true - run: python -m pytest tests/functional/apikey - - - name: Run Finetuner - continue-on-error: true - run: python -m pytest tests/functional/finetune \ No newline at end of file diff --git a/aixplain/enums/embedding_model.py b/aixplain/enums/embedding_model.py index 8769e3dd..dbd6b9c1 100644 --- a/aixplain/enums/embedding_model.py +++ b/aixplain/enums/embedding_model.py @@ -25,6 +25,11 @@ class EmbeddingModel(Enum): OPENAI_ADA002 = "6734c55df127847059324d9e" SNOWFLAKE_ARCTIC_EMBED_L_V2_0 = "678a4f8547f687504744960a" JINA_CLIP_V2_MULTIMODAL = "67c5f705d8f6a65d6f74d732" + MULTILINGUAL_E5_LARGE = "67efd0772a0a850afa045af3" + BGE_M3 = "67f401032a0a850afa045b19" + + + def __str__(self): return self._value_ diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 9532bf72..bad94fb3 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -81,9 +81,9 @@ def create( Agent: created Agent """ warnings.warn( - "The 'instructions' parameter was recently added and serves the same purpose as 'description' did previously: set the role of the agent as a system prompt. " - "The 'description' parameter is still required and should be used to set a short summary of the agent's purpose. " - "For the next releases, the 'instructions' parameter will be required.", + "Use `instructions` to define the **system prompt**. " + "Use `description` to provide a **short summary** of the agent for metadata and dashboard display. " + "Note: In upcoming releases, `instructions` will become a required parameter.", UserWarning, ) from aixplain.factories.agent_factory.utils import build_agent diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index d64ab773..6dae6098 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -37,8 +37,15 @@ def build_tool(tool: Dict): ]: supplier = supplier_ break + assert "function" in tool, "Function is required for model tools" + function_name = tool.get("function") + try: + function = Function(function_name) + except ValueError: + valid_functions = [func.value for func in Function] + raise ValueError(f"Function {function_name} is not a valid function. The valid functions are: {valid_functions}") tool = ModelTool( - function=Function(tool.get("function", None)), + function=function, supplier=supplier, version=tool["version"], model=tool["assetId"], @@ -63,7 +70,7 @@ def build_tool(tool: Dict): description=tool["description"], database=database, schema=schema, tables=tables, enable_commit=enable_commit ) else: - raise Exception("Agent Creation Error: Tool type not supported.") + raise ValueError("Agent Creation Error: Tool type not supported.") return tool @@ -77,6 +84,9 @@ def build_agent(payload: Dict, tools: List[Tool] = None, api_key: Text = config. for tool in tools_dict: try: payload_tools.append(build_tool(tool)) + except (ValueError, AssertionError) as e: + logging.warning(str(e)) + continue except Exception: logging.warning( f"Tool {tool['assetId']} is not available. Make sure it exists or you have access to it. " diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 468a0eb9..c5a90367 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -4,7 +4,7 @@ from aixplain.modules.model.llm_model import LLM from aixplain.modules.model.index_model import IndexModel from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput -from aixplain.enums import DataType, Function, Language, OwnershipType, Supplier, SortBy, SortOrder +from aixplain.enums import DataType, Function, Language, OwnershipType, Supplier, SortBy, SortOrder, AssetStatus from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from datetime import datetime @@ -80,6 +80,7 @@ def create_model_from_response(response: Dict) -> Model: version=response["version"]["id"], inputs=inputs, temperature=temperature, + status=response.get("status", AssetStatus.DRAFT), ) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index c0209313..ecba14d9 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -125,17 +125,13 @@ def _validate(self) -> None: except Exception: raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") - assert ( - llm.function == Function.TEXT_GENERATION - ), "Large Language Model must be a text generation model." + assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." for tool in self.tools: if isinstance(tool, Tool): tool.validate() elif isinstance(tool, Model): - assert not isinstance( - tool, Agent - ), "Agent cannot contain another Agent." + assert not isinstance(tool, Agent), "Agent cannot contain another Agent." def validate(self, raise_exception: bool = False) -> bool: """Validate the Agent.""" @@ -148,9 +144,7 @@ def validate(self, raise_exception: bool = False) -> bool: raise e else: logging.warning(f"Agent Validation Error: {e}") - logging.warning( - "You won't be able to run the Agent until the issues are handled manually." - ) + logging.warning("You won't be able to run the Agent until the issues are handled manually.") return self.is_valid def run( @@ -207,10 +201,8 @@ def run( return response poll_url = response["url"] end = time.time() - result = self.sync_poll( - poll_url, name=name, timeout=timeout, wait_time=wait_time - ) - result_data = result.data + result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + result_data = result.get("data") or {} return AgentResponse( status=ResponseStatus.SUCCESS, completed=True, @@ -272,18 +264,12 @@ def run_async( from aixplain.factories.file_factory import FileFactory if not self.is_valid: - raise Exception( - "Agent is not valid. Please validate the agent before running." - ) + raise Exception("Agent is not valid. Please validate the agent before running.") - assert ( - data is not None or query is not None - ), "Either 'data' or 'query' must be provided." + assert data is not None or query is not None, "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): - assert ( - "query" in data and data["query"] is not None - ), "When providing a dictionary, 'query' must be provided." + assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." query = data.get("query") if session_id is None: session_id = data.get("session_id") @@ -296,9 +282,7 @@ def run_async( # process content inputs if content is not None: - assert ( - FileFactory.check_storage_type(query) == StorageType.TEXT - ), "When providing 'content', query must be text." + assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text." if isinstance(content, list): assert len(content) <= 3, "The maximum number of content inputs is 3." @@ -307,9 +291,7 @@ def run_async( query += f"\n{input_link}" elif isinstance(content, dict): for key, value in content.items(): - assert ( - "{{" + key + "}}" in query - ), f"Key '{key}' not found in query." + assert "{{" + key + "}}" in query, f"Key '{key}' not found in query." value = FileFactory.to_link(value) query = query.replace("{{" + key + "}}", f"'{value}'") @@ -324,16 +306,8 @@ def run_async( "sessionId": session_id, "history": history, "executionParams": { - "maxTokens": ( - parameters["max_tokens"] - if "max_tokens" in parameters - else max_tokens - ), - "maxIterations": ( - parameters["max_iterations"] - if "max_iterations" in parameters - else max_iterations - ), + "maxTokens": (parameters["max_tokens"] if "max_tokens" in parameters else max_tokens), + "maxIterations": (parameters["max_iterations"] if "max_iterations" in parameters else max_iterations), "outputFormat": output_format.value, }, } @@ -367,11 +341,7 @@ def to_dict(self) -> Dict: "assets": [tool.to_dict() for tool in self.tools], "description": self.description, "role": self.instructions, - "supplier": ( - self.supplier.value["code"] - if isinstance(self.supplier, Supplier) - else self.supplier - ), + "supplier": (self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier), "version": self.version, "llmId": self.llm_id, "status": self.status.value, @@ -409,8 +379,7 @@ def update(self) -> None: stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " - "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " "Please use save() instead.", DeprecationWarning, stacklevel=2, ) @@ -422,9 +391,7 @@ def update(self) -> None: payload = self.to_dict() - logging.debug( - f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}" - ) + logging.debug(f"Start service for PUT Update Agent - {url} - {headers} - {json.dumps(payload)}") resp = "No specified error." try: r = _request_with_retry("put", url, headers=headers, json=payload) @@ -443,9 +410,7 @@ def save(self) -> None: self.update() def deploy(self) -> None: - assert ( - self.status == AssetStatus.DRAFT - ), "Agent must be in draft status to be deployed." + assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." assert self.status != AssetStatus.ONBOARDED, "Agent is already deployed." self.status = AssetStatus.ONBOARDED self.update() diff --git a/aixplain/modules/agent/agent_response_data.py b/aixplain/modules/agent/agent_response_data.py index 6040be0c..0acd8b80 100644 --- a/aixplain/modules/agent/agent_response_data.py +++ b/aixplain/modules/agent/agent_response_data.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Text class AgentResponseData: @@ -54,3 +54,10 @@ def __repr__(self) -> str: f"intermediate_steps={self.intermediate_steps}, " f"execution_stats={self.execution_stats})" ) + + def __contains__(self, key: Text) -> bool: + try: + self[key] + return True + except KeyError: + return False diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index ba16317a..29c68e0f 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -53,44 +53,17 @@ def __init__( function (Optional[Union[Function, Text]]): task that the tool performs. Defaults to None. supplier (Optional[Union[Dict, Supplier]]): Preferred supplier to perform the task. Defaults to None. Defaults to None. model (Optional[Union[Text, Model]]): Model function. Defaults to None. + name (Optional[Text]): Name of the tool. Defaults to None. description (Text): Description of the tool. Defaults to "". parameters (Optional[Dict]): Parameters of the tool. Defaults to None. """ - assert ( - function is not None or model is not None - ), "Agent Creation Error: Either function or model must be provided when instantiating a tool." - name = name or "" super().__init__(name=name, description=description, **additional_info) - if function is not None: - if isinstance(function, str): - function = Function(function) - assert ( - function is None or function is not Function.UTILITIES or model is not None - ), "Agent Creation Error: Utility function must be used with an associated model." - - try: - if isinstance(supplier, dict): - supplier = Supplier(supplier) - except Exception: - supplier = None - - self.model_object = None - if model is not None: - if isinstance(model, Text) is True: - self.model = model - model = self.validate() - self.model_object = model - else: - self.model_object = model - function = model.function - if isinstance(model.supplier, Supplier): - supplier = model.supplier - model = model.id self.supplier = supplier self.model = model self.function = function - self.parameters = self.validate_parameters(parameters) + self.parameters = parameters + self.validate() def to_dict(self) -> Dict: """Converts the tool to a dictionary.""" @@ -110,23 +83,61 @@ def to_dict(self) -> Dict: "description": self.description, "supplier": supplier, "version": self.version if self.version else None, - "assetId": self.model, + "assetId": self.model.id if self.model is not None and isinstance(self.model, Model) else self.model, "parameters": self.parameters, } - def validate(self) -> Model: + def validate(self) -> None: + """ + Validates the tool. + Notes: + - Checks if the tool has a function or model. + - If the function is a string, it converts it to a Function enum. + - Checks if the function is a utility function and if it has an associated model. + - Validates the supplier. + - Validates the model. + - If the description is empty, it sets the description to the function description or the model description. + """ + from aixplain.enums import FunctionInputOutput from aixplain.factories.model_factory import ModelFactory - if self.model_object is not None: - return self.model_object + assert ( + self.function is not None or self.model is not None + ), "Agent Creation Error: Either function or model must be provided when instantiating a tool." + + if self.function is not None: + if isinstance(self.function, str): + self.function = Function(self.function) + assert ( + self.function is None or self.function is not Function.UTILITIES or self.model is not None + ), "Agent Creation Error: Utility function must be used with an associated model." try: - model = None - if self.model is not None: - model = ModelFactory.get(self.model, api_key=self.api_key) - return model + if isinstance(self.supplier, dict): + self.supplier = Supplier(self.supplier) except Exception: - raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.") + self.supplier = None + + if self.model is not None: + if isinstance(self.model, Text) is True: + try: + self.model = ModelFactory.get(self.model, api_key=self.api_key) + except Exception: + raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.") + self.function = self.model.function + if isinstance(self.model.supplier, Supplier): + self.supplier = self.model.supplier + + if self.description == "": + if self.model is not None: + self.description = self.model.description + elif self.function is not None: + try: + self.description = FunctionInputOutput[self.function.value]["spec"]["metaData"]["description"] + except Exception: + self.description = "" + + self.parameters = self.validate_parameters(self.parameters) def get_parameters(self) -> Dict: return self.parameters @@ -145,8 +156,8 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) """ if received_parameters is None: # Get default parameters if none provided - if self.model_object is not None and self.model_object.model_params is not None: - return self.model_object.model_params.to_list() + if self.model is not None and self.model.model_params is not None: + return self.model.model_params.to_list() elif self.function is not None: function_params = self.function.get_parameters() if function_params is not None: @@ -155,8 +166,8 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) # Get expected parameters expected_params = None - if self.model_object is not None and self.model_object.model_params is not None: - expected_params = self.model_object.model_params + if self.model is not None and self.model.model_params is not None: + expected_params = self.model.model_params elif self.function is not None: expected_params = self.function.get_parameters() diff --git a/aixplain/modules/agent/tool/python_interpreter_tool.py b/aixplain/modules/agent/tool/python_interpreter_tool.py index 2d1daa30..d5947bff 100644 --- a/aixplain/modules/agent/tool/python_interpreter_tool.py +++ b/aixplain/modules/agent/tool/python_interpreter_tool.py @@ -29,11 +29,12 @@ class PythonInterpreterTool(Tool): def __init__(self, **additional_info) -> None: """Python Interpreter Tool""" - super().__init__(name="Python Interpreter", description="", **additional_info) + description = "A Python shell. Use this to execute python commands. Input should be a valid python command." + super().__init__(name="Python Interpreter", description=description, **additional_info) def to_dict(self): return { - "description": "", + "description": self.description, "type": "utility", "utility": "custom_python_code", } diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 6fa7d9a5..96454181 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -180,18 +180,15 @@ def validate(self): """Validate the Utility Model.""" description = None name = None - inputs = [] # check if the model exists and if the code is strring with s3:// # if not, parse the code and update the description and inputs and do the validation # if yes, just do the validation on the description and inputs if not (self._model_exists() and str(self.code).startswith("s3://")): - self.code, inputs, description, name = parse_code_decorated(self.code) + self.code, self.inputs, description, name = parse_code_decorated(self.code) if self.name is None: self.name = name if self.description is None: self.description = description - if len(self.inputs) == 0: - self.inputs = inputs for input in self.inputs: input.validate() else: diff --git a/aixplain/modules/pipeline/designer/base.py b/aixplain/modules/pipeline/designer/base.py index 063accad..5d499a1c 100644 --- a/aixplain/modules/pipeline/designer/base.py +++ b/aixplain/modules/pipeline/designer/base.py @@ -1,4 +1,3 @@ -import re from typing import ( List, Union, @@ -94,8 +93,6 @@ def back_link(self, from_param: "Param") -> "Param": assert from_param.param_type == ParamType.OUTPUT, "Invalid param type" assert self.code in self.node.inputs, "Param not registered as input" link = from_param.node.link(self.node, from_param, self) - self.link_ = link - from_param.link_ = link return link def serialize(self) -> dict: @@ -131,6 +128,7 @@ class Link(Serializable): to_node: "Node" from_param: str to_param: str + data_source_id: Optional[str] = None pipeline: Optional["DesignerPipeline"] = None @@ -140,6 +138,7 @@ def __init__( to_node: "Node", from_param: Union[Param, str], to_param: Union[Param, str], + data_source_id: Optional[str] = None, pipeline: "DesignerPipeline" = None, ): @@ -149,28 +148,28 @@ def __init__( to_param = to_param.code assert from_param in from_node.outputs, ( - "Invalid from param. " "Make sure all input params are already linked accordingly" + "Invalid from param. " + "Make sure all input params are already linked accordingly" ) + assert to_param in to_node.inputs, ( + "Invalid to param. " + "Make sure all output params are already linked accordingly" + ) + + tp_instance = to_node.inputs[to_param] fp_instance = from_node.outputs[from_param] - from .nodes import Decision - - if isinstance(to_node, Decision) and to_param == to_node.inputs.passthrough.code: - if from_param not in to_node.outputs: - to_node.outputs.create_param( - from_param, - fp_instance.data_type, - is_required=fp_instance.is_required, - ) - else: - to_node.outputs[from_param].data_type = fp_instance.data_type assert to_param in to_node.inputs, "Invalid to param" + tp_instance.link_ = self + fp_instance.link_ = self + self.from_node = from_node self.to_node = to_node self.from_param = from_param self.to_param = to_param + self.data_source_id = data_source_id if pipeline: self.attach_to(pipeline) @@ -204,7 +203,9 @@ def validate(self): # Should we check for data type mismatch? if from_param.data_type and to_param.data_type: if from_param.data_type != to_param.data_type: - raise ValueError(f"Data type mismatch between {from_param.data_type} and {to_param.data_type}") # noqa + raise ValueError( + f"Data type mismatch between {from_param.data_type} and {to_param.data_type}" + ) # noqa def attach_to(self, pipeline: "DesignerPipeline"): """ @@ -224,15 +225,17 @@ def attach_to(self, pipeline: "DesignerPipeline"): def serialize(self) -> dict: assert self.from_node.number is not None, "From node number not set" assert self.to_node.number is not None, "To node number not set" + param_mapping = { + "from": self.from_param, + "to": self.to_param, + } + if self.data_source_id: + param_mapping["dataSourceId"] = self.data_source_id + return { "from": self.from_node.number, "to": self.to_node.number, - "paramMapping": [ - { - "from": self.from_param, - "to": self.to_param, - } - ], + "paramMapping": [param_mapping], } @@ -255,7 +258,9 @@ def add_param(self, param: Param) -> None: if not hasattr(self, param.code): setattr(self, param.code, param) - def _create_param(self, code: str, data_type: DataType = None, value: any = None) -> Param: + def _create_param( + self, code: str, data_type: DataType = None, value: any = None + ) -> Param: raise NotImplementedError() def create_param( @@ -299,7 +304,10 @@ def special_prompt_handling(self, code: str, value: str) -> None: if not isinstance(self.node, AssetNode): return - if not hasattr(self.node, "asset") or self.node.asset.function != "text-generation": + if ( + not hasattr(self.node, "asset") + or self.node.asset.function != "text-generation" + ): return matches = find_prompt_params(value) @@ -353,7 +361,9 @@ def _create_param( class Outputs(ParamProxy): - def _create_param(self, code: str, data_type: DataType = None, value: any = None) -> OutputParam: + def _create_param( + self, code: str, data_type: DataType = None, value: any = None + ) -> OutputParam: return OutputParam(code=code, data_type=data_type, value=value) diff --git a/aixplain/modules/pipeline/designer/nodes.py b/aixplain/modules/pipeline/designer/nodes.py index c8558846..10c9d3eb 100644 --- a/aixplain/modules/pipeline/designer/nodes.py +++ b/aixplain/modules/pipeline/designer/nodes.py @@ -378,11 +378,11 @@ def __init__(self, node: Node): class DecisionOutputs(Outputs): - input: OutputParam = None + data: OutputParam = None def __init__(self, node: Node): super().__init__(node) - self.input = self.create_param("input") + self.data = self.create_param("data") class Decision(Node[DecisionInputs, DecisionOutputs], LinkableMixin): @@ -407,7 +407,25 @@ def link( to_param: Union[str, Param], ) -> Link: link = super().link(to_node, from_param, to_param) - self.outputs.input.data_type = self.inputs.passthrough.data_type + + if isinstance(from_param, str): + assert from_param in self.outputs, f"Decision node has no input param called {from_param}, node linking validation is broken, please report this issue." + from_param = self.outputs[from_param] + + if from_param.code == "data": + if not self.inputs.passthrough.link_: + raise ValueError("To able to infer data source, " + "passthrough input param should be linked first.") + + # Infer data source from the passthrough node + link.data_source_id = self.inputs.passthrough.link_.from_node.number + + # Infer data type from the passthrough node + ref_param_code = self.inputs.passthrough.link_.from_param + ref_node = self.inputs.passthrough.link_.from_node + ref_param = ref_node.outputs[ref_param_code] + from_param.data_type = ref_param.data_type + return link def serialize(self) -> dict: diff --git a/aixplain/modules/pipeline/designer/pipeline.py b/aixplain/modules/pipeline/designer/pipeline.py index bf9c74a6..80248909 100644 --- a/aixplain/modules/pipeline/designer/pipeline.py +++ b/aixplain/modules/pipeline/designer/pipeline.py @@ -213,29 +213,7 @@ def auto_infer(self): nodes based on the data types of the connected nodes. """ for link in self.links: - from_node = self.get_node(link.from_node) - to_node = self.get_node(link.to_node) - if not from_node or not to_node: - continue # will be handled by the validation - for param in link.param_mapping: - from_param = from_node.outputs[param.from_param] - to_param = to_node.inputs[param.to_param] - if not from_param or not to_param: - continue # will be handled by the validation - # if one of the data types is missing, infer the other one - dataType = from_param.data_type or to_param.data_type - from_param.data_type = dataType - to_param.data_type = dataType - - def infer_data_type(node): - from .nodes import Input, Output - - if isinstance(node, Input) or isinstance(node, Output): - if dataType and dataType not in node.data_types: - node.data_types.append(dataType) - - infer_data_type(self) - infer_data_type(to_node) + link.auto_infer() def asset(self, asset_id: str, *args, asset_class: Type[T] = AssetNode, **kwargs) -> T: """ diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 87c78159..a913f223 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -38,6 +38,7 @@ from aixplain.modules.model import Model from aixplain.modules.agent import Agent, OutputFormat from aixplain.modules.agent.agent_response import AgentResponse +from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry @@ -135,7 +136,7 @@ def run( max_tokens: int = 2048, max_iterations: int = 30, output_format: OutputFormat = OutputFormat.TEXT, - ) -> Dict: + ) -> AgentResponse: """Runs a team agent call. Args: @@ -155,6 +156,7 @@ def run( Dict: parsed output from model """ start = time.time() + result_data = {} try: response = self.run_async( data=data, @@ -168,16 +170,27 @@ def run( max_iterations=max_iterations, output_format=output_format, ) - if response["status"] == "FAILED": + if response["status"] == ResponseStatus.FAILED: end = time.time() response["elapsed_time"] = end - start return response poll_url = response["url"] end = time.time() - response = self.sync_poll( - poll_url, name=name, timeout=timeout, wait_time=wait_time + result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + result_data = result.data + return AgentResponse( + status=ResponseStatus.SUCCESS, + completed=True, + data=AgentResponseData( + input=result_data.get("input"), + output=result_data.get("output"), + session_id=result_data.get("session_id"), + intermediate_steps=result_data.get("intermediate_steps"), + execution_stats=result_data.get("executionStats"), + ), + used_credits=result_data.get("usedCredits", 0.0), + run_time=result_data.get("runTime", end - start), ) - return response except Exception as e: logging.error(f"Team Agent Run: Error in running for {name}: {e}") end = time.time() @@ -199,7 +212,7 @@ def run_async( max_tokens: int = 2048, max_iterations: int = 30, output_format: OutputFormat = OutputFormat.TEXT, - ) -> Dict: + ) -> AgentResponse: """Runs asynchronously a Team Agent call. Args: @@ -224,9 +237,7 @@ def run_async( assert data is not None or query is not None, "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): - assert ( - "query" in data and data["query"] is not None - ), "When providing a dictionary, 'query' must be provided." + assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." if session_id is None: session_id = data.pop("session_id", None) if history is None: @@ -240,8 +251,7 @@ def run_async( # process content inputs if content is not None: assert ( - isinstance(query, str) - and FileFactory.check_storage_type(query) == StorageType.TEXT + isinstance(query, str) and FileFactory.check_storage_type(query) == StorageType.TEXT ), "When providing 'content', query must be text." if isinstance(content, list): @@ -251,9 +261,7 @@ def run_async( query += f"\n{input_link}" elif isinstance(content, dict): for key, value in content.items(): - assert ( - "{{" + key + "}}" in query - ), f"Key '{key}' not found in query." + assert "{{" + key + "}}" in query, f"Key '{key}' not found in query." value = FileFactory.to_link(value) query = query.replace("{{" + key + "}}", f"'{value}'") @@ -277,9 +285,7 @@ def run_async( payload = json.dumps(payload) r = _request_with_retry("post", self.url, headers=headers, data=payload) - logging.info( - f"Team Agent Run Async: Start service for {name} - {self.url} - {payload} - {headers}" - ) + logging.info(f"Team Agent Run Async: Start service for {name} - {self.url} - {payload} - {headers}") resp = None try: @@ -287,14 +293,20 @@ def run_async( logging.info(f"Result of request for {name} - {r.status_code} - {resp}") poll_url = resp["data"] - response = {"status": "IN_PROGRESS", "url": poll_url} + return AgentResponse( + status=ResponseStatus.IN_PROGRESS, + url=poll_url, + data=AgentResponseData(input=input_data), + run_time=0.0, + used_credits=0.0, + ) except Exception: - response = {"status": "FAILED"} msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Team Agent Run Async: Error in running for {name}: {resp}") - if resp is not None: - response["error"] = msg - return response + return AgentResponse( + status=ResponseStatus.FAILED, + error=msg, + ) def delete(self) -> None: """Delete Corpus service""" @@ -309,7 +321,9 @@ def delete(self) -> None: if r.status_code != 200: raise Exception() except Exception: - message = f"Team Agent Deletion Error (HTTP {r.status_code}): Make sure the Team Agent exists and you are the owner." + message = ( + f"Team Agent Deletion Error (HTTP {r.status_code}): Make sure the Team Agent exists and you are the owner." + ) logging.error(message) raise Exception(f"{message}") @@ -318,8 +332,7 @@ def to_dict(self) -> Dict: "id": self.id, "name": self.name, "agents": [ - {"assetId": agent.id, "number": idx, "type": "AGENT", "label": "AGENT"} - for idx, agent in enumerate(self.agents) + {"assetId": agent.id, "number": idx, "type": "AGENT", "label": "AGENT"} for idx, agent in enumerate(self.agents) ], "links": [], "description": self.description, @@ -345,9 +358,7 @@ def _validate(self) -> None: try: llm = ModelFactory.get(self.llm_id) - assert ( - llm.function == Function.TEXT_GENERATION - ), "Large Language Model must be a text generation model." + assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." except Exception: raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") @@ -377,8 +388,7 @@ def update(self) -> None: stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " - "Please use save() instead.", + "update() is deprecated and will be removed in a future version. " "Please use save() instead.", DeprecationWarning, stacklevel=2, ) @@ -396,9 +406,7 @@ def update(self) -> None: r = _request_with_retry("put", url, headers=headers, json=payload) resp = r.json() except Exception: - raise Exception( - "Team Agent Update Error: Please contact the administrators." - ) + raise Exception("Team Agent Update Error: Please contact the administrators.") if 200 <= r.status_code < 300: return build_team_agent(resp) @@ -412,11 +420,7 @@ def save(self) -> None: def deploy(self) -> None: """Deploy the Team Agent.""" - assert ( - self.status == AssetStatus.DRAFT - ), "Team Agent Deployment Error: Team Agent must be in draft status." - assert ( - self.status != AssetStatus.ONBOARDED - ), "Team Agent Deployment Error: Team Agent must be onboarded." + assert self.status == AssetStatus.DRAFT, "Team Agent Deployment Error: Team Agent must be in draft status." + assert self.status != AssetStatus.ONBOARDED, "Team Agent Deployment Error: Team Agent must be onboarded." self.status = AssetStatus.ONBOARDED self.update() diff --git a/pyproject.toml b/pyproject.toml index 971f15d8..ef57b89f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ "click>=7.1.2", "PyYAML>=6.0.1", "dataclasses-json>=0.5.2", - "Jinja2==3.1.4", + "Jinja2==3.1.6", ] [project.urls] diff --git a/tests/conftest.py b/tests/conftest.py index a17177b8..461e8146 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,114 @@ +import pytest +from typing import Any, Callable + from dotenv import load_dotenv # Load environment variables once for all tests load_dotenv(override=True) + +SDK_VERSION_ARG = "--sdk_version" +SDK_VERSION_PARAM_ARG = "--sdk_version_param" +PIPELINE_VERSION_ARG = "--pipeline_version" + +SDK_VERSION_V1 = "v1" +SDK_VERSION_V2 = "v2" +SDK_VERSIONS = [SDK_VERSION_V1, SDK_VERSION_V2] + +PIPELINE_VERSION_2_0 = "2.0" +PIPELINE_VERSION_3_0 = "3.0" +PIPELINE_VERSIONS = [PIPELINE_VERSION_2_0, PIPELINE_VERSION_3_0] + + +def pytest_addoption(parser: pytest.Parser): + # Here we're adding the options for the pipeline version and the sdk version + parser.addoption(f"{PIPELINE_VERSION_ARG}", action="store", help="pipeline version") + parser.addoption(f"{SDK_VERSION_ARG}", action="store", help="sdk version") + parser.addoption( + f"{SDK_VERSION_PARAM_ARG}", action="store", help="sdk version parameter" + ) + + +def filter_items(items: list, param_name: str, predicate: Callable): + """Filter the items based on the parameter name and the predicate. + + Args: + items (list): The list of items to filter. + param_name (str): The parameter name to filter by. + predicate (callable): The predicate to filter by. + """ + items[:] = [ + item + for item in items + if hasattr(item, "callspec") + and param_name in item.callspec.params + and predicate(item.callspec.params[param_name]) + ] + + +def filter_pipeline_version(items: list, pipeline_version: str): + """Filter the items based on the pipeline version. + + Args: + items (list): The list of items to filter. + pipeline_version (str): The pipeline version to filter by. + + Raises: + ValueError: If the pipeline version is invalid. + """ + if pipeline_version not in PIPELINE_VERSIONS: + raise ValueError(f"Invalid pipeline version: {pipeline_version}") + + filter_items(items, "version", lambda version: version == pipeline_version) + + +def filter_sdk_version(items: list, sdk_version: str, sdk_param: str): + """Filter the items based on the SDK version. + + Args: + items (list): The list of items to filter. + sdk_version (str): The SDK version to filter by. + + Raises: + ValueError: If the SDK version is invalid. + """ + if sdk_version not in SDK_VERSIONS: + raise ValueError(f"Invalid SDK version: {sdk_version}") + + from aixplain.v2.resource import BaseResource + + def predicate(param: Any): + return ( + issubclass(param, BaseResource) + if sdk_version == SDK_VERSION_V1 + else not issubclass(param, BaseResource) + ) + + filter_items(items, sdk_param, predicate) + + +def pytest_collection_modifyitems( + session: pytest.Session, config: pytest.Config, items: list +): + """Modify the items based on the pipeline version and the SDK version. + + Args: + session (pytest.Session): The pytest session. + config (pytest.Config): The pytest config. + items (list): The list of items to modify. + + Raises: + ValueError: If the pipeline version or the SDK version is invalid. + """ + pipeline_version = config.getoption(f"{PIPELINE_VERSION_ARG}") + sdk_version = config.getoption(f"{SDK_VERSION_ARG}") + + if pipeline_version: + filter_pipeline_version(items, pipeline_version) + + if sdk_version: + sdk_param = config.getoption(f"{SDK_VERSION_PARAM_ARG}") + if not sdk_param: + raise ValueError( + f"{SDK_VERSION_PARAM_ARG} parameter is required when using {SDK_VERSION_ARG}" + ) + filter_sdk_version(items, sdk_version, sdk_param) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 1a2cbcb4..5df56d13 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -107,7 +107,7 @@ def test_python_interpreter_tool(delete_agents_and_team_agents, AgentFactory): tool = AgentFactory.create_python_interpreter_tool() assert tool is not None assert tool.name == "Python Interpreter" - assert tool.description == "" + assert tool.description == "A Python shell. Use this to execute python commands. Input should be a valid python command." agent = AgentFactory.create( name="Python Developer", @@ -300,7 +300,8 @@ def test_update_tools_of_agent(run_input_map, delete_agents_and_team_agents, Age }, ], ) -def test_specific_model_parameters_e2e(tool_config): +def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agents): + assert delete_agents_and_team_agents """Test end-to-end agent execution with specific model parameters""" # Create tool based on config if tool_config["type"] == "search": @@ -350,111 +351,123 @@ def test_specific_model_parameters_e2e(tool_config): @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents - import os - - # Create test SQLite database - with open("ftest.db", "w") as f: - f.write("") - - tool = AgentFactory.create_sql_tool( - description="Execute an SQL query and return the result", source="ftest.db", source_type="sqlite", enable_commit=True - ) - assert tool is not None - assert tool.description == "Execute an SQL query and return the result" - - agent = AgentFactory.create( - name="Teste", - description="You are a test agent that search for employee information in a database", - tools=[tool], - ) - assert agent is not None - - response = agent.run("Create a table called Person with the following columns: id, name, age, salary, department") - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - - response = agent.run("Insert the following data into the Person table: 1, Eve, 30, 50000, Sales") - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - - response = agent.run("What is the name of the employee with the highest salary?") - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "eve" in str(response["data"]["output"]).lower() + try: + import os + + # Create test SQLite database + with open("ftest.db", "w") as f: + f.write("") + + tool = AgentFactory.create_sql_tool( + description="Execute an SQL query and return the result", + source="ftest.db", + source_type="sqlite", + enable_commit=True, + ) + assert tool is not None + assert tool.description == "Execute an SQL query and return the result" - os.remove("ftest.db") - agent.delete() + agent = AgentFactory.create( + name="Teste", + description="You are a test agent that search for employee information in a database", + tools=[tool], + ) + assert agent is not None + + response = agent.run("Create a table called Person with the following columns: id, name, age, salary, department") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + response = agent.run("Insert the following data into the Person table: 1, Eve, 30, 50000, Sales") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + response = agent.run("What is the name of the employee with the highest salary?") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "eve" in str(response["data"]["output"]).lower() + finally: + os.remove("ftest.db") + agent.delete() @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents + try: + import os + import pandas as pd + + # remove test.csv if it exists + if os.path.exists("test.csv"): + os.remove("test.csv") + + # remove test.db if it exists + if os.path.exists("test.db"): + os.remove("test.db") + + # Create a more comprehensive test dataset + df = pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "David", "Eve"], + "department": ["Sales", "IT", "Sales", "Marketing", "IT"], + "salary": [75000, 85000, 72000, 68000, 90000], + } + ) + df.to_csv("test.csv", index=False) - import pandas as pd - - # Create a more comprehensive test dataset - df = pd.DataFrame( - { - "id": [1, 2, 3, 4, 5], - "name": ["Alice", "Bob", "Charlie", "David", "Eve"], - "department": ["Sales", "IT", "Sales", "Marketing", "IT"], - "salary": [75000, 85000, 72000, 68000, 90000], - } - ) - df.to_csv("test.csv", index=False) - - # Create SQL tool from CSV - tool = AgentFactory.create_sql_tool( - description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"] - ) - - # Verify tool setup - assert tool is not None - assert tool.description == "Execute SQL queries on employee data" - assert tool.database.endswith(".db") - assert tool.tables == ["employees"] - assert ( - tool.schema - == 'CREATE TABLE employees (\n "id" INTEGER, "name" TEXT, "department" TEXT, "salary" INTEGER\n )' # noqa: W503 - ) - assert not tool.enable_commit # must be False by default - - # Create an agent with the SQL tool - agent = AgentFactory.create( - name="SQL Query Agent", - description="I am an agent that helps query employee information from a database.", - instructions="Help users query employee information from the database. Use SQL queries to get the requested information.", - tools=[tool], - ) - assert agent is not None - - # Test 1: Basic SELECT query - response = agent.run("Who are all the employees in the Sales department?") - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "alice" in response["data"]["output"].lower() - assert "charlie" in response["data"]["output"].lower() - - # Test 2: Aggregation query - response = agent.run("What is the average salary in each department?") - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "sales" in response["data"]["output"].lower() - assert "it" in response["data"]["output"].lower() - assert "marketing" in response["data"]["output"].lower() - - # Test 3: Complex query with conditions - response = agent.run("Who is the highest paid employee in the IT department?") - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "eve" in response["data"]["output"].lower() - - import os + # Create SQL tool from CSV + tool = AgentFactory.create_sql_tool( + description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"] + ) - # Cleanup - os.remove("test.csv") - os.remove("test.db") - agent.delete() + # Verify tool setup + assert tool is not None + assert tool.description == "Execute SQL queries on employee data" + assert tool.database.endswith(".db") + assert tool.tables == ["employees"] + assert ( + tool.schema + == 'CREATE TABLE employees (\n "id" INTEGER, "name" TEXT, "department" TEXT, "salary" INTEGER\n )' # noqa: W503 + ) + assert not tool.enable_commit # must be False by default + + # Create an agent with the SQL tool + agent = AgentFactory.create( + name="SQL Query Agent", + description="I am an agent that helps query employee information from a database.", + instructions="Help users query employee information from the database. Use SQL queries to get the requested information.", + tools=[tool], + ) + assert agent is not None + + # Test 1: Basic SELECT query + response = agent.run("Who are all the employees in the Sales department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "alice" in response["data"]["output"].lower() + assert "charlie" in response["data"]["output"].lower() + + # Test 2: Aggregation query + response = agent.run("What is the average salary in each department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "sales" in response["data"]["output"].lower() + assert "it" in response["data"]["output"].lower() + assert "marketing" in response["data"]["output"].lower() + + # Test 3: Complex query with conditions + response = agent.run("Who is the highest paid employee in the IT department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "eve" in response["data"]["output"].lower() + + finally: + # Cleanup + os.remove("test.csv") + os.remove("test.db") + agent.delete() diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index b0b49933..72590314 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -63,6 +63,8 @@ def test_run_async(): pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, id="Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.OPENAI_ADA002, id="OpenAI Ada 002"), pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, id="Snowflake Arctic Embed L v2.0"), + pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, id="Multilingual E5 Large"), + pytest.param(EmbeddingModel.BGE_M3, id="BGE M3"), ], ) def test_index_model(embedding_model): @@ -108,6 +110,8 @@ def test_index_model(embedding_model): pytest.param(EmbeddingModel.OPENAI_ADA002, id="OpenAI Ada 002"), pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, id="Snowflake Arctic Embed L v2.0"), pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, id="Jina Clip v2 Multimodal"), + pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, id="Multilingual E5 Large"), + pytest.param(EmbeddingModel.BGE_M3, id="BGE M3"), ], ) def test_index_model_with_filter(embedding_model): diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py index 4c05c554..a2c67ca9 100644 --- a/tests/functional/model/run_utility_model_test.py +++ b/tests/functional/model/run_utility_model_test.py @@ -8,7 +8,7 @@ def test_run_utility_model(): utility_model = None try: inputs = [ - UtilityModelInput(name="inputA", description="input A is the only input", type=DataType.TEXT), + UtilityModelInput(name="inputA", description="The inputA input is a text", type=DataType.TEXT), ] output_description = "An example is 'test'" @@ -208,3 +208,130 @@ def get_user_location(dummy_input: str, dummy_input2: str) -> str: utility_model.delete() if utility_model_duplicate: utility_model_duplicate.delete() + + +def test_utility_model_update(): + utility_model = None + updated_model = None + final_model = None + try: + # Define initial model with string concatenation + def concat_strings(str1: str, str2: str): + """Concatenates two strings. + + Args: + str1: The first string. + str2: The second string. + + Returns: + The concatenated string. + """ + return str1 + str2 + + # Create and deploy the utility model + utility_model = ModelFactory.create_utility_model( + name="concat_strings_test", + description="Initial string concatenation utility", + code=concat_strings, + ) + assert utility_model.status == AssetStatus.DRAFT + utility_model.deploy() + + assert utility_model.status == AssetStatus.ONBOARDED + + # Verify initial state + assert utility_model.name == "concat_strings_test" + assert utility_model.description == "Initial string concatenation utility" + assert len(utility_model.inputs) == 2 + assert utility_model.inputs[0].name == "str1" + assert utility_model.inputs[1].name == "str2" + + # Test initial behavior + response = utility_model.run({"str1": "Hello, ", "str2": "World!"}) + assert response.status == "SUCCESS" + assert response.data == "Hello, World!" + + # Define new function with different signature + def sum_numbers(num1: int, num2: int): + """Sums two numbers. + + Args: + num1: The first number. + num2: The second number. + + Returns: + The sum of the two numbers. + """ + return num1 + num2 + + # Update model with new name, description, and code + utility_model.name = "sum_numbers_test" + utility_model.description = "Updated to sum numbers utility" + utility_model.code = sum_numbers + utility_model.save() + + # Verify updated state + updated_model = ModelFactory.get(utility_model.id) + assert updated_model.status == AssetStatus.DRAFT + assert updated_model.name == "sum_numbers_test" + assert updated_model.description == "Updated to sum numbers utility" + assert len(updated_model.inputs) == 2 + assert updated_model.inputs[0].name == "num1" + assert updated_model.inputs[1].name == "num2" + + # Test updated behavior with new function + response = updated_model.run({"num1": 5, "num2": 7}) + assert response.status == "SUCCESS" + assert response.data == "12" + + # Test partial update - only update code, keeping name and description + def multiply_numbers(num1: int, num2: int): + """Multiplies two numbers. + + Args: + num1: The first number. + num2: The second number. + + Returns: + The product of the two numbers. + """ + return num1 * num2 + + updated_model.code = multiply_numbers + assert updated_model.status == AssetStatus.DRAFT + updated_model.deploy() + assert updated_model.status == AssetStatus.ONBOARDED + + updated_model.save() + + # Verify partial update + final_model = ModelFactory.get(utility_model.id) + assert final_model.name == "sum_numbers_test" + assert final_model.description == "Updated to sum numbers utility" + assert final_model.status == AssetStatus.DRAFT + # Test final behavior with new function but same input field names + response = final_model.run({"num1": 5, "num2": 7}) + assert response.status == "SUCCESS" + assert response.data == "35" + + finally: + if utility_model: + utility_model.delete() + if updated_model: + updated_model.delete() + if final_model: + final_model.delete() + + +def test_model_tool_creation(): + from aixplain.factories import AgentFactory + import warnings + + # Capture warnings during the create_model_tool call + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered + warnings.simplefilter("always") + # Create the model tool + AgentFactory.create_model_tool(model="6736411cf127849667606689") # Tavily Search + # Check that no warnings were raised + assert len(w) == 0, f"Warning was raised when calling create_model_tool: {[warning.message for warning in w]}" diff --git a/tests/functional/pipelines/designer_test.py b/tests/functional/pipelines/designer_test.py index 73a67996..a2710b7e 100644 --- a/tests/functional/pipelines/designer_test.py +++ b/tests/functional/pipelines/designer_test.py @@ -184,11 +184,10 @@ def test_decision_pipeline(pipeline): input.outputs.input.link(sentiment_analysis.inputs.text) sentiment_analysis.outputs.data.link(decision_node.inputs.comparison) input.outputs.input.link(decision_node.inputs.passthrough) - decision_node.outputs.input.link(positive_output.inputs.output) - decision_node.outputs.input.link(negative_output.inputs.output) + decision_node.outputs.data.link(positive_output.inputs.output) + decision_node.outputs.data.link(negative_output.inputs.output) pipeline.save() - output = pipeline.run( "I feel so bad today!", version="3.0", diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 3cabad8c..cb5f80a9 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -301,6 +301,7 @@ def test_team_agent_with_parameterized_agents(run_input_map, delete_agents_and_t llm_id=run_input_map["llm_id"], tools=[search_tool], ) + search_agent.deploy() # Create second agent with translation tool translation_function = Function.TRANSLATION @@ -318,7 +319,7 @@ def test_team_agent_with_parameterized_agents(run_input_map, delete_agents_and_t llm_id=run_input_map["llm_id"], tools=[translation_tool], ) - + translation_agent.deploy() team_agent = create_team_agent( TeamAgentFactory, [search_agent, translation_agent], run_input_map, use_mentalist=True, use_inspector=True ) diff --git a/tests/unit/agent/agent_factory_utils_test.py b/tests/unit/agent/agent_factory_utils_test.py new file mode 100644 index 00000000..66a7a4e2 --- /dev/null +++ b/tests/unit/agent/agent_factory_utils_test.py @@ -0,0 +1,355 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from aixplain.factories.agent_factory.utils import build_tool, build_agent +from aixplain.enums import Function, Supplier +from aixplain.enums.asset_status import AssetStatus +from aixplain.modules.agent.tool.model_tool import ModelTool +from aixplain.modules.agent.tool.pipeline_tool import PipelineTool +from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool +from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool +from aixplain.modules.agent.tool.sql_tool import SQLTool +from aixplain.modules.agent import Agent +from aixplain.modules.agent.agent_task import AgentTask +from aixplain.factories.model_factory import ModelFactory + + +@pytest.fixture +def mock_model(): + """Create a mock model for testing.""" + mock = MagicMock() + mock.id = "test_model" + mock.function = Function.SPEECH_RECOGNITION + mock.supplier = Supplier.AIXPLAIN + mock.description = "Test model" + mock.model_params = None + mock.to_dict.return_value = { + "id": "test_model", + "function": "speech-recognition", + "supplier": "aixplain", + "description": "Test model", + } + return mock + + +@pytest.fixture(autouse=True) +def mock_model_factory(mock_model): + """Mock the ModelFactory.get method for all tests.""" + with patch.object(ModelFactory, "get", return_value=mock_model) as mock: + yield mock + + +@pytest.fixture +def mock_tools(): + """Create mock tools for testing.""" + return [Mock(spec=ModelTool), Mock(spec=PipelineTool)] + + +@pytest.mark.parametrize( + "tool_dict,expected_error", + [ + pytest.param( + {"type": "model", "supplier": "aixplain", "version": "1.0", "assetId": "test_model", "description": "Test model"}, + "Function is required for model tools", + id="missing_function", + ), + pytest.param( + { + "type": "model", + "supplier": "aixplain", + "version": "1.0", + "assetId": "test_model", + "description": "Test model", + "function": "invalid_function", + }, + "Function invalid_function is not a valid function", + id="invalid_function", + ), + pytest.param( + {"type": "invalid_type", "description": "Test tool"}, + "Agent Creation Error: Tool type not supported", + id="invalid_tool_type", + ), + ], +) +def test_build_tool_error_cases(tool_dict, expected_error): + """Test various error cases when building tools.""" + with pytest.raises(Exception) as exc_info: + build_tool(tool_dict) + assert expected_error in str(exc_info.value) + + +@pytest.mark.parametrize( + "tool_dict,expected_type,expected_attrs", + [ + pytest.param( + { + "type": "model", + "supplier": "aixplain", + "version": "1.0", + "assetId": "test_model", + "description": "Test model", + "function": "speech-recognition", + }, + ModelTool, + { + "function": Function.SPEECH_RECOGNITION, + "supplier": Supplier.AIXPLAIN, + "version": "1.0", + "model": "test_model", + "description": "Test model", + }, + id="model_tool_basic", + ), + pytest.param( + { + "type": "model", + "supplier": "aixplain", + "version": "1.0", + "assetId": "test_model", + "description": "Test model", + "function": "speech-recognition", + "parameters": [{"name": "language", "value": "en"}], + }, + ModelTool, + { + "function": Function.SPEECH_RECOGNITION, + "supplier": Supplier.AIXPLAIN, + "version": "1.0", + "model": "test_model", + "description": "Test model", + "parameters": [{"name": "language", "value": "en"}], + }, + id="model_tool_with_params", + ), + pytest.param( + {"type": "pipeline", "description": "Test pipeline", "assetId": "test_pipeline"}, + PipelineTool, + {"description": "Test pipeline", "pipeline": "test_pipeline"}, + id="pipeline_tool", + ), + pytest.param( + {"type": "utility", "description": "Test utility", "utilityCode": "print('Hello World')"}, + CustomPythonCodeTool, + {"description": "Test utility", "code": "print('Hello World')"}, + id="custom_python_tool", + ), + pytest.param( + {"type": "utility", "description": "Test utility"}, PythonInterpreterTool, {}, id="python_interpreter_tool" + ), + pytest.param( + { + "type": "sql", + "description": "Test SQL", + "parameters": [ + {"name": "database", "value": "test_db"}, + {"name": "schema", "value": "public"}, + {"name": "tables", "value": "table1,table2"}, + {"name": "enable_commit", "value": True}, + ], + }, + SQLTool, + { + "description": "Test SQL", + "database": "test_db", + "schema": "public", + "tables": ["table1", "table2"], + "enable_commit": True, + }, + id="sql_tool_boolean_commit", + ), + pytest.param( + { + "type": "sql", + "description": "Test SQL with string enable_commit", + "parameters": [ + {"name": "database", "value": "test_db"}, + {"name": "schema", "value": "public"}, + {"name": "tables", "value": "table1"}, + {"name": "enable_commit", "value": True}, + ], + }, + SQLTool, + { + "description": "Test SQL with string enable_commit", + "database": "test_db", + "schema": "public", + "tables": ["table1"], + "enable_commit": True, + }, + id="sql_tool_string_commit", + ), + ], +) +def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock_model): + """Test successful tool creation with various configurations.""" + with patch.object(ModelFactory, "get", return_value=mock_model): + tool = build_tool(tool_dict) + assert isinstance(tool, expected_type) + + for attr, value in expected_attrs.items(): + if attr == "model": + assert tool.model == mock_model + else: + assert getattr(tool, attr) == value + + +@pytest.mark.parametrize( + "payload,expected_attrs", + [ + pytest.param( + { + "id": "test_agent", + "name": "Test Agent", + "description": "Test Description", + "role": "Test Instructions", + "teamId": "test_team", + "version": "1.0", + "cost": 10.0, + "status": "onboarded", + "assets": [], + "tasks": [ + { + "name": "Task 1", + "description": "Task 1 Description", + "expectedOutput": "Expected Output 1", + "dependencies": ["dep1", "dep2"], + } + ], + }, + { + "id": "test_agent", + "name": "Test Agent", + "description": "Test Description", + "instructions": "Test Instructions", + "supplier": "test_team", + "version": "1.0", + "cost": 10.0, + "status": AssetStatus.ONBOARDED, + "tasks": [ + { + "name": "Task 1", + "description": "Task 1 Description", + "expected_output": "Expected Output 1", + "dependencies": ["dep1", "dep2"], + } + ], + }, + id="agent_with_tasks", + ), + pytest.param( + { + "id": "test_agent", + "name": "Test Agent", + "status": "onboarded", + "assets": [ + { + "type": "model", + "supplier": "aixplain", + "version": "1.0", + "assetId": "test_model", + "description": "Test model", + "function": "speech-recognition", + }, + {"type": "pipeline", "description": "Test pipeline", "assetId": "test_pipeline"}, + ], + }, + { + "id": "test_agent", + "name": "Test Agent", + "status": AssetStatus.ONBOARDED, + "tools": [{"type": ModelTool}, {"type": PipelineTool}], + }, + id="agent_with_tools", + ), + ], +) +def test_build_agent_success_cases(payload, expected_attrs, mock_tools): + """Test successful agent creation with various configurations.""" + agent = build_agent(payload, tools=mock_tools if "assets" not in payload else None) + assert isinstance(agent, Agent) + + for attr, value in expected_attrs.items(): + if attr == "tasks": + assert len(agent.tasks) == len(value) + for task, expected_task in zip(agent.tasks, value): + assert isinstance(task, AgentTask) + for task_attr, task_value in expected_task.items(): + assert getattr(task, task_attr) == task_value + elif attr == "tools": + assert len(agent.tools) == len(value) + for tool, expected_tool in zip(agent.tools, value): + assert isinstance(tool, expected_tool["type"]) + else: + assert getattr(agent, attr) == value + + +@pytest.mark.parametrize( + "payload,expected_error", + [ + pytest.param( + { + "id": "test_agent", + "name": "Test Agent", + "status": "onboarded", + "assets": [{"type": "invalid_type", "description": "Test tool", "assetId": "invalid_asset"}], + }, + "Agent Creation Error: Tool type not supported", + id="invalid_tool_type", + ), + pytest.param( + { + "id": "test_agent", + "name": "Test Agent", + "status": "onboarded", + "assets": [ + { + "type": "model", + "supplier": "aixplain", + "version": "1.0", + "assetId": "test_model", + "description": "Test model", + "function": "invalid_function", + } + ], + }, + "Function invalid_function is not a valid function", + id="invalid_function", + ), + pytest.param( + { + "id": "test_agent", + "name": "Test Agent", + "status": "onboarded", + "assets": [ + { + "type": "model", + "supplier": "aixplain", + "version": "1.0", + "assetId": "test_model", + "description": "Test model", + } + ], + }, + "Function is required for model tools", + id="missing_function", + ), + pytest.param( + { + "id": "test_agent", + "name": "Test Agent", + "status": "onboarded", + "assets": [{"type": "model", "assetId": "test_model", "function": "speech-recognition"}], + }, + "Tool test_model is not available. Make sure it exists or you have access to it. If you think this is an error, please contact the administrators.", + id="generic_error", + ), + ], +) +def test_build_agent_with_invalid_tool(payload, expected_error): + """Test that building an agent with an invalid tool handles the error gracefully.""" + with patch("logging.warning") as mock_warning: + agent = build_agent(payload) + assert isinstance(agent, Agent) + assert len(agent.tools) == 0 + mock_warning.assert_called_once() + assert expected_error in mock_warning.call_args[0][0] diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index 8afda296..b418172e 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -2,7 +2,7 @@ import requests_mock from aixplain.factories import AgentFactory from aixplain.enums.asset_status import AssetStatus -from aixplain.modules import Agent +from aixplain.modules import Agent, Model from aixplain.modules.agent import OutputFormat from aixplain.utils import config from aixplain.modules.agent.tool.pipeline_tool import PipelineTool @@ -93,12 +93,6 @@ def test_invalid_pipelinetool(): assert str(exc_info.value) == "Pipeline Tool Unavailable. Make sure Pipeline '309851793' exists or you have access to it." -def test_invalid_modeltool(): - with pytest.raises(Exception) as exc_info: - AgentFactory.create(name="Test", tools=[ModelTool(model="309851793")], llm_id="6646261c6eb563165658bbb1") - assert str(exc_info.value) == "Model Tool Unavailable. Make sure Model '309851793' exists or you have access to it." - - def test_invalid_llm_id(): with pytest.raises(Exception) as exc_info: AgentFactory.create(name="Test", description="", instructions="", tools=[], llm_id="123") @@ -117,7 +111,6 @@ def test_invalid_agent_name(): @patch("aixplain.factories.model_factory.ModelFactory.get") def test_create_agent(mock_model_factory_get): from aixplain.enums import Supplier, Function - from aixplain.modules import Model # Mock the model factory response mock_model = Model( @@ -166,7 +159,7 @@ def test_create_agent(mock_model_factory_get): { "type": "utility", "utility": "custom_python_code", - "description": "", + "description": "A Python shell. Use this to execute python commands. Input should be a valid python command.", }, ], } @@ -240,7 +233,6 @@ def test_to_dict(): @patch("aixplain.factories.model_factory.ModelFactory.get") def test_update_success(mock_model_factory_get): - from aixplain.modules import Model from aixplain.enums import Function # Mock the model factory response @@ -313,7 +305,6 @@ def test_update_success(mock_model_factory_get): @patch("aixplain.factories.model_factory.ModelFactory.get") def test_save_success(mock_model_factory_get): - from aixplain.modules import Model from aixplain.enums import Function # Mock the model factory response @@ -644,11 +635,9 @@ def test_function(query: str) -> str: assert tool.description == "Test description" -@patch("aixplain.modules.agent.tool.model_tool.ModelTool.validate", autospec=True) @patch("aixplain.factories.model_factory.ModelFactory.get") -def test_create_agent_with_model_instance(mock_model_factory_get, mock_validate): +def test_create_agent_with_model_instance(mock_model_factory_get): from aixplain.enums import Supplier, Function - from aixplain.modules import Model from aixplain.modules.model.model_parameters import ModelParameters # Create model parameters @@ -665,9 +654,6 @@ def test_create_agent_with_model_instance(mock_model_factory_get, mock_validate) model_params=model_params, ) - # Mock the validate method to return the model instance - mock_validate.return_value = model_tool - # Mock the LLM model factory response llm_model = Model( id="6646261c6eb563165658bbb1", @@ -676,7 +662,15 @@ def test_create_agent_with_model_instance(mock_model_factory_get, mock_validate) function=Function.TEXT_GENERATION, model_params=model_params, ) - mock_model_factory_get.return_value = llm_model + + def validate_side_effect(model_id, *args, **kwargs): + if model_id == "model123": + return model_tool + elif model_id == "6646261c6eb563165658bbb1": + return llm_model + return None + + mock_model_factory_get.side_effect = validate_side_effect with requests_mock.Mocker() as mock: url = urljoin(config.BACKEND_URL, "sdk/agents") @@ -732,6 +726,7 @@ def test_create_agent_with_model_instance(mock_model_factory_get, mock_validate) # Verify the tool was converted correctly tool = agent.tools[0] assert isinstance(tool, Model) + assert tool.id == "model123" assert tool.name == model_tool.name assert tool.function == model_tool.function assert tool.supplier == model_tool.supplier @@ -740,9 +735,8 @@ def test_create_agent_with_model_instance(mock_model_factory_get, mock_validate) assert not tool.model_params.parameters["max_tokens"].required -@patch("aixplain.modules.agent.tool.model_tool.ModelTool.validate", autospec=True) @patch("aixplain.factories.model_factory.ModelFactory.get") -def test_create_agent_with_mixed_tools(mock_model_factory_get, mock_validate): +def test_create_agent_with_mixed_tools(mock_model_factory_get): from aixplain.enums import Supplier, Function from aixplain.modules import Model from aixplain.modules.model.model_parameters import ModelParameters @@ -774,15 +768,26 @@ def test_create_agent_with_mixed_tools(mock_model_factory_get, mock_validate): model_params=classification_params, ) + # Mock the LLM model factory response + llm_model = Model( + id="6646261c6eb563165658bbb1", + name="Test LLM", + description="Test LLM Description", + function=Function.TEXT_GENERATION, + model_params=text_gen_params, + ) + # Mock the validate method to return different models based on the model ID - def validate_side_effect(self, *args, **kwargs): - if self.model == "model123": + def validate_side_effect(model_id, *args, **kwargs): + if model_id == "model123": return model_tool - elif self.model == "openai-model": + elif model_id == "openai-model": return openai_model + elif model_id == "6646261c6eb563165658bbb1": + return llm_model return None - mock_validate.side_effect = validate_side_effect + mock_model_factory_get.side_effect = validate_side_effect # Create a regular ModelTool instance regular_tool = AgentFactory.create_model_tool( @@ -792,16 +797,6 @@ def validate_side_effect(self, *args, **kwargs): description="Regular Tool", ) - # Mock the LLM model factory response - llm_model = Model( - id="6646261c6eb563165658bbb1", - name="Test LLM", - description="Test LLM Description", - function=Function.TEXT_GENERATION, - model_params=text_gen_params, - ) - mock_model_factory_get.return_value = llm_model - with requests_mock.Mocker() as mock: url = urljoin(config.BACKEND_URL, "sdk/agents") headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} @@ -861,9 +856,10 @@ def validate_side_effect(self, *args, **kwargs): assert agent.description == ref_response["description"] assert len(agent.tools) == 2 - # Verify the first tool (Model) + # Verify the first tool (Model instance converted to ModelTool) tool1 = agent.tools[0] assert isinstance(tool1, Model) + assert tool1.id == "model123" assert tool1.name == model_tool.name assert tool1.function == model_tool.function assert tool1.supplier == model_tool.supplier @@ -874,13 +870,13 @@ def validate_side_effect(self, *args, **kwargs): # Verify the second tool (regular ModelTool) tool2 = agent.tools[1] assert isinstance(tool2, ModelTool) - assert tool2.model == "openai-model" + assert tool2.model.id == "openai-model" assert tool2.function == Function.TEXT_CLASSIFICATION assert tool2.supplier == Supplier.OPENAI - assert isinstance(tool2.model_object, Model) - assert isinstance(tool2.model_object.model_params, ModelParameters) - assert tool2.model_object.model_params.parameters["threshold"].required - assert tool2.model_object.model_params.parameters["labels"].required + assert isinstance(tool2.model, Model) + assert isinstance(tool2.model.model_params, ModelParameters) + assert tool2.model.model_params.parameters["threshold"].required + assert tool2.model.model_params.parameters["labels"].required @pytest.mark.parametrize( diff --git a/tests/unit/agent/model_tool_test.py b/tests/unit/agent/model_tool_test.py index 6c5cf8fb..c29df945 100644 --- a/tests/unit/agent/model_tool_test.py +++ b/tests/unit/agent/model_tool_test.py @@ -18,6 +18,7 @@ def mock_model(): model.function = Function.TRANSLATION model.supplier = Supplier.AIXPLAIN model.name = "Test Model" + model.description = "Test Model Description" model.model_params = ModelParameters( { "sourcelanguage": {"name": "sourcelanguage", "required": True}, @@ -51,9 +52,10 @@ def test_init_with_model(mock_model, mock_model_factory): mock_model_factory.get.return_value = mock_model tool = ModelTool(model="test_model_id") assert tool.function == Function.TRANSLATION - assert tool.model == "test_model_id" + assert tool.model.id == "test_model_id" assert tool.supplier == Supplier.AIXPLAIN - assert tool.model_object == mock_model + assert tool.model == mock_model + assert tool.description == "Test Model Description" def test_init_with_supplier_dict(): @@ -112,22 +114,12 @@ def test_to_dict(mock_model, mock_model_factory): def test_validate(mock_model, mock_model_factory, model_exists): if model_exists: mock_model_factory.get.return_value = mock_model - with patch.object(ModelTool, "__init__", return_value=None): - tool = ModelTool() - tool.model = "test_model_id" - tool.api_key = None - tool.model_object = None - validated_model = tool.validate() - assert validated_model == mock_model + tool = ModelTool(model="test_model_id", api_key=None) + assert tool.model == mock_model else: mock_model_factory.get.side_effect = Exception("Model not found") - with patch.object(ModelTool, "__init__", return_value=None): - tool = ModelTool() - tool.model = "nonexistent_model" - tool.api_key = None - tool.model_object = None - with pytest.raises(Exception, match="Model Tool Unavailable"): - tool.validate() + with pytest.raises(Exception, match="Model Tool Unavailable"): + tool = ModelTool(model="nonexistent_model", api_key=None) def test_get_parameters(): @@ -154,28 +146,26 @@ def test_get_parameters(): (None, None, False, None), ], ) -def test_validate_parameters(mock_model, params, expected_result, error_expected, error_message): - with patch.object(ModelTool, "__init__", return_value=None): - tool = ModelTool() - tool.model_object = mock_model - tool.function = Function.TRANSLATION - - # Mock the model parameters - mock_params = MagicMock() - mock_params.parameters = { - "sourcelanguage": Parameter(name="sourcelanguage", required=True), - "targetlanguage": Parameter(name="targetlanguage", required=True), - } - # Mock the to_list method to return None when no parameters are set - mock_params.to_list.return_value = None - mock_model.model_params = mock_params +def test_validate_parameters(mocker, mock_model, params, expected_result, error_expected, error_message): + mocker.patch("aixplain.factories.model_factory.ModelFactory.get", return_value=mock_model) + tool = ModelTool(model=mock_model.id, function=Function.TRANSLATION) + + # Mock the model parameters + mock_params = MagicMock() + mock_params.parameters = { + "sourcelanguage": Parameter(name="sourcelanguage", required=True), + "targetlanguage": Parameter(name="targetlanguage", required=True), + } + # Mock the to_list method to return None when no parameters are set + mock_params.to_list.return_value = None + mock_model.model_params = mock_params - if error_expected: - with pytest.raises(ValueError, match=re.escape(error_message)): - tool.validate_parameters(params) - else: - result = tool.validate_parameters(params) - assert result == expected_result + if error_expected: + with pytest.raises(ValueError, match=re.escape(error_message)): + tool.validate_parameters(params) + else: + result = tool.validate_parameters(params) + assert result == expected_result @pytest.mark.parametrize( @@ -194,3 +184,36 @@ def test_tool_name(mock_model, mock_model_factory, tool_name, expected_name): # Verify name appears correctly in dictionary representation tool_dict = tool.to_dict() assert tool_dict["name"] == expected_name + + +def test_invalid_modeltool(mocker): + mocker.patch("aixplain.factories.model_factory.ModelFactory.get", side_effect=Exception()) + with pytest.raises(Exception) as exc_info: + model_tool = ModelTool(model="309851793") + model_tool.validate() + assert str(exc_info.value) == "Model Tool Unavailable. Make sure Model '309851793' exists or you have access to it." + + +def test_validate_model_tool_with_function(): + model_tool = ModelTool(function="text-generation") + assert model_tool.function == Function.TEXT_GENERATION + assert model_tool.description != "" + + +def test_validate_model_tool_with_model(mocker): + mocker.patch( + "aixplain.factories.model_factory.ModelFactory.get", + return_value=Model( + id="309851793", name="Test Model", description="Test Model Description", function=Function.TEXT_GENERATION + ), + ) + model_tool = ModelTool(model="309851793", function=Function.TRANSLATION) + assert model_tool.model.id == "309851793" + assert model_tool.function == Function.TEXT_GENERATION + assert model_tool.description != "" + + +def test_validate_model_tool_without_function_or_model(): + with pytest.raises(Exception) as exc_info: + ModelTool() + assert str(exc_info.value) == "Agent Creation Error: Either function or model must be provided when instantiating a tool." diff --git a/tests/unit/designer_unit_test.py b/tests/unit/designer_unit_test.py index 3b4d93bc..aaf45026 100644 --- a/tests/unit/designer_unit_test.py +++ b/tests/unit/designer_unit_test.py @@ -715,6 +715,33 @@ class AssetNode(Node): mock_attach_to.assert_called_once_with(pipeline) +def test_pipeline_decision_node_passthrough_linking(): + from aixplain.modules.pipeline.designer import Route + from aixplain.modules.pipeline.designer.nodes import Input, Output, Decision + + input_node = Input() + output_node = Output() + + decision_node = Decision(routes=[Mock(spec=Route)]) + + # Decision node "passthrough" param should be linked first to infer output param "data" + with pytest.raises(ValueError): + decision_node.link(output_node, from_param="data", to_param="output") + + # Link the "passthrough" param to the asset node + input_node.outputs.input.link(decision_node.inputs.passthrough) + + # Now we can link the "data" param to the asset node + decision_node.link(output_node, from_param="data", to_param="output") + + assert decision_node.outputs.data.link_ is not None + assert decision_node.outputs.data.link_.from_node == decision_node + assert decision_node.outputs.data.link_.to_node == output_node + assert decision_node.outputs.data.link_.to_param == "output" + assert decision_node.outputs.data.data_type == input_node.outputs.input.data_type + assert decision_node.outputs.data.link_.data_source_id == input_node.number + + def test_pipeline_special_prompt_validation(): from aixplain.modules.pipeline.designer.nodes import AssetNode