From a7ac7adad29112d1e40a4c8295302aa43f876d2d Mon Sep 17 00:00:00 2001 From: kadirpekel Date: Fri, 11 Apr 2025 17:37:40 +0200 Subject: [PATCH 01/62] ENG-1852: Hotfix, added fail-fast option (#488) --- .github/workflows/main.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index b939a3a0..faf380eb 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -12,6 +12,7 @@ jobs: setup-and-test: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: test-suite: [ 'tests/unit', From b0887bd19b05cd4c610711bb1089d9b6826ea8ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Wed, 16 Apr 2025 11:47:54 +0300 Subject: [PATCH 02/62] Fix: BUG-503 failed to update utility tool after get (#491) --- aixplain/factories/model_factory/utils.py | 15 +++++- tests/unit/utility_test.py | 58 ++++++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index c5a90367..c10e4369 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -10,6 +10,7 @@ from datetime import datetime from typing import Dict, Union, List, Optional, Tuple from urllib.parse import urljoin +import requests def create_model_from_response(response: Dict) -> Model: @@ -39,10 +40,13 @@ def create_model_from_response(response: Dict) -> Model: function_input_params, function_output_params = function.get_input_output_params() model_params = {param["name"]: param for param in response["params"]} + code = response.get("code", "") + inputs, temperature = [], None input_params, output_params = function_input_params, function_output_params ModelClass = Model + if function == Function.TEXT_GENERATION: ModelClass = LLM f = [p for p in response.get("params", []) if p["name"] == "temperature"] @@ -58,6 +62,15 @@ def create_model_from_response(response: Dict) -> Model: ] input_params = model_params + if not code: + if "version" in response and response["version"]: + version_link = response["version"]["id"] + if version_link: + version_content = requests.get(version_link).text + code = version_content + else: + raise Exception("Utility Model Error: Code not found") + created_at = None if "createdAt" in response and response["createdAt"]: created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) @@ -66,7 +79,7 @@ def create_model_from_response(response: Dict) -> Model: response["id"], response["name"], description=response.get("description", ""), - code=response.get("code", ""), + code=code if code else "", supplier=response["supplier"], api_key=response["api_key"], cost=response["pricing"], diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index e678fa7a..9aabd831 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -7,7 +7,7 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.modules.model.utils import parse_code, parse_code_decorated -from unittest.mock import patch +from unittest.mock import patch, MagicMock import warnings @@ -507,3 +507,59 @@ def test_utility_model_status_after_deployment(): warnings.simplefilter("always") utility_model.validate() assert len(w) == 0 + + +def test_concat_strings(): + """Test the concat_strings function directly.""" + assert concat_strings("Hello, ", "World!") == "Hello, World!" + assert concat_strings("", "") == "" + assert concat_strings("123", "456") == "123456" + + +def concat_strings(str1: str, str2: str): + """Concatenates two strings. + + Args: + str1: The first string. + str2: The second string. + + Returns: + The concatenated string. + """ + return str1 + str2 + + +@patch("aixplain.factories.ModelFactory") +def test_create_and_deploy_utility_model(mock_model_factory): + """Test creating and deploying a utility model with mocked backend requests.""" + # Mock the create_utility_model method + mock_utility_model = MagicMock() + mock_utility_model.id = "mock-utility-model-id" + mock_model_factory.create_utility_model.return_value = mock_utility_model + + # Mock the get method + mock_model = MagicMock() + mock_model_factory.get.return_value = mock_model + + # Create utility model + from aixplain.factories import ModelFactory + + utility_model = ModelFactory.create_utility_model( + name="concat_strings", + code=concat_strings, + ) + + # Assert create_utility_model was called with correct parameters + mock_model_factory.create_utility_model.assert_called_once() + + # Get the model with mocked id + model = ModelFactory.get(utility_model.id) + + # Assert get was called with the correct id + mock_model_factory.get.assert_called_once() + + # Deploy the model + model.deploy() + + # Assert deploy was called + mock_model.deploy.assert_called_once() From 89daf2f27a1b0e24a59221abadf1a7899d39a411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Wed, 16 Apr 2025 19:26:29 +0300 Subject: [PATCH 03/62] Bug-503-fix Update (#493) * Fix: BUG-503 failed to update utility tool after getFix: BUG-503 failed to update utility tool after get * fix if not url in code --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> --- aixplain/factories/model_factory/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index c10e4369..1be2186c 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -66,8 +66,11 @@ def create_model_from_response(response: Dict) -> Model: if "version" in response and response["version"]: version_link = response["version"]["id"] if version_link: - version_content = requests.get(version_link).text - code = version_content + try: + version_content = requests.get(version_link).text + code = version_content + except Exception: + code = "" else: raise Exception("Utility Model Error: Code not found") From 6bf7d00510b338ac6c54c62d2081c882e5f5b51f Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Thu, 17 Apr 2025 00:58:28 +0300 Subject: [PATCH 04/62] ENG-1920: added sentry cred (#476) * added sentry cred * removed condition and added dependency * Fixes on sentry setting --------- Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> --- aixplain/utils/config.py | 14 +++++++++++--- pyproject.toml | 1 + 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/aixplain/utils/config.py b/aixplain/utils/config.py index 6dc2bb49..6cb15775 100644 --- a/aixplain/utils/config.py +++ b/aixplain/utils/config.py @@ -16,17 +16,17 @@ import os import logging +import sentry_sdk logger = logging.getLogger(__name__) BACKEND_URL = os.getenv("BACKEND_URL", "https://platform-api.aixplain.com") -MODELS_RUN_URL = os.getenv( - "MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute" -) +MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute") # GET THE API KEY FROM CMD TEAM_API_KEY = os.getenv("TEAM_API_KEY", "") AIXPLAIN_API_KEY = os.getenv("AIXPLAIN_API_KEY", "") +ENV = "dev" if "dev" in BACKEND_URL else "test" if "test" in BACKEND_URL else "prod" if not TEAM_API_KEY and not AIXPLAIN_API_KEY: raise Exception( @@ -47,3 +47,11 @@ MODEL_API_KEY = os.getenv("MODEL_API_KEY", "") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") HF_TOKEN = os.getenv("HF_TOKEN", "") +SENTRY_DSN = os.getenv("SENTRY_DSN") + +if SENTRY_DSN: + sentry_sdk.init( + dsn=SENTRY_DSN, + environment=ENV, + send_default_pii=True, + ) diff --git a/pyproject.toml b/pyproject.toml index ef57b89f..422ea158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "PyYAML>=6.0.1", "dataclasses-json>=0.5.2", "Jinja2==3.1.6", + "sentry-sdk>=1.0.0" ] [project.urls] From 9e2b03c7c9ac885e22d6e8dc0699747369b28f2b Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Thu, 17 Apr 2025 17:59:35 -0300 Subject: [PATCH 05/62] Read input variables correctly (#496) --- aixplain/modules/agent/__init__.py | 2 +- .../functional/agent/agent_functional_test.py | 26 +++++++++++++++++++ tests/unit/agent/agent_test.py | 2 +- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index ecba14d9..38b8a7ae 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -298,7 +298,7 @@ def run_async( headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} # build query - input_data = process_variables(query, data, parameters, self.description) + input_data = process_variables(query, data, parameters, self.instructions) payload = { "id": self.id, diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 4e9152bb..151c4a55 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -458,3 +458,29 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): os.remove("test.csv") os.remove("test.db") agent.delete() + + +@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) +def test_instructions(delete_agents_and_team_agents, AgentFactory): + assert delete_agents_and_team_agents + + agent = AgentFactory.create( + name="Test Agent", + description="Test description", + instructions="Always respond with '{magic_word}' does not matter what you are prompted for.", + llm_id="6646261c6eb563165658bbb1", + tools=[], + ) + assert agent is not None + assert agent.status == AssetStatus.DRAFT + + agent = AgentFactory.get(agent.id) + assert agent is not None + response = agent.run(data={"magic_word": "aixplain", "query": "What is the capital of France?"}) + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "data" in response + assert response["data"]["session_id"] is not None + assert response["data"]["output"] is not None + assert "aixplain" in response["data"]["output"].lower() diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index b418172e..f767f0ec 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -399,7 +399,7 @@ def test_run_success(): def test_run_variable_error(): - agent = Agent("123", "Test Agent", "Translate the input data into {target_language}", "Test Agent Role") + agent = Agent("123", "Test Agent", "Agent description", "Translate the input data into {target_language}") with pytest.raises(Exception) as exc_info: agent.run_async(data={"query": "Hello, how are you?"}, output_format=OutputFormat.MARKDOWN) assert str(exc_info.value) == ( From 409a89315c2263e5aaa7a874091b9f7d1eadba2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Fri, 18 Apr 2025 01:17:53 +0300 Subject: [PATCH 06/62] Fix: file not deleted in test_sql_tool_with_csv test (#492) Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> --- .../functional/agent/agent_functional_test.py | 213 +++++++++--------- 1 file changed, 112 insertions(+), 101 deletions(-) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 151c4a55..2313234e 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -350,115 +350,126 @@ def test_specific_model_parameters_e2e(tool_config): @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents - import os - - # Create test SQLite database - with open("ftest.db", "w") as f: - f.write("") - - tool = AgentFactory.create_sql_tool( - description="Execute an SQL query and return the result", source="ftest.db", source_type="sqlite", enable_commit=True - ) - assert tool is not None - assert tool.description == "Execute an SQL query and return the result" - - agent = AgentFactory.create( - name="Teste", - description="You are a test agent that search for employee information in a database", - tools=[tool], - ) - assert agent is not None - - response = agent.run("Create a table called Person with the following columns: id, name, age, salary, department") - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - - response = agent.run("Insert the following data into the Person table: 1, Eve, 30, 50000, Sales") - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - - response = agent.run("What is the name of the employee with the highest salary?") - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "eve" in str(response["data"]["output"]).lower() + try: + import os + + # Create test SQLite database + with open("ftest.db", "w") as f: + f.write("") + + tool = AgentFactory.create_sql_tool( + description="Execute an SQL query and return the result", + source="ftest.db", + source_type="sqlite", + enable_commit=True, + ) + assert tool is not None + assert tool.description == "Execute an SQL query and return the result" - os.remove("ftest.db") - agent.delete() + agent = AgentFactory.create( + name="Teste", + description="You are a test agent that search for employee information in a database", + tools=[tool], + ) + assert agent is not None + + response = agent.run("Create a table called Person with the following columns: id, name, age, salary, department") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + response = agent.run("Insert the following data into the Person table: 1, Eve, 30, 50000, Sales") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + response = agent.run("What is the name of the employee with the highest salary?") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "eve" in str(response["data"]["output"]).lower() + finally: + os.remove("ftest.db") + agent.delete() @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents + try: + import os + import pandas as pd + + # remove test.csv if it exists + if os.path.exists("test.csv"): + os.remove("test.csv") + + # remove test.db if it exists + if os.path.exists("test.db"): + os.remove("test.db") + + # Create a more comprehensive test dataset + df = pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "David", "Eve"], + "department": ["Sales", "IT", "Sales", "Marketing", "IT"], + "salary": [75000, 85000, 72000, 68000, 90000], + } + ) + df.to_csv("test.csv", index=False) - import pandas as pd - - # Create a more comprehensive test dataset - df = pd.DataFrame( - { - "id": [1, 2, 3, 4, 5], - "name": ["Alice", "Bob", "Charlie", "David", "Eve"], - "department": ["Sales", "IT", "Sales", "Marketing", "IT"], - "salary": [75000, 85000, 72000, 68000, 90000], - } - ) - df.to_csv("test.csv", index=False) - - # Create SQL tool from CSV - tool = AgentFactory.create_sql_tool( - description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"] - ) - - # Verify tool setup - assert tool is not None - assert tool.description == "Execute SQL queries on employee data" - assert tool.database.endswith(".db") - assert tool.tables == ["employees"] - assert ( - tool.schema - == 'CREATE TABLE employees (\n "id" INTEGER, "name" TEXT, "department" TEXT, "salary" INTEGER\n )' # noqa: W503 - ) - assert not tool.enable_commit # must be False by default - - # Create an agent with the SQL tool - agent = AgentFactory.create( - name="SQL Query Agent", - description="I am an agent that helps query employee information from a database.", - instructions="Help users query employee information from the database. Use SQL queries to get the requested information.", - tools=[tool], - ) - assert agent is not None - - # Test 1: Basic SELECT query - response = agent.run("Who are all the employees in the Sales department?") - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "alice" in response["data"]["output"].lower() - assert "charlie" in response["data"]["output"].lower() - - # Test 2: Aggregation query - response = agent.run("What is the average salary in each department?") - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "sales" in response["data"]["output"].lower() - assert "it" in response["data"]["output"].lower() - assert "marketing" in response["data"]["output"].lower() - - # Test 3: Complex query with conditions - response = agent.run("Who is the highest paid employee in the IT department?") - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "eve" in response["data"]["output"].lower() - - import os - - # Cleanup - os.remove("test.csv") - os.remove("test.db") - agent.delete() + # Create SQL tool from CSV + tool = AgentFactory.create_sql_tool( + description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"] + ) + # Verify tool setup + assert tool is not None + assert tool.description == "Execute SQL queries on employee data" + assert tool.database.endswith(".db") + assert tool.tables == ["employees"] + assert ( + tool.schema + == 'CREATE TABLE employees (\n "id" INTEGER, "name" TEXT, "department" TEXT, "salary" INTEGER\n )' # noqa: W503 + ) + assert not tool.enable_commit # must be False by default + + # Create an agent with the SQL tool + agent = AgentFactory.create( + name="SQL Query Agent", + description="I am an agent that helps query employee information from a database.", + instructions="Help users query employee information from the database. Use SQL queries to get the requested information.", + tools=[tool], + ) + assert agent is not None + + # Test 1: Basic SELECT query + response = agent.run("Who are all the employees in the Sales department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "alice" in response["data"]["output"].lower() + assert "charlie" in response["data"]["output"].lower() + + # Test 2: Aggregation query + response = agent.run("What is the average salary in each department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "sales" in response["data"]["output"].lower() + assert "it" in response["data"]["output"].lower() + assert "marketing" in response["data"]["output"].lower() + + # Test 3: Complex query with conditions + response = agent.run("Who is the highest paid employee in the IT department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "eve" in response["data"]["output"].lower() + + finally: + # Cleanup + os.remove("test.csv") + os.remove("test.db") + agent.delete() @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_instructions(delete_agents_and_team_agents, AgentFactory): From 3acfacd20cb2e87330979f0a412981ce29f90a2b Mon Sep 17 00:00:00 2001 From: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Date: Thu, 17 Apr 2025 19:18:36 -0300 Subject: [PATCH 07/62] ENG-2007: Fix agent and team agent parametrized functional test - Sync with dev (#495) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix general assets and api key functional tests (#468) * Update return type so it works with python 3.8 (#390) * ENG-1557: model params are now available for designer assets (#387) * ENG-1559: Fixed designer tests (#388) * Use input api key to list models when given (#395) * BUG-375 new functional test regarding ensuring failure (#396) * Role 2 Instructions (#393) * added validate check when s3 link (#399) * MErge to prod (#340) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Merge dev to test (#107) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Development to Test (#109) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> * Merge to test (#111) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Update Finetuner functional tests (#112) * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#118) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Merge to test (#124) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message * Merge dev to test (#126) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira * Add new LLMs finetuner models (mistral and solar) (#128) * Merge dev to test (#129) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message * Add new LLMs finetuner models (mistral and solar) (#128) --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira * Enabling dataset ID and model ID as parameters for finetuner creation (#131) Co-authored-by: Thiago Castro Ferreira * Fix supplier representation of a model (#132) * Fix supplier representation of a model * Fixing parameter typing --------- Co-authored-by: Thiago Castro Ferreira * Fixing indentation in documentation sample code (#134) Co-authored-by: Thiago Castro Ferreira * Merge to test (#135) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message * Add new LLMs finetuner models (mistral and solar) (#128) * Enabling dataset ID and model ID as parameters for finetuner creation (#131) Co-authored-by: Thiago Castro Ferreira * Fix supplier representation of a model (#132) * Fix supplier representation of a model * Fixing parameter typing --------- Co-authored-by: Thiago Castro Ferreira * Fixing indentation in documentation sample code (#134) Co-authored-by: Thiago Castro Ferreira --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira * Update FineTune unit and functional tests (#136) * Merge dev to test (#137) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message * Add new LLMs finetuner models (mistral and solar) (#128) * Enabling dataset ID and model ID as parameters for finetuner creation (#131) Co-authored-by: Thiago Castro Ferreira * Fix supplier representation of a model (#132) * Fix supplier representation of a model * Fixing parameter typing --------- Co-authored-by: Thiago Castro Ferreira * Fixing indentation in documentation sample code (#134) Co-authored-by: Thiago Castro Ferreira * Update FineTune unit and functional tests (#136) --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira * Click fix (#140) * Merge to prod (#119) * Merge dev to test (#107) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Development to Test (#109) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> * Merge to test (#111) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#118) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com… * ENG-1851: Update test and main workflows to run unit tests (#472) * Update test and main workflows to run iunit tests * Update to install all test packages listed in pyproject.toml * Add pytest-mock to list of test dependencies * Update workflow names * Add continue on error to each test and fix typos * Fix team agent parametrized functional test (#470) * ENG-1852: Hotfix, added fail-fast option (#488) (#489) Co-authored-by: kadirpekel * Fix: file not deleted in test_sql_tool_with_csv test * Fix agent and team agent parametrized functional test * Update to use draft agent --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Signed-off-by: Michael Lam Signed-off-by: root Co-authored-by: kadirpekel Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: Ahmet Gündüz Co-authored-by: ikxplain <88332269+ikxplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira Co-authored-by: Shreyas Sharma <85180538+shreyasXplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira Co-authored-by: Hadi Nasrallah <87204330+hadi-aix@users.noreply.github.com> Co-authored-by: kadir pekel Co-authored-by: root Co-authored-by: Zaina Abu Shaban Co-authored-by: xainaz Co-authored-by: xainaz Co-authored-by: Lucas Pavanelli Co-authored-by: ikxplain Co-authored-by: OsujiCC Co-authored-by: Yunsu Kim Co-authored-by: Yunsu Kim Co-authored-by: Muhammad-Elmallah <145364766+Muhammad-Elmallah@users.noreply.github.com> Co-authored-by: ahmetgunduz --- tests/functional/agent/agent_functional_test.py | 3 ++- tests/functional/team_agent/team_agent_functional_test.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 2313234e..0026d65a 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -300,7 +300,8 @@ def test_update_tools_of_agent(run_input_map, delete_agents_and_team_agents, Age }, ], ) -def test_specific_model_parameters_e2e(tool_config): +def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agents): + assert delete_agents_and_team_agents """Test end-to-end agent execution with specific model parameters""" # Create tool based on config if tool_config["type"] == "search": diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 3cabad8c..cb5f80a9 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -301,6 +301,7 @@ def test_team_agent_with_parameterized_agents(run_input_map, delete_agents_and_t llm_id=run_input_map["llm_id"], tools=[search_tool], ) + search_agent.deploy() # Create second agent with translation tool translation_function = Function.TRANSLATION @@ -318,7 +319,7 @@ def test_team_agent_with_parameterized_agents(run_input_map, delete_agents_and_t llm_id=run_input_map["llm_id"], tools=[translation_tool], ) - + translation_agent.deploy() team_agent = create_team_agent( TeamAgentFactory, [search_agent, translation_agent], run_input_map, use_mentalist=True, use_inspector=True ) From da7ec38ae161f198b2ac273aa5d92f734e5e0d95 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <50206820+basitanees@users.noreply.github.com> Date: Wed, 23 Apr 2025 12:24:26 +0300 Subject: [PATCH 08/62] ENG-1789: Add multiple index backbones support (#443) * add vectara support * add GraphRAG support * restructure and use Generics * convert to abstract class, fix param name for llm * update functional test for new param style * collection name param name changed * add backwards compatibility * Adding a warning for deprecation and some tests * add option for embedding size in embedding params * add ZeroEntropy support --------- Co-authored-by: Thiago Castro Ferreira --- aixplain/enums/__init__.py | 1 + aixplain/enums/embedding_model.py | 2 +- aixplain/enums/index_stores.py | 14 +++ aixplain/factories/index_factory.py | 50 --------- aixplain/factories/index_factory/__init__.py | 102 +++++++++++++++++++ aixplain/factories/index_factory/utils.py | 66 ++++++++++++ tests/functional/model/run_model_test.py | 92 ++++++++++------- tests/unit/index_model_test.py | 35 ++++++- 8 files changed, 273 insertions(+), 89 deletions(-) create mode 100644 aixplain/enums/index_stores.py delete mode 100644 aixplain/factories/index_factory.py create mode 100644 aixplain/factories/index_factory/__init__.py create mode 100644 aixplain/factories/index_factory/utils.py diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 4f0364e1..e80c03c6 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -18,3 +18,4 @@ from .database_source import DatabaseSourceType from .embedding_model import EmbeddingModel from .asset_status import AssetStatus +from .index_stores import IndexStores diff --git a/aixplain/enums/embedding_model.py b/aixplain/enums/embedding_model.py index dbd6b9c1..c52387b2 100644 --- a/aixplain/enums/embedding_model.py +++ b/aixplain/enums/embedding_model.py @@ -20,7 +20,7 @@ from enum import Enum -class EmbeddingModel(Enum): +class EmbeddingModel(str, Enum): SNOWFLAKE_ARCTIC_EMBED_M_LONG = "6658d40729985c2cf72f42ec" OPENAI_ADA002 = "6734c55df127847059324d9e" SNOWFLAKE_ARCTIC_EMBED_L_V2_0 = "678a4f8547f687504744960a" diff --git a/aixplain/enums/index_stores.py b/aixplain/enums/index_stores.py new file mode 100644 index 00000000..7cdfabb3 --- /dev/null +++ b/aixplain/enums/index_stores.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class IndexStores(Enum): + AIR = {"name": "air", "id": "66eae6656eb56311f2595011"} + VECTARA = {"name": "vectara", "id": "655e20f46eb563062a1aa301"} + GRAPHRAG = {"name": "graphrag", "id": "67dd6d487cbf0a57cf4b72f3"} + ZERO_ENTROPY = {"name": "zeroentropy", "id": "6807949168e47e7844c1f0c5"} + + def __str__(self): + return self.value["name"] + + def get_model_id(self): + return self.value["id"] diff --git a/aixplain/factories/index_factory.py b/aixplain/factories/index_factory.py deleted file mode 100644 index 7588e583..00000000 --- a/aixplain/factories/index_factory.py +++ /dev/null @@ -1,50 +0,0 @@ -from aixplain.modules.model.index_model import IndexModel -from aixplain.factories import ModelFactory -from aixplain.enums import EmbeddingModel, Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier -from typing import Optional, Text, Union, List, Tuple - -AIR_MODEL_ID = "66eae6656eb56311f2595011" - - -class IndexFactory(ModelFactory): - @classmethod - def create( - cls, name: Text, description: Text, embedding_model: EmbeddingModel = EmbeddingModel.OPENAI_ADA002 - ) -> IndexModel: - """Create a new index collection""" - model = cls.get(AIR_MODEL_ID) - - data = {"data": name, "description": description, "model": embedding_model.value} - response = model.run(data=data) - if response.status == ResponseStatus.SUCCESS: - model_id = response.data - model = cls.get(model_id) - return model - - error_message = f"Index Factory Exception: {response.error_message}" - if error_message == "": - error_message = "Index Factory Exception: An error occurred while creating the index collection." - raise Exception(error_message) - - @classmethod - def list( - cls, - query: Optional[Text] = "", - suppliers: Optional[Union[Supplier, List[Supplier]]] = None, - ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, - sort_by: Optional[SortBy] = None, - sort_order: SortOrder = SortOrder.ASCENDING, - page_number: int = 0, - page_size: int = 20, - ) -> List[IndexModel]: - """List all indexes""" - return super().list( - function=Function.SEARCH, - query=query, - suppliers=suppliers, - ownership=ownership, - sort_by=sort_by, - sort_order=sort_order, - page_number=page_number, - page_size=page_size, - ) diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py new file mode 100644 index 00000000..74e3ee1a --- /dev/null +++ b/aixplain/factories/index_factory/__init__.py @@ -0,0 +1,102 @@ +__author__ = "aiXplain" + +""" +Copyright 2022 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Abdul Basit Anees, Thiago Castro Ferreira, Zaina Abushaban +Date: December 26th 2024 +Description: + Index Factory Class +""" + +from aixplain.modules.model.index_model import IndexModel +from aixplain.factories import ModelFactory +from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier, IndexStores, EmbeddingModel +from typing import Text, Union, List, Tuple, Optional, TypeVar, Generic +from aixplain.factories.index_factory.utils import BaseIndexParams + +T = TypeVar("T", bound=BaseIndexParams) + + +class IndexFactory(ModelFactory, Generic[T]): + @classmethod + def create( + cls, + name: Optional[Text] = None, + description: Optional[Text] = None, + embedding_model: EmbeddingModel = EmbeddingModel.OPENAI_ADA002, + params: Optional[T] = None, + **kwargs, + ) -> IndexModel: + """Create a new index collection""" + import warnings + + warnings.warn( + "name, description, and embedding_model will be deprecated in the next release. Please use params instead.", + DeprecationWarning, + ) + + model_id = IndexStores.AIR.get_model_id() + if params is not None: + model_id = params.id + data = params.to_dict() + assert ( + name is None and description is None + ), "Index Factory Exception: name, description, and embedding_model must not be provided when params is provided" + else: + assert ( + name is not None and description is not None and embedding_model is not None + ), "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + data = { + "data": name, + "description": description, + "model": embedding_model, + } + + model = cls.get(model_id) + + response = model.run(data=data) + if response.status == ResponseStatus.SUCCESS: + model_id = response.data + model = cls.get(model_id) + return model + + error_message = f"Index Factory Exception: {response.error_message}" + if response.error_message.strip() == "": + error_message = "Index Factory Exception: An error occurred while creating the index collection." + raise Exception(error_message) + + @classmethod + def list( + cls, + query: Optional[Text] = "", + suppliers: Optional[Union[Supplier, List[Supplier]]] = None, + ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, + sort_by: Optional[SortBy] = None, + sort_order: SortOrder = SortOrder.ASCENDING, + page_number: int = 0, + page_size: int = 20, + ) -> List[IndexModel]: + """List all indexes""" + return super().list( + function=Function.SEARCH, + query=query, + suppliers=suppliers, + ownership=ownership, + sort_by=sort_by, + sort_order=sort_order, + page_number=page_number, + page_size=page_size, + ) diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py new file mode 100644 index 00000000..83efd469 --- /dev/null +++ b/aixplain/factories/index_factory/utils.py @@ -0,0 +1,66 @@ +from pydantic import BaseModel, ConfigDict +from typing import Text, Optional, ClassVar, Dict +from aixplain.enums import IndexStores, EmbeddingModel +from abc import ABC, abstractmethod + + +class BaseIndexParams(BaseModel, ABC): + model_config = ConfigDict(use_enum_values=True) + name: Text + description: Optional[Text] = "" + + def to_dict(self): + data = self.model_dump(exclude_none=True) + data["data"] = data.pop("name") + return data + + @property + @abstractmethod + def id(self) -> str: + """Abstract property that must be implemented in subclasses.""" + pass + + +class BaseIndexParamsWithEmbeddingModel(BaseIndexParams, ABC): + embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 + embedding_size: Optional[int] = None + + def to_dict(self): + data = super().to_dict() + data["model"] = data.pop("embedding_model") + if data.get("embedding_size"): + data["additional_params"] = {"embedding_size": data.pop("embedding_size")} + return data + + +class VectaraParams(BaseIndexParams): + _id: ClassVar[str] = IndexStores.VECTARA.get_model_id() + + @property + def id(self) -> str: + return self._id + + +class ZeroEntropyParams(BaseIndexParams): + _id: ClassVar[str] = IndexStores.ZERO_ENTROPY.get_model_id() + + @property + def id(self) -> str: + return self._id + + +class AirParams(BaseIndexParamsWithEmbeddingModel): + _id: ClassVar[str] = IndexStores.AIR.get_model_id() + + @property + def id(self) -> str: + return self._id + + +class GraphRAGParams(BaseIndexParamsWithEmbeddingModel): + _id: ClassVar[str] = IndexStores.GRAPHRAG.get_model_id() + llm: Optional[Text] = None + + @property + def id(self) -> str: + return self._id diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 72590314..421feb89 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -8,6 +8,7 @@ from aixplain.modules import LLM from datetime import datetime, timedelta, timezone from pathlib import Path +from aixplain.factories.index_factory.utils import AirParams, VectaraParams, GraphRAGParams, ZeroEntropyParams def pytest_generate_tests(metafunc): @@ -16,7 +17,7 @@ def pytest_generate_tests(metafunc): models = ModelFactory.list(function=Function.TEXT_GENERATION)["results"] predefined_models = [] - for predefined_model in ["Groq Llama 3 70B", "Chat GPT 3.5", "GPT-4o"]: + for predefined_model in ["Groq Llama 3 70B", "GPT-4o"]: predefined_models.extend( [ m @@ -57,43 +58,27 @@ def test_run_async(): assert "teste" in response["data"].lower() -@pytest.mark.parametrize( - "embedding_model", - [ - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, id="Snowflake Arctic Embed M Long"), - pytest.param(EmbeddingModel.OPENAI_ADA002, id="OpenAI Ada 002"), - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, id="Snowflake Arctic Embed L v2.0"), - pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, id="Multilingual E5 Large"), - pytest.param(EmbeddingModel.BGE_M3, id="BGE M3"), - ], -) -def test_index_model(embedding_model): - from uuid import uuid4 +def run_index_model(index_model): from aixplain.modules.model.record import Record - from aixplain.factories import IndexFactory - for index in IndexFactory.list()["results"]: - index.delete() - - index_model = IndexFactory.create(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) - index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) - response = index_model.search("Hello") + index_model.upsert([Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})]) + response = index_model.search("Berlin") assert str(response.status) == "SUCCESS" - assert "world" in response.data.lower() + assert "germany" in response.data.lower() assert index_model.count() == 1 - index_model.upsert([Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={})]) - response = index_model.search("aiXplain") + index_model.upsert([Record(value="Ankara is the capital of Turkey.", value_type="text", uri="", id="1", attributes={})]) + response = index_model.search("Ankara") assert str(response.status) == "SUCCESS" - assert "aixplain" in response.data.lower() + assert "turkey" in response.data.lower() assert index_model.count() == 1 - index_model.upsert([Record(value="The world is great", value_type="text", uri="", id="2", attributes={})]) + index_model.upsert([Record(value="London is the capital of England.", value_type="text", uri="", id="2", attributes={})]) assert index_model.count() == 2 response = index_model.get_document("1") assert str(response.status) == "SUCCESS" - assert response.data == "Hello, aiXplain!" + assert response.data == "Berlin is the capital of Germany." assert index_model.count() == 2 response = index_model.delete_document("1") @@ -104,17 +89,43 @@ def test_index_model(embedding_model): @pytest.mark.parametrize( - "embedding_model", + "embedding_model,supplier_params", + [ + pytest.param(None, VectaraParams, id="VECTARA"), + pytest.param(None, ZeroEntropyParams, id="ZERO_ENTROPY"), + pytest.param(EmbeddingModel.OPENAI_ADA002, GraphRAGParams, id="GRAPHRAG"), + pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="AIR - OpenAI Ada 002"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, AirParams, id="AIR - Snowflake Arctic Embed M Long"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, AirParams, id="AIR - Snowflake Arctic Embed L v2.0"), + pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), + pytest.param(EmbeddingModel.BGE_M3, AirParams, id="AIR - BGE M3"), + ], +) +def test_index_model(embedding_model, supplier_params): + from uuid import uuid4 + from aixplain.factories import IndexFactory + + params = supplier_params(name=str(uuid4()), description=str(uuid4())) + if embedding_model is not None: + params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) + + index_model = IndexFactory.create(params=params) + run_index_model(index_model) + + +@pytest.mark.parametrize( + "embedding_model,supplier_params", [ - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, id="Snowflake Arctic Embed M Long"), - pytest.param(EmbeddingModel.OPENAI_ADA002, id="OpenAI Ada 002"), - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, id="Snowflake Arctic Embed L v2.0"), - pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, id="Jina Clip v2 Multimodal"), - pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, id="Multilingual E5 Large"), - pytest.param(EmbeddingModel.BGE_M3, id="BGE M3"), + pytest.param(None, VectaraParams, id="VECTARA"), + pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="OpenAI Ada 002"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, AirParams, id="Snowflake Arctic Embed M Long"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, AirParams, id="Snowflake Arctic Embed L v2.0"), + pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), + pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), + pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"), ], ) -def test_index_model_with_filter(embedding_model): +def test_index_model_with_filter(embedding_model, supplier_params): from uuid import uuid4 from aixplain.modules.model.record import Record from aixplain.factories import IndexFactory @@ -123,7 +134,11 @@ def test_index_model_with_filter(embedding_model): for index in IndexFactory.list()["results"]: index.delete() - index_model = IndexFactory.create(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) + params = supplier_params(name=str(uuid4()), description=str(uuid4())) + if embedding_model is not None: + params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) + + index_model = IndexFactory.create(params=params) index_model.upsert([Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})]) index_model.upsert( [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] @@ -157,18 +172,21 @@ def test_llm_run_with_file(): assert "🤖" in response["data"], "Robot emoji should be present in the response" -def test_index_model_with_image(): +def test_index_model_air_with_image(): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record from uuid import uuid4 + from aixplain.factories.index_factory.utils import AirParams for index in IndexFactory.list()["results"]: index.delete() - index_model = IndexFactory.create( + params = AirParams( name=f"Image Index {uuid4()}", description="Index for images", embedding_model=EmbeddingModel.JINA_CLIP_V2_MULTIMODAL ) + index_model = IndexFactory.create(params=params) + records = [] # Building image records.append( diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 6f826f36..9929ccf4 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -1,5 +1,6 @@ import requests_mock from aixplain.enums import DataType, Function, ResponseStatus, StorageType, EmbeddingModel +from aixplain.factories.index_factory import IndexFactory from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse from aixplain.modules.model.index_model import IndexModel @@ -123,7 +124,10 @@ def test_count_success(): def test_get_document_success(): - mock_response = {"status": "SUCCESS", "data": {"value": "Sample document content 1", "value_type": "text", "id": 0, "uri": "", "attributes": {}}} + mock_response = { + "status": "SUCCESS", + "data": {"value": "Sample document content 1", "value_type": "text", "id": 0, "uri": "", "attributes": {}}, + } mock_documents = [Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={})] with requests_mock.Mocker() as mock: mock.post(execute_url, json=mock_response, status_code=200) @@ -134,6 +138,7 @@ def test_get_document_success(): assert isinstance(response, ModelResponse) assert response.status == ResponseStatus.SUCCESS + def test_delete_document_success(): mock_response = {"status": "SUCCESS"} mock_documents = [Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={})] @@ -209,3 +214,31 @@ def test_index_filter(): assert filter.field == "category" assert filter.value == "world" assert filter.operator == IndexFilterOperator.EQUALS + + +def test_index_factory_create_failure(): + from aixplain.factories.index_factory.utils import AirParams + + with pytest.raises(Exception) as e: + IndexFactory.create( + name="test", + description="test", + embedding_model=EmbeddingModel.OPENAI_ADA002, + params=AirParams(name="test", description="test", embedding_model=EmbeddingModel.OPENAI_ADA002), + ) + assert ( + str(e.value) + == "Index Factory Exception: name, description, and embedding_model must not be provided when params is provided" + ) + + with pytest.raises(Exception) as e: + IndexFactory.create(description="test") + assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + + with pytest.raises(Exception) as e: + IndexFactory.create(name="test") + assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + + with pytest.raises(Exception) as e: + IndexFactory.create(name="test", description="test", embedding_model=None) + assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" From d3607804cafbbbba654c50e5876cc11bf71ed7aa Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Thu, 24 Apr 2025 14:45:32 -0300 Subject: [PATCH 09/62] ENG-2049: rename air functions (#500) --- aixplain/modules/model/index_model.py | 24 ++++++++++++------------ tests/functional/model/run_model_test.py | 8 ++++---- tests/unit/index_model_test.py | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index efd35b51..8a5b18fc 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -149,35 +149,35 @@ def count(self) -> float: if response.status == "SUCCESS": return int(response.data) raise Exception(f"Failed to count documents: {response.error_message}") - - def get_document(self, document_id: Text) -> ModelResponse: + + def get_record(self, record_id: Text) -> ModelResponse: """ Get a document from the index. Args: - document_id (Text): ID of the document to retrieve. + record_id (Text): ID of the document to retrieve. Returns: ModelResponse: Response containing the retrieved document data. Raises: Exception: If document retrieval fails. - + Example: - >>> index_model.get_document("123") + >>> index_model.get_record("123") """ - data = {"action": "get_document", "data": document_id} + data = {"action": "get_document", "data": record_id} response = self.run(data=data) if response.status == "SUCCESS": return response - raise Exception(f"Failed to get document: {response.error_message}") + raise Exception(f"Failed to get record: {response.error_message}") - def delete_document(self, document_id: Text) -> ModelResponse: + def delete_record(self, record_id: Text) -> ModelResponse: """ Delete a document from the index. Args: - document_id (Text): ID of the document to delete. + record_id (Text): ID of the document to delete. Returns: ModelResponse: Response containing the deleted document data. @@ -186,10 +186,10 @@ def delete_document(self, document_id: Text) -> ModelResponse: Exception: If document deletion fails. Example: - >>> index_model.delete_document("123") + >>> index_model.delete_record("123") """ - data = {"action": "delete", "data": document_id} + data = {"action": "delete", "data": record_id} response = self.run(data=data) if response.status == "SUCCESS": return response - raise Exception(f"Failed to delete document: {response.error_message}") \ No newline at end of file + raise Exception(f"Failed to delete record: {response.error_message}") diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 421feb89..916d6077 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -76,12 +76,12 @@ def run_index_model(index_model): index_model.upsert([Record(value="London is the capital of England.", value_type="text", uri="", id="2", attributes={})]) assert index_model.count() == 2 - response = index_model.get_document("1") + response = index_model.get_record("1") assert str(response.status) == "SUCCESS" - assert response.data == "Berlin is the capital of Germany." + assert response.data == "Ankara is the capital of Turkey." assert index_model.count() == 2 - response = index_model.delete_document("1") + response = index_model.delete_record("1") assert str(response.status) == "SUCCESS" assert index_model.count() == 1 @@ -229,7 +229,7 @@ def test_index_model_air_with_image(): assert index_model.count() == 4 - response = index_model.get_document("2") + response = index_model.get_record("2") assert str(response.status) == "SUCCESS" second_record = response.details[0]["metadata"]["uri"] assert "hurricane" in second_record.lower() diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 9929ccf4..4b265dd3 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -133,7 +133,7 @@ def test_get_document_success(): mock.post(execute_url, json=mock_response, status_code=200) index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) index_model.upsert(mock_documents) - response = index_model.get_document(0) + response = index_model.get_record(0) assert isinstance(response, ModelResponse) assert response.status == ResponseStatus.SUCCESS @@ -147,7 +147,7 @@ def test_delete_document_success(): mock.post(execute_url, json=mock_response, status_code=200) index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) index_model.upsert(mock_documents) - response = index_model.delete_document("0") + response = index_model.delete_record("0") assert isinstance(response, ModelResponse) assert response.status == ResponseStatus.SUCCESS From 24e88da515d0f0c4739742d676cca83dc58a9a30 Mon Sep 17 00:00:00 2001 From: OsujiCC Date: Thu, 24 Apr 2025 15:24:46 -0400 Subject: [PATCH 10/62] ENG 1924: aixplain sdk new test cases for agents using utility and pipeline tools. (#490) * new functional test cases * Fixing functional tests for agent with utility and pipeline tools --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira --- .../functional/agent/agent_functional_test.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 0026d65a..0eecfa3f 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -472,6 +472,7 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): os.remove("test.db") agent.delete() + @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_instructions(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents @@ -496,3 +497,99 @@ def test_instructions(delete_agents_and_team_agents, AgentFactory): assert response["data"]["session_id"] is not None assert response["data"]["output"] is not None assert "aixplain" in response["data"]["output"].lower() + assert "eve" in response["data"]["output"].lower() + + import os + + # Cleanup + os.remove("test.csv") + os.remove("test.db") + agent.delete() + + +@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) +def test_agent_with_utility_tool(delete_agents_and_team_agents, AgentFactory): + from aixplain.enums import DataType + from aixplain.modules.model.utility_model import utility_tool, UtilityModelInput + + assert delete_agents_and_team_agents + + @utility_tool( + name="vowel_remover", + description="Remove all vowels from a given string", + inputs=[UtilityModelInput(name="text", description="String from which to remove vowels", type=DataType.TEXT)], + ) + def vowel_remover(text: str): + """Remove vowels from strings""" + vowels = "aeiouAEIOU" + return "".join([char for char in text if char not in vowels]) + + vowel_remover_ = ModelFactory.create_utility_model(name="vowel_remover", code=vowel_remover) + + @utility_tool( + name="concat_strings", + description="Concatenate two strings into one", + inputs=[ + UtilityModelInput(name="string1", description="First string to concatenate", type=DataType.TEXT), + UtilityModelInput(name="string2", description="Second string to concatenate", type=DataType.TEXT), + ], + ) + def concat_strings(string1: str, string2: str): + return string1 + string2 + + concat_strings_ = ModelFactory.create_utility_model(name="concat_strings", code=concat_strings) + + instructions = """You are a text processing agent equipped with two specialized tools: a Vowel Remover and a String Concatenator. Your task involves processing input text in two ways. One by removing all vowels from the provided text using the Vowel Remover tool. Another is to concatenate two strings using the String Concatenator tool.""" + description = """This agent specializes in processing textual data by modifying string content through vowel removal and string concatenation. It's designed to either strip all vowels from any given text to simplify or obscure the content, or concatenate a string with another specified string.""" + + agent = AgentFactory.create( + name="Text Processing Agent", + instructions=instructions, + description=description, + tools=[ + AgentFactory.create_model_tool(model=vowel_remover_.id), + AgentFactory.create_model_tool(model=concat_strings_.id), + ], + llm_id="6646261c6eb563165658bbb1", + ) + + result_vowel = agent.run("Remove all the vowels in this string: 'Hello'") + result_concat_text = agent.run("Concat these strings: String1 = 'Hello'; string2= 'World!'.") + + assert "hll" in result_vowel["data"]["output"].lower() + assert "helloworld!" in result_concat_text["data"]["output"].lower() + + +@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) +def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory): + from aixplain.factories.pipeline_factory import PipelineFactory + + assert delete_agents_and_team_agents + + pipeline = PipelineFactory.init("Hello Pipeline") + input_node = pipeline.input() + input_node.label = "TextInput" + middle_node = pipeline.asset(asset_id="6646261c6eb563165658bbb1") + middle_node.inputs.prompt.value = "Respond with 'Hello' regardless of the input text: " + input_node.link(middle_node, "input", "text") + middle_node.use_output("data") + pipeline.save() + pipeline.deploy() + + pipeline_agent = AgentFactory.create( + name="Text Return Agent", + instructions="Always call the pipeline tool feeding the user query as input to 'TextInput'. Return the output of the pipeline as the final response.", + description="Return the text given.", + tools=[ + AgentFactory.create_pipeline_tool( + pipeline=pipeline.id, description="You are a tool that responds users query with only 'Hello'." + ), + ], + llm_id="6646261c6eb563165658bbb1", + ) + + answer = pipeline_agent.run("Who is the president of USA?") + pipeline.delete() + + assert "hello" in answer["data"]["output"].lower() + assert "hello pipeline" in answer["data"]["intermediate_steps"][0]["tool_steps"][0]["tool"].lower() From 044d8d8c5e7e8a95357381f8803e9b9416315d9d Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <50206820+basitanees@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:51:08 +0300 Subject: [PATCH 11/62] add pydantic requirement (#502) --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 422ea158..1cf10327 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,8 @@ dependencies = [ "PyYAML>=6.0.1", "dataclasses-json>=0.5.2", "Jinja2==3.1.6", - "sentry-sdk>=1.0.0" + "sentry-sdk>=1.0.0", + "pydantic>=2.10.6" ] [project.urls] From 4efdbc88d60a20f507c09befdcaa13030717a6fa Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Fri, 25 Apr 2025 10:21:21 -0300 Subject: [PATCH 12/62] ENG-1978: Adding instructions to teams (#485) * Adding instructions to teams * Update docstring --- .../factories/team_agent_factory/__init__.py | 5 ++- .../factories/team_agent_factory/utils.py | 1 + aixplain/modules/team_agent/__init__.py | 7 +++- .../team_agent/team_agent_functional_test.py | 42 +++++++++++++++++++ 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index c2611d6c..892d7ded 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -50,6 +50,7 @@ def create( num_inspectors: int = 1, inspector_targets: List[Union[InspectorTarget, Text]] = [InspectorTarget.STEPS], use_mentalist_and_inspector: bool = False, # TODO: remove this + instructions: Optional[Text] = None, ) -> TeamAgent: """Create a new team agent in the platform. @@ -57,7 +58,7 @@ def create( name: The name of the team agent. agents: A list of agents to be added to the team. llm_id: The ID of the LLM to be used for the team agent. - description: The description of the team agent. + description: The description of the team agent to be displayed in the aiXplain platform. api_key: The API key to be used for the team agent. supplier: The supplier of the team agent. version: The version of the team agent. @@ -66,6 +67,7 @@ def create( num_inspectors: The number of inspectors to be used for each inspection. inspector_targets: Which stages to be inspected during an execution of the team agent. (steps, output) use_mentalist_and_inspector: Whether to use the mentalist and inspector agents. (legacy) + instructions: The instructions to guide the team agent (i.e. appended in the prompt of the team agent). Returns: A new team agent instance. @@ -139,6 +141,7 @@ def create( "supplier": supplier, "version": version, "status": "draft", + "role": instructions, } team_agent = build_team_agent(payload=payload, agents=agent_list, api_key=api_key) diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index debadff8..48268968 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -38,6 +38,7 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = name=payload.get("name", ""), agents=payload_agents, description=payload.get("description", ""), + instructions=payload.get("role", None), supplier=payload.get("teamId", None), version=payload.get("version", None), cost=payload.get("cost", None), diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index a913f223..3014adb1 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -88,6 +88,7 @@ def __init__( max_inspectors: int = 1, inspector_targets: List[InspectorTarget] = [InspectorTarget.STEPS], status: AssetStatus = AssetStatus.DRAFT, + instructions: Optional[Text] = None, **additional_info, ) -> None: """Create a FineTune with the necessary information. @@ -96,7 +97,7 @@ def __init__( id (Text): ID of the Team Agent name (Text): Name of the Team Agent agents (List[Agent]): List of agents that the Team Agent uses. - description (Text, optional): description of the Team Agent. Defaults to "". + description (Text, optional): The description of the team agent to be displayed in the aiXplain platform. Defaults to "". llm_id (Text, optional): large language model. Defaults to GPT-4o (6646261c6eb563165658bbb1). supplier (Text): Supplier of the Team Agent. version (Text): Version of the Team Agent. @@ -104,6 +105,7 @@ def __init__( api_key (str): The TEAM API key used for authentication. cost (Dict, optional): model price. Defaults to None. use_mentalist_and_inspector (bool): Use Mentalist and Inspector tools. Defaults to True. + instructions (Text, optional): The instructions to guide the team agent (i.e. appended in the prompt of the team agent). Defaults to None. """ super().__init__(id, name, description, api_key, supplier, version, cost=cost) self.additional_info = additional_info @@ -113,7 +115,7 @@ def __init__( self.use_inspector = use_inspector self.max_inspectors = max_inspectors self.inspector_targets = inspector_targets - + self.instructions = instructions if isinstance(status, str): try: status = AssetStatus(status) @@ -345,6 +347,7 @@ def to_dict(self) -> Dict: "supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier, "version": self.version, "status": self.status.value, + "role": self.instructions, } def _validate(self) -> None: diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index cb5f80a9..759f3920 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -592,3 +592,45 @@ def test_team_agent_with_multiple_inspectors(run_input_map, delete_agents_and_te verify_response_generator(steps, has_output_target=False) team_agent.delete() + + +def test_team_agent_with_instructions(delete_agents_and_team_agents): + assert delete_agents_and_team_agents + + agent_1 = AgentFactory.create( + name="Agent 1", + description="Translation agent", + tools=[AgentFactory.create_model_tool(function=Function.TRANSLATION, supplier=Supplier.MICROSOFT)], + llm_id="6646261c6eb563165658bbb1", + ) + + agent_2 = AgentFactory.create( + name="Agent 2", + description="Translation agent", + tools=[AgentFactory.create_model_tool(function=Function.TRANSLATION, supplier=Supplier.GOOGLE)], + llm_id="6646261c6eb563165658bbb1", + ) + + team_agent = TeamAgentFactory.create( + name="Team Agent", + agents=[agent_1, agent_2], + description="Team agent", + instructions="Use only 'Agent 2' to solve the tasks.", + llm_id="6646261c6eb563165658bbb1", + use_mentalist=True, + use_inspector=False, + ) + + response = team_agent.run(data="Translate 'cat' to Portuguese") + assert response.status == "SUCCESS" + assert "gato" in response.data["output"] + + mentalist_steps = eval(response.data["intermediate_steps"][0]["output"]) + + called_agents = set([step["worker"] for step in mentalist_steps]) + assert len(called_agents) == 1 + assert "Agent 2" in called_agents + + team_agent.delete() + agent_1.delete() + agent_2.delete() From ab2fcc544c8309d0a6ef9d8a008fa087d0462f7b Mon Sep 17 00:00:00 2001 From: kadirpekel Date: Fri, 25 Apr 2025 16:59:01 +0200 Subject: [PATCH 13/62] BUG-504: Merged paramMappings for the same link vectors (#499) --- .../modules/pipeline/designer/pipeline.py | 17 ++++++- tests/unit/designer_unit_test.py | 44 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/aixplain/modules/pipeline/designer/pipeline.py b/aixplain/modules/pipeline/designer/pipeline.py index 80248909..0985026d 100644 --- a/aixplain/modules/pipeline/designer/pipeline.py +++ b/aixplain/modules/pipeline/designer/pipeline.py @@ -68,9 +68,22 @@ def serialize(self) -> dict: :return: the pipeline as a dictionary """ + nodes = [node.serialize() for node in self.nodes] + links = [link.serialize() for link in self.links] + + # merge the params for links using the key `from_node` and `to_node` + merged_links = {} + for link in links: + key = (link['from'], link['to']) + if key not in merged_links: + merged_links[key] = link + else: + existing_link = merged_links[key] + existing_link['paramMapping'] += link['paramMapping'] + return { - "nodes": [node.serialize() for node in self.nodes], - "links": [link.serialize() for link in self.links], + "nodes": nodes, + "links": list(merged_links.values()), } def validate_nodes(self): diff --git a/tests/unit/designer_unit_test.py b/tests/unit/designer_unit_test.py index aaf45026..143d5485 100644 --- a/tests/unit/designer_unit_test.py +++ b/tests/unit/designer_unit_test.py @@ -815,3 +815,47 @@ def test_pipeline_special_prompt_validation(): def test_find_prompt_params(input, expected): print(input, expected) assert find_prompt_params(input) == expected + + +def test_pipeline_serialize(): + pipeline = DesignerPipeline() + + # Create nodes + class AssetNode(Node, LinkableMixin): + type: NodeType = NodeType.ASSET + + node1 = AssetNode(pipeline=pipeline) + node2 = AssetNode(pipeline=pipeline) + + # Create multiple parameters for each node + node1.outputs.create_param("output1", DataType.TEXT, "foo1") + node1.outputs.create_param("output2", DataType.TEXT, "foo2") + node2.inputs.create_param("input1", DataType.TEXT, "bar1") + node2.inputs.create_param("input2", DataType.TEXT, "bar2") + + # Create multiple links between the same nodes + link1 = Link( + from_node=node1, + to_node=node2, + from_param="output1", + to_param="input1", + ) + pipeline.add_link(link1) + + link2 = Link( + from_node=node1, + to_node=node2, + from_param="output2", + to_param="input2" + ) + pipeline.add_link(link2) + + serialized = pipeline.serialize() + + # Verify pipeline serialization + assert len(serialized["links"]) == 1 + assert len(serialized["links"][0]["paramMapping"]) == 2 + assert {"from": "output1", "to": "input1"} in serialized["links"][0]["paramMapping"] + assert {"from": "output2", "to": "input2"} in serialized["links"][0]["paramMapping"] + assert serialized["nodes"][0] == node1.serialize() + assert serialized["nodes"][1] == node2.serialize() From 268cc1a69ee2633664cf7f95c2f0bd2d9e0b49a2 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Mon, 28 Apr 2025 16:44:45 -0300 Subject: [PATCH 14/62] ENG-1836: Set name of tools on the SDK (#501) * Set model tool name * Setting tool names * Validate custom tools * Addressing comments * Name on SQL tool * Default name for SQL tool --- aixplain/factories/agent_factory/__init__.py | 16 ++- aixplain/factories/agent_factory/utils.py | 8 +- aixplain/modules/agent/__init__.py | 12 +- .../agent/tool/custom_python_code_tool.py | 12 +- aixplain/modules/agent/tool/model_tool.py | 31 +++++ aixplain/modules/agent/tool/pipeline_tool.py | 22 +++- .../agent/tool/python_interpreter_tool.py | 4 + aixplain/modules/agent/tool/sql_tool.py | 4 +- aixplain/v2/agent.py | 17 ++- .../functional/agent/agent_functional_test.py | 21 ++-- tests/unit/agent/agent_factory_utils_test.py | 44 +++++-- tests/unit/agent/agent_test.py | 112 ++++++++++++------ tests/unit/agent/model_tool_test.py | 13 +- tests/unit/agent/sql_tool_test.py | 39 +++--- 14 files changed, 257 insertions(+), 98 deletions(-) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index bad94fb3..040bcd71 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -164,6 +164,7 @@ def create_model_tool( supplier: Optional[Union[Supplier, Text]] = None, description: Text = "", parameters: Optional[Dict] = None, + name: Optional[Text] = None, ) -> ModelTool: """Create a new model tool.""" if function is not None and isinstance(function, str): @@ -179,9 +180,11 @@ def create_model_tool( return ModelTool(function=function, supplier=supplier, model=model, description=description, parameters=parameters) @classmethod - def create_pipeline_tool(cls, description: Text, pipeline: Union[Pipeline, Text]) -> PipelineTool: + def create_pipeline_tool( + cls, description: Text, pipeline: Union[Pipeline, Text], name: Optional[Text] = None + ) -> PipelineTool: """Create a new pipeline tool.""" - return PipelineTool(description=description, pipeline=pipeline) + return PipelineTool(description=description, pipeline=pipeline, name=name) @classmethod def create_python_interpreter_tool(cls) -> PythonInterpreterTool: @@ -189,13 +192,16 @@ def create_python_interpreter_tool(cls) -> PythonInterpreterTool: return PythonInterpreterTool() @classmethod - def create_custom_python_code_tool(cls, code: Union[Text, Callable], description: Text = "") -> CustomPythonCodeTool: + def create_custom_python_code_tool( + cls, code: Union[Text, Callable], name: Text, description: Text = "" + ) -> CustomPythonCodeTool: """Create a new custom python code tool.""" - return CustomPythonCodeTool(description=description, code=code) + return CustomPythonCodeTool(name=name, description=description, code=code) @classmethod def create_sql_tool( cls, + name: Text, description: Text, source: str, source_type: Union[str, DatabaseSourceType], @@ -206,6 +212,7 @@ def create_sql_tool( """Create a new SQL tool Args: + name (Text): name of the tool description (Text): description of the database tool source (Union[Text, Dict]): database source - can be a connection string or dictionary with connection details source_type (Union[str, DatabaseSourceType]): type of source (postgresql, sqlite, csv) or DatabaseSourceType enum @@ -317,6 +324,7 @@ def create_sql_tool( # Create and return SQLTool return SQLTool( + name=name, description=description, database=database_path, schema=schema, diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 6dae6098..704d7fe7 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -60,6 +60,7 @@ def build_tool(tool: Dict): else: tool = PythonInterpreterTool() elif tool["type"] == "sql": + name = tool.get("name", "SQLTool") parameters = {parameter["name"]: parameter["value"] for parameter in tool.get("parameters", [])} database = parameters.get("database") schema = parameters.get("schema") @@ -67,7 +68,12 @@ def build_tool(tool: Dict): tables = tables.split(",") if tables is not None else None enable_commit = parameters.get("enable_commit", False) tool = SQLTool( - description=tool["description"], database=database, schema=schema, tables=tables, enable_commit=enable_commit + name=name, + description=tool["description"], + database=database, + schema=schema, + tables=tables, + enable_commit=enable_commit, ) else: raise ValueError("Agent Creation Error: Tool type not supported.") diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 38b8a7ae..227c0040 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -127,11 +127,21 @@ def _validate(self) -> None: assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." + tool_names = [] for tool in self.tools: + tool_name = None if isinstance(tool, Tool): - tool.validate() + tool_name = tool.name elif isinstance(tool, Model): assert not isinstance(tool, Agent), "Agent cannot contain another Agent." + tool_name = tool.name + tool_names.append(tool_name) + + if len(tool_names) != len(set(tool_names)): + duplicates = set([name for name in tool_names if tool_names.count(name) > 1]) + raise Exception( + f"Agent Creation Error - Duplicate tool names found: {', '.join(duplicates)}. Make sure all tool names are unique." + ) def validate(self, raise_exception: bool = False) -> bool: """Validate the Agent.""" diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index 511e4a8f..d544d2b1 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -21,7 +21,7 @@ Agentification Class """ -from typing import Text, Union, Callable +from typing import Text, Union, Callable, Optional from aixplain.modules.agent.tool import Tool import logging @@ -29,10 +29,13 @@ class CustomPythonCodeTool(Tool): """Custom Python Code Tool""" - def __init__(self, code: Union[Text, Callable], description: Text = "", **additional_info) -> None: + def __init__( + self, code: Union[Text, Callable], description: Text = "", name: Optional[Text] = None, **additional_info + ) -> None: """Custom Python Code Tool""" - super().__init__(name="Custom Python Code", description=description, **additional_info) + super().__init__(name=name or "", description=description, **additional_info) self.code = code + self.validate() def to_dict(self): return { @@ -64,3 +67,6 @@ def validate(self): ), "Custom Python Code Tool Error: Tool description is required" assert self.code and self.code.strip() != "", "Custom Python Code Tool Error: Code is required" assert self.name and self.name.strip() != "", "Custom Python Code Tool Error: Name is required" + + def __repr__(self) -> Text: + return f"CustomPythonCodeTool(name={self.name})" diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 29c68e0f..6a945a15 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -28,6 +28,30 @@ from aixplain.modules.model import Model +def set_tool_name(function: Function, supplier: Supplier = None, model: Model = None) -> Text: + """Sets the name of the tool based on the function, supplier, and model. + + Args: + function (Function): The function to be used in the tool. + supplier (Supplier): The supplier to be used in the tool. + model (Model): The model to be used in the tool. + + Returns: + Text: The name of the tool. + """ + function_name = function.value.lower().replace(" ", "_") + tool_name = f"{function_name}" + + if supplier is not None: + supplier_name = supplier.name.lower().replace(" ", "_") + tool_name += f"-{supplier_name}" + + if model is not None and supplier is not None: + model_name = model.name.lower().replace(" ", "_") + tool_name += f"-{model_name}" + return tool_name + + class ModelTool(Tool): """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. @@ -139,6 +163,8 @@ def validate(self) -> None: self.parameters = self.validate_parameters(self.parameters) + self.name = self.name if self.name else set_tool_name(self.function, self.supplier, self.model) + def get_parameters(self) -> Dict: return self.parameters @@ -187,3 +213,8 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) raise ValueError(f"Invalid parameters provided: {invalid_params}. Expected parameters are: {expected_param_names}") return received_parameters + + def __repr__(self) -> Text: + supplier_str = self.supplier.value if self.supplier is not None else None + model_str = self.model.id if self.model is not None else None + return f"ModelTool(name={self.name}, function={self.function}, supplier={supplier_str}, model={model_str})" diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 13d3c46f..0de83916 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -50,9 +50,8 @@ def __init__( name = name or "" super().__init__(name=name, description=description, **additional_info) - if isinstance(pipeline, Pipeline): - pipeline = pipeline.id self.pipeline = pipeline + self.validate() def to_dict(self): return { @@ -65,7 +64,18 @@ def to_dict(self): def validate(self): from aixplain.factories.pipeline_factory import PipelineFactory - try: - PipelineFactory.get(self.pipeline, api_key=self.api_key) - except Exception: - raise Exception(f"Pipeline Tool Unavailable. Make sure Pipeline '{self.pipeline}' exists or you have access to it.") + if isinstance(self.pipeline, Pipeline): + pipeline_obj = self.pipeline + else: + try: + pipeline_obj = PipelineFactory.get(self.pipeline, api_key=self.api_key) + except Exception: + raise Exception( + f"Pipeline Tool Unavailable. Make sure Pipeline '{self.pipeline}' exists or you have access to it." + ) + + if self.name.strip() == "": + self.name = pipeline_obj.name + + def __repr__(self) -> Text: + return f"PipelineTool(name={self.name}, pipeline={self.pipeline})" diff --git a/aixplain/modules/agent/tool/python_interpreter_tool.py b/aixplain/modules/agent/tool/python_interpreter_tool.py index d5947bff..42621c45 100644 --- a/aixplain/modules/agent/tool/python_interpreter_tool.py +++ b/aixplain/modules/agent/tool/python_interpreter_tool.py @@ -22,6 +22,7 @@ """ from aixplain.modules.agent.tool import Tool +from typing import Text class PythonInterpreterTool(Tool): @@ -41,3 +42,6 @@ def to_dict(self): def validate(self): pass + + def __repr__(self) -> Text: + return "PythonInterpreterTool()" diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index 81444db0..56cf116c 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -259,17 +259,18 @@ class SQLTool(Tool): def __init__( self, + name: Text, description: Text, database: Text, schema: Optional[Text] = None, tables: Optional[Union[List[Text], Text]] = None, enable_commit: bool = False, - name: Optional[Text] = None, **additional_info, ) -> None: """Tool to execute SQL query commands in an SQLite database. Args: + name (Text): name of the tool description (Text): description of the tool database (Text): database uri schema (Optional[Text]): database schema description @@ -277,7 +278,6 @@ def __init__( enable_commit (bool): enable to modify the database (optional) """ - name = name or "" super().__init__(name=name, description=description, **additional_info) self.database = database diff --git a/aixplain/v2/agent.py b/aixplain/v2/agent.py index af8c2d39..eadb155f 100644 --- a/aixplain/v2/agent.py +++ b/aixplain/v2/agent.py @@ -95,17 +95,22 @@ def create_model_tool( function: Union[Function, str] = None, supplier: Union["Supplier", str] = None, description: str = "", + name: Optional[str] = None, ): from aixplain.factories import AgentFactory - return AgentFactory.create_model_tool(model=model, function=function, supplier=supplier, description=description) + return AgentFactory.create_model_tool( + model=model, function=function, supplier=supplier, description=description, name=name + ) @classmethod - def create_pipeline_tool(cls, description: str, pipeline: Union["Pipeline", str]) -> "PipelineTool": + def create_pipeline_tool( + cls, description: str, pipeline: Union["Pipeline", str], name: Optional[str] = None + ) -> "PipelineTool": """Create a new pipeline tool.""" from aixplain.factories import AgentFactory - return AgentFactory.create_pipeline_tool(description=description, pipeline=pipeline) + return AgentFactory.create_pipeline_tool(description=description, pipeline=pipeline, name=name) @classmethod def create_python_interpreter_tool(cls) -> "PythonInterpreterTool": @@ -115,11 +120,13 @@ def create_python_interpreter_tool(cls) -> "PythonInterpreterTool": return AgentFactory.create_python_interpreter_tool() @classmethod - def create_custom_python_code_tool(cls, code: Union[str, Callable], description: str = "") -> "CustomPythonCodeTool": + def create_custom_python_code_tool( + cls, code: Union[str, Callable], name: str, description: str = "" + ) -> "CustomPythonCodeTool": """Create a new custom python code tool.""" from aixplain.factories import AgentFactory - return AgentFactory.create_custom_python_code_tool(code=code, description=description) + return AgentFactory.create_custom_python_code_tool(code=code, name=name, description=description) @classmethod def create_sql_tool( diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 0eecfa3f..ca3c4816 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -131,24 +131,27 @@ def test_python_interpreter_tool(delete_agents_and_team_agents, AgentFactory): def test_custom_code_tool(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents tool = AgentFactory.create_custom_python_code_tool( - description="Add two numbers", - code='def main(aaa: int, bbb: int) -> int:\n """Add two numbers"""\n return aaa + bbb', + description="Add two strings", + code='def main(aaa: str, bbb: str) -> str:\n """Add two strings"""\n return aaa + bbb', + name="Add Strings", ) assert tool is not None - assert tool.description == "Add two numbers" - assert tool.code == 'def main(aaa: int, bbb: int) -> int:\n """Add two numbers"""\n return aaa + bbb' + assert tool.description == "Add two strings" + assert tool.code == 'def main(aaa: str, bbb: str) -> str:\n """Add two strings"""\n return aaa + bbb' agent = AgentFactory.create( - name="Add Numbers Agent", - description="Add two numbers. Do not directly answer. Use the tool to add the numbers.", - instructions="Add two numbers. Do not directly answer. Use the tool to add the numbers.", + name="Add Strings Agent", + description="Add two strings. Do not directly answer. Use the tool to add the strings.", + instructions="Add two strings. Do not directly answer. Use the tool to add the strings.", tools=[tool], ) assert agent is not None - response = agent.run("How much is 12342 + 112312? Do not directly answer the question, call the tool.") + response = agent.run( + "What is the result of concatenating 'Hello' and 'World'? Do not directly answer the question, call the tool." + ) assert response is not None assert response["completed"] is True assert response["status"].lower() == "success" - assert "124654" in response["data"]["output"] + assert "HelloWorld" in response["data"]["output"] agent.delete() diff --git a/tests/unit/agent/agent_factory_utils_test.py b/tests/unit/agent/agent_factory_utils_test.py index 66a7a4e2..c0104f5c 100644 --- a/tests/unit/agent/agent_factory_utils_test.py +++ b/tests/unit/agent/agent_factory_utils_test.py @@ -3,6 +3,7 @@ from aixplain.factories.agent_factory.utils import build_tool, build_agent from aixplain.enums import Function, Supplier from aixplain.enums.asset_status import AssetStatus +from aixplain.modules.pipeline import Pipeline from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool @@ -10,7 +11,7 @@ from aixplain.modules.agent.tool.sql_tool import SQLTool from aixplain.modules.agent import Agent from aixplain.modules.agent.agent_task import AgentTask -from aixplain.factories.model_factory import ModelFactory +from aixplain.factories import ModelFactory, PipelineFactory @pytest.fixture @@ -139,6 +140,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): pytest.param( { "type": "sql", + "name": "Test SQL", "description": "Test SQL", "parameters": [ {"name": "database", "value": "test_db"}, @@ -149,6 +151,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): }, SQLTool, { + "name": "Test SQL", "description": "Test SQL", "database": "test_db", "schema": "public", @@ -160,6 +163,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): pytest.param( { "type": "sql", + "name": "Test SQL", "description": "Test SQL with string enable_commit", "parameters": [ {"name": "database", "value": "test_db"}, @@ -170,6 +174,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): }, SQLTool, { + "name": "Test SQL", "description": "Test SQL with string enable_commit", "database": "test_db", "schema": "public", @@ -180,17 +185,28 @@ def test_build_tool_error_cases(tool_dict, expected_error): ), ], ) -def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock_model): +def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock_model, mocker): """Test successful tool creation with various configurations.""" - with patch.object(ModelFactory, "get", return_value=mock_model): - tool = build_tool(tool_dict) - assert isinstance(tool, expected_type) + mocker.patch.object(ModelFactory, "get", return_value=mock_model) + mocker.patch( + "aixplain.modules.model.utils.parse_code_decorated", + return_value=("print('Hello World')", [], "Test description", "test_name"), + ) + if tool_dict["type"] == "pipeline": + mocker.patch.object( + PipelineFactory, + "get", + return_value=Pipeline(id=tool_dict["assetId"], description=tool_dict["description"], name="Pipeline", api_key=""), + ) - for attr, value in expected_attrs.items(): - if attr == "model": - assert tool.model == mock_model - else: - assert getattr(tool, attr) == value + tool = build_tool(tool_dict) + assert isinstance(tool, expected_type) + + for attr, value in expected_attrs.items(): + if attr == "model": + assert tool.model == mock_model + else: + assert getattr(tool, attr) == value @pytest.mark.parametrize( @@ -263,8 +279,14 @@ def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock ), ], ) -def test_build_agent_success_cases(payload, expected_attrs, mock_tools): +def test_build_agent_success_cases(payload, expected_attrs, mock_tools, mocker): """Test successful agent creation with various configurations.""" + mocker.patch.object( + PipelineFactory, + "get", + return_value=Pipeline(id="test_pipeline", description="Test pipeline", name="Pipeline", api_key=""), + ) + agent = build_agent(payload, tools=mock_tools if "assets" not in payload else None) assert isinstance(agent, Agent) diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index f767f0ec..3a151b08 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -12,7 +12,7 @@ from aixplain.modules.agent.utils import process_variables from urllib.parse import urljoin from unittest.mock import patch -from aixplain.enums.function import Function +from aixplain.enums import Function, Supplier from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData @@ -81,7 +81,11 @@ def test_success_query_content(): assert response["url"] == ref_response["data"] -def test_invalid_pipelinetool(): +def test_invalid_pipelinetool(mocker): + mocker.patch( + "aixplain.factories.model_factory.ModelFactory.get", + return_value=Model(id="6646261c6eb563165658bbb1", name="Test Model", function=Function.TEXT_GENERATION), + ) with pytest.raises(Exception) as exc_info: AgentFactory.create( name="Test", @@ -155,6 +159,7 @@ def test_create_agent(mock_model_factory_get): "utility": "custom_python_code", "utilityCode": "def main(query: str) -> str:\n return 'Hello, how are you?'", "description": "Test Tool", + "name": "Test Tool", }, { "type": "utility", @@ -188,7 +193,9 @@ def test_create_agent(mock_model_factory_get): supplier=Supplier.OPENAI, function="text-generation", description="Test Tool" ), AgentFactory.create_custom_python_code_tool( - code="def main(query: str) -> str:\n return 'Hello, how are you?'", description="Test Tool" + code="def main(query: str) -> str:\n return 'Hello, how are you?'", + description="Test Tool", + name="Test Tool", ), AgentFactory.create_python_interpreter_tool(), ], @@ -462,7 +469,7 @@ def test_agent_multiple_tools_api_key(): AgentFactory.create_model_tool(function="text-generation"), AgentFactory.create_python_interpreter_tool(), AgentFactory.create_custom_python_code_tool( - code="def main(query: str) -> str:\n return 'Hello'", description="Test Tool" + code="def main(query: str) -> str:\n return 'Hello'", description="Test Tool", name="Test Tool" ), ] @@ -551,19 +558,28 @@ def test_agent_response(): assert response["data"]["output"] == "new_output" -def test_custom_python_code_tool_initialization(): +def test_custom_python_code_tool_initialization(mocker): """Test basic initialization of CustomPythonCodeTool""" + mocker.patch( + "aixplain.modules.model.utils.parse_code_decorated", + return_value=("def main(query: str) -> str:\n return 'Hello'", [], "Test description", "HelloWorld"), + ) + code = "def main(query: str) -> str:\n return 'Hello'" description = "Test description" - tool = CustomPythonCodeTool(code=code, description=description) + tool = CustomPythonCodeTool(code=code, description=description, name="HelloWorld") assert tool.code == code assert tool.description == description - assert tool.name == "Custom Python Code" + assert tool.name == "HelloWorld" -def test_custom_python_code_tool_to_dict(): +def test_custom_python_code_tool_to_dict(mocker): """Test the to_dict method of CustomPythonCodeTool""" + mocker.patch( + "aixplain.modules.model.utils.parse_code_decorated", + return_value=("def main(query: str) -> str:\n return 'Hello'", [], "Test description", "HelloWorld"), + ) code = "def main(query: str) -> str:\n return 'Hello'" description = "Test description" tool = CustomPythonCodeTool(code=code, description=description) @@ -594,22 +610,17 @@ def test_custom_python_code_tool_validation(): assert tool.name == "test_name" -def test_custom_python_code_tool_validation_missing_description(): +def test_custom_python_code_tool_validation_missing_description(mocker): """Test validation fails when description is missing""" - with patch( - "aixplain.modules.model.utils.parse_code", - return_value=( - "def main(query: str) -> str:\n return 'Hello'", # code - [], # inputs - None, # description - "test_name", # name - ), - ): - code = "def main(query: str) -> str:\n return 'Hello'" - tool = CustomPythonCodeTool(code=code) - with pytest.raises(AssertionError) as exc_info: - tool.validate() - assert str(exc_info.value) == "Custom Python Code Tool Error: Tool description is required" + mocker.patch( + "aixplain.modules.model.utils.parse_code_decorated", + return_value=("def main(query: str) -> str:\n return 'Hello'", [], "", "HelloWorld"), + ) + + code = "def main(query: str) -> str:\n return 'Hello'" + with pytest.raises(AssertionError) as exc_info: + CustomPythonCodeTool(code=code) + assert str(exc_info.value) == "Custom Python Code Tool Error: Tool description is required" def test_custom_python_code_tool_validation_missing_code(): @@ -619,22 +630,10 @@ def test_custom_python_code_tool_validation_missing_code(): return_value=("", [], "Parsed description", "test_name"), # code # inputs # description # name ): with pytest.raises(AssertionError) as exc_info: - tool = CustomPythonCodeTool(code="", description="Test description") - tool.validate() + CustomPythonCodeTool(code="", description="Test description") assert str(exc_info.value) == "Custom Python Code Tool Error: Code is required" -def test_custom_python_code_tool_with_callable(): - """Test CustomPythonCodeTool with a callable function""" - - def test_function(query: str) -> str: - return "Hello" - - tool = CustomPythonCodeTool(code=test_function, description="Test description") - assert callable(tool.code) - assert tool.description == "Test description" - - @patch("aixplain.factories.model_factory.ModelFactory.get") def test_create_agent_with_model_instance(mock_model_factory_get): from aixplain.enums import Supplier, Function @@ -957,3 +956,44 @@ def test_agent_response_repr(): # Most importantly, verify that 'status' is complete (not 'tatus') assert "status=" in repr_str # Should find complete field name + + +@pytest.mark.parametrize( + "function,supplier,model,expected_name", + [ + (Function.TRANSLATION, None, None, "translation"), + (Function.TEXT_GENERATION, Supplier.AIXPLAIN, None, "text-generation-aixplain"), + (Function.TEXT_GENERATION, Supplier.OPENAI, None, "text-generation-openai"), + ( + Function.TEXT_GENERATION, + Supplier.AIXPLAIN, + Model(id="123", name="Test Model"), + "text-generation-aixplain-test_model", + ), + ], +) +def test_set_tool_name(function, supplier, model, expected_name): + from aixplain.modules.agent.tool.model_tool import set_tool_name + + name = set_tool_name(function, supplier, model) + assert name == expected_name + + +def test_create_agent_with_duplicate_tool_names(mocker): + from aixplain.factories import AgentFactory + from aixplain.modules import Model + from aixplain.modules.agent.tool.model_tool import ModelTool + + mocker.patch( + "aixplain.factories.model_factory.ModelFactory.get", + return_value=Model(id="123", name="Test Model", function=Function.TEXT_GENERATION), + ) + + # Create a ModelTool with a specific name + tool1 = ModelTool(model="123", name="Test Model") + tool2 = ModelTool(model="123", name="Test Model") + with pytest.raises(Exception) as exc_info: + AgentFactory.create(name="Test Agent", description="Test Agent Description", tools=[tool1, tool2]) + assert "Agent Creation Error - Duplicate tool names found: Test Model. Make sure all tool names are unique." in str( + exc_info.value + ) diff --git a/tests/unit/agent/model_tool_test.py b/tests/unit/agent/model_tool_test.py index c29df945..869979e1 100644 --- a/tests/unit/agent/model_tool_test.py +++ b/tests/unit/agent/model_tool_test.py @@ -64,7 +64,7 @@ def test_init_with_supplier_dict(): with patch("aixplain.modules.agent.tool.model_tool.Supplier") as mock_supplier: # Create a mock for the Supplier enum - mock_supplier.AIXPLAIN = type("MockSupplier", (), {"value": mock_enum["AIXPLAIN"]})() + mock_supplier.AIXPLAIN = type("MockSupplier", (), {"value": mock_enum["AIXPLAIN"], "name": "aiXplain"})() mock_supplier.return_value = mock_supplier.AIXPLAIN tool = ModelTool(function=Function.TRANSLATION, supplier=supplier_dict) @@ -87,8 +87,11 @@ def test_init_validation_errors(error_case, expected_error, expected_message): error_case() -def test_to_dict(mock_model, mock_model_factory): +def test_to_dict(mocker, mock_model, mock_model_factory): mock_model_factory.get.return_value = mock_model + + mocker.patch("aixplain.modules.agent.tool.model_tool.set_tool_name", return_value="test_tool_name") + tool = ModelTool( model="test_model_id", description="Test description", @@ -98,7 +101,7 @@ def test_to_dict(mock_model, mock_model_factory): expected = { "function": mock_model.function.value, "type": "model", - "name": "", + "name": "test_tool_name", "description": "Test description", "supplier": mock_model.supplier.value["code"], "version": None, @@ -172,9 +175,9 @@ def test_validate_parameters(mocker, mock_model, params, expected_result, error_ "tool_name,expected_name", [ ("custom_tool", "custom_tool"), - ("", ""), # Test empty name + ("", "translation-aixplain-test_model"), # Test empty name ("translation_model", "translation_model"), - (None, ""), # Test None value should default to empty string + (None, "translation-aixplain-test_model"), # Test None value should default to empty string ], ) def test_tool_name(mock_model, mock_model_factory, tool_name, expected_name): diff --git a/tests/unit/agent/sql_tool_test.py b/tests/unit/agent/sql_tool_test.py index 12170a23..1a1dbb84 100644 --- a/tests/unit/agent/sql_tool_test.py +++ b/tests/unit/agent/sql_tool_test.py @@ -55,7 +55,7 @@ def test_create_sql_tool(mocker, tmp_path): # Test SQLite source type tool = AgentFactory.create_sql_tool( - description="Test", source=db_path, source_type="sqlite", schema="test", tables=["test", "test2"] + name="Test SQL", description="Test", source=db_path, source_type="sqlite", schema="test", tables=["test", "test2"] ) assert isinstance(tool, SQLTool) assert tool.description == "Test" @@ -67,7 +67,9 @@ def test_create_sql_tool(mocker, tmp_path): df = pd.DataFrame({"id": [1, 2, 3], "name": ["test1", "test2", "test3"]}) df.to_csv(csv_path, index=False) # Test CSV source type - csv_tool = AgentFactory.create_sql_tool(description="Test CSV", source=csv_path, source_type="csv", tables=["data"]) + csv_tool = AgentFactory.create_sql_tool( + name="Test CSV", description="Test CSV", source=csv_path, source_type="csv", tables=["data"] + ) assert isinstance(csv_tool, SQLTool) assert csv_tool.description == "Test CSV" assert csv_tool.database.endswith(".db") @@ -160,36 +162,37 @@ def test_sql_tool_validation_errors(tmp_path): # Test missing description with pytest.raises(SQLToolError, match="Description is required"): - tool = AgentFactory.create_sql_tool(description="", source=db_path, source_type="sqlite") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="", source=db_path, source_type="sqlite") tool.validate() # Test missing source with pytest.raises(SQLToolError, match="Source must be provided"): - tool = AgentFactory.create_sql_tool(description="Test", source="", source_type="sqlite") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source="", source_type="sqlite") tool.validate() # Test missing source_type with pytest.raises(TypeError, match="missing 1 required positional argument: 'source_type'"): - tool = AgentFactory.create_sql_tool(description="Test", source=db_path) + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=db_path) tool.validate() # Test invalid source type with pytest.raises(SQLToolError, match="Invalid source type"): - AgentFactory.create_sql_tool(description="Test", source=db_path, source_type="invalid") + AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=db_path, source_type="invalid") # Test non-existent SQLite database with pytest.raises(SQLToolError, match="Database .* does not exist"): - tool = AgentFactory.create_sql_tool(description="Test", source="nonexistent.db", source_type="sqlite") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source="nonexistent.db", source_type="sqlite") tool.validate() # Test non-existent CSV file with pytest.raises(SQLToolError, match="CSV file .* does not exist"): - tool = AgentFactory.create_sql_tool(description="Test", source="nonexistent.csv", source_type="csv") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source="nonexistent.csv", source_type="csv") tool.validate() # Test PostgreSQL (not supported) with pytest.raises(SQLToolError, match="PostgreSQL is not supported yet"): tool = AgentFactory.create_sql_tool( + name="Test SQL", description="Test", source="postgresql://user:pass@localhost/mydb", source_type="postgresql", @@ -210,7 +213,7 @@ def test_create_sql_tool_with_schema_inference(tmp_path, mocker): conn.close() # Create tool without schema and tables - tool = AgentFactory.create_sql_tool(description="Test", source=db_path, source_type="sqlite") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=db_path, source_type="sqlite") # Mock schema inference schema = "CREATE TABLE test (id INTEGER, name TEXT)" @@ -239,7 +242,7 @@ def test_create_sql_tool_from_csv_with_warnings(tmp_path, mocker): # Create tool and check for warnings with pytest.warns(UserWarning) as record: - tool = AgentFactory.create_sql_tool(description="Test", source=csv_path, source_type="csv") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=csv_path, source_type="csv") # Verify warnings about column name changes warning_messages = [str(w.message) for w in record] @@ -274,7 +277,9 @@ def test_create_sql_tool_from_csv(tmp_path): df.to_csv(csv_path, index=False) # Test successful creation - tool = AgentFactory.create_sql_tool(description="Test", source=csv_path, source_type="csv", tables=["test"]) + tool = AgentFactory.create_sql_tool( + name="Test SQL", description="Test", source=csv_path, source_type="csv", tables=["test"] + ) assert isinstance(tool, SQLTool) assert tool.description == "Test" assert tool.database.endswith(".db") @@ -301,7 +306,7 @@ def test_sql_tool_schema_inference(tmp_path): df.to_csv(csv_path, index=False) # Create tool without schema and tables - tool = AgentFactory.create_sql_tool(description="Test", source=csv_path, source_type="csv") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=csv_path, source_type="csv") try: tool.validate() @@ -324,15 +329,19 @@ def test_create_sql_tool_source_type_handling(tmp_path): conn.close() # Test with string input - tool_str = AgentFactory.create_sql_tool(description="Test", source=db_path, source_type="sqlite", schema="test") + tool_str = AgentFactory.create_sql_tool( + name="Test SQL", description="Test", source=db_path, source_type="sqlite", schema="test" + ) assert isinstance(tool_str, SQLTool) # Test with enum input tool_enum = AgentFactory.create_sql_tool( - description="Test", source=db_path, source_type=DatabaseSourceType.SQLITE, schema="test" + name="Test SQL", description="Test", source=db_path, source_type=DatabaseSourceType.SQLITE, schema="test" ) assert isinstance(tool_enum, SQLTool) # Test invalid type with pytest.raises(SQLToolError, match="Source type must be either a string or DatabaseSourceType enum, got "): - AgentFactory.create_sql_tool(description="Test", source=db_path, source_type=123, schema="test") # Invalid type \ No newline at end of file + AgentFactory.create_sql_tool( + name="Test SQL", description="Test", source=db_path, source_type=123, schema="test" + ) # Invalid type From 9796d197a19c05d6ccb49c12af128b65b13f5018 Mon Sep 17 00:00:00 2001 From: kadirpekel Date: Wed, 30 Apr 2025 14:47:21 +0200 Subject: [PATCH 15/62] Eng 2051 Improvements on CI flow (#509) * ENG-2051: improvements on CI workflow * Update main.yaml --- .github/workflows/main.yaml | 87 +++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 18 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index faf380eb..9f1ac34a 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -9,29 +9,79 @@ on: workflow_dispatch: jobs: - setup-and-test: + test: runs-on: ubuntu-latest + timeout-minutes: 45 # Global timeout fallback strategy: fail-fast: false matrix: test-suite: [ - 'tests/unit', - 'tests/functional/file_asset', - 'tests/functional/data_asset', - 'tests/functional/benchmark', - 'tests/functional/model', - 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v1 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v2 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v1 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v2 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/designer_test.py', - 'tests/functional/pipelines/create_test.py', - 'tests/functional/finetune --sdk_version v1 --sdk_version_param FinetuneFactory', - 'tests/functional/finetune --sdk_version v2 --sdk_version_param FinetuneFactory', - 'tests/functional/general_assets', - 'tests/functional/apikey', - 'tests/functional/agent tests/functional/team_agent', + 'unit', + 'file_asset', + 'data_asset', + 'benchmark', + 'model', + 'pipeline_2.0_v1', + 'pipeline_2.0_v2', + 'pipeline_3.0_v1', + 'pipeline_3.0_v2', + 'pipeline_designer', + 'pipeline_create', + 'finetune_v1', + 'finetune_v2', + 'general_assets', + 'apikey', + 'agent_and_team_agent', ] + include: + - test-suite: 'unit' + path: 'tests/unit' + timeout: 45 # Tweaked timeout for each unit tests + - test-suite: 'file_asset' + path: 'tests/functional/file_asset' + timeout: 45 + - test-suite: 'data_asset' + path: 'tests/functional/data_asset' + timeout: 45 + - test-suite: 'benchmark' + path: 'tests/functional/benchmark' + timeout: 45 + - test-suite: 'model' + path: 'tests/functional/model' + timeout: 45 + - test-suite: 'pipeline_2.0_v1' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v1 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_2.0_v2' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v2 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_3.0_v1' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v1 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_3.0_v2' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v2 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_designer' + path: 'tests/functional/pipelines/designer_test.py' + timeout: 45 + - test-suite: 'pipeline_create' + path: 'tests/functional/pipelines/create_test.py' + timeout: 45 + - test-suite: 'finetune_v1' + path: 'tests/functional/finetune --sdk_version v1 --sdk_version_param FinetuneFactory' + timeout: 45 + - test-suite: 'finetune_v2' + path: 'tests/functional/finetune --sdk_version v2 --sdk_version_param FinetuneFactory' + timeout: 45 + - test-suite: 'general_assets' + path: 'tests/functional/general_assets' + timeout: 45 + - test-suite: 'apikey' + path: 'tests/functional/apikey' + timeout: 45 + - test-suite: 'agent_and_team_agent' + path: 'tests/functional/agent tests/functional/team_agent' + timeout: 45 steps: - name: Checkout repository uses: actions/checkout@v4 @@ -62,4 +112,5 @@ jobs: fi - name: Run Tests - run: python -m pytest ${{ matrix.test-suite}} \ No newline at end of file + timeout-minutes: ${{ matrix.timeout }} + run: python -m pytest ${{ matrix.path }} From f0837fca2b737d0491fd66932b8204d8d785dee8 Mon Sep 17 00:00:00 2001 From: Muhammad-Elmallah <145364766+Muhammad-Elmallah@users.noreply.github.com> Date: Mon, 5 May 2025 14:48:27 +0300 Subject: [PATCH 16/62] Add a finetuned version of BGE model (#512) * Add a finetuned version of BGE model * changin mode name --- aixplain/enums/embedding_model.py | 4 +--- tests/functional/model/run_model_test.py | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aixplain/enums/embedding_model.py b/aixplain/enums/embedding_model.py index c52387b2..31618580 100644 --- a/aixplain/enums/embedding_model.py +++ b/aixplain/enums/embedding_model.py @@ -27,9 +27,7 @@ class EmbeddingModel(str, Enum): JINA_CLIP_V2_MULTIMODAL = "67c5f705d8f6a65d6f74d732" MULTILINGUAL_E5_LARGE = "67efd0772a0a850afa045af3" BGE_M3 = "67f401032a0a850afa045b19" - - - + AIXPLAIN_LEGAL_EMBEDDINGS = "681254b668e47e7844c1f15a" def __str__(self): return self._value_ diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 916d6077..1aed18df 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -99,6 +99,7 @@ def run_index_model(index_model): pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, AirParams, id="AIR - Snowflake Arctic Embed L v2.0"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="AIR - BGE M3"), + pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="AIR - aiXplain Legal Embeddings"), ], ) def test_index_model(embedding_model, supplier_params): @@ -123,6 +124,7 @@ def test_index_model(embedding_model, supplier_params): pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"), + pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="aiXplain Legal Embeddings"), ], ) def test_index_model_with_filter(embedding_model, supplier_params): From 511bf5f570dcec7011b2ac011653b319d7018e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Tue, 6 May 2025 13:33:05 +0300 Subject: [PATCH 17/62] ENG-2055-Aixplain-SDK-Centralized-Error-Handling (#510) * move request_with_retry from file_utils to request_utils because of circular import * added get error from response status for model/pipeline calls * added error codes * Small refactor and error code inclusion --------- Co-authored-by: Thiago Castro Ferreira --- .pre-commit-config.yaml | 2 +- aixplain/enums/license.py | 2 +- aixplain/enums/supplier.py | 2 +- aixplain/exceptions/__init__.py | 116 ++++++++++ aixplain/exceptions/types.py | 217 ++++++++++++++++++ aixplain/factories/agent_factory/__init__.py | 2 +- aixplain/factories/api_key_factory.py | 2 +- aixplain/factories/benchmark_factory.py | 2 +- aixplain/factories/corpus_factory.py | 2 +- aixplain/factories/data_factory.py | 8 +- aixplain/factories/dataset_factory.py | 3 +- .../factories/finetune_factory/__init__.py | 4 +- aixplain/factories/metric_factory.py | 6 +- aixplain/factories/model_factory/__init__.py | 2 +- aixplain/factories/model_factory/utils.py | 2 +- .../factories/pipeline_factory/__init__.py | 2 +- .../factories/team_agent_factory/__init__.py | 2 +- aixplain/factories/wallet_factory.py | 2 +- aixplain/modules/agent/__init__.py | 2 +- aixplain/modules/api_key.py | 2 +- aixplain/modules/benchmark.py | 2 +- aixplain/modules/benchmark_job.py | 5 +- aixplain/modules/corpus.py | 4 +- aixplain/modules/dataset.py | 2 +- aixplain/modules/finetune/__init__.py | 2 +- aixplain/modules/model/__init__.py | 5 +- aixplain/modules/model/llm_model.py | 1 + aixplain/modules/model/response.py | 6 + aixplain/modules/model/utility_model.py | 2 +- aixplain/modules/model/utils.py | 21 +- aixplain/modules/pipeline/asset.py | 24 +- aixplain/modules/team_agent/__init__.py | 2 +- .../data_onboarding/onboard_functions.py | 5 +- aixplain/utils/file_utils.py | 21 +- tests/test_utils.py | 4 +- tests/unit/llm_test.py | 8 +- tests/unit/model_test.py | 8 +- tests/unit/pipeline_test.py | 132 +---------- 38 files changed, 410 insertions(+), 226 deletions(-) create mode 100644 aixplain/exceptions/__init__.py create mode 100644 aixplain/exceptions/types.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 456aba3b..2500c2c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,4 +22,4 @@ repos: hooks: - id: flake8 args: # arguments to configure flake8 - - --ignore=E402,E501,E203 \ No newline at end of file + - --ignore=E402,E501,E203,W503 \ No newline at end of file diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index a860a539..566be092 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -25,8 +25,8 @@ from enum import Enum from urllib.parse import urljoin from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.utils.request_utils import _request_with_retry CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" diff --git a/aixplain/enums/supplier.py b/aixplain/enums/supplier.py index 26058bf5..18a3e81d 100644 --- a/aixplain/enums/supplier.py +++ b/aixplain/enums/supplier.py @@ -24,7 +24,7 @@ import logging from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from enum import Enum from urllib.parse import urljoin import re diff --git a/aixplain/exceptions/__init__.py b/aixplain/exceptions/__init__.py new file mode 100644 index 00000000..5d645c62 --- /dev/null +++ b/aixplain/exceptions/__init__.py @@ -0,0 +1,116 @@ +""" +Error message registry for aiXplain SDK. + +This module maintains a centralized registry of error messages used throughout the aiXplain ecosystem. +It allows developers to look up existing error messages and reuse them instead of creating new ones. +""" + +from aixplain.exceptions.types import ( + AixplainBaseException, + AuthenticationError, + ValidationError, + ResourceError, + BillingError, + SupplierError, + NetworkError, + ServiceError, + InternalError, +) + + +def get_error_from_status_code(status_code: int, error_details: str = None) -> AixplainBaseException: + """ + Map HTTP status codes to appropriate exception types. + + Args: + status_code (int): The HTTP status code to map. + default_message (str, optional): The default message to use if no specific message is available. + + Returns: + AixplainBaseException: An exception of the appropriate type. + """ + try: + if isinstance(status_code, str): + status_code = int(status_code) + except Exception as e: + raise InternalError(f"Failed to get status code from {status_code}: {e}") from e + + error_details = f"Details: {error_details}" if error_details else "" + if status_code == 400: + return ValidationError( + message=f"Bad request: Please verify the request payload and ensure it is correct. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 401: + return AuthenticationError( + message=f"Unauthorized API key: Please verify the spelling of the API key and its current validity. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 402: + return BillingError( + message=f"Payment required: Please ensure you have enough credits to run this asset. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 403: + # 403 could be auth or resource, using ResourceError as a general 'forbidden' + return ResourceError( + message=f"Forbidden access: Please verify the API key and its current validity. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 404: + # Added 404 mapping + return ResourceError( + message=f"Resource not found: Please verify the spelling of the resource and its current availability. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 429: + # Using SupplierError for rate limiting as per your original function + return SupplierError( + message=f"Rate limit exceeded: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif status_code == 500: + return InternalError( + message=f"Internal server error: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif status_code == 503: + return ServiceError( + message=f"Service unavailable: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif status_code == 504: + return NetworkError( + message=f"Gateway timeout: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif 460 <= status_code < 470: + return ResourceError( + message=f"Subscription-related error: Please ensure that your subscription is active and has not expired. {error_details}".strip(), + status_code=status_code, + ) + elif 470 <= status_code < 480: + return BillingError( + message=f"Billing-related error: Please ensure you have enough credits to run this asset. {error_details}".strip(), + status_code=status_code, + ) + elif 480 <= status_code < 490: + return SupplierError( + message=f"Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access. {error_details}".strip(), + status_code=status_code, + ) + elif 490 <= status_code < 500: + return ValidationError( + message=f"Validation-related error: Please verify the request payload and ensure it is correct. {error_details}".strip(), + status_code=status_code, + ) + else: + # Catch-all for other client/server errors + category = "Client" if 400 <= status_code < 500 else "Server" + return InternalError( + message=f"Unspecified {category} Error (Status {status_code}) {error_details}".strip(), status_code=status_code + ) diff --git a/aixplain/exceptions/types.py b/aixplain/exceptions/types.py new file mode 100644 index 00000000..e41c7fd2 --- /dev/null +++ b/aixplain/exceptions/types.py @@ -0,0 +1,217 @@ +from enum import Enum +from typing import Optional, Dict, Any + + +class ErrorSeverity(str, Enum): + """Severity levels for errors.""" + + INFO = "info" # Informational, not an error + WARNING = "warning" # Warning, operation can continue + ERROR = "error" # Error, operation cannot continue + CRITICAL = "critical" # System stability might be compromised + + +class ErrorCategory(Enum): + """Categorizes errors by their domain.""" + + AUTHENTICATION = "authentication" # API keys, permissions + VALIDATION = "validation" # Input validation + RESOURCE = "resource" # Resource availability + BILLING = "billing" # Credits, payment + SUPPLIER = "supplier" # External supplier issues + NETWORK = "network" # Network connectivity + SERVICE = "service" # Service availability + INTERNAL = "internal" # Internal system errors + AGENT = "agent" # Agent-specific errors + UNKNOWN = "unknown" # Uncategorized errors + + +class ErrorCode(str, Enum): + """Standard error codes for aiXplain exceptions. + + The format is AX--, where is a short identifier + derived from the ErrorCategory (e.g., AUTH, VAL, RES) and is a + unique sequential number within that category, starting from 1000. + + How to Add a New Error Code: + 1. Identify the appropriate `ErrorCategory` for the new error. + 2. Determine the next available sequential ID within that category. + For example, if `AX-AUTH-1000` exists, the next authentication-specific + error could be `AX-AUTH-1001`. + 3. Define the new enum member using the format `AX--`. + Use a concise abbreviation for the category (e.g., AUTH, VAL, RES, BIL, + SUP, NET, SVC, INT). + 4. Assign the string value (e.g., `"AX-AUTH-1001"`). + 5. Add a clear docstring explaining the specific condition that triggers + this error code. + 6. (Optional but recommended) Consider creating a more specific exception + class inheriting from the corresponding category exception (e.g., + `class InvalidApiKeyError(AuthenticationError): ...`) and assign the + new error code to it. + """ + + AX_AUTH_ERROR = "AX-AUTH-1000" # General authentication error. Use for issues like invalid API keys, insufficient permissions, or failed login attempts. + AX_VAL_ERROR = "AX-VAL-1000" # General validation error. Use when user-provided input fails validation checks (e.g., incorrect data type, missing required fields, invalid format. + AX_RES_ERROR = "AX-RES-1000" # General resource error. Use for issues related to accessing or managing resources, such as a requested model being unavailable or quota limits exceeded. + AX_BIL_ERROR = "AX-BIL-1000" # General billing error. Use for problems related to billing, payments, or credits (e.g., insufficient funds, expired subscription. + AX_SUP_ERROR = "AX-SUP-1000" # General supplier error. Use when an error originates from an external supplier or third-party service integrated with aiXplain. + AX_NET_ERROR = "AX-NET-1000" # General network error. Use for issues related to network connectivity, such as timeouts, DNS resolution failures, or unreachable services. + AX_SVC_ERROR = "AX-SVC-1000" # General service error. Use when a specific aiXplain service or endpoint is unavailable or malfunctioning (e.g., service downtime, internal component failure. + AX_INT_ERROR = "AX-INT-1000" # General internal error. Use for unexpected server-side errors that are not covered by other categories. This often indicates a bug or an issue within the aiXplain platform itself. + + +class AixplainBaseException(Exception): + """Base exception class for all aiXplain exceptions.""" + + def __init__( + self, + message: str, + category: ErrorCategory = ErrorCategory.UNKNOWN, + severity: ErrorSeverity = ErrorSeverity.ERROR, + status_code: Optional[int] = None, + details: Optional[Dict[str, Any]] = None, + retry_recommended: bool = False, + error_code: Optional[ErrorCode] = None, + ): + self.message = message + self.category = category + self.severity = severity + self.status_code = status_code + self.details = details or {} + self.retry_recommended = retry_recommended + self.error_code = error_code + super().__init__(self.message) + + def __str__(self): + error_code_str = f" [{self.error_code}]" if self.error_code else "" + return f"{self.__class__.__name__}{error_code_str}: {self.message}" + + def to_dict(self) -> Dict[str, Any]: + """Convert exception to dictionary for serialization.""" + return { + "message": self.message, + "category": self.category.value, + "severity": self.severity.value, + "status_code": self.status_code, + "details": self.details, + "retry_recommended": self.retry_recommended, + "error_code": self.error_code.value if self.error_code else None, + } + + +class AuthenticationError(AixplainBaseException): + """Raised when authentication fails.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.AUTHENTICATION, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_AUTH_ERROR, + **kwargs, + ) + + +class ValidationError(AixplainBaseException): + """Raised when input validation fails.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.VALIDATION, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_VAL_ERROR, + **kwargs, + ) + + +class ResourceError(AixplainBaseException): + """Raised when a resource is unavailable.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.RESOURCE, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_RES_ERROR, + **kwargs, + ) + + +class BillingError(AixplainBaseException): + """Raised when there are billing issues.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.BILLING, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_BIL_ERROR, + **kwargs, + ) + + +class SupplierError(AixplainBaseException): + """Raised when there are issues with external suppliers.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.SUPPLIER, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", True), + error_code=ErrorCode.AX_SUP_ERROR, + **kwargs, + ) + + +class NetworkError(AixplainBaseException): + """Raised when there are network connectivity issues.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.NETWORK, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", True), + error_code=ErrorCode.AX_NET_ERROR, + **kwargs, + ) + + +class ServiceError(AixplainBaseException): + """Raised when a service is unavailable.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.SERVICE, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", True), + error_code=ErrorCode.AX_SVC_ERROR, + **kwargs, + ) + + +class InternalError(AixplainBaseException): + """Raised when there is an internal system error.""" + + def __init__(self, message: str, **kwargs): + # Server errors (5xx) should generally be retryable + status_code = kwargs.get("status_code") + retry_recommended = kwargs.pop("retry_recommended", False) + if status_code and status_code in [500, 502, 503, 504]: + retry_recommended = True + + super().__init__( + message=message, + category=ErrorCategory.INTERNAL, + severity=ErrorSeverity.ERROR, + retry_recommended=retry_recommended, + error_code=ErrorCode.AX_INT_ERROR, + **kwargs, + ) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 040bcd71..9440ff81 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -41,7 +41,7 @@ from aixplain.utils import config from typing import Callable, Dict, List, Optional, Text, Union -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from aixplain.enums import DatabaseSourceType diff --git a/aixplain/factories/api_key_factory.py b/aixplain/factories/api_key_factory.py index c719c26b..3d081e27 100644 --- a/aixplain/factories/api_key_factory.py +++ b/aixplain/factories/api_key_factory.py @@ -3,7 +3,7 @@ import aixplain.utils.config as config from datetime import datetime from typing import Text, List, Optional, Dict, Union -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.api_key import APIKey, APIKeyLimits, APIKeyUsageLimit diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index 3f643a3d..c37f17a8 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -32,7 +32,7 @@ from aixplain.factories.dataset_factory import DatasetFactory from aixplain.factories.model_factory import ModelFactory from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin diff --git a/aixplain/factories/corpus_factory.py b/aixplain/factories/corpus_factory.py index db7aa44e..9563ad14 100644 --- a/aixplain/factories/corpus_factory.py +++ b/aixplain/factories/corpus_factory.py @@ -38,7 +38,7 @@ from aixplain.enums.language import Language from aixplain.enums.license import License from aixplain.enums.privacy import Privacy -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.utils import config from pathlib import Path from tqdm import tqdm diff --git a/aixplain/factories/data_factory.py b/aixplain/factories/data_factory.py index 1879b321..65aa8a87 100644 --- a/aixplain/factories/data_factory.py +++ b/aixplain/factories/data_factory.py @@ -28,15 +28,11 @@ from aixplain.modules.data import Data from aixplain.enums.data_subtype import DataSubtype from aixplain.enums.data_type import DataType -from aixplain.enums.function import Function from aixplain.enums.language import Language -from aixplain.enums.license import License from aixplain.enums.privacy import Privacy -from aixplain.utils.file_utils import _request_with_retry -from aixplain.utils import config -from typing import Any, Dict, List, Text +from aixplain.utils.request_utils import _request_with_retry +from typing import Dict, Text from urllib.parse import urljoin -from uuid import uuid4 class DataFactory(AssetFactory): diff --git a/aixplain/factories/dataset_factory.py b/aixplain/factories/dataset_factory.py index ca9d993e..3b86b45e 100644 --- a/aixplain/factories/dataset_factory.py +++ b/aixplain/factories/dataset_factory.py @@ -41,7 +41,8 @@ from aixplain.enums.privacy import Privacy from aixplain.utils import config from aixplain.utils.convert_datatype_utils import dict_to_metadata -from aixplain.utils.file_utils import _request_with_retry, s3_to_csv +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.file_utils import s3_to_csv from aixplain.utils.validation_utils import dataset_onboarding_validation from pathlib import Path from tqdm import tqdm diff --git a/aixplain/factories/finetune_factory/__init__.py b/aixplain/factories/finetune_factory/__init__.py index 238d0d0c..b6006f0b 100644 --- a/aixplain/factories/finetune_factory/__init__.py +++ b/aixplain/factories/finetune_factory/__init__.py @@ -33,7 +33,7 @@ from aixplain.modules.dataset import Dataset from aixplain.modules.model import Model from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin @@ -98,7 +98,7 @@ def create( if prompt_template is not None: prompt_template = validate_prompt(prompt_template, dataset_list) try: - url = urljoin(cls.backend_url, f"sdk/finetune/cost-estimation") + url = urljoin(cls.backend_url, "sdk/finetune/cost-estimation") headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} payload = { "datasets": [ diff --git a/aixplain/factories/metric_factory.py b/aixplain/factories/metric_factory.py index 9f42fb3e..6279ffc1 100644 --- a/aixplain/factories/metric_factory.py +++ b/aixplain/factories/metric_factory.py @@ -22,14 +22,12 @@ """ import logging -import os from typing import List, Optional from aixplain.modules import Metric from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from typing import Dict, Text from urllib.parse import urljoin -from warnings import warn class MetricFactory: @@ -113,7 +111,7 @@ def list( List[Metric]: List of supported metrics """ try: - url = urljoin(cls.backend_url, f"sdk/metrics") + url = urljoin(cls.backend_url, "sdk/metrics") filter_params = {} if model_id is not None: filter_params["modelId"] = model_id diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index b39cc668..85c1ac4f 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -27,7 +27,7 @@ from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 1be2186c..418a32b9 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -6,7 +6,7 @@ from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.enums import DataType, Function, Language, OwnershipType, Supplier, SortBy, SortOrder, AssetStatus from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from datetime import datetime from typing import Dict, Union, List, Optional, Tuple from urllib.parse import urljoin diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index cfbfce54..ba164199 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -31,7 +31,7 @@ from aixplain.modules.model import Model from aixplain.modules.pipeline import Pipeline from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from warnings import warn diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index 892d7ded..6a1db846 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -31,7 +31,7 @@ from aixplain.modules.team_agent import TeamAgent, InspectorTarget from aixplain.utils import config from aixplain.factories.team_agent_factory.utils import build_team_agent -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class TeamAgentFactory: diff --git a/aixplain/factories/wallet_factory.py b/aixplain/factories/wallet_factory.py index 1591dc2e..2de28ec4 100644 --- a/aixplain/factories/wallet_factory.py +++ b/aixplain/factories/wallet_factory.py @@ -1,6 +1,6 @@ import aixplain.utils.config as config from aixplain.modules.wallet import Wallet -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry import logging from typing import Text diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 227c0040..b6920f5c 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -26,7 +26,7 @@ import time import traceback -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier from aixplain.enums.asset_status import AssetStatus diff --git a/aixplain/modules/api_key.py b/aixplain/modules/api_key.py index ae774c23..d606e106 100644 --- a/aixplain/modules/api_key.py +++ b/aixplain/modules/api_key.py @@ -1,6 +1,6 @@ import logging from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.modules import Model from datetime import datetime from typing import Dict, List, Optional, Text, Union diff --git a/aixplain/modules/benchmark.py b/aixplain/modules/benchmark.py index d76b2e62..6878becf 100644 --- a/aixplain/modules/benchmark.py +++ b/aixplain/modules/benchmark.py @@ -26,7 +26,7 @@ from aixplain.modules import Asset, Dataset, Metric, Model from aixplain.modules.benchmark_job import BenchmarkJob from urllib.parse import urljoin -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class Benchmark(Asset): diff --git a/aixplain/modules/benchmark_job.py b/aixplain/modules/benchmark_job.py index 29a33aa7..cd17c0e1 100644 --- a/aixplain/modules/benchmark_job.py +++ b/aixplain/modules/benchmark_job.py @@ -4,7 +4,8 @@ from urllib.parse import urljoin import pandas as pd from pathlib import Path -from aixplain.utils.file_utils import _request_with_retry, save_file +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.file_utils import save_file class BenchmarkJob: @@ -72,7 +73,7 @@ def download_results_as_csv(self, save_path: Optional[Text] = None, return_dataf logging.info(f"Downloading Benchmark Results: Status of downloading results for {self.id}: {resp}") if "reportUrl" not in resp or resp["reportUrl"] == "": logging.error( - f"Downloading Benchmark Results: Can't get download results as they aren't generated yet. Please wait for a while." + "Downloading Benchmark Results: Can't get download results as they aren't generated yet. Please wait for a while." ) return None csv_url = resp["reportUrl"] diff --git a/aixplain/modules/corpus.py b/aixplain/modules/corpus.py index b65664b6..10101292 100644 --- a/aixplain/modules/corpus.py +++ b/aixplain/modules/corpus.py @@ -29,8 +29,8 @@ from aixplain.modules.asset import Asset from aixplain.modules.data import Data from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry -from typing import Any, List, Optional, Text +from aixplain.utils.request_utils import _request_with_retry +from typing import List, Optional, Text from urllib.parse import urljoin diff --git a/aixplain/modules/dataset.py b/aixplain/modules/dataset.py index fd79e9f3..85264013 100644 --- a/aixplain/modules/dataset.py +++ b/aixplain/modules/dataset.py @@ -30,7 +30,7 @@ from aixplain.modules.asset import Asset from aixplain.modules.data import Data from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from typing import Any, Dict, List, Optional, Text diff --git a/aixplain/modules/finetune/__init__.py b/aixplain/modules/finetune/__init__.py index fe2cb15c..15cc37a7 100644 --- a/aixplain/modules/finetune/__init__.py +++ b/aixplain/modules/finetune/__init__.py @@ -31,7 +31,7 @@ from aixplain.modules.model import Model from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class Finetune(Asset): diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index adedfcfb..4729dd42 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -28,7 +28,7 @@ from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config from urllib.parse import urljoin -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from typing import Union, Optional, Text, Dict from datetime import datetime from aixplain.modules.model.response import ModelResponse @@ -199,6 +199,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: status = ResponseStatus.FAILED else: status = ResponseStatus.IN_PROGRESS + logging.debug(f"Single Poll for Model: Status of polling for {name}: {resp}") return ModelResponse( status=resp.pop("status", status), @@ -209,6 +210,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: used_credits=resp.pop("usedCredits", 0), run_time=resp.pop("runTime", 0), usage=resp.pop("usage", None), + error_code=resp.get("error_code", None), **resp, ) except Exception as e: @@ -264,6 +266,7 @@ def run( used_credits=response.pop("usedCredits", 0), run_time=response.pop("runTime", 0), usage=response.pop("usage", None), + error_code=response.get("error_code", None), **response, ) diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index cf60d0a2..0b8e6cf8 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -166,6 +166,7 @@ def run( used_credits=response.pop("usedCredits", 0), run_time=response.pop("runTime", 0), usage=response.pop("usage", None), + error_code=response.get("error_code", None), **response, ) diff --git a/aixplain/modules/model/response.py b/aixplain/modules/model/response.py index ac9f8184..a0cf08f8 100644 --- a/aixplain/modules/model/response.py +++ b/aixplain/modules/model/response.py @@ -1,5 +1,6 @@ from typing import Text, Any, Optional, Dict, List, Union from aixplain.enums import ResponseStatus +from aixplain.exceptions.types import ErrorCode class ModelResponse: @@ -16,6 +17,7 @@ def __init__( run_time: float = 0.0, usage: Optional[Dict] = None, url: Optional[Text] = None, + error_code: Optional[ErrorCode] = None, **kwargs, ): self.status = status @@ -31,6 +33,7 @@ def __init__( self.run_time = run_time self.usage = usage self.url = url + self.error_code = error_code self.additional_fields = kwargs def __getitem__(self, key: Text) -> Any: @@ -82,6 +85,8 @@ def __repr__(self) -> str: fields.append(f"usage={self.usage}") if self.url: fields.append(f"url='{self.url}'") + if self.error_code: + fields.append(f"error_code='{self.error_code}'") if self.additional_fields: fields.extend([f"{k}={repr(v)}" for k, v in self.additional_fields.items()]) return f"ModelResponse({', '.join(fields)})" @@ -104,6 +109,7 @@ def to_dict(self) -> Dict[Text, Any]: "run_time": self.run_time, "usage": self.usage, "url": self.url, + "error_code": self.error_code, } if self.additional_fields: base_dict.update(self.additional_fields) diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 96454181..043571c2 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -24,7 +24,7 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model import Model from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.model.utils import parse_code_decorated from dataclasses import dataclass from typing import Callable, Union, Optional, List, Text, Dict diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 7ba42f2d..6f3f9319 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -4,6 +4,7 @@ import logging from aixplain.utils.file_utils import _request_with_retry from typing import Callable, Dict, List, Text, Tuple, Union, Optional +from aixplain.exceptions import get_error_from_status_code def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): @@ -61,22 +62,12 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: else: response = resp else: - resp = resp["error"] if isinstance(resp, dict) and "error" in resp else resp - if r.status_code == 401: - error = f"Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {resp}" - elif 460 <= r.status_code < 470: - error = f"Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {resp}" - elif 470 <= r.status_code < 480: - error = f"Billing-related error: Please ensure you have enough credits to run this model. Details: {resp}" - elif 480 <= r.status_code < 490: - error = f"Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {resp}" - elif 490 <= r.status_code < 500: - error = f"{resp}" - else: - status_code = str(r.status_code) - error = f"Status {status_code} - Unspecified error: {resp}" - response = {"status": "FAILED", "error_message": error, "completed": True} + error_details = resp["error"] if isinstance(resp, dict) and "error" in resp else resp + status_code = r.status_code + error = get_error_from_status_code(status_code, error_details) + logging.error(f"Error in request: {r.status_code}: {error}") + response = {"status": "FAILED", "error_message": error.message, "completed": True} return response diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 0642e6ee..63eac9b7 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -29,10 +29,11 @@ from aixplain.enums.response_status import ResponseStatus from aixplain.modules.asset import Asset from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin from aixplain.modules.pipeline.response import PipelineResponse +from aixplain.exceptions import get_error_from_status_code class Pipeline(Asset): @@ -410,33 +411,20 @@ def run_async( return res else: - if r.status_code == 401: - error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." - elif 460 <= r.status_code < 470: - error = "Subscription-related error: Please ensure that your subscription is active and has not expired." - elif 470 <= r.status_code < 480: - error = "Billing-related error: Please ensure you have enough credits to run this pipeline. " - elif 480 <= r.status_code < 490: - error = "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access." - elif 490 <= r.status_code < 500: - error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." - else: - status_code = str(r.status_code) - error = ( - f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." - ) + status_code = r.status_code + error = get_error_from_status_code(status_code) logging.error(f"Error in request for {name} (Pipeline ID '{self.id}') - {r.status_code}: {error}") if response_version == "v1": return { "status": "failed", - "error": error, + "error": error.message, "elapsed_time": None, **kwargs, } return PipelineResponse( status=ResponseStatus.FAILED, - error={"error": error, "status": "ERROR"}, + error={"error": error.message, "status": "ERROR"}, elapsed_time=None, **kwargs, ) diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 3014adb1..56f18633 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -41,7 +41,7 @@ from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class InspectorTarget(str, Enum): diff --git a/aixplain/processes/data_onboarding/onboard_functions.py b/aixplain/processes/data_onboarding/onboard_functions.py index 01a3fe9b..09d1b153 100644 --- a/aixplain/processes/data_onboarding/onboard_functions.py +++ b/aixplain/processes/data_onboarding/onboard_functions.py @@ -18,7 +18,7 @@ from aixplain.modules.dataset import Dataset from aixplain.modules.file import File from aixplain.modules.metadata import MetaData -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Text, Union from urllib.parse import urljoin @@ -203,11 +203,10 @@ def build_payload_dataset( "description": dataset.description, "function": dataset.function.value, "onboardingErrorsPolicy": error_handler.value, - "tags": dataset.tags, + "tags": dataset.tags or tags, "privacy": dataset.privacy.value, "license": {"typeId": dataset.license.value}, "refData": ref_data, - "tags": tags, "data": [], "input": [], "hypotheses": [], diff --git a/aixplain/utils/file_utils.py b/aixplain/utils/file_utils.py index 58781dfb..f2bb55bc 100644 --- a/aixplain/utils/file_utils.py +++ b/aixplain/utils/file_utils.py @@ -20,10 +20,9 @@ import aixplain.utils.config as config from aixplain.enums.license import License - +from aixplain.utils.request_utils import _request_with_retry from collections import defaultdict from pathlib import Path -from requests.adapters import HTTPAdapter, Retry from typing import Any, Optional, Text, Union, Dict, List from uuid import uuid4 from urllib.parse import urljoin, urlparse @@ -52,24 +51,6 @@ def save_file(download_url: Text, download_file_path: Optional[Any] = None) -> A return download_file_path -def _request_with_retry(method: Text, url: Text, **params) -> requests.Response: - """Wrapper around requests with Session to retry in case it fails - - Args: - method (Text): HTTP method, such as 'GET' or 'HEAD'. - url (Text): The URL of the resource to fetch. - **params: Params to pass to request function - - Returns: - requests.Response: Response object of the request - """ - session = requests.Session() - retries = Retry(total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]) - session.mount("https://", HTTPAdapter(max_retries=retries)) - response = session.request(method=method.upper(), url=url, **params) - return response - - def download_data(url_link, local_filename=None): if local_filename is None: local_filename = url_link.split("/")[-1] diff --git a/tests/test_utils.py b/tests/test_utils.py index 264538d5..3d93fe25 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin import logging from aixplain.utils import config @@ -12,7 +12,7 @@ def delete_asset(model_id, api_key): def delete_service_account(api_key): - delete_url = urljoin(config.BACKEND_URL, f"sdk/ecr/logout") + delete_url = urljoin(config.BACKEND_URL, "sdk/ecr/logout") logging.debug(f"URL: {delete_url}") headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} _ = _request_with_retry("post", delete_url, headers=headers) diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index 753a8f7a..5d188470 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -24,17 +24,17 @@ ), ( 475, - "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", + "Billing-related error: Please ensure you have enough credits to run this asset. Details: An unspecified error occurred while processing your request.", ), ( 485, - "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", + "Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access. Details: An unspecified error occurred while processing your request.", ), ( 495, - "An unspecified error occurred while processing your request.", + "Validation-related error: Please verify the request payload and ensure it is correct. Details: An unspecified error occurred while processing your request.", ), - (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), + (501, "Unspecified Server Error (Status 501) Details: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message): diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 65708d0c..f9e1a379 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -121,17 +121,17 @@ def test_failed_poll(): ), ( 475, - "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", + "Billing-related error: Please ensure you have enough credits to run this asset. Details: An unspecified error occurred while processing your request.", ), ( 485, - "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", + "Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access. Details: An unspecified error occurred while processing your request.", ), ( 495, - "An unspecified error occurred while processing your request.", + "Validation-related error: Please verify the request payload and ensure it is correct. Details: An unspecified error occurred while processing your request.", ), - (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), + (501, "Unspecified Server Error (Status 501) Details: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message): diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index 7df95691..5f08eec3 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -35,12 +35,8 @@ def test_create_pipeline(): headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} ref_response = {"id": "12345"} mock.post(url, headers=headers, json=ref_response) - ref_pipeline = Pipeline( - id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY - ) - hyp_pipeline = PipelineFactory.create( - pipeline={"nodes": []}, name="Pipeline Test" - ) + ref_pipeline = Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) + hyp_pipeline = PipelineFactory.create(pipeline={"nodes": []}, name="Pipeline Test") assert hyp_pipeline.id == ref_pipeline.id assert hyp_pipeline.name == ref_pipeline.name @@ -58,19 +54,19 @@ def test_create_pipeline(): ), ( 475, - "{'error': 'Billing-related error: Please ensure you have enough credits to run this pipeline. ', 'status': 'ERROR'}", + "{'error': 'Billing-related error: Please ensure you have enough credits to run this asset.', 'status': 'ERROR'}", ), ( 485, - "{'error': 'Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access.', 'status': 'ERROR'}", + "{'error': 'Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access.', 'status': 'ERROR'}", ), ( 495, - "{'error': 'Validation-related error: Please ensure all required fields are provided and correctly formatted.', 'status': 'ERROR'}", + "{'error': 'Validation-related error: Please verify the request payload and ensure it is correct.', 'status': 'ERROR'}", ), ( 501, - "{'error': 'Status 501: Unspecified error: An unspecified error occurred while processing your request.', 'status': 'ERROR'}", + "{'error': 'Unspecified Server Error (Status 501)', 'status': 'ERROR'}", ), ], ) @@ -107,14 +103,9 @@ def test_list_pipelines_error_response(): mock.post(url, headers=headers, json=error_response, status_code=400) with pytest.raises(Exception) as excinfo: - PipelineFactory.list( - query=query, page_number=page_number, page_size=page_size - ) + PipelineFactory.list(query=query, page_number=page_number, page_size=page_size) - assert ( - "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" - in str(excinfo.value) - ) + assert "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" in str(excinfo.value) def test_get_pipeline_error_response(): @@ -132,112 +123,7 @@ def test_get_pipeline_error_response(): with pytest.raises(Exception) as excinfo: PipelineFactory.get(pipeline_id=pipeline_id) - assert ( - "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" - in str(excinfo.value) - ) - - -@pytest.fixture -def mock_pipeline(): - return Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) - - -def test_run_async_success(mock_pipeline): - with requests_mock.Mocker() as mock: - execute_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" - ) - success_response = PipelineResponse( - status=ResponseStatus.SUCCESS, url=execute_url - ) - mock.post(execute_url, json=success_response.__dict__, status_code=200) - - response = mock_pipeline.run_async(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_run_sync_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" - ) - execute_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" - ) - success_response = PipelineResponse(status=ResponseStatus.SUCCESS, url=poll_url) - poll_response = PipelineResponse( - status=ResponseStatus.SUCCESS, data={"output": "poll_result"} - ) - mock.post(execute_url, json=success_response.__dict__, status_code=200) - mock.get(poll_url, json=poll_response.__dict__, status_code=200) - response = mock_pipeline.run(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_poll_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" - ) - poll_response = PipelineResponse( - status=ResponseStatus.SUCCESS, data={"output": "poll_result"} - ) - mock.get(poll_url, json=poll_response.__dict__, status_code=200) - - response = mock_pipeline.poll(poll_url=poll_url) - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - assert response.data["output"] == "poll_result" - - -@pytest.fixture -def mock_pipeline(): - return Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) - - -def test_run_async_success(mock_pipeline): - with requests_mock.Mocker() as mock: - execute_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}") - success_response = PipelineResponse(status=ResponseStatus.SUCCESS, url=execute_url) - mock.post(execute_url, json=success_response.__dict__, status_code=200) - - response = mock_pipeline.run_async(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_run_sync_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}") - execute_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}") - success_response = {"status": "SUCCESS", "url": poll_url, "completed": True} - poll_response = {"status": "SUCCESS", "data": {"output": "poll_result"}, "completed": True} - mock.post(execute_url, json=success_response, status_code=200) - mock.get(poll_url, json=poll_response, status_code=200) - response = mock_pipeline.run(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_poll_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}") - poll_response = PipelineResponse(status=ResponseStatus.SUCCESS, data={"output": "poll_result"}) - mock.get(poll_url, json=poll_response.__dict__, status_code=200) - - response = mock_pipeline.poll(poll_url=poll_url) - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - assert response.data["output"] == "poll_result" + assert "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" in str(excinfo.value) @pytest.fixture From 0c4edf4fd6165212a5101ee20a01d5aa135e3842 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Tue, 6 May 2025 15:54:02 +0300 Subject: [PATCH 18/62] ENG-1862:Added status to Tools and deployment check for Agent and TeamAgent (#455) * Added status to Models, Pipelines * added status to Tools and deployment check for Agent and TeamAgent * DeployableMixin added * removed the deployement status check as it is moved to backend * for pipeline and model if id given check status added * model params set for model tool explicitly from object * BUG-499: Fix status update in class when update failed * Set to previous status if failed to update --- aixplain/enums/asset_status.py | 1 + aixplain/factories/model_factory/utils.py | 5 +- aixplain/modules/agent/__init__.py | 19 ++--- aixplain/modules/agent/tool/__init__.py | 3 + .../agent/tool/custom_python_code_tool.py | 6 ++ aixplain/modules/agent/tool/model_tool.py | 57 +++++++++++--- aixplain/modules/agent/tool/pipeline_tool.py | 10 ++- .../agent/tool/python_interpreter_tool.py | 2 + aixplain/modules/agent/tool/sql_tool.py | 3 +- aixplain/modules/mixins.py | 74 +++++++++++++++++++ aixplain/modules/model/__init__.py | 12 ++- aixplain/modules/model/utility_model.py | 13 ++-- aixplain/modules/pipeline/asset.py | 37 ++++++---- aixplain/modules/team_agent/__init__.py | 16 +--- tests/functional/model/run_model_test.py | 5 ++ tests/unit/agent/model_tool_test.py | 5 +- .../mock_responses/list_models_response.json | 11 ++- tests/unit/model_test.py | 8 +- tests/unit/team_agent_test.py | 25 ++++++- 19 files changed, 241 insertions(+), 71 deletions(-) create mode 100644 aixplain/modules/mixins.py diff --git a/aixplain/enums/asset_status.py b/aixplain/enums/asset_status.py index 994212fb..3d1e4323 100644 --- a/aixplain/enums/asset_status.py +++ b/aixplain/enums/asset_status.py @@ -43,3 +43,4 @@ class AssetStatus(Text, Enum): COMPLETED = "completed" CANCELING = "canceling" CANCELED = "canceled" + DEPRECATED_DRAFT = "deprecated_draft" diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 418a32b9..05a4b4c9 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -10,6 +10,7 @@ from datetime import datetime from typing import Dict, Union, List, Optional, Tuple from urllib.parse import urljoin +from aixplain.enums import AssetStatus import requests @@ -61,7 +62,6 @@ def create_model_from_response(response: Dict) -> Model: for param in response["params"] ] input_params = model_params - if not code: if "version" in response and response["version"]: version_link = response["version"]["id"] @@ -74,6 +74,7 @@ def create_model_from_response(response: Dict) -> Model: else: raise Exception("Utility Model Error: Code not found") + status = AssetStatus(response.get("status", AssetStatus.DRAFT.value)) created_at = None if "createdAt" in response and response["createdAt"]: created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) @@ -96,7 +97,7 @@ def create_model_from_response(response: Dict) -> Model: version=response["version"]["id"], inputs=inputs, temperature=temperature, - status=response.get("status", AssetStatus.DRAFT), + status=status, ) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index b6920f5c..8f19d3ec 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -26,26 +26,23 @@ import time import traceback -from aixplain.utils.request_utils import _request_with_retry -from aixplain.enums.function import Function -from aixplain.enums.supplier import Supplier -from aixplain.enums.asset_status import AssetStatus -from aixplain.enums.storage_type import StorageType +from aixplain.utils.file_utils import _request_with_retry +from aixplain.enums import Function, Supplier, AssetStatus, StorageType, ResponseStatus from aixplain.modules.model import Model from aixplain.modules.agent.agent_task import AgentTask from aixplain.modules.agent.output_format import OutputFormat from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData -from aixplain.enums import ResponseStatus from aixplain.modules.agent.utils import process_variables from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin from aixplain.utils import config +from aixplain.modules.mixins import DeployableMixin -class Agent(Model): +class Agent(Model, DeployableMixin[Tool]): """Advanced AI system capable of performing tasks by leveraging specialized software tools and resources from aiXplain marketplace. Attributes: @@ -212,6 +209,8 @@ def run( poll_url = response["url"] end = time.time() result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + # if result.status == ResponseStatus.FAILED: + # raise Exception("Model failed to run with error: " + result.error_message) result_data = result.get("data") or {} return AgentResponse( status=ResponseStatus.SUCCESS, @@ -419,11 +418,5 @@ def save(self) -> None: """Save the Agent.""" self.update() - def deploy(self) -> None: - assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." - assert self.status != AssetStatus.ONBOARDED, "Agent is already deployed." - self.status = AssetStatus.ONBOARDED - self.update() - def __repr__(self): return f"Agent(id={self.id}, name={self.name}, function={self.function})" diff --git a/aixplain/modules/agent/tool/__init__.py b/aixplain/modules/agent/tool/__init__.py index aefa093a..93dc269d 100644 --- a/aixplain/modules/agent/tool/__init__.py +++ b/aixplain/modules/agent/tool/__init__.py @@ -23,6 +23,7 @@ from abc import ABC from typing import Optional, Text from aixplain.utils import config +from aixplain.enums import AssetStatus class Tool(ABC): @@ -40,6 +41,7 @@ def __init__( description: Text, version: Optional[Text] = None, api_key: Optional[Text] = config.TEAM_API_KEY, + status: Optional[AssetStatus] = AssetStatus.DRAFT, **additional_info, ) -> None: """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. @@ -55,6 +57,7 @@ def __init__( self.version = version self.api_key = api_key self.additional_info = additional_info + self.status = status def to_dict(self): """Converts the tool to a dictionary.""" diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index d544d2b1..9d6363c6 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -24,6 +24,7 @@ from typing import Text, Union, Callable, Optional from aixplain.modules.agent.tool import Tool import logging +from aixplain.enums import AssetStatus class CustomPythonCodeTool(Tool): @@ -35,6 +36,7 @@ def __init__( """Custom Python Code Tool""" super().__init__(name=name or "", description=description, **additional_info) self.code = code + self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool self.validate() def to_dict(self): @@ -67,6 +69,10 @@ def validate(self): ), "Custom Python Code Tool Error: Tool description is required" assert self.code and self.code.strip() != "", "Custom Python Code Tool Error: Code is required" assert self.name and self.name.strip() != "", "Custom Python Code Tool Error: Name is required" + assert self.status in [ + AssetStatus.DRAFT, + AssetStatus.ONBOARDED, + ], "Custom Python Code Tool Error: Status must be DRAFT or ONBOARDED" def __repr__(self) -> Text: return f"CustomPythonCodeTool(name={self.name})" diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 6a945a15..9b073a84 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -22,8 +22,7 @@ """ from typing import Optional, Union, Text, Dict, List -from aixplain.enums.function import Function -from aixplain.enums.supplier import Supplier +from aixplain.enums import AssetStatus, Function, Supplier from aixplain.modules.agent.tool import Tool from aixplain.modules.model import Model @@ -83,8 +82,26 @@ def __init__( """ name = name or "" super().__init__(name=name, description=description, **additional_info) + status = AssetStatus.ONBOARDED if model is None else AssetStatus.DRAFT + model_id = model # if None, Set id to None as default + self.model_object = None # Store the actual model object for parameter access + + if isinstance(model, Model): + model_id = model.id + status = model.status + self.model_object = model # Store the Model object + elif isinstance(model, Text): + # get model from id + try: + self.model_object = self._get_model(model) # Store the Model object + model_id = self.model_object.id + status = self.model_object.status + except Exception: + raise Exception(f"Model Tool Unavailable. Make sure Model '{model}' exists or you have access to it.") + self.supplier = supplier - self.model = model + self.model = model_id + self.status = status self.function = function self.parameters = parameters self.validate() @@ -109,6 +126,7 @@ def to_dict(self) -> Dict: "version": self.version if self.version else None, "assetId": self.model.id if self.model is not None and isinstance(self.model, Model) else self.model, "parameters": self.parameters, + "status": self.status, } def validate(self) -> None: @@ -123,7 +141,6 @@ def validate(self) -> None: - If the description is empty, it sets the description to the function description or the model description. """ from aixplain.enums import FunctionInputOutput - from aixplain.factories.model_factory import ModelFactory assert ( self.function is not None or self.model is not None @@ -145,7 +162,7 @@ def validate(self) -> None: if self.model is not None: if isinstance(self.model, Text) is True: try: - self.model = ModelFactory.get(self.model, api_key=self.api_key) + self.model = self._get_model() except Exception: raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.") self.function = self.model.function @@ -166,8 +183,22 @@ def validate(self) -> None: self.name = self.name if self.name else set_tool_name(self.function, self.supplier, self.model) def get_parameters(self) -> Dict: + # If parameters were not explicitly provided, get them from the model + if ( + self.parameters is None + and self.model_object is not None # noqa: W503 + and hasattr(self.model_object, "model_params") # noqa: W503 + and self.model_object.model_params is not None # noqa: W503 + ): + return self.model_object.model_params.to_list() return self.parameters + def _get_model(self, model_id: Text = None): + from aixplain.factories.model_factory import ModelFactory + + model_id = model_id or self.model + return ModelFactory.get(model_id, api_key=self.api_key) + def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) -> Optional[List[Dict]]: """Validates and formats the parameters for the tool. @@ -182,8 +213,12 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) """ if received_parameters is None: # Get default parameters if none provided - if self.model is not None and self.model.model_params is not None: - return self.model.model_params.to_list() + if ( + self.model_object is not None + and hasattr(self.model_object, "model_params") # noqa: W503 + and self.model_object.model_params is not None # noqa: W503 + ): + return self.model_object.model_params.to_list() elif self.function is not None: function_params = self.function.get_parameters() if function_params is not None: @@ -192,8 +227,12 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) # Get expected parameters expected_params = None - if self.model is not None and self.model.model_params is not None: - expected_params = self.model.model_params + if ( + self.model_object is not None + and hasattr(self.model_object, "model_params") # noqa: W503 + and self.model_object.model_params is not None # noqa: W503 + ): + expected_params = self.model_object.model_params elif self.function is not None: expected_params = self.function.get_parameters() diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 0de83916..eccc80f7 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -24,6 +24,7 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.pipeline import Pipeline +from aixplain.enums import AssetStatus class PipelineTool(Tool): @@ -50,6 +51,7 @@ def __init__( name = name or "" super().__init__(name=name, description=description, **additional_info) + self.status = AssetStatus.DRAFT self.pipeline = pipeline self.validate() @@ -59,8 +61,12 @@ def to_dict(self): "name": self.name, "description": self.description, "type": "pipeline", + "status": self.status, } + def __repr__(self) -> Text: + return f"PipelineTool(name={self.name}, pipeline={self.pipeline})" + def validate(self): from aixplain.factories.pipeline_factory import PipelineFactory @@ -76,6 +82,4 @@ def validate(self): if self.name.strip() == "": self.name = pipeline_obj.name - - def __repr__(self) -> Text: - return f"PipelineTool(name={self.name}, pipeline={self.pipeline})" + self.status = pipeline_obj.status diff --git a/aixplain/modules/agent/tool/python_interpreter_tool.py b/aixplain/modules/agent/tool/python_interpreter_tool.py index 42621c45..5e1e45e5 100644 --- a/aixplain/modules/agent/tool/python_interpreter_tool.py +++ b/aixplain/modules/agent/tool/python_interpreter_tool.py @@ -22,6 +22,7 @@ """ from aixplain.modules.agent.tool import Tool +from aixplain.enums import AssetStatus from typing import Text @@ -32,6 +33,7 @@ def __init__(self, **additional_info) -> None: """Python Interpreter Tool""" description = "A Python shell. Use this to execute python commands. Input should be a valid python command." super().__init__(name="Python Interpreter", description=description, **additional_info) + self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool def to_dict(self): return { diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index 56cf116c..f2262ee9 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -27,7 +27,7 @@ import numpy as np from typing import Text, Optional, Dict, List, Union import sqlite3 - +from aixplain.enums import AssetStatus from aixplain.modules.agent.tool import Tool @@ -284,6 +284,7 @@ def __init__( self.schema = schema self.tables = tables if isinstance(tables, list) else [tables] if tables else None self.enable_commit = enable_commit + self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool def to_dict(self) -> Dict[str, Text]: return { diff --git a/aixplain/modules/mixins.py b/aixplain/modules/mixins.py new file mode 100644 index 00000000..0b402cf2 --- /dev/null +++ b/aixplain/modules/mixins.py @@ -0,0 +1,74 @@ +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: November 25th 2024 +Description: + Mixins for common functionality across different asset types +""" +from abc import ABC +from typing import TypeVar, Generic +from aixplain.enums import AssetStatus + +T = TypeVar("T") + + +class DeployableMixin(ABC, Generic[T]): + """A mixin that provides common deployment-related functionality for assets. + + This mixin provides methods for: + 1. Filtering items that are not onboarded + 2. Validating if an asset is ready to be deployed + 3. Deploying an asset + + Classes that inherit from this mixin should: + 1. Implement _validate_deployment_readiness to call the parent implementation with their specific asset type + 2. Optionally override deploy() if they need special deployment handling + """ + + def _validate_deployment_readiness(self) -> None: + """Validate if the asset is ready to be deployed. + + Args: + asset_type (str): Type of asset being validated (e.g. "Agent", "Team Agent", "Pipeline") + items (Optional[List[T]], optional): List of items to validate (e.g. tools for Agent, agents for TeamAgent) + + Raises: + ValueError: If the asset is not ready to be deployed + """ + asset_type = self.__class__.__name__ + if self.status == AssetStatus.ONBOARDED: + raise ValueError(f"{asset_type} is already deployed.") + + if self.status != AssetStatus.DRAFT: + raise ValueError(f"{asset_type} must be in DRAFT status to be deployed.") + + def deploy(self) -> None: + """Deploy the asset. + + This method validates that the asset is ready to be deployed and updates its status to ONBOARDED. + Classes that need special deployment handling should override this method. + + Raises: + ValueError: If the asset is not ready to be deployed + """ + self._validate_deployment_readiness() + previous_status = self.status + try: + self.status = AssetStatus.ONBOARDED + self.update() + except Exception as e: + self.status = previous_status + raise Exception(f"Error deploying because of backend error: {e}") from e diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 4729dd42..49cf4591 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -34,6 +34,7 @@ from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters +from aixplain.enums import AssetStatus class Model(Asset): @@ -72,6 +73,7 @@ def __init__( input_params: Optional[Dict] = None, output_params: Optional[Dict] = None, model_params: Optional[Dict] = None, + status: Optional[AssetStatus] = AssetStatus.ONBOARDED, # default status for models is ONBOARDED **additional_info, ) -> None: """Model Init @@ -89,6 +91,7 @@ def __init__( input_params (Dict, optional): input parameters for the function. output_params (Dict, optional): output parameters for the function. model_params (Dict, optional): parameters for the function. + status (AssetStatus, optional): status of the model. Defaults to None. **additional_info: Any additional Model info to be saved """ super().__init__(id, name, description, supplier, version, cost=cost) @@ -102,6 +105,12 @@ def __init__( self.input_params = input_params self.output_params = output_params self.model_params = ModelParameters(model_params) if model_params else None + if isinstance(status, str): + try: + status = AssetStatus(status) + except Exception: + status = AssetStatus.ONBOARDED + self.status = status def to_dict(self) -> Dict: """Get the model info as a Dictionary @@ -119,6 +128,7 @@ def to_dict(self) -> Dict: "input_params": self.input_params, "output_params": self.output_params, "model_params": self.model_params.to_dict(), + "status": self.status, } def get_parameters(self) -> ModelParameters: @@ -309,7 +319,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): Returns: FinetuneStatus: The status of the FineTune model. """ - from aixplain.enums.asset_status import AssetStatus + from aixplain.enums import AssetStatus from aixplain.modules.finetune.status import FinetuneStatus headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 043571c2..1cbbf3da 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -21,7 +21,7 @@ import logging import warnings from aixplain.enums import Function, Supplier, DataType -from aixplain.enums.asset_status import AssetStatus +from aixplain.enums import AssetStatus from aixplain.modules.model import Model from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry @@ -29,6 +29,7 @@ from dataclasses import dataclass from typing import Callable, Union, Optional, List, Text, Dict from urllib.parse import urljoin +from aixplain.modules.mixins import DeployableMixin @dataclass @@ -88,7 +89,7 @@ def decorator(func): return decorator -class UtilityModel(Model): +class UtilityModel(Model, DeployableMixin): """Ready-to-use Utility Model. Note: Non-deployed utility models (status=DRAFT) will expire after 24 hours after creation. @@ -107,6 +108,7 @@ class UtilityModel(Model): function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. + status (AssetStatus, optional): status of the model. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Model info to be saved """ @@ -155,6 +157,7 @@ def __init__( function=function, is_subscribed=is_subscribed, api_key=api_key, + status=status, **additional_info, ) self.url = config.MODELS_RUN_URL @@ -274,9 +277,3 @@ def delete(self): message = f"Utility Model Deletion Error: {response}" logging.error(message) raise Exception(f"{message}") - - def deploy(self) -> None: - assert self.status == AssetStatus.DRAFT, "Utility Model must be in draft status to be deployed." - assert self.status != AssetStatus.ONBOARDED, "Utility Model is already deployed." - self.status = AssetStatus.ONBOARDED - self.update() diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 63eac9b7..cc234337 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -1,7 +1,7 @@ __author__ = "aiXplain" """ -Copyright 2022 The aiXplain SDK authors +Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,28 +15,28 @@ See the License for the specific language governing permissions and limitations under the License. -Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli -Date: September 1st 2022 +Author: Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: November 25th 2024 Description: - Pipeline Class + Pipeline Asset Class """ import time import json import os import logging -from aixplain.enums.asset_status import AssetStatus -from aixplain.enums.response_status import ResponseStatus -from aixplain.modules.asset import Asset +from aixplain.enums import AssetStatus, ResponseStatus +from aixplain.modules import Asset from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin from aixplain.modules.pipeline.response import PipelineResponse +from aixplain.modules.mixins import DeployableMixin from aixplain.exceptions import get_error_from_status_code -class Pipeline(Asset): +class Pipeline(Asset, DeployableMixin): """Representing a custom pipeline that was created on the aiXplain Platform Attributes: @@ -46,6 +46,7 @@ class Pipeline(Asset): url (Text, optional): running URL of platform. Defaults to config.BACKEND_URL. supplier (Text, optional): Pipeline supplier. Defaults to "aiXplain". version (Text, optional): version of the pipeline. Defaults to "1.0". + status (AssetStatus, optional): Pipeline status. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Pipeline info to be saved """ @@ -581,13 +582,23 @@ def save( raise Exception(e) def deploy(self, api_key: Optional[Text] = None) -> None: - """Deploy the Pipeline.""" - assert self.status == "draft", "Pipeline Deployment Error: Pipeline must be in draft status." - assert self.status != "onboarded", "Pipeline Deployment Error: Pipeline must be onboarded." + """Deploy the Pipeline. + This method overrides the deploy method in DeployableMixin to handle + Pipeline-specific deployment functionality. + + Args: + api_key (Optional[Text], optional): Team API Key to deploy the Pipeline. Defaults to None. + """ + self._validate_deployment_readiness() pipeline = self.to_dict() - self.update(pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name) - self.status = AssetStatus.ONBOARDED + previous_status = self.status + try: + self.status = AssetStatus.ONBOARDED + self.update(pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name) + except Exception as e: + self.status = previous_status + raise Exception(f"Error deploying because of backend error: {e}") from e def __repr__(self): return f"Pipeline(id={self.id}, name={self.name})" diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 56f18633..449e9549 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -41,7 +41,8 @@ from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.file_utils import _request_with_retry +from aixplain.modules.mixins import DeployableMixin class InspectorTarget(str, Enum): @@ -53,7 +54,7 @@ def __str__(self): return self._value_ -class TeamAgent(Model): +class TeamAgent(Model, DeployableMixin[Agent]): """Advanced AI system capable of using multiple agents to perform a variety of tasks. Attributes: @@ -416,14 +417,3 @@ def update(self) -> None: else: error_msg = f"Team Agent Update Error (HTTP {r.status_code}): {resp}" raise Exception(error_msg) - - def save(self) -> None: - """Save the Team Agent.""" - self.update() - - def deploy(self) -> None: - """Deploy the Team Agent.""" - assert self.status == AssetStatus.DRAFT, "Team Agent Deployment Error: Team Agent must be in draft status." - assert self.status != AssetStatus.ONBOARDED, "Team Agent Deployment Error: Team Agent must be onboarded." - self.status = AssetStatus.ONBOARDED - self.update() diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 1aed18df..25a28640 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -237,3 +237,8 @@ def test_index_model_air_with_image(): assert "hurricane" in second_record.lower() index_model.delete() + + import os + + if os.path.exists("hurricane.jpeg"): + os.remove("hurricane.jpeg") diff --git a/tests/unit/agent/model_tool_test.py b/tests/unit/agent/model_tool_test.py index 869979e1..5dfc736e 100644 --- a/tests/unit/agent/model_tool_test.py +++ b/tests/unit/agent/model_tool_test.py @@ -4,11 +4,10 @@ from unittest.mock import MagicMock from aixplain.modules.agent.tool.model_tool import ModelTool -from aixplain.enums.function import Function -from aixplain.enums.supplier import Supplier from aixplain.modules.model import Model from aixplain.modules.model.model_parameters import ModelParameters from aixplain.base.parameters import Parameter +from aixplain.enums import AssetStatus, Function, Supplier @pytest.fixture @@ -19,6 +18,7 @@ def mock_model(): model.supplier = Supplier.AIXPLAIN model.name = "Test Model" model.description = "Test Model Description" + model.status = AssetStatus.ONBOARDED model.model_params = ModelParameters( { "sourcelanguage": {"name": "sourcelanguage", "required": True}, @@ -107,6 +107,7 @@ def test_to_dict(mocker, mock_model, mock_model_factory): "version": None, "assetId": "test_model_id", "parameters": [{"name": "sourcelanguage", "value": "en"}, {"name": "targetlanguage", "value": "es"}], + "status": mock_model.status.value, } result = tool.to_dict() diff --git a/tests/unit/mock_responses/list_models_response.json b/tests/unit/mock_responses/list_models_response.json index de2cb7ba..c927ba0b 100644 --- a/tests/unit/mock_responses/list_models_response.json +++ b/tests/unit/mock_responses/list_models_response.json @@ -1,4 +1,3 @@ - { "total": 15, "pageTotal": 10, @@ -7,6 +6,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Transformerml6x6)", "serviceName": "Machine Translation", + "status": "onboarded", "supplier": { "id": "Nvidia", "name": "Nvidia" @@ -129,6 +129,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Transformer12x2)", "serviceName": "Machine Translation", + "status": "onboarded", "supplier": { "id": "Nvidia", "name": "Nvidia" @@ -250,6 +251,7 @@ { "id": "test_asset_id", "name": "Translate from English to French", + "status": "onboarded", "supplier": { "id": "eBay", "name": "eBay" @@ -376,6 +378,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Canada)", "serviceName": "Amazon Translate", + "status": "onboarded", "supplier": { "id": "AWS", "name": "AWS" @@ -507,6 +510,7 @@ "id": "test_asset_id", "name": "Translate from English to French", "serviceName": "Cloud Translation", + "status": "onboarded", "supplier": { "id": "Google", "name": "Google" @@ -632,6 +636,7 @@ { "id": "test_asset_id", "name": "Translate from English to French", + "status": "onboarded", "supplier": { "id": "ModernMT", "name": "ModernMT" @@ -758,6 +763,7 @@ "id": "test_asset_id", "name": "Translate from English to French", "serviceName": "Cognitive Services", + "status": "onboarded", "supplier": { "id": "Azure", "name": "Azure" @@ -884,6 +890,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Canada)", "serviceName": "Cognitive Services", + "status": "onboarded", "supplier": { "id": "Azure", "name": "Azure" @@ -1010,6 +1017,7 @@ { "id": "test_asset_id", "name": "Translate from English to French", + "status": "onboarded", "supplier": { "id": "AppTek", "name": "AppTek" @@ -1136,6 +1144,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Transformerml24x6)", "serviceName": "Machine Translation", + "status": "onboarded", "supplier": { "id": "Nvidia", "name": "Nvidia" diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index f9e1a379..943d501d 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -204,6 +204,7 @@ def test_get_model_from_ids(): { "id": "test-model-id-1", "name": "Test Model 1", + "status": "onboarded", "description": "Test Description 1", "function": {"id": "text-generation"}, "supplier": {"id": "aiXplain"}, @@ -214,6 +215,7 @@ def test_get_model_from_ids(): { "id": "test-model-id-2", "name": "Test Model 2", + "status": "onboarded", "description": "Test Description 2", "function": {"id": "text-generation"}, "supplier": {"id": "aiXplain"}, @@ -237,9 +239,9 @@ def test_list_models_error(): with pytest.raises(Exception) as excinfo: ModelFactory.list(model_ids=model_ids, function=Function.TEXT_GENERATION, api_key=config.AIXPLAIN_API_KEY) - assert ( - str(excinfo.value) - == "Cannot filter by function, suppliers, source languages, target languages, is finetunable, ownership, sort by when using model ids" + assert str(excinfo.value) == ( + "Cannot filter by function, suppliers, " + "source languages, target languages, is finetunable, ownership, sort by when using model ids" ) with pytest.raises(Exception) as excinfo: diff --git a/tests/unit/team_agent_test.py b/tests/unit/team_agent_test.py index 2b06043e..e84e2e34 100644 --- a/tests/unit/team_agent_test.py +++ b/tests/unit/team_agent_test.py @@ -1,7 +1,7 @@ import pytest import requests_mock from urllib.parse import urljoin -from unittest.mock import patch +from unittest.mock import patch, Mock from aixplain.enums.asset_status import AssetStatus from aixplain.factories import TeamAgentFactory @@ -181,7 +181,7 @@ def test_create_team_agent(mock_model_factory_get): "role": "Test Agent Role", "teamId": "123", "version": "1.0", - "status": "draft", + "status": "onboarded", "llmId": "6646261c6eb563165658bbb1", "pricing": {"currency": "USD", "value": 0.0}, "assets": [ @@ -502,3 +502,24 @@ def get_mock(agent_id): agent1 = next((agent for agent in team_agent.agents if agent.id == "agent1"), None) assert agent1 is not None assert agent1.tasks[0].dependencies[0].name == "Test Task 2" + + +def test_deploy_team_agent(): + # Create a mock agent with ONBOARDED status + mock_agent = Mock() + mock_agent.id = "agent-id" + mock_agent.name = "Test Agent" + mock_agent.status = AssetStatus.ONBOARDED + + # Create the team agent + team_agent = TeamAgent(id="team-agent-id", name="Test Team Agent", agents=[mock_agent], status=AssetStatus.DRAFT) + + # Mock the update method + team_agent.update = Mock() + + # Deploy the team agent + team_agent.deploy() + + # Verify that status was updated and update was called + assert team_agent.status == AssetStatus.ONBOARDED + team_agent.update.assert_called_once() From 84baed8b98f0802f222fc294c680ffeb1cd622c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Fri, 9 May 2025 11:37:59 +0300 Subject: [PATCH 19/62] Fix: BUG-543 SQL Tool upload db issue (#519) --- aixplain/factories/agent_factory/__init__.py | 5 +++-- aixplain/modules/agent/tool/sql_tool.py | 2 +- aixplain/v2/agent.py | 2 ++ .../functional/agent/agent_functional_test.py | 21 ++++++++++++++----- tests/unit/agent/agent_factory_utils_test.py | 13 ++++++++---- tests/unit/agent/sql_tool_test.py | 14 ++++++++----- 6 files changed, 40 insertions(+), 17 deletions(-) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 9440ff81..156f1b54 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -177,7 +177,9 @@ def create_model_tool( supplier = supplier_ break assert isinstance(supplier, Supplier), f"Supplier {supplier} is not a valid supplier" - return ModelTool(function=function, supplier=supplier, model=model, description=description, parameters=parameters) + return ModelTool( + function=function, supplier=supplier, model=model, name=name, description=description, parameters=parameters + ) @classmethod def create_pipeline_tool( @@ -278,7 +280,6 @@ def create_sql_tool( base_name = os.path.splitext(os.path.basename(source))[0] db_path = os.path.join(os.path.dirname(source), f"{base_name}.db") table_name = tables[0] if tables else None - try: # Create database from CSV schema = create_database_from_csv(source, db_path, table_name) diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index f2262ee9..5b3f4178 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -285,6 +285,7 @@ def __init__( self.tables = tables if isinstance(tables, list) else [tables] if tables else None self.enable_commit = enable_commit self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool + self.validate() # to upload the database def to_dict(self) -> Dict[str, Text]: return { @@ -306,7 +307,6 @@ def validate(self): raise SQLToolError("Description is required") if not self.database: raise SQLToolError("Database must be provided") - # Handle database validation if not ( str(self.database).startswith("s3://") diff --git a/aixplain/v2/agent.py b/aixplain/v2/agent.py index eadb155f..a493f723 100644 --- a/aixplain/v2/agent.py +++ b/aixplain/v2/agent.py @@ -131,6 +131,7 @@ def create_custom_python_code_tool( @classmethod def create_sql_tool( cls, + name: str, description: str, source: str, source_type: str, @@ -172,6 +173,7 @@ def create_sql_tool( from aixplain.factories import AgentFactory return AgentFactory.create_sql_tool( + name=name, description=description, source=source, source_type=source_type, diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 73bab1fe..10569440 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -354,6 +354,7 @@ def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agent @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents + agent = None try: import os @@ -362,6 +363,7 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): f.write("") tool = AgentFactory.create_sql_tool( + name="Teste", description="Execute an SQL query and return the result", source="ftest.db", source_type="sqlite", @@ -394,11 +396,13 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert "eve" in str(response["data"]["output"]).lower() finally: os.remove("ftest.db") - agent.delete() + if agent: + agent.delete() @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents + agent = None try: import os import pandas as pd @@ -424,7 +428,11 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): # Create SQL tool from CSV tool = AgentFactory.create_sql_tool( - description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"] + name="CSV Tool Test", + description="Execute SQL queries on employee data", + source="test.csv", + source_type="csv", + tables=["employees"], ) # Verify tool setup @@ -470,9 +478,12 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): finally: # Cleanup - os.remove("test.csv") - os.remove("test.db") - agent.delete() + if agent: + agent.delete() + if os.path.exists("test.csv"): + os.remove("test.csv") + if os.path.exists("test.db"): + os.remove("test.db") @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) diff --git a/tests/unit/agent/agent_factory_utils_test.py b/tests/unit/agent/agent_factory_utils_test.py index c0104f5c..f1b4da70 100644 --- a/tests/unit/agent/agent_factory_utils_test.py +++ b/tests/unit/agent/agent_factory_utils_test.py @@ -12,6 +12,7 @@ from aixplain.modules.agent import Agent from aixplain.modules.agent.agent_task import AgentTask from aixplain.factories import ModelFactory, PipelineFactory +import os @pytest.fixture @@ -143,7 +144,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): "name": "Test SQL", "description": "Test SQL", "parameters": [ - {"name": "database", "value": "test_db"}, + {"name": "database", "value": "test_db.db"}, {"name": "schema", "value": "public"}, {"name": "tables", "value": "table1,table2"}, {"name": "enable_commit", "value": True}, @@ -153,7 +154,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): { "name": "Test SQL", "description": "Test SQL", - "database": "test_db", + "database": "test_db.db", "schema": "public", "tables": ["table1", "table2"], "enable_commit": True, @@ -166,7 +167,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): "name": "Test SQL", "description": "Test SQL with string enable_commit", "parameters": [ - {"name": "database", "value": "test_db"}, + {"name": "database", "value": "test_db.db"}, {"name": "schema", "value": "public"}, {"name": "tables", "value": "table1"}, {"name": "enable_commit", "value": True}, @@ -176,7 +177,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): { "name": "Test SQL", "description": "Test SQL with string enable_commit", - "database": "test_db", + "database": "test_db.db", "schema": "public", "tables": ["table1"], "enable_commit": True, @@ -192,6 +193,8 @@ def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock "aixplain.modules.model.utils.parse_code_decorated", return_value=("print('Hello World')", [], "Test description", "test_name"), ) + mocker.patch("os.path.exists", lambda path: True if path == "test_db.db" else os.path.exists(path)) + mocker.patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://mocked-file-path/test_db.db") if tool_dict["type"] == "pipeline": mocker.patch.object( PipelineFactory, @@ -205,6 +208,8 @@ def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock for attr, value in expected_attrs.items(): if attr == "model": assert tool.model == mock_model + elif attr == "database" and value == "test_db.db": + assert getattr(tool, attr) == "s3://mocked-file-path/test_db.db" else: assert getattr(tool, attr) == value diff --git a/tests/unit/agent/sql_tool_test.py b/tests/unit/agent/sql_tool_test.py index 1a1dbb84..9f03a279 100644 --- a/tests/unit/agent/sql_tool_test.py +++ b/tests/unit/agent/sql_tool_test.py @@ -59,7 +59,9 @@ def test_create_sql_tool(mocker, tmp_path): ) assert isinstance(tool, SQLTool) assert tool.description == "Test" - assert tool.database == db_path + assert os.path.basename(db_path) in os.path.basename(tool.database) + assert tool.database.startswith("s3://") + assert tool.database.endswith(".db") assert tool.schema == "test" assert tool.tables == ["test", "test2"] @@ -78,7 +80,7 @@ def test_create_sql_tool(mocker, tmp_path): tool_dict = tool.to_dict() assert tool_dict["description"] == "Test" assert tool_dict["parameters"] == [ - {"name": "database", "value": db_path}, + {"name": "database", "value": tool.database}, {"name": "schema", "value": "test"}, {"name": "tables", "value": "test,test2"}, {"name": "enable_commit", "value": False}, @@ -89,7 +91,8 @@ def test_create_sql_tool(mocker, tmp_path): mocker.patch("os.path.exists", return_value=True) mocker.patch("aixplain.modules.agent.tool.sql_tool.get_table_schema", return_value="CREATE TABLE test (id INTEGER)") tool.validate() - assert tool.database == "s3://test.db" + assert tool.database.startswith("s3://") + assert tool.database.endswith(".db") def test_create_database_from_csv(tmp_path): @@ -225,7 +228,8 @@ def test_create_sql_tool_with_schema_inference(tmp_path, mocker): tool.validate() assert tool.schema == schema assert tool.tables == ["test"] - assert tool.database == "s3://test.db" + assert tool.database.startswith("s3://") + assert tool.database.endswith(".db") def test_create_sql_tool_from_csv_with_warnings(tmp_path, mocker): @@ -283,7 +287,7 @@ def test_create_sql_tool_from_csv(tmp_path): assert isinstance(tool, SQLTool) assert tool.description == "Test" assert tool.database.endswith(".db") - assert os.path.exists(tool.database) + assert tool.database.startswith("s3://") # Test schema and table inference during validation try: From de6a1c4739594d474e61ae9fbbccbe717ec5ddd5 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Fri, 9 May 2025 14:08:54 -0300 Subject: [PATCH 20/62] ENG-2028: model streaming (#506) * Enabling streaming * ModelStreamer * Adding debug logging messages * Streaming service and tests * Parse new JSON content * Remove full_content from model streamer --- aixplain/factories/model_factory/utils.py | 1 + aixplain/modules/model/__init__.py | 27 ++++++++++++++++-- aixplain/modules/model/llm_model.py | 13 ++++++--- .../modules/model/model_response_streamer.py | 28 +++++++++++++++++++ aixplain/modules/model/utils.py | 8 +++++- tests/functional/model/run_model_test.py | 21 ++++++++++++-- tests/unit/model_test.py | 20 +++++++++++-- 7 files changed, 106 insertions(+), 12 deletions(-) create mode 100644 aixplain/modules/model/model_response_streamer.py diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 589bdf68..a5f489be 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -98,6 +98,7 @@ def create_model_from_response(response: Dict) -> Model: version=response["version"]["id"], inputs=inputs, temperature=temperature, + supports_streaming=response.get("supportsStreaming", False), status=status, ) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 49cf4591..92e48990 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -25,6 +25,7 @@ import traceback from aixplain.enums import Supplier, Function from aixplain.modules.asset import Asset +from aixplain.modules.model.model_response_streamer import ModelResponseStreamer from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config from urllib.parse import urljoin @@ -56,6 +57,7 @@ class Model(Asset): input_params (ModelParameters, optional): input parameters for the function. output_params (Dict, optional): output parameters for the function. model_params (ModelParameters, optional): parameters for the function. + supports_streaming (bool, optional): whether the model supports streaming. Defaults to False. """ def __init__( @@ -73,6 +75,7 @@ def __init__( input_params: Optional[Dict] = None, output_params: Optional[Dict] = None, model_params: Optional[Dict] = None, + supports_streaming: bool = False, status: Optional[AssetStatus] = AssetStatus.ONBOARDED, # default status for models is ONBOARDED **additional_info, ) -> None: @@ -91,6 +94,7 @@ def __init__( input_params (Dict, optional): input parameters for the function. output_params (Dict, optional): output parameters for the function. model_params (Dict, optional): parameters for the function. + supports_streaming (bool, optional): whether the model supports streaming. Defaults to False. status (AssetStatus, optional): status of the model. Defaults to None. **additional_info: Any additional Model info to be saved """ @@ -105,6 +109,7 @@ def __init__( self.input_params = input_params self.output_params = output_params self.model_params = ModelParameters(model_params) if model_params else None + self.supports_streaming = supports_streaming if isinstance(status, str): try: status = AssetStatus(status) @@ -232,6 +237,19 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: completed=False, ) + def run_stream( + self, + data: Union[Text, Dict], + parameters: Optional[Dict] = None, + ) -> ModelResponseStreamer: + assert self.supports_streaming, f"Model '{self.name} ({self.id})' does not support streaming" + payload = build_payload(data=data, parameters=parameters, stream=True) + url = f"{self.url}/{self.id}".replace("api/v1/execute", "api/v2/execute") + logging.debug(f"Model Run Stream: Start service for {url} - {payload}") + headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + r = _request_with_retry("post", url, headers=headers, data=payload, stream=True) + return ModelResponseStreamer(r.iter_lines(decode_unicode=True)) + def run( self, data: Union[Text, Dict], @@ -239,7 +257,8 @@ def run( timeout: float = 300, parameters: Optional[Dict] = None, wait_time: float = 0.5, - ) -> ModelResponse: + stream: bool = False, + ) -> Union[ModelResponse, ModelResponseStreamer]: """Runs a model call. Args: @@ -248,10 +267,12 @@ def run( timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. - + stream (bool, optional): whether the model supports streaming. Defaults to False. Returns: - Dict: parsed output from model + Union[ModelResponse, ModelStreamer]: parsed output from model """ + if stream: + return self.run_stream(data=data, parameters=parameters) start = time.time() payload = build_payload(data=data, parameters=parameters) url = f"{self.url}/{self.id}".replace("api/v1/execute", "api/v2/execute") diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index 0b8e6cf8..546be1b8 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -25,6 +25,7 @@ import traceback from aixplain.enums import Function, Supplier from aixplain.modules.model import Model +from aixplain.modules.model.model_response_streamer import ModelResponseStreamer from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config from typing import Union, Optional, List, Text, Dict @@ -108,7 +109,8 @@ def run( timeout: float = 300, parameters: Optional[Dict] = None, wait_time: float = 0.5, - ) -> ModelResponse: + stream: bool = False, + ) -> Union[ModelResponse, ModelResponseStreamer]: """Synchronously running a Large Language Model (LLM) model. Args: @@ -123,9 +125,9 @@ def run( timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. - + stream (bool, optional): whether the model supports streaming. Defaults to False. Returns: - Dict: parsed output from model + Union[ModelResponse, ModelStreamer]: parsed output from model """ start = time.time() parameters = parameters or {} @@ -141,10 +143,13 @@ def run( parameters.setdefault("max_tokens", max_tokens) parameters.setdefault("top_p", top_p) + if stream: + return self.run_stream(data=data, parameters=parameters) + payload = build_payload(data=data, parameters=parameters) logging.info(payload) url = f"{self.url}/{self.id}".replace("/api/v1/execute", "/api/v2/execute") - logging.debug(f"Model Run Sync: Start service for {name} - {url}") + logging.debug(f"Model Run Sync: Start service for {name} - {url} - {payload}") response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) if response["status"] == "IN_PROGRESS": try: diff --git a/aixplain/modules/model/model_response_streamer.py b/aixplain/modules/model/model_response_streamer.py new file mode 100644 index 00000000..f1fef5f8 --- /dev/null +++ b/aixplain/modules/model/model_response_streamer.py @@ -0,0 +1,28 @@ +import json +from typing import Iterator + +from aixplain.modules.model.response import ModelResponse, ResponseStatus + + +class ModelResponseStreamer: + def __init__(self, iterator: Iterator): + self.iterator = iterator + self.status = ResponseStatus.IN_PROGRESS + + def __next__(self): + """ + Returns the next chunk of the response. + """ + line = next(self.iterator).replace("data: ", "") + try: + data = json.loads(line) + except json.JSONDecodeError: + data = {"data": line} + content = data.get("data", "") + if content == "[DONE]": + self.status = ResponseStatus.SUCCESS + content = "" + return ModelResponse(status=self.status, data=content) + + def __iter__(self): + return self diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 6f3f9319..cc68a347 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -7,12 +7,17 @@ from aixplain.exceptions import get_error_from_status_code -def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): +def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None, stream: Optional[bool] = None): from aixplain.factories import FileFactory if parameters is None: parameters = {} + if stream is not None: + if "options" not in parameters: + parameters["options"] = {} + parameters["options"]["stream"] = stream + data = FileFactory.to_link(data) if isinstance(data, dict): payload = data @@ -36,6 +41,7 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: resp = "unspecified error" try: + logging.debug(f"Calling {url} with payload: {payload}") r = _request_with_retry("post", url, headers=headers, data=payload) resp = r.json() except Exception as e: diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index dd8cb04e..7fbb74d1 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -46,6 +46,25 @@ def test_llm_run(llm_model): assert response["status"] == "SUCCESS" +def test_llm_run_stream(): + """Testing LLMs with streaming""" + from aixplain.modules.model.response import ModelResponse, ResponseStatus + from aixplain.modules.model.model_response_streamer import ModelResponseStreamer + + llm_model = ModelFactory.get("669a63646eb56306647e1091") + + assert isinstance(llm_model, LLM) + response = llm_model.run( + data="This is a test prompt where I expect you to respond with the following phrase: 'This is a test response.'", + stream=True, + ) + assert isinstance(response, ModelResponseStreamer) + for chunk in response: + assert isinstance(chunk, ModelResponse) + assert chunk.data in "This is a test response." + assert response.status == ResponseStatus.SUCCESS + + def test_run_async(): """Testing Model Async""" model = ModelFactory.get("60ddef828d38c51c5885d491") @@ -100,7 +119,6 @@ def run_index_model(index_model): pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="AIR - BGE M3"), pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="AIR - aiXplain Legal Embeddings"), - ], ) def test_index_model(embedding_model, supplier_params): @@ -126,7 +144,6 @@ def test_index_model(embedding_model, supplier_params): pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"), pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="aiXplain Legal Embeddings"), - ], ) def test_index_model_with_filter(embedding_model, supplier_params): diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 943d501d..57ef17d0 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -25,8 +25,8 @@ from aixplain.factories import ModelFactory from aixplain.enums import Function from urllib.parse import urljoin -from aixplain.enums import ResponseStatus -from aixplain.modules.model.response import ModelResponse +from aixplain.modules.model.response import ModelResponse, ResponseStatus +from aixplain.modules.model.model_response_streamer import ModelResponseStreamer import pytest from unittest.mock import patch from aixplain.enums.asset_status import AssetStatus @@ -638,3 +638,19 @@ def test_empty_model_parameters_string(): """Test string representation of empty ModelParameters.""" params = ModelParameters({}) assert str(params) == "No parameters defined" + + +def test_model_response_streamer(): + """Test ModelResponseStreamer class.""" + streamer = ModelResponseStreamer(iter([])) + assert isinstance(streamer, ModelResponseStreamer) + assert streamer.status == ResponseStatus.IN_PROGRESS + + +def test_model_not_supports_streaming(mocker): + """Test ModelResponseStreamer class.""" + mocker.patch("aixplain.modules.model.utils.build_payload", return_value={"data": "test"}) + model = Model(id="test-id", name="Test Model", supports_streaming=False) + with pytest.raises(Exception) as excinfo: + model.run(data="test", stream=True) + assert f"Model '{model.name} ({model.id})' does not support streaming" in str(excinfo.value) From b909a5418357987d683a58490f2e4cc47b5fec32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Fri, 9 May 2025 20:32:54 +0300 Subject: [PATCH 21/62] BUG-542-Utility-Model-Update-Test-Failing-in-SDK (#515) * set utility models status to draft if updated * BUG-542: fix utility model tests --- aixplain/modules/model/utility_model.py | 1 - tests/functional/model/run_utility_model_test.py | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 1cbbf3da..f6b69ba2 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -237,7 +237,6 @@ def update(self): DeprecationWarning, stacklevel=2, ) - self.validate() url = urljoin(self.backend_url, f"sdk/utilities/{self.id}") headers = {"x-api-key": f"{self.api_key}", "Content-Type": "application/json"} diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py index a2c67ca9..3b6616ae 100644 --- a/tests/functional/model/run_utility_model_test.py +++ b/tests/functional/model/run_utility_model_test.py @@ -272,7 +272,7 @@ def sum_numbers(num1: int, num2: int): # Verify updated state updated_model = ModelFactory.get(utility_model.id) - assert updated_model.status == AssetStatus.DRAFT + assert updated_model.status == AssetStatus.ONBOARDED assert updated_model.name == "sum_numbers_test" assert updated_model.description == "Updated to sum numbers utility" assert len(updated_model.inputs) == 2 @@ -298,17 +298,19 @@ def multiply_numbers(num1: int, num2: int): return num1 * num2 updated_model.code = multiply_numbers - assert updated_model.status == AssetStatus.DRAFT - updated_model.deploy() assert updated_model.status == AssetStatus.ONBOARDED + # the next line should raise an exception + with pytest.raises(Exception, match=".*UtilityModel is already deployed*"): + updated_model.deploy() updated_model.save() + assert updated_model.status == AssetStatus.ONBOARDED # Verify partial update final_model = ModelFactory.get(utility_model.id) assert final_model.name == "sum_numbers_test" assert final_model.description == "Updated to sum numbers utility" - assert final_model.status == AssetStatus.DRAFT + assert final_model.status == AssetStatus.ONBOARDED # Test final behavior with new function but same input field names response = final_model.run({"num1": 5, "num2": 7}) assert response.status == "SUCCESS" From 8aa3d3192ea8dbb312ba76fe6dffb204befe53e0 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Fri, 9 May 2025 16:25:10 -0300 Subject: [PATCH 22/62] ENG-2115 fixing agent tests (#522) * Fixing agent functional tests * Fixing team agent test * Fixing test_instructions of team agents --- tests/functional/agent/agent_functional_test.py | 11 ++--------- .../team_agent/team_agent_functional_test.py | 4 ++-- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 10569440..c7c294fe 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -137,7 +137,7 @@ def test_custom_code_tool(delete_agents_and_team_agents, AgentFactory): ) assert tool is not None assert tool.description == "Add two strings" - assert tool.code == 'def main(aaa: str, bbb: str) -> str:\n """Add two strings"""\n return aaa + bbb' + assert tool.code.startswith("s3://") agent = AgentFactory.create( name="Add Strings Agent", description="Add two strings. Do not directly answer. Use the tool to add the strings.", @@ -399,6 +399,7 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): if agent: agent.delete() + @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents @@ -510,13 +511,6 @@ def test_instructions(delete_agents_and_team_agents, AgentFactory): assert response["data"]["session_id"] is not None assert response["data"]["output"] is not None assert "aixplain" in response["data"]["output"].lower() - assert "eve" in response["data"]["output"].lower() - - import os - - # Cleanup - os.remove("test.csv") - os.remove("test.db") agent.delete() @@ -606,4 +600,3 @@ def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory): assert "hello" in answer["data"]["output"].lower() assert "hello pipeline" in answer["data"]["intermediate_steps"][0]["tool_steps"][0]["tool"].lower() - diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 759f3920..baa0a34b 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -625,9 +625,9 @@ def test_team_agent_with_instructions(delete_agents_and_team_agents): assert response.status == "SUCCESS" assert "gato" in response.data["output"] - mentalist_steps = eval(response.data["intermediate_steps"][0]["output"]) + mentalist_steps = [json.loads(step) for step in eval(response.data["intermediate_steps"][0]["output"])] - called_agents = set([step["worker"] for step in mentalist_steps]) + called_agents = set([step["agent"] for step in mentalist_steps]) assert len(called_agents) == 1 assert "Agent 2" in called_agents From 87d5791250bce5a9b895761037ed27d7ace8ddf8 Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Mon, 12 May 2025 22:25:51 +0300 Subject: [PATCH 23/62] ENG-1551 ai xplain sdk caching onboarded models pipelines and agents (#389) * caching for agents, pipelines and models * formatting * Agent Cache Class added * removed unused imports * made requested changes and added functional tests * changes * Made changes to model cache * Made changes to model cache * model changes * changes for functions, languages, licenses, agents, pipelines caching structure * removed agent and pipeline changes, made metadata changes * fixed unit test issue * fixed relative import --------- Co-authored-by: ahmetgunduz Co-authored-by: Thiago Castro Ferreira --- aixplain/enums/function.py | 92 +++++++--- aixplain/enums/language.py | 63 +++++-- aixplain/enums/license.py | 53 ++++-- aixplain/factories/agent_factory/__init__.py | 2 +- aixplain/factories/model_factory/__init__.py | 166 +++++++++++++----- .../factories/pipeline_factory/__init__.py | 2 +- aixplain/modules/agent/__init__.py | 3 +- aixplain/modules/agent/tool/model_tool.py | 2 +- aixplain/modules/model/__init__.py | 93 ++++++++-- aixplain/modules/pipeline/asset.py | 2 +- aixplain/modules/pipeline/default.py | 17 +- aixplain/utils/asset_cache.py | 140 +++++++++++++++ aixplain/utils/cache_utils.py | 35 ++-- tests/functional/model/run_model_test.py | 27 ++- .../model/run_utility_model_test.py | 1 + tests/functional/pipelines/create_test.py | 27 ++- tests/unit/utility_tool_decorator_test.py | 38 ++-- 17 files changed, 623 insertions(+), 140 deletions(-) create mode 100644 aixplain/utils/asset_cache.py diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index a51f5301..462f2e38 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -20,26 +20,67 @@ Description: Function Enum """ - +import logging from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from enum import Enum from urllib.parse import urljoin -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.utils.asset_cache import AssetCache, CACHE_FOLDER from typing import Tuple, Dict from aixplain.base.parameters import BaseParameters, Parameter +import os CACHE_FILE = f"{CACHE_FOLDER}/functions.json" +LOCK_FILE = f"{CACHE_FILE}.lock" + +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Any + +@dataclass +class FunctionMetadata: + id: str + name: str + description: Optional[str] = None + params: List[Dict[str, Any]] = field(default_factory=list) + output: List[Dict[str, Any]] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "params": self.params, + "output": self.output, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: dict): + return cls( + id=data.get("id"), + name=data.get("name"), + description=data.get("description"), + params=data.get("params", []), + output=data.get("output", []), + metadata={k: v for k, v in data.items() if k not in {"id", "name", "description", "params", "output"}}, + ) def load_functions(): api_key = config.TEAM_API_KEY backend_url = config.BACKEND_URL - resp = load_from_cache(CACHE_FILE) - if resp is None: - url = urljoin(backend_url, "sdk/functions") + os.makedirs(CACHE_FOLDER, exist_ok=True) + url = urljoin(backend_url, "sdk/functions") + cache = AssetCache(FunctionMetadata, cache_filename="functions") + if cache.has_valid_cache(): + logging.info("Loading functions from cache...") + function_objects = list(cache.store.data.values()) + else: + logging.info("Fetching functions from backend...") + url = urljoin(backend_url, "sdk/functions") headers = {"x-api-key": api_key, "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) if not 200 <= r.status_code < 300: @@ -47,7 +88,9 @@ def load_functions(): f'Functions could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp) + results = resp.get("results") + function_objects = [FunctionMetadata.from_dict(f) for f in results] + cache.add_list(function_objects) class Function(str, Enum): def __new__(cls, value): @@ -63,8 +106,14 @@ def get_input_output_params(self) -> Tuple[Dict, Dict]: Tuple[Dict, Dict]: A tuple containing (input_params, output_params) """ function_io = FunctionInputOutput.get(self.value, None) - input_params = {param["code"]: param for param in function_io["spec"]["params"]} - output_params = {param["code"]: param for param in function_io["spec"]["output"]} + if function_io is None: + return {}, {} + input_params = { + param["code"]: param for param in function_io["spec"]["params"] + } + output_params = { + param["code"]: param for param in function_io["spec"]["output"] + } return input_params, output_params def get_parameters(self) -> "FunctionParameters": @@ -78,19 +127,18 @@ def get_parameters(self) -> "FunctionParameters": self._parameters = FunctionParameters(input_params) return self._parameters - functions = Function("Function", {w["id"].upper().replace("-", "_"): w["id"] for w in resp["items"]}) + functions = Function( + "Function", {f.id.upper().replace("-", "_"): f.id for f in function_objects} + ) functions_input_output = { - function["id"]: { - "input": { - input_data_object["dataType"] - for input_data_object in function["params"] - if input_data_object["required"] is True - }, - "output": {output_data_object["dataType"] for output_data_object in function["output"]}, - "spec": function, + f.id: { + "input": {p["dataType"] for p in f.params if p.get("required")}, + "output": {o["dataType"] for o in f.output}, + "spec": f.to_dict(), } - for function in resp["items"] + for f in function_objects } + return functions, functions_input_output @@ -105,7 +153,11 @@ def __init__(self, input_params: Dict): """ super().__init__() for param_code, param_config in input_params.items(): - self.parameters[param_code] = Parameter(name=param_code, required=param_config.get("required", False), value=None) + self.parameters[param_code] = Parameter( + name=param_code, + required=param_config.get("required", False), + value=None, + ) -Function, FunctionInputOutput = load_functions() +Function, FunctionInputOutput = load_functions() \ No newline at end of file diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index db66b2a1..a660024a 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -25,19 +25,50 @@ from urllib.parse import urljoin from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.utils.asset_cache import AssetCache +import logging +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional -CACHE_FILE = f"{CACHE_FOLDER}/languages.json" +@dataclass +class LanguageMetadata: + id: str + value: str + label: str + dialects: List[Dict[str, str]] = field(default_factory=list) + scripts: List[Any] = field(default_factory=list) + def to_dict(self) -> dict: + return { + "id": self.id, + "value": self.value, + "label": self.label, + "dialects": self.dialects, + "scripts": self.scripts, + } + + @classmethod + def from_dict(cls, data: dict): + return cls( + id=data.get("id"), + value=data.get("value"), + label=data.get("label"), + dialects=data.get("dialects", []), + scripts=data.get("scripts", []), + ) def load_languages(): - resp = load_from_cache(CACHE_FILE) - if resp is None: - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL - url = urljoin(backend_url, "sdk/languages") + url = urljoin(backend_url, "sdk/languages") + cache = AssetCache(LanguageMetadata, cache_filename="languages") + if cache.has_valid_cache(): + logging.info("Loading languages from cache...") + lang_entries = list(cache.store.data.values()) + else: + logging.info("Fetching languages from backend...") headers = {"x-api-key": api_key, "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) if not 200 <= r.status_code < 300: @@ -45,18 +76,22 @@ def load_languages(): f'Languages could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' ) resp = r.json() - save_to_cache(CACHE_FILE, resp) + lang_entries = [LanguageMetadata.from_dict(item) for item in resp] + cache.add_list(lang_entries) languages = {} - for w in resp: - language = w["value"] - language_label = "_".join(w["label"].split()) - languages[language_label] = {"language": language, "dialect": ""} - for dialect in w["dialects"]: + for entry in lang_entries: + language = entry.value + label = "_".join(entry.label.split()) + languages[label] = {"language": language, "dialect": ""} + for dialect in entry.dialects: dialect_label = "_".join(dialect["label"].split()).upper() dialect_value = dialect["value"] + languages[f"{label}_{dialect_label}"] = { + "language": language, + "dialect": dialect_value, + } - languages[language_label + "_" + dialect_label] = {"language": language, "dialect": dialect_value} return Enum("Language", languages, type=dict) diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index 566be092..da1623c4 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -25,31 +25,64 @@ from enum import Enum from urllib.parse import urljoin from aixplain.utils import config -from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.asset_cache import AssetCache, CACHE_FOLDER -CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" +from dataclasses import dataclass + +@dataclass +class LicenseMetadata: + id: str + name: str + description: str + url: str + allowCustomUrl: bool + + def to_dict(self) -> dict: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "url": self.url, + "allowCustomUrl": self.allowCustomUrl, + } + + @classmethod + def from_dict(cls, data: dict): + return cls( + id=data.get("id"), + name=data.get("name"), + description=data.get("description"), + url=data.get("url"), + allowCustomUrl=data.get("allowCustomUrl", False), + ) def load_licenses(): - resp = load_from_cache(CACHE_FILE) try: - if resp is None: - api_key = config.TEAM_API_KEY - backend_url = config.BACKEND_URL + api_key = config.TEAM_API_KEY + backend_url = config.BACKEND_URL - url = urljoin(backend_url, "sdk/licenses") + url = urljoin(backend_url, "sdk/licenses") + cache = AssetCache(LicenseMetadata, cache_filename="licenses") + if cache.has_valid_cache(): + logging.info("Loading licenses from cache...") + license_objects = list(cache.store.data.values()) + else: + logging.info("Fetching licenses from backend...") headers = {"x-api-key": api_key, "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) if not 200 <= r.status_code < 300: raise Exception( - f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)' + f'Licenses could not be loaded, probably due to the set API key (e.g. "{api_key}") is not valid. For help, please refer to the documentation.' ) resp = r.json() - save_to_cache(CACHE_FILE, resp) - licenses = {"_".join(w["name"].split()): w["id"] for w in resp} + license_objects = [LicenseMetadata.from_dict(item) for item in resp] + cache.add_list(license_objects) + + licenses = {"_".join(lic.name.split()): lic.id for lic in license_objects} return Enum("License", licenses, type=str) except Exception: logging.exception("License Loading Error") diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 156f1b54..92c6ec3b 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -386,4 +386,4 @@ def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> Agent: if "message" in resp: msg = resp["message"] error_msg = f"Agent Get Error (HTTP {r.status_code}): {msg}" - raise Exception(error_msg) + raise Exception(error_msg) \ No newline at end of file diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index 85c1ac4f..2a24ced7 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -25,10 +25,19 @@ import logging from aixplain.modules.model import Model from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput -from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder +from aixplain.enums import ( + Function, + Language, + OwnershipType, + Supplier, + SortBy, + SortOrder, +) from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin +from aixplain.utils.asset_cache import AssetCache, CACHE_FOLDER +from aixplain.factories.model_factory.utils import create_model_from_response class ModelFactory: @@ -78,7 +87,9 @@ def create_utility_model( url = urljoin(cls.backend_url, "sdk/utilities") headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} try: - logging.info(f"Start service for POST Utility Model - {url} - {headers} - {payload}") + logging.info( + f"Start service for POST Utility Model - {url} - {headers} - {payload}" + ) r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() except Exception as e: @@ -87,49 +98,78 @@ def create_utility_model( if 200 <= r.status_code < 300: utility_model.id = resp["id"] - logging.info(f"Utility Model Creation: Model {utility_model.id} instantiated.") + logging.info( + f"Utility Model Creation: Model {utility_model.id} instantiated." + ) return utility_model else: - error_message = ( - f"Utility Model Creation: Failed to create utility model. Status Code: {r.status_code}. Error: {resp}" - ) + error_message = f"Utility Model Creation: Failed to create utility model. Status Code: {r.status_code}. Error: {resp}" logging.error(error_message) raise Exception(error_message) @classmethod - def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: - """Create a 'Model' object from model id - - Args: - model_id (Text): Model ID of required model. - api_key (Optional[Text], optional): Model API key. Defaults to None. + def get( + cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = True + ) -> Model: + """Create a 'Model' object from model id""" + cache = AssetCache(Model) + + if use_cache: + if cache.has_valid_cache(): + cached_model = cache.store.data.get(model_id) + if cached_model: + return cached_model + logging.info("Model not found in valid cache, fetching individually...") + model = cls._fetch_model_by_id(model_id, api_key) + cache.add(model) + return model + else: + try: + model_list_resp = cls.list(model_ids=None, api_key=api_key) + models = model_list_resp["results"] + cache.add_list(models) + for model in models: + if model.id == model_id: + return model + except Exception as e: + logging.error(f"Error fetching model list: {e}") + raise e + + logging.info("Fetching model directly without cache...") + model = cls._fetch_model_by_id(model_id, api_key) + cache.add(model) + return model - Returns: - Model: Created 'Model' object - """ + @classmethod + def _fetch_model_by_id( + cls, model_id: Text, api_key: Optional[Text] = None + ) -> Model: resp = None try: url = urljoin(cls.backend_url, f"sdk/models/{model_id}") - - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key or config.TEAM_API_KEY}", + "Content-Type": "application/json", + } logging.info(f"Start service for GET Model - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - except Exception: - if resp is not None and "statusCode" in resp: + if resp and "statusCode" in resp: status_code = resp["statusCode"] - message = resp["message"] - message = f"Model Creation: Status {status_code} - {message}" + message = f"Model Creation: Status {status_code} - {resp['message']}" else: message = "Model Creation: Unspecified Error" logging.error(message) - raise Exception(f"{message}") + raise Exception(message) + if 200 <= r.status_code < 300: resp["api_key"] = config.TEAM_API_KEY if api_key is not None: resp["api_key"] = api_key - from aixplain.factories.model_factory.utils import create_model_from_response + from aixplain.factories.model_factory.utils import ( + create_model_from_response, + ) model = create_model_from_response(resp) logging.info(f"Model Creation: Model {model_id} instantiated.") @@ -186,7 +226,9 @@ def list( and ownership is None and sort_by is None ), "Cannot filter by function, suppliers, source languages, target languages, is finetunable, ownership, sort by when using model ids" - assert len(model_ids) <= page_size, "Page size must be greater than the number of model ids" + assert ( + len(model_ids) <= page_size + ), "Page size must be greater than the number of model ids" models, total = get_model_from_ids(model_ids, api_key), len(model_ids) else: from aixplain.factories.model_factory.utils import get_assets_from_page @@ -228,7 +270,10 @@ def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]: if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} else: - headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "x-api-key": f"{config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("get", machines_url, headers=headers) response_dicts = json.loads(response.text) for dictionary in response_dicts: @@ -247,15 +292,23 @@ def list_gpus(cls, api_key: Optional[Text] = None) -> List[List[Text]]: """ gpu_url = urljoin(config.BACKEND_URL, "sdk/model-onboarding/gpus") if api_key: - headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("get", gpu_url, headers=headers) response_list = json.loads(response.text) return response_list @classmethod - def list_functions(cls, verbose: Optional[bool] = False, api_key: Optional[Text] = None) -> List[Dict]: + def list_functions( + cls, verbose: Optional[bool] = False, api_key: Optional[Text] = None + ) -> List[Dict]: """Lists supported model functions on platform. Args: @@ -272,7 +325,10 @@ def list_functions(cls, verbose: Optional[bool] = False, api_key: Optional[Text] if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} else: - headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "x-api-key": f"{config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("get", functions_url, headers=headers) response_dict = json.loads(response.text) if verbose: @@ -331,7 +387,10 @@ def create_asset_repo( if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} else: - headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "x-api-key": f"{config.TEAM_API_KEY}", + "Content-Type": "application/json", + } payload = { "model": { @@ -347,7 +406,9 @@ def create_asset_repo( "onboardingParams": {}, } logging.debug(f"Body: {str(payload)}") - response = _request_with_retry("post", create_url, headers=headers, json=payload) + response = _request_with_retry( + "post", create_url, headers=headers, json=payload + ) assert response.status_code == 201 @@ -367,9 +428,15 @@ def asset_repo_login(cls, api_key: Optional[Text] = None) -> Dict: login_url = urljoin(config.BACKEND_URL, "sdk/ecr/login") logging.debug(f"URL: {login_url}") if api_key: - headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("post", login_url, headers=headers) response_dict = json.loads(response.text) return response_dict @@ -399,10 +466,15 @@ def onboard_model( if api_key: headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} else: - headers = {"x-api-key": f"{config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "x-api-key": f"{config.TEAM_API_KEY}", + "Content-Type": "application/json", + } payload = {"image": image_tag, "sha": image_hash, "hostMachine": host_machine} logging.debug(f"Body: {str(payload)}") - response = _request_with_retry("post", onboard_url, headers=headers, json=payload) + response = _request_with_retry( + "post", onboard_url, headers=headers, json=payload + ) if response.status_code == 201: message = "Your onboarding request has been submitted to an aiXplain specialist for finalization. We will notify you when the process is completed." logging.info(message) @@ -432,9 +504,15 @@ def deploy_huggingface_model( supplier, model_name = hf_repo_id.split("/") deploy_url = urljoin(config.BACKEND_URL, "sdk/model-onboarding/onboard") if api_key: - headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } body = { "model": { "name": name, @@ -458,7 +536,9 @@ def deploy_huggingface_model( return response_dicts @classmethod - def get_huggingface_model_status(cls, model_id: Text, api_key: Optional[Text] = None): + def get_huggingface_model_status( + cls, model_id: Text, api_key: Optional[Text] = None + ): """Gets the on-boarding status of a Hugging Face model with ID MODEL_ID. Args: @@ -469,9 +549,15 @@ def get_huggingface_model_status(cls, model_id: Text, api_key: Optional[Text] = """ status_url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") if api_key: - headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {config.TEAM_API_KEY}", + "Content-Type": "application/json", + } response = _request_with_retry("get", status_url, headers=headers) logging.debug(response.text) response_dicts = json.loads(response.text) diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index ba164199..f2606ca2 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -313,4 +313,4 @@ def create( return Pipeline(response["id"], name, api_key) except Exception as e: - raise Exception(e) + raise Exception(e) \ No newline at end of file diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 8f19d3ec..f9510480 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -131,6 +131,7 @@ def _validate(self) -> None: tool_name = tool.name elif isinstance(tool, Model): assert not isinstance(tool, Agent), "Agent cannot contain another Agent." + tool_name = tool.name tool_names.append(tool_name) @@ -419,4 +420,4 @@ def save(self) -> None: self.update() def __repr__(self): - return f"Agent(id={self.id}, name={self.name}, function={self.function})" + return f"Agent(id={self.id}, name={self.name}, function={self.function})" \ No newline at end of file diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 9b073a84..8a0f9dc2 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -174,7 +174,7 @@ def validate(self) -> None: self.description = self.model.description elif self.function is not None: try: - self.description = FunctionInputOutput[self.function.value]["spec"]["metaData"]["description"] + self.description = FunctionInputOutput[self.function.value]["spec"]["description"] except Exception: self.description = "" diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 92e48990..3a27cae2 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -123,7 +123,9 @@ def to_dict(self) -> Dict: Returns: Dict: Model Information """ - clean_additional_info = {k: v for k, v in self.additional_info.items() if v is not None} + clean_additional_info = { + k: v for k, v in self.additional_info.items() if v is not None + } return { "id": self.id, "name": self.name, @@ -133,6 +135,7 @@ def to_dict(self) -> Dict: "input_params": self.input_params, "output_params": self.output_params, "model_params": self.model_params.to_dict(), + "function": self.function, "status": self.status, } @@ -148,7 +151,11 @@ def __repr__(self): return f"" def sync_poll( - self, poll_url: Text, name: Text = "model_process", wait_time: float = 0.5, timeout: float = 300 + self, + poll_url: Text, + name: Text = "model_process", + wait_time: float = 0.5, + timeout: float = 300, ) -> ModelResponse: """Keeps polling the platform to check whether an asynchronous call is done. @@ -179,15 +186,21 @@ def sync_poll( wait_time *= 1.1 except Exception as e: response_body = ModelResponse( - status=ResponseStatus.FAILED, completed=False, error_message="No response from the service." + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", ) logging.error(f"Polling for Model: polling for {name}: {e}") break if response_body["completed"] is True: - logging.debug(f"Polling for Model: Final status of polling for {name}: {response_body}") + logging.debug( + f"Polling for Model: Final status of polling for {name}: {response_body}" + ) else: response_body = ModelResponse( - status=ResponseStatus.FAILED, completed=False, error_message="No response from the service." + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", ) logging.error( f"Polling for Model: Final status of polling for {name}: No response in {timeout} seconds - {response_body}" @@ -216,6 +229,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: status = ResponseStatus.IN_PROGRESS logging.debug(f"Single Poll for Model: Status of polling for {name}: {resp}") + return ModelResponse( status=resp.pop("status", status), data=resp.pop("data", ""), @@ -282,12 +296,18 @@ def run( try: poll_url = response["url"] end = time.time() - return self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + return self.sync_poll( + poll_url, name=name, timeout=timeout, wait_time=wait_time + ) except Exception as e: msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Model Run: Error in running for {name}: {e}") end = time.time() - response = {"status": "FAILED", "error_message": msg, "runTime": end - start} + response = { + "status": "FAILED", + "error_message": msg, + "runTime": end - start, + } return ModelResponse( status=response.pop("status", ResponseStatus.FAILED), data=response.pop("data", ""), @@ -302,7 +322,10 @@ def run( ) def run_async( - self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = None + self, + data: Union[Text, Dict], + name: Text = "model_process", + parameters: Optional[Dict] = None, ) -> ModelResponse: """Runs asynchronously a model call. @@ -347,7 +370,9 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): resp = None try: url = urljoin(self.backend_url, f"sdk/finetune/{self.id}/ml-logs") - logging.info(f"Start service for GET Check FineTune status Model - {url} - {headers}") + logging.info( + f"Start service for GET Check FineTune status Model - {url} - {headers}" + ) r = _request_with_retry("get", url, headers=headers) resp = r.json() finetune_status = AssetStatus(resp["finetuneStatus"]) @@ -374,9 +399,21 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): status = FinetuneStatus( status=finetune_status, model_status=model_status, - epoch=float(log["epoch"]) if "epoch" in log and log["epoch"] is not None else None, - training_loss=float(log["trainLoss"]) if "trainLoss" in log and log["trainLoss"] is not None else None, - validation_loss=float(log["evalLoss"]) if "evalLoss" in log and log["evalLoss"] is not None else None, + epoch=( + float(log["epoch"]) + if "epoch" in log and log["epoch"] is not None + else None + ), + training_loss=( + float(log["trainLoss"]) + if "trainLoss" in log and log["trainLoss"] is not None + else None + ), + validation_loss=( + float(log["evalLoss"]) + if "evalLoss" in log and log["evalLoss"] is not None + else None + ), ) else: status = FinetuneStatus( @@ -384,7 +421,9 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): model_status=model_status, ) - logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}.") + logging.info( + f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}." + ) return status except Exception: message = "" @@ -399,7 +438,10 @@ def delete(self) -> None: """Delete Model service""" try: url = urljoin(self.backend_url, f"sdk/models/{self.id}") - headers = {"Authorization": f"Token {self.api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Token {self.api_key}", + "Content-Type": "application/json", + } logging.info(f"Start service for DELETE Model - {url} - {headers}") r = _request_with_retry("delete", url, headers=headers) if r.status_code != 200: @@ -408,3 +450,26 @@ def delete(self) -> None: message = "Model Deletion Error: Make sure the model exists and you are the owner." logging.error(message) raise Exception(f"{message}") + + @classmethod + def from_dict(cls, data: Dict) -> "Model": + return cls( + id=data.get("id", ""), + name=data.get("name", ""), + description=data.get("description", ""), + api_key=data.get("api_key", config.TEAM_API_KEY), + supplier=data.get("supplier", "aiXplain"), + version=data.get("version", "1.0"), + function=Function(data.get("function")), + is_subscribed=data.get("is_subscribed", False), + cost=data.get("cost"), + created_at=( + datetime.fromisoformat(data["created_at"]) + if data.get("created_at") + else None + ), + input_params=data.get("input_params"), + output_params=data.get("output_params"), + model_params=data.get("model_params"), + **data.get("additional_info", {}), + ) diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index cc234337..adf74644 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -601,4 +601,4 @@ def deploy(self, api_key: Optional[Text] = None) -> None: raise Exception(f"Error deploying because of backend error: {e}") from e def __repr__(self): - return f"Pipeline(id={self.id}, name={self.name})" + return f"Pipeline(id={self.id}, name={self.name})" \ No newline at end of file diff --git a/aixplain/modules/pipeline/default.py b/aixplain/modules/pipeline/default.py index 41ae3c71..226deefd 100644 --- a/aixplain/modules/pipeline/default.py +++ b/aixplain/modules/pipeline/default.py @@ -1,5 +1,6 @@ from .asset import Pipeline as PipelineAsset from .designer import DesignerPipeline +from enum import Enum class DefaultPipeline(PipelineAsset, DesignerPipeline): @@ -13,4 +14,18 @@ def save(self, *args, **kwargs): super().save(*args, **kwargs) def to_dict(self) -> dict: - return self.serialize() + data = self.__dict__.copy() + + for key, value in data.items(): + if isinstance(value, Enum): + data[key] = value.value + + elif isinstance(value, list): + data[key] = [ + v.to_dict() if hasattr(v, "to_dict") else str(v) for v in value + ] + + elif hasattr(value, "to_dict"): + data[key] = value.to_dict() + + return data diff --git a/aixplain/utils/asset_cache.py b/aixplain/utils/asset_cache.py new file mode 100644 index 00000000..982b8d95 --- /dev/null +++ b/aixplain/utils/asset_cache.py @@ -0,0 +1,140 @@ +import os +import logging +import json +import time +from typing import Dict, Optional +from dataclasses import dataclass +from filelock import FileLock + +from aixplain.utils import config +from typing import TypeVar, Generic, Type +from typing import List + +logger = logging.getLogger(__name__) + + +T = TypeVar("T") + +# Constants +CACHE_FOLDER = ".cache" +DEFAULT_CACHE_EXPIRY = 86400 + + +@dataclass +class Store(Generic[T]): + data: Dict[str, T] + expiry: int + + +class AssetCache(Generic[T]): + """ + A modular caching system to handle different asset types (Models, Pipelines, Agents). + """ + + def __init__( + self, + cls: Type[T], + cache_filename: Optional[str] = None, + ): + self.cls = cls + if cache_filename is None: + cache_filename = self.cls.__name__.lower() + + # create cache file and lock file name + self.cache_file = os.path.join(CACHE_FOLDER, f"{cache_filename}.json") + self.lock_file = os.path.join(CACHE_FOLDER, f"{cache_filename}.lock") + self.store = Store(data={}, expiry=self.compute_expiry()) + self.load() + + if not os.path.exists(self.cache_file): + self.save() + + def compute_expiry(self): + try: + expiry = int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) + except Exception as e: + logger.warning( + f"Failed to parse CACHE_EXPIRY_TIME: {e}, " + f"fallback to default value {DEFAULT_CACHE_EXPIRY}" + ) + # remove the CACHE_EXPIRY_TIME from the environment variables + del os.environ["CACHE_EXPIRY_TIME"] + expiry = DEFAULT_CACHE_EXPIRY + + return time.time() + int(expiry) + + def invalidate(self): + self.store = Store(data={}, expiry=self.compute_expiry()) + # delete cache file and lock file + if os.path.exists(self.cache_file): + os.remove(self.cache_file) + if os.path.exists(self.lock_file): + os.remove(self.lock_file) + + def load(self): + if not os.path.exists(self.cache_file): + self.invalidate() + return + + with FileLock(self.lock_file): + with open(self.cache_file, "r") as f: + try: + cache_data = json.load(f) + expiry = cache_data["expiry"] + raw_data = cache_data["data"] + parsed_data = { + k: self.cls.from_dict(v) for k, v in raw_data.items() + } + + self.store = Store(data=parsed_data, expiry=expiry) + + if self.store.expiry < time.time(): + logger.warning(f"Cache expired for {self.cls.__name__}") + self.invalidate() + + except Exception as e: + self.invalidate() + logger.warning(f"Failed to load cache data: {e}") + + if self.store.expiry < time.time(): + logger.warning( + f"Cache expired, invalidating cache for {self.cls.__name__}" + ) + self.invalidate() + return + + def save(self): + + os.makedirs(CACHE_FOLDER, exist_ok=True) + + with FileLock(self.lock_file): + with open(self.cache_file, "w") as f: + data_dict = {} + for asset_id, asset in self.store.data.items(): + try: + data_dict[asset_id] = asset.to_dict() + except Exception as e: + logger.error(f"Error serializing {asset_id}: {e}") + serializable_store = { + "expiry": self.store.expiry, + "data": data_dict, + } + + json.dump(serializable_store, f, indent=4) + + def get(self, asset_id: str) -> Optional[T]: + return self.store.data.get(asset_id) + + def add(self, asset: T): + self.store.data[asset.id] = asset + self.save() + + def add_list(self, assets: List[T]): + self.store.data = {asset.id: asset for asset in assets} + self.save() + + def get_all(self) -> List[T]: + return list(self.store.data.values()) + + def has_valid_cache(self) -> bool: + return self.store.expiry >= time.time() and bool(self.store.data) diff --git a/aixplain/utils/cache_utils.py b/aixplain/utils/cache_utils.py index 5a0eb6ae..01981701 100644 --- a/aixplain/utils/cache_utils.py +++ b/aixplain/utils/cache_utils.py @@ -2,26 +2,35 @@ import json import time import logging +from filelock import FileLock -CACHE_DURATION = 24 * 60 * 60 -CACHE_FOLDER = ".aixplain_cache" +CACHE_FOLDER = ".cache" +CACHE_FILE = f"{CACHE_FOLDER}/cache.json" +LOCK_FILE = f"{CACHE_FILE}.lock" +DEFAULT_CACHE_EXPIRY = 86400 -def save_to_cache(cache_file, data): +def get_cache_expiry(): + return int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) + + +def save_to_cache(cache_file, data, lock_file): try: os.makedirs(os.path.dirname(cache_file), exist_ok=True) - with open(cache_file, "w") as f: - json.dump({"timestamp": time.time(), "data": data}, f) + with FileLock(lock_file): + with open(cache_file, "w") as f: + json.dump({"timestamp": time.time(), "data": data}, f) except Exception as e: logging.error(f"Failed to save cache to {cache_file}: {e}") -def load_from_cache(cache_file): - if os.path.exists(cache_file) is True: - with open(cache_file, "r") as f: - cache_data = json.load(f) - if time.time() - cache_data["timestamp"] < CACHE_DURATION: - return cache_data["data"] - else: - return None +def load_from_cache(cache_file, lock_file): + if os.path.exists(cache_file): + with FileLock(lock_file): + with open(cache_file, "r") as f: + cache_data = json.load(f) + if time.time() - cache_data["timestamp"] < int(get_cache_expiry()): + return cache_data["data"] + else: + return None return None diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 7fbb74d1..bf6d9320 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -1,5 +1,7 @@ __author__ = "thiagocastroferreira" +import os +import json import pytest import requests @@ -8,9 +10,10 @@ from aixplain.modules import LLM from datetime import datetime, timedelta, timezone from pathlib import Path +from aixplain.utils.cache_utils import CACHE_FOLDER +from aixplain.modules.model import Model from aixplain.factories.index_factory.utils import AirParams, VectaraParams, GraphRAGParams, ZeroEntropyParams - def pytest_generate_tests(metafunc): if "llm_model" in metafunc.fixturenames: four_weeks_ago = datetime.now(timezone.utc) - timedelta(weeks=4) @@ -193,6 +196,28 @@ def test_llm_run_with_file(): assert "🤖" in response["data"], "Robot emoji should be present in the response" +def test_aixplain_model_cache_creation(): + """Ensure AssetCache is triggered and cache is created.""" + + cache_file = os.path.join(CACHE_FOLDER, "models.json") + + # Clean up cache before the test + if os.path.exists(cache_file): + os.remove(cache_file) + + # Instantiate the Model (replace this with a real model ID from your env) + model_id = "6239efa4822d7a13b8e20454" # Translate from Punjabi to Portuguese (Brazil) + _ = Model(id=model_id) + + # Assert the cache file was created + assert os.path.exists(cache_file), "Expected cache file was not created." + + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + assert "data" in cache_data, "Cache file structure invalid - missing 'data' key." + assert any(m.get("id") == model_id for m in cache_data["data"]["items"]), "Instantiated model not found in cache." +======= def test_index_model_air_with_image(): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py index 3b6616ae..09e5ace2 100644 --- a/tests/functional/model/run_utility_model_test.py +++ b/tests/functional/model/run_utility_model_test.py @@ -4,6 +4,7 @@ import pytest + def test_run_utility_model(): utility_model = None try: diff --git a/tests/functional/pipelines/create_test.py b/tests/functional/pipelines/create_test.py index ae33f454..8bcb3b7d 100644 --- a/tests/functional/pipelines/create_test.py +++ b/tests/functional/pipelines/create_test.py @@ -15,7 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - +import os +from aixplain.utils.cache_utils import CACHE_FOLDER +from aixplain.modules.pipeline import Pipeline import json import pytest from aixplain.factories import PipelineFactory @@ -75,3 +77,26 @@ def test_create_pipeline_wrong_path(PipelineFactory): with pytest.raises(Exception): PipelineFactory.create(name=pipeline_name, pipeline="/") + + + +@pytest.mark.parametrize("PipelineFactory", [PipelineFactory]) +def test_pipeline_cache_creation(PipelineFactory): + cache_file = os.path.join(CACHE_FOLDER, "pipelines.json") + if os.path.exists(cache_file): + os.remove(cache_file) + + pipeline_json = "tests/functional/pipelines/data/pipeline.json" + pipeline_name = str(uuid4()) + pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_json) + + assert os.path.exists(cache_file), "Pipeline cache file was not created!" + + with open(cache_file, "r") as f: + cache_data = json.load(f) + + assert "data" in cache_data, "Cache format invalid, missing 'data'." + + pipeline.delete() + if os.path.exists(cache_file): + os.remove(cache_file) \ No newline at end of file diff --git a/tests/unit/utility_tool_decorator_test.py b/tests/unit/utility_tool_decorator_test.py index f9c87f02..c63aa5f6 100644 --- a/tests/unit/utility_tool_decorator_test.py +++ b/tests/unit/utility_tool_decorator_test.py @@ -3,16 +3,15 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model.utility_model import utility_tool, UtilityModelInput + def test_utility_tool_basic_decoration(): """Test basic decoration with minimal parameters""" - @utility_tool( - name="test_function", - description="Test function description" - ) + + @utility_tool(name="test_function", description="Test function description") def test_func(input_text: str) -> str: return input_text - assert hasattr(test_func, '_is_utility_tool') + assert hasattr(test_func, "_is_utility_tool") assert test_func._is_utility_tool is True assert test_func._tool_name == "test_function" assert test_func._tool_description == "Test function description" @@ -20,19 +19,20 @@ def test_func(input_text: str) -> str: assert test_func._tool_output_examples == "" assert test_func._tool_status == AssetStatus.DRAFT + def test_utility_tool_with_all_parameters(): """Test decoration with all optional parameters""" inputs = [ UtilityModelInput(name="text_input", type=DataType.TEXT, description="A text input"), - UtilityModelInput(name="num_input", type=DataType.NUMBER, description="A number input") + UtilityModelInput(name="num_input", type=DataType.NUMBER, description="A number input"), ] - + @utility_tool( name="full_test_function", description="Full test function description", inputs=inputs, output_examples="Example output: Hello World", - status=AssetStatus.ONBOARDED + status=AssetStatus.ONBOARDED, ) def test_func(text_input: str, num_input: int) -> str: return f"{text_input} {num_input}" @@ -45,32 +45,28 @@ def test_func(text_input: str, num_input: int) -> str: assert test_func._tool_output_examples == "Example output: Hello World" assert test_func._tool_status == AssetStatus.ONBOARDED + def test_utility_tool_function_still_callable(): """Test that decorated function remains callable""" - @utility_tool( - name="callable_test", - description="Test function callable" - ) + + @utility_tool(name="callable_test", description="Test function callable") def test_func(x: int, y: int) -> int: return x + y assert test_func(2, 3) == 5 assert test_func._is_utility_tool is True + def test_utility_tool_invalid_inputs(): """Test validation of invalid inputs""" with pytest.raises(ValueError): - @utility_tool( - name="", # Empty name should raise error - description="Test description" - ) + + @utility_tool(name="", description="Test description") # Empty name should raise error def test_func(): pass with pytest.raises(ValueError): - @utility_tool( - name="test_name", - description="" # Empty description should raise error - ) + + @utility_tool(name="test_name", description="") # Empty description should raise error def test_func(): - pass \ No newline at end of file + pass From 6f600d947149ed159c68dceb7bd3171901af1951 Mon Sep 17 00:00:00 2001 From: OsujiCC Date: Wed, 14 May 2025 16:52:41 +0100 Subject: [PATCH 24/62] Bug 531: standardize asset names (#521) * Standardize the asset names * debug unit test * remove double __repr__ * resolve conflicts --- aixplain/modules/agent/__init__.py | 2 +- aixplain/modules/model/__init__.py | 4 ++-- aixplain/modules/pipeline/asset.py | 2 +- aixplain/modules/team_agent/__init__.py | 3 +++ tests/unit/model_test.py | 4 ++-- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index f9510480..c534ba7f 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -420,4 +420,4 @@ def save(self) -> None: self.update() def __repr__(self): - return f"Agent(id={self.id}, name={self.name}, function={self.function})" \ No newline at end of file + return f"Agent: {self.name} (id={self.id})" diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 3a27cae2..04e39deb 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -146,9 +146,9 @@ def get_parameters(self) -> ModelParameters: def __repr__(self): try: - return f"" + return f"Model: {self.name} by {self.supplier['name']} (id={self.id})" except Exception: - return f"" + return f"Model: {self.name} by {self.supplier} (id={self.id})" def sync_poll( self, diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index adf74644..3f585315 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -601,4 +601,4 @@ def deploy(self, api_key: Optional[Text] = None) -> None: raise Exception(f"Error deploying because of backend error: {e}") from e def __repr__(self): - return f"Pipeline(id={self.id}, name={self.name})" \ No newline at end of file + return f"Pipeline: {self.name} (id={self.id})" diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 449e9549..2d3d22a5 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -417,3 +417,6 @@ def update(self) -> None: else: error_msg = f"Team Agent Update Error (HTTP {r.status_code}): {resp}" raise Exception(error_msg) + + def __repr__(self): + return f"TeamAgent: {self.name} (id={self.id})" \ No newline at end of file diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 57ef17d0..6c1a372f 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -435,11 +435,11 @@ def test_model_to_dict(): def test_model_repr(): # Test with supplier as dict model1 = Model(id="test-id", name="Test Model", supplier={"name": "aiXplain"}) - assert repr(model1).lower() == "".lower() + assert repr(model1).lower() == "model: test model by aixplain (id=test-id)".lower() # Test with supplier as string model2 = Model(id="test-id", name="Test Model", supplier="aiXplain") - assert str(model2).lower() == "".lower() + assert str(model2).lower() == "model: test model by aixplain (id=test-id)".lower() def test_poll_with_error(): From 85d040694b21f207cf91cc9122dc6dcf2c5cb4f4 Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Thu, 15 May 2025 17:03:21 +0300 Subject: [PATCH 25/62] fixed asset issue (#526) --- aixplain/factories/model_factory/__init__.py | 3 --- aixplain/utils/asset_cache.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index 2a24ced7..bf0c21b5 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -167,9 +167,6 @@ def _fetch_model_by_id( resp["api_key"] = config.TEAM_API_KEY if api_key is not None: resp["api_key"] = api_key - from aixplain.factories.model_factory.utils import ( - create_model_from_response, - ) model = create_model_from_response(resp) logging.info(f"Model Creation: Model {model_id} instantiated.") diff --git a/aixplain/utils/asset_cache.py b/aixplain/utils/asset_cache.py index 982b8d95..a2b8d09e 100644 --- a/aixplain/utils/asset_cache.py +++ b/aixplain/utils/asset_cache.py @@ -126,7 +126,7 @@ def get(self, asset_id: str) -> Optional[T]: return self.store.data.get(asset_id) def add(self, asset: T): - self.store.data[asset.id] = asset + self.store.data[asset.id] = asset.__dict__ self.save() def add_list(self, assets: List[T]): From 4d650c919663c49f386c811f1d2ba1c67cc912ba Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Thu, 15 May 2025 11:27:08 -0300 Subject: [PATCH 26/62] Make functional test more stable (#525) --- .../functional/agent/agent_functional_test.py | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index f4cacfbb..aaf0cd54 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -286,21 +286,27 @@ def test_update_tools_of_agent(run_input_map, delete_agents_and_team_agents, Age @pytest.mark.parametrize( "tool_config", [ - { - "type": "search", - "model": "65c51c556eb563350f6e1bb1", - "query": "What is the weather in New York?", - "description": "Search tool with custom number of results", - "expected_tool_input": "'numResults': 5", - }, - { - "type": "translation", - "supplier": "Microsoft", - "function": "translation", - "query": "Translate: Olá, como vai você?", - "description": "Translation tool with target language", - "expected_tool_input": "targetlanguage", - }, + pytest.param( + { + "type": "search", + "model": "65c51c556eb563350f6e1bb1", + "query": "What is the weather in New York?", + "description": "Search tool with custom number of results", + "expected_tool_input": "'numResults': 5", + }, + id="search_tool", + ), + pytest.param( + { + "type": "translation", + "supplier": "Microsoft", + "function": "translation", + "query": "Translate: Olá, como vai você?", + "description": "Translation tool with target language", + "expected_tool_input": "targetlanguage", + }, + id="translation_tool", + ), ], ) def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agents): @@ -327,7 +333,7 @@ def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agent # Create and run agent agent = AgentFactory.create( name="Test Parameter Agent", - description="Test agent with parameterized tools", + description="Test agent with parameterized tools. You MUST use a tool for the tasks.", tools=[tool], llm_id="6626a3a8c8f1d089790cf5a2", # Using LLM ID from test data ) @@ -399,6 +405,7 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): if agent: agent.delete() + @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents From 91d686450e4eec052b12db1076f837a0e4cfd9e0 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Mon, 19 May 2025 16:28:48 -0300 Subject: [PATCH 27/62] ENG-2100: Enable JSON schema as output format (#513) * Enable JSON schema as output format * Expected output field * Add functional test * Fixing input parameters for team agent * Fix team agent test for JSON output --- aixplain/modules/agent/__init__.py | 23 ++++- aixplain/modules/agent/output_format.py | 1 + aixplain/modules/team_agent/__init__.py | 18 +++- .../functional/agent/agent_functional_test.py | 80 +++++++++++++++ .../team_agent/team_agent_functional_test.py | 97 +++++++++++++++++++ 5 files changed, 213 insertions(+), 6 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index c534ba7f..eaf628c6 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -35,6 +35,7 @@ from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables +from pydantic import BaseModel from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -169,6 +170,7 @@ def run( max_tokens: int = 2048, max_iterations: int = 10, output_format: OutputFormat = OutputFormat.TEXT, + expected_output: Optional[Union[BaseModel, Text, dict]] = None, ) -> AgentResponse: """Runs an agent call. @@ -184,7 +186,8 @@ def run( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10. - output_format (ResponseFormat, optional): response format. Defaults to TEXT. + output_format (OutputFormat, optional): response format. Defaults to TEXT. + expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. Returns: Dict: parsed output from model """ @@ -202,6 +205,7 @@ def run( max_tokens=max_tokens, max_iterations=max_iterations, output_format=output_format, + expected_output=expected_output, ) if response["status"] == ResponseStatus.FAILED: end = time.time() @@ -254,6 +258,7 @@ def run_async( max_tokens: int = 2048, max_iterations: int = 10, output_format: OutputFormat = OutputFormat.TEXT, + expected_output: Optional[Union[BaseModel, Text, dict]] = None, ) -> AgentResponse: """Runs asynchronously an agent call. @@ -267,7 +272,8 @@ def run_async( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10. - output_format (ResponseFormat, optional): response format. Defaults to TEXT. + output_format (OutputFormat, optional): response format. Defaults to TEXT. + expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. Returns: dict: polling URL in response """ @@ -276,6 +282,11 @@ def run_async( if not self.is_valid: raise Exception("Agent is not valid. Please validate the agent before running.") + if output_format == OutputFormat.JSON: + assert expected_output is not None and ( + issubclass(expected_output, BaseModel) or isinstance(expected_output, dict) + ), "Expected output must be a Pydantic BaseModel or a JSON object when output format is JSON." + assert data is not None or query is not None, "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): @@ -310,6 +321,11 @@ def run_async( # build query input_data = process_variables(query, data, parameters, self.instructions) + if expected_output is not None and issubclass(expected_output, BaseModel): + expected_output = expected_output.model_json_schema() + if isinstance(output_format, OutputFormat): + output_format = output_format.value + payload = { "id": self.id, "query": input_data, @@ -318,7 +334,8 @@ def run_async( "executionParams": { "maxTokens": (parameters["max_tokens"] if "max_tokens" in parameters else max_tokens), "maxIterations": (parameters["max_iterations"] if "max_iterations" in parameters else max_iterations), - "outputFormat": output_format.value, + "outputFormat": output_format, + "expectedOutput": expected_output, }, } diff --git a/aixplain/modules/agent/output_format.py b/aixplain/modules/agent/output_format.py index 3a53e2f8..1e4984ad 100644 --- a/aixplain/modules/agent/output_format.py +++ b/aixplain/modules/agent/output_format.py @@ -28,3 +28,4 @@ class OutputFormat(Text, Enum): MARKDOWN = "markdown" TEXT = "text" + JSON = "json" diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 2d3d22a5..c2961c8a 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -43,6 +43,7 @@ from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry from aixplain.modules.mixins import DeployableMixin +from pydantic import BaseModel class InspectorTarget(str, Enum): @@ -139,6 +140,7 @@ def run( max_tokens: int = 2048, max_iterations: int = 30, output_format: OutputFormat = OutputFormat.TEXT, + expected_output: Optional[Union[BaseModel, Text, dict]] = None, ) -> AgentResponse: """Runs a team agent call. @@ -154,7 +156,8 @@ def run( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30. - output_format (ResponseFormat, optional): response format. Defaults to TEXT. + output_format (OutputFormat, optional): response format. Defaults to TEXT. + expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. Returns: Dict: parsed output from model """ @@ -172,6 +175,7 @@ def run( max_tokens=max_tokens, max_iterations=max_iterations, output_format=output_format, + expected_output=expected_output, ) if response["status"] == ResponseStatus.FAILED: end = time.time() @@ -215,6 +219,7 @@ def run_async( max_tokens: int = 2048, max_iterations: int = 30, output_format: OutputFormat = OutputFormat.TEXT, + expected_output: Optional[Union[BaseModel, Text, dict]] = None, ) -> AgentResponse: """Runs asynchronously a Team Agent call. @@ -228,7 +233,8 @@ def run_async( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30. - output_format (ResponseFormat, optional): response format. Defaults to TEXT. + output_format (OutputFormat, optional): response format. Defaults to TEXT. + expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. Returns: dict: polling URL in response """ @@ -273,6 +279,11 @@ def run_async( # build query input_data = process_variables(query, data, parameters, self.description) + if expected_output is not None and issubclass(expected_output, BaseModel): + expected_output = expected_output.model_json_schema() + if isinstance(output_format, OutputFormat): + output_format = output_format.value + payload = { "id": self.id, "query": input_data, @@ -281,7 +292,8 @@ def run_async( "executionParams": { "maxTokens": (parameters["max_tokens"] if "max_tokens" in parameters else max_tokens), "maxIterations": (parameters["max_iterations"] if "max_iterations" in parameters else max_iterations), - "outputFormat": output_format.value, + "outputFormat": output_format, + "expectedOutput": expected_output, }, } payload.update(parameters) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index aaf0cd54..16d6ba7b 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -606,3 +606,83 @@ def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory): assert "hello" in answer["data"]["output"].lower() assert "hello pipeline" in answer["data"]["intermediate_steps"][0]["tool_steps"][0]["tool"].lower() + + +def test_run_agent_with_expected_output(): + from pydantic import BaseModel + from typing import Optional, List + from aixplain.modules.agent import AgentResponse + from aixplain.modules.agent.output_format import OutputFormat + + class Person(BaseModel): + name: str + age: int + city: Optional[str] = None + + class Response(BaseModel): + result: List[Person] + + INSTRUCTIONS = """Answer questions based on the following context: + ++-----------------+-------+----------------+ +| Name | Age | City | ++=================+=======+================+ +| João Silva | 34 | São Paulo | ++-----------------+-------+----------------+ +| Maria Santos | 28 | Rio de Janeiro | ++-----------------+-------+----------------+ +| Pedro Oliveira | 45 | | ++-----------------+-------+----------------+ +| Ana Costa | 19 | Recife | ++-----------------+-------+----------------+ +| Carlos Pereira | 52 | Belo Horizonte | ++-----------------+-------+----------------+ +| Beatriz Lima | 31 | | ++-----------------+-------+----------------+ +| Lucas Ferreira | 25 | Curitiba | ++-----------------+-------+----------------+ +| Julia Rodrigues | 41 | Salvador | ++-----------------+-------+----------------+ +| Miguel Almeida | 37 | | ++-----------------+-------+----------------+ +| Sofia Carvalho | 29 | Brasília | ++-----------------+-------+----------------+""" + + agent = AgentFactory.create( + name="Test Agent", + description="Test description", + instructions=INSTRUCTIONS, + llm_id="6646261c6eb563165658bbb1", + ) + # Run the agent + response = agent.run("Who have more than 30 years old?", output_format=OutputFormat.JSON, expected_output=Response) + + # Verify response basics + assert response is not None + assert isinstance(response, AgentResponse) + + try: + response_json = json.loads(response.data.output) + except Exception: + import re + + response_json = re.search(r"```json(.*?)```", response.data.output, re.DOTALL).group(1) + response_json = json.loads(response_json) + assert "result" in response_json + assert len(response_json["result"]) > 0 + + more_than_30_years_old = [ + "João Silva", + "Pedro Oliveira", + "Carlos Pereira", + "Beatriz Lima", + "Julia Rodrigues", + "Miguel Almeida", + "Sofia Carvalho", + ] + + for person in response_json["result"]: + assert "name" in person + assert "age" in person + assert "city" in person + assert person["name"] in more_than_30_years_old diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index baa0a34b..47d88365 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -634,3 +634,100 @@ def test_team_agent_with_instructions(delete_agents_and_team_agents): team_agent.delete() agent_1.delete() agent_2.delete() + + +def test_run_team_agent_with_expected_output(): + from pydantic import BaseModel + from typing import Optional, List + from aixplain.modules.agent import AgentResponse + from aixplain.modules.agent.output_format import OutputFormat + + class Person(BaseModel): + name: str + age: int + city: Optional[str] = None + + class Response(BaseModel): + result: List[Person] + + INSTRUCTIONS = """Answer questions based on the following context: + ++-----------------+-------+----------------+ +| Name | Age | City | ++=================+=======+================+ +| João Silva | 34 | São Paulo | ++-----------------+-------+----------------+ +| Maria Santos | 28 | Rio de Janeiro | ++-----------------+-------+----------------+ +| Pedro Oliveira | 45 | | ++-----------------+-------+----------------+ +| Ana Costa | 19 | Recife | ++-----------------+-------+----------------+ +| Carlos Pereira | 52 | Belo Horizonte | ++-----------------+-------+----------------+ +| Beatriz Lima | 31 | | ++-----------------+-------+----------------+ +| Lucas Ferreira | 25 | Curitiba | ++-----------------+-------+----------------+ +| Julia Rodrigues | 41 | Salvador | ++-----------------+-------+----------------+ +| Miguel Almeida | 37 | | ++-----------------+-------+----------------+ +| Sofia Carvalho | 29 | Brasília | ++-----------------+-------+----------------+""" + + agent = AgentFactory.create( + name="Test Agent", + description="Test description", + instructions=INSTRUCTIONS, + tasks=[ + AgentFactory.create_task( + name="Task 1", + description="Check table for information about people related to the query", + expected_output="A table with the following columns: Name, Age, City", + ) + ], + llm_id="6646261c6eb563165658bbb1", + ) + + team_agent = TeamAgentFactory.create( + name="Team Agent", + agents=[agent], + description="Team agent", + llm_id="6646261c6eb563165658bbb1", + use_mentalist=False, + use_inspector=False, + ) + + # Run the team agent + response = team_agent.run("Who have more than 30 years old?", output_format=OutputFormat.JSON, expected_output=Response) + + # Verify response basics + assert response is not None + assert isinstance(response, AgentResponse) + + try: + response_json = json.loads(response.data.output) + except Exception: + import re + + response_json = re.search(r"```json(.*?)```", response.data.output, re.DOTALL).group(1) + response_json = json.loads(response_json) + assert "result" in response_json + assert len(response_json["result"]) > 0 + + more_than_30_years_old = [ + "João Silva", + "Pedro Oliveira", + "Carlos Pereira", + "Beatriz Lima", + "Julia Rodrigues", + "Miguel Almeida", + "Sofia Carvalho", + ] + + for person in response_json["result"]: + assert "name" in person + assert "age" in person + assert "city" in person + assert person["name"] in more_than_30_years_old From cecd3a6f6742ddcc71be0061e87dbf18214efd15 Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Tue, 20 May 2025 21:49:04 +0300 Subject: [PATCH 28/62] Added serialize function for save (#529) * Added serialize function for save * fixed syntax error --- aixplain/utils/asset_cache.py | 17 ++++++++++++++++- tests/functional/model/run_model_test.py | 2 +- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/aixplain/utils/asset_cache.py b/aixplain/utils/asset_cache.py index a2b8d09e..3be2e9f6 100644 --- a/aixplain/utils/asset_cache.py +++ b/aixplain/utils/asset_cache.py @@ -112,7 +112,7 @@ def save(self): data_dict = {} for asset_id, asset in self.store.data.items(): try: - data_dict[asset_id] = asset.to_dict() + data_dict[asset_id] = serialize(asset) except Exception as e: logger.error(f"Error serializing {asset_id}: {e}") serializable_store = { @@ -138,3 +138,18 @@ def get_all(self) -> List[T]: def has_valid_cache(self) -> bool: return self.store.expiry >= time.time() and bool(self.store.data) + +def serialize(obj): + if isinstance(obj, (str, int, float, bool, type(None))): + return obj + elif isinstance(obj, (list, tuple, set)): + return [serialize(o) for o in obj] + elif isinstance(obj, dict): + return {str(k): serialize(v) for k, v in obj.items()} + elif hasattr(obj, "to_dict"): + return serialize(obj.to_dict()) + elif hasattr(obj, "__dict__"): + return serialize(vars(obj)) + else: + return str(obj) + diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index bf6d9320..95192770 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -217,7 +217,7 @@ def test_aixplain_model_cache_creation(): assert "data" in cache_data, "Cache file structure invalid - missing 'data' key." assert any(m.get("id") == model_id for m in cache_data["data"]["items"]), "Instantiated model not found in cache." -======= + def test_index_model_air_with_image(): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record From 6aa648f11d11145376659d3edb17744e5ec079f7 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <50206820+basitanees@users.noreply.github.com> Date: Wed, 21 May 2025 00:10:19 +0300 Subject: [PATCH 29/62] Add Embedding Params to Model (#534) --- aixplain/factories/model_factory/utils.py | 11 ++++++ aixplain/modules/model/__init__.py | 44 +++++------------------ aixplain/modules/model/index_model.py | 18 ++++++++++ tests/functional/model/run_model_test.py | 19 +++++++++- 4 files changed, 56 insertions(+), 36 deletions(-) diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index a5f489be..2a1c9269 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -36,6 +36,16 @@ def create_model_from_response(response: Dict) -> Model: if len(values) > 0: parameters[param["name"]] = values + additional_kwargs = {} + attributes = response.get("attributes", None) + if attributes: + embedding_model = next((item["code"] for item in attributes if item["name"] == "embeddingmodel"), None) + if embedding_model: + additional_kwargs["embedding_model"] = embedding_model + embedding_size = next((item["value"] for item in attributes if item["name"] == "embeddingSize"), None) + if embedding_size: + additional_kwargs["embedding_size"] = embedding_size + function_id = response["function"]["id"] function = Function(function_id) function_input_params, function_output_params = function.get_input_output_params() @@ -100,6 +110,7 @@ def create_model_from_response(response: Dict) -> Model: temperature=temperature, supports_streaming=response.get("supportsStreaming", False), status=status, + **additional_kwargs, ) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 04e39deb..b00b1993 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -123,9 +123,7 @@ def to_dict(self) -> Dict: Returns: Dict: Model Information """ - clean_additional_info = { - k: v for k, v in self.additional_info.items() if v is not None - } + clean_additional_info = {k: v for k, v in self.additional_info.items() if v not in [None, [], {}]} return { "id": self.id, "name": self.name, @@ -193,9 +191,7 @@ def sync_poll( logging.error(f"Polling for Model: polling for {name}: {e}") break if response_body["completed"] is True: - logging.debug( - f"Polling for Model: Final status of polling for {name}: {response_body}" - ) + logging.debug(f"Polling for Model: Final status of polling for {name}: {response_body}") else: response_body = ModelResponse( status=ResponseStatus.FAILED, @@ -296,9 +292,7 @@ def run( try: poll_url = response["url"] end = time.time() - return self.sync_poll( - poll_url, name=name, timeout=timeout, wait_time=wait_time - ) + return self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) except Exception as e: msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Model Run: Error in running for {name}: {e}") @@ -370,9 +364,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): resp = None try: url = urljoin(self.backend_url, f"sdk/finetune/{self.id}/ml-logs") - logging.info( - f"Start service for GET Check FineTune status Model - {url} - {headers}" - ) + logging.info(f"Start service for GET Check FineTune status Model - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() finetune_status = AssetStatus(resp["finetuneStatus"]) @@ -399,21 +391,9 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): status = FinetuneStatus( status=finetune_status, model_status=model_status, - epoch=( - float(log["epoch"]) - if "epoch" in log and log["epoch"] is not None - else None - ), - training_loss=( - float(log["trainLoss"]) - if "trainLoss" in log and log["trainLoss"] is not None - else None - ), - validation_loss=( - float(log["evalLoss"]) - if "evalLoss" in log and log["evalLoss"] is not None - else None - ), + epoch=(float(log["epoch"]) if "epoch" in log and log["epoch"] is not None else None), + training_loss=(float(log["trainLoss"]) if "trainLoss" in log and log["trainLoss"] is not None else None), + validation_loss=(float(log["evalLoss"]) if "evalLoss" in log and log["evalLoss"] is not None else None), ) else: status = FinetuneStatus( @@ -421,9 +401,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): model_status=model_status, ) - logging.info( - f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}." - ) + logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}.") return status except Exception: message = "" @@ -463,11 +441,7 @@ def from_dict(cls, data: Dict) -> "Model": function=Function(data.get("function")), is_subscribed=data.get("is_subscribed", False), cost=data.get("cost"), - created_at=( - datetime.fromisoformat(data["created_at"]) - if data.get("created_at") - else None - ), + created_at=(datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None), input_params=data.get("input_params"), output_params=data.get("output_params"), model_params=data.get("model_params"), diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index 8a5b18fc..f0292a9d 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -83,6 +83,24 @@ def __init__( self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL self.embedding_model = embedding_model + if embedding_model: + try: + from aixplain.factories import ModelFactory + + model = ModelFactory.get(embedding_model) + self.embedding_size = model.additional_info["embedding_size"] + except Exception as e: + import warnings + + warnings.warn(f"Failed to get embedding size for embedding model {embedding_model}: {e}") + self.embedding_size = None + + def to_dict(self) -> Dict: + data = super().to_dict() + data["embedding_model"] = self.embedding_model + data["embedding_size"] = self.embedding_size + data["collection_type"] = self.version.split("-", 1)[0] + return data def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: """Search for documents in the index diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 95192770..03f22cc4 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -14,6 +14,7 @@ from aixplain.modules.model import Model from aixplain.factories.index_factory.utils import AirParams, VectaraParams, GraphRAGParams, ZeroEntropyParams + def pytest_generate_tests(metafunc): if "llm_model" in metafunc.fixturenames: four_weeks_ago = datetime.now(timezone.utc) - timedelta(weeks=4) @@ -80,6 +81,21 @@ def test_run_async(): assert "teste" in response["data"].lower() +@pytest.mark.parametrize( + "embedding_model,supplier_params", + [pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="AIR - OpenAI Ada 002")], +) +def test_index_model_with_embedding(embedding_model, supplier_params): + from uuid import uuid4 + from aixplain.factories import IndexFactory + + params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) + index_model = IndexFactory.create(params=params) + assert index_model.embedding_model == embedding_model + assert index_model.embedding_size == 1536, f"Embedding size mismatch for {embedding_model}" + index_model.delete() + + def run_index_model(index_model): from aixplain.modules.model.record import Record @@ -206,7 +222,7 @@ def test_aixplain_model_cache_creation(): os.remove(cache_file) # Instantiate the Model (replace this with a real model ID from your env) - model_id = "6239efa4822d7a13b8e20454" # Translate from Punjabi to Portuguese (Brazil) + model_id = "6239efa4822d7a13b8e20454" # Translate from Punjabi to Portuguese (Brazil) _ = Model(id=model_id) # Assert the cache file was created @@ -218,6 +234,7 @@ def test_aixplain_model_cache_creation(): assert "data" in cache_data, "Cache file structure invalid - missing 'data' key." assert any(m.get("id") == model_id for m in cache_data["data"]["items"]), "Instantiated model not found in cache." + def test_index_model_air_with_image(): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record From ad4ddbf281ed7768353e94c30c4851ad5adc15f2 Mon Sep 17 00:00:00 2001 From: Muhammad-Elmallah <145364766+Muhammad-Elmallah@users.noreply.github.com> Date: Wed, 21 May 2025 21:14:56 +0300 Subject: [PATCH 30/62] Prod 1785 enable adding any embedding in ai r not just embedding model (#514) * enable calling aiR with any embedding model * Added validation step for embedding models * changeing BGE model_id * getting embedding size through modelfactory * getting embedding size through modelfactory * Update __init__.py --- .gitignore | 6 + aixplain/enums/embedding_model.py | 5 +- aixplain/factories/index_factory/__init__.py | 23 ++-- aixplain/factories/index_factory/utils.py | 26 +++- aixplain/modules/model/index_model.py | 15 ++- tests/functional/model/run_model_test.py | 120 +++++++------------ 6 files changed, 101 insertions(+), 94 deletions(-) diff --git a/.gitignore b/.gitignore index 304f04cb..a9a2c574 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,9 @@ dmypy.json # Vscode .vscode .DS_Store + +.env.dev +.env.prod +.env.test + +aixplain-dev/ diff --git a/aixplain/enums/embedding_model.py b/aixplain/enums/embedding_model.py index 31618580..4e77fa22 100644 --- a/aixplain/enums/embedding_model.py +++ b/aixplain/enums/embedding_model.py @@ -21,13 +21,10 @@ class EmbeddingModel(str, Enum): - SNOWFLAKE_ARCTIC_EMBED_M_LONG = "6658d40729985c2cf72f42ec" OPENAI_ADA002 = "6734c55df127847059324d9e" - SNOWFLAKE_ARCTIC_EMBED_L_V2_0 = "678a4f8547f687504744960a" JINA_CLIP_V2_MULTIMODAL = "67c5f705d8f6a65d6f74d732" MULTILINGUAL_E5_LARGE = "67efd0772a0a850afa045af3" - BGE_M3 = "67f401032a0a850afa045b19" - AIXPLAIN_LEGAL_EMBEDDINGS = "681254b668e47e7844c1f15a" + BGE_M3 = "67efd4f92a0a850afa045af7" def __str__(self): return self._value_ diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 74e3ee1a..a5691988 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -29,6 +29,14 @@ T = TypeVar("T", bound=BaseIndexParams) +import os +from aixplain.utils.file_utils import _request_with_retry +from urllib.parse import urljoin + +def validate_embedding_model(model_id) -> bool: + model = ModelFactory.get(model_id) + return model.function == Function.TEXT_EMBEDDING + class IndexFactory(ModelFactory, Generic[T]): @classmethod @@ -36,7 +44,7 @@ def create( cls, name: Optional[Text] = None, description: Optional[Text] = None, - embedding_model: EmbeddingModel = EmbeddingModel.OPENAI_ADA002, + embedding_model: Union[EmbeddingModel, str] = EmbeddingModel.OPENAI_ADA002, params: Optional[T] = None, **kwargs, ) -> IndexModel: @@ -59,14 +67,13 @@ def create( assert ( name is not None and description is not None and embedding_model is not None ), "Index Factory Exception: name, description, and embedding_model must be provided when params is not" - data = { - "data": name, - "description": description, - "model": embedding_model, - } - + if validate_embedding_model(embedding_model): + data = { + "data": name, + "description": description, + "model": embedding_model, + } model = cls.get(model_id) - response = model.run(data=data) if response.status == ResponseStatus.SUCCESS: model_id = response.data diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py index 83efd469..b09df3c0 100644 --- a/aixplain/factories/index_factory/utils.py +++ b/aixplain/factories/index_factory/utils.py @@ -1,7 +1,9 @@ -from pydantic import BaseModel, ConfigDict -from typing import Text, Optional, ClassVar, Dict +from pydantic import BaseModel, ConfigDict, field_validator +from typing import Text, Optional, ClassVar, Dict, Union from aixplain.enums import IndexStores, EmbeddingModel from abc import ABC, abstractmethod +from aixplain.factories import ModelFactory +from aixplain.enums import Function class BaseIndexParams(BaseModel, ABC): @@ -22,16 +24,29 @@ def id(self) -> str: class BaseIndexParamsWithEmbeddingModel(BaseIndexParams, ABC): - embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 + embedding_model: Optional[Union[EmbeddingModel, str]] = EmbeddingModel.OPENAI_ADA002 embedding_size: Optional[int] = None + @field_validator('embedding_model') + def validate_embedding_model(cls, model_id) -> bool: + model = ModelFactory.get(model_id) + if model.function == Function.TEXT_EMBEDDING: + return model_id + else: + raise ValueError("This is not an embedding model") + + def to_dict(self): data = super().to_dict() data["model"] = data.pop("embedding_model") + if data.get("embedding_size"): data["additional_params"] = {"embedding_size": data.pop("embedding_size")} return data + + + class VectaraParams(BaseIndexParams): _id: ClassVar[str] = IndexStores.VECTARA.get_model_id() @@ -64,3 +79,8 @@ class GraphRAGParams(BaseIndexParamsWithEmbeddingModel): @property def id(self) -> str: return self._id + + + + + diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index f0292a9d..15a1a474 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -7,6 +7,11 @@ from enum import Enum from typing import List +import os + +from urllib.parse import urljoin +from aixplain.utils.file_utils import _request_with_retry + class IndexFilterOperator(Enum): EQUALS = "==" @@ -37,6 +42,8 @@ def to_dict(self): } + + class IndexModel(Model): def __init__( self, @@ -49,7 +56,7 @@ def __init__( function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, - embedding_model: Optional[EmbeddingModel] = None, + embedding_model: Union[EmbeddingModel, str] = None, **additional_info, ) -> None: """Index Init @@ -64,7 +71,7 @@ def __init__( function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. - embedding_model (EmbeddingModel, optional): embedding model. Defaults to None. + embedding_model (Union[EmbeddingModel, str], optional): embedding model. Defaults to None. **additional_info: Any additional Model info to be saved """ assert function == Function.SEARCH, "Index only supports search function" @@ -102,6 +109,8 @@ def to_dict(self) -> Dict: data["collection_type"] = self.version.split("-", 1)[0] return data + + def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: """Search for documents in the index @@ -131,7 +140,7 @@ def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) - "data": query or uri, "dataType": value_type, "filters": [filter.to_dict() for filter in filters], - "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, + "payload": {"uri": uri, "value_type": value_type, "top_k": top_k} } return self.run(data=data) diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 03f22cc4..14495b53 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -1,7 +1,5 @@ __author__ = "thiagocastroferreira" -import os -import json import pytest import requests @@ -10,9 +8,12 @@ from aixplain.modules import LLM from datetime import datetime, timedelta, timezone from pathlib import Path -from aixplain.utils.cache_utils import CACHE_FOLDER -from aixplain.modules.model import Model from aixplain.factories.index_factory.utils import AirParams, VectaraParams, GraphRAGParams, ZeroEntropyParams +from aixplain.factories import IndexFactory +from aixplain.modules.model.record import Record +import time + + def pytest_generate_tests(metafunc): @@ -81,48 +82,22 @@ def test_run_async(): assert "teste" in response["data"].lower() -@pytest.mark.parametrize( - "embedding_model,supplier_params", - [pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="AIR - OpenAI Ada 002")], -) -def test_index_model_with_embedding(embedding_model, supplier_params): - from uuid import uuid4 - from aixplain.factories import IndexFactory - - params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) - index_model = IndexFactory.create(params=params) - assert index_model.embedding_model == embedding_model - assert index_model.embedding_size == 1536, f"Embedding size mismatch for {embedding_model}" - index_model.delete() - - -def run_index_model(index_model): +def run_index_model(index_model, retries): from aixplain.modules.model.record import Record - index_model.upsert([Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})]) + + for _ in range(retries): + try: + index_model.upsert([Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})]) + break + except Exception as e: + time.sleep(180) + response = index_model.search("Berlin") assert str(response.status) == "SUCCESS" assert "germany" in response.data.lower() assert index_model.count() == 1 - index_model.upsert([Record(value="Ankara is the capital of Turkey.", value_type="text", uri="", id="1", attributes={})]) - response = index_model.search("Ankara") - assert str(response.status) == "SUCCESS" - assert "turkey" in response.data.lower() - assert index_model.count() == 1 - - index_model.upsert([Record(value="London is the capital of England.", value_type="text", uri="", id="2", attributes={})]) - assert index_model.count() == 2 - - response = index_model.get_record("1") - assert str(response.status) == "SUCCESS" - assert response.data == "Ankara is the capital of Turkey." - assert index_model.count() == 2 - - response = index_model.delete_record("1") - assert str(response.status) == "SUCCESS" - assert index_model.count() == 1 - index_model.delete() @@ -133,11 +108,10 @@ def run_index_model(index_model): pytest.param(None, ZeroEntropyParams, id="ZERO_ENTROPY"), pytest.param(EmbeddingModel.OPENAI_ADA002, GraphRAGParams, id="GRAPHRAG"), pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="AIR - OpenAI Ada 002"), - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, AirParams, id="AIR - Snowflake Arctic Embed M Long"), - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, AirParams, id="AIR - Snowflake Arctic Embed L v2.0"), + pytest.param("6658d40729985c2cf72f42ec", AirParams, id="AIR - Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), - pytest.param(EmbeddingModel.BGE_M3, AirParams, id="AIR - BGE M3"), - pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="AIR - aiXplain Legal Embeddings"), + pytest.param("67efd4f92a0a850afa045af7", AirParams, id="AIR - BGE M3"), + pytest.param("681254b668e47e7844c1f15a", AirParams, id="AIR - aiXplain Legal Embeddings"), ], ) def test_index_model(embedding_model, supplier_params): @@ -146,23 +120,26 @@ def test_index_model(embedding_model, supplier_params): params = supplier_params(name=str(uuid4()), description=str(uuid4())) if embedding_model is not None: + print(f"Embedding Model : {embedding_model}") params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) index_model = IndexFactory.create(params=params) - run_index_model(index_model) - + if embedding_model in [EmbeddingModel.MULTILINGUAL_E5_LARGE, EmbeddingModel.BGE_M3]: + retries = 3 + else: + retries = 1 + run_index_model(index_model, retries) @pytest.mark.parametrize( "embedding_model,supplier_params", [ pytest.param(None, VectaraParams, id="VECTARA"), pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="OpenAI Ada 002"), - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, AirParams, id="Snowflake Arctic Embed M Long"), - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, AirParams, id="Snowflake Arctic Embed L v2.0"), + pytest.param("6658d40729985c2cf72f42ec", AirParams, id="Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), - pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"), - pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="aiXplain Legal Embeddings"), + pytest.param("67efd4f92a0a850afa045af7", AirParams, id="BGE M3"), + pytest.param("681254b668e47e7844c1f15a", AirParams, id="aiXplain Legal Embeddings"), ], ) def test_index_model_with_filter(embedding_model, supplier_params): @@ -179,10 +156,23 @@ def test_index_model_with_filter(embedding_model, supplier_params): params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) index_model = IndexFactory.create(params=params) - index_model.upsert([Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})]) - index_model.upsert( - [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] - ) + if embedding_model in [EmbeddingModel.MULTILINGUAL_E5_LARGE, EmbeddingModel.BGE_M3]: + retries = 3 + else: + retries = 1 + for _ in range(retries): + try: + index_model.upsert([Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})]) + break + except Exception: + time.sleep(180) + for _ in range(retries): + try: + index_model.upsert([Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})]) + break + except Exception: + time.sleep(180) + assert index_model.count() == 2 response = index_model.search( "", filters=[IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS)] @@ -212,29 +202,6 @@ def test_llm_run_with_file(): assert "🤖" in response["data"], "Robot emoji should be present in the response" -def test_aixplain_model_cache_creation(): - """Ensure AssetCache is triggered and cache is created.""" - - cache_file = os.path.join(CACHE_FOLDER, "models.json") - - # Clean up cache before the test - if os.path.exists(cache_file): - os.remove(cache_file) - - # Instantiate the Model (replace this with a real model ID from your env) - model_id = "6239efa4822d7a13b8e20454" # Translate from Punjabi to Portuguese (Brazil) - _ = Model(id=model_id) - - # Assert the cache file was created - assert os.path.exists(cache_file), "Expected cache file was not created." - - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - assert "data" in cache_data, "Cache file structure invalid - missing 'data' key." - assert any(m.get("id") == model_id for m in cache_data["data"]["items"]), "Instantiated model not found in cache." - - def test_index_model_air_with_image(): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record @@ -278,6 +245,7 @@ def test_index_model_air_with_image(): index_model.upsert(records) + response = index_model.search("beach") assert str(response.status) == "SUCCESS" second_record = response.details[1]["metadata"]["uri"] From bb74b68e44d77b1ba315cf7949bdbf2656ac9c0f Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Thu, 22 May 2025 17:29:50 +0300 Subject: [PATCH 31/62] Changed cache default to false (#536) --- aixplain/factories/model_factory/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index bf0c21b5..64ab8d99 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -109,7 +109,7 @@ def create_utility_model( @classmethod def get( - cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = True + cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = False ) -> Model: """Create a 'Model' object from model id""" cache = AssetCache(Model) From b2a00e3dbd52157802610ea173fc2725339e200b Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Fri, 23 May 2025 03:46:33 +0300 Subject: [PATCH 32/62] added filelock (#538) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 1cf10327..e21333f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "Jinja2==3.1.6", "sentry-sdk>=1.0.0", "pydantic>=2.10.6" + "filelock>=3.0.0" ] [project.urls] From 239b56aef4dd1eccbed9795c867b49e8c9cac264 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Fri, 23 May 2025 09:47:10 -0300 Subject: [PATCH 33/62] Add filelock to requirements (#540) * added filelock * Fix typo in pyproject --------- Co-authored-by: xainaz --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e21333f0..0bc5ec6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ dependencies = [ "dataclasses-json>=0.5.2", "Jinja2==3.1.6", "sentry-sdk>=1.0.0", - "pydantic>=2.10.6" + "pydantic>=2.10.6", "filelock>=3.0.0" ] From 1f6b614d14b4f64b80dbac5fb075d3b6a091bc98 Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Fri, 23 May 2025 19:06:11 +0300 Subject: [PATCH 34/62] cache duration (#541) --- aixplain/utils/asset_cache.py | 8 ++++---- aixplain/utils/cache_utils.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aixplain/utils/asset_cache.py b/aixplain/utils/asset_cache.py index 3be2e9f6..357b70ef 100644 --- a/aixplain/utils/asset_cache.py +++ b/aixplain/utils/asset_cache.py @@ -17,7 +17,7 @@ # Constants CACHE_FOLDER = ".cache" -DEFAULT_CACHE_EXPIRY = 86400 +CACHE_DURATION = 86400 @dataclass @@ -51,15 +51,15 @@ def __init__( def compute_expiry(self): try: - expiry = int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) + expiry = int(os.getenv("CACHE_EXPIRY_TIME", CACHE_DURATION)) except Exception as e: logger.warning( f"Failed to parse CACHE_EXPIRY_TIME: {e}, " - f"fallback to default value {DEFAULT_CACHE_EXPIRY}" + f"fallback to default value {CACHE_DURATION}" ) # remove the CACHE_EXPIRY_TIME from the environment variables del os.environ["CACHE_EXPIRY_TIME"] - expiry = DEFAULT_CACHE_EXPIRY + expiry = CACHE_DURATION return time.time() + int(expiry) diff --git a/aixplain/utils/cache_utils.py b/aixplain/utils/cache_utils.py index 01981701..fcfe1cb6 100644 --- a/aixplain/utils/cache_utils.py +++ b/aixplain/utils/cache_utils.py @@ -7,11 +7,11 @@ CACHE_FOLDER = ".cache" CACHE_FILE = f"{CACHE_FOLDER}/cache.json" LOCK_FILE = f"{CACHE_FILE}.lock" -DEFAULT_CACHE_EXPIRY = 86400 +CACHE_DURATION = 86400 def get_cache_expiry(): - return int(os.getenv("CACHE_EXPIRY_TIME", DEFAULT_CACHE_EXPIRY)) + return int(os.getenv("CACHE_EXPIRY_TIME", CACHE_DURATION)) def save_to_cache(cache_file, data, lock_file): From 14485e23bcc32ccecb1e3cf51bb9272a3a6e7fec Mon Sep 17 00:00:00 2001 From: Yunsu Kim Date: Mon, 26 May 2025 10:32:37 +0200 Subject: [PATCH 35/62] ErrorCode returns code in string (#542) --- aixplain/exceptions/types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aixplain/exceptions/types.py b/aixplain/exceptions/types.py index e41c7fd2..56c710b5 100644 --- a/aixplain/exceptions/types.py +++ b/aixplain/exceptions/types.py @@ -59,6 +59,9 @@ class inheriting from the corresponding category exception (e.g., AX_SVC_ERROR = "AX-SVC-1000" # General service error. Use when a specific aiXplain service or endpoint is unavailable or malfunctioning (e.g., service downtime, internal component failure. AX_INT_ERROR = "AX-INT-1000" # General internal error. Use for unexpected server-side errors that are not covered by other categories. This often indicates a bug or an issue within the aiXplain platform itself. + def __str__(self): + return self.value + class AixplainBaseException(Exception): """Base exception class for all aiXplain exceptions.""" From 59be7cc3e697de46d9c7c8061d141ac7247f780f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Tue, 27 May 2025 14:58:20 +0300 Subject: [PATCH 36/62] ENG-2003 : Add LLM's to Agents as Object (#524) * Add llm params to agent and team agent * Add llm params to agent and team agent * ENG-2003 : added llm instance method and create payload from it * correct get params for llms * Teamagent llms conditioned for backward compatibility * added llm id * corrected team agent llm description --------- Co-authored-by: lucas-aixplain --- aixplain/factories/agent_factory/__init__.py | 47 +++++++- aixplain/factories/agent_factory/utils.py | 41 +++++++ .../factories/team_agent_factory/__init__.py | 102 +++++++++++++++++- .../factories/team_agent_factory/utils.py | 51 ++++++++- aixplain/modules/agent/__init__.py | 24 +++-- aixplain/modules/team_agent/__init__.py | 35 ++++-- aixplain/utils/llm_utils.py | 23 ++++ .../functional/agent/agent_functional_test.py | 30 ++++++ .../team_agent/team_agent_functional_test.py | 42 ++++++++ 9 files changed, 369 insertions(+), 26 deletions(-) create mode 100644 aixplain/utils/llm_utils.py diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 92c6ec3b..0190cb55 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -37,6 +37,7 @@ SQLTool, ) from aixplain.modules.model import Model +from aixplain.modules.model.llm_model import LLM from aixplain.modules.pipeline import Pipeline from aixplain.utils import config from typing import Callable, Dict, List, Optional, Text, Union @@ -44,6 +45,7 @@ from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from aixplain.enums import DatabaseSourceType +from aixplain.utils.llm_utils import get_llm_instance class AgentFactory: @@ -53,7 +55,8 @@ def create( name: Text, description: Text, instructions: Optional[Text] = None, - llm_id: Text = "669a63646eb56306647e1091", + llm: Optional[Union[LLM, Text]] = None, + llm_id: Optional[Text] = None, tools: List[Union[Tool, Model]] = [], api_key: Text = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", @@ -71,21 +74,37 @@ def create( name (Text): name of the agent description (Text): description of the agent role. instructions (Text): role of the agent. - llm_id (Text, optional): aiXplain ID of the large language model to be used as agent. Defaults to "669a63646eb56306647e1091" (GPT-4o mini). + llm (Optional[Union[LLM, Text]], optional): LLM instance to use as an object or as an ID. + llm_id (Optional[Text], optional): ID of LLM to use if no LLM instance provided. Defaults to None. tools (List[Union[Tool, Model]], optional): list of tool for the agent. Defaults to []. api_key (Text, optional): team/user API key. Defaults to config.TEAM_API_KEY. supplier (Union[Dict, Text, Supplier, int], optional): owner of the agent. Defaults to "aiXplain". version (Optional[Text], optional): version of the agent. Defaults to None. tasks (List[AgentTask], optional): list of tasks for the agent. Defaults to []. + Returns: Agent: created Agent """ + if llm is None and llm_id is not None: + llm = get_llm_instance(llm_id, api_key=api_key) + elif llm is None: + # Use default GPT-4o if no LLM specified + llm = get_llm_instance("669a63646eb56306647e1091", api_key=api_key) + + if instructions is None: + warnings.warn( + "Use `instructions` to define the **system prompt**. " + "Use `description` to provide a **short summary** of the agent for metadata and dashboard display. " + "Note: In upcoming releases, `instructions` will become a required parameter.", + UserWarning, + ) warnings.warn( - "Use `instructions` to define the **system prompt**. " - "Use `description` to provide a **short summary** of the agent for metadata and dashboard display. " - "Note: In upcoming releases, `instructions` will become a required parameter.", + "Use `llm` to define the large language model (aixplain.modules.model.llm_model.LLM) to be used as agent. " + "Use `llm_id` to provide the model ID of the large language model to be used as agent. " + "Note: In upcoming releases, `llm` will become a required parameter.", UserWarning, ) + from aixplain.factories.agent_factory.utils import build_agent agent = None @@ -124,7 +143,22 @@ def create( "llmId": llm_id, "status": "draft", "tasks": [task.to_dict() for task in tasks], + "tools": [], } + + if llm is not None: + llm = get_llm_instance(llm, api_key=api_key) if isinstance(llm, str) else llm + payload["tools"].append( + { + "type": "llm", + "description": "main", + "parameters": llm.get_parameters().to_list() if llm.get_parameters() else None, + } + ) + payload["llmId"] = llm.id + # Store the LLM object in payload to avoid recreating it + payload["llm"] = llm + agent = build_agent(payload=payload, tools=tools, api_key=api_key) agent.validate(raise_exception=True) response = "Unspecified error" @@ -136,6 +170,9 @@ def create( raise Exception("Agent Onboarding Error: Please contact the administrators.") if 200 <= r.status_code < 300: + # Preserve the LLM if it exists + if "llm" in payload: + response["llm"] = payload["llm"] agent = build_agent(payload=response, tools=tools, api_key=api_key) else: error_msg = f"Agent Onboarding Error: {response}" diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 704d7fe7..43540344 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -2,8 +2,10 @@ import logging import aixplain.utils.config as config +from aixplain.utils.llm_utils import get_llm_instance from aixplain.enums import Function, Supplier from aixplain.enums.asset_status import AssetStatus +from aixplain.modules.model.llm_model import LLM from aixplain.modules.agent import Agent from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_task import AgentTask @@ -81,6 +83,42 @@ def build_tool(tool: Dict): return tool +def build_llm(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> LLM: + """Build a LLM from a dictionary.""" + # Get LLM from tools if present + llm = None + # First check if we have the LLM object + if "llm" in payload: + llm = payload["llm"] + # Otherwise create from the parameters + elif "tools" in payload: + for tool in payload["tools"]: + if tool["type"] == "llm" and tool["description"] == "main": + + llm = get_llm_instance(payload["llmId"], api_key=api_key) + # Set parameters from the tool + if "parameters" in tool: + # Apply all parameters directly to the LLM properties + for param in tool["parameters"]: + param_name = param["name"] + param_value = param["value"] + # Apply any parameter that exists as an attribute on the LLM + if hasattr(llm, param_name): + setattr(llm, param_name, param_value) + + # Also set model_params for completeness + # Convert parameters list to dictionary format expected by ModelParameters + params_dict = {} + for param in tool["parameters"]: + params_dict[param["name"]] = {"required": False, "value": param["value"]} + # Create ModelParameters and set it on the LLM + from aixplain.modules.model.model_parameters import ModelParameters + + llm.model_params = ModelParameters(params_dict) + break + return llm + + def build_agent(payload: Dict, tools: List[Tool] = None, api_key: Text = config.TEAM_API_KEY) -> Agent: """Instantiate a new agent in the platform.""" tools_dict = payload["assets"] @@ -100,6 +138,8 @@ def build_agent(payload: Dict, tools: List[Tool] = None, api_key: Text = config. ) continue + llm = build_llm(payload, api_key) + agent = Agent( id=payload["id"] if "id" in payload else "", name=payload.get("name", ""), @@ -110,6 +150,7 @@ def build_agent(payload: Dict, tools: List[Tool] = None, api_key: Text = config. version=payload.get("version", None), cost=payload.get("cost", None), llm_id=payload.get("llmId", GPT_4o_ID), + llm=llm, api_key=api_key, status=AssetStatus(payload["status"]), tasks=[ diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index 6a1db846..cfbbe622 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -32,6 +32,8 @@ from aixplain.utils import config from aixplain.factories.team_agent_factory.utils import build_team_agent from aixplain.utils.request_utils import _request_with_retry +from aixplain.modules.model.llm_model import LLM +from aixplain.utils.llm_utils import get_llm_instance class TeamAgentFactory: @@ -41,6 +43,10 @@ def create( name: Text, agents: List[Union[Text, Agent]], llm_id: Text = "669a63646eb56306647e1091", + llm: Optional[Union[LLM, Text]] = None, + supervisor_llm: Optional[Union[LLM, Text]] = None, + mentalist_llm: Optional[Union[LLM, Text]] = None, + inspector_llm: Optional[Union[LLM, Text]] = None, description: Text = "", api_key: Text = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", @@ -58,6 +64,10 @@ def create( name: The name of the team agent. agents: A list of agents to be added to the team. llm_id: The ID of the LLM to be used for the team agent. + llm (Optional[Union[LLM, Text]], optional): The LLM to be used for the team agent. + supervisor_llm (Optional[Union[LLM, Text]], optional): Main supervisor LLM. Defaults to None. + mentalist_llm (Optional[Union[LLM, Text]], optional): LLM for planning. Defaults to None. + inspector_llm (Optional[Union[LLM, Text]], optional): LLM for inspection. Defaults to None. description: The description of the team agent to be displayed in the aiXplain platform. api_key: The API key to be used for the team agent. supplier: The supplier of the team agent. @@ -114,6 +124,19 @@ def create( mentalist_llm_id = llm_id if use_mentalist else None inspector_llm_id = llm_id if use_inspector else None + # Set up LLMs + if llm is None: + llm = get_llm_instance(llm_id, api_key=api_key) + + if supervisor_llm is None: + supervisor_llm = get_llm_instance(llm_id, api_key=api_key) + + if use_mentalist and mentalist_llm is None: + mentalist_llm = get_llm_instance(mentalist_llm_id or "669a63646eb56306647e1091", api_key=api_key) + + if use_inspector and inspector_llm is None: + inspector_llm = get_llm_instance(inspector_llm_id or "669a63646eb56306647e1091", api_key=api_key) + team_agent = None url = urljoin(config.BACKEND_URL, "sdk/agent-communities") headers = {"x-api-key": api_key} @@ -132,19 +155,78 @@ def create( "agents": agent_payload_list, "links": [], "description": description, - "llmId": llm_id, - "supervisorId": llm_id, - "plannerId": mentalist_llm_id, - "inspectorId": inspector_llm_id, + "llmId": llm.id if llm else llm_id, + "supervisorId": supervisor_llm.id if supervisor_llm else llm_id, + "plannerId": mentalist_llm.id if mentalist_llm else mentalist_llm_id, + "inspectorId": inspector_llm.id if inspector_llm else inspector_llm_id, "maxInspectors": max_inspectors, "inspectorTargets": inspector_targets if use_inspector else [], "supplier": supplier, "version": version, "status": "draft", + "tools": [], "role": instructions, } - team_agent = build_team_agent(payload=payload, agents=agent_list, api_key=api_key) + # Add LLM tools to the payload + if llm is not None: + llm = get_llm_instance(llm, api_key=api_key) if isinstance(llm, str) else llm + payload["tools"].append( + { + "type": "llm", + "description": "main", + "parameters": llm.get_parameters().to_list() if llm.get_parameters() else None, + } + ) + + if supervisor_llm is not None: + supervisor_llm = ( + get_llm_instance(supervisor_llm, api_key=api_key) if isinstance(supervisor_llm, str) else supervisor_llm + ) + payload["tools"].append( + { + "type": "llm", + "description": "supervisor", + "parameters": supervisor_llm.get_parameters().to_list() if supervisor_llm.get_parameters() else None, + } + ) + + if mentalist_llm is not None: + mentalist_llm = ( + get_llm_instance(mentalist_llm, api_key=api_key) if isinstance(mentalist_llm, str) else mentalist_llm + ) + payload["tools"].append( + { + "type": "llm", + "description": "mentalist", + "parameters": mentalist_llm.get_parameters().to_list() if mentalist_llm.get_parameters() else None, + } + ) + + if inspector_llm is not None: + inspector_llm = ( + get_llm_instance(inspector_llm, api_key=api_key) if isinstance(inspector_llm, str) else inspector_llm + ) + payload["tools"].append( + { + "type": "llm", + "description": "inspector", + "parameters": inspector_llm.get_parameters().to_list() if inspector_llm.get_parameters() else None, + } + ) + + # Store the LLM objects directly in the payload for build_team_agent + internal_payload = payload.copy() + if llm is not None: + internal_payload["llm"] = llm + if supervisor_llm is not None: + internal_payload["supervisor_llm"] = supervisor_llm + if mentalist_llm is not None: + internal_payload["mentalist_llm"] = mentalist_llm + if inspector_llm is not None: + internal_payload["inspector_llm"] = inspector_llm + + team_agent = build_team_agent(payload=internal_payload, agents=agent_list, api_key=api_key) team_agent.validate(raise_exception=True) response = "Unspecified error" try: @@ -155,6 +237,16 @@ def create( raise Exception(e) if 200 <= r.status_code < 300: + # Preserve the LLM objects + if "llm" in internal_payload: + response["llm"] = internal_payload["llm"] + if "supervisor_llm" in internal_payload: + response["supervisor_llm"] = internal_payload["supervisor_llm"] + if "mentalist_llm" in internal_payload: + response["mentalist_llm"] = internal_payload["mentalist_llm"] + if "inspector_llm" in internal_payload: + response["inspector_llm"] = internal_payload["inspector_llm"] + team_agent = build_team_agent(payload=response, agents=agent_list, api_key=api_key) else: error_msg = f"{response}" diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index 48268968..ce382531 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -8,6 +8,9 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.agent import Agent from aixplain.modules.team_agent import TeamAgent, InspectorTarget +from aixplain.factories.agent_factory import AgentFactory +from aixplain.factories.model_factory import ModelFactory +from aixplain.modules.model.model_parameters import ModelParameters GPT_4o_ID = "6646261c6eb563165658bbb1" @@ -15,8 +18,6 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = config.TEAM_API_KEY) -> TeamAgent: """Instantiate a new team agent in the platform.""" - from aixplain.factories.agent_factory import AgentFactory - agents_dict = payload["agents"] payload_agents = agents if payload_agents is None: @@ -33,6 +34,49 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = inspector_targets = [InspectorTarget(target.lower()) for target in payload.get("inspectorTargets", [])] + # Get LLMs from tools if present + supervisor_llm = None + mentalist_llm = None + inspector_llm = None + + # First check if we have direct LLM objects in the payload + if "supervisor_llm" in payload: + supervisor_llm = payload["supervisor_llm"] + if "mentalist_llm" in payload: + mentalist_llm = payload["mentalist_llm"] + if "inspector_llm" in payload: + inspector_llm = payload["inspector_llm"] + # Otherwise create from the parameters + elif "tools" in payload: + for tool in payload["tools"]: + if tool["type"] == "llm": + llm = ModelFactory.get(payload["llmId"], api_key=api_key) + # Set parameters from the tool + if "parameters" in tool: + # Apply all parameters directly to the LLM properties + for param in tool["parameters"]: + param_name = param["name"] + param_value = param["value"] + # Apply any parameter that exists as an attribute on the LLM + if hasattr(llm, param_name): + setattr(llm, param_name, param_value) + + # Also set model_params for completeness + # Convert parameters list to dictionary format expected by ModelParameters + params_dict = {} + for param in tool["parameters"]: + params_dict[param["name"]] = {"required": False, "value": param["value"]} + # Create ModelParameters and set it on the LLM + llm.model_params = ModelParameters(params_dict) + + # Assign LLM based on description + if tool["description"] == "supervisor": + supervisor_llm = llm + elif tool["description"] == "mentalist": + mentalist_llm = llm + elif tool["description"] == "inspector": + inspector_llm = llm + team_agent = TeamAgent( id=payload.get("id", ""), name=payload.get("name", ""), @@ -43,6 +87,9 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = version=payload.get("version", None), cost=payload.get("cost", None), llm_id=payload.get("llmId", GPT_4o_ID), + supervisor_llm=supervisor_llm, + mentalist_llm=mentalist_llm, + inspector_llm=inspector_llm, use_mentalist=True if payload.get("plannerId", None) is not None else False, use_inspector=True if payload.get("inspectorId", None) is not None else False, max_inspectors=payload.get("maxInspectors", 1), diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index eaf628c6..3d449956 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -38,6 +38,7 @@ from pydantic import BaseModel from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin +from aixplain.modules.model.llm_model import LLM from aixplain.utils import config from aixplain.modules.mixins import DeployableMixin @@ -70,6 +71,7 @@ def __init__( instructions: Text, tools: List[Union[Tool, Model]] = [], llm_id: Text = "6646261c6eb563165658bbb1", + llm: Optional[LLM] = None, api_key: Optional[Text] = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, @@ -86,7 +88,8 @@ def __init__( description (Text): description of the Agent. instructions (Text): role of the Agent. tools (List[Union[Tool, Model]]): List of tools that the Agent uses. - llm_id (Text, optional): large language model. Defaults to GPT-4o (6646261c6eb563165658bbb1). + llm_id (Text, optional): large language model ID. Defaults to GPT-4o (6646261c6eb563165658bbb1). + llm (LLM, optional): large language model object. Defaults to None. supplier (Text): Supplier of the Agent. version (Text): Version of the Agent. backend_url (str): URL of the backend. @@ -100,6 +103,7 @@ def __init__( for i, _ in enumerate(tools): self.tools[i].api_key = api_key self.llm_id = llm_id + self.llm = llm if isinstance(status, str): try: status = AssetStatus(status) @@ -111,17 +115,14 @@ def __init__( def _validate(self) -> None: """Validate the Agent.""" - from aixplain.factories.model_factory import ModelFactory + from aixplain.utils.llm_utils import get_llm_instance # validate name assert ( re.match(r"^[a-zA-Z0-9 \-\(\)]*$", self.name) is not None ), "Agent Creation Error: Agent name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed." - try: - llm = ModelFactory.get(self.llm_id, api_key=self.api_key) - except Exception: - raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") + llm = get_llm_instance(self.llm_id, api_key=self.api_key) assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." @@ -370,9 +371,18 @@ def to_dict(self) -> Dict: "role": self.instructions, "supplier": (self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier), "version": self.version, - "llmId": self.llm_id, + "llmId": self.llm_id if self.llm is None else self.llm.id, "status": self.status.value, "tasks": [task.to_dict() for task in self.tasks], + "tools": [ + { + "type": "llm", + "description": "main", + "parameters": self.llm.get_parameters().to_list() if self.llm.get_parameters() else None, + } + ] + if self.llm is not None + else [], } def delete(self) -> None: diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index c2961c8a..0485b10b 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -41,7 +41,9 @@ from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry +from aixplain.modules.model.llm_model import LLM +from aixplain.utils.llm_utils import get_llm_instance from aixplain.modules.mixins import DeployableMixin from pydantic import BaseModel @@ -81,6 +83,10 @@ def __init__( agents: List[Agent] = [], description: Text = "", llm_id: Text = "6646261c6eb563165658bbb1", + llm: Optional[LLM] = None, + supervisor_llm: Optional[LLM] = None, + mentalist_llm: Optional[LLM] = None, + inspector_llm: Optional[LLM] = None, api_key: Optional[Text] = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, @@ -101,6 +107,10 @@ def __init__( agents (List[Agent]): List of agents that the Team Agent uses. description (Text, optional): The description of the team agent to be displayed in the aiXplain platform. Defaults to "". llm_id (Text, optional): large language model. Defaults to GPT-4o (6646261c6eb563165658bbb1). + llm (LLM, optional): large language model object. Defaults to None. + supervisor_llm (LLM, optional): supervisor large language model object. Defaults to None. + mentalist_llm (LLM, optional): mentalist large language model object. Defaults to None. + inspector_llm (LLM, optional): inspector large language model object. Defaults to None. supplier (Text): Supplier of the Team Agent. version (Text): Version of the Team Agent. backend_url (str): URL of the backend. @@ -113,10 +123,14 @@ def __init__( self.additional_info = additional_info self.agents = agents self.llm_id = llm_id + self.llm = llm self.use_mentalist = use_mentalist self.use_inspector = use_inspector self.max_inspectors = max_inspectors self.inspector_targets = inspector_targets + self.supervisor_llm = supervisor_llm + self.mentalist_llm = mentalist_llm + self.inspector_llm = inspector_llm self.instructions = instructions if isinstance(status, str): try: @@ -343,6 +357,14 @@ def delete(self) -> None: raise Exception(f"{message}") def to_dict(self) -> Dict: + if self.use_mentalist: + planner_id = self.mentalist_llm.id if self.mentalist_llm else self.llm_id + else: + planner_id = None + if self.use_inspector: + inspector_id = self.inspector_llm.id if self.inspector_llm else self.llm_id + else: + inspector_id = None return { "id": self.id, "name": self.name, @@ -351,10 +373,10 @@ def to_dict(self) -> Dict: ], "links": [], "description": self.description, - "llmId": self.llm_id, - "supervisorId": self.llm_id, - "plannerId": self.llm_id if self.use_mentalist else None, - "inspectorId": self.llm_id if self.use_inspector else None, + "llmId": self.llm.id if self.llm else self.llm_id, + "supervisorId": self.supervisor_llm.id if self.supervisor_llm else self.llm_id, + "plannerId": planner_id, + "inspectorId": inspector_id, "maxInspectors": self.max_inspectors, "inspectorTargets": [target.value for target in self.inspector_targets], "supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier, @@ -365,7 +387,6 @@ def to_dict(self) -> Dict: def _validate(self) -> None: """Validate the Team.""" - from aixplain.factories.model_factory import ModelFactory # validate name assert ( @@ -373,7 +394,7 @@ def _validate(self) -> None: ), "Team Agent Creation Error: Team name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed." try: - llm = ModelFactory.get(self.llm_id) + llm = get_llm_instance(self.llm_id) assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." except Exception: raise Exception(f"Large Language Model with ID '{self.llm_id}' not found.") diff --git a/aixplain/utils/llm_utils.py b/aixplain/utils/llm_utils.py new file mode 100644 index 00000000..82fe3511 --- /dev/null +++ b/aixplain/utils/llm_utils.py @@ -0,0 +1,23 @@ +from typing import Optional, Text +from aixplain.factories.model_factory import ModelFactory +from aixplain.modules.model.llm_model import LLM + + +def get_llm_instance( + llm_id: Text, + api_key: Optional[Text] = None, +) -> LLM: + """Get an LLM instance with specific configuration. + + Args: + llm_id (Text): ID of the LLM model to use + api_key (Optional[Text], optional): API key to use. Defaults to None. + + Returns: + LLM: Configured LLM instance + """ + try: + llm = ModelFactory.get(llm_id, api_key=api_key) + return llm + except Exception: + raise Exception(f"Large Language Model with ID '{llm_id}' not found.") diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 9538ea5e..2a45c2bf 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -606,6 +606,36 @@ def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory): assert "hello" in answer["data"]["output"].lower() assert "hello pipeline" in answer["data"]["intermediate_steps"][0]["tool_steps"][0]["tool"].lower() +@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) +def test_agent_llm_parameter_preservation(delete_agents_and_team_agents, AgentFactory): + """Test that LLM parameters like temperature are preserved when creating agents.""" + assert delete_agents_and_team_agents + + # Get an LLM instance and customize its temperature + llm = ModelFactory.get("671be4886eb56397e51f7541") # Anthropic Claude 3.5 Sonnet v1 + original_temperature = llm.temperature + custom_temperature = 0.1 + llm.temperature = custom_temperature + + # Create agent with the custom LLM + agent = AgentFactory.create( + name="LLM Parameter Test Agent", + description="An agent for testing LLM parameter preservation", + instructions="Testing LLM parameter preservation", + llm=llm, + ) + + # Verify that the temperature setting was preserved + assert agent.llm.temperature == custom_temperature + + # Verify that the agent's LLM is the same instance as the original + assert id(agent.llm) == id(llm) + + # Clean up + agent.delete() + + # Reset the LLM temperature to its original value + llm.temperature = original_temperature def test_run_agent_with_expected_output(): from pydantic import BaseModel diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 47d88365..b526dcae 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -635,6 +635,48 @@ def test_team_agent_with_instructions(delete_agents_and_team_agents): agent_1.delete() agent_2.delete() +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_llm_parameter_preservation(delete_agents_and_team_agents, run_input_map, TeamAgentFactory): + """Test that LLM parameters like temperature are preserved for all LLM roles in team agents.""" + assert delete_agents_and_team_agents + + # Create a regular agent first + agents = create_agents_from_input_map(run_input_map, deploy=True) + + # Get LLM instances and customize their temperatures + supervisor_llm = ModelFactory.get("671be4886eb56397e51f7541") # Anthropic Claude 3.5 Sonnet v1 + mentalist_llm = ModelFactory.get("671be4886eb56397e51f7541") # Anthropic Claude 3.5 Sonnet v1 + inspector_llm = ModelFactory.get("671be4886eb56397e51f7541") # Anthropic Claude 3.5 Sonnet v1 + # Set custom temperatures + supervisor_llm.temperature = 0.1 + mentalist_llm.temperature = 0.3 + inspector_llm.temperature = 0.5 + + # Create a team agent with custom LLMs + team_agent = TeamAgentFactory.create( + name="LLM Parameter Test Team Agent", + agents=agents, + supervisor_llm=supervisor_llm, + mentalist_llm=mentalist_llm, + inspector_llm=inspector_llm, + llm_id="671be4886eb56397e51f7541", # Still required even with custom LLMs + description="A team agent for testing LLM parameter preservation", + use_mentalist=True, + use_inspector=True, + ) + + # Verify that temperature settings were preserved + assert team_agent.supervisor_llm.temperature == 0.1 + assert team_agent.mentalist_llm.temperature == 0.3 + assert team_agent.inspector_llm.temperature == 0.5 + + # Verify that the team agent's LLMs are the same instances as the originals + assert id(team_agent.supervisor_llm) == id(supervisor_llm) + assert id(team_agent.mentalist_llm) == id(mentalist_llm) + assert id(team_agent.inspector_llm) == id(inspector_llm) + + # Clean up + team_agent.delete() def test_run_team_agent_with_expected_output(): from pydantic import BaseModel From cdaaa81ab6d8f6ce51b12b9798473565eb541fea Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Tue, 27 May 2025 09:51:29 -0300 Subject: [PATCH 37/62] ENG-2105: persist sql data (#528) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Check for HTTP or s3 links on FileFactory * Return download link for temporary uploaded files * Persist database for SQL tool * Fixing some unit tests * NotImplementedError on deployment * Assertion on file factory and fix on agent functional test * forgotten line removed --------- Co-authored-by: ahmetgunduz Co-authored-by: Ahmet Gündüz --- aixplain/factories/file_factory.py | 27 ++++++++++++---- aixplain/modules/agent/tool/__init__.py | 3 ++ .../agent/tool/custom_python_code_tool.py | 5 +-- aixplain/modules/agent/tool/model_tool.py | 3 ++ aixplain/modules/agent/tool/pipeline_tool.py | 2 ++ .../agent/tool/python_interpreter_tool.py | 3 ++ aixplain/modules/agent/tool/sql_tool.py | 23 ++++++++++++-- aixplain/modules/mixins.py | 2 ++ .../data_onboarding/process_media_files.py | 12 ++++--- .../data_onboarding/process_text_files.py | 4 +-- aixplain/utils/file_utils.py | 31 +++++++++++++++---- .../functional/agent/agent_functional_test.py | 4 +-- .../functional/file_asset/file_create_test.py | 23 +++++++------- tests/functional/model/run_model_test.py | 24 ++++++++++++++ tests/unit/agent/agent_factory_utils_test.py | 8 ++--- tests/unit/agent/sql_tool_test.py | 13 ++++++-- 16 files changed, 146 insertions(+), 41 deletions(-) diff --git a/aixplain/factories/file_factory.py b/aixplain/factories/file_factory.py index b4701653..e1acebf2 100644 --- a/aixplain/factories/file_factory.py +++ b/aixplain/factories/file_factory.py @@ -39,7 +39,12 @@ class FileFactory: @classmethod def upload( - cls, local_path: Text, tags: Optional[List[Text]] = None, license: Optional[License] = None, is_temp: bool = True + cls, + local_path: Text, + tags: Optional[List[Text]] = None, + license: Optional[License] = None, + is_temp: bool = True, + return_download_link: bool = False, ) -> Text: """ Uploads a file to an S3 bucket. @@ -49,7 +54,7 @@ def upload( tags (List[Text], optional): tags of the file license (License, optional): the license for the file is_temp (bool): specify if the file that will be upload is a temporary file - + return_download_link (bool): specify if the function should return the download link of the file or the S3 path Returns: Text: The S3 path where the file was uploaded. @@ -57,6 +62,10 @@ def upload( FileNotFoundError: If the local file is not found. Exception: If the file size exceeds the maximum allowed size. """ + if is_temp is False: + assert ( + return_download_link is False + ), "File Upload Error: It is not allowed to return the download link for non-temporary files." if os.path.exists(local_path) is False: raise FileNotFoundError(f'File Upload Error: local file "{local_path}" not found.') # mime type format: {type}/{extension} @@ -86,9 +95,16 @@ def upload( ) if is_temp is False: - s3_path = upload_data(file_name=local_path, tags=tags, license=license, is_temp=is_temp, content_type=content_type) + s3_path = upload_data( + file_name=local_path, + tags=tags, + license=license, + is_temp=is_temp, + content_type=content_type, + return_download_link=return_download_link, + ) else: - s3_path = upload_data(file_name=local_path) + s3_path = upload_data(file_name=local_path, return_download_link=return_download_link) return s3_path @classmethod @@ -145,7 +161,6 @@ def create( tags (List[Text], optional): tags of the file license (License, optional): the license for the file is_temp (bool): specify if the file that will be upload is a temporary file - Returns: Text: The S3 path where the file was uploaded. @@ -156,4 +171,4 @@ def create( assert ( license is not None if is_temp is False else True ), "File Asset Creation Error: To upload a non-temporary file, you need to specify the `license`." - return cls.upload(local_path=local_path, tags=tags, license=license, is_temp=is_temp) + return cls.upload(local_path=local_path, tags=tags, license=license, is_temp=is_temp, return_download_link=is_temp) diff --git a/aixplain/modules/agent/tool/__init__.py b/aixplain/modules/agent/tool/__init__.py index 93dc269d..4b0ed0fc 100644 --- a/aixplain/modules/agent/tool/__init__.py +++ b/aixplain/modules/agent/tool/__init__.py @@ -65,3 +65,6 @@ def to_dict(self): def validate(self): raise NotImplementedError + + def deploy(self) -> None: + raise NotImplementedError diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index 8ec3ab9e..6433a408 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -75,7 +75,8 @@ def validate(self): AssetStatus.ONBOARDED, ], "Custom Python Code Tool Error: Status must be DRAFT or ONBOARDED" - - def __repr__(self) -> Text: return f"CustomPythonCodeTool(name={self.name})" + + def deploy(self): + pass diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 56310100..2f662ff6 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -258,3 +258,6 @@ def __repr__(self) -> Text: supplier_str = self.supplier.value if self.supplier is not None else None model_str = self.model.id if self.model is not None else None return f"ModelTool(name={self.name}, function={self.function}, supplier={supplier_str}, model={model_str})" + + def deploy(self): + pass diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 728256d2..d6bd913e 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -85,3 +85,5 @@ def validate(self): self.name = pipeline_obj.name self.status = pipeline_obj.status + def deploy(self): + pass diff --git a/aixplain/modules/agent/tool/python_interpreter_tool.py b/aixplain/modules/agent/tool/python_interpreter_tool.py index 7bbb6e25..69212bf4 100644 --- a/aixplain/modules/agent/tool/python_interpreter_tool.py +++ b/aixplain/modules/agent/tool/python_interpreter_tool.py @@ -49,3 +49,6 @@ def validate(self): def __repr__(self) -> Text: return "PythonInterpreterTool()" + + def deploy(self): + pass diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index 5b3f4178..4cc36c68 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -285,7 +285,7 @@ def __init__( self.tables = tables if isinstance(tables, list) else [tables] if tables else None self.enable_commit = enable_commit self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool - self.validate() # to upload the database + self.validate() def to_dict(self) -> Dict[str, Text]: return { @@ -335,6 +335,25 @@ def validate(self): # Upload database try: - self.database = FileFactory.upload(local_path=self.database, is_temp=True) + self.database = FileFactory.create(local_path=self.database, is_temp=True) except Exception as e: raise SQLToolError(f"Failed to upload database: {str(e)}") + + def deploy(self) -> None: + import uuid + import requests + from pathlib import Path + from aixplain.factories.file_factory import FileFactory + from aixplain.enums import License + + # Generate unique filename with uuid4 + local_path = str(Path(f"{uuid.uuid4()}.db")) + + # Download database file + if str(self.database).startswith(("http://", "https://")): + response = requests.get(self.database) + response.raise_for_status() + with open(local_path, "wb") as f: + f.write(response.content) + self.database = FileFactory.create(local_path=local_path, is_temp=False, license=License.MIT) + os.remove(local_path) diff --git a/aixplain/modules/mixins.py b/aixplain/modules/mixins.py index 0b402cf2..27c023c7 100644 --- a/aixplain/modules/mixins.py +++ b/aixplain/modules/mixins.py @@ -67,6 +67,8 @@ def deploy(self) -> None: self._validate_deployment_readiness() previous_status = self.status try: + if hasattr(self, "tools"): + [tool.deploy() for tool in self.tools] self.status = AssetStatus.ONBOARDED self.update() except Exception as e: diff --git a/aixplain/processes/data_onboarding/process_media_files.py b/aixplain/processes/data_onboarding/process_media_files.py index 1e007d85..f0200165 100644 --- a/aixplain/processes/data_onboarding/process_media_files.py +++ b/aixplain/processes/data_onboarding/process_media_files.py @@ -166,7 +166,9 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> # compress the folder compressed_folder = compress_folder(data_file_name) # upload zipped medias into s3 - s3_compressed_folder = upload_data(compressed_folder, content_type="application/x-tar", return_s3_link=True) + s3_compressed_folder = upload_data( + compressed_folder, content_type="application/x-tar", return_download_link=False + ) # update index files pointing the s3 link df["@SOURCE"] = s3_compressed_folder # remove media folder @@ -199,7 +201,9 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> end_column_idx = df.columns.to_list().index(end_column) df.to_csv(index_file_name, compression="gzip", index=False) - s3_link = upload_data(index_file_name, content_type="text/csv", content_encoding="gzip", return_s3_link=True) + s3_link = upload_data( + index_file_name, content_type="text/csv", content_encoding="gzip", return_download_link=False + ) files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) # get data column index data_column_idx = df.columns.to_list().index(metadata.name) @@ -224,7 +228,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> # compress the folder compressed_folder = compress_folder(data_file_name) # upload zipped medias into s3 - s3_compressed_folder = upload_data(compressed_folder, content_type="application/x-tar", return_s3_link=True) + s3_compressed_folder = upload_data(compressed_folder, content_type="application/x-tar", return_download_link=False) # update index files pointing the s3 link df["@SOURCE"] = s3_compressed_folder # remove media folder @@ -257,7 +261,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 100) -> end_column_idx = df.columns.to_list().index(end_column) df.to_csv(index_file_name, compression="gzip", index=False) - s3_link = upload_data(index_file_name, content_type="text/csv", content_encoding="gzip", return_s3_link=True) + s3_link = upload_data(index_file_name, content_type="text/csv", content_encoding="gzip", return_download_link=False) files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) # get data column index data_column_idx = df.columns.to_list().index(metadata.name) diff --git a/aixplain/processes/data_onboarding/process_text_files.py b/aixplain/processes/data_onboarding/process_text_files.py index 84df057f..e219b835 100644 --- a/aixplain/processes/data_onboarding/process_text_files.py +++ b/aixplain/processes/data_onboarding/process_text_files.py @@ -100,7 +100,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 1000) - start, end = idx - len(batch), idx df["@INDEX"] = range(start, end) df.to_csv(file_name, compression="gzip", index=False) - s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip", return_s3_link=True) + s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip", return_download_link=False) files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) # get data column index data_column_idx = df.columns.to_list().index(metadata.name) @@ -114,7 +114,7 @@ def run(metadata: MetaData, paths: List, folder: Path, batch_size: int = 1000) - start, end = idx - len(batch), idx df["@INDEX"] = range(start, end) df.to_csv(file_name, compression="gzip", index=False) - s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip", return_s3_link=True) + s3_link = upload_data(file_name, content_type="text/csv", content_encoding="gzip", return_download_link=False) files.append(File(path=s3_link, extension=FileType.CSV, compression="gzip")) # get data column index data_column_idx = df.columns.to_list().index(metadata.name) diff --git a/aixplain/utils/file_utils.py b/aixplain/utils/file_utils.py index f2bb55bc..3eddb95d 100644 --- a/aixplain/utils/file_utils.py +++ b/aixplain/utils/file_utils.py @@ -73,7 +73,7 @@ def upload_data( content_type: Text = "text/csv", content_encoding: Optional[Text] = None, nattempts: int = 2, - return_s3_link: bool = True, + return_download_link: bool = False, ): """Upload files to S3 with pre-signed URLs @@ -85,7 +85,7 @@ def upload_data( content_type (Text, optional): Type of content. Defaults to "text/csv". content_encoding (Text, optional): Content encoding. Defaults to None. nattempts (int, optional): Number of attempts for diminish the risk of exceptions. Defaults to 2. - return_s3_link (bool, optional): If True, the function will return the s3 link instead of the presigned url. Defaults to False. + return_download_link (bool, optional): If True, the function will return the download link instead of the presigned url. Defaults to False. Reference: https://python.plainenglish.io/upload-files-to-aws-s3-using-pre-signed-urls-in-python-d3c2fcab1b41 @@ -113,6 +113,7 @@ def upload_data( path = response["key"] # Upload data presigned_url = response["uploadUrl"] # pre-signed URL + download_link = response.get("downloadUrl", "") headers = {"Content-Type": content_type} if content_encoding is not None: headers["Content-Encoding"] = content_encoding @@ -123,17 +124,35 @@ def upload_data( # if the process fail, try one more if r.status_code != 200: if nattempts > 0: - return upload_data(file_name, content_type, content_encoding, nattempts - 1) + return upload_data( + file_name=file_name, + content_type=content_type, + tags=tags, + license=license, + is_temp=is_temp, + content_encoding=content_encoding, + nattempts=nattempts - 1, + return_download_link=return_download_link, + ) else: raise Exception("File Uploading Error: Failure on Uploading to S3.") - if return_s3_link: + if return_download_link is False: bucket_name = re.findall(r"https://(.*?).s3.amazonaws.com", presigned_url)[0] s3_link = f"s3://{bucket_name}/{path}" return s3_link - return presigned_url + return download_link except Exception: if nattempts > 0: - return upload_data(file_name, content_type, content_encoding, nattempts - 1) + return upload_data( + file_name=file_name, + content_type=content_type, + tags=tags, + license=license, + is_temp=is_temp, + content_encoding=content_encoding, + nattempts=nattempts - 1, + return_download_link=return_download_link, + ) else: raise Exception("File Uploading Error: Failure on Uploading to S3.") diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 2a45c2bf..94bbdd61 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -369,7 +369,7 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): f.write("") tool = AgentFactory.create_sql_tool( - name="Teste", + name="TestDB", description="Execute an SQL query and return the result", source="ftest.db", source_type="sqlite", @@ -444,7 +444,7 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): # Verify tool setup assert tool is not None assert tool.description == "Execute SQL queries on employee data" - assert tool.database.endswith(".db") + assert tool.database.split("?")[0].endswith(".db") assert tool.tables == ["employees"] assert ( tool.schema diff --git a/tests/functional/file_asset/file_create_test.py b/tests/functional/file_asset/file_create_test.py index 29678cc7..0eb06ad8 100644 --- a/tests/functional/file_asset/file_create_test.py +++ b/tests/functional/file_asset/file_create_test.py @@ -22,15 +22,16 @@ from aixplain import aixplain_v2 as v2 -@pytest.mark.parametrize("FileFactory", [FileFactory, v2.File]) -def test_file_create(FileFactory): +@pytest.mark.parametrize( + "FileFactory, is_temp, expected_link", + [ + (FileFactory, True, "http"), + (v2.File, True, "http"), + (FileFactory, False, "s3"), + (v2.File, False, "s3"), + ], +) +def test_file_create(FileFactory, is_temp, expected_link): upload_file = "tests/functional/file_asset/input/test.csv" - s3_link = FileFactory.create(local_path=upload_file, tags=["test1", "test2"], license=License.MIT, is_temp=False) - assert s3_link.startswith("s3") - - -@pytest.mark.parametrize("FileFactory", [FileFactory, v2.File]) -def test_file_create_temp(FileFactory): - upload_file = "tests/functional/file_asset/input/test.csv" - s3_link = FileFactory.create(local_path=upload_file, tags=["test1", "test2"], license=License.MIT, is_temp=True) - assert s3_link.startswith("s3") + s3_link = FileFactory.create(local_path=upload_file, tags=["test1", "test2"], license=License.MIT, is_temp=is_temp) + assert s3_link.startswith(expected_link) diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 14495b53..41526945 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -16,6 +16,7 @@ + def pytest_generate_tests(metafunc): if "llm_model" in metafunc.fixturenames: four_weeks_ago = datetime.now(timezone.utc) - timedelta(weeks=4) @@ -202,6 +203,29 @@ def test_llm_run_with_file(): assert "🤖" in response["data"], "Robot emoji should be present in the response" +def test_aixplain_model_cache_creation(): + """Ensure AssetCache is triggered and cache is created.""" + + cache_file = os.path.join(CACHE_FOLDER, "models.json") + + # Clean up cache before the test + if os.path.exists(cache_file): + os.remove(cache_file) + + # Instantiate the Model (replace this with a real model ID from your env) + model_id = "6239efa4822d7a13b8e20454" # Translate from Punjabi to Portuguese (Brazil) + _ = Model(id=model_id) + + # Assert the cache file was created + assert os.path.exists(cache_file), "Expected cache file was not created." + + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + assert "data" in cache_data, "Cache file structure invalid - missing 'data' key." + assert any(m.get("id") == model_id for m in cache_data["data"]["items"]), "Instantiated model not found in cache." + + def test_index_model_air_with_image(): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record diff --git a/tests/unit/agent/agent_factory_utils_test.py b/tests/unit/agent/agent_factory_utils_test.py index e48c7a45..59486242 100644 --- a/tests/unit/agent/agent_factory_utils_test.py +++ b/tests/unit/agent/agent_factory_utils_test.py @@ -144,7 +144,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): "name": "Test SQL", "description": "Test SQL", "parameters": [ - {"name": "database", "value": "test_db.db"}, + {"name": "database", "value": "s3://test_db.db"}, {"name": "schema", "value": "public"}, {"name": "tables", "value": "table1,table2"}, {"name": "enable_commit", "value": True}, @@ -154,7 +154,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): { "name": "Test SQL", "description": "Test SQL", - "database": "test_db.db", + "database": "s3://test_db.db", "schema": "public", "tables": ["table1", "table2"], "enable_commit": True, @@ -167,7 +167,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): "name": "Test SQL", "description": "Test SQL with string enable_commit", "parameters": [ - {"name": "database", "value": "test_db.db"}, + {"name": "database", "value": "s3://test_db.db"}, {"name": "schema", "value": "public"}, {"name": "tables", "value": "table1"}, {"name": "enable_commit", "value": True}, @@ -177,7 +177,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): { "name": "Test SQL", "description": "Test SQL with string enable_commit", - "database": "test_db.db", + "database": "s3://test_db.db", "schema": "public", "tables": ["table1"], "enable_commit": True, diff --git a/tests/unit/agent/sql_tool_test.py b/tests/unit/agent/sql_tool_test.py index 9f03a279..b9128bb3 100644 --- a/tests/unit/agent/sql_tool_test.py +++ b/tests/unit/agent/sql_tool_test.py @@ -1,7 +1,7 @@ import os import pytest import pandas as pd -from aixplain.factories import AgentFactory +from aixplain.factories import AgentFactory, FileFactory from aixplain.enums import DatabaseSourceType from aixplain.modules.agent.tool.sql_tool import ( @@ -53,6 +53,8 @@ def test_create_sql_tool(mocker, tmp_path): conn.execute("CREATE TABLE test (id INTEGER, name TEXT)") conn.close() + mocker.patch.object(FileFactory, "upload", return_value="s3://test.db") + # Test SQLite source type tool = AgentFactory.create_sql_tool( name="Test SQL", description="Test", source=db_path, source_type="sqlite", schema="test", tables=["test", "test2"] @@ -215,6 +217,8 @@ def test_create_sql_tool_with_schema_inference(tmp_path, mocker): conn.execute("CREATE TABLE test (id INTEGER, name TEXT)") conn.close() + mocker.patch.object(FileFactory, "upload", return_value=db_path) + # Create tool without schema and tables tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=db_path, source_type="sqlite") @@ -274,12 +278,17 @@ def test_create_sql_tool_from_csv_with_warnings(tmp_path, mocker): os.remove(tool.database) -def test_create_sql_tool_from_csv(tmp_path): +def test_create_sql_tool_from_csv(tmp_path, mocker): # Create a temporary CSV file csv_path = os.path.join(tmp_path, "test.csv") df = pd.DataFrame({"id": [1, 2, 3], "name": ["test1", "test2", "test3"], "value": [1.1, 2.2, 3.3]}) df.to_csv(csv_path, index=False) + with open("test.db", "w") as f: + f.write("") + + mocker.patch.object(FileFactory, "upload", return_value="s3://test.db") + # Test successful creation tool = AgentFactory.create_sql_tool( name="Test SQL", description="Test", source=csv_path, source_type="csv", tables=["test"] From 77ff8d849f279e04d95b938ce4f6c6c592b5abff Mon Sep 17 00:00:00 2001 From: Yunsu Kim Date: Mon, 2 Jun 2025 16:15:21 +0200 Subject: [PATCH 38/62] Custom inspector interface (#484) * Copy from utility model module * First cursor revision for GuardrailModel * Add GuardrailPolicy * Add guard ID and config * Add GuardrailFactory * Add logic for creation with guard_instruction * Rename and move factory file * Add ModelWithParams * Add InspectorAgent * Remove guardrail model * Add name in Inspector * Draft for InspectorFactory * Remove validator for policy and add it for name * Support only Guardrail model * TODO comment * Move inspector.py to team_agent dir * Add inspectors to TeamAgent (remove use_inspector and max_inspectors) * Add inspectors to TeamAgentFactory.create() (remove use_inspector and num_inspectors) * Update inspector functional tests (deactivated for now) * Add InspectorAuto * Add create_auto * Add unit tests * Inspector can be created also from Model * Support also LLMs * Ensure Inspector classes are instantiated for compatibility with backend return format * Enable test_team_agent_with_steps_inspector * Enable test_team_agent_with_output_inspector * Enable test_team_agent_with_multiple_inspectors * Refactor team agent functional tests * Remove inspector_llm and add use_inspector in TeamAgent * Add critiques in AgentResponse * Remove inspector_llm in team agent functional tests * Add more inspector functional tests * Revert warn policy test --- .../factories/team_agent_factory/__init__.py | 70 ++-- .../team_agent_factory/inspector_factory.py | 112 +++++ .../factories/team_agent_factory/utils.py | 14 +- aixplain/modules/agent/agent_response_data.py | 7 +- aixplain/modules/agent/model_with_params.py | 51 +++ aixplain/modules/team_agent/__init__.py | 49 +-- aixplain/modules/team_agent/inspector.py | 76 ++++ .../team_agent/inspector_functional_test.py | 370 +++++++++++++++++ .../team_agent/team_agent_functional_test.py | 392 ++---------------- tests/functional/team_agent/test_utils.py | 78 ++++ tests/unit/team_agent/inspector_test.py | 144 +++++++ .../unit/{ => team_agent}/team_agent_test.py | 207 ++------- 12 files changed, 946 insertions(+), 624 deletions(-) create mode 100644 aixplain/factories/team_agent_factory/inspector_factory.py create mode 100644 aixplain/modules/agent/model_with_params.py create mode 100644 aixplain/modules/team_agent/inspector.py create mode 100644 tests/functional/team_agent/inspector_functional_test.py create mode 100644 tests/functional/team_agent/test_utils.py create mode 100644 tests/unit/team_agent/inspector_test.py rename tests/unit/{ => team_agent}/team_agent_test.py (67%) diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index cfbbe622..77f9fad4 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -29,6 +29,7 @@ from aixplain.enums.supplier import Supplier from aixplain.modules.agent import Agent from aixplain.modules.team_agent import TeamAgent, InspectorTarget +from aixplain.modules.team_agent.inspector import Inspector from aixplain.utils import config from aixplain.factories.team_agent_factory.utils import build_team_agent from aixplain.utils.request_utils import _request_with_retry @@ -46,17 +47,15 @@ def create( llm: Optional[Union[LLM, Text]] = None, supervisor_llm: Optional[Union[LLM, Text]] = None, mentalist_llm: Optional[Union[LLM, Text]] = None, - inspector_llm: Optional[Union[LLM, Text]] = None, description: Text = "", api_key: Text = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, use_mentalist: bool = True, - use_inspector: bool = True, - num_inspectors: int = 1, + inspectors: List[Inspector] = [], inspector_targets: List[Union[InspectorTarget, Text]] = [InspectorTarget.STEPS], - use_mentalist_and_inspector: bool = False, # TODO: remove this instructions: Optional[Text] = None, + **kwargs, ) -> TeamAgent: """Create a new team agent in the platform. @@ -67,14 +66,12 @@ def create( llm (Optional[Union[LLM, Text]], optional): The LLM to be used for the team agent. supervisor_llm (Optional[Union[LLM, Text]], optional): Main supervisor LLM. Defaults to None. mentalist_llm (Optional[Union[LLM, Text]], optional): LLM for planning. Defaults to None. - inspector_llm (Optional[Union[LLM, Text]], optional): LLM for inspection. Defaults to None. description: The description of the team agent to be displayed in the aiXplain platform. api_key: The API key to be used for the team agent. supplier: The supplier of the team agent. version: The version of the team agent. use_mentalist: Whether to use the mentalist agent. - use_inspector: Whether to use the inspector agent. - num_inspectors: The number of inspectors to be used for each inspection. + inspectors: A list of inspectors to be added to the team. inspector_targets: Which stages to be inspected during an execution of the team agent. (steps, output) use_mentalist_and_inspector: Whether to use the mentalist and inspector agents. (legacy) instructions: The instructions to guide the team agent (i.e. appended in the prompt of the team agent). @@ -82,6 +79,16 @@ def create( Returns: A new team agent instance. """ + # legacy params + if "use_mentalist_and_inspector" in kwargs: + logging.warning( + "TeamAgent Onboarding Warning: use_mentalist_and_inspector is no longer supported. Use use_mentalist and inspectors instead." + ) + if "use_inspector" in kwargs: + logging.warning("TeamAgent Onboarding Warning: use_inspector is no longer supported. Use inspectors instead.") + if "num_inspectors" in kwargs: + logging.warning("TeamAgent Onboarding Warning: num_inspectors is no longer supported. Use inspectors instead.") + assert len(agents) > 0, "TeamAgent Onboarding Error: At least one agent must be provided." agent_list = [] for agent in agents: @@ -100,29 +107,21 @@ def create( assert isinstance(agent, Agent), "TeamAgent Onboarding Error: Agents must be instances of Agent class" agent_list.append(agent_obj) - # NOTE: backend expects max_inspectors (for "generated" inspectors) - max_inspectors = num_inspectors - - if use_inspector: + if inspectors: try: # convert to enum if string and check its validity inspector_targets = [InspectorTarget(target) for target in inspector_targets] except ValueError: - raise ValueError("TeamAgent Onboarding Error: Invalid inspector target. Valid targets are: steps, output") + raise ValueError( + f"TeamAgent Onboarding Error: Invalid inspector target. Valid targets are: {list(InspectorTarget)}" + ) if not use_mentalist: raise Exception("TeamAgent Onboarding Error: To use the Inspector agent, you must enable Mentalist.") - if max_inspectors < 1: - raise Exception( - "TeamAgent Onboarding Error: The number of inspectors must be greater than 0 when using the Inspector agent." - ) - - if use_mentalist_and_inspector: - mentalist_llm_id = llm_id - inspector_llm_id = llm_id else: - mentalist_llm_id = llm_id if use_mentalist else None - inspector_llm_id = llm_id if use_inspector else None + inspector_targets = [] + + mentalist_llm_id = llm_id if use_mentalist else None # Set up LLMs if llm is None: @@ -134,9 +133,6 @@ def create( if use_mentalist and mentalist_llm is None: mentalist_llm = get_llm_instance(mentalist_llm_id or "669a63646eb56306647e1091", api_key=api_key) - if use_inspector and inspector_llm is None: - inspector_llm = get_llm_instance(inspector_llm_id or "669a63646eb56306647e1091", api_key=api_key) - team_agent = None url = urljoin(config.BACKEND_URL, "sdk/agent-communities") headers = {"x-api-key": api_key} @@ -158,9 +154,8 @@ def create( "llmId": llm.id if llm else llm_id, "supervisorId": supervisor_llm.id if supervisor_llm else llm_id, "plannerId": mentalist_llm.id if mentalist_llm else mentalist_llm_id, - "inspectorId": inspector_llm.id if inspector_llm else inspector_llm_id, - "maxInspectors": max_inspectors, - "inspectorTargets": inspector_targets if use_inspector else [], + "inspectors": inspectors, + "inspectorTargets": inspector_targets, "supplier": supplier, "version": version, "status": "draft", @@ -203,18 +198,6 @@ def create( } ) - if inspector_llm is not None: - inspector_llm = ( - get_llm_instance(inspector_llm, api_key=api_key) if isinstance(inspector_llm, str) else inspector_llm - ) - payload["tools"].append( - { - "type": "llm", - "description": "inspector", - "parameters": inspector_llm.get_parameters().to_list() if inspector_llm.get_parameters() else None, - } - ) - # Store the LLM objects directly in the payload for build_team_agent internal_payload = payload.copy() if llm is not None: @@ -223,13 +206,14 @@ def create( internal_payload["supervisor_llm"] = supervisor_llm if mentalist_llm is not None: internal_payload["mentalist_llm"] = mentalist_llm - if inspector_llm is not None: - internal_payload["inspector_llm"] = inspector_llm team_agent = build_team_agent(payload=internal_payload, agents=agent_list, api_key=api_key) team_agent.validate(raise_exception=True) response = "Unspecified error" try: + payload["inspectors"] = [ + inspector.model_dump(by_alias=True) for inspector in inspectors + ] # convert Inspector object to dict logging.debug(f"Start service for POST Create TeamAgent - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("post", url, headers=headers, json=payload) response = r.json() @@ -244,8 +228,6 @@ def create( response["supervisor_llm"] = internal_payload["supervisor_llm"] if "mentalist_llm" in internal_payload: response["mentalist_llm"] = internal_payload["mentalist_llm"] - if "inspector_llm" in internal_payload: - response["inspector_llm"] = internal_payload["inspector_llm"] team_agent = build_team_agent(payload=response, agents=agent_list, api_key=api_key) else: diff --git a/aixplain/factories/team_agent_factory/inspector_factory.py b/aixplain/factories/team_agent_factory/inspector_factory.py new file mode 100644 index 00000000..0d68d1cd --- /dev/null +++ b/aixplain/factories/team_agent_factory/inspector_factory.py @@ -0,0 +1,112 @@ +"""Factory for inspectors. + +Example usage: + +inspector = InspectorFactory.create_from_model( + name="my_inspector", + model_id="my_model", + model_config={"prompt": "Check if the data is safe to use."}, + policy=InspectorPolicy.ADAPTIVE, +) +""" + +import logging +from typing import Dict, Optional, Text, Union +from urllib.parse import urljoin + +from aixplain.enums.asset_status import AssetStatus +from aixplain.enums.function import Function +from aixplain.factories.model_factory.utils import create_model_from_response +from aixplain.modules.model import Model +from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy, InspectorAuto +from aixplain.utils import config +from aixplain.utils.file_utils import _request_with_retry + + +INSPECTOR_SUPPORTED_FUNCTIONS = [Function.GUARDRAILS, Function.TEXT_GENERATION] + + +class InspectorFactory: + """A class for creating an Inspector instance.""" + + @classmethod + def create_from_model( + cls, + name: Text, + model: Union[Text, Model], + model_config: Optional[Dict] = None, + policy: InspectorPolicy = InspectorPolicy.ADAPTIVE, # default: doing something dynamically + ) -> Inspector: + """Create a new inspector agent from an onboarded model. + + Args: + name: Name of the inspector agent. + model: Model or model ID to use for inspector. + model_config: Configuration for the inspector. Defaults to None. + policy: Action to take upon negative feedback (WARN/ABORT/ADAPTIVE). Defaults to ADAPTIVE. + + Returns: + Inspector: The created inspector + """ + # fetch model if model ID is provided + if isinstance(model, Text): + model_id = model + try: + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + logging.info(f"Start service for GET Model - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + except Exception: + raise ValueError(f"Inspector: Failed to get model with ID {model_id}") + + if 200 <= r.status_code < 300: + model = create_model_from_response(resp) + else: + error_message = ( + f"Inspector: Failed to get model with ID {model_id} (status code = {r.status_code})\nError: {resp}" + ) + logging.error(error_message) + raise Exception(error_message) + else: + model_id = model.id + + # check if the model is onboarded + if model.status != AssetStatus.ONBOARDED: + raise ValueError(f"Inspector: Model with ID {model_id} is not onboarded") + + # TODO: relax this constraint + if model.function not in INSPECTOR_SUPPORTED_FUNCTIONS: + raise ValueError( + f"Inspector: Only {', '.join([f.value for f in INSPECTOR_SUPPORTED_FUNCTIONS])} models are supported at the moment. Model with ID {model_id} is a {model.function} model" + ) + + return Inspector( + name=name, + model_id=model_id, + model_params=model_config, + policy=policy, + ) + + @classmethod + def create_auto( + cls, + auto: InspectorAuto, + name: Optional[Text] = None, + policy: InspectorPolicy = InspectorPolicy.ADAPTIVE, + ) -> Inspector: + """Create a new inspector agent from an automatically configured inspector. + + Args: + auto: The automatically configured inspector. + policy: Action to take upon negative feedback (WARN/ABORT/ADAPTIVE). Defaults to ADAPTIVE. + + Returns: + Inspector: The created inspector. + """ + return Inspector( + name=name or auto.get_name(), + auto=auto, + policy=policy, + ) diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index ce382531..16b65641 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -8,6 +8,7 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.agent import Agent from aixplain.modules.team_agent import TeamAgent, InspectorTarget +from aixplain.modules.team_agent.inspector import Inspector from aixplain.factories.agent_factory import AgentFactory from aixplain.factories.model_factory import ModelFactory from aixplain.modules.model.model_parameters import ModelParameters @@ -32,20 +33,21 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = ) continue + # Ensure custom classes are instantiated: for compatibility with backend return format + inspectors = [ + inspector if isinstance(inspector, Inspector) else Inspector(**inspector) for inspector in payload.get("inspectors", []) + ] inspector_targets = [InspectorTarget(target.lower()) for target in payload.get("inspectorTargets", [])] # Get LLMs from tools if present supervisor_llm = None mentalist_llm = None - inspector_llm = None # First check if we have direct LLM objects in the payload if "supervisor_llm" in payload: supervisor_llm = payload["supervisor_llm"] if "mentalist_llm" in payload: mentalist_llm = payload["mentalist_llm"] - if "inspector_llm" in payload: - inspector_llm = payload["inspector_llm"] # Otherwise create from the parameters elif "tools" in payload: for tool in payload["tools"]: @@ -74,8 +76,6 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = supervisor_llm = llm elif tool["description"] == "mentalist": mentalist_llm = llm - elif tool["description"] == "inspector": - inspector_llm = llm team_agent = TeamAgent( id=payload.get("id", ""), @@ -89,10 +89,8 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = llm_id=payload.get("llmId", GPT_4o_ID), supervisor_llm=supervisor_llm, mentalist_llm=mentalist_llm, - inspector_llm=inspector_llm, use_mentalist=True if payload.get("plannerId", None) is not None else False, - use_inspector=True if payload.get("inspectorId", None) is not None else False, - max_inspectors=payload.get("maxInspectors", 1), + inspectors=inspectors, inspector_targets=inspector_targets, api_key=api_key, status=AssetStatus(payload["status"]), diff --git a/aixplain/modules/agent/agent_response_data.py b/aixplain/modules/agent/agent_response_data.py index 0acd8b80..7a93b3aa 100644 --- a/aixplain/modules/agent/agent_response_data.py +++ b/aixplain/modules/agent/agent_response_data.py @@ -9,12 +9,14 @@ def __init__( session_id: str = "", intermediate_steps: Optional[List[Any]] = None, execution_stats: Optional[Dict[str, Any]] = None, + critiques: Optional[str] = None, ): self.input = input self.output = output self.session_id = session_id self.intermediate_steps = intermediate_steps or [] self.execution_stats = execution_stats + self.critiques = critiques or "" @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AgentResponseData": @@ -24,6 +26,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AgentResponseData": session_id=data.get("session_id", ""), intermediate_steps=data.get("intermediate_steps", []), execution_stats=data.get("executionStats"), + critiques=data.get("critiques", ""), ) def to_dict(self) -> Dict[str, Any]: @@ -34,6 +37,7 @@ def to_dict(self) -> Dict[str, Any]: "intermediate_steps": self.intermediate_steps, "executionStats": self.execution_stats, "execution_stats": self.execution_stats, + "critiques": self.critiques, } def __getitem__(self, key): @@ -52,7 +56,8 @@ def __repr__(self) -> str: f"output={self.output}, " f"session_id='{self.session_id}', " f"intermediate_steps={self.intermediate_steps}, " - f"execution_stats={self.execution_stats})" + f"execution_stats={self.execution_stats}, " + f"critiques='{self.critiques}')" ) def __contains__(self, key: Text) -> bool: diff --git a/aixplain/modules/agent/model_with_params.py b/aixplain/modules/agent/model_with_params.py new file mode 100644 index 00000000..431f8cd0 --- /dev/null +++ b/aixplain/modules/agent/model_with_params.py @@ -0,0 +1,51 @@ +"""A generic class that wraps a model with extra parameters. + +This is an abstract base class that must be extended by specific model wrappers. + +Example usage: + +class MyModel(ModelWithParams): + model_id: Text = "my_model" + extra_param: int = 10 + + @field_validator("extra_param") + def validate_extra_param(cls, v: int) -> int: + if v < 0: + raise ValueError("Extra parameter must be positive") + return v +""" + +from abc import ABC +from typing import Text + +from pydantic import BaseModel, ConfigDict, field_validator +from pydantic.alias_generators import to_camel + + +class ModelWithParams(BaseModel, ABC): + """A generic class that wraps a model with extra parameters. + + The extra parameters are not part of the model's input/output parameters. + This is an abstract base class that must be extended by specific model wrappers. + + Attributes: + model_id: The ID of the model to wrap. + """ + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + model_id: Text + + @field_validator("model_id") + def validate_model_id(cls, v: Text) -> Text: + if not v or not v.strip(): + raise ValueError("Model ID is required") + return v + + def __new__(cls, *args, **kwargs): + if cls is ModelWithParams: + raise TypeError("ModelWithParams is an abstract base class and cannot be instantiated directly") + return super().__new__(cls) diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 0485b10b..510e82d1 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -40,6 +40,7 @@ from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables +from aixplain.modules.team_agent.inspector import Inspector from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.model.llm_model import LLM @@ -63,15 +64,16 @@ class TeamAgent(Model, DeployableMixin[Agent]): Attributes: id (Text): ID of the Team Agent name (Text): Name of the Team Agent - agents (List[Agent]): List of Agents that the Team Agent uses. + agents (List[Agent]): List of agents that the Team Agent uses. description (Text, optional): description of the Team Agent. Defaults to "". llm_id (Text, optional): large language model. Defaults to GPT-4o (6646261c6eb563165658bbb1). + api_key (str): The TEAM API key used for authentication. supplier (Text): Supplier of the Team Agent. version (Text): Version of the Team Agent. - backend_url (str): URL of the backend. - api_key (str): The TEAM API key used for authentication. cost (Dict, optional): model price. Defaults to None. - use_mentalist_and_inspector (bool): Use Mentalist and Inspector tools. Defaults to True. + use_mentalist (bool): Use Mentalist agent for pre-planning. Defaults to True. + inspectors (List[Inspector]): List of inspectors that the team agent uses. + inspector_targets (List[InspectorTarget]): List of targets where the inspectors are applied. Defaults to [InspectorTarget.STEPS]. """ is_valid: bool @@ -86,51 +88,28 @@ def __init__( llm: Optional[LLM] = None, supervisor_llm: Optional[LLM] = None, mentalist_llm: Optional[LLM] = None, - inspector_llm: Optional[LLM] = None, api_key: Optional[Text] = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, cost: Optional[Dict] = None, use_mentalist: bool = True, - use_inspector: bool = True, - max_inspectors: int = 1, + inspectors: List[Inspector] = [], inspector_targets: List[InspectorTarget] = [InspectorTarget.STEPS], status: AssetStatus = AssetStatus.DRAFT, instructions: Optional[Text] = None, **additional_info, ) -> None: - """Create a FineTune with the necessary information. - - Args: - id (Text): ID of the Team Agent - name (Text): Name of the Team Agent - agents (List[Agent]): List of agents that the Team Agent uses. - description (Text, optional): The description of the team agent to be displayed in the aiXplain platform. Defaults to "". - llm_id (Text, optional): large language model. Defaults to GPT-4o (6646261c6eb563165658bbb1). - llm (LLM, optional): large language model object. Defaults to None. - supervisor_llm (LLM, optional): supervisor large language model object. Defaults to None. - mentalist_llm (LLM, optional): mentalist large language model object. Defaults to None. - inspector_llm (LLM, optional): inspector large language model object. Defaults to None. - supplier (Text): Supplier of the Team Agent. - version (Text): Version of the Team Agent. - backend_url (str): URL of the backend. - api_key (str): The TEAM API key used for authentication. - cost (Dict, optional): model price. Defaults to None. - use_mentalist_and_inspector (bool): Use Mentalist and Inspector tools. Defaults to True. - instructions (Text, optional): The instructions to guide the team agent (i.e. appended in the prompt of the team agent). Defaults to None. - """ super().__init__(id, name, description, api_key, supplier, version, cost=cost) self.additional_info = additional_info self.agents = agents self.llm_id = llm_id self.llm = llm self.use_mentalist = use_mentalist - self.use_inspector = use_inspector - self.max_inspectors = max_inspectors + self.inspectors = inspectors self.inspector_targets = inspector_targets + self.use_inspector = True if inspectors else False self.supervisor_llm = supervisor_llm self.mentalist_llm = mentalist_llm - self.inspector_llm = inspector_llm self.instructions = instructions if isinstance(status, str): try: @@ -208,6 +187,7 @@ def run( session_id=result_data.get("session_id"), intermediate_steps=result_data.get("intermediate_steps"), execution_stats=result_data.get("executionStats"), + critiques=result_data.get("critiques", ""), ), used_credits=result_data.get("usedCredits", 0.0), run_time=result_data.get("runTime", end - start), @@ -361,10 +341,6 @@ def to_dict(self) -> Dict: planner_id = self.mentalist_llm.id if self.mentalist_llm else self.llm_id else: planner_id = None - if self.use_inspector: - inspector_id = self.inspector_llm.id if self.inspector_llm else self.llm_id - else: - inspector_id = None return { "id": self.id, "name": self.name, @@ -376,8 +352,7 @@ def to_dict(self) -> Dict: "llmId": self.llm.id if self.llm else self.llm_id, "supervisorId": self.supervisor_llm.id if self.supervisor_llm else self.llm_id, "plannerId": planner_id, - "inspectorId": inspector_id, - "maxInspectors": self.max_inspectors, + "inspectors": [inspector.model_dump(by_alias=True) for inspector in self.inspectors], "inspectorTargets": [target.value for target in self.inspector_targets], "supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier, "version": self.version, @@ -452,4 +427,4 @@ def update(self) -> None: raise Exception(error_msg) def __repr__(self): - return f"TeamAgent: {self.name} (id={self.id})" \ No newline at end of file + return f"TeamAgent: {self.name} (id={self.id})" diff --git a/aixplain/modules/team_agent/inspector.py b/aixplain/modules/team_agent/inspector.py new file mode 100644 index 00000000..ab4fb79f --- /dev/null +++ b/aixplain/modules/team_agent/inspector.py @@ -0,0 +1,76 @@ +"""Pre-defined agent for inspecting the data flow within a team agent. + +Example usage: + +inspector = Inspector( + name="my_inspector", + model_id="my_model", + model_config={"prompt": "Check if the data is safe to use."}, + policy=InspectorPolicy.ADAPTIVE +) + +team = TeamAgent( + name="team" + agents=agents, + description="team description", + llm_id="xyz", + use_mentalist=True, + inspectors=[inspector], +) +""" + +from enum import Enum +from typing import Dict, Optional, Text + +from pydantic import field_validator + +from aixplain.modules.agent.model_with_params import ModelWithParams + + +AUTO_DEFAULT_MODEL_ID = "67fd9e2bef0365783d06e2f0" # GPT-4.1 Nano + + +class InspectorAuto(str, Enum): + """A list of keywords for inspectors configured automatically in the backend.""" + + CORRECTNESS = "correctness" + + def get_name(self) -> Text: + return "inspector_" + self.value + + +class InspectorPolicy(str, Enum): + """Which action to take if the inspector gives negative feedback.""" + + WARN = "warn" # log only, continue execution + ABORT = "abort" # stop execution + ADAPTIVE = "adaptive" # adjust execution according to feedback + + +class Inspector(ModelWithParams): + """Pre-defined agent for inspecting the data flow within a team agent. + + The model should be onboarded before using it as an inspector. + + Attributes: + name: The name of the inspector. + model_id: The ID of the model to wrap. + model_params: The configuration for the model. + policy: The policy for the inspector. Default is ADAPTIVE. + """ + + name: Text + model_params: Optional[Dict] = None + auto: Optional[InspectorAuto] = None + policy: InspectorPolicy = InspectorPolicy.ADAPTIVE + + def __init__(self, *args, **kwargs): + if kwargs.get("auto"): + kwargs["model_id"] = AUTO_DEFAULT_MODEL_ID + super().__init__(*args, **kwargs) + + @field_validator("name") + def validate_name(cls, v: Text) -> Text: + if v == "": + raise ValueError("name cannot be empty") + return v diff --git a/tests/functional/team_agent/inspector_functional_test.py b/tests/functional/team_agent/inspector_functional_test.py new file mode 100644 index 00000000..b415e719 --- /dev/null +++ b/tests/functional/team_agent/inspector_functional_test.py @@ -0,0 +1,370 @@ +""" +Functional tests for team agents with inspectors. +""" + +from dotenv import load_dotenv +from typing import Dict, List + +load_dotenv() + +import pytest + +from aixplain import aixplain_v2 as v2 +from aixplain.factories import AgentFactory, TeamAgentFactory +from aixplain.enums.asset_status import AssetStatus +from aixplain.modules.team_agent import InspectorTarget +from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy + +from tests.functional.team_agent.test_utils import ( + RUN_FILE, + read_data, + create_agents_from_input_map, + create_team_agent, + verify_response_generator, +) + + +@pytest.fixture(scope="function") +def delete_agents_and_team_agents(): + for team_agent in TeamAgentFactory.list()["results"]: + team_agent.delete() + for agent in AgentFactory.list()["results"]: + agent.delete() + + yield True + + for team_agent in TeamAgentFactory.list()["results"]: + team_agent.delete() + for agent in AgentFactory.list()["results"]: + agent.delete() + + +@pytest.fixture(scope="module", params=read_data(RUN_FILE)) +def run_input_map(request): + return request.param + + +def verify_inspector_steps(steps: Dict, inspector_names: List[str], inspector_targets: List[InspectorTarget]) -> None: + """Helper function to verify inspector steps""" + # Count occurrences of each inspector + inspector_counts = {} + for inspector_name in inspector_names: + inspector_steps = [step for step in steps if inspector_name.lower() in step.get("agent", "").lower()] + inspector_counts[inspector_name] = len(inspector_steps) + + # Verify all inspectors are present and have the same number of steps + assert len(inspector_counts) == len( + inspector_names + ), f"Expected {len(inspector_names)} inspectors, found {len(inspector_counts)}" + + if len(inspector_counts) > 0: + first_count = next(iter(inspector_counts.values())) + for inspector, count in inspector_counts.items(): + assert count > 0, f"Inspector {inspector} has no steps" + assert count == first_count, f"Inspector {inspector} has {count} steps, expected {first_count}" + print(f"Inspector {inspector} has {count} steps") + + # If OUTPUT is in inspector_targets, verify there are inspector steps after response generator + if InspectorTarget.OUTPUT in inspector_targets: + response_generator_steps = [step for step in steps if "response_generator" in step.get("agent", "").lower()] + assert len(response_generator_steps) == 1, "Expected exactly one response_generator step" + response_generator_index = steps.index(response_generator_steps[0]) + + inspector_steps_after = [ + step + for step in steps[response_generator_index + 1 :] + if any(inspector_name.lower() in step.get("agent", "").lower() for inspector_name in inspector_names) + ] + assert len(inspector_steps_after) > 0, "No inspector steps found after response generator step" + print(f"Found {len(inspector_steps_after)} inspector steps after response generator") + + # Verify inspector steps are the last steps + last_steps = steps[response_generator_index + 1 :] + assert all( + any(inspector_name.lower() in step.get("agent", "").lower() for inspector_name in inspector_names) + for step in last_steps + ), "Not all steps after response generator are inspector steps" + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_warn_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with warn policy inspector that provides feedback but continues execution""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create inspector with warn policy + inspector = Inspector( + name="warn_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Check if the steps are valid and provide feedback"}, + policy=InspectorPolicy.WARN, + ) + + # Create team agent with steps inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.STEPS], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + verify_inspector_steps(steps, ["warn_inspector"], [InspectorTarget.STEPS]) + verify_response_generator(steps) + + # Verify inspector runs and execution continues + inspector_steps = [step for step in steps if "warn_inspector" in step.get("agent", "").lower()] + assert len(inspector_steps) > 0, "Warn inspector should run at least once" + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_adaptive_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with adaptive inspector that runs multiple times""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create inspector with adaptive policy + inspector = Inspector( + name="adaptive_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Check if the steps are valid and provide feedback for improvement"}, + policy=InspectorPolicy.ADAPTIVE, + ) + + # Create team agent with steps inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.STEPS], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + print(*steps, sep="\n") + verify_inspector_steps(steps, ["adaptive_inspector"], [InspectorTarget.STEPS]) + verify_response_generator(steps) + + # Verify inspector runs multiple times + inspector_steps = [step for step in steps if "adaptive_inspector" in step.get("agent", "").lower()] + assert len(inspector_steps) > 1, "Adaptive inspector should run more than once" + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_abort_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with abort inspector that stops execution on critique""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create inspector with abort policy + inspector = Inspector( + name="abort_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Always find issues and provide negative feedback"}, + policy=InspectorPolicy.ABORT, + ) + + # Create team agent with steps inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.STEPS], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "I couldn't provide an answer because the inspector detected issues" in response["data"]["output"] + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + verify_inspector_steps(steps, ["abort_inspector"], [InspectorTarget.STEPS]) + verify_response_generator(steps) + + # Verify response generator comes right after first inspector critique + inspector_steps = [step for step in steps if "abort_inspector" in step.get("agent", "").lower()] + assert len(inspector_steps) == 1, "Abort inspector should only run once" + response_generator_index = steps.index( + [step for step in steps if "response_generator" in step.get("agent", "").lower()][0] + ) + assert ( + response_generator_index == steps.index(inspector_steps[0]) + 1 + ), "Response generator should come right after inspector critique" + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_output_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with output inspector that runs after response generator""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create inspector + inspector = Inspector( + name="output_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Check if the output is valid and provide feedback"}, + policy=InspectorPolicy.WARN, + ) + + # Create team agent with output inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.OUTPUT], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + verify_inspector_steps(steps, ["output_inspector"], [InspectorTarget.OUTPUT]) + verify_response_generator(steps) + + # Verify critiques are in response data + assert "critiques" in response["data"] + assert response["data"]["critiques"], "No critiques found in response data" + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_multiple_inspector_targets(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with inspectors targeting both steps and output""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create inspectors + steps_inspector = Inspector( + name="steps_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Check if the steps are valid"}, + policy=InspectorPolicy.WARN, + ) + output_inspector = Inspector( + name="output_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Check if the output is valid"}, + policy=InspectorPolicy.WARN, + ) + + # Create team agent with multiple inspectors + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[steps_inspector, output_inspector], + inspector_targets=[InspectorTarget.STEPS, InspectorTarget.OUTPUT], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + verify_inspector_steps(steps, ["steps_inspector", "output_inspector"], [InspectorTarget.STEPS, InspectorTarget.OUTPUT]) + verify_response_generator(steps) + + # Verify critiques are in response data + assert "critiques" in response["data"] + assert response["data"]["critiques"], "No critiques found in response data" + + team_agent.delete() diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index b526dcae..fc6b66af 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -17,24 +17,24 @@ """ import json from dotenv import load_dotenv +from uuid import uuid4 load_dotenv() + +import pytest + +from aixplain import aixplain_v2 as v2 from aixplain.factories import AgentFactory, TeamAgentFactory, ModelFactory from aixplain.enums.asset_status import AssetStatus from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier -from aixplain.modules.team_agent import InspectorTarget -from copy import copy -from uuid import uuid4 -import pytest - -from aixplain import aixplain_v2 as v2 - -RUN_FILE = "tests/functional/team_agent/data/team_agent_test_end2end.json" - -def read_data(data_path): - return json.load(open(data_path, "r")) +from tests.functional.team_agent.test_utils import ( + RUN_FILE, + read_data, + create_agents_from_input_map, + create_team_agent, +) @pytest.fixture(scope="function") @@ -57,107 +57,14 @@ def run_input_map(request): return request.param -def create_agents_from_input_map(run_input_map, deploy=True): - """Helper function to create agents from input map""" - agents = [] - for agent in run_input_map["agents"]: - tools = [] - if "model_tools" in agent: - for tool in agent["model_tools"]: - tool_ = copy(tool) - for supplier in Supplier: - if tool["supplier"] is not None and tool["supplier"].lower() in [ - supplier.value["code"].lower(), - supplier.value["name"].lower(), - ]: - tool_["supplier"] = supplier - break - tools.append(AgentFactory.create_model_tool(**tool_)) - if "pipeline_tools" in agent: - for tool in agent["pipeline_tools"]: - tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) - - agent = AgentFactory.create( - name=agent["agent_name"], - description=agent["agent_name"], - instructions=agent["agent_name"], - llm_id=agent["llm_id"], - tools=tools, - ) - if deploy: - agent.deploy() - agents.append(agent) - - return agents - - -def create_team_agent( - factory, agents, run_input_map, use_mentalist=True, use_inspector=True, num_inspectors=1, inspector_targets=None -): - """Helper function to create a team agent""" - if inspector_targets is None: - inspector_targets = [InspectorTarget.STEPS] - - team_agent = factory.create( - name=run_input_map["team_agent_name"], - agents=agents, - description=run_input_map["team_agent_name"], - llm_id=run_input_map["llm_id"], - use_mentalist=use_mentalist, - use_inspector=use_inspector, - num_inspectors=num_inspectors, - inspector_targets=inspector_targets, - ) - - return team_agent - - -def verify_inspector_steps(steps, num_inspectors): - """Helper function to verify inspector steps""" - # Count occurrences of each inspector - inspector_counts = {} - for i in range(num_inspectors): - inspector_name = f"inspector_{i}" - inspector_steps = [step for step in steps if inspector_name.lower() in step.get("agent", "").lower()] - inspector_counts[inspector_name] = len(inspector_steps) - - # Verify all inspectors are present and have the same number of steps - assert len(inspector_counts) == num_inspectors, f"Expected {num_inspectors} inspectors, found {len(inspector_counts)}" - - if len(inspector_counts) > 0: - first_count = next(iter(inspector_counts.values())) - for inspector, count in inspector_counts.items(): - assert count > 0, f"Inspector {inspector} has no steps" - assert count == first_count, f"Inspector {inspector} has {count} steps, expected {first_count}" - print(f"Inspector {inspector} has {count} steps") - - return inspector_counts - - -def verify_response_generator(steps, has_output_target=False): - """Helper function to verify response generator step""" - response_generator_steps = [step for step in steps if "response_generator" in step.get("agent", "").lower()] - assert ( - len(response_generator_steps) == 1 - ), f"Expected exactly one response_generator step, found {len(response_generator_steps)}" - - response_generator_step = response_generator_steps[0] - - if has_output_target: - assert response_generator_step[ - "thought" - ], "Response generator thought is empty, but should contain inspector feedback because OUTPUT is in inspector_targets" - print(f"Response generator thought with OUTPUT target: {response_generator_step['thought']}") - - return response_generator_step - - @pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) def test_end2end(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): assert delete_agents_and_team_agents agents = create_agents_from_input_map(run_input_map) - team_agent = create_team_agent(TeamAgentFactory, agents, run_input_map, use_mentalist=True, use_inspector=True) + team_agent = create_team_agent( + TeamAgentFactory, agents, run_input_map, use_mentalist=True, inspectors=[], inspector_targets=None + ) assert team_agent is not None assert team_agent.status == AssetStatus.DRAFT @@ -188,7 +95,9 @@ def test_draft_team_agent_update(run_input_map, TeamAgentFactory): agent.delete() agents = create_agents_from_input_map(run_input_map, deploy=False) - team_agent = create_team_agent(TeamAgentFactory, agents, run_input_map, use_mentalist=True, use_inspector=True) + team_agent = create_team_agent( + TeamAgentFactory, agents, run_input_map, use_mentalist=True, inspectors=[], inspector_targets=None + ) team_agent_name = str(uuid4()).replace("-", "") team_agent.name = team_agent_name @@ -222,7 +131,9 @@ def test_add_remove_agents_from_team_agent(run_input_map, delete_agents_and_team assert delete_agents_and_team_agents agents = create_agents_from_input_map(run_input_map, deploy=False) - team_agent = create_team_agent(TeamAgentFactory, agents, run_input_map, use_mentalist=True, use_inspector=True) + team_agent = create_team_agent( + TeamAgentFactory, agents, run_input_map, use_mentalist=True, inspectors=[], inspector_targets=None + ) assert team_agent is not None assert team_agent.status == AssetStatus.DRAFT @@ -321,7 +232,12 @@ def test_team_agent_with_parameterized_agents(run_input_map, delete_agents_and_t ) translation_agent.deploy() team_agent = create_team_agent( - TeamAgentFactory, [search_agent, translation_agent], run_input_map, use_mentalist=True, use_inspector=True + TeamAgentFactory, + [search_agent, translation_agent], + run_input_map, + use_mentalist=True, + inspectors=[], + inspector_targets=None, ) # Deploy team agent @@ -347,253 +263,6 @@ def test_team_agent_with_parameterized_agents(run_input_map, delete_agents_and_t translation_agent.delete() -@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) -def test_team_agent_with_inspector_params(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): - """Test team agent with custom inspector parameters""" - assert delete_agents_and_team_agents - - agents = create_agents_from_input_map(run_input_map) - - # Create team agent with custom inspector parameters - num_inspectors = 2 - team_agent = create_team_agent( - TeamAgentFactory, - agents, - run_input_map, - use_mentalist=True, - use_inspector=True, - num_inspectors=num_inspectors, - inspector_targets=["steps", "output"], - ) - - assert team_agent is not None - assert team_agent.status == AssetStatus.DRAFT - assert team_agent.use_mentalist is True - assert team_agent.use_inspector is True - assert team_agent.max_inspectors == num_inspectors - assert len(team_agent.inspector_targets) == 2 - assert InspectorTarget.STEPS in team_agent.inspector_targets - assert InspectorTarget.OUTPUT in team_agent.inspector_targets - - # deploy team agent - team_agent.deploy() - team_agent = TeamAgentFactory.get(team_agent.id) - assert team_agent is not None - assert team_agent.status == AssetStatus.ONBOARDED - assert team_agent.max_inspectors == num_inspectors - assert len(team_agent.inspector_targets) == 2 - - # Run the team agent - response = team_agent.run(data=run_input_map["query"]) - - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "data" in response - assert response["data"]["session_id"] is not None - assert response["data"]["output"] is not None - - # Check if intermediate steps contain inspector outputs - if "intermediate_steps" in response["data"]: - steps = response["data"]["intermediate_steps"] - - # Verify inspector steps - verify_inspector_steps(steps, num_inspectors) - - # Verify response generator - verify_response_generator(steps, has_output_target=True) - - team_agent.delete() - - -@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) -def test_team_agent_update_inspector_params(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): - """Test updating inspector parameters for a team agent""" - assert delete_agents_and_team_agents - - agents = create_agents_from_input_map(run_input_map) - - # Create team agent with initial inspector parameters - team_agent = create_team_agent( - TeamAgentFactory, - agents, - run_input_map, - use_mentalist=True, - use_inspector=True, - num_inspectors=1, - inspector_targets=["steps"], - ) - - assert team_agent is not None - assert team_agent.status == AssetStatus.DRAFT - assert team_agent.max_inspectors == 1 - assert len(team_agent.inspector_targets) == 1 - assert team_agent.inspector_targets[0] == InspectorTarget.STEPS - - # Update inspector parameters - team_agent.max_inspectors = 3 - team_agent.inspector_targets = [InspectorTarget.STEPS, InspectorTarget.OUTPUT] - team_agent.update() - - # Get the updated team agent - updated_team_agent = TeamAgentFactory.get(team_agent.id) - assert updated_team_agent.max_inspectors == 3 - assert len(updated_team_agent.inspector_targets) == 2 - assert InspectorTarget.STEPS in updated_team_agent.inspector_targets - assert InspectorTarget.OUTPUT in updated_team_agent.inspector_targets - - team_agent.delete() - - -@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) -def test_team_agent_with_steps_only_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): - """Test team agent with inspector targeting only steps""" - assert delete_agents_and_team_agents - - agents = create_agents_from_input_map(run_input_map) - - # Create team agent with steps-only inspector - num_inspectors = 1 - team_agent = create_team_agent( - TeamAgentFactory, - agents, - run_input_map, - use_mentalist=True, - use_inspector=True, - num_inspectors=num_inspectors, - inspector_targets=["steps"], - ) - - assert team_agent is not None - assert team_agent.status == AssetStatus.DRAFT - assert team_agent.max_inspectors == num_inspectors - assert len(team_agent.inspector_targets) == 1 - assert team_agent.inspector_targets[0] == InspectorTarget.STEPS - - # deploy team agent - team_agent.deploy() - team_agent = TeamAgentFactory.get(team_agent.id) - assert team_agent is not None - assert team_agent.status == AssetStatus.ONBOARDED - - # Run the team agent - response = team_agent.run(data=run_input_map["query"]) - - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - - # Check for inspector steps - if "intermediate_steps" in response["data"]: - steps = response["data"]["intermediate_steps"] - - # Verify inspector steps - verify_inspector_steps(steps, num_inspectors) - - # Verify response generator - response_generator_step = verify_response_generator(steps, has_output_target=False) - print(f"Response generator thought (STEPS only): {response_generator_step.get('thought', '')}") - - team_agent.delete() - - -@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) -def test_team_agent_with_output_only_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): - """Test team agent with inspector targeting only output""" - assert delete_agents_and_team_agents - - agents = create_agents_from_input_map(run_input_map) - - # Create team agent with output-only inspector - num_inspectors = 1 - team_agent = create_team_agent( - TeamAgentFactory, - agents, - run_input_map, - use_mentalist=True, - use_inspector=True, - num_inspectors=num_inspectors, - inspector_targets=["output"], - ) - - assert team_agent is not None - assert team_agent.status == AssetStatus.DRAFT - assert team_agent.max_inspectors == num_inspectors - assert len(team_agent.inspector_targets) == 1 - assert team_agent.inspector_targets[0] == InspectorTarget.OUTPUT - - # deploy team agent - team_agent.deploy() - team_agent = TeamAgentFactory.get(team_agent.id) - assert team_agent is not None - assert team_agent.status == AssetStatus.ONBOARDED - - # Run the team agent - response = team_agent.run(data=run_input_map["query"]) - - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - - # Check for inspector steps - if "intermediate_steps" in response["data"]: - steps = response["data"]["intermediate_steps"] - - # Verify response generator with OUTPUT target - verify_response_generator(steps, has_output_target=True) - - team_agent.delete() - - -@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) -def test_team_agent_with_multiple_inspectors(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): - """Test team agent with multiple inspectors""" - assert delete_agents_and_team_agents - - agents = create_agents_from_input_map(run_input_map) - - # Create team agent with multiple inspectors - num_inspectors = 5 # Testing with 5 inspectors - team_agent = create_team_agent( - TeamAgentFactory, - agents, - run_input_map, - use_mentalist=True, - use_inspector=True, - num_inspectors=num_inspectors, - inspector_targets=["steps"], - ) - - assert team_agent is not None - assert team_agent.status == AssetStatus.DRAFT - assert team_agent.max_inspectors == num_inspectors - - # deploy team agent - team_agent.deploy() - team_agent = TeamAgentFactory.get(team_agent.id) - assert team_agent is not None - assert team_agent.status == AssetStatus.ONBOARDED - - # Run the team agent - response = team_agent.run(data=run_input_map["query"]) - - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - - # Check for inspector steps - if "intermediate_steps" in response["data"]: - steps = response["data"]["intermediate_steps"] - - # Verify inspector steps - verify_inspector_steps(steps, num_inspectors) - - # Verify response generator - verify_response_generator(steps, has_output_target=False) - - team_agent.delete() - - def test_team_agent_with_instructions(delete_agents_and_team_agents): assert delete_agents_and_team_agents @@ -635,6 +304,7 @@ def test_team_agent_with_instructions(delete_agents_and_team_agents): agent_1.delete() agent_2.delete() + @pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) def test_team_agent_llm_parameter_preservation(delete_agents_and_team_agents, run_input_map, TeamAgentFactory): """Test that LLM parameters like temperature are preserved for all LLM roles in team agents.""" @@ -646,11 +316,10 @@ def test_team_agent_llm_parameter_preservation(delete_agents_and_team_agents, ru # Get LLM instances and customize their temperatures supervisor_llm = ModelFactory.get("671be4886eb56397e51f7541") # Anthropic Claude 3.5 Sonnet v1 mentalist_llm = ModelFactory.get("671be4886eb56397e51f7541") # Anthropic Claude 3.5 Sonnet v1 - inspector_llm = ModelFactory.get("671be4886eb56397e51f7541") # Anthropic Claude 3.5 Sonnet v1 + # Set custom temperatures supervisor_llm.temperature = 0.1 mentalist_llm.temperature = 0.3 - inspector_llm.temperature = 0.5 # Create a team agent with custom LLMs team_agent = TeamAgentFactory.create( @@ -658,26 +327,23 @@ def test_team_agent_llm_parameter_preservation(delete_agents_and_team_agents, ru agents=agents, supervisor_llm=supervisor_llm, mentalist_llm=mentalist_llm, - inspector_llm=inspector_llm, llm_id="671be4886eb56397e51f7541", # Still required even with custom LLMs description="A team agent for testing LLM parameter preservation", use_mentalist=True, - use_inspector=True, ) # Verify that temperature settings were preserved assert team_agent.supervisor_llm.temperature == 0.1 assert team_agent.mentalist_llm.temperature == 0.3 - assert team_agent.inspector_llm.temperature == 0.5 # Verify that the team agent's LLMs are the same instances as the originals assert id(team_agent.supervisor_llm) == id(supervisor_llm) assert id(team_agent.mentalist_llm) == id(mentalist_llm) - assert id(team_agent.inspector_llm) == id(inspector_llm) # Clean up team_agent.delete() + def test_run_team_agent_with_expected_output(): from pydantic import BaseModel from typing import Optional, List diff --git a/tests/functional/team_agent/test_utils.py b/tests/functional/team_agent/test_utils.py new file mode 100644 index 00000000..16e1b421 --- /dev/null +++ b/tests/functional/team_agent/test_utils.py @@ -0,0 +1,78 @@ +""" +Shared test utilities for team agent tests. +""" + +import json +from copy import copy +from typing import Dict + +from aixplain.factories import AgentFactory +from aixplain.enums.supplier import Supplier +from aixplain.modules.team_agent import InspectorTarget + + +RUN_FILE = "tests/functional/team_agent/data/team_agent_test_end2end.json" + + +def read_data(data_path): + return json.load(open(data_path, "r")) + + +def create_agents_from_input_map(run_input_map, deploy=True): + """Helper function to create agents from input map""" + agents = [] + for agent in run_input_map["agents"]: + tools = [] + if "model_tools" in agent: + for tool in agent["model_tools"]: + tool_ = copy(tool) + for supplier in Supplier: + if tool["supplier"] is not None and tool["supplier"].lower() in [ + supplier.value["code"].lower(), + supplier.value["name"].lower(), + ]: + tool_["supplier"] = supplier + break + tools.append(AgentFactory.create_model_tool(**tool_)) + if "pipeline_tools" in agent: + for tool in agent["pipeline_tools"]: + tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) + + agent = AgentFactory.create( + name=agent["agent_name"], + description=agent["agent_name"], + instructions=agent["agent_name"], + llm_id=agent["llm_id"], + tools=tools, + ) + if deploy: + agent.deploy() + agents.append(agent) + + return agents + + +def create_team_agent(factory, agents, run_input_map, use_mentalist=True, inspectors=[], inspector_targets=None): + """Helper function to create a team agent""" + if inspector_targets is None: + inspector_targets = [InspectorTarget.STEPS] + + team_agent = factory.create( + name=run_input_map["team_agent_name"], + agents=agents, + description=run_input_map["team_agent_name"], + llm_id=run_input_map["llm_id"], + use_mentalist=use_mentalist, + inspectors=inspectors, + inspector_targets=inspector_targets, + ) + + return team_agent + + +def verify_response_generator(steps: Dict) -> None: + """Helper function to verify response generator step""" + response_generator_steps = [step for step in steps if "response_generator" in step.get("agent", "").lower()] + assert ( + len(response_generator_steps) == 1 + ), f"Expected exactly one response_generator step, found {len(response_generator_steps)}" diff --git a/tests/unit/team_agent/inspector_test.py b/tests/unit/team_agent/inspector_test.py new file mode 100644 index 00000000..e21d6c53 --- /dev/null +++ b/tests/unit/team_agent/inspector_test.py @@ -0,0 +1,144 @@ +import pytest +from unittest.mock import patch, MagicMock +from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy, InspectorAuto, AUTO_DEFAULT_MODEL_ID +from aixplain.factories.team_agent_factory.inspector_factory import InspectorFactory +from aixplain.enums.function import Function +from aixplain.enums.asset_status import AssetStatus + +# Test data +INSPECTOR_CONFIG = { + "name": "test_inspector", + "model_id": "test_model_id", + "model_config": {"prompt": "Check if the data is safe to use."}, + "policy": InspectorPolicy.ADAPTIVE, +} + +MOCK_MODEL_RESPONSE = { + "id": "test_model_id", + "name": "test_model", + "description": "Test model description", + "createdAt": "2024-03-20T10:00:00Z", + "supplier": "test_supplier", + "pricing": {"per_token": 0.001}, + "version": {"id": "v1"}, + "params": [], + "attributes": [], + "api_key": "test_api_key", +} + + +def test_inspector_creation(): + """Test basic inspector creation with valid parameters""" + inspector = Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + assert inspector.name == INSPECTOR_CONFIG["name"] + assert inspector.model_id == INSPECTOR_CONFIG["model_id"] + assert inspector.model_params == INSPECTOR_CONFIG["model_config"] + assert inspector.policy == INSPECTOR_CONFIG["policy"] + assert inspector.auto is None + + +def test_inspector_auto_creation(): + """Test inspector creation with auto configuration""" + inspector = Inspector(name="auto_inspector", auto=InspectorAuto.CORRECTNESS, policy=InspectorPolicy.WARN) + + assert inspector.name == "auto_inspector" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == InspectorPolicy.WARN + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None + + +def test_inspector_name_validation(): + """Test inspector name validation""" + with pytest.raises(ValueError, match="name cannot be empty"): + Inspector(name="", model_id="test_model_id") + + +def test_inspector_factory_create_from_model(): + """Test creating inspector from model using factory""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.ONBOARDED.value, + "function": {"id": Function.GUARDRAILS.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + inspector = InspectorFactory.create_from_model( + name=INSPECTOR_CONFIG["name"], + model=INSPECTOR_CONFIG["model_id"], + model_config=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + assert inspector.name == INSPECTOR_CONFIG["name"] + assert inspector.model_id == INSPECTOR_CONFIG["model_id"] + assert inspector.model_params == INSPECTOR_CONFIG["model_config"] + assert inspector.policy == INSPECTOR_CONFIG["policy"] + + +def test_inspector_factory_create_from_model_invalid_status(): + """Test creating inspector from model with invalid status""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.DRAFT.value, + "function": {"id": Function.GUARDRAILS.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + with pytest.raises(ValueError, match="is not onboarded"): + InspectorFactory.create_from_model( + name=INSPECTOR_CONFIG["name"], + model=INSPECTOR_CONFIG["model_id"], + model_config=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + +def test_inspector_factory_create_from_model_invalid_function(): + """Test creating inspector from model with invalid function""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.ONBOARDED.value, + "function": {"id": Function.TRANSLATION.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + with pytest.raises(ValueError, match="models are supported"): + InspectorFactory.create_from_model( + name=INSPECTOR_CONFIG["name"], + model=INSPECTOR_CONFIG["model_id"], + model_config=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + +def test_inspector_factory_create_auto(): + """Test creating auto-configured inspector using factory""" + inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS, name="custom_name", policy=InspectorPolicy.ABORT) + + assert inspector.name == "custom_name" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == InspectorPolicy.ABORT + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None + + +def test_inspector_factory_create_auto_default_name(): + """Test creating auto-configured inspector with default name""" + inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS) + + assert inspector.name == "inspector_correctness" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == InspectorPolicy.ADAPTIVE # default policy diff --git a/tests/unit/team_agent_test.py b/tests/unit/team_agent/team_agent_test.py similarity index 67% rename from tests/unit/team_agent_test.py rename to tests/unit/team_agent/team_agent_test.py index e84e2e34..5a154d0b 100644 --- a/tests/unit/team_agent_test.py +++ b/tests/unit/team_agent/team_agent_test.py @@ -8,6 +8,7 @@ from aixplain.factories import AgentFactory from aixplain.modules.agent import Agent from aixplain.modules.team_agent import TeamAgent, InspectorTarget +from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.utils import config @@ -92,55 +93,36 @@ def test_to_dict(): description="Test Team Agent Description", llm_id="6646261c6eb563165658bbb1", use_mentalist=False, - use_inspector=False, + inspectors=[ + Inspector( + name="Test Inspector", + model_id="6646261c6eb563165658bbb1", + model_params={"prompt": "Test Prompt"}, + policy=InspectorPolicy.ADAPTIVE, + ) + ], + inspector_targets=[InspectorTarget.STEPS, InspectorTarget.OUTPUT], ) team_agent_dict = team_agent.to_dict() + assert team_agent_dict["id"] == "123" assert team_agent_dict["name"] == "Test Team Agent(-)" assert team_agent_dict["description"] == "Test Team Agent Description" assert team_agent_dict["llmId"] == "6646261c6eb563165658bbb1" assert team_agent_dict["supervisorId"] == "6646261c6eb563165658bbb1" - assert team_agent_dict["plannerId"] is None - assert team_agent_dict["inspectorId"] is None - assert len(team_agent_dict["agents"]) == 1 + assert team_agent_dict["agents"][0]["assetId"] == "" assert team_agent_dict["agents"][0]["number"] == 0 assert team_agent_dict["agents"][0]["type"] == "AGENT" assert team_agent_dict["agents"][0]["label"] == "AGENT" - -def test_to_dict_with_inspector_params(): - team_agent = TeamAgent( - id="123", - name="Test Team Agent(-)", - agents=[ - Agent( - id="", - name="Test Agent(-)", - description="Test Agent Description", - instructions="Test Agent Role", - llm_id="6646261c6eb563165658bbb1", - tools=[ModelTool(function="text-generation")], - ) - ], - description="Test Team Agent Description", - llm_id="6646261c6eb563165658bbb1", - use_mentalist=True, - use_inspector=True, - max_inspectors=2, - inspector_targets=[InspectorTarget.STEPS, InspectorTarget.OUTPUT], - ) - - team_agent_dict = team_agent.to_dict() - assert team_agent_dict["id"] == "123" - assert team_agent_dict["name"] == "Test Team Agent(-)" - assert team_agent_dict["description"] == "Test Team Agent Description" - assert team_agent_dict["llmId"] == "6646261c6eb563165658bbb1" - assert team_agent_dict["supervisorId"] == "6646261c6eb563165658bbb1" - assert team_agent_dict["plannerId"] == "6646261c6eb563165658bbb1" - assert team_agent_dict["inspectorId"] == "6646261c6eb563165658bbb1" - assert team_agent_dict["maxInspectors"] == 2 + assert team_agent_dict["plannerId"] is None + assert len(team_agent_dict["inspectors"]) == 1 + assert team_agent_dict["inspectors"][0]["name"] == "Test Inspector" + assert team_agent_dict["inspectors"][0]["modelId"] == "6646261c6eb563165658bbb1" + assert team_agent_dict["inspectors"][0]["modelParams"] == {"prompt": "Test Prompt"} + assert team_agent_dict["inspectors"][0]["policy"] == "adaptive" assert team_agent_dict["inspectorTargets"] == ["steps", "output"] assert len(team_agent_dict["agents"]) == 1 @@ -221,7 +203,6 @@ def test_create_team_agent(mock_model_factory_get): "agents": [{"assetId": "123", "type": "AGENT", "number": 0, "label": "AGENT"}], "links": [], "plannerId": "6646261c6eb563165658bbb1", - "inspectorId": "6646261c6eb563165658bbb1", "supervisorId": "6646261c6eb563165658bbb1", "createdAt": "2024-10-28T19:30:25.344Z", "updatedAt": "2024-10-28T19:30:25.344Z", @@ -234,14 +215,14 @@ def test_create_team_agent(mock_model_factory_get): llm_id="6646261c6eb563165658bbb1", description="TEST Multi agent", use_mentalist=True, - use_inspector=True, + # TODO: inspectors=[Inspector(name="Test Inspector", model_id="6646261c6eb563165658bbb1", model_params={"prompt": "Test Prompt"}, policy=InspectorPolicy.ADAPTIVE)], + # TODO: inspector_targets=[InspectorTarget.STEPS, InspectorTarget.OUTPUT], ) assert team_agent.id is not None assert team_agent.name == team_ref_response["name"] assert team_agent.description == team_ref_response["description"] assert team_agent.llm_id == team_ref_response["llmId"] assert team_agent.use_mentalist is True - assert team_agent.use_inspector is True assert team_agent.status == AssetStatus.DRAFT assert len(team_agent.agents) == 1 assert team_agent.agents[0].id == team_ref_response["agents"][0]["assetId"] @@ -269,114 +250,6 @@ def test_create_team_agent(mock_model_factory_get): assert team_agent.status.value == "onboarded" -@patch("aixplain.factories.model_factory.ModelFactory.get") -def test_create_team_agent_with_inspector_params(mock_model_factory_get): - from aixplain.modules import Model - from aixplain.enums import Function - - # Mock the model factory response - mock_model = Model( - id="6646261c6eb563165658bbb1", name="Test LLM", description="Test LLM Description", function=Function.TEXT_GENERATION - ) - mock_model_factory_get.return_value = mock_model - - with requests_mock.Mocker() as mock: - headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} - # MOCK GET LLM - url = urljoin(config.BACKEND_URL, "sdk/models/6646261c6eb563165658bbb1") - model_ref_response = { - "id": "6646261c6eb563165658bbb1", - "name": "Test LLM", - "description": "Test LLM Description", - "function": {"id": "text-generation"}, - "supplier": "openai", - "version": {"id": "1.0"}, - "status": "onboarded", - "pricing": {"currency": "USD", "value": 0.0}, - } - mock.get(url, headers=headers, json=model_ref_response) - - # AGENT MOCK CREATION - url = urljoin(config.BACKEND_URL, "sdk/agents") - ref_response = { - "id": "123", - "name": "Test Agent(-)", - "description": "Test Agent Description", - "role": "Test Agent Role", - "teamId": "123", - "version": "1.0", - "status": "draft", - "llmId": "6646261c6eb563165658bbb1", - "pricing": {"currency": "USD", "value": 0.0}, - "assets": [ - { - "type": "model", - "supplier": "openai", - "version": "1.0", - "assetId": "6646261c6eb563165658bbb1", - "function": "text-generation", - } - ], - } - mock.post(url, headers=headers, json=ref_response) - - agent = AgentFactory.create( - name="Test Agent(-)", - description="Test Agent Description", - instructions="Test Agent Role", - llm_id="6646261c6eb563165658bbb1", - tools=[ModelTool(model="6646261c6eb563165658bbb1")], - ) - - # AGENT MOCK GET - url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}") - mock.get(url, headers=headers, json=ref_response) - - # TEAM MOCK CREATION - url = urljoin(config.BACKEND_URL, "sdk/agent-communities") - team_ref_response = { - "id": "team_agent_123", - "name": "TEST Multi agent(-)", - "status": "draft", - "teamId": 645, - "description": "TEST Multi agent", - "llmId": "6646261c6eb563165658bbb1", - "assets": [], - "agents": [{"assetId": "123", "type": "AGENT", "number": 0, "label": "AGENT"}], - "links": [], - "plannerId": "6646261c6eb563165658bbb1", - "inspectorId": "6646261c6eb563165658bbb1", - "supervisorId": "6646261c6eb563165658bbb1", - "maxInspectors": 3, - "inspectorTargets": ["steps", "output"], - "createdAt": "2024-10-28T19:30:25.344Z", - "updatedAt": "2024-10-28T19:30:25.344Z", - } - mock.post(url, headers=headers, json=team_ref_response) - - team_agent = TeamAgentFactory.create( - name="TEST Multi agent(-)", - agents=[agent], - llm_id="6646261c6eb563165658bbb1", - description="TEST Multi agent", - use_mentalist=True, - use_inspector=True, - num_inspectors=3, - inspector_targets=[InspectorTarget.STEPS, InspectorTarget.OUTPUT], - ) - assert team_agent.id is not None - assert team_agent.name == team_ref_response["name"] - assert team_agent.description == team_ref_response["description"] - assert team_agent.llm_id == team_ref_response["llmId"] - assert team_agent.use_mentalist is True - assert team_agent.use_inspector is True - assert team_agent.max_inspectors == 3 - assert team_agent.inspector_targets == [InspectorTarget.STEPS, InspectorTarget.OUTPUT] - assert team_agent.status == AssetStatus.DRAFT - assert len(team_agent.agents) == 1 - assert team_agent.agents[0].id == team_ref_response["agents"][0]["assetId"] - - def test_fail_inspector_without_mentalist(): with pytest.raises(Exception) as exc_info: TeamAgentFactory.create( @@ -392,7 +265,14 @@ def test_fail_inspector_without_mentalist(): ) ], use_mentalist=False, - use_inspector=True, + inspectors=[ + Inspector( + name="Test Inspector", + model_id="6646261c6eb563165658bbb1", + model_params={"prompt": "Test Prompt"}, + policy=InspectorPolicy.ADAPTIVE, + ) + ], ) assert "you must enable Mentalist" in str(exc_info.value) @@ -413,33 +293,18 @@ def test_fail_invalid_inspector_target(): ) ], use_mentalist=True, - use_inspector=True, - inspector_targets=["invalid_target"], - ) - - assert "Invalid inspector target" in str(exc_info.value) - - -def test_fail_zero_inspectors(): - with pytest.raises(Exception) as exc_info: - TeamAgentFactory.create( - name="Test Team Agent(-)", - agents=[ - Agent( - id="123", - name="Test Agent(-)", - description="Test Agent Description", - instructions="Test Agent Role", - llm_id="6646261c6eb563165658bbb1", - tools=[ModelTool(function="text-generation")], + inspectors=[ + Inspector( + name="Test Inspector", + model_id="6646261c6eb563165658bbb1", + model_params={"prompt": "Test Prompt"}, + policy=InspectorPolicy.ADAPTIVE, ) ], - use_mentalist=True, - use_inspector=True, - num_inspectors=0, + inspector_targets=["invalid_target"], ) - assert "The number of inspectors must be greater than 0" in str(exc_info.value) + assert "Invalid inspector target" in str(exc_info.value) def test_build_team_agent(mocker): From f6422db7770dbfdcf43a1e8b61273c8add068599 Mon Sep 17 00:00:00 2001 From: Muhammad-Elmallah <145364766+Muhammad-Elmallah@users.noreply.github.com> Date: Mon, 2 Jun 2025 17:56:32 +0300 Subject: [PATCH 39/62] Eng 2040 ai r 2 add splitting features to the sdk (#546) * Adding Splitting Features of aiR to SDK * adding the splitting parameters in the additiona_param field * adding splitting feature * adding splitting feature * adding enumerator for the splitting options * adding enumerator for the splitting options * fixing agent unit test --- aixplain/enums/splitting_options.py | 35 +++++++++++ aixplain/modules/model/index_model.py | 43 +++++++++---- aixplain/modules/pipeline/designer/base.py | 25 ++------ aixplain/modules/pipeline/designer/nodes.py | 11 ++-- aixplain/v2/api_key.py | 4 +- aixplain/v2/client.py | 13 +--- aixplain/v2/enums.py | 4 +- aixplain/v2/model.py | 10 +-- aixplain/v2/resource.py | 24 ++----- aixplain/v2/script.py | 1 - tests/conftest.py | 22 ++----- .../data_asset/corpus_onboarding_test.py | 2 +- tests/functional/model/run_model_test.py | 63 ++++++++++++++----- tests/functional/pipelines/create_test.py | 1 + tests/functional/pipelines/designer_test.py | 28 +++------ tests/functional/pipelines/run_test.py | 29 +++------ tests/unit/agent/sql_tool_test.py | 2 +- tests/unit/index_model_test.py | 10 +++ tests/unit/v2/test_core.py | 15 +---- 19 files changed, 173 insertions(+), 169 deletions(-) create mode 100644 aixplain/enums/splitting_options.py diff --git a/aixplain/enums/splitting_options.py b/aixplain/enums/splitting_options.py new file mode 100644 index 00000000..51dcaab3 --- /dev/null +++ b/aixplain/enums/splitting_options.py @@ -0,0 +1,35 @@ +__author__ = "aiXplain" + +""" +Copyright 2023 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: aiXplain team +Date: May 30th 2025 +Description: + Splitting Options Enum +""" + +from enum import Enum + + +class SplittingOptions(str, Enum): + WORD = "word" + SENTENCE = "sentence" + PASSAGE = "passage" + PAGE = "page" + LINE = "line" + + def __str__(self): + return self._value_ diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index 15a1a474..523b3f64 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -6,11 +6,7 @@ from aixplain.modules.model.record import Record from enum import Enum from typing import List - -import os - -from urllib.parse import urljoin -from aixplain.utils.file_utils import _request_with_retry +from aixplain.enums.splitting_options import SplittingOptions class IndexFilterOperator(Enum): @@ -42,6 +38,18 @@ def to_dict(self): } +class Splitter: + def __init__( + self, + split: bool = False, + split_by: SplittingOptions = SplittingOptions.WORD, + split_length: int = 1, + split_overlap: int = 0, + ): + self.split = split + self.split_by = split_by + self.split_length = split_length + self.split_overlap = split_overlap class IndexModel(Model): @@ -109,8 +117,6 @@ def to_dict(self) -> Dict: data["collection_type"] = self.version.split("-", 1)[0] return data - - def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: """Search for documents in the index @@ -140,21 +146,24 @@ def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) - "data": query or uri, "dataType": value_type, "filters": [filter.to_dict() for filter in filters], - "payload": {"uri": uri, "value_type": value_type, "top_k": top_k} + "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, } return self.run(data=data) - def upsert(self, documents: List[Record]) -> ModelResponse: + def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: """Upsert documents into the index Args: documents (List[Record]): List of documents to be upserted + splitter (Splitter, optional): Splitter to be applied. Defaults to None. Returns: ModelResponse: Response from the indexing service - Example: + Examples: index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) + Splitter in the above example is optional and can be used to split the documents into smaller chunks. """ # Validate documents for doc in documents: @@ -162,7 +171,19 @@ def upsert(self, documents: List[Record]) -> ModelResponse: # Convert documents to payloads payloads = [doc.to_dict() for doc in documents] # Build payload - data = {"action": "ingest", "data": payloads} + data = { + "action": "ingest", + "data": payloads, + } + if splitter and splitter.split: + data["additional_params"] = { + "splitter": { + "split": splitter.split, + "split_by": splitter.split_by, + "split_length": splitter.split_length, + "split_overlap": splitter.split_overlap, + } + } # Run the indexing service response = self.run(data=data) if response.status == ResponseStatus.SUCCESS: diff --git a/aixplain/modules/pipeline/designer/base.py b/aixplain/modules/pipeline/designer/base.py index 5d499a1c..6ff18b32 100644 --- a/aixplain/modules/pipeline/designer/base.py +++ b/aixplain/modules/pipeline/designer/base.py @@ -148,14 +148,10 @@ def __init__( to_param = to_param.code assert from_param in from_node.outputs, ( - "Invalid from param. " - "Make sure all input params are already linked accordingly" + "Invalid from param. " "Make sure all input params are already linked accordingly" ) - assert to_param in to_node.inputs, ( - "Invalid to param. " - "Make sure all output params are already linked accordingly" - ) + assert to_param in to_node.inputs, "Invalid to param. " "Make sure all output params are already linked accordingly" tp_instance = to_node.inputs[to_param] fp_instance = from_node.outputs[from_param] @@ -203,9 +199,7 @@ def validate(self): # Should we check for data type mismatch? if from_param.data_type and to_param.data_type: if from_param.data_type != to_param.data_type: - raise ValueError( - f"Data type mismatch between {from_param.data_type} and {to_param.data_type}" - ) # noqa + raise ValueError(f"Data type mismatch between {from_param.data_type} and {to_param.data_type}") # noqa def attach_to(self, pipeline: "DesignerPipeline"): """ @@ -258,9 +252,7 @@ def add_param(self, param: Param) -> None: if not hasattr(self, param.code): setattr(self, param.code, param) - def _create_param( - self, code: str, data_type: DataType = None, value: any = None - ) -> Param: + def _create_param(self, code: str, data_type: DataType = None, value: any = None) -> Param: raise NotImplementedError() def create_param( @@ -304,10 +296,7 @@ def special_prompt_handling(self, code: str, value: str) -> None: if not isinstance(self.node, AssetNode): return - if ( - not hasattr(self.node, "asset") - or self.node.asset.function != "text-generation" - ): + if not hasattr(self.node, "asset") or self.node.asset.function != "text-generation": return matches = find_prompt_params(value) @@ -361,9 +350,7 @@ def _create_param( class Outputs(ParamProxy): - def _create_param( - self, code: str, data_type: DataType = None, value: any = None - ) -> OutputParam: + def _create_param(self, code: str, data_type: DataType = None, value: any = None) -> OutputParam: return OutputParam(code=code, data_type=data_type, value=value) diff --git a/aixplain/modules/pipeline/designer/nodes.py b/aixplain/modules/pipeline/designer/nodes.py index 10c9d3eb..e81be81a 100644 --- a/aixplain/modules/pipeline/designer/nodes.py +++ b/aixplain/modules/pipeline/designer/nodes.py @@ -79,9 +79,7 @@ def populate_asset(self): if self.function: if self.asset.function.value != self.function: - raise ValueError( - f"Function {self.function} is not supported by asset {self.asset_id}" - ) + raise ValueError(f"Function {self.function} is not supported by asset {self.asset_id}") else: self.function = self.asset.function.value @@ -409,13 +407,14 @@ def link( link = super().link(to_node, from_param, to_param) if isinstance(from_param, str): - assert from_param in self.outputs, f"Decision node has no input param called {from_param}, node linking validation is broken, please report this issue." + assert ( + from_param in self.outputs + ), f"Decision node has no input param called {from_param}, node linking validation is broken, please report this issue." from_param = self.outputs[from_param] if from_param.code == "data": if not self.inputs.passthrough.link_: - raise ValueError("To able to infer data source, " - "passthrough input param should be linked first.") + raise ValueError("To able to infer data source, " "passthrough input param should be linked first.") # Infer data source from the passthrough node link.data_source_id = self.inputs.passthrough.link_.from_node.number diff --git a/aixplain/v2/api_key.py b/aixplain/v2/api_key.py index 048bc0da..5d23bac4 100644 --- a/aixplain/v2/api_key.py +++ b/aixplain/v2/api_key.py @@ -61,9 +61,7 @@ def update(cls, api_key: "APIKey") -> "APIKey": return APIKeyFactory.update(api_key) @classmethod - def get_usage_limits( - cls, api_key: Text = None, asset_id: Optional[Text] = None - ) -> List["APIKeyUsageLimit"]: + def get_usage_limits(cls, api_key: Text = None, asset_id: Optional[Text] = None) -> List["APIKeyUsageLimit"]: from aixplain.factories import APIKeyFactory from aixplain.utils import config diff --git a/aixplain/v2/client.py b/aixplain/v2/client.py index 78727a00..46b23224 100644 --- a/aixplain/v2/client.py +++ b/aixplain/v2/client.py @@ -10,9 +10,7 @@ DEFAULT_RETRY_STATUS_FORCELIST = [500, 502, 503, 504] -def create_retry_session( - total=None, backoff_factor=None, status_forcelist=None, **kwargs -): +def create_retry_session(total=None, backoff_factor=None, status_forcelist=None, **kwargs): """ Creates a requests.Session with a specified retry strategy. @@ -43,7 +41,6 @@ def create_retry_session( class AixplainClient: - def __init__( self, base_url: str, @@ -69,14 +66,10 @@ def __init__( self.aixplain_api_key = aixplain_api_key if not (self.aixplain_api_key or self.team_api_key): - raise ValueError( - "Either `aixplain_api_key` or `team_api_key` should be set" - ) + raise ValueError("Either `aixplain_api_key` or `team_api_key` should be set") if self.aixplain_api_key and self.team_api_key: - raise ValueError( - "Either `aixplain_api_key` or `team_api_key` should be set" - ) + raise ValueError("Either `aixplain_api_key` or `team_api_key` should be set") headers = {"Content-Type": "application/json"} if self.aixplain_api_key: diff --git a/aixplain/v2/enums.py b/aixplain/v2/enums.py index 34e13f07..2eb47e52 100644 --- a/aixplain/v2/enums.py +++ b/aixplain/v2/enums.py @@ -43,9 +43,7 @@ class Function(str, Enum): AUTO_MASK_GENERATION = "auto-mask-generation" DOCUMENT_IMAGE_PARSING = "document-image-parsing" ENTITY_LINKING = "entity-linking" - REFERENCELESS_TEXT_GENERATION_METRIC_DEFAULT = ( - "referenceless-text-generation-metric-default" - ) + REFERENCELESS_TEXT_GENERATION_METRIC_DEFAULT = "referenceless-text-generation-metric-default" FILL_TEXT_MASK = "fill-text-mask" SUBTITLING_TRANSLATION = "subtitling-translation" INSTANCE_SEGMENTATION = "instance-segmentation" diff --git a/aixplain/v2/model.py b/aixplain/v2/model.py index 5cc1d9f2..d04b8363 100644 --- a/aixplain/v2/model.py +++ b/aixplain/v2/model.py @@ -69,11 +69,7 @@ def create_utility_model( ) -> "Model": from aixplain.factories import ModelFactory - return Model( - ModelFactory.create_utility_model( - name, code, inputs, description, output_examples, api_key - ) - ) + return Model(ModelFactory.create_utility_model(name, code, inputs, description, output_examples, api_key)) @classmethod def list_host_machines(cls, api_key: str = None) -> List[str]: @@ -135,9 +131,7 @@ def onboard_model( ) -> dict: from aixplain.factories import ModelFactory - return ModelFactory.onboard_model( - model_id, image_tag, image_hash, host_machine=host_machine, api_key=api_key - ) + return ModelFactory.onboard_model(model_id, image_tag, image_hash, host_machine=host_machine, api_key=api_key) @classmethod def deploy_hugging_face_model( diff --git a/aixplain/v2/resource.py b/aixplain/v2/resource.py index 8c6d2c93..da74ef4b 100644 --- a/aixplain/v2/resource.py +++ b/aixplain/v2/resource.py @@ -72,9 +72,7 @@ def save(self): else: self._action("post", **self._obj) - def _action( - self, method: str = None, action_paths: List[str] = None, **kwargs - ) -> requests.Response: + def _action(self, method: str = None, action_paths: List[str] = None, **kwargs) -> requests.Response: """ Internal method to perform actions on the resource. @@ -92,9 +90,7 @@ def _action( 'id' attribute is missing. """ - assert getattr( - self, "RESOURCE_PATH" - ), "Subclasses of 'BaseResource' must specify 'RESOURCE_PATH'" + assert getattr(self, "RESOURCE_PATH"), "Subclasses of 'BaseResource' must specify 'RESOURCE_PATH'" if not self.id: raise ValueError("Action call requires an 'id' attribute") @@ -236,9 +232,7 @@ def list(cls: Type[R], **kwargs: Unpack[L]) -> Page[R]: Page[R]: Page of BaseResource instances """ - assert getattr( - cls, "RESOURCE_PATH" - ), "Subclasses of 'BaseResource' must specify 'RESOURCE_PATH'" + assert getattr(cls, "RESOURCE_PATH"), "Subclasses of 'BaseResource' must specify 'RESOURCE_PATH'" # TypedDict does not support default values, so we need to manually set them # Dataclasses might be a better fit, but we're using the TypedDict to ensure @@ -251,9 +245,7 @@ def list(cls: Type[R], **kwargs: Unpack[L]) -> Page[R]: filters = cls._populate_filters(params) paginate_path = cls._populate_path(cls.RESOURCE_PATH) print(paginate_path, filters) - response = cls.context.client.request( - cls.PAGINATE_METHOD, paginate_path, json=filters - ) + response = cls.context.client.request(cls.PAGINATE_METHOD, paginate_path, json=filters) return cls._build_page(response, **kwargs) @classmethod @@ -357,9 +349,7 @@ def get(cls: Type[R], id: Any, **kwargs: Unpack[G]) -> R: Raises: ValueError: If 'RESOURCE_PATH' is not defined by the subclass. """ - assert getattr( - cls, "RESOURCE_PATH" - ), "Subclasses of 'BaseResource' must specify 'RESOURCE_PATH'" + assert getattr(cls, "RESOURCE_PATH"), "Subclasses of 'BaseResource' must specify 'RESOURCE_PATH'" path = f"{cls.RESOURCE_PATH}/{id}" obj = cls.context.client.get_obj(path, **kwargs) @@ -380,9 +370,7 @@ def create(cls, *args, **kwargs: Unpack[C]) -> R: Returns: BaseResource: The created resource. """ - assert getattr( - cls, "RESOURCE_PATH" - ), "Subclasses of 'BaseResource' must specify 'RESOURCE_PATH'" + assert getattr(cls, "RESOURCE_PATH"), "Subclasses of 'BaseResource' must specify 'RESOURCE_PATH'" obj = cls.context.client.request("post", cls.RESOURCE_PATH, *args, **kwargs) return cls(obj) diff --git a/aixplain/v2/script.py b/aixplain/v2/script.py index 38f6b2b4..1ee7076a 100644 --- a/aixplain/v2/script.py +++ b/aixplain/v2/script.py @@ -2,7 +2,6 @@ class Script(BaseResource): - @classmethod def upload(cls, script_path: str) -> "Script": """Upload a script to the server. diff --git a/tests/conftest.py b/tests/conftest.py index 461e8146..68a4c787 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,9 +23,7 @@ def pytest_addoption(parser: pytest.Parser): # Here we're adding the options for the pipeline version and the sdk version parser.addoption(f"{PIPELINE_VERSION_ARG}", action="store", help="pipeline version") parser.addoption(f"{SDK_VERSION_ARG}", action="store", help="sdk version") - parser.addoption( - f"{SDK_VERSION_PARAM_ARG}", action="store", help="sdk version parameter" - ) + parser.addoption(f"{SDK_VERSION_PARAM_ARG}", action="store", help="sdk version parameter") def filter_items(items: list, param_name: str, predicate: Callable): @@ -39,9 +37,7 @@ def filter_items(items: list, param_name: str, predicate: Callable): items[:] = [ item for item in items - if hasattr(item, "callspec") - and param_name in item.callspec.params - and predicate(item.callspec.params[param_name]) + if hasattr(item, "callspec") and param_name in item.callspec.params and predicate(item.callspec.params[param_name]) ] @@ -77,18 +73,12 @@ def filter_sdk_version(items: list, sdk_version: str, sdk_param: str): from aixplain.v2.resource import BaseResource def predicate(param: Any): - return ( - issubclass(param, BaseResource) - if sdk_version == SDK_VERSION_V1 - else not issubclass(param, BaseResource) - ) + return issubclass(param, BaseResource) if sdk_version == SDK_VERSION_V1 else not issubclass(param, BaseResource) filter_items(items, sdk_param, predicate) -def pytest_collection_modifyitems( - session: pytest.Session, config: pytest.Config, items: list -): +def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config, items: list): """Modify the items based on the pipeline version and the SDK version. Args: @@ -108,7 +98,5 @@ def pytest_collection_modifyitems( if sdk_version: sdk_param = config.getoption(f"{SDK_VERSION_PARAM_ARG}") if not sdk_param: - raise ValueError( - f"{SDK_VERSION_PARAM_ARG} parameter is required when using {SDK_VERSION_ARG}" - ) + raise ValueError(f"{SDK_VERSION_PARAM_ARG} parameter is required when using {SDK_VERSION_ARG}") filter_sdk_version(items, sdk_version, sdk_param) diff --git a/tests/functional/data_asset/corpus_onboarding_test.py b/tests/functional/data_asset/corpus_onboarding_test.py index b8633209..ca681746 100644 --- a/tests/functional/data_asset/corpus_onboarding_test.py +++ b/tests/functional/data_asset/corpus_onboarding_test.py @@ -69,4 +69,4 @@ def test_corpus_listing(CorpusFactory): @pytest.mark.parametrize("CorpusFactory", [CorpusFactory, v2.Corpus]) def test_corpus_get_error(CorpusFactory): with pytest.raises(Exception): - response = CorpusFactory.get("131312") \ No newline at end of file + response = CorpusFactory.get("131312") diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 41526945..71838034 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -9,12 +9,8 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from aixplain.factories.index_factory.utils import AirParams, VectaraParams, GraphRAGParams, ZeroEntropyParams -from aixplain.factories import IndexFactory -from aixplain.modules.model.record import Record import time - - - +import os def pytest_generate_tests(metafunc): @@ -86,14 +82,15 @@ def test_run_async(): def run_index_model(index_model, retries): from aixplain.modules.model.record import Record - for _ in range(retries): try: - index_model.upsert([Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})]) + index_model.upsert( + [Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})] + ) break except Exception as e: time.sleep(180) - + response = index_model.search("Berlin") assert str(response.status) == "SUCCESS" assert "germany" in response.data.lower() @@ -112,7 +109,6 @@ def run_index_model(index_model, retries): pytest.param("6658d40729985c2cf72f42ec", AirParams, id="AIR - Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), pytest.param("67efd4f92a0a850afa045af7", AirParams, id="AIR - BGE M3"), - pytest.param("681254b668e47e7844c1f15a", AirParams, id="AIR - aiXplain Legal Embeddings"), ], ) def test_index_model(embedding_model, supplier_params): @@ -131,6 +127,7 @@ def test_index_model(embedding_model, supplier_params): retries = 1 run_index_model(index_model, retries) + @pytest.mark.parametrize( "embedding_model,supplier_params", [ @@ -140,7 +137,6 @@ def test_index_model(embedding_model, supplier_params): pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), pytest.param("67efd4f92a0a850afa045af7", AirParams, id="BGE M3"), - pytest.param("681254b668e47e7844c1f15a", AirParams, id="aiXplain Legal Embeddings"), ], ) def test_index_model_with_filter(embedding_model, supplier_params): @@ -163,17 +159,21 @@ def test_index_model_with_filter(embedding_model, supplier_params): retries = 1 for _ in range(retries): try: - index_model.upsert([Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})]) + index_model.upsert( + [Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})] + ) break except Exception: time.sleep(180) for _ in range(retries): try: - index_model.upsert([Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})]) + index_model.upsert( + [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] + ) break except Exception: time.sleep(180) - + assert index_model.count() == 2 response = index_model.search( "", filters=[IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS)] @@ -269,7 +269,6 @@ def test_index_model_air_with_image(): index_model.upsert(records) - response = index_model.search("beach") assert str(response.status) == "SUCCESS" second_record = response.details[1]["metadata"]["uri"] @@ -291,7 +290,37 @@ def test_index_model_air_with_image(): index_model.delete() - import os - if os.path.exists("hurricane.jpeg"): - os.remove("hurricane.jpeg") +@pytest.mark.parametrize( + "embedding_model,supplier_params", + [ + pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="OpenAI Ada 002"), + pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), + pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), + pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"), + ], +) +def test_index_model_air_with_splitter(embedding_model, supplier_params): + from aixplain.factories import IndexFactory + from aixplain.modules.model.record import Record + from uuid import uuid4 + from aixplain.modules.model.index_model import Splitter + from aixplain.enums.splitting_options import SplittingOptions + + for index in IndexFactory.list()["results"]: + index.delete() + + params = supplier_params( + name=f"Splitter Index {uuid4()}", description="Index for splitter", embedding_model=embedding_model + ) + index_model = IndexFactory.create(params=params) + index_model.upsert( + [Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})], + splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0), + ) + response = index_model.count() + assert response == 6 + response = index_model.search("berlin") + assert str(response.status) == "SUCCESS" + assert "berlin" in response.data.lower() + index_model.delete() diff --git a/tests/functional/pipelines/create_test.py b/tests/functional/pipelines/create_test.py index 8bcb3b7d..f1dac2c4 100644 --- a/tests/functional/pipelines/create_test.py +++ b/tests/functional/pipelines/create_test.py @@ -25,6 +25,7 @@ from uuid import uuid4 from aixplain import aixplain_v2 as v2 + @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) def test_create_pipeline_from_json(PipelineFactory): pipeline_json = "tests/functional/pipelines/data/pipeline.json" diff --git a/tests/functional/pipelines/designer_test.py b/tests/functional/pipelines/designer_test.py index a2710b7e..0aa24355 100644 --- a/tests/functional/pipelines/designer_test.py +++ b/tests/functional/pipelines/designer_test.py @@ -119,9 +119,7 @@ def test_routing_pipeline(pipeline): pipeline.save() - output = pipeline.run( - "This is a sample text!", **{"batchmode": False, "version": "3.0"} - ) + output = pipeline.run("This is a sample text!", **{"batchmode": False, "version": "3.0"}) assert output["status"] == ResponseStatus.SUCCESS @@ -131,9 +129,7 @@ def test_scripting_pipeline(pipeline): input = pipeline.input() - segmentor = pipeline.speaker_diarization_audio( - asset_id=SPEAKER_DIARIZATION_AUDIO_ASSET - ) + segmentor = pipeline.speaker_diarization_audio(asset_id=SPEAKER_DIARIZATION_AUDIO_ASSET) script = pipeline.script(script_path="tests/functional/pipelines/data/script.py") script.inputs.create_param(code="transcripts", data_type=DataType.TEXT) @@ -201,9 +197,7 @@ def test_reconstructing_pipeline(pipeline): segmentor = pipeline.speaker_diarization_audio(asset_id="62fab6ecb39cca09ca5bc365") - speech_recognition = pipeline.speech_recognition( - asset_id="60ddefab8d38c51c5885ee38" - ) + speech_recognition = pipeline.speech_recognition(asset_id="60ddefab8d38c51c5885ee38") reconstructor = pipeline.text_reconstruction(asset_id="636cf7ab0f8ddf0db97929e4") @@ -234,25 +228,17 @@ def test_metric_pipeline(pipeline): reference_input_node = pipeline.input(label="ReferenceInput") # Instantiate the metric node - translation_metric_node = pipeline.text_generation_metric( - asset_id="639874ab506c987b1ae1acc6" - ) + translation_metric_node = pipeline.text_generation_metric(asset_id="639874ab506c987b1ae1acc6") # Instantiate output node score_output_node = pipeline.output() # Link the nodes - text_input_node.link( - translation_metric_node, from_param="input", to_param="hypotheses" - ) + text_input_node.link(translation_metric_node, from_param="input", to_param="hypotheses") - reference_input_node.link( - translation_metric_node, from_param="input", to_param="references" - ) + reference_input_node.link(translation_metric_node, from_param="input", to_param="references") - translation_metric_node.link( - score_output_node, from_param="data", to_param="output" - ) + translation_metric_node.link(score_output_node, from_param="data", to_param="output") translation_metric_node.inputs.score_identifier = "bleu" diff --git a/tests/functional/pipelines/run_test.py b/tests/functional/pipelines/run_test.py index 86b0d213..a3fcbfb1 100644 --- a/tests/functional/pipelines/run_test.py +++ b/tests/functional/pipelines/run_test.py @@ -54,9 +54,7 @@ def test_get_pipeline(PipelineFactory): def test_run_single_str(batchmode: bool, version: str): pipeline = PipelineFactory.list(query="SingleNodePipeline")["results"][0] - response = pipeline.run( - data="Translate this thing", batch_mode=batchmode, **{"version": version} - ) + response = pipeline.run(data="Translate this thing", batch_mode=batchmode, **{"version": version}) assert response["status"] == ResponseStatus.SUCCESS @@ -180,9 +178,7 @@ def test_run_multipipe_with_datasets(batchmode: bool, version: str, PipelineFact @pytest.mark.parametrize("version", ["2.0", "3.0"]) @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) def test_run_segment_reconstruct(version: str, PipelineFactory): - pipeline = PipelineFactory.list( - query="Segmentation/Reconstruction Functional Test - DO NOT DELETE" - )["results"][0] + pipeline = PipelineFactory.list(query="Segmentation/Reconstruction Functional Test - DO NOT DELETE")["results"][0] response = pipeline.run( "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", **{"version": version}, @@ -212,9 +208,7 @@ def test_run_translation_metric(version: str, PipelineFactory): @pytest.mark.parametrize("version", ["2.0", "3.0"]) @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) def test_run_metric(version: str, PipelineFactory): - pipeline = PipelineFactory.list(query="ASR Metric Functional Test - DO NOT DELETE")[ - "results" - ][0] + pipeline = PipelineFactory.list(query="ASR Metric Functional Test - DO NOT DELETE")["results"][0] response = pipeline.run( { "AudioInput": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", @@ -280,9 +274,7 @@ def test_run_decision(input_data: str, output_data: str, version: str, PipelineF @pytest.mark.parametrize("version", ["3.0"]) @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) def test_run_script(version: str, PipelineFactory): - pipeline = PipelineFactory.list(query="Script Functional Test - DO NOT DELETE")[ - "results" - ][0] + pipeline = PipelineFactory.list(query="Script Functional Test - DO NOT DELETE")["results"][0] response = pipeline.run( "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", **{"version": version}, @@ -296,9 +288,7 @@ def test_run_script(version: str, PipelineFactory): @pytest.mark.parametrize("version", ["2.0", "3.0"]) @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) def test_run_text_reconstruction(version: str, PipelineFactory): - pipeline = PipelineFactory.list(query="Text Reconstruction - DO NOT DELETE")[ - "results" - ][0] + pipeline = PipelineFactory.list(query="Text Reconstruction - DO NOT DELETE")["results"][0] response = pipeline.run("Segment A\nSegment B\nSegment C", **{"version": version}) assert response["status"] == ResponseStatus.SUCCESS @@ -316,9 +306,7 @@ def test_run_text_reconstruction(version: str, PipelineFactory): @pytest.mark.parametrize("version", ["3.0"]) @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) def test_run_diarization(version: str, PipelineFactory): - pipeline = PipelineFactory.list( - query="Diarization ASR Functional Test - DO NOT DELETE" - )["results"][0] + pipeline = PipelineFactory.list(query="Diarization ASR Functional Test - DO NOT DELETE")["results"][0] response = pipeline.run( "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", **{"version": version}, @@ -329,12 +317,11 @@ def test_run_diarization(version: str, PipelineFactory): assert len(d["segments"]) > 0 assert d["segments"][0]["success"] is True + @pytest.mark.parametrize("version", ["3.0"]) @pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) def test_run_failure(version: str, PipelineFactory): - pipeline = PipelineFactory.list(query="Script Functional Test - DO NOT DELETE")[ - "results" - ][0] + pipeline = PipelineFactory.list(query="Script Functional Test - DO NOT DELETE")["results"][0] response = pipeline.run( "INCORRECT DATA", **{"version": version}, diff --git a/tests/unit/agent/sql_tool_test.py b/tests/unit/agent/sql_tool_test.py index b9128bb3..d68cd634 100644 --- a/tests/unit/agent/sql_tool_test.py +++ b/tests/unit/agent/sql_tool_test.py @@ -357,4 +357,4 @@ def test_create_sql_tool_source_type_handling(tmp_path): with pytest.raises(SQLToolError, match="Source type must be either a string or DatabaseSourceType enum, got "): AgentFactory.create_sql_tool( name="Test SQL", description="Test", source=db_path, source_type=123, schema="test" - ) # Invalid type + ) # Invalid type \ No newline at end of file diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 4b265dd3..8d5c3a74 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -242,3 +242,13 @@ def test_index_factory_create_failure(): with pytest.raises(Exception) as e: IndexFactory.create(name="test", description="test", embedding_model=None) assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + + +def test_index_model_splitter(): + from aixplain.modules.model.index_model import Splitter + + splitter = Splitter(split=True, split_by="sentence", split_length=100, split_overlap=0) + assert splitter.split == True + assert splitter.split_by == "sentence" + assert splitter.split_length == 100 + assert splitter.split_overlap == 0 diff --git a/tests/unit/v2/test_core.py b/tests/unit/v2/test_core.py index 19c20bb3..5cbb9257 100644 --- a/tests/unit/v2/test_core.py +++ b/tests/unit/v2/test_core.py @@ -14,18 +14,12 @@ def test_aixplain_instance(): aixplain = Aixplain(api_key="test") assert aixplain is not None assert aixplain.api_key == "test" - assert ( - aixplain.base_url == os.getenv("BACKEND_URL") - or "https://platform-api.aixplain.com" - ) + assert aixplain.base_url == os.getenv("BACKEND_URL") or "https://platform-api.aixplain.com" assert ( aixplain.pipeline_url == os.getenv("PIPELINES_RUN_URL") or "https://platform-api.aixplain.com/assets/pipeline/execution/run" ) - assert ( - aixplain.model_url == os.getenv("MODELS_RUN_URL") - or "https://models.aixplain.com/api/v1/execute" - ) + assert aixplain.model_url == os.getenv("MODELS_RUN_URL") or "https://models.aixplain.com/api/v1/execute" aixplain.init_env.assert_called_once() aixplain.init_client.assert_called_once() aixplain.init_resources.assert_called_once() @@ -42,10 +36,7 @@ def test_aixplain_init_env(): aixplain.init_env() assert mock_environ["TEAM_API_KEY"] == "test" assert mock_environ["BACKEND_URL"] == "https://platform-api.aixplain.com" - assert ( - mock_environ["PIPELINE_URL"] - == "https://platform-api.aixplain.com/assets/pipeline/execution/run" - ) + assert mock_environ["PIPELINE_URL"] == "https://platform-api.aixplain.com/assets/pipeline/execution/run" assert mock_environ["MODEL_URL"] == "https://models.aixplain.com/api/v1/execute" From 59e40d67be63c6bb340e2b45252af95fd3cdd43a Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Mon, 2 Jun 2025 18:32:50 +0300 Subject: [PATCH 40/62] pipeline to_dict change + fixed circular imports pipeline functional tests (#544) --- aixplain/factories/agent_factory/__init__.py | 2 +- aixplain/modules/pipeline/default.py | 17 ++--------------- aixplain/modules/team_agent/__init__.py | 3 ++- 3 files changed, 5 insertions(+), 17 deletions(-) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 0190cb55..4325efe4 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -45,7 +45,6 @@ from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from aixplain.enums import DatabaseSourceType -from aixplain.utils.llm_utils import get_llm_instance class AgentFactory: @@ -85,6 +84,7 @@ def create( Returns: Agent: created Agent """ + from aixplain.utils.llm_utils import get_llm_instance if llm is None and llm_id is not None: llm = get_llm_instance(llm_id, api_key=api_key) elif llm is None: diff --git a/aixplain/modules/pipeline/default.py b/aixplain/modules/pipeline/default.py index 226deefd..0815c127 100644 --- a/aixplain/modules/pipeline/default.py +++ b/aixplain/modules/pipeline/default.py @@ -14,18 +14,5 @@ def save(self, *args, **kwargs): super().save(*args, **kwargs) def to_dict(self) -> dict: - data = self.__dict__.copy() - - for key, value in data.items(): - if isinstance(value, Enum): - data[key] = value.value - - elif isinstance(value, list): - data[key] = [ - v.to_dict() if hasattr(v, "to_dict") else str(v) for v in value - ] - - elif hasattr(value, "to_dict"): - data[key] = value.to_dict() - - return data + return self.serialize() + diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 510e82d1..75f0b0dd 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -44,7 +44,6 @@ from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.model.llm_model import LLM -from aixplain.utils.llm_utils import get_llm_instance from aixplain.modules.mixins import DeployableMixin from pydantic import BaseModel @@ -361,6 +360,8 @@ def to_dict(self) -> Dict: } def _validate(self) -> None: + from aixplain.utils.llm_utils import get_llm_instance + """Validate the Team.""" # validate name From c8aa11c905bb1c04c963474e4fdfc0bac6c9e96d Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Wed, 4 Jun 2025 15:30:56 -0300 Subject: [PATCH 41/62] ENG-1962: composio (#547) * Add function type * Introducing the connector class * Introducr Connection * get action and get action inputs in connection * Running a connection * Functional and unit tests * Introducing ToolFactory * Refactoring of ToolFactory * Model getter mixin * Model List Mixin and Connect repr changes * Connection as tools * Add code on dict repr of connection * Functional tests for action tools in agents/teams * Integrator factory * Changes asked by reviewers * ToolFactory test * CONNECTOR to INTEGRATION and ScriptModel to UtilityModel --- .github/workflows/main.yaml | 1 + aixplain/enums/__init__.py | 1 + aixplain/enums/function_type.py | 35 +++ aixplain/factories/__init__.py | 2 + aixplain/factories/agent_factory/__init__.py | 25 +- aixplain/factories/agent_factory/utils.py | 34 ++- aixplain/factories/index_factory/__init__.py | 7 +- aixplain/factories/integration_factory.py | 42 +++ aixplain/factories/model_factory/__init__.py | 189 +------------- .../model_factory/mixins/__init__.py | 4 + .../model_factory/mixins/model_getter.py | 77 ++++++ .../model_factory/mixins/model_list.py | 75 ++++++ aixplain/factories/model_factory/utils.py | 14 +- aixplain/factories/tool_factory.py | 155 +++++++++++ aixplain/modules/agent/__init__.py | 4 +- aixplain/modules/model/__init__.py | 6 +- aixplain/modules/model/connection.py | 129 ++++++++++ aixplain/modules/model/index_model.py | 4 +- aixplain/modules/model/integration.py | 161 ++++++++++++ aixplain/modules/model/llm_model.py | 6 +- aixplain/modules/model/utility_model.py | 18 +- .../functional/agent/agent_functional_test.py | 41 ++- .../model/run_connect_model_test.py | 30 +++ .../team_agent/team_agent_functional_test.py | 50 ++++ tests/unit/model_test.py | 241 +++++++++++++++++- 25 files changed, 1137 insertions(+), 214 deletions(-) create mode 100644 aixplain/enums/function_type.py create mode 100644 aixplain/factories/integration_factory.py create mode 100644 aixplain/factories/model_factory/mixins/__init__.py create mode 100644 aixplain/factories/model_factory/mixins/model_getter.py create mode 100644 aixplain/factories/model_factory/mixins/model_list.py create mode 100644 aixplain/factories/tool_factory.py create mode 100644 aixplain/modules/model/connection.py create mode 100644 aixplain/modules/model/integration.py create mode 100644 tests/functional/model/run_connect_model_test.py diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 193e5c7b..96ddff04 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -110,6 +110,7 @@ jobs: echo "MODELS_RUN_URL=https://test-models.aixplain.com/api/v1/execute" >> $GITHUB_ENV echo "PIPELINES_RUN_URL=https://test-platform-api.aixplain.com/assets/pipeline/execution/run" >> $GITHUB_ENV fi + echo "SLACK_TOKEN=${{ secrets.SLACK_TOKEN }}" >> $GITHUB_ENV - name: Run Tests timeout-minutes: ${{ matrix.timeout }} diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index e80c03c6..725fdb90 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -19,3 +19,4 @@ from .embedding_model import EmbeddingModel from .asset_status import AssetStatus from .index_stores import IndexStores +from .function_type import FunctionType diff --git a/aixplain/enums/function_type.py b/aixplain/enums/function_type.py new file mode 100644 index 00000000..514ff992 --- /dev/null +++ b/aixplain/enums/function_type.py @@ -0,0 +1,35 @@ +__author__ = "aiXplain" + +""" +Copyright 2023 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: aiXplain team +Date: May 22th 2025 +Description: + Function Type Enum +""" + +from enum import Enum + + +class FunctionType(Enum): + AI = "ai" + SEGMENTOR = "segmentor" + RECONSTRUCTOR = "reconstructor" + UTILITY = "utility" + METRIC = "metric" + SEARCH = "search" + INTEGRATION = "connector" + CONNECTION = "connection" diff --git a/aixplain/factories/__init__.py b/aixplain/factories/__init__.py index f663a4eb..905a1c68 100644 --- a/aixplain/factories/__init__.py +++ b/aixplain/factories/__init__.py @@ -34,3 +34,5 @@ from .wallet_factory import WalletFactory from .api_key_factory import APIKeyFactory from .index_factory import IndexFactory +from .tool_factory import ToolFactory +from .integration_factory import IntegrationFactory diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 4325efe4..962061b0 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -85,6 +85,7 @@ def create( Agent: created Agent """ from aixplain.utils.llm_utils import get_llm_instance + if llm is None and llm_id is not None: llm = get_llm_instance(llm_id, api_key=api_key) elif llm is None: @@ -104,8 +105,7 @@ def create( "Note: In upcoming releases, `llm` will become a required parameter.", UserWarning, ) - - from aixplain.factories.agent_factory.utils import build_agent + from aixplain.factories.agent_factory.utils import build_agent, build_tool_payload agent = None url = urljoin(config.BACKEND_URL, "sdk/agents") @@ -118,24 +118,7 @@ def create( payload = { "name": name, - "assets": [ - tool.to_dict() - if isinstance(tool, Tool) - else { - "id": tool.id, - "name": tool.name, - "description": tool.description, - "supplier": tool.supplier.value["code"] if isinstance(tool.supplier, Supplier) else tool.supplier, - "parameters": tool.get_parameters().to_list() - if hasattr(tool, "get_parameters") and tool.get_parameters() is not None - else None, - "function": tool.function if hasattr(tool, "function") and tool.function is not None else None, - "type": "model", - "version": tool.version if hasattr(tool, "version") else None, - "assetId": tool.id, - } - for tool in tools - ], + "assets": [build_tool_payload(tool) for tool in tools], "description": description, "role": instructions or description, "supplier": supplier, @@ -423,4 +406,4 @@ def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> Agent: if "message" in resp: msg = resp["message"] error_msg = f"Agent Get Error (HTTP {r.status_code}): {msg}" - raise Exception(error_msg) \ No newline at end of file + raise Exception(error_msg) diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 43540344..8264c07e 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -14,12 +14,44 @@ from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool from aixplain.modules.agent.tool.sql_tool import SQLTool -from typing import Dict, Text, List +from aixplain.modules.model import Model +from aixplain.modules.model.connection import ConnectionTool +from typing import Dict, Text, List, Union from urllib.parse import urljoin GPT_4o_ID = "6646261c6eb563165658bbb1" +def build_tool_payload(tool: Union[Tool, Model]): + """Build a tool payload from a tool or model object. + + Args: + tool (Union[Tool, Model]): The tool or model object to build the payload from. + + Returns: + Dict: The tool payload. + """ + if isinstance(tool, Tool): + return tool.to_dict() + else: + parameters = None + if isinstance(tool, ConnectionTool): + parameters = tool.get_parameters() + elif hasattr(tool, "get_parameters") and tool.get_parameters() is not None: + parameters = tool.get_parameters().to_list() + return { + "id": tool.id, + "name": tool.name, + "description": tool.description, + "supplier": tool.supplier.value["code"] if isinstance(tool.supplier, Supplier) else tool.supplier, + "parameters": parameters, + "function": tool.function if hasattr(tool, "function") and tool.function is not None else None, + "type": "model", + "version": tool.version if hasattr(tool, "version") else None, + "assetId": tool.id, + } + + def build_tool(tool: Dict): """Build a tool from a dictionary. diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index a5691988..189f9417 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -29,13 +29,10 @@ T = TypeVar("T", bound=BaseIndexParams) -import os -from aixplain.utils.file_utils import _request_with_retry -from urllib.parse import urljoin def validate_embedding_model(model_id) -> bool: - model = ModelFactory.get(model_id) - return model.function == Function.TEXT_EMBEDDING + model = ModelFactory.get(model_id) + return model.function == Function.TEXT_EMBEDDING class IndexFactory(ModelFactory, Generic[T]): diff --git a/aixplain/factories/integration_factory.py b/aixplain/factories/integration_factory.py new file mode 100644 index 00000000..2a4a397a --- /dev/null +++ b/aixplain/factories/integration_factory.py @@ -0,0 +1,42 @@ +__author__ = "thiagocastroferreira" + +import aixplain.utils.config as config +from aixplain.enums import Function, Supplier, SortBy, SortOrder, OwnershipType +from aixplain.factories.model_factory.mixins.model_getter import ModelGetterMixin +from aixplain.factories.model_factory.mixins.model_list import ModelListMixin +from aixplain.modules.model.integration import Integration +from typing import Optional, Text, Union, Tuple, List + + +class IntegrationFactory(ModelGetterMixin, ModelListMixin): + backend_url = config.BACKEND_URL + + @classmethod + def get(cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = False) -> Integration: + model = super().get(model_id=model_id, api_key=api_key) + assert isinstance(model, Integration), f"The provided ID ('{model_id}') is not from an integration model" + return model + + @classmethod + def list( + cls, + query: Optional[Text] = "", + suppliers: Optional[Union[Supplier, List[Supplier]]] = None, + ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, + sort_by: Optional[SortBy] = None, + sort_order: SortOrder = SortOrder.ASCENDING, + page_number: int = 0, + page_size: int = 20, + api_key: Optional[Text] = None, + ) -> List[Integration]: + return super().list( + function=Function.CONNECTOR, + query=query, + suppliers=suppliers, + ownership=ownership, + sort_by=sort_by, + sort_order=sort_order, + page_number=page_number, + page_size=page_size, + api_key=api_key, + ) diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index 64ab8d99..2d1370d7 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -20,27 +20,18 @@ Description: Model Factory Class """ -from typing import Callable, Dict, List, Optional, Text, Tuple, Union import json import logging -from aixplain.modules.model import Model from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput -from aixplain.enums import ( - Function, - Language, - OwnershipType, - Supplier, - SortBy, - SortOrder, -) +from aixplain.enums import Function from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin -from aixplain.utils.asset_cache import AssetCache, CACHE_FOLDER -from aixplain.factories.model_factory.utils import create_model_from_response +from aixplain.factories.model_factory.mixins import ModelGetterMixin, ModelListMixin +from typing import Callable, Dict, List, Optional, Text, Union -class ModelFactory: +class ModelFactory(ModelGetterMixin, ModelListMixin): """A static class for creating and exploring Model Objects. Attributes: @@ -87,9 +78,7 @@ def create_utility_model( url = urljoin(cls.backend_url, "sdk/utilities") headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"} try: - logging.info( - f"Start service for POST Utility Model - {url} - {headers} - {payload}" - ) + logging.info(f"Start service for POST Utility Model - {url} - {headers} - {payload}") r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() except Exception as e: @@ -98,159 +87,15 @@ def create_utility_model( if 200 <= r.status_code < 300: utility_model.id = resp["id"] - logging.info( - f"Utility Model Creation: Model {utility_model.id} instantiated." - ) + logging.info(f"Utility Model Creation: Model {utility_model.id} instantiated.") return utility_model else: - error_message = f"Utility Model Creation: Failed to create utility model. Status Code: {r.status_code}. Error: {resp}" - logging.error(error_message) - raise Exception(error_message) - - @classmethod - def get( - cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = False - ) -> Model: - """Create a 'Model' object from model id""" - cache = AssetCache(Model) - - if use_cache: - if cache.has_valid_cache(): - cached_model = cache.store.data.get(model_id) - if cached_model: - return cached_model - logging.info("Model not found in valid cache, fetching individually...") - model = cls._fetch_model_by_id(model_id, api_key) - cache.add(model) - return model - else: - try: - model_list_resp = cls.list(model_ids=None, api_key=api_key) - models = model_list_resp["results"] - cache.add_list(models) - for model in models: - if model.id == model_id: - return model - except Exception as e: - logging.error(f"Error fetching model list: {e}") - raise e - - logging.info("Fetching model directly without cache...") - model = cls._fetch_model_by_id(model_id, api_key) - cache.add(model) - return model - - @classmethod - def _fetch_model_by_id( - cls, model_id: Text, api_key: Optional[Text] = None - ) -> Model: - resp = None - try: - url = urljoin(cls.backend_url, f"sdk/models/{model_id}") - headers = { - "Authorization": f"Token {api_key or config.TEAM_API_KEY}", - "Content-Type": "application/json", - } - logging.info(f"Start service for GET Model - {url} - {headers}") - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - except Exception: - if resp and "statusCode" in resp: - status_code = resp["statusCode"] - message = f"Model Creation: Status {status_code} - {resp['message']}" - else: - message = "Model Creation: Unspecified Error" - logging.error(message) - raise Exception(message) - - if 200 <= r.status_code < 300: - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - - model = create_model_from_response(resp) - logging.info(f"Model Creation: Model {model_id} instantiated.") - return model - else: - error_message = f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}" + error_message = ( + f"Utility Model Creation: Failed to create utility model. Status Code: {r.status_code}. Error: {resp}" + ) logging.error(error_message) raise Exception(error_message) - @classmethod - def list( - cls, - function: Optional[Function] = None, - query: Optional[Text] = "", - suppliers: Optional[Union[Supplier, List[Supplier]]] = None, - source_languages: Optional[Union[Language, List[Language]]] = None, - target_languages: Optional[Union[Language, List[Language]]] = None, - is_finetunable: Optional[bool] = None, - ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, - sort_by: Optional[SortBy] = None, - sort_order: SortOrder = SortOrder.ASCENDING, - page_number: int = 0, - page_size: int = 20, - model_ids: Optional[List[Text]] = None, - api_key: Optional[Text] = None, - ) -> List[Model]: - """Gets the first k given models based on the provided task and language filters - - Args: - function (Function): function filter. - source_languages (Optional[Union[Language, List[Language]]], optional): language filter of input data. Defaults to None. - target_languages (Optional[Union[Language, List[Language]]], optional): language filter of output data. Defaults to None. - is_finetunable (Optional[bool], optional): can be finetuned or not. Defaults to None. - ownership (Optional[Tuple[OwnershipType, List[OwnershipType]]], optional): Ownership filters (e.g. SUBSCRIBED, OWNER). Defaults to None. - sort_by (Optional[SortBy], optional): sort the retrived models by a specific attribute, - page_number (int, optional): page number. Defaults to 0. - page_size (int, optional): page size. Defaults to 20. - model_ids (Optional[List[Text]], optional): model ids to filter. Defaults to None. - api_key (Optional[Text], optional): Team API key. Defaults to None. - - Returns: - List[Model]: List of models based on given filters - """ - if model_ids is not None: - from aixplain.factories.model_factory.utils import get_model_from_ids - - assert len(model_ids) > 0, "Please provide at least one model id" - assert ( - function is None - and suppliers is None - and source_languages is None - and target_languages is None - and is_finetunable is None - and ownership is None - and sort_by is None - ), "Cannot filter by function, suppliers, source languages, target languages, is finetunable, ownership, sort by when using model ids" - assert ( - len(model_ids) <= page_size - ), "Page size must be greater than the number of model ids" - models, total = get_model_from_ids(model_ids, api_key), len(model_ids) - else: - from aixplain.factories.model_factory.utils import get_assets_from_page - - models, total = get_assets_from_page( - query, - page_number, - page_size, - function, - suppliers, - source_languages, - target_languages, - is_finetunable, - ownership, - sort_by, - sort_order, - api_key, - ) - return { - "results": models, - "page_total": min(page_size, len(models)), - "page_number": page_number, - "total": total, - } - @classmethod def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]: """Lists available hosting machines for model. @@ -303,9 +148,7 @@ def list_gpus(cls, api_key: Optional[Text] = None) -> List[List[Text]]: return response_list @classmethod - def list_functions( - cls, verbose: Optional[bool] = False, api_key: Optional[Text] = None - ) -> List[Dict]: + def list_functions(cls, verbose: Optional[bool] = False, api_key: Optional[Text] = None) -> List[Dict]: """Lists supported model functions on platform. Args: @@ -403,9 +246,7 @@ def create_asset_repo( "onboardingParams": {}, } logging.debug(f"Body: {str(payload)}") - response = _request_with_retry( - "post", create_url, headers=headers, json=payload - ) + response = _request_with_retry("post", create_url, headers=headers, json=payload) assert response.status_code == 201 @@ -469,9 +310,7 @@ def onboard_model( } payload = {"image": image_tag, "sha": image_hash, "hostMachine": host_machine} logging.debug(f"Body: {str(payload)}") - response = _request_with_retry( - "post", onboard_url, headers=headers, json=payload - ) + response = _request_with_retry("post", onboard_url, headers=headers, json=payload) if response.status_code == 201: message = "Your onboarding request has been submitted to an aiXplain specialist for finalization. We will notify you when the process is completed." logging.info(message) @@ -533,9 +372,7 @@ def deploy_huggingface_model( return response_dicts @classmethod - def get_huggingface_model_status( - cls, model_id: Text, api_key: Optional[Text] = None - ): + def get_huggingface_model_status(cls, model_id: Text, api_key: Optional[Text] = None): """Gets the on-boarding status of a Hugging Face model with ID MODEL_ID. Args: diff --git a/aixplain/factories/model_factory/mixins/__init__.py b/aixplain/factories/model_factory/mixins/__init__.py new file mode 100644 index 00000000..0d51850b --- /dev/null +++ b/aixplain/factories/model_factory/mixins/__init__.py @@ -0,0 +1,4 @@ +from aixplain.factories.model_factory.mixins.model_getter import ModelGetterMixin +from aixplain.factories.model_factory.mixins.model_list import ModelListMixin + +__all__ = ["ModelGetterMixin", "ModelListMixin"] diff --git a/aixplain/factories/model_factory/mixins/model_getter.py b/aixplain/factories/model_factory/mixins/model_getter.py new file mode 100644 index 00000000..79632146 --- /dev/null +++ b/aixplain/factories/model_factory/mixins/model_getter.py @@ -0,0 +1,77 @@ +import logging + +from aixplain.factories.model_factory.utils import create_model_from_response +from aixplain.modules.model import Model +from aixplain.utils import config +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.asset_cache import AssetCache +from urllib.parse import urljoin +from typing import Optional, Text + + +class ModelGetterMixin: + @classmethod + def get(cls, model_id: Text, api_key: Optional[Text] = None, use_cache: bool = False) -> Model: + """Create a 'Model' object from model id""" + model_id = model_id.replace("/", "%2F") + cache = AssetCache(Model) + + if use_cache: + if cache.has_valid_cache(): + cached_model = cache.store.data.get(model_id) + if cached_model: + return cached_model + logging.info("Model not found in valid cache, fetching individually...") + model = cls._fetch_model_by_id(model_id, api_key) + cache.add(model) + return model + else: + try: + model_list_resp = cls.list(model_ids=None, api_key=api_key) + models = model_list_resp["results"] + cache.add_list(models) + for model in models: + if model.id == model_id: + return model + except Exception as e: + logging.error(f"Error fetching model list: {e}") + raise e + + logging.info("Fetching model directly without cache...") + model = cls._fetch_model_by_id(model_id, api_key) + cache.add(model) + return model + + @classmethod + def _fetch_model_by_id(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: + resp = None + try: + url = urljoin(cls.backend_url, f"sdk/models/{model_id}") + headers = { + "Authorization": f"Token {api_key or config.TEAM_API_KEY}", + "Content-Type": "application/json", + } + logging.info(f"Start service for GET Model - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + except Exception: + if resp and "statusCode" in resp: + status_code = resp["statusCode"] + message = f"Model Creation: Status {status_code} - {resp['message']}" + else: + message = "Model Creation: Unspecified Error" + logging.error(message) + raise Exception(message) + + if 200 <= r.status_code < 300: + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + + model = create_model_from_response(resp) + logging.info(f"Model Creation: Model {model_id} instantiated.") + return model + else: + error_message = f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) diff --git a/aixplain/factories/model_factory/mixins/model_list.py b/aixplain/factories/model_factory/mixins/model_list.py new file mode 100644 index 00000000..3a1ae5f9 --- /dev/null +++ b/aixplain/factories/model_factory/mixins/model_list.py @@ -0,0 +1,75 @@ +from typing import Optional, Union, List, Tuple, Text +from aixplain.factories.model_factory.utils import get_model_from_ids, get_assets_from_page +from aixplain.enums import Function, Language, OwnershipType, SortBy, SortOrder, Supplier +from aixplain.modules.model import Model + + +class ModelListMixin: + @classmethod + def list( + cls, + function: Optional[Function] = None, + query: Optional[Text] = "", + suppliers: Optional[Union[Supplier, List[Supplier]]] = None, + source_languages: Optional[Union[Language, List[Language]]] = None, + target_languages: Optional[Union[Language, List[Language]]] = None, + is_finetunable: Optional[bool] = None, + ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, + sort_by: Optional[SortBy] = None, + sort_order: SortOrder = SortOrder.ASCENDING, + page_number: int = 0, + page_size: int = 20, + model_ids: Optional[List[Text]] = None, + api_key: Optional[Text] = None, + ) -> List[Model]: + """Gets the first k given models based on the provided task and language filters + + Args: + function (Function): function filter. + source_languages (Optional[Union[Language, List[Language]]], optional): language filter of input data. Defaults to None. + target_languages (Optional[Union[Language, List[Language]]], optional): language filter of output data. Defaults to None. + is_finetunable (Optional[bool], optional): can be finetuned or not. Defaults to None. + ownership (Optional[Tuple[OwnershipType, List[OwnershipType]]], optional): Ownership filters (e.g. SUBSCRIBED, OWNER). Defaults to None. + sort_by (Optional[SortBy], optional): sort the retrived models by a specific attribute, + page_number (int, optional): page number. Defaults to 0. + page_size (int, optional): page size. Defaults to 20. + model_ids (Optional[List[Text]], optional): model ids to filter. Defaults to None. + api_key (Optional[Text], optional): Team API key. Defaults to None. + + Returns: + List[Model]: List of models based on given filters + """ + if model_ids is not None: + assert len(model_ids) > 0, "Please provide at least one model id" + assert ( + function is None + and suppliers is None + and source_languages is None + and target_languages is None + and is_finetunable is None + and ownership is None + and sort_by is None + ), "Cannot filter by function, suppliers, source languages, target languages, is finetunable, ownership, sort by when using model ids" + assert len(model_ids) <= page_size, "Page size must be greater than the number of model ids" + models, total = get_model_from_ids(model_ids, api_key), len(model_ids) + else: + models, total = get_assets_from_page( + query, + page_number, + page_size, + function, + suppliers, + source_languages, + target_languages, + is_finetunable, + ownership, + sort_by, + sort_order, + api_key, + ) + return { + "results": models, + "page_total": min(page_size, len(models)), + "page_number": page_number, + "total": total, + } diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 2a1c9269..ca3e1eec 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -3,14 +3,16 @@ from aixplain.modules.model import Model from aixplain.modules.model.llm_model import LLM from aixplain.modules.model.index_model import IndexModel -from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput -from aixplain.enums import DataType, Function, Language, OwnershipType, Supplier, SortBy, SortOrder, AssetStatus +from aixplain.modules.model.integration import Integration +from aixplain.modules.model.connection import ConnectionTool +from aixplain.modules.model.utility_model import UtilityModel +from aixplain.modules.model.utility_model import UtilityModelInput +from aixplain.enums import DataType, Function, FunctionType, Language, OwnershipType, Supplier, SortBy, SortOrder, AssetStatus from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from datetime import datetime from typing import Dict, Union, List, Optional, Tuple from urllib.parse import urljoin -from aixplain.enums import AssetStatus import requests @@ -48,6 +50,7 @@ def create_model_from_response(response: Dict) -> Model: function_id = response["function"]["id"] function = Function(function_id) + function_type = FunctionType(response.get("functionType", "ai")) function_input_params, function_output_params = function.get_input_output_params() model_params = {param["name"]: param for param in response["params"]} @@ -65,6 +68,10 @@ def create_model_from_response(response: Dict) -> Model: temperature = float(f[0]["defaultValues"][0]["value"]) elif function == Function.SEARCH: ModelClass = IndexModel + elif function_type == FunctionType.INTEGRATION: + ModelClass = Integration + elif function_type == FunctionType.CONNECTION: + ModelClass = ConnectionTool elif function == Function.UTILITIES: ModelClass = UtilityModel inputs = [ @@ -110,6 +117,7 @@ def create_model_from_response(response: Dict) -> Model: temperature=temperature, supports_streaming=response.get("supportsStreaming", False), status=status, + function_type=function_type, **additional_kwargs, ) diff --git a/aixplain/factories/tool_factory.py b/aixplain/factories/tool_factory.py new file mode 100644 index 00000000..9db2a9ea --- /dev/null +++ b/aixplain/factories/tool_factory.py @@ -0,0 +1,155 @@ +import warnings +from aixplain.enums import FunctionType +from aixplain.factories import ModelFactory +from aixplain.factories.model_factory.mixins import ModelGetterMixin, ModelListMixin +from aixplain.modules.model import Model +from aixplain.modules.model.index_model import IndexModel +from aixplain.modules.model.integration import Integration +from aixplain.modules.model.integration import BaseAuthenticationParams +from aixplain.factories.index_factory.utils import BaseIndexParams, AirParams, VectaraParams, ZeroEntropyParams, GraphRAGParams +from aixplain.enums.index_stores import IndexStores +from aixplain.modules.model.utility_model import BaseUtilityModelParams +from typing import Optional, Text, Union +from aixplain.enums import ResponseStatus +from aixplain.utils import config + + +class ToolFactory(ModelGetterMixin, ModelListMixin): + backend_url = config.BACKEND_URL + + @classmethod + def create( + cls, + integration: Optional[Union[Text, Model]] = None, + params: Optional[Union[BaseUtilityModelParams, BaseIndexParams, BaseAuthenticationParams]] = None, + **kwargs, + ) -> Model: + """Factory method to create indexes, script models and connections + + Examples: + Create a script model (option 1): + Option 1: + from aixplain.modules.model.utility_model import BaseUtilityModelParams + + def add(a: int, b: int) -> int: + return a + b + + params = BaseUtilityModelParams( + name="My Script Model", + description="My Script Model Description", + code=add + ) + tool = ToolFactory.create(params=params) + + Option 2: + def add(a: int, b: int) -> int: + \"\"\"Add two numbers\"\"\" + return a + b + + tool = ToolFactory.create( + name="My Script Model", + code=add + ) + + Create a search collection: + Option 1: + from aixplain.factories.index_factory.utils import AirParams + + params = AirParams( + name="My Search Collection", + description="My Search Collection Description" + ) + tool = ToolFactory.create(params=params) + + Option 2: + from aixplain.enums.index_stores import IndexStores + + tool = ToolFactory.create( + integration=IndexStores.VECTARA.get_model_id(), + name="My Search Collection", + description="My Search Collection Description" + ) + + Create a connector: + Option 1: + from aixplain.modules.model.connector import BearerAuthenticationParams + + params = BearerAuthenticationParams( + connector_id="my_connector_id", + token="my_token", + name="My Connection" + ) + tool = ToolFactory.create(params=params) + + Option 2: + tool = ToolFactory.create( + integration="my_connector_id", + name="My Connection", + token="my_token" + ) + + Args: + params: The parameters for the tool + Returns: + The created tool + """ + if params is None: + integration_model = None + if isinstance(integration, Text): + integration_model = cls.get(integration) + elif isinstance(integration, Model): + integration_model = integration + integration = integration_model.id + + assert ( + isinstance(integration_model, Integration) + or isinstance(integration_model, IndexModel) + or kwargs.get("code") is not None + ), "Please provide the proper integration (ConnectorModel, IndexModel or UtilityModel code) or params to create a model tool." + if isinstance(integration_model, Integration): + from aixplain.modules.model.integration import build_connector_params + + kwargs["connector_id"] = integration_model.id + params = build_connector_params(**kwargs) + elif isinstance(integration_model, IndexModel): + if IndexStores.AIR.get_model_id() == integration_model.id: + params = AirParams(**kwargs) + elif IndexStores.VECTARA.get_model_id() == integration_model.id: + params = VectaraParams(**kwargs) + elif IndexStores.ZERO_ENTROPY.get_model_id() == integration_model.id: + params = ZeroEntropyParams(**kwargs) + elif IndexStores.GRAPHRAG.get_model_id() == integration_model.id: + params = GraphRAGParams(**kwargs) + else: + raise ValueError( + f"ToolFactory Error: The index store '{integration_model.id} - {integration_model.name}' is not supported." + ) + else: + params = BaseUtilityModelParams(**kwargs) + + if isinstance(params, BaseUtilityModelParams): + return ModelFactory.create_utility_model( + name=params.name, + description=params.description, + code=params.code, + ) + elif isinstance(params, BaseIndexParams): + from aixplain.factories import IndexFactory + + return IndexFactory.create(params=params) + elif isinstance(params, BaseAuthenticationParams): + assert params.connector_id is not None, "Please provide the ID of the service you want to connect to" + connector = cls.get(params.connector_id) + assert ( + connector.function_type == FunctionType.INTEGRATION + ), f"The model you are trying to connect ({connector.id}) to is not a connector." + response = connector.connect(params) + assert response.status == ResponseStatus.SUCCESS, f"Failed to connect to {connector.id} - {response.error_message}" + connection = cls.get(response.data["id"]) + if "redirectURL" in response.data: + warnings.warn( + f"Before using the tool, please visit the following URL to complete the connection: {response.data['redirectURL']}" + ) + return connection + else: + raise ValueError("ToolFactory Error: Invalid params") diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 3d449956..f4049adf 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -363,10 +363,12 @@ def run_async( ) def to_dict(self) -> Dict: + from aixplain.factories.agent_factory.utils import build_tool_payload + return { "id": self.id, "name": self.name, - "assets": [tool.to_dict() for tool in self.tools], + "assets": [build_tool_payload(tool) for tool in self.tools], "description": self.description, "role": self.instructions, "supplier": (self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier), diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index b00b1993..f20d4e20 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -23,7 +23,7 @@ import time import logging import traceback -from aixplain.enums import Supplier, Function +from aixplain.enums import Supplier, Function, FunctionType from aixplain.modules.asset import Asset from aixplain.modules.model.model_response_streamer import ModelResponseStreamer from aixplain.modules.model.utils import build_payload, call_run_endpoint @@ -58,6 +58,7 @@ class Model(Asset): output_params (Dict, optional): output parameters for the function. model_params (ModelParameters, optional): parameters for the function. supports_streaming (bool, optional): whether the model supports streaming. Defaults to False. + function_type (FunctionType, optional): type of the function. Defaults to FunctionType.AI. """ def __init__( @@ -77,6 +78,7 @@ def __init__( model_params: Optional[Dict] = None, supports_streaming: bool = False, status: Optional[AssetStatus] = AssetStatus.ONBOARDED, # default status for models is ONBOARDED + function_type: Optional[FunctionType] = FunctionType.AI, **additional_info, ) -> None: """Model Init @@ -96,6 +98,7 @@ def __init__( model_params (Dict, optional): parameters for the function. supports_streaming (bool, optional): whether the model supports streaming. Defaults to False. status (AssetStatus, optional): status of the model. Defaults to None. + function_type (FunctionType, optional): type of the function. Defaults to FunctionType.AI. **additional_info: Any additional Model info to be saved """ super().__init__(id, name, description, supplier, version, cost=cost) @@ -110,6 +113,7 @@ def __init__( self.output_params = output_params self.model_params = ModelParameters(model_params) if model_params else None self.supports_streaming = supports_streaming + self.function_type = function_type if isinstance(status, str): try: status = AssetStatus(status) diff --git a/aixplain/modules/model/connection.py b/aixplain/modules/model/connection.py new file mode 100644 index 00000000..6d98f140 --- /dev/null +++ b/aixplain/modules/model/connection.py @@ -0,0 +1,129 @@ +from aixplain.enums import Function, Supplier, FunctionType, ResponseStatus +from aixplain.modules.model import Model +from aixplain.utils import config +from typing import Text, Optional, Union, Dict, List + + +class ConnectAction: + name: Text + description: Text + code: Optional[Text] = None + inputs: Optional[Dict] = None + + def __init__(self, name: Text, description: Text, code: Optional[Text] = None, inputs: Optional[Dict] = None): + self.name = name + self.description = description + self.code = code + self.inputs = inputs + + def __repr__(self): + return f"Action(code={self.code}, name={self.name})" + + +class ConnectionTool(Model): + actions: List[ConnectAction] + action_scope: Optional[List[ConnectAction]] = None + + def __init__( + self, + id: Text, + name: Text, + description: Text = "", + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + version: Optional[Text] = None, + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + function_type: Optional[FunctionType] = FunctionType.CONNECTION, + **additional_info, + ) -> None: + """Connection Init + + Args: + id (Text): ID of the Model + name (Text): Name of the Model + description (Text, optional): description of the model. Defaults to "". + api_key (Text, optional): API key of the Model. Defaults to None. + supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". + version (Text, optional): version of the model. Defaults to "1.0". + function (Function, optional): model AI function. Defaults to None. + is_subscribed (bool, optional): Is the user subscribed. Defaults to False. + cost (Dict, optional): model price. Defaults to None. + scope (Text, optional): action scope of the connection. Defaults to None. + **additional_info: Any additional Model info to be saved + """ + assert function_type == FunctionType.CONNECTION, "Connection only supports connection function" + super().__init__( + id=id, + name=name, + description=description, + supplier=supplier, + version=version, + cost=cost, + function=function, + is_subscribed=is_subscribed, + api_key=api_key, + function_type=function_type, + **additional_info, + ) + self.url = config.MODELS_RUN_URL + self.backend_url = config.BACKEND_URL + self.actions = self._get_actions() + self.action_scope = None + + def _get_actions(self): + response = super().run({"action": "LIST_ACTIONS", "data": " "}) + if response.status == ResponseStatus.SUCCESS: + return [ + ConnectAction(name=action["displayName"], description=action["description"], code=action["name"]) + for action in response.data + ] + raise Exception( + f"It was not possible to get the actions for the connection {self.id}. Error {response.error_code}: {response.error_message}" + ) + + def get_action_inputs(self, action: Union[ConnectAction, Text]): + if action.inputs: + return action.inputs + + if isinstance(action, ConnectAction): + action = action.code + + response = super().run({"action": "LIST_INPUTS", "data": {"actions": [action]}}) + if response.status == ResponseStatus.SUCCESS: + try: + inputs = {inp["code"]: inp for inp in response.data[0]["inputs"]} + action_idx = next((i for i, a in enumerate(self.actions) if a.code == action), None) + if action_idx is not None: + self.actions[action_idx].inputs = inputs + return inputs + except Exception as e: + raise Exception(f"It was not possible to get the inputs for the action {action}. Error {e}") + + raise Exception( + f"It was not possible to get the inputs for the action {action}. Error {response.error_code}: {response.error_message}" + ) + + def run(self, action: Union[ConnectAction, Text], inputs: Dict): + if isinstance(action, ConnectAction): + action = action.code + return super().run({"action": action, "data": inputs}) + + def get_parameters(self) -> List[Dict]: + assert ( + self.action_scope is not None and len(self.action_scope) > 0 + ), f"Please set the scope of actions for the connection '{self.id}'." + response = [ + { + "code": action.code, + "name": action.name, + "description": action.description, + "inputs": self.get_action_inputs(action), + } + for action in self.action_scope + ] + return response + + def __repr__(self): + return f"ConnectionTool(id={self.id}, name={self.name})" diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index 523b3f64..04ed3df1 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -1,4 +1,4 @@ -from aixplain.enums import EmbeddingModel, Function, Supplier, ResponseStatus, StorageType +from aixplain.enums import EmbeddingModel, Function, Supplier, ResponseStatus, StorageType, FunctionType from aixplain.modules.model import Model from aixplain.utils import config from aixplain.modules.model.response import ModelResponse @@ -65,6 +65,7 @@ def __init__( is_subscribed: bool = False, cost: Optional[Dict] = None, embedding_model: Union[EmbeddingModel, str] = None, + function_type: Optional[FunctionType] = FunctionType.SEARCH, **additional_info, ) -> None: """Index Init @@ -93,6 +94,7 @@ def __init__( function=function, is_subscribed=is_subscribed, api_key=api_key, + function_type=function_type, **additional_info, ) self.url = config.MODELS_RUN_URL diff --git a/aixplain/modules/model/integration.py b/aixplain/modules/model/integration.py new file mode 100644 index 00000000..65afc322 --- /dev/null +++ b/aixplain/modules/model/integration.py @@ -0,0 +1,161 @@ +import warnings +from aixplain.enums import Function, Supplier, FunctionType +from aixplain.modules.model import Model, ModelResponse +from aixplain.utils import config +from typing import Text, Optional, Union, Dict +from enum import Enum +from pydantic import BaseModel + + +class AuthenticationSchema(Enum): + BEARER = "BEARER_TOKEN" + OAUTH = "OAUTH" + OAUTH2 = "OAUTH2" + + +class BaseAuthenticationParams(BaseModel): + name: Optional[Text] = None + authentication_schema: AuthenticationSchema = AuthenticationSchema.OAUTH2 + connector_id: Optional[Text] = None + + +class BearerAuthenticationParams(BaseAuthenticationParams): + token: Text + authentication_schema: AuthenticationSchema = AuthenticationSchema.BEARER + + +class OAuthAuthenticationParams(BaseAuthenticationParams): + client_id: Text + client_secret: Text + authentication_schema: AuthenticationSchema = AuthenticationSchema.OAUTH + + +class OAuth2AuthenticationParams(BaseAuthenticationParams): + authentication_schema: AuthenticationSchema = AuthenticationSchema.OAUTH2 + + +def build_connector_params(**kwargs) -> BaseAuthenticationParams: + name = kwargs.get("name") + token = kwargs.get("token") + client_id = kwargs.get("client_id") + client_secret = kwargs.get("client_secret") + connector_id = kwargs.get("connector_id") + if token: + args = BearerAuthenticationParams(name=name, token=token, connector_id=connector_id) + elif client_id and client_secret: + args = OAuthAuthenticationParams(name=name, client_id=client_id, client_secret=client_secret, connector_id=connector_id) + else: + args = OAuth2AuthenticationParams(name=name, connector_id=connector_id) + return args + + +class Integration(Model): + def __init__( + self, + id: Text, + name: Text, + description: Text = "", + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + version: Optional[Text] = None, + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + function_type: Optional[FunctionType] = FunctionType.INTEGRATION, + **additional_info, + ) -> None: + """Integration Init + + Args: + id (Text): ID of the Model + name (Text): Name of the Model + description (Text, optional): description of the model. Defaults to "". + api_key (Text, optional): API key of the Model. Defaults to None. + supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". + version (Text, optional): version of the model. Defaults to "1.0". + function (Function, optional): model AI function. Defaults to None. + is_subscribed (bool, optional): Is the user subscribed. Defaults to False. + cost (Dict, optional): model price. Defaults to None. + **additional_info: Any additional Model info to be saved + """ + assert function_type == FunctionType.INTEGRATION, "Integration only supports connector function" + super().__init__( + id=id, + name=name, + description=description, + supplier=supplier, + version=version, + cost=cost, + function=function, + is_subscribed=is_subscribed, + api_key=api_key, + function_type=function_type, + **additional_info, + ) + self.url = config.MODELS_RUN_URL + self.backend_url = config.BACKEND_URL + + def connect(self, args: Optional[BaseAuthenticationParams] = None, **kwargs) -> ModelResponse: + """Connect to the integration + + Examples: + - For Bearer Token Authentication: + >>> integration.connect(BearerAuthenticationSchema(name="My Connection", token="1234567890")) + >>> integration.connect(BearerAuthenticationSchema(token="1234567890")) + >>> integration.connect(token="1234567890") + - For OAuth Authentication: + >>> integration.connect(OAuthAuthenticationSchema(name="My Connection", client_id="1234567890", client_secret="1234567890")) + >>> integration.connect(OAuthAuthenticationSchema(client_id="1234567890", client_secret="1234567890")) + >>> integration.connect(client_id="1234567890", client_secret="1234567890") + - For OAuth2 Authentication: + >>> integration.connect(OAuth2AuthenticationSchema(name="My Connection")) + >>> integration.connect() + Make sure to click on the redirect url to complete the connection. + + Returns: + id: Connection ID (retrieve it with ModelFactory.get(id)) + redirectUrl: Redirect URL to complete the connection (only for OAuth2) + """ + if args is None: + args = build_connector_params(**kwargs) + + authentication_schema = args.authentication_schema + if authentication_schema == AuthenticationSchema.BEARER: + return self.run( + { + "name": args.name, + "authScheme": authentication_schema.value, + "data": { + "token": args.token, + }, + } + ) + elif authentication_schema == AuthenticationSchema.OAUTH: + return self.run( + { + "name": args.name, + "authScheme": authentication_schema.value, + "data": { + "client_id": args.client_id, + "client_secret": args.client_secret, + }, + } + ) + elif authentication_schema == AuthenticationSchema.OAUTH2: + response = self.run( + { + "name": args.name, + "authScheme": authentication_schema.value, + } + ) + if "redirectURL" in response.data: + warnings.warn( + f"Before using the tool, please visit the following URL to complete the connection: {response.data['redirectURL']}" + ) + return response + + def __repr__(self): + try: + return f"Integration: {self.name} by {self.supplier['name']} (id={self.id})" + except Exception: + return f"Integration: {self.name} by {self.supplier} (id={self.id})" diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index 546be1b8..a1eeb95e 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -23,7 +23,7 @@ import time import logging import traceback -from aixplain.enums import Function, Supplier +from aixplain.enums import Function, Supplier, FunctionType from aixplain.modules.model import Model from aixplain.modules.model.model_response_streamer import ModelResponseStreamer from aixplain.modules.model.utils import build_payload, call_run_endpoint @@ -48,6 +48,7 @@ class LLM(Model): url (str): URL to run the model. backend_url (str): URL of the backend. pricing (Dict, optional): model price. Defaults to None. + function_type (FunctionType, optional): type of the function. Defaults to FunctionType.AI. **additional_info: Any additional Model info to be saved """ @@ -63,6 +64,7 @@ def __init__( is_subscribed: bool = False, cost: Optional[Dict] = None, temperature: float = 0.001, + function_type: Optional[FunctionType] = FunctionType.AI, **additional_info, ) -> None: """LLM Init @@ -77,6 +79,7 @@ def __init__( function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. + function_type (FunctionType, optional): type of the function. Defaults to FunctionType.AI. **additional_info: Any additional Model info to be saved """ assert function == Function.TEXT_GENERATION, "LLM only supports large language models (i.e. text generation function)" @@ -90,6 +93,7 @@ def __init__( function=function, is_subscribed=is_subscribed, api_key=api_key, + function_type=function_type, **additional_info, ) self.url = config.MODELS_RUN_URL diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index f6b69ba2..7ad49673 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -20,7 +20,7 @@ """ import logging import warnings -from aixplain.enums import Function, Supplier, DataType +from aixplain.enums import Function, Supplier, DataType, FunctionType from aixplain.enums import AssetStatus from aixplain.modules.model import Model from aixplain.utils import config @@ -30,6 +30,13 @@ from typing import Callable, Union, Optional, List, Text, Dict from urllib.parse import urljoin from aixplain.modules.mixins import DeployableMixin +from pydantic import BaseModel + + +class BaseUtilityModelParams(BaseModel): + name: Text + code: Union[Text, Callable] + description: Optional[Text] = None @dataclass @@ -127,6 +134,7 @@ def __init__( is_subscribed: bool = False, cost: Optional[Dict] = None, status: AssetStatus = AssetStatus.DRAFT, + function_type: Optional[FunctionType] = FunctionType.UTILITY, **additional_info, ) -> None: """Utility Model Init @@ -144,6 +152,7 @@ def __init__( function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. + function_type (FunctionType, optional): type of the function. Defaults to FunctionType.UTILITY. **additional_info: Any additional Model info to be saved """ assert function == Function.UTILITIES, "Utility Model only supports 'utilities' function" @@ -158,6 +167,7 @@ def __init__( is_subscribed=is_subscribed, api_key=api_key, status=status, + function_type=function_type, **additional_info, ) self.url = config.MODELS_RUN_URL @@ -276,3 +286,9 @@ def delete(self): message = f"Utility Model Deletion Error: {response}" logging.error(message) raise Exception(f"{message}") + + def __repr__(self): + try: + return f"UtilityModel: {self.name} by {self.supplier['name']} (id={self.id})" + except Exception: + return f"UtilityModel: {self.name} by {self.supplier} (id={self.id})" diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 94bbdd61..5bfeddf7 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -17,6 +17,7 @@ """ import copy import json +import os from dotenv import load_dotenv load_dotenv() @@ -335,7 +336,7 @@ def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agent name="Test Parameter Agent", description="Test agent with parameterized tools. You MUST use a tool for the tasks.", tools=[tool], - llm_id="6626a3a8c8f1d089790cf5a2", # Using LLM ID from test data + llm_id="6646261c6eb563165658bbb1", # Using LLM ID from test data ) # Run agent @@ -405,6 +406,7 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): if agent: agent.delete() + @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents @@ -578,6 +580,9 @@ def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents + for pipeline in PipelineFactory.list(query="Hello Pipeline")["results"]: + pipeline.delete() + pipeline = PipelineFactory.init("Hello Pipeline") input_node = pipeline.input() input_node.label = "TextInput" @@ -606,6 +611,7 @@ def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory): assert "hello" in answer["data"]["output"].lower() assert "hello pipeline" in answer["data"]["intermediate_steps"][0]["tool_steps"][0]["tool"].lower() + @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_agent_llm_parameter_preservation(delete_agents_and_team_agents, AgentFactory): """Test that LLM parameters like temperature are preserved when creating agents.""" @@ -637,6 +643,7 @@ def test_agent_llm_parameter_preservation(delete_agents_and_team_agents, AgentFa # Reset the LLM temperature to its original value llm.temperature = original_temperature + def test_run_agent_with_expected_output(): from pydantic import BaseModel from typing import Optional, List @@ -715,3 +722,35 @@ class Response(BaseModel): assert "age" in person assert "city" in person assert person["name"] in more_than_30_years_old + + +def test_agent_with_action_tool(): + from aixplain.modules.model.integration import AuthenticationSchema + + connector = ModelFactory.get("67eff5c0e05614297caeef98") + # connect + response = connector.connect(authentication_schema=AuthenticationSchema.BEARER, token=os.getenv("SLACK_TOKEN")) + connection_id = response.data["id"] + + connection = ModelFactory.get(connection_id) + connection.action_scope = [action for action in connection.actions if action.code == "SLACK_CHAT_POST_MESSAGE"] + + agent = AgentFactory.create( + name="Test Agent", + description="This agent is used to send messages to Slack", + instructions="You are a helpful assistant that can send messages to Slack.", + llm_id="669a63646eb56306647e1091", + tools=[ + connection, + AgentFactory.create_model_tool(model="6736411cf127849667606689"), + ], + ) + + response = agent.run( + "Send what is the capital of Finland on Slack to channel of #modelserving-alerts: 'C084G435LR5'. Add the name of the capital in the final answer." + ) + assert response is not None + assert response["status"].lower() == "success" + assert "helsinki" in response.data.output.lower() + assert "SLACK_CHAT_POST_MESSAGE" in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] + connection.delete() diff --git a/tests/functional/model/run_connect_model_test.py b/tests/functional/model/run_connect_model_test.py new file mode 100644 index 00000000..3f89b4ad --- /dev/null +++ b/tests/functional/model/run_connect_model_test.py @@ -0,0 +1,30 @@ +import os +from aixplain.enums import ResponseStatus +from aixplain.factories import ModelFactory +from aixplain.modules.model.integration import Integration, AuthenticationSchema +from aixplain.modules.model.connection import ConnectionTool + + +def test_run_connect_model(): + # get slack connector + connector = ModelFactory.get("67eff5c0e05614297caeef98") + + assert isinstance(connector, Integration) + assert connector.id == "67eff5c0e05614297caeef98" + assert connector.name == "Slack" + + response = connector.connect(authentication_schema=AuthenticationSchema.BEARER, token=os.getenv("SLACK_TOKEN")) + assert response.status == ResponseStatus.SUCCESS + assert "id" in response.data + connection_id = response.data["id"] + # get slack connection + connection = ModelFactory.get(connection_id) + assert isinstance(connection, ConnectionTool) + assert connection.id == connection_id + assert connection.actions is not None + + action = [action for action in connection.actions if action.code == "SLACK_CHAT_POST_MESSAGE"] + assert len(action) > 0 + action = action[0] + response = connection.run(action, {"text": "This is a test!", "channel": "C084G435LR5"}) + assert response.status == ResponseStatus.SUCCESS diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index fc6b66af..f3852f60 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -16,6 +16,7 @@ limitations under the License. """ import json +import os from dotenv import load_dotenv from uuid import uuid4 @@ -439,3 +440,52 @@ class Response(BaseModel): assert "age" in person assert "city" in person assert person["name"] in more_than_30_years_old + + +def test_team_agent_with_slack_connector(): + from aixplain.modules.model.integration import AuthenticationSchema + + connector = ModelFactory.get("67eff5c0e05614297caeef98") + # connect + response = connector.connect(authentication_schema=AuthenticationSchema.BEARER, token=os.getenv("SLACK_TOKEN")) + connection_id = response.data["id"] + + connection = ModelFactory.get(connection_id) + connection.action_scope = [action for action in connection.actions if action.code == "SLACK_CHAT_POST_MESSAGE"] + + agent = AgentFactory.create( + name="Test Agent", + description="This agent is used to send messages to Slack", + instructions="You are a helpful assistant that can answer questions based on a large knowledge base and send messages to Slack.", + llm_id="669a63646eb56306647e1091", + tasks=[ + AgentFactory.create_task( + name="Task 1", + description="Check knowledge base for information about the query and send the response to Slack", + expected_output="A message sent to Slack", + ) + ], + tools=[ + connection, + AgentFactory.create_model_tool(model="6736411cf127849667606689"), + ], + ) + + team_agent = TeamAgentFactory.create( + name="Team Agent", + agents=[agent], + description="Team agent", + llm_id="6646261c6eb563165658bbb1", + use_mentalist=False, + use_inspector=False, + ) + + response = team_agent.run( + "Send what is the capital of Senegal on Slack to channel of #modelserving-alerts: 'C084G435LR5'. Add the name of the capital in the final answer." + ) + assert response["status"].lower() == "success" + assert "dakar" in response.data.output.lower() + + team_agent.delete() + agent.delete() + connection.delete() diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 6c1a372f..9abc6fe0 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -23,7 +23,7 @@ from aixplain.modules import Model from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.factories import ModelFactory -from aixplain.enums import Function +from aixplain.enums import Function, FunctionType from urllib.parse import urljoin from aixplain.modules.model.response import ModelResponse, ResponseStatus from aixplain.modules.model.model_response_streamer import ModelResponseStreamer @@ -31,6 +31,11 @@ from unittest.mock import patch from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model.model_parameters import ModelParameters +from aixplain.modules.model.llm_model import LLM +from aixplain.modules.model.index_model import IndexModel +from aixplain.modules.model.utility_model import UtilityModel +from aixplain.modules.model.integration import Integration, AuthenticationSchema +from aixplain.modules.model.connection import ConnectionTool, ConnectAction def test_build_payload(): @@ -165,7 +170,7 @@ def test_get_model_error_response(): def test_get_assets_from_page_error(): - from aixplain.factories.model_factory.utils import get_assets_from_page + from aixplain.factories.model_factory.mixins.model_list import get_assets_from_page with requests_mock.Mocker() as mock: query = "test-query" @@ -654,3 +659,235 @@ def test_model_not_supports_streaming(mocker): with pytest.raises(Exception) as excinfo: model.run(data="test", stream=True) assert f"Model '{model.name} ({model.id})' does not support streaming" in str(excinfo.value) + + +@pytest.mark.parametrize( + "payload, expected_model_class", + [ + ( + { + "id": "connector-id", + "name": "connector-name", + "function": {"id": "utilities"}, + "functionType": "connector", + "supplier": "aiXplain", + "api_key": "api_key", + "pricing": {"price": 10, "currency": "USD"}, + "params": {}, + "version": {"id": "1.0"}, + }, + Integration, + ), + ( + { + "id": "llm-id", + "name": "llm-name", + "function": {"id": "text-generation"}, + "functionType": "ai", + "supplier": "aiXplain", + "api_key": "api_key", + "pricing": {"price": 10, "currency": "USD"}, + "params": {}, + "version": {"id": "1.0"}, + }, + LLM, + ), + ( + { + "id": "index-id", + "name": "index-name", + "function": {"id": "search"}, + "functionType": "ai", + "supplier": "aiXplain", + "api_key": "api_key", + "pricing": {"price": 10, "currency": "USD"}, + "params": {}, + "version": {"id": "1.0"}, + }, + IndexModel, + ), + ( + { + "id": "utility-id", + "name": "utility-name", + "function": {"id": "utilities"}, + "functionType": "utility", + "supplier": "aiXplain", + "api_key": "api_key", + "pricing": {"price": 10, "currency": "USD"}, + "params": {}, + "version": {"id": "1.0"}, + }, + UtilityModel, + ), + ], +) +def test_create_model_from_response(payload, expected_model_class): + from aixplain.factories.model_factory.utils import create_model_from_response + from aixplain.enums import FunctionType + + model = create_model_from_response(payload) + assert isinstance(model, expected_model_class) + assert model.id == payload["id"] + assert model.name == payload["name"] + assert model.function == Function(payload["function"]["id"]) + assert model.function_type == FunctionType(payload["functionType"]) + assert model.api_key == payload["api_key"] + + +@pytest.mark.parametrize( + "authentication_schema, name, token, client_id, client_secret", + [ + (AuthenticationSchema.BEARER, "test-name", "test-token", None, None), + (AuthenticationSchema.OAUTH, "test-name", None, "test-client-id", "test-client-secret"), + ], +) +def test_connector_connect(mocker, authentication_schema, name, token, client_id, client_secret): + mocker.patch("aixplain.modules.model.integration.Integration.run", return_value={"id": "test-id"}) + connector = Integration( + id="connector-id", + name="connector-name", + function=Function.UTILITIES, + function_type=FunctionType.INTEGRATION, + supplier="aiXplain", + api_key="api_key", + version={"id": "1.0"}, + ) + response = connector.connect( + authentication_schema=authentication_schema, name=name, token=token, client_id=client_id, client_secret=client_secret + ) + assert response == {"id": "test-id"} + + +def test_connection_init_with_actions(mocker): + mocker.patch( + "aixplain.modules.model.Model.run", + side_effect=[ + ModelResponse( + status=ResponseStatus.SUCCESS, + data=[{"displayName": "test-name", "description": "test-description", "name": "test-code"}], + ), + ModelResponse( + status=ResponseStatus.SUCCESS, + data=[{"inputs": [{"code": "test-code", "name": "test-name", "description": "test-description"}]}], + ), + ], + ) + connection = ConnectionTool( + id="connection-id", + name="connection-name", + function=Function.UTILITIES, + function_type=FunctionType.CONNECTION, + supplier="aiXplain", + api_key="api_key", + version={"id": "1.0"}, + ) + assert connection.id == "connection-id" + assert connection.name == "connection-name" + assert connection.function == Function.UTILITIES + assert connection.function_type == FunctionType.CONNECTION + assert connection.api_key == "api_key" + assert connection.version == {"id": "1.0"} + assert connection.actions is not None + assert len(connection.actions) == 1 + assert connection.actions[0].name == "test-name" + assert connection.actions[0].description == "test-description" + assert connection.actions[0].code == "test-code" + + action = ConnectAction(code="test-code", name="test-name", description="test-description") + inputs = connection.get_action_inputs(action) + assert "test-code" in inputs + assert inputs["test-code"]["name"] == "test-name" + assert inputs["test-code"]["description"] == "test-description" + + +def test_tool_factory(mocker): + from aixplain.factories import ToolFactory + from aixplain.modules.model.utility_model import BaseUtilityModelParams + + # Utility Model + mocker.patch( + "aixplain.factories.model_factory.ModelFactory.create_utility_model", + return_value=UtilityModel( + id="test-id", + name="test-name", + function=Function.UTILITIES, + function_type=FunctionType.AI, + api_key="api_key", + version={"id": "1.0"}, + ), + ) + + def add(aaa: int, bbb: int) -> int: + return aaa + bbb + + params = BaseUtilityModelParams(name="My Script Model", description="My Script Model Description", code=add) + tool = ToolFactory.create(params=params) + assert isinstance(tool, UtilityModel) + assert tool.id == "test-id" + assert tool.name == "test-name" + assert tool.function == Function.UTILITIES + assert tool.function_type == FunctionType.AI + assert tool.api_key == "api_key" + assert tool.version == {"id": "1.0"} + + # Index Model + from aixplain.factories.index_factory.utils import AirParams + + params = AirParams(name="My Search Collection", description="My Search Collection Description") + mocker.patch( + "aixplain.factories.index_factory.IndexFactory.create", + return_value=IndexModel( + id="test-id", + name="test-name", + function=Function.SEARCH, + function_type=FunctionType.SEARCH, + api_key="api_key", + version={"id": "1.0"}, + ), + ) + tool = ToolFactory.create(params=params) + assert isinstance(tool, IndexModel) + assert tool.id == "test-id" + assert tool.name == "test-name" + assert tool.function == Function.SEARCH + assert tool.function_type == FunctionType.SEARCH + assert tool.api_key == "api_key" + assert tool.version == {"id": "1.0"} + + # Integration Model + mocker.patch("aixplain.modules.model.connection.ConnectionTool._get_actions", return_value=[]) + mocker.patch( + "aixplain.modules.model.integration.Integration.connect", + return_value=ModelResponse(status=ResponseStatus.SUCCESS, data={"id": "connection-id"}), + ) + + def get_mock(id): + if id == "67eff5c0e05614297caeef98": + return Integration( + id="67eff5c0e05614297caeef98", + name="test-name", + function=Function.UTILITIES, + function_type=FunctionType.INTEGRATION, + api_key="api_key", + version={"id": "1.0"}, + ) + elif id == "connection-id": + return ConnectionTool( + id="connection-id", + name="test-name", + function=Function.UTILITIES, + function_type=FunctionType.CONNECTION, + api_key="api_key", + version={"id": "1.0"}, + ) + + mocker.patch("aixplain.factories.tool_factory.ToolFactory.get", side_effect=get_mock) + tool = ToolFactory.create(integration="67eff5c0e05614297caeef98", name="My Connector 1234", token="slack-token") + assert isinstance(tool, ConnectionTool) + assert tool.id == "connection-id" + assert tool.name == "test-name" + assert tool.function == Function.UTILITIES + assert tool.function_type == FunctionType.CONNECTION + assert tool.api_key == "api_key" + assert tool.version == {"id": "1.0"} From b7a929d95fd4890b8ef47a305510a143e46010a3 Mon Sep 17 00:00:00 2001 From: Yunsu Kim Date: Fri, 6 Jun 2025 23:56:59 +0200 Subject: [PATCH 42/62] Add input target for inspectors (#548) * Add input inspector target * Add functional tests for inspector target 'input' --- aixplain/modules/team_agent/__init__.py | 2 +- .../team_agent/inspector_functional_test.py | 189 ++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 75f0b0dd..adf0976b 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -49,7 +49,7 @@ class InspectorTarget(str, Enum): - # TODO: INPUT + INPUT = "input" STEPS = "steps" OUTPUT = "output" diff --git a/tests/functional/team_agent/inspector_functional_test.py b/tests/functional/team_agent/inspector_functional_test.py index b415e719..e1679dde 100644 --- a/tests/functional/team_agent/inspector_functional_test.py +++ b/tests/functional/team_agent/inspector_functional_test.py @@ -368,3 +368,192 @@ def test_team_agent_with_multiple_inspector_targets(run_input_map, delete_agents assert response["data"]["critiques"], "No critiques found in response data" team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_input_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with input inspector that runs before any steps are executed""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create inspector with warn policy + inspector = Inspector( + name="input_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Check if the input is valid and provide feedback"}, + policy=InspectorPolicy.WARN, + ) + + # Create team agent with input inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.INPUT], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + verify_inspector_steps(steps, ["input_inspector"], [InspectorTarget.INPUT]) + verify_response_generator(steps) + + # Verify inspector runs and execution continues + inspector_steps = [step for step in steps if "input_inspector" in step.get("agent", "").lower()] + assert len(inspector_steps) > 0, "Input inspector should run at least once" + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_input_abort_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with input inspector (ABORT policy): if critiques are non-empty, response_generator is called immediately after inspector.""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create inspector with abort policy + inspector = Inspector( + name="input_abort_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Always find issues and provide negative feedback on input"}, + policy=InspectorPolicy.ABORT, + ) + + # Create team agent with input inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.INPUT], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + verify_inspector_steps(steps, ["input_abort_inspector"], [InspectorTarget.INPUT]) + verify_response_generator(steps) + + # Critiques should be present and non-empty in the inspector step's 'thought' + inspector_steps = [step for step in steps if "input_abort_inspector" in step.get("agent", "").lower()] + assert len(inspector_steps) == 1, "Input abort inspector should only run once" + inspector_thought = inspector_steps[0].get("thought", "") + assert inspector_thought, "No thought found in inspector step" + assert ( + "critique" in inspector_thought.lower() or len(inspector_thought.strip()) > 0 + ), "Inspector step's thought does not contain critique or is empty" + + # Inspector should run once, then response_generator should come right after + response_generator_index = next( + (i for i, step in enumerate(steps) if "response_generator" in step.get("agent", "").lower()), + None, + ) + assert response_generator_index is not None, "No response_generator step found" + assert ( + response_generator_index == steps.index(inspector_steps[0]) + 1 + ), "Response generator should come right after input abort inspector critique" + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_input_adaptive_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with input inspector (ADAPTIVE policy): query_manager step exists more than once and mentalist creates a plan for the revised query (output of the last query_manager).""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create inspector with adaptive policy + inspector = Inspector( + name="input_adaptive_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "If the input is not valid, suggest a revised query and critique."}, + policy=InspectorPolicy.ADAPTIVE, + ) + + # Create team agent with input inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.INPUT], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + verify_inspector_steps(steps, ["input_adaptive_inspector"], [InspectorTarget.INPUT]) + verify_response_generator(steps) + + # There should be more than one query_manager step + query_manager_steps = [step for step in steps if "query_manager" in step.get("agent", "").lower()] + assert len(query_manager_steps) > 1, "There should be more than one query_manager step for adaptive input inspector" + + # The last query_manager's output should be contained in the mentalist's input + last_query_manager = query_manager_steps[-1] + revised_query = last_query_manager.get("output", None) + assert revised_query, "No output found in the last query_manager step" + + # There must be only one mentalist step + mentalist_steps = [step for step in steps if "mentalist" in step.get("agent", "").lower()] + mentalist_input = mentalist_steps[0].get("input", None) + assert ( + mentalist_input and revised_query in mentalist_input + ), "The mentalist input does not contain the revised query from the last query_manager" + + team_agent.delete() From 6067f16648c62a9ead98f1d17947f962198fc1f9 Mon Sep 17 00:00:00 2001 From: kadirpekel Date: Tue, 10 Jun 2025 19:31:29 +0200 Subject: [PATCH 43/62] ENG-1900 Refined agent deletion error messages (#543) --- aixplain/modules/agent/__init__.py | 54 +++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index f4049adf..34ae74ec 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -395,19 +395,63 @@ def delete(self) -> None: "x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json", } - logging.debug(f"Start service for DELETE Agent - {url} - {headers}") + logging.debug( + f"Start service for DELETE Agent - {url} - {headers}" + ) r = _request_with_retry("delete", url, headers=headers) - logging.debug(f"Result of request for DELETE Agent - {r.status_code}") + logging.debug( + f"Result of request for DELETE Agent - {r.status_code}" + ) if r.status_code != 200: raise Exception() except Exception: try: response_json = r.json() - message = f"Agent Deletion Error (HTTP {r.status_code}): {response_json.get('message', '').strip('{{}}')}." + error_message = response_json.get('message', '').strip('{{}}') + + if r.status_code == 403 and error_message == "err.agent_is_in_use": + # Get team agents that use this agent + from aixplain.factories.team_agent_factory import ( + TeamAgentFactory + ) + team_agents = TeamAgentFactory.list()["results"] + using_team_agents = [ + ta for ta in team_agents + if any(agent.id == self.id for agent in ta.agents) + ] + + if using_team_agents: + # Scenario 1: User has access to team agents + team_agent_ids = [ta.id for ta in using_team_agents] + message = ( + "Error: Agent cannot be deleted.\n" + "Reason: This agent is currently used by one or more " + "team agents.\n\n" + f"team_agent_id: {', '.join(team_agent_ids)}. " + "To proceed, remove the agent from all team agents " + "before deletion." + ) + else: + # Scenario 2: User doesn't have access to team agents + message = ( + "Error: Agent cannot be deleted.\n" + "Reason: This agent is currently used by one or more " + "team agents.\n\n" + "One or more inaccessible team agents are " + "referencing it." + ) + else: + message = ( + f"Agent Deletion Error (HTTP {r.status_code}): " + f"{error_message}." + ) except ValueError: - message = f"Agent Deletion Error (HTTP {r.status_code}): There was an error in deleting the agent." + message = ( + f"Agent Deletion Error (HTTP {r.status_code}): " + "There was an error in deleting the agent." + ) logging.error(message) - raise Exception(f"{message}") + raise Exception(message) def update(self) -> None: """Update agent.""" From c570a70a13dfcd159b7276299ad86d1b2240f857 Mon Sep 17 00:00:00 2001 From: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Date: Mon, 16 Jun 2025 16:16:18 -0300 Subject: [PATCH 44/62] ENG-2271: Add code interpreter model ID (#550) * Add code interpreter model ID * Add CodeInterpreter to enums init --- aixplain/enums/__init__.py | 1 + aixplain/enums/code_interpeter.py | 10 ++++++++++ aixplain/modules/agent/tool/custom_python_code_tool.py | 3 +++ 3 files changed, 14 insertions(+) create mode 100644 aixplain/enums/code_interpeter.py diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 725fdb90..6162ed86 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -20,3 +20,4 @@ from .asset_status import AssetStatus from .index_stores import IndexStores from .function_type import FunctionType +from .code_interpeter import CodeInterpreterModel diff --git a/aixplain/enums/code_interpeter.py b/aixplain/enums/code_interpeter.py new file mode 100644 index 00000000..9f0a14c6 --- /dev/null +++ b/aixplain/enums/code_interpeter.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class CodeInterpreterModel(str, Enum): + """Code Interpreter Model IDs""" + + PYTHON_AZURE = "67476fa16eb563d00060ad62" + + def __str__(self): + return self._value_ diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index 6433a408..05715b2a 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -25,6 +25,7 @@ from aixplain.modules.agent.tool import Tool import logging from aixplain.enums import AssetStatus +from aixplain.enums.code_interpeter import CodeInterpreterModel class CustomPythonCodeTool(Tool): @@ -37,11 +38,13 @@ def __init__( super().__init__(name=name or "", description=description, **additional_info) self.code = code self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool + self.id = CodeInterpreterModel.PYTHON_AZURE self.validate() def to_dict(self): return { + "id": self.id, "name": self.name, "description": self.description, "type": "utility", From 80ddb8a90dcd71c2f07eb69b9527c5a5db226a93 Mon Sep 17 00:00:00 2001 From: Hadi Nasrallah <87204330+hadi-aix@users.noreply.github.com> Date: Wed, 18 Jun 2025 14:37:01 -0400 Subject: [PATCH 45/62] Update enums.py (#552) --- aixplain/modules/pipeline/designer/enums.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aixplain/modules/pipeline/designer/enums.py b/aixplain/modules/pipeline/designer/enums.py index fe4cbfed..d7c286f5 100644 --- a/aixplain/modules/pipeline/designer/enums.py +++ b/aixplain/modules/pipeline/designer/enums.py @@ -36,6 +36,9 @@ class FunctionType(str, Enum): RECONSTRUCTOR = "reconstructor" UTILITY = "utility" METRIC = "metric" + CONNECTOR = 'connector' + CONNECTION = 'connection' + MCPSERVER = 'mcpserver' class ParamType: From 39be1d05bf4fd24fbab6bf4047c4113f65a8d1c8 Mon Sep 17 00:00:00 2001 From: Hadi Nasrallah <87204330+hadi-aix@users.noreply.github.com> Date: Wed, 18 Jun 2025 14:50:05 -0400 Subject: [PATCH 46/62] Update enums.py (#554) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update enums.py * Update enum --------- Co-authored-by: Ahmet Gündüz --- aixplain/enums/function_type.py | 1 + aixplain/modules/pipeline/designer/enums.py | 14 +------------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/aixplain/enums/function_type.py b/aixplain/enums/function_type.py index 514ff992..ae6f8e79 100644 --- a/aixplain/enums/function_type.py +++ b/aixplain/enums/function_type.py @@ -33,3 +33,4 @@ class FunctionType(Enum): SEARCH = "search" INTEGRATION = "connector" CONNECTION = "connection" + MCPSERVER = 'mcpserver' diff --git a/aixplain/modules/pipeline/designer/enums.py b/aixplain/modules/pipeline/designer/enums.py index d7c286f5..b733265d 100644 --- a/aixplain/modules/pipeline/designer/enums.py +++ b/aixplain/modules/pipeline/designer/enums.py @@ -1,5 +1,5 @@ from enum import Enum - +from aixplain.enums import FunctionType class RouteType(str, Enum): CHECK_TYPE = "checkType" @@ -29,18 +29,6 @@ class NodeType(str, Enum): class AssetType(str, Enum): MODEL = "MODEL" - -class FunctionType(str, Enum): - AI = "ai" - SEGMENTOR = "segmentor" - RECONSTRUCTOR = "reconstructor" - UTILITY = "utility" - METRIC = "metric" - CONNECTOR = 'connector' - CONNECTION = 'connection' - MCPSERVER = 'mcpserver' - - class ParamType: INPUT = "INPUT" OUTPUT = "OUTPUT" From 095f9b8ab0216a374d9d3dcbec4643e3b0f11213 Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Thu, 19 Jun 2025 20:51:58 +0300 Subject: [PATCH 47/62] removed pipeline cache test (#556) Co-authored-by: Zaina --- tests/functional/pipelines/create_test.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/functional/pipelines/create_test.py b/tests/functional/pipelines/create_test.py index f1dac2c4..076a637f 100644 --- a/tests/functional/pipelines/create_test.py +++ b/tests/functional/pipelines/create_test.py @@ -16,7 +16,6 @@ limitations under the License. """ import os -from aixplain.utils.cache_utils import CACHE_FOLDER from aixplain.modules.pipeline import Pipeline import json import pytest @@ -79,25 +78,3 @@ def test_create_pipeline_wrong_path(PipelineFactory): with pytest.raises(Exception): PipelineFactory.create(name=pipeline_name, pipeline="/") - - -@pytest.mark.parametrize("PipelineFactory", [PipelineFactory]) -def test_pipeline_cache_creation(PipelineFactory): - cache_file = os.path.join(CACHE_FOLDER, "pipelines.json") - if os.path.exists(cache_file): - os.remove(cache_file) - - pipeline_json = "tests/functional/pipelines/data/pipeline.json" - pipeline_name = str(uuid4()) - pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_json) - - assert os.path.exists(cache_file), "Pipeline cache file was not created!" - - with open(cache_file, "r") as f: - cache_data = json.load(f) - - assert "data" in cache_data, "Cache format invalid, missing 'data'." - - pipeline.delete() - if os.path.exists(cache_file): - os.remove(cache_file) \ No newline at end of file From 01a7f9eb2fcb858f3a456e7de08d08fbb65cd5f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Fri, 20 Jun 2025 19:08:02 +0300 Subject: [PATCH 48/62] Functional test fix of agent in use error (#560) * function test fix of agent in use error * more fixes --- tests/functional/agent/agent_functional_test.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 08e6e095..2783ffc3 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -18,6 +18,7 @@ import copy import json import os +import re from dotenv import load_dotenv load_dotenv() @@ -234,7 +235,10 @@ def test_delete_agent_in_use(delete_agents_and_team_agents, AgentFactory): with pytest.raises(Exception) as exc_info: agent.delete() - assert str(exc_info.value) == "Agent Deletion Error (HTTP 403): err.agent_is_in_use." + assert re.match( + r"Error: Agent cannot be deleted\.\nReason: This agent is currently used by one or more team agents\.\n\nteam_agent_id: [a-f0-9]{24}\. To proceed, remove the agent from all team agents before deletion\.", + str(exc_info.value), + ) @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) @@ -302,7 +306,7 @@ def test_update_tools_of_agent(run_input_map, delete_agents_and_team_agents, Age "type": "translation", "supplier": "Microsoft", "function": "translation", - "query": "Translate: Olá, como vai você?", + "query": "Translate: 'Olá, como vai você?'", "description": "Translation tool with target language", "expected_tool_input": "targetlanguage", }, @@ -334,7 +338,7 @@ def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agent # Create and run agent agent = AgentFactory.create( name="Test Parameter Agent", - description="Test agent with parameterized tools. You MUST use a tool for the tasks.", + description="Test agent with parameterized tools. You MUST use a tool for the tasks. Do not directly answer the question.", tools=[tool], llm_id="6646261c6eb563165658bbb1", # Using LLM ID from test data ) @@ -351,8 +355,9 @@ def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agent # Verify tool was used in execution assert len(response["data"]["intermediate_steps"]) > 0 tool_used = False + for step in response["data"]["intermediate_steps"]: - if tool_config["expected_tool_input"] in step["tool_steps"][0]["input"]: + if len(step["tool_steps"]) > 0 and tool_config["expected_tool_input"] in step["tool_steps"][0]["input"]: tool_used = True break assert tool_used, "Tool was not used in execution" @@ -643,6 +648,7 @@ def test_agent_llm_parameter_preservation(delete_agents_and_team_agents, AgentFa # Reset the LLM temperature to its original value llm.temperature = original_temperature + def test_run_agent_with_expected_output(): from pydantic import BaseModel from typing import Optional, List @@ -753,4 +759,3 @@ def test_agent_with_action_tool(): assert "helsinki" in response.data.output.lower() assert "SLACK_CHAT_POST_MESSAGE" in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] connection.delete() - From 585444ae65b72b26d927bc52d68114783c1f77c0 Mon Sep 17 00:00:00 2001 From: kadirpekel Date: Fri, 20 Jun 2025 18:58:51 +0200 Subject: [PATCH 49/62] BUG-574 Fixed functionType type handling (#561) --- aixplain/modules/pipeline/designer/nodes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/aixplain/modules/pipeline/designer/nodes.py b/aixplain/modules/pipeline/designer/nodes.py index e81be81a..e92d5fe1 100644 --- a/aixplain/modules/pipeline/designer/nodes.py +++ b/aixplain/modules/pipeline/designer/nodes.py @@ -1,4 +1,5 @@ from typing import List, Union, Type, TYPE_CHECKING, Optional +from enum import Enum from aixplain.modules import Model from aixplain.enums import DataType, Function @@ -142,7 +143,11 @@ def serialize(self) -> dict: obj["supplier"] = self.supplier obj["version"] = self.version obj["assetType"] = self.assetType - obj["functionType"] = self.functionType + # Handle functionType as enum or string + if isinstance(self.functionType, Enum): + obj["functionType"] = self.functionType.value + else: + obj["functionType"] = self.functionType obj["type"] = self.type return obj From 1e846276db722837a21616f30cc622021fbe5aca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Tue, 24 Jun 2025 15:26:55 +0300 Subject: [PATCH 50/62] PROD-1833: Optional Instructions in Agents (#559) --- aixplain/factories/agent_factory/__init__.py | 9 +- aixplain/factories/agent_factory/utils.py | 2 +- aixplain/modules/agent/__init__.py | 34 +--- aixplain/modules/agent/utils.py | 6 +- tests/unit/agent/agent_test.py | 204 ++++++++++++++++++- tests/unit/agent/sql_tool_test.py | 4 +- 6 files changed, 214 insertions(+), 45 deletions(-) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 53ef3e63..a00fe3bc 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -92,13 +92,6 @@ def create( # Use default GPT-4o if no LLM specified llm = get_llm_instance("669a63646eb56306647e1091", api_key=api_key) - if instructions is None: - warnings.warn( - "Use `instructions` to define the **system prompt**. " - "Use `description` to provide a **short summary** of the agent for metadata and dashboard display. " - "Note: In upcoming releases, `instructions` will become a required parameter.", - UserWarning, - ) warnings.warn( "Use `llm` to define the large language model (aixplain.modules.model.llm_model.LLM) to be used as agent. " "Use `llm_id` to provide the model ID of the large language model to be used as agent. " @@ -406,4 +399,4 @@ def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> Agent: if "message" in resp: msg = resp["message"] error_msg = f"Agent Get Error (HTTP {r.status_code}): {msg}" - raise Exception(error_msg) \ No newline at end of file + raise Exception(error_msg) diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 8264c07e..450da273 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -177,7 +177,7 @@ def build_agent(payload: Dict, tools: List[Tool] = None, api_key: Text = config. name=payload.get("name", ""), tools=payload_tools, description=payload.get("description", ""), - instructions=payload.get("role", ""), + instructions=payload.get("role"), supplier=payload.get("teamId", None), version=payload.get("version", None), cost=payload.get("cost", None), diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 34ae74ec..e3ca6f89 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -68,7 +68,7 @@ def __init__( id: Text, name: Text, description: Text, - instructions: Text, + instructions: Optional[Text] = None, tools: List[Union[Tool, Model]] = [], llm_id: Text = "6646261c6eb563165658bbb1", llm: Optional[LLM] = None, @@ -370,7 +370,7 @@ def to_dict(self) -> Dict: "name": self.name, "assets": [build_tool_payload(tool) for tool in self.tools], "description": self.description, - "role": self.instructions, + "role": self.instructions or self.description, "supplier": (self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier), "version": self.version, "llmId": self.llm_id if self.llm is None else self.llm.id, @@ -395,30 +395,22 @@ def delete(self) -> None: "x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json", } - logging.debug( - f"Start service for DELETE Agent - {url} - {headers}" - ) + logging.debug(f"Start service for DELETE Agent - {url} - {headers}") r = _request_with_retry("delete", url, headers=headers) - logging.debug( - f"Result of request for DELETE Agent - {r.status_code}" - ) + logging.debug(f"Result of request for DELETE Agent - {r.status_code}") if r.status_code != 200: raise Exception() except Exception: try: response_json = r.json() - error_message = response_json.get('message', '').strip('{{}}') + error_message = response_json.get("message", "").strip("{{}}") if r.status_code == 403 and error_message == "err.agent_is_in_use": # Get team agents that use this agent - from aixplain.factories.team_agent_factory import ( - TeamAgentFactory - ) + from aixplain.factories.team_agent_factory import TeamAgentFactory + team_agents = TeamAgentFactory.list()["results"] - using_team_agents = [ - ta for ta in team_agents - if any(agent.id == self.id for agent in ta.agents) - ] + using_team_agents = [ta for ta in team_agents if any(agent.id == self.id for agent in ta.agents)] if using_team_agents: # Scenario 1: User has access to team agents @@ -441,15 +433,9 @@ def delete(self) -> None: "referencing it." ) else: - message = ( - f"Agent Deletion Error (HTTP {r.status_code}): " - f"{error_message}." - ) + message = f"Agent Deletion Error (HTTP {r.status_code}): " f"{error_message}." except ValueError: - message = ( - f"Agent Deletion Error (HTTP {r.status_code}): " - "There was an error in deleting the agent." - ) + message = f"Agent Deletion Error (HTTP {r.status_code}): " "There was an error in deleting the agent." logging.error(message) raise Exception(message) diff --git a/aixplain/modules/agent/utils.py b/aixplain/modules/agent/utils.py index aba5bb1c..684c82db 100644 --- a/aixplain/modules/agent/utils.py +++ b/aixplain/modules/agent/utils.py @@ -2,7 +2,9 @@ import re -def process_variables(query: Union[Text, Dict], data: Union[Dict, Text], parameters: Dict, agent_description: Text) -> Text: +def process_variables( + query: Union[Text, Dict], data: Union[Dict, Text], parameters: Dict, agent_description: Union[Text, None] +) -> Text: from aixplain.factories.file_factory import FileFactory if isinstance(query, dict): @@ -13,7 +15,7 @@ def process_variables(query: Union[Text, Dict], data: Union[Dict, Text], paramet else: input_data = {"input": FileFactory.to_link(query)} - variables = re.findall(r"(?"): AgentFactory.create_sql_tool( name="Test SQL", description="Test", source=db_path, source_type=123, schema="test" - ) # Invalid type \ No newline at end of file + ) # Invalid type From ea51442a67debc1a5ef9781505b68ee3190cf466 Mon Sep 17 00:00:00 2001 From: Muhammad-Elmallah <145364766+Muhammad-Elmallah@users.noreply.github.com> Date: Wed, 25 Jun 2025 16:22:44 +0300 Subject: [PATCH 51/62] fixing the tests (#569) --- tests/functional/model/run_model_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 71838034..ab210109 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -106,7 +106,6 @@ def run_index_model(index_model, retries): pytest.param(None, ZeroEntropyParams, id="ZERO_ENTROPY"), pytest.param(EmbeddingModel.OPENAI_ADA002, GraphRAGParams, id="GRAPHRAG"), pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="AIR - OpenAI Ada 002"), - pytest.param("6658d40729985c2cf72f42ec", AirParams, id="AIR - Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), pytest.param("67efd4f92a0a850afa045af7", AirParams, id="AIR - BGE M3"), ], @@ -133,7 +132,6 @@ def test_index_model(embedding_model, supplier_params): [ pytest.param(None, VectaraParams, id="VECTARA"), pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="OpenAI Ada 002"), - pytest.param("6658d40729985c2cf72f42ec", AirParams, id="Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), pytest.param("67efd4f92a0a850afa045af7", AirParams, id="BGE M3"), From ba406e2b36b973ea44476ac7ecd0e5966213a30d Mon Sep 17 00:00:00 2001 From: Shreyas Sharma <85180538+shreyasXplain@users.noreply.github.com> Date: Thu, 26 Jun 2025 10:25:01 +0530 Subject: [PATCH 52/62] Introducing prompt benchmarking (#497) * Add base support for benchmarking models with config * bugFix: config normalization * TypoFix: add 's' to configuration * add display name in get_scores * add tests for prompt benchmark * uncomment first benchmark test --- aixplain/factories/benchmark_factory.py | 38 ++++++++++++++++--- aixplain/modules/benchmark_job.py | 5 +++ aixplain/modules/model/__init__.py | 11 ++++++ .../benchmark/benchmark_functional_test.py | 36 ++++++++++++------ .../data/benchmark_test_with_parameters.json | 22 +++++++++++ 5 files changed, 96 insertions(+), 16 deletions(-) create mode 100644 tests/functional/benchmark/data/benchmark_test_with_parameters.json diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index c37f17a8..1c4408ea 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -22,7 +22,7 @@ """ import logging -from typing import Dict, List, Text +from typing import Dict, List, Text, Any, Tuple import json from aixplain.enums.supplier import Supplier from aixplain.modules import Dataset, Metric, Model @@ -150,9 +150,9 @@ def _validate_create_benchmark_payload(cls, payload): if len(payload["datasets"]) != 1: raise Exception("Please use exactly one dataset") if len(payload["metrics"]) == 0: - raise Exception("Please use exactly one metric") - if len(payload["model"]) == 0: - raise Exception("Please use exactly one model") + raise Exception("Please use at least one metric") + if len(payload["model"]) == 0 and payload.get("models", None) is None: + raise Exception("Please use at least one model") clean_metrics_info = {} for metric_info in payload["metrics"]: metric_id = metric_info["id"] @@ -167,6 +167,31 @@ def _validate_create_benchmark_payload(cls, payload): {"id": metric_id, "configurations": metric_config} for metric_id, metric_config in clean_metrics_info.items() ] return payload + + @classmethod + def _reformat_model_list(cls, model_list: List[Model]) -> Tuple[List[Any], List[Any]]: + """Reformat the model list to be used in the create benchmark API + + Args: + model_list (List[Model]): List of models to be used in the benchmark + + Returns: + Tuple[List[Any], List[Any]]: Reformatted model lists + + """ + model_list_without_parms, model_list_with_parms = [], [] + for model in model_list: + if "displayName" in model.additional_info: + model_list_with_parms.append({"id": model.id, "displayName": model.additional_info["displayName"], "configurations": json.dumps(model.additional_info["configuration"])}) + else: + model_list_without_parms.append(model.id) + if len(model_list_with_parms) > 0: + if len(model_list_without_parms) > 0: + raise Exception("Please provide addditional info for all models or for none of the models") + else: + model_list_with_parms = None + return model_list_without_parms, model_list_with_parms + @classmethod def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], metric_list: List[Metric]) -> Benchmark: @@ -186,15 +211,18 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], try: url = urljoin(cls.backend_url, "sdk/benchmarks") headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + model_list_without_parms, model_list_with_parms = cls._reformat_model_list(model_list) payload = { "name": name, "datasets": [dataset.id for dataset in dataset_list], - "model": [model.id for model in model_list], "metrics": [{"id": metric.id, "configurations": metric.normalization_options} for metric in metric_list], + "model": model_list_without_parms, "shapScores": [], "humanEvaluationReport": False, "automodeTraining": False, } + if model_list_with_parms is not None: + payload["models"] = model_list_with_parms clean_payload = cls._validate_create_benchmark_payload(payload) payload = json.dumps(clean_payload) r = _request_with_retry("post", url, headers=headers, data=payload) diff --git a/aixplain/modules/benchmark_job.py b/aixplain/modules/benchmark_job.py index cd17c0e1..4ab0865b 100644 --- a/aixplain/modules/benchmark_job.py +++ b/aixplain/modules/benchmark_job.py @@ -3,6 +3,7 @@ from aixplain.utils import config from urllib.parse import urljoin import pandas as pd +import json from pathlib import Path from aixplain.utils.request_utils import _request_with_retry from aixplain.utils.file_utils import save_file @@ -109,6 +110,10 @@ def get_scores(self, return_simplified=True, return_as_dataframe=True): scores = {} for iteration_info in iterations: model_id = iteration_info["pipeline"] + pipeline_json = json.loads(iteration_info["pipelineJson"]) + if "benchmark" in pipeline_json: + model_id = pipeline_json["benchmark"]["displayName"] + model_info = { "creditsUsed": round(iteration_info.get("credits", 0), 5), "timeSpent": round(iteration_info.get("runtime", 0), 2), diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index f20d4e20..5c96ab63 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -432,6 +432,16 @@ def delete(self) -> None: message = "Model Deletion Error: Make sure the model exists and you are the owner." logging.error(message) raise Exception(f"{message}") + + def add_additional_info_for_benchmark(self, display_name: str, configuration: Dict) -> None: + """Add additional info for benchmark + + Args: + display_name (str): display name of the model + configuration (Dict): configuration of the model + """ + self.additional_info["displayName"] = display_name + self.additional_info["configuration"] = configuration @classmethod def from_dict(cls, data: Dict) -> "Model": @@ -451,3 +461,4 @@ def from_dict(cls, data: Dict) -> "Model": model_params=data.get("model_params"), **data.get("additional_info", {}), ) + diff --git a/tests/functional/benchmark/benchmark_functional_test.py b/tests/functional/benchmark/benchmark_functional_test.py index 93abd869..7c691a6f 100644 --- a/tests/functional/benchmark/benchmark_functional_test.py +++ b/tests/functional/benchmark/benchmark_functional_test.py @@ -11,9 +11,7 @@ from pathlib import Path import pytest - import logging - from aixplain import aixplain_v2 as v2 logger = logging.getLogger() @@ -22,6 +20,7 @@ TIMEOUT = 60 * 30 RUN_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_test_run_data.json")) MODULE_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_module_test_data.json")) +RUN_WITH_PARAMETERS_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_test_with_parameters.json")) def read_data(data_path): @@ -33,6 +32,11 @@ def run_input_map(request): return request.param +@pytest.fixture(scope="module", params=[(name, params) for name, params in read_data(RUN_WITH_PARAMETERS_FILE).items()]) +def run_with_parameters_input_map(request): + return request.param + + @pytest.fixture(scope="module", params=read_data(MODULE_FILE)) def module_input_map(request): return request.param @@ -79,12 +83,22 @@ def test_create_and_run(run_input_map, BenchmarkFactory): assert_correct_results(benchmark_job) -# def test_module(module_input_map): -# benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"]) -# assert benchmark.id == module_input_map["benchmark_id"] -# benchmark_job = benchmark.job_list[0] -# assert benchmark_job.benchmark_id == module_input_map["benchmark_id"] -# job_status = benchmark_job.check_status() -# assert job_status in ["in_progress", "completed"] -# df = benchmark_job.download_results_as_csv(return_dataframe=True) -# assert type(df) is pd.DataFrame +@pytest.mark.parametrize("BenchmarkFactory", [BenchmarkFactory, v2.Benchmark]) +def test_create_and_run_with_parameters(run_with_parameters_input_map, BenchmarkFactory): + name, params = run_with_parameters_input_map + model_list = [] + for model_info in params["models_with_parameters"]: + model = ModelFactory.get(model_info["model_id"]) + model.add_additional_info_for_benchmark(display_name=model_info["display_name"], configuration=model_info["configuration"]) + model_list.append(model) + dataset_list = [DatasetFactory.list(query=dataset_name)["results"][0] for dataset_name in params["dataset_names"]] + metric_list = [MetricFactory.get(metric_id) for metric_id in params["metric_ids"]] + benchmark = BenchmarkFactory.create(f"SDK Benchmark Test With Parameters({name}) {uuid.uuid4()}", dataset_list, model_list, metric_list) + assert type(benchmark) is Benchmark, "Couldn't create benchmark" + benchmark_job = benchmark.start() + assert type(benchmark_job) is BenchmarkJob, "Couldn't start job" + assert is_job_finshed(benchmark_job), "Job did not finish in time" + assert_correct_results(benchmark_job) + + + diff --git a/tests/functional/benchmark/data/benchmark_test_with_parameters.json b/tests/functional/benchmark/data/benchmark_test_with_parameters.json new file mode 100644 index 00000000..287d3d9f --- /dev/null +++ b/tests/functional/benchmark/data/benchmark_test_with_parameters.json @@ -0,0 +1,22 @@ +{ + "Translation With LLMs": { + "models_with_parameters": [ + { + "model_id": "669a63646eb56306647e1091", + "display_name": "EnHi LLM", + "configuration": { + "prompt": "Translate the following text into Hindi." + } + }, + { + "model_id": "669a63646eb56306647e1091", + "display_name": "EnEs LLM", + "configuration": { + "prompt": "Translate the following text into Spanish." + } + } + ], + "dataset_names": ["EnHi SDK Test - Benchmark Dataset"], + "metric_ids": ["639874ab506c987b1ae1acc6", "6408942f166427039206d71e"] + } +} \ No newline at end of file From 3d6dc91a7ae85c5ab11ebdbf15ff87af025882e0 Mon Sep 17 00:00:00 2001 From: kadirpekel Date: Fri, 27 Jun 2025 14:15:54 +0200 Subject: [PATCH 53/62] ENG-2371 functional test for pipeline run_async (#570) --- tests/functional/pipelines/run_test.py | 29 ++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/functional/pipelines/run_test.py b/tests/functional/pipelines/run_test.py index a3fcbfb1..999bc2d0 100644 --- a/tests/functional/pipelines/run_test.py +++ b/tests/functional/pipelines/run_test.py @@ -328,3 +328,32 @@ def test_run_failure(version: str, PipelineFactory): ) assert response["status"] == ResponseStatus.FAILED + + +@pytest.mark.parametrize("version", ["2.0", "3.0"]) +@pytest.mark.parametrize("PipelineFactory", [PipelineFactory, v2.Pipeline]) +def test_run_async_simple(version: str, PipelineFactory): + """Test simple async pipeline execution with polling""" + pipeline = PipelineFactory.list(query="SingleNodePipeline")["results"][0] + + # Start async execution + response = pipeline.run_async( + data="Translate this simple text", + **{"version": version} + ) + + poll_url = response["url"] + import time + max_attempts = 30 + attempt = 0 + + while attempt < max_attempts: + poll_response = pipeline.poll(poll_url) + if hasattr(poll_response, 'completed') and poll_response.completed: + break + elif isinstance(poll_response, dict) and poll_response.get("completed", False): + break + time.sleep(1) + attempt += 1 + + assert poll_response.status == ResponseStatus.SUCCESS From a96dd07b29cb1f08fe8c11fe6e3c553d77343bb4 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <50206820+basitanees@users.noreply.github.com> Date: Tue, 8 Jul 2025 15:53:26 +0300 Subject: [PATCH 54/62] Add new methods to Air (#578) * add retrieve by filter and delete by date to air * add functional tests for retrieve records and delete by date * add documentation to functions --- aixplain/modules/model/index_model.py | 49 +++++++++++- tests/functional/model/run_model_test.py | 94 ++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index a240f3b2..31c91531 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -55,6 +55,7 @@ def __init__( self.split_length = split_length self.split_overlap = split_overlap + class IndexModel(Model): def __init__( self, @@ -151,7 +152,7 @@ def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) - "data": query or uri, "dataType": value_type, "filters": [filter.to_dict() for filter in filters], - "payload": {"uri": uri, "value_type": value_type, "top_k": top_k} + "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, } return self.run(data=data) @@ -246,3 +247,49 @@ def delete_record(self, record_id: Text) -> ModelResponse: if response.status == "SUCCESS": return response raise Exception(f"Failed to delete record: {response.error_message}") + + def retrieve_records_with_filter(self, filter: IndexFilter) -> ModelResponse: + """ + Retrieve records from the index that match the given filter. + + Args: + filter (IndexFilter): The filter criteria to apply when retrieving records. + + Returns: + ModelResponse: Response containing the retrieved records. + + Raises: + Exception: If retrieval fails. + + Example: + >>> from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + >>> my_filter = IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS) + >>> index_model.retrieve_records_with_filter(my_filter) + """ + data = {"action": "retrieve_by_filter", "data": filter.to_dict()} + response = self.run(data=data) + if response.status == "SUCCESS": + return response + raise Exception(f"Failed to retrieve records with filter: {response.error_message}") + + def delete_records_by_date(self, date: float) -> ModelResponse: + """ + Delete records from the index that match the given date. + + Args: + date (float): The date (as a timestamp) to match records for deletion. + + Returns: + ModelResponse: Response containing the result of the deletion operation. + + Raises: + Exception: If deletion fails. + + Example: + >>> index_model.delete_records_by_date(1717708800) + """ + data = {"action": "delete_by_date", "data": date} + response = self.run(data=data) + if response.status == "SUCCESS": + return response + raise Exception(f"Failed to delete records by date: {response.error_message}") diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index ab210109..bf7415fb 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -322,3 +322,97 @@ def test_index_model_air_with_splitter(embedding_model, supplier_params): assert str(response.status) == "SUCCESS" assert "berlin" in response.data.lower() index_model.delete() + + +def _test_records(): + from aixplain.modules.model.record import Record + from aixplain.enums import DataType + + return [ + Record( + value="Artificial intelligence is transforming industries worldwide, from healthcare to finance.", + value_type=DataType.TEXT, + id="doc1", + uri="", + attributes={"category": "technology", "date": 1751464788}, + ), + Record( + value="The Mona Lisa, painted by Leonardo da Vinci, is one of the most famous artworks in history.", + value_type=DataType.TEXT, + id="doc2", + uri="", + attributes={"category": "art", "date": 1751464790}, + ), + Record( + value="Machine learning algorithms are being used to predict patient outcomes in hospitals.", + value_type=DataType.TEXT, + id="doc3", + uri="", + attributes={"category": "technology", "date": 1751464795}, + ), + Record( + value="The Earth orbits the Sun once every 365.25 days, creating the calendar year.", + value_type=DataType.TEXT, + id="doc4", + uri="", + attributes={"category": "science", "date": 1751464798}, + ), + Record( + value="Quantum computing promises to solve complex problems that are currently intractable for classical computers.", + value_type=DataType.TEXT, + id="doc5", + uri="", + attributes={"category": "technology", "date": 1751464801}, + ), + ] + + +@pytest.fixture(scope="function") +def setup_index_with_test_records(): + from aixplain.factories import IndexFactory + from aixplain.enums import EmbeddingModel + from aixplain.factories.index_factory.utils import AirParams + from uuid import uuid4 + import time + + # Clean up all existing indexes + for index in IndexFactory.list()["results"]: + index.delete() + + params = AirParams( + name=f"Test Index {uuid4()}", + description="Test index for filter/date tests", + embedding_model=EmbeddingModel.OPENAI_ADA002, + ) + index_model = IndexFactory.create(params=params) + records = _test_records() + + index_model.upsert(records) + + yield index_model + index_model.delete() + + +def test_retrieve_records_with_filter(setup_index_with_test_records): + from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + + index_model = setup_index_with_test_records + filter_ = IndexFilter(field="category", value="technology", operator=IndexFilterOperator.EQUALS) + response = index_model.retrieve_records_with_filter(filter_) + assert response.status == "SUCCESS" + assert len(response.details) == 3 + for item in response.details: + assert item["metadata"]["category"] == "technology" + + +def test_delete_records_by_date(setup_index_with_test_records): + from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + + index_model = setup_index_with_test_records + response = index_model.delete_records_by_date(1751464796) + assert response.status == "SUCCESS" + assert response.data == "2" # 2 records should remain + filter_all = IndexFilter(field="date", value=0, operator=IndexFilterOperator.GREATER_THAN) + response = index_model.retrieve_records_with_filter(filter_all) + assert response.status == "SUCCESS" + assert len(response.details) == 2 From ac8b5ac226a986210c1a33d4a8f80bd40610be51 Mon Sep 17 00:00:00 2001 From: Michael Lam <131073216+mikelam-us-aixplain@users.noreply.github.com> Date: Tue, 8 Jul 2025 07:39:30 -0700 Subject: [PATCH 55/62] Bumped model-interfaces to latest release (#221) Signed-off-by: mikelam-us-aixplain --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0bc5ec6a..fe0f7f00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ aixplain = "aixplain.cli_groups:run_cli" [project.optional-dependencies] model-builder = [ - "model-interfaces~=0.0.1" + "model-interfaces~=0.0.2" ] test = [ "pytest>=6.1.0", From 7e595190ed355c76cf7bb67faef659763c0f1e16 Mon Sep 17 00:00:00 2001 From: Abdelrahman El-Sheikh <139810675+elsheikhams99@users.noreply.github.com> Date: Sun, 13 Jul 2025 15:26:14 +0300 Subject: [PATCH 56/62] Re-implement the integration.py file (#581) * Re-implement the integration.py file * Add NO_AUTH, and rename OAUTH to OAUTH1 --- aixplain/modules/model/integration.py | 90 ++++++++++----------------- 1 file changed, 34 insertions(+), 56 deletions(-) diff --git a/aixplain/modules/model/integration.py b/aixplain/modules/model/integration.py index 65afc322..5fc6c0e6 100644 --- a/aixplain/modules/model/integration.py +++ b/aixplain/modules/model/integration.py @@ -5,48 +5,24 @@ from typing import Text, Optional, Union, Dict from enum import Enum from pydantic import BaseModel - +import json class AuthenticationSchema(Enum): - BEARER = "BEARER_TOKEN" - OAUTH = "OAUTH" + BEARER_TOKEN = "BEARER_TOKEN" + OAUTH1 = "OAUTH1" OAUTH2 = "OAUTH2" - - + API_KEY = "API_KEY" + BASIC = "BASIC" + NO_AUTH = "NO_AUTH" + class BaseAuthenticationParams(BaseModel): name: Optional[Text] = None - authentication_schema: AuthenticationSchema = AuthenticationSchema.OAUTH2 connector_id: Optional[Text] = None - -class BearerAuthenticationParams(BaseAuthenticationParams): - token: Text - authentication_schema: AuthenticationSchema = AuthenticationSchema.BEARER - - -class OAuthAuthenticationParams(BaseAuthenticationParams): - client_id: Text - client_secret: Text - authentication_schema: AuthenticationSchema = AuthenticationSchema.OAUTH - - -class OAuth2AuthenticationParams(BaseAuthenticationParams): - authentication_schema: AuthenticationSchema = AuthenticationSchema.OAUTH2 - - def build_connector_params(**kwargs) -> BaseAuthenticationParams: name = kwargs.get("name") - token = kwargs.get("token") - client_id = kwargs.get("client_id") - client_secret = kwargs.get("client_secret") connector_id = kwargs.get("connector_id") - if token: - args = BearerAuthenticationParams(name=name, token=token, connector_id=connector_id) - elif client_id and client_secret: - args = OAuthAuthenticationParams(name=name, client_id=client_id, client_secret=client_secret, connector_id=connector_id) - else: - args = OAuth2AuthenticationParams(name=name, connector_id=connector_id) - return args + return BaseAuthenticationParams(name=name, connector_id=connector_id) class Integration(Model): @@ -94,8 +70,10 @@ def __init__( ) self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL + self.authentication_methods = json.loads([item for item in additional_info['attributes'] if item['name'] == 'auth_schemes'][0]['code']) - def connect(self, args: Optional[BaseAuthenticationParams] = None, **kwargs) -> ModelResponse: + + def connect(self, authentication_schema: AuthenticationSchema, args: Optional[BaseAuthenticationParams] = None, data: Optional[Dict] = None, **kwargs) -> ModelResponse: """Connect to the integration Examples: @@ -119,29 +97,20 @@ def connect(self, args: Optional[BaseAuthenticationParams] = None, **kwargs) -> if args is None: args = build_connector_params(**kwargs) - authentication_schema = args.authentication_schema - if authentication_schema == AuthenticationSchema.BEARER: - return self.run( - { - "name": args.name, - "authScheme": authentication_schema.value, - "data": { - "token": args.token, - }, - } - ) - elif authentication_schema == AuthenticationSchema.OAUTH: - return self.run( - { - "name": args.name, - "authScheme": authentication_schema.value, - "data": { - "client_id": args.client_id, - "client_secret": args.client_secret, - }, - } - ) - elif authentication_schema == AuthenticationSchema.OAUTH2: + if authentication_schema.value not in self.authentication_methods: + raise ValueError(f"Authentication schema {authentication_schema.value} is not supported for this integration. Supported authentication methods: {self.authentication_methods}") + if data is None: + data = {} + required_params = json.loads([item for item in self.additional_info['attributes'] if item['name'] == authentication_schema.value + "-inputs"][0]['code']) + required_params_names = [param['name'] for param in required_params] + for param in required_params_names: + if param not in data: + if len(required_params_names) == 1: + raise ValueError(f"Parameter '{param}' is required for {self.name} {authentication_schema.value} authentication. Please provide the parameter in the data dictionary.") + else: + raise ValueError(f"Parameters {required_params_names} are required for {self.name} {authentication_schema.value} authentication. Please provide the parameters in the data dictionary.") + + if authentication_schema in [AuthenticationSchema.OAUTH2, AuthenticationSchema.OAUTH1, AuthenticationSchema.NO_AUTH]: response = self.run( { "name": args.name, @@ -153,6 +122,15 @@ def connect(self, args: Optional[BaseAuthenticationParams] = None, **kwargs) -> f"Before using the tool, please visit the following URL to complete the connection: {response.data['redirectURL']}" ) return response + else: + return self.run( + { + "name": args.name, + "authScheme": authentication_schema.value, + "data": data, + } + ) + def __repr__(self): try: From 5b6bcaaed7fb418a622abf8e181516c429cc020b Mon Sep 17 00:00:00 2001 From: Abdelrahman El-Sheikh <139810675+elsheikhams99@users.noreply.github.com> Date: Mon, 14 Jul 2025 15:25:57 +0300 Subject: [PATCH 57/62] Improve readability - fix OAUTH methods bug (#583) --- aixplain/modules/model/integration.py | 39 ++++++++++++++------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/aixplain/modules/model/integration.py b/aixplain/modules/model/integration.py index 5fc6c0e6..69e831d5 100644 --- a/aixplain/modules/model/integration.py +++ b/aixplain/modules/model/integration.py @@ -99,41 +99,42 @@ def connect(self, authentication_schema: AuthenticationSchema, args: Optional[Ba if authentication_schema.value not in self.authentication_methods: raise ValueError(f"Authentication schema {authentication_schema.value} is not supported for this integration. Supported authentication methods: {self.authentication_methods}") + if data is None: data = {} - required_params = json.loads([item for item in self.additional_info['attributes'] if item['name'] == authentication_schema.value + "-inputs"][0]['code']) - required_params_names = [param['name'] for param in required_params] - for param in required_params_names: - if param not in data: - if len(required_params_names) == 1: - raise ValueError(f"Parameter '{param}' is required for {self.name} {authentication_schema.value} authentication. Please provide the parameter in the data dictionary.") - else: - raise ValueError(f"Parameters {required_params_names} are required for {self.name} {authentication_schema.value} authentication. Please provide the parameters in the data dictionary.") - - if authentication_schema in [AuthenticationSchema.OAUTH2, AuthenticationSchema.OAUTH1, AuthenticationSchema.NO_AUTH]: - response = self.run( + + if authentication_schema not in [AuthenticationSchema.OAUTH2, AuthenticationSchema.OAUTH1, AuthenticationSchema.NO_AUTH]: + required_params = json.loads([item for item in self.additional_info['attributes'] if item['name'] == authentication_schema.value + "-inputs"][0]['code']) + required_params_names = [param['name'] for param in required_params] + for param in required_params_names: + if param not in data: + if len(required_params_names) == 1: + raise ValueError(f"Parameter '{param}' is required for {self.name} {authentication_schema.value} authentication. Please provide the parameter in the data dictionary.") + else: + raise ValueError(f"Parameters {required_params_names} are required for {self.name} {authentication_schema.value} authentication. Please provide the parameters in the data dictionary.") + return self.run( { "name": args.name, "authScheme": authentication_schema.value, + "data": data, } ) - if "redirectURL" in response.data: - warnings.warn( - f"Before using the tool, please visit the following URL to complete the connection: {response.data['redirectURL']}" - ) - return response else: - return self.run( + response = self.run( { "name": args.name, "authScheme": authentication_schema.value, - "data": data, } ) + if "redirectURL" in response.data: + warnings.warn( + f"Before using the tool, please visit the following URL to complete the connection: {response.data['redirectURL']}" + ) + return response def __repr__(self): try: return f"Integration: {self.name} by {self.supplier['name']} (id={self.id})" except Exception: - return f"Integration: {self.name} by {self.supplier} (id={self.id})" + return f"Integration: {self.name} by {self.supplier} (id={self.id})" \ No newline at end of file From 6ab59533f2b422cab78de77fc2f982c1458f4b58 Mon Sep 17 00:00:00 2001 From: Abdelrahman El-Sheikh <139810675+elsheikhams99@users.noreply.github.com> Date: Tue, 15 Jul 2025 16:39:56 +0300 Subject: [PATCH 58/62] Eng 2400 add tests, fix composio bugs (#585) * Fix bugs, and modify tests * Update tool_factory.py parameter desc --- aixplain/factories/model_factory/utils.py | 1 + aixplain/factories/tool_factory.py | 18 +++++++++++++++--- .../functional/agent/agent_functional_test.py | 8 ++++---- .../team_agent/team_agent_functional_test.py | 6 +++--- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index ca3e1eec..ed3409a3 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -118,6 +118,7 @@ def create_model_from_response(response: Dict) -> Model: supports_streaming=response.get("supportsStreaming", False), status=status, function_type=function_type, + attributes=attributes, **additional_kwargs, ) diff --git a/aixplain/factories/tool_factory.py b/aixplain/factories/tool_factory.py index 9db2a9ea..8f746422 100644 --- a/aixplain/factories/tool_factory.py +++ b/aixplain/factories/tool_factory.py @@ -4,12 +4,12 @@ from aixplain.factories.model_factory.mixins import ModelGetterMixin, ModelListMixin from aixplain.modules.model import Model from aixplain.modules.model.index_model import IndexModel -from aixplain.modules.model.integration import Integration +from aixplain.modules.model.integration import Integration, AuthenticationSchema from aixplain.modules.model.integration import BaseAuthenticationParams from aixplain.factories.index_factory.utils import BaseIndexParams, AirParams, VectaraParams, ZeroEntropyParams, GraphRAGParams from aixplain.enums.index_stores import IndexStores from aixplain.modules.model.utility_model import BaseUtilityModelParams -from typing import Optional, Text, Union +from typing import Optional, Text, Union, Dict from aixplain.enums import ResponseStatus from aixplain.utils import config @@ -22,6 +22,8 @@ def create( cls, integration: Optional[Union[Text, Model]] = None, params: Optional[Union[BaseUtilityModelParams, BaseIndexParams, BaseAuthenticationParams]] = None, + authentication_schema: Optional[AuthenticationSchema] = None, + data: Optional[Dict] = None, **kwargs, ) -> Model: """Factory method to create indexes, script models and connections @@ -143,7 +145,17 @@ def add(a: int, b: int) -> int: assert ( connector.function_type == FunctionType.INTEGRATION ), f"The model you are trying to connect ({connector.id}) to is not a connector." - response = connector.connect(params) + + assert authentication_schema is not None, "Please provide the authentication schema to use (authentication_schema parameter)" + assert isinstance(authentication_schema, AuthenticationSchema), "authentication_schema must be an instance of AuthenticationSchema" + + auth_data = data if data is not None else {} + if not auth_data: + for key, value in kwargs.items(): + if key not in ['name', 'connector_id']: + auth_data[key] = value + + response = connector.connect(authentication_schema, params, auth_data) assert response.status == ResponseStatus.SUCCESS, f"Failed to connect to {connector.id} - {response.error_message}" connection = cls.get(response.data["id"]) if "redirectURL" in response.data: diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 2783ffc3..69c9f002 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -732,13 +732,13 @@ class Response(BaseModel): def test_agent_with_action_tool(): from aixplain.modules.model.integration import AuthenticationSchema - connector = ModelFactory.get("67eff5c0e05614297caeef98") + connector = ModelFactory.get("686432941223092cb4294d3f") # connect - response = connector.connect(authentication_schema=AuthenticationSchema.BEARER, token=os.getenv("SLACK_TOKEN")) + response = connector.connect(authentication_schema=AuthenticationSchema.BEARER_TOKEN, data={"token": os.getenv("SLACK_TOKEN")}) connection_id = response.data["id"] connection = ModelFactory.get(connection_id) - connection.action_scope = [action for action in connection.actions if action.code == "SLACK_CHAT_POST_MESSAGE"] + connection.action_scope = [action for action in connection.actions if action.code == "SLACK_SENDS_A_MESSAGE_TO_A_SLACK_CHANNEL"] agent = AgentFactory.create( name="Test Agent", @@ -757,5 +757,5 @@ def test_agent_with_action_tool(): assert response is not None assert response["status"].lower() == "success" assert "helsinki" in response.data.output.lower() - assert "SLACK_CHAT_POST_MESSAGE" in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] + assert "SLACK_SENDS_A_MESSAGE_TO_A_SLACK_CHANNEL" in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] connection.delete() diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index f3852f60..5659bfa8 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -445,13 +445,13 @@ class Response(BaseModel): def test_team_agent_with_slack_connector(): from aixplain.modules.model.integration import AuthenticationSchema - connector = ModelFactory.get("67eff5c0e05614297caeef98") + connector = ModelFactory.get("686432941223092cb4294d3f") # connect - response = connector.connect(authentication_schema=AuthenticationSchema.BEARER, token=os.getenv("SLACK_TOKEN")) + response = connector.connect(authentication_schema=AuthenticationSchema.BEARER_TOKEN, data={"token": os.getenv("SLACK_TOKEN")}) connection_id = response.data["id"] connection = ModelFactory.get(connection_id) - connection.action_scope = [action for action in connection.actions if action.code == "SLACK_CHAT_POST_MESSAGE"] + connection.action_scope = [action for action in connection.actions if action.code == "SLACK_SENDS_A_MESSAGE_TO_A_SLACK_CHANNEL"] agent = AgentFactory.create( name="Test Agent", From c397aaacc81befb856e480665dc4dcee2e1ff774 Mon Sep 17 00:00:00 2001 From: Hadi Nasrallah <87204330+hadi-aix@users.noreply.github.com> Date: Tue, 15 Jul 2025 10:51:06 -0400 Subject: [PATCH 59/62] Add MCP Connection (#587) --- aixplain/enums/function_type.py | 1 + aixplain/factories/model_factory/utils.py | 2 +- aixplain/modules/model/connection.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/aixplain/enums/function_type.py b/aixplain/enums/function_type.py index ae6f8e79..b9655813 100644 --- a/aixplain/enums/function_type.py +++ b/aixplain/enums/function_type.py @@ -34,3 +34,4 @@ class FunctionType(Enum): INTEGRATION = "connector" CONNECTION = "connection" MCPSERVER = 'mcpserver' + MCPCONNECTION = "mcpconnection" diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index ed3409a3..5021ae18 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -70,7 +70,7 @@ def create_model_from_response(response: Dict) -> Model: ModelClass = IndexModel elif function_type == FunctionType.INTEGRATION: ModelClass = Integration - elif function_type == FunctionType.CONNECTION: + elif function_type == FunctionType.CONNECTION or function_type == FunctionType.MCPCONNECTION: ModelClass = ConnectionTool elif function == Function.UTILITIES: ModelClass = UtilityModel diff --git a/aixplain/modules/model/connection.py b/aixplain/modules/model/connection.py index 6d98f140..430289ea 100644 --- a/aixplain/modules/model/connection.py +++ b/aixplain/modules/model/connection.py @@ -53,7 +53,7 @@ def __init__( scope (Text, optional): action scope of the connection. Defaults to None. **additional_info: Any additional Model info to be saved """ - assert function_type == FunctionType.CONNECTION, "Connection only supports connection function" + assert function_type == FunctionType.CONNECTION or function_type == FunctionType.MCPCONNECTION, "Connection only supports connection function" super().__init__( id=id, name=name, From ba71caa2af462a88a406ccc36f49145fc9553f76 Mon Sep 17 00:00:00 2001 From: Abdelrahman El-Sheikh <139810675+elsheikhams99@users.noreply.github.com> Date: Wed, 16 Jul 2025 18:32:40 +0300 Subject: [PATCH 60/62] Fix integration tests in model_test.py (#590) --- tests/unit/model_test.py | 74 ++++++++++++++++++++++++++++++++++------ 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 9abc6fe0..94c7201c 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -34,7 +34,7 @@ from aixplain.modules.model.llm_model import LLM from aixplain.modules.model.index_model import IndexModel from aixplain.modules.model.utility_model import UtilityModel -from aixplain.modules.model.integration import Integration, AuthenticationSchema +from aixplain.modules.model.integration import Integration, AuthenticationSchema, build_connector_params from aixplain.modules.model.connection import ConnectionTool, ConnectAction @@ -675,6 +675,12 @@ def test_model_not_supports_streaming(mocker): "pricing": {"price": 10, "currency": "USD"}, "params": {}, "version": {"id": "1.0"}, + "attributes": [ + { + "name": "auth_schemes", + "code": '["BEARER_TOKEN", "API_KEY", "BASIC"]' + }, + ] }, Integration, ), @@ -736,14 +742,35 @@ def test_create_model_from_response(payload, expected_model_class): @pytest.mark.parametrize( - "authentication_schema, name, token, client_id, client_secret", + "authentication_schema, name, data", [ - (AuthenticationSchema.BEARER, "test-name", "test-token", None, None), - (AuthenticationSchema.OAUTH, "test-name", None, "test-client-id", "test-client-secret"), + (AuthenticationSchema.BEARER_TOKEN, "test-name", {"token": "test-token"}), + (AuthenticationSchema.API_KEY, "test-name", {"api_key": "test-api-key"}), + (AuthenticationSchema.BASIC, "test-name", {"username": "test-user", "password": "test-pass"}), ], ) -def test_connector_connect(mocker, authentication_schema, name, token, client_id, client_secret): +def test_connector_connect(mocker, authentication_schema, name, data): mocker.patch("aixplain.modules.model.integration.Integration.run", return_value={"id": "test-id"}) + additional_info = { + 'attributes': [ + { + 'name': 'auth_schemes', + 'code': '["BEARER_TOKEN", "API_KEY", "BASIC"]' + }, + { + 'name': 'BEARER_TOKEN-inputs', + 'code': '[{"name": "token"}]' + }, + { + 'name': 'API_KEY-inputs', + 'code': '[{"name": "api_key"}]' + }, + { + 'name': 'BASIC-inputs', + 'code': '[{"name": "username"}, {"name": "password"}]' + } + ] + } connector = Integration( id="connector-id", name="connector-name", @@ -752,11 +779,16 @@ def test_connector_connect(mocker, authentication_schema, name, token, client_id supplier="aiXplain", api_key="api_key", version={"id": "1.0"}, + **additional_info ) + args = build_connector_params(name=name) response = connector.connect( - authentication_schema=authentication_schema, name=name, token=token, client_id=client_id, client_secret=client_secret + authentication_schema=authentication_schema, + args=args, + data=data ) - assert response == {"id": "test-id"} + + assert response["id"] == "test-id" def test_connection_init_with_actions(mocker): @@ -863,14 +895,35 @@ def add(aaa: int, bbb: int) -> int: ) def get_mock(id): - if id == "67eff5c0e05614297caeef98": + additional_info = { + 'attributes': [ + { + 'name': 'auth_schemes', + 'code': '["BEARER_TOKEN", "API_KEY", "BASIC"]' + }, + { + 'name': 'BEARER_TOKEN-inputs', + 'code': '[{"name": "token"}]' + }, + { + 'name': 'API_KEY-inputs', + 'code': '[{"name": "api_key"}]' + }, + { + 'name': 'BASIC-inputs', + 'code': '[{"name": "username"}, {"name": "password"}]' + } + ] + } + if id == "686432941223092cb4294d3f": return Integration( - id="67eff5c0e05614297caeef98", + id="686432941223092cb4294d3f", name="test-name", function=Function.UTILITIES, function_type=FunctionType.INTEGRATION, api_key="api_key", version={"id": "1.0"}, + **additional_info ) elif id == "connection-id": return ConnectionTool( @@ -880,10 +933,11 @@ def get_mock(id): function_type=FunctionType.CONNECTION, api_key="api_key", version={"id": "1.0"}, + **additional_info ) mocker.patch("aixplain.factories.tool_factory.ToolFactory.get", side_effect=get_mock) - tool = ToolFactory.create(integration="67eff5c0e05614297caeef98", name="My Connector 1234", token="slack-token") + tool = ToolFactory.create(integration="686432941223092cb4294d3f", name="My Connector 1234", authentication_schema=AuthenticationSchema.BEARER_TOKEN, data={"token": "slack-token"}) assert isinstance(tool, ConnectionTool) assert tool.id == "connection-id" assert tool.name == "test-name" From 7ea5ec8491eea2371bf0d5ae3718231b81f8e2bf Mon Sep 17 00:00:00 2001 From: Hadi Nasrallah <87204330+hadi-aix@users.noreply.github.com> Date: Thu, 17 Jul 2025 09:13:27 -0400 Subject: [PATCH 61/62] Eng 2327 integrate mcp server (#592) * Add MCP connection * Add MCP connection --------- Co-authored-by: lucas-aixplain Co-authored-by: Abdelrahman El-Sheikh <139810675+elsheikhams99@users.noreply.github.com> --- aixplain/enums/function_type.py | 4 +- aixplain/factories/model_factory/utils.py | 3 + aixplain/modules/model/connection.py | 4 +- aixplain/modules/model/integration.py | 5 + aixplain/modules/model/mcp_connection.py | 102 ++++++++++++++++++ .../functional/agent/agent_functional_test.py | 31 ++++++ .../model/run_connect_model_test.py | 27 +++++ 7 files changed, 173 insertions(+), 3 deletions(-) create mode 100644 aixplain/modules/model/mcp_connection.py diff --git a/aixplain/enums/function_type.py b/aixplain/enums/function_type.py index b9655813..f09d87e2 100644 --- a/aixplain/enums/function_type.py +++ b/aixplain/enums/function_type.py @@ -33,5 +33,5 @@ class FunctionType(Enum): SEARCH = "search" INTEGRATION = "connector" CONNECTION = "connection" - MCPSERVER = 'mcpserver' - MCPCONNECTION = "mcpconnection" + MCP_CONNECTION = "mcpconnection" + MCPSERVER = "mcpserver" \ No newline at end of file diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 5021ae18..4e1f8f85 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -5,6 +5,7 @@ from aixplain.modules.model.index_model import IndexModel from aixplain.modules.model.integration import Integration from aixplain.modules.model.connection import ConnectionTool +from aixplain.modules.model.mcp_connection import MCPConnection from aixplain.modules.model.utility_model import UtilityModel from aixplain.modules.model.utility_model import UtilityModelInput from aixplain.enums import DataType, Function, FunctionType, Language, OwnershipType, Supplier, SortBy, SortOrder, AssetStatus @@ -72,6 +73,8 @@ def create_model_from_response(response: Dict) -> Model: ModelClass = Integration elif function_type == FunctionType.CONNECTION or function_type == FunctionType.MCPCONNECTION: ModelClass = ConnectionTool + elif function_type == FunctionType.MCP_CONNECTION: + ModelClass = MCPConnection elif function == Function.UTILITIES: ModelClass = UtilityModel inputs = [ diff --git a/aixplain/modules/model/connection.py b/aixplain/modules/model/connection.py index 430289ea..3e73574d 100644 --- a/aixplain/modules/model/connection.py +++ b/aixplain/modules/model/connection.py @@ -53,7 +53,9 @@ def __init__( scope (Text, optional): action scope of the connection. Defaults to None. **additional_info: Any additional Model info to be saved """ - assert function_type == FunctionType.CONNECTION or function_type == FunctionType.MCPCONNECTION, "Connection only supports connection function" + assert ( + function_type == FunctionType.CONNECTION or function_type == FunctionType.MCP_CONNECTION + ), "Connection only supports connection function" super().__init__( id=id, name=name, diff --git a/aixplain/modules/model/integration.py b/aixplain/modules/model/integration.py index 69e831d5..0e4c5ee5 100644 --- a/aixplain/modules/model/integration.py +++ b/aixplain/modules/model/integration.py @@ -94,6 +94,9 @@ def connect(self, authentication_schema: AuthenticationSchema, args: Optional[Ba id: Connection ID (retrieve it with ModelFactory.get(id)) redirectUrl: Redirect URL to complete the connection (only for OAuth2) """ + if self.id == "68549a33ba00e44f357896f1": + return self.run({"data": kwargs.get("data")}) + if args is None: args = build_connector_params(**kwargs) @@ -131,6 +134,8 @@ def connect(self, authentication_schema: AuthenticationSchema, args: Optional[Ba f"Before using the tool, please visit the following URL to complete the connection: {response.data['redirectURL']}" ) return response + else: + raise ValueError(f"Invalid authentication schema: {authentication_schema}") def __repr__(self): diff --git a/aixplain/modules/model/mcp_connection.py b/aixplain/modules/model/mcp_connection.py new file mode 100644 index 00000000..19155694 --- /dev/null +++ b/aixplain/modules/model/mcp_connection.py @@ -0,0 +1,102 @@ +from aixplain.enums import Function, Supplier, FunctionType, ResponseStatus +from aixplain.modules.model.connection import ConnectionTool +from aixplain.modules.model import Model +from typing import Text, Optional, Union, Dict, List + + +class ConnectAction: + name: Text + description: Text + code: Optional[Text] = None + inputs: Optional[Dict] = None + + def __init__(self, name: Text, description: Text, code: Optional[Text] = None, inputs: Optional[Dict] = None): + self.name = name + self.description = description + self.code = code + self.inputs = inputs + + def __repr__(self): + return f"Action(code={self.code}, name={self.name})" + + +class MCPConnection(ConnectionTool): + actions: List[ConnectAction] + action_scope: Optional[List[ConnectAction]] = None + + def __init__( + self, + id: Text, + name: Text, + description: Text = "", + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + version: Optional[Text] = None, + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + function_type: Optional[FunctionType] = FunctionType.CONNECTION, + **additional_info, + ) -> None: + """Connection Init + + Args: + id (Text): ID of the Model + name (Text): Name of the Model + description (Text, optional): description of the model. Defaults to "". + api_key (Text, optional): API key of the Model. Defaults to None. + supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". + version (Text, optional): version of the model. Defaults to "1.0". + function (Function, optional): model AI function. Defaults to None. + is_subscribed (bool, optional): Is the user subscribed. Defaults to False. + cost (Dict, optional): model price. Defaults to None. + scope (Text, optional): action scope of the connection. Defaults to None. + **additional_info: Any additional Model info to be saved + """ + assert function_type == FunctionType.MCP_CONNECTION, "Connection only supports mcp connection function" + super().__init__( + id=id, + name=name, + description=description, + supplier=supplier, + version=version, + cost=cost, + function=function, + is_subscribed=is_subscribed, + api_key=api_key, + function_type=function_type, + **additional_info, + ) + + def _get_actions(self): + response = Model.run(self, {"action": "LIST_TOOLS", "data": " "}) + if response.status == ResponseStatus.SUCCESS: + return [ + ConnectAction(name=action["name"], description=action["description"], code=action["name"]) + for action in response.data + ] + raise Exception( + f"It was not possible to get the actions for the connection {self.id}. Error {response.error_code}: {response.error_message}" + ) + + def get_action_inputs(self, action: Union[ConnectAction, Text]): + if action.inputs: + return action.inputs + + if isinstance(action, ConnectAction): + action = action.code + + response = Model.run(self, {"action": "LIST_TOOLS", "data": {"actions": [action]}}) + if response.status == ResponseStatus.SUCCESS: + try: + inputs = {inp["code"]: inp for inp in response.data[0]["inputs"]} + action_idx = next((i for i, a in enumerate(self.actions) if a.code == action), None) + if action_idx is not None: + self.actions[action_idx].inputs = inputs + return inputs + except Exception as e: + raise Exception(f"It was not possible to get the inputs for the action {action}. Error {e}") + + raise Exception( + f"It was not possible to get the inputs for the action {action}. Error {response.error_code}: {response.error_message}" + ) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 69c9f002..db1f88db 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -759,3 +759,34 @@ def test_agent_with_action_tool(): assert "helsinki" in response.data.output.lower() assert "SLACK_SENDS_A_MESSAGE_TO_A_SLACK_CHANNEL" in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] connection.delete() + + +def test_agent_with_mcp_tool(): + connector = ModelFactory.get("68549a33ba00e44f357896f1") + # connect + response = connector.connect( + data="https://mcp.zapier.com/api/mcp/s/OTJiMjVlYjEtMGE4YS00OTVjLWIwMGYtZDJjOGVkNTc4NjFkOjI0MTNjNzg5LWZlNGMtNDZmNC05MDhmLWM0MGRlNDU4ZmU1NA==/mcp" + ) + connection_id = response.data["id"] + connection = ModelFactory.get(connection_id) + action_name = "SLACK_SEND_CHANNEL_MESSAGE".lower() + connection.action_scope = [action for action in connection.actions if action.code == action_name] + + agent = AgentFactory.create( + name="Test Agent", + description="This agent is used to send messages to Slack", + instructions="You are a helpful assistant that can send messages to Slack. You MUST use the tool to send the message.", + llm_id="669a63646eb56306647e1091", + tools=[ + connection, + ], + ) + + response = agent.run( + "Send what is the capital of Finland on Slack to channel of #modelserving-alerts-testing. Add the name of the capital in the final answer." + ) + assert response is not None + assert response["status"].lower() == "success" + assert "helsinki" in response.data.output.lower() + assert action_name in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] + connection.delete() diff --git a/tests/functional/model/run_connect_model_test.py b/tests/functional/model/run_connect_model_test.py index 3f89b4ad..2b3124b7 100644 --- a/tests/functional/model/run_connect_model_test.py +++ b/tests/functional/model/run_connect_model_test.py @@ -3,6 +3,7 @@ from aixplain.factories import ModelFactory from aixplain.modules.model.integration import Integration, AuthenticationSchema from aixplain.modules.model.connection import ConnectionTool +from aixplain.modules.model.mcp_connection import MCPConnection def test_run_connect_model(): @@ -28,3 +29,29 @@ def test_run_connect_model(): action = action[0] response = connection.run(action, {"text": "This is a test!", "channel": "C084G435LR5"}) assert response.status == ResponseStatus.SUCCESS + + +def test_run_mcp_connect_model(): + # get slack connector + connector = ModelFactory.get("68549a33ba00e44f357896f1") + + assert isinstance(connector, Integration) + assert connector.id == "68549a33ba00e44f357896f1" + assert connector.name == "" + + url = "https://mcp.zapier.com/api/mcp/s/OTJiMjVlYjEtMGE4YS00OTVjLWIwMGYtZDJjOGVkNTc4NjFkOjI0MTNjNzg5LWZlNGMtNDZmNC05MDhmLWM0MGRlNDU4ZmU1NA==" + response = connector.connect(data=url) + assert response.status == ResponseStatus.SUCCESS + assert "id" in response.data + connection_id = response.data["id"] + # get slack connection + connection = ModelFactory.get(connection_id) + assert isinstance(connection, MCPConnection) + assert connection.id == connection_id + assert connection.actions is not None + + action = [action for action in connection.actions if action.code == "SLACK_CHAT_POST_MESSAGE"] + assert len(action) > 0 + action = action[0] + response = connection.run(action, {"text": "This is a test!", "channel": "C084G435LR5"}) + assert response.status == ResponseStatus.SUCCESS From 926927ce9bf717ec64d926edda937f74b5e4f3e4 Mon Sep 17 00:00:00 2001 From: Hadi Date: Thu, 17 Jul 2025 09:23:41 -0400 Subject: [PATCH 62/62] Update utils.py --- aixplain/factories/model_factory/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 4e1f8f85..9f32e51f 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -71,7 +71,7 @@ def create_model_from_response(response: Dict) -> Model: ModelClass = IndexModel elif function_type == FunctionType.INTEGRATION: ModelClass = Integration - elif function_type == FunctionType.CONNECTION or function_type == FunctionType.MCPCONNECTION: + elif function_type == FunctionType.CONNECTION : ModelClass = ConnectionTool elif function_type == FunctionType.MCP_CONNECTION: ModelClass = MCPConnection