From a7ac7adad29112d1e40a4c8295302aa43f876d2d Mon Sep 17 00:00:00 2001 From: kadirpekel Date: Fri, 11 Apr 2025 17:37:40 +0200 Subject: [PATCH 01/22] ENG-1852: Hotfix, added fail-fast option (#488) --- .github/workflows/main.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index b939a3a0..faf380eb 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -12,6 +12,7 @@ jobs: setup-and-test: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: test-suite: [ 'tests/unit', From b0887bd19b05cd4c610711bb1089d9b6826ea8ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Wed, 16 Apr 2025 11:47:54 +0300 Subject: [PATCH 02/22] Fix: BUG-503 failed to update utility tool after get (#491) --- aixplain/factories/model_factory/utils.py | 15 +++++- tests/unit/utility_test.py | 58 ++++++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index c5a90367..c10e4369 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -10,6 +10,7 @@ from datetime import datetime from typing import Dict, Union, List, Optional, Tuple from urllib.parse import urljoin +import requests def create_model_from_response(response: Dict) -> Model: @@ -39,10 +40,13 @@ def create_model_from_response(response: Dict) -> Model: function_input_params, function_output_params = function.get_input_output_params() model_params = {param["name"]: param for param in response["params"]} + code = response.get("code", "") + inputs, temperature = [], None input_params, output_params = function_input_params, function_output_params ModelClass = Model + if function == Function.TEXT_GENERATION: ModelClass = LLM f = [p for p in response.get("params", []) if p["name"] == "temperature"] @@ -58,6 +62,15 @@ def create_model_from_response(response: Dict) -> Model: ] input_params = model_params + if not code: + if "version" in response and response["version"]: + version_link = response["version"]["id"] + if version_link: + version_content = requests.get(version_link).text + code = version_content + else: + raise Exception("Utility Model Error: Code not found") + created_at = None if "createdAt" in response and response["createdAt"]: created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) @@ -66,7 +79,7 @@ def create_model_from_response(response: Dict) -> Model: response["id"], response["name"], description=response.get("description", ""), - code=response.get("code", ""), + code=code if code else "", supplier=response["supplier"], api_key=response["api_key"], cost=response["pricing"], diff --git a/tests/unit/utility_test.py b/tests/unit/utility_test.py index e678fa7a..9aabd831 100644 --- a/tests/unit/utility_test.py +++ b/tests/unit/utility_test.py @@ -7,7 +7,7 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.modules.model.utils import parse_code, parse_code_decorated -from unittest.mock import patch +from unittest.mock import patch, MagicMock import warnings @@ -507,3 +507,59 @@ def test_utility_model_status_after_deployment(): warnings.simplefilter("always") utility_model.validate() assert len(w) == 0 + + +def test_concat_strings(): + """Test the concat_strings function directly.""" + assert concat_strings("Hello, ", "World!") == "Hello, World!" + assert concat_strings("", "") == "" + assert concat_strings("123", "456") == "123456" + + +def concat_strings(str1: str, str2: str): + """Concatenates two strings. + + Args: + str1: The first string. + str2: The second string. + + Returns: + The concatenated string. + """ + return str1 + str2 + + +@patch("aixplain.factories.ModelFactory") +def test_create_and_deploy_utility_model(mock_model_factory): + """Test creating and deploying a utility model with mocked backend requests.""" + # Mock the create_utility_model method + mock_utility_model = MagicMock() + mock_utility_model.id = "mock-utility-model-id" + mock_model_factory.create_utility_model.return_value = mock_utility_model + + # Mock the get method + mock_model = MagicMock() + mock_model_factory.get.return_value = mock_model + + # Create utility model + from aixplain.factories import ModelFactory + + utility_model = ModelFactory.create_utility_model( + name="concat_strings", + code=concat_strings, + ) + + # Assert create_utility_model was called with correct parameters + mock_model_factory.create_utility_model.assert_called_once() + + # Get the model with mocked id + model = ModelFactory.get(utility_model.id) + + # Assert get was called with the correct id + mock_model_factory.get.assert_called_once() + + # Deploy the model + model.deploy() + + # Assert deploy was called + mock_model.deploy.assert_called_once() From 89daf2f27a1b0e24a59221abadf1a7899d39a411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Wed, 16 Apr 2025 19:26:29 +0300 Subject: [PATCH 03/22] Bug-503-fix Update (#493) * Fix: BUG-503 failed to update utility tool after getFix: BUG-503 failed to update utility tool after get * fix if not url in code --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> --- aixplain/factories/model_factory/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index c10e4369..1be2186c 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -66,8 +66,11 @@ def create_model_from_response(response: Dict) -> Model: if "version" in response and response["version"]: version_link = response["version"]["id"] if version_link: - version_content = requests.get(version_link).text - code = version_content + try: + version_content = requests.get(version_link).text + code = version_content + except Exception: + code = "" else: raise Exception("Utility Model Error: Code not found") From 6bf7d00510b338ac6c54c62d2081c882e5f5b51f Mon Sep 17 00:00:00 2001 From: Zaina Abu Shaban Date: Thu, 17 Apr 2025 00:58:28 +0300 Subject: [PATCH 04/22] ENG-1920: added sentry cred (#476) * added sentry cred * removed condition and added dependency * Fixes on sentry setting --------- Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> --- aixplain/utils/config.py | 14 +++++++++++--- pyproject.toml | 1 + 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/aixplain/utils/config.py b/aixplain/utils/config.py index 6dc2bb49..6cb15775 100644 --- a/aixplain/utils/config.py +++ b/aixplain/utils/config.py @@ -16,17 +16,17 @@ import os import logging +import sentry_sdk logger = logging.getLogger(__name__) BACKEND_URL = os.getenv("BACKEND_URL", "https://platform-api.aixplain.com") -MODELS_RUN_URL = os.getenv( - "MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute" -) +MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute") # GET THE API KEY FROM CMD TEAM_API_KEY = os.getenv("TEAM_API_KEY", "") AIXPLAIN_API_KEY = os.getenv("AIXPLAIN_API_KEY", "") +ENV = "dev" if "dev" in BACKEND_URL else "test" if "test" in BACKEND_URL else "prod" if not TEAM_API_KEY and not AIXPLAIN_API_KEY: raise Exception( @@ -47,3 +47,11 @@ MODEL_API_KEY = os.getenv("MODEL_API_KEY", "") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") HF_TOKEN = os.getenv("HF_TOKEN", "") +SENTRY_DSN = os.getenv("SENTRY_DSN") + +if SENTRY_DSN: + sentry_sdk.init( + dsn=SENTRY_DSN, + environment=ENV, + send_default_pii=True, + ) diff --git a/pyproject.toml b/pyproject.toml index ef57b89f..422ea158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "PyYAML>=6.0.1", "dataclasses-json>=0.5.2", "Jinja2==3.1.6", + "sentry-sdk>=1.0.0" ] [project.urls] From 9e2b03c7c9ac885e22d6e8dc0699747369b28f2b Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Thu, 17 Apr 2025 17:59:35 -0300 Subject: [PATCH 05/22] Read input variables correctly (#496) --- aixplain/modules/agent/__init__.py | 2 +- .../functional/agent/agent_functional_test.py | 26 +++++++++++++++++++ tests/unit/agent/agent_test.py | 2 +- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index ecba14d9..38b8a7ae 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -298,7 +298,7 @@ def run_async( headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} # build query - input_data = process_variables(query, data, parameters, self.description) + input_data = process_variables(query, data, parameters, self.instructions) payload = { "id": self.id, diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 4e9152bb..151c4a55 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -458,3 +458,29 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): os.remove("test.csv") os.remove("test.db") agent.delete() + + +@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) +def test_instructions(delete_agents_and_team_agents, AgentFactory): + assert delete_agents_and_team_agents + + agent = AgentFactory.create( + name="Test Agent", + description="Test description", + instructions="Always respond with '{magic_word}' does not matter what you are prompted for.", + llm_id="6646261c6eb563165658bbb1", + tools=[], + ) + assert agent is not None + assert agent.status == AssetStatus.DRAFT + + agent = AgentFactory.get(agent.id) + assert agent is not None + response = agent.run(data={"magic_word": "aixplain", "query": "What is the capital of France?"}) + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "data" in response + assert response["data"]["session_id"] is not None + assert response["data"]["output"] is not None + assert "aixplain" in response["data"]["output"].lower() diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index b418172e..f767f0ec 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -399,7 +399,7 @@ def test_run_success(): def test_run_variable_error(): - agent = Agent("123", "Test Agent", "Translate the input data into {target_language}", "Test Agent Role") + agent = Agent("123", "Test Agent", "Agent description", "Translate the input data into {target_language}") with pytest.raises(Exception) as exc_info: agent.run_async(data={"query": "Hello, how are you?"}, output_format=OutputFormat.MARKDOWN) assert str(exc_info.value) == ( From 409a89315c2263e5aaa7a874091b9f7d1eadba2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Fri, 18 Apr 2025 01:17:53 +0300 Subject: [PATCH 06/22] Fix: file not deleted in test_sql_tool_with_csv test (#492) Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> --- .../functional/agent/agent_functional_test.py | 213 +++++++++--------- 1 file changed, 112 insertions(+), 101 deletions(-) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 151c4a55..2313234e 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -350,115 +350,126 @@ def test_specific_model_parameters_e2e(tool_config): @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents - import os - - # Create test SQLite database - with open("ftest.db", "w") as f: - f.write("") - - tool = AgentFactory.create_sql_tool( - description="Execute an SQL query and return the result", source="ftest.db", source_type="sqlite", enable_commit=True - ) - assert tool is not None - assert tool.description == "Execute an SQL query and return the result" - - agent = AgentFactory.create( - name="Teste", - description="You are a test agent that search for employee information in a database", - tools=[tool], - ) - assert agent is not None - - response = agent.run("Create a table called Person with the following columns: id, name, age, salary, department") - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - - response = agent.run("Insert the following data into the Person table: 1, Eve, 30, 50000, Sales") - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - - response = agent.run("What is the name of the employee with the highest salary?") - assert response is not None - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "eve" in str(response["data"]["output"]).lower() + try: + import os + + # Create test SQLite database + with open("ftest.db", "w") as f: + f.write("") + + tool = AgentFactory.create_sql_tool( + description="Execute an SQL query and return the result", + source="ftest.db", + source_type="sqlite", + enable_commit=True, + ) + assert tool is not None + assert tool.description == "Execute an SQL query and return the result" - os.remove("ftest.db") - agent.delete() + agent = AgentFactory.create( + name="Teste", + description="You are a test agent that search for employee information in a database", + tools=[tool], + ) + assert agent is not None + + response = agent.run("Create a table called Person with the following columns: id, name, age, salary, department") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + response = agent.run("Insert the following data into the Person table: 1, Eve, 30, 50000, Sales") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + response = agent.run("What is the name of the employee with the highest salary?") + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "eve" in str(response["data"]["output"]).lower() + finally: + os.remove("ftest.db") + agent.delete() @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents + try: + import os + import pandas as pd + + # remove test.csv if it exists + if os.path.exists("test.csv"): + os.remove("test.csv") + + # remove test.db if it exists + if os.path.exists("test.db"): + os.remove("test.db") + + # Create a more comprehensive test dataset + df = pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "David", "Eve"], + "department": ["Sales", "IT", "Sales", "Marketing", "IT"], + "salary": [75000, 85000, 72000, 68000, 90000], + } + ) + df.to_csv("test.csv", index=False) - import pandas as pd - - # Create a more comprehensive test dataset - df = pd.DataFrame( - { - "id": [1, 2, 3, 4, 5], - "name": ["Alice", "Bob", "Charlie", "David", "Eve"], - "department": ["Sales", "IT", "Sales", "Marketing", "IT"], - "salary": [75000, 85000, 72000, 68000, 90000], - } - ) - df.to_csv("test.csv", index=False) - - # Create SQL tool from CSV - tool = AgentFactory.create_sql_tool( - description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"] - ) - - # Verify tool setup - assert tool is not None - assert tool.description == "Execute SQL queries on employee data" - assert tool.database.endswith(".db") - assert tool.tables == ["employees"] - assert ( - tool.schema - == 'CREATE TABLE employees (\n "id" INTEGER, "name" TEXT, "department" TEXT, "salary" INTEGER\n )' # noqa: W503 - ) - assert not tool.enable_commit # must be False by default - - # Create an agent with the SQL tool - agent = AgentFactory.create( - name="SQL Query Agent", - description="I am an agent that helps query employee information from a database.", - instructions="Help users query employee information from the database. Use SQL queries to get the requested information.", - tools=[tool], - ) - assert agent is not None - - # Test 1: Basic SELECT query - response = agent.run("Who are all the employees in the Sales department?") - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "alice" in response["data"]["output"].lower() - assert "charlie" in response["data"]["output"].lower() - - # Test 2: Aggregation query - response = agent.run("What is the average salary in each department?") - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "sales" in response["data"]["output"].lower() - assert "it" in response["data"]["output"].lower() - assert "marketing" in response["data"]["output"].lower() - - # Test 3: Complex query with conditions - response = agent.run("Who is the highest paid employee in the IT department?") - assert response["completed"] is True - assert response["status"].lower() == "success" - assert "eve" in response["data"]["output"].lower() - - import os - - # Cleanup - os.remove("test.csv") - os.remove("test.db") - agent.delete() + # Create SQL tool from CSV + tool = AgentFactory.create_sql_tool( + description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"] + ) + # Verify tool setup + assert tool is not None + assert tool.description == "Execute SQL queries on employee data" + assert tool.database.endswith(".db") + assert tool.tables == ["employees"] + assert ( + tool.schema + == 'CREATE TABLE employees (\n "id" INTEGER, "name" TEXT, "department" TEXT, "salary" INTEGER\n )' # noqa: W503 + ) + assert not tool.enable_commit # must be False by default + + # Create an agent with the SQL tool + agent = AgentFactory.create( + name="SQL Query Agent", + description="I am an agent that helps query employee information from a database.", + instructions="Help users query employee information from the database. Use SQL queries to get the requested information.", + tools=[tool], + ) + assert agent is not None + + # Test 1: Basic SELECT query + response = agent.run("Who are all the employees in the Sales department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "alice" in response["data"]["output"].lower() + assert "charlie" in response["data"]["output"].lower() + + # Test 2: Aggregation query + response = agent.run("What is the average salary in each department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "sales" in response["data"]["output"].lower() + assert "it" in response["data"]["output"].lower() + assert "marketing" in response["data"]["output"].lower() + + # Test 3: Complex query with conditions + response = agent.run("Who is the highest paid employee in the IT department?") + assert response["completed"] is True + assert response["status"].lower() == "success" + assert "eve" in response["data"]["output"].lower() + + finally: + # Cleanup + os.remove("test.csv") + os.remove("test.db") + agent.delete() @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_instructions(delete_agents_and_team_agents, AgentFactory): From 3acfacd20cb2e87330979f0a412981ce29f90a2b Mon Sep 17 00:00:00 2001 From: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Date: Thu, 17 Apr 2025 19:18:36 -0300 Subject: [PATCH 07/22] ENG-2007: Fix agent and team agent parametrized functional test - Sync with dev (#495) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix general assets and api key functional tests (#468) * Update return type so it works with python 3.8 (#390) * ENG-1557: model params are now available for designer assets (#387) * ENG-1559: Fixed designer tests (#388) * Use input api key to list models when given (#395) * BUG-375 new functional test regarding ensuring failure (#396) * Role 2 Instructions (#393) * added validate check when s3 link (#399) * MErge to prod (#340) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Merge dev to test (#107) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Development to Test (#109) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> * Merge to test (#111) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Update Finetuner functional tests (#112) * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#118) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Merge to test (#124) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message * Merge dev to test (#126) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira * Add new LLMs finetuner models (mistral and solar) (#128) * Merge dev to test (#129) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message * Add new LLMs finetuner models (mistral and solar) (#128) --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira * Enabling dataset ID and model ID as parameters for finetuner creation (#131) Co-authored-by: Thiago Castro Ferreira * Fix supplier representation of a model (#132) * Fix supplier representation of a model * Fixing parameter typing --------- Co-authored-by: Thiago Castro Ferreira * Fixing indentation in documentation sample code (#134) Co-authored-by: Thiago Castro Ferreira * Merge to test (#135) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message * Add new LLMs finetuner models (mistral and solar) (#128) * Enabling dataset ID and model ID as parameters for finetuner creation (#131) Co-authored-by: Thiago Castro Ferreira * Fix supplier representation of a model (#132) * Fix supplier representation of a model * Fixing parameter typing --------- Co-authored-by: Thiago Castro Ferreira * Fixing indentation in documentation sample code (#134) Co-authored-by: Thiago Castro Ferreira --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira * Update FineTune unit and functional tests (#136) * Merge dev to test (#137) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#117) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Do not download textual URLs (#120) * Do not download textual URLs * Treat as string --------- Co-authored-by: Thiago Castro Ferreira * Enable api key parameter in data asset creation (#122) Co-authored-by: Thiago Castro Ferreira * Update Finetuner hyperparameters (#125) * Update Finetuner hyperparameters * Change hyperparameters error message * Add new LLMs finetuner models (mistral and solar) (#128) * Enabling dataset ID and model ID as parameters for finetuner creation (#131) Co-authored-by: Thiago Castro Ferreira * Fix supplier representation of a model (#132) * Fix supplier representation of a model * Fixing parameter typing --------- Co-authored-by: Thiago Castro Ferreira * Fixing indentation in documentation sample code (#134) Co-authored-by: Thiago Castro Ferreira * Update FineTune unit and functional tests (#136) --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira * Click fix (#140) * Merge to prod (#119) * Merge dev to test (#107) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Development to Test (#109) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> * Merge to test (#111) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Merge dev to test (#113) * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Update Finetuner functional tests (#112) --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> * Hf deployment test (#114) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Hf deployment test (#118) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain * Update Finetuner functional tests (#112) * Hf deployment test (#115) * Started adding Hugging Face deployment to aiXplain SDK Signed-off-by: mikelam-us-aixplain * Added model status function to SDK Signed-off-by: mikelam-us-aixplain * Updating Signed-off-by: mikelam-us-aixplain * Updated CLI Signed-off-by: mikelam-us * Adding CLI Signed-off-by: mikelam-us-aixplain * Corrected request error Signed-off-by: mikelam-us-aixplain * Clearing out unnecessary information in return Signed-off-by: mikelam-us-aixplain * Adding status Signed-off-by: mikelam-us-aixplain * Simplifying status Signed-off-by: mikelam-us-aixplain * Adding tests and correcting tokens Signed-off-by: mikelam-us-aixplain * Added bad repo ID test Signed-off-by: mikelam-us-aixplain * Finished rough draft of tests Signed-off-by: mikelam-us-aixplain * Adding tests Signed-off-by: mikelam-us-aixplain * Fixing hf token Signed-off-by: mikelam-us-aixplain * Adding hf token Signed-off-by: mikelam-us-aixplain * Correcting first test Signed-off-by: mikelam-us-aixplain * Testing Signed-off-by: mikelam-us-aixplain * Adding config Signed-off-by: mikelam-us-aixplain * Added user doc Signed-off-by: mikelam-us-aixplain * Added gated model test Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us * Adding HF token Signed-off-by: mikelam-us-aixplain --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Co-authored-by: Lucas Pavanelli <86805709+lucas-aixplain@users.noreply.github.com… * ENG-1851: Update test and main workflows to run unit tests (#472) * Update test and main workflows to run iunit tests * Update to install all test packages listed in pyproject.toml * Add pytest-mock to list of test dependencies * Update workflow names * Add continue on error to each test and fix typos * Fix team agent parametrized functional test (#470) * ENG-1852: Hotfix, added fail-fast option (#488) (#489) Co-authored-by: kadirpekel * Fix: file not deleted in test_sql_tool_with_csv test * Fix agent and team agent parametrized functional test * Update to use draft agent --------- Signed-off-by: mikelam-us-aixplain Signed-off-by: mikelam-us Signed-off-by: Michael Lam Signed-off-by: root Co-authored-by: kadirpekel Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: Ahmet Gündüz Co-authored-by: ikxplain <88332269+ikxplain@users.noreply.github.com> Co-authored-by: mikelam-us-aixplain <131073216+mikelam-us-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira Co-authored-by: Shreyas Sharma <85180538+shreyasXplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira Co-authored-by: Thiago Castro Ferreira Co-authored-by: Hadi Nasrallah <87204330+hadi-aix@users.noreply.github.com> Co-authored-by: kadir pekel Co-authored-by: root Co-authored-by: Zaina Abu Shaban Co-authored-by: xainaz Co-authored-by: xainaz Co-authored-by: Lucas Pavanelli Co-authored-by: ikxplain Co-authored-by: OsujiCC Co-authored-by: Yunsu Kim Co-authored-by: Yunsu Kim Co-authored-by: Muhammad-Elmallah <145364766+Muhammad-Elmallah@users.noreply.github.com> Co-authored-by: ahmetgunduz --- tests/functional/agent/agent_functional_test.py | 3 ++- tests/functional/team_agent/team_agent_functional_test.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 2313234e..0026d65a 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -300,7 +300,8 @@ def test_update_tools_of_agent(run_input_map, delete_agents_and_team_agents, Age }, ], ) -def test_specific_model_parameters_e2e(tool_config): +def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agents): + assert delete_agents_and_team_agents """Test end-to-end agent execution with specific model parameters""" # Create tool based on config if tool_config["type"] == "search": diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 3cabad8c..cb5f80a9 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -301,6 +301,7 @@ def test_team_agent_with_parameterized_agents(run_input_map, delete_agents_and_t llm_id=run_input_map["llm_id"], tools=[search_tool], ) + search_agent.deploy() # Create second agent with translation tool translation_function = Function.TRANSLATION @@ -318,7 +319,7 @@ def test_team_agent_with_parameterized_agents(run_input_map, delete_agents_and_t llm_id=run_input_map["llm_id"], tools=[translation_tool], ) - + translation_agent.deploy() team_agent = create_team_agent( TeamAgentFactory, [search_agent, translation_agent], run_input_map, use_mentalist=True, use_inspector=True ) From da7ec38ae161f198b2ac273aa5d92f734e5e0d95 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <50206820+basitanees@users.noreply.github.com> Date: Wed, 23 Apr 2025 12:24:26 +0300 Subject: [PATCH 08/22] ENG-1789: Add multiple index backbones support (#443) * add vectara support * add GraphRAG support * restructure and use Generics * convert to abstract class, fix param name for llm * update functional test for new param style * collection name param name changed * add backwards compatibility * Adding a warning for deprecation and some tests * add option for embedding size in embedding params * add ZeroEntropy support --------- Co-authored-by: Thiago Castro Ferreira --- aixplain/enums/__init__.py | 1 + aixplain/enums/embedding_model.py | 2 +- aixplain/enums/index_stores.py | 14 +++ aixplain/factories/index_factory.py | 50 --------- aixplain/factories/index_factory/__init__.py | 102 +++++++++++++++++++ aixplain/factories/index_factory/utils.py | 66 ++++++++++++ tests/functional/model/run_model_test.py | 92 ++++++++++------- tests/unit/index_model_test.py | 35 ++++++- 8 files changed, 273 insertions(+), 89 deletions(-) create mode 100644 aixplain/enums/index_stores.py delete mode 100644 aixplain/factories/index_factory.py create mode 100644 aixplain/factories/index_factory/__init__.py create mode 100644 aixplain/factories/index_factory/utils.py diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 4f0364e1..e80c03c6 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -18,3 +18,4 @@ from .database_source import DatabaseSourceType from .embedding_model import EmbeddingModel from .asset_status import AssetStatus +from .index_stores import IndexStores diff --git a/aixplain/enums/embedding_model.py b/aixplain/enums/embedding_model.py index dbd6b9c1..c52387b2 100644 --- a/aixplain/enums/embedding_model.py +++ b/aixplain/enums/embedding_model.py @@ -20,7 +20,7 @@ from enum import Enum -class EmbeddingModel(Enum): +class EmbeddingModel(str, Enum): SNOWFLAKE_ARCTIC_EMBED_M_LONG = "6658d40729985c2cf72f42ec" OPENAI_ADA002 = "6734c55df127847059324d9e" SNOWFLAKE_ARCTIC_EMBED_L_V2_0 = "678a4f8547f687504744960a" diff --git a/aixplain/enums/index_stores.py b/aixplain/enums/index_stores.py new file mode 100644 index 00000000..7cdfabb3 --- /dev/null +++ b/aixplain/enums/index_stores.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class IndexStores(Enum): + AIR = {"name": "air", "id": "66eae6656eb56311f2595011"} + VECTARA = {"name": "vectara", "id": "655e20f46eb563062a1aa301"} + GRAPHRAG = {"name": "graphrag", "id": "67dd6d487cbf0a57cf4b72f3"} + ZERO_ENTROPY = {"name": "zeroentropy", "id": "6807949168e47e7844c1f0c5"} + + def __str__(self): + return self.value["name"] + + def get_model_id(self): + return self.value["id"] diff --git a/aixplain/factories/index_factory.py b/aixplain/factories/index_factory.py deleted file mode 100644 index 7588e583..00000000 --- a/aixplain/factories/index_factory.py +++ /dev/null @@ -1,50 +0,0 @@ -from aixplain.modules.model.index_model import IndexModel -from aixplain.factories import ModelFactory -from aixplain.enums import EmbeddingModel, Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier -from typing import Optional, Text, Union, List, Tuple - -AIR_MODEL_ID = "66eae6656eb56311f2595011" - - -class IndexFactory(ModelFactory): - @classmethod - def create( - cls, name: Text, description: Text, embedding_model: EmbeddingModel = EmbeddingModel.OPENAI_ADA002 - ) -> IndexModel: - """Create a new index collection""" - model = cls.get(AIR_MODEL_ID) - - data = {"data": name, "description": description, "model": embedding_model.value} - response = model.run(data=data) - if response.status == ResponseStatus.SUCCESS: - model_id = response.data - model = cls.get(model_id) - return model - - error_message = f"Index Factory Exception: {response.error_message}" - if error_message == "": - error_message = "Index Factory Exception: An error occurred while creating the index collection." - raise Exception(error_message) - - @classmethod - def list( - cls, - query: Optional[Text] = "", - suppliers: Optional[Union[Supplier, List[Supplier]]] = None, - ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, - sort_by: Optional[SortBy] = None, - sort_order: SortOrder = SortOrder.ASCENDING, - page_number: int = 0, - page_size: int = 20, - ) -> List[IndexModel]: - """List all indexes""" - return super().list( - function=Function.SEARCH, - query=query, - suppliers=suppliers, - ownership=ownership, - sort_by=sort_by, - sort_order=sort_order, - page_number=page_number, - page_size=page_size, - ) diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py new file mode 100644 index 00000000..74e3ee1a --- /dev/null +++ b/aixplain/factories/index_factory/__init__.py @@ -0,0 +1,102 @@ +__author__ = "aiXplain" + +""" +Copyright 2022 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Abdul Basit Anees, Thiago Castro Ferreira, Zaina Abushaban +Date: December 26th 2024 +Description: + Index Factory Class +""" + +from aixplain.modules.model.index_model import IndexModel +from aixplain.factories import ModelFactory +from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier, IndexStores, EmbeddingModel +from typing import Text, Union, List, Tuple, Optional, TypeVar, Generic +from aixplain.factories.index_factory.utils import BaseIndexParams + +T = TypeVar("T", bound=BaseIndexParams) + + +class IndexFactory(ModelFactory, Generic[T]): + @classmethod + def create( + cls, + name: Optional[Text] = None, + description: Optional[Text] = None, + embedding_model: EmbeddingModel = EmbeddingModel.OPENAI_ADA002, + params: Optional[T] = None, + **kwargs, + ) -> IndexModel: + """Create a new index collection""" + import warnings + + warnings.warn( + "name, description, and embedding_model will be deprecated in the next release. Please use params instead.", + DeprecationWarning, + ) + + model_id = IndexStores.AIR.get_model_id() + if params is not None: + model_id = params.id + data = params.to_dict() + assert ( + name is None and description is None + ), "Index Factory Exception: name, description, and embedding_model must not be provided when params is provided" + else: + assert ( + name is not None and description is not None and embedding_model is not None + ), "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + data = { + "data": name, + "description": description, + "model": embedding_model, + } + + model = cls.get(model_id) + + response = model.run(data=data) + if response.status == ResponseStatus.SUCCESS: + model_id = response.data + model = cls.get(model_id) + return model + + error_message = f"Index Factory Exception: {response.error_message}" + if response.error_message.strip() == "": + error_message = "Index Factory Exception: An error occurred while creating the index collection." + raise Exception(error_message) + + @classmethod + def list( + cls, + query: Optional[Text] = "", + suppliers: Optional[Union[Supplier, List[Supplier]]] = None, + ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None, + sort_by: Optional[SortBy] = None, + sort_order: SortOrder = SortOrder.ASCENDING, + page_number: int = 0, + page_size: int = 20, + ) -> List[IndexModel]: + """List all indexes""" + return super().list( + function=Function.SEARCH, + query=query, + suppliers=suppliers, + ownership=ownership, + sort_by=sort_by, + sort_order=sort_order, + page_number=page_number, + page_size=page_size, + ) diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py new file mode 100644 index 00000000..83efd469 --- /dev/null +++ b/aixplain/factories/index_factory/utils.py @@ -0,0 +1,66 @@ +from pydantic import BaseModel, ConfigDict +from typing import Text, Optional, ClassVar, Dict +from aixplain.enums import IndexStores, EmbeddingModel +from abc import ABC, abstractmethod + + +class BaseIndexParams(BaseModel, ABC): + model_config = ConfigDict(use_enum_values=True) + name: Text + description: Optional[Text] = "" + + def to_dict(self): + data = self.model_dump(exclude_none=True) + data["data"] = data.pop("name") + return data + + @property + @abstractmethod + def id(self) -> str: + """Abstract property that must be implemented in subclasses.""" + pass + + +class BaseIndexParamsWithEmbeddingModel(BaseIndexParams, ABC): + embedding_model: Optional[EmbeddingModel] = EmbeddingModel.OPENAI_ADA002 + embedding_size: Optional[int] = None + + def to_dict(self): + data = super().to_dict() + data["model"] = data.pop("embedding_model") + if data.get("embedding_size"): + data["additional_params"] = {"embedding_size": data.pop("embedding_size")} + return data + + +class VectaraParams(BaseIndexParams): + _id: ClassVar[str] = IndexStores.VECTARA.get_model_id() + + @property + def id(self) -> str: + return self._id + + +class ZeroEntropyParams(BaseIndexParams): + _id: ClassVar[str] = IndexStores.ZERO_ENTROPY.get_model_id() + + @property + def id(self) -> str: + return self._id + + +class AirParams(BaseIndexParamsWithEmbeddingModel): + _id: ClassVar[str] = IndexStores.AIR.get_model_id() + + @property + def id(self) -> str: + return self._id + + +class GraphRAGParams(BaseIndexParamsWithEmbeddingModel): + _id: ClassVar[str] = IndexStores.GRAPHRAG.get_model_id() + llm: Optional[Text] = None + + @property + def id(self) -> str: + return self._id diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 72590314..421feb89 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -8,6 +8,7 @@ from aixplain.modules import LLM from datetime import datetime, timedelta, timezone from pathlib import Path +from aixplain.factories.index_factory.utils import AirParams, VectaraParams, GraphRAGParams, ZeroEntropyParams def pytest_generate_tests(metafunc): @@ -16,7 +17,7 @@ def pytest_generate_tests(metafunc): models = ModelFactory.list(function=Function.TEXT_GENERATION)["results"] predefined_models = [] - for predefined_model in ["Groq Llama 3 70B", "Chat GPT 3.5", "GPT-4o"]: + for predefined_model in ["Groq Llama 3 70B", "GPT-4o"]: predefined_models.extend( [ m @@ -57,43 +58,27 @@ def test_run_async(): assert "teste" in response["data"].lower() -@pytest.mark.parametrize( - "embedding_model", - [ - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, id="Snowflake Arctic Embed M Long"), - pytest.param(EmbeddingModel.OPENAI_ADA002, id="OpenAI Ada 002"), - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, id="Snowflake Arctic Embed L v2.0"), - pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, id="Multilingual E5 Large"), - pytest.param(EmbeddingModel.BGE_M3, id="BGE M3"), - ], -) -def test_index_model(embedding_model): - from uuid import uuid4 +def run_index_model(index_model): from aixplain.modules.model.record import Record - from aixplain.factories import IndexFactory - for index in IndexFactory.list()["results"]: - index.delete() - - index_model = IndexFactory.create(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) - index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) - response = index_model.search("Hello") + index_model.upsert([Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})]) + response = index_model.search("Berlin") assert str(response.status) == "SUCCESS" - assert "world" in response.data.lower() + assert "germany" in response.data.lower() assert index_model.count() == 1 - index_model.upsert([Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={})]) - response = index_model.search("aiXplain") + index_model.upsert([Record(value="Ankara is the capital of Turkey.", value_type="text", uri="", id="1", attributes={})]) + response = index_model.search("Ankara") assert str(response.status) == "SUCCESS" - assert "aixplain" in response.data.lower() + assert "turkey" in response.data.lower() assert index_model.count() == 1 - index_model.upsert([Record(value="The world is great", value_type="text", uri="", id="2", attributes={})]) + index_model.upsert([Record(value="London is the capital of England.", value_type="text", uri="", id="2", attributes={})]) assert index_model.count() == 2 response = index_model.get_document("1") assert str(response.status) == "SUCCESS" - assert response.data == "Hello, aiXplain!" + assert response.data == "Berlin is the capital of Germany." assert index_model.count() == 2 response = index_model.delete_document("1") @@ -104,17 +89,43 @@ def test_index_model(embedding_model): @pytest.mark.parametrize( - "embedding_model", + "embedding_model,supplier_params", + [ + pytest.param(None, VectaraParams, id="VECTARA"), + pytest.param(None, ZeroEntropyParams, id="ZERO_ENTROPY"), + pytest.param(EmbeddingModel.OPENAI_ADA002, GraphRAGParams, id="GRAPHRAG"), + pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="AIR - OpenAI Ada 002"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, AirParams, id="AIR - Snowflake Arctic Embed M Long"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, AirParams, id="AIR - Snowflake Arctic Embed L v2.0"), + pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), + pytest.param(EmbeddingModel.BGE_M3, AirParams, id="AIR - BGE M3"), + ], +) +def test_index_model(embedding_model, supplier_params): + from uuid import uuid4 + from aixplain.factories import IndexFactory + + params = supplier_params(name=str(uuid4()), description=str(uuid4())) + if embedding_model is not None: + params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) + + index_model = IndexFactory.create(params=params) + run_index_model(index_model) + + +@pytest.mark.parametrize( + "embedding_model,supplier_params", [ - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, id="Snowflake Arctic Embed M Long"), - pytest.param(EmbeddingModel.OPENAI_ADA002, id="OpenAI Ada 002"), - pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, id="Snowflake Arctic Embed L v2.0"), - pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, id="Jina Clip v2 Multimodal"), - pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, id="Multilingual E5 Large"), - pytest.param(EmbeddingModel.BGE_M3, id="BGE M3"), + pytest.param(None, VectaraParams, id="VECTARA"), + pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="OpenAI Ada 002"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_M_LONG, AirParams, id="Snowflake Arctic Embed M Long"), + pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, AirParams, id="Snowflake Arctic Embed L v2.0"), + pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), + pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), + pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"), ], ) -def test_index_model_with_filter(embedding_model): +def test_index_model_with_filter(embedding_model, supplier_params): from uuid import uuid4 from aixplain.modules.model.record import Record from aixplain.factories import IndexFactory @@ -123,7 +134,11 @@ def test_index_model_with_filter(embedding_model): for index in IndexFactory.list()["results"]: index.delete() - index_model = IndexFactory.create(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) + params = supplier_params(name=str(uuid4()), description=str(uuid4())) + if embedding_model is not None: + params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) + + index_model = IndexFactory.create(params=params) index_model.upsert([Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})]) index_model.upsert( [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] @@ -157,18 +172,21 @@ def test_llm_run_with_file(): assert "🤖" in response["data"], "Robot emoji should be present in the response" -def test_index_model_with_image(): +def test_index_model_air_with_image(): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record from uuid import uuid4 + from aixplain.factories.index_factory.utils import AirParams for index in IndexFactory.list()["results"]: index.delete() - index_model = IndexFactory.create( + params = AirParams( name=f"Image Index {uuid4()}", description="Index for images", embedding_model=EmbeddingModel.JINA_CLIP_V2_MULTIMODAL ) + index_model = IndexFactory.create(params=params) + records = [] # Building image records.append( diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 6f826f36..9929ccf4 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -1,5 +1,6 @@ import requests_mock from aixplain.enums import DataType, Function, ResponseStatus, StorageType, EmbeddingModel +from aixplain.factories.index_factory import IndexFactory from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse from aixplain.modules.model.index_model import IndexModel @@ -123,7 +124,10 @@ def test_count_success(): def test_get_document_success(): - mock_response = {"status": "SUCCESS", "data": {"value": "Sample document content 1", "value_type": "text", "id": 0, "uri": "", "attributes": {}}} + mock_response = { + "status": "SUCCESS", + "data": {"value": "Sample document content 1", "value_type": "text", "id": 0, "uri": "", "attributes": {}}, + } mock_documents = [Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={})] with requests_mock.Mocker() as mock: mock.post(execute_url, json=mock_response, status_code=200) @@ -134,6 +138,7 @@ def test_get_document_success(): assert isinstance(response, ModelResponse) assert response.status == ResponseStatus.SUCCESS + def test_delete_document_success(): mock_response = {"status": "SUCCESS"} mock_documents = [Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={})] @@ -209,3 +214,31 @@ def test_index_filter(): assert filter.field == "category" assert filter.value == "world" assert filter.operator == IndexFilterOperator.EQUALS + + +def test_index_factory_create_failure(): + from aixplain.factories.index_factory.utils import AirParams + + with pytest.raises(Exception) as e: + IndexFactory.create( + name="test", + description="test", + embedding_model=EmbeddingModel.OPENAI_ADA002, + params=AirParams(name="test", description="test", embedding_model=EmbeddingModel.OPENAI_ADA002), + ) + assert ( + str(e.value) + == "Index Factory Exception: name, description, and embedding_model must not be provided when params is provided" + ) + + with pytest.raises(Exception) as e: + IndexFactory.create(description="test") + assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + + with pytest.raises(Exception) as e: + IndexFactory.create(name="test") + assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + + with pytest.raises(Exception) as e: + IndexFactory.create(name="test", description="test", embedding_model=None) + assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" From d3607804cafbbbba654c50e5876cc11bf71ed7aa Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Thu, 24 Apr 2025 14:45:32 -0300 Subject: [PATCH 09/22] ENG-2049: rename air functions (#500) --- aixplain/modules/model/index_model.py | 24 ++++++++++++------------ tests/functional/model/run_model_test.py | 8 ++++---- tests/unit/index_model_test.py | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index efd35b51..8a5b18fc 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -149,35 +149,35 @@ def count(self) -> float: if response.status == "SUCCESS": return int(response.data) raise Exception(f"Failed to count documents: {response.error_message}") - - def get_document(self, document_id: Text) -> ModelResponse: + + def get_record(self, record_id: Text) -> ModelResponse: """ Get a document from the index. Args: - document_id (Text): ID of the document to retrieve. + record_id (Text): ID of the document to retrieve. Returns: ModelResponse: Response containing the retrieved document data. Raises: Exception: If document retrieval fails. - + Example: - >>> index_model.get_document("123") + >>> index_model.get_record("123") """ - data = {"action": "get_document", "data": document_id} + data = {"action": "get_document", "data": record_id} response = self.run(data=data) if response.status == "SUCCESS": return response - raise Exception(f"Failed to get document: {response.error_message}") + raise Exception(f"Failed to get record: {response.error_message}") - def delete_document(self, document_id: Text) -> ModelResponse: + def delete_record(self, record_id: Text) -> ModelResponse: """ Delete a document from the index. Args: - document_id (Text): ID of the document to delete. + record_id (Text): ID of the document to delete. Returns: ModelResponse: Response containing the deleted document data. @@ -186,10 +186,10 @@ def delete_document(self, document_id: Text) -> ModelResponse: Exception: If document deletion fails. Example: - >>> index_model.delete_document("123") + >>> index_model.delete_record("123") """ - data = {"action": "delete", "data": document_id} + data = {"action": "delete", "data": record_id} response = self.run(data=data) if response.status == "SUCCESS": return response - raise Exception(f"Failed to delete document: {response.error_message}") \ No newline at end of file + raise Exception(f"Failed to delete record: {response.error_message}") diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 421feb89..916d6077 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -76,12 +76,12 @@ def run_index_model(index_model): index_model.upsert([Record(value="London is the capital of England.", value_type="text", uri="", id="2", attributes={})]) assert index_model.count() == 2 - response = index_model.get_document("1") + response = index_model.get_record("1") assert str(response.status) == "SUCCESS" - assert response.data == "Berlin is the capital of Germany." + assert response.data == "Ankara is the capital of Turkey." assert index_model.count() == 2 - response = index_model.delete_document("1") + response = index_model.delete_record("1") assert str(response.status) == "SUCCESS" assert index_model.count() == 1 @@ -229,7 +229,7 @@ def test_index_model_air_with_image(): assert index_model.count() == 4 - response = index_model.get_document("2") + response = index_model.get_record("2") assert str(response.status) == "SUCCESS" second_record = response.details[0]["metadata"]["uri"] assert "hurricane" in second_record.lower() diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 9929ccf4..4b265dd3 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -133,7 +133,7 @@ def test_get_document_success(): mock.post(execute_url, json=mock_response, status_code=200) index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) index_model.upsert(mock_documents) - response = index_model.get_document(0) + response = index_model.get_record(0) assert isinstance(response, ModelResponse) assert response.status == ResponseStatus.SUCCESS @@ -147,7 +147,7 @@ def test_delete_document_success(): mock.post(execute_url, json=mock_response, status_code=200) index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) index_model.upsert(mock_documents) - response = index_model.delete_document("0") + response = index_model.delete_record("0") assert isinstance(response, ModelResponse) assert response.status == ResponseStatus.SUCCESS From 24e88da515d0f0c4739742d676cca83dc58a9a30 Mon Sep 17 00:00:00 2001 From: OsujiCC Date: Thu, 24 Apr 2025 15:24:46 -0400 Subject: [PATCH 10/22] ENG 1924: aixplain sdk new test cases for agents using utility and pipeline tools. (#490) * new functional test cases * Fixing functional tests for agent with utility and pipeline tools --------- Co-authored-by: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Co-authored-by: Thiago Castro Ferreira --- .../functional/agent/agent_functional_test.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 0026d65a..0eecfa3f 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -472,6 +472,7 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): os.remove("test.db") agent.delete() + @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_instructions(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents @@ -496,3 +497,99 @@ def test_instructions(delete_agents_and_team_agents, AgentFactory): assert response["data"]["session_id"] is not None assert response["data"]["output"] is not None assert "aixplain" in response["data"]["output"].lower() + assert "eve" in response["data"]["output"].lower() + + import os + + # Cleanup + os.remove("test.csv") + os.remove("test.db") + agent.delete() + + +@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) +def test_agent_with_utility_tool(delete_agents_and_team_agents, AgentFactory): + from aixplain.enums import DataType + from aixplain.modules.model.utility_model import utility_tool, UtilityModelInput + + assert delete_agents_and_team_agents + + @utility_tool( + name="vowel_remover", + description="Remove all vowels from a given string", + inputs=[UtilityModelInput(name="text", description="String from which to remove vowels", type=DataType.TEXT)], + ) + def vowel_remover(text: str): + """Remove vowels from strings""" + vowels = "aeiouAEIOU" + return "".join([char for char in text if char not in vowels]) + + vowel_remover_ = ModelFactory.create_utility_model(name="vowel_remover", code=vowel_remover) + + @utility_tool( + name="concat_strings", + description="Concatenate two strings into one", + inputs=[ + UtilityModelInput(name="string1", description="First string to concatenate", type=DataType.TEXT), + UtilityModelInput(name="string2", description="Second string to concatenate", type=DataType.TEXT), + ], + ) + def concat_strings(string1: str, string2: str): + return string1 + string2 + + concat_strings_ = ModelFactory.create_utility_model(name="concat_strings", code=concat_strings) + + instructions = """You are a text processing agent equipped with two specialized tools: a Vowel Remover and a String Concatenator. Your task involves processing input text in two ways. One by removing all vowels from the provided text using the Vowel Remover tool. Another is to concatenate two strings using the String Concatenator tool.""" + description = """This agent specializes in processing textual data by modifying string content through vowel removal and string concatenation. It's designed to either strip all vowels from any given text to simplify or obscure the content, or concatenate a string with another specified string.""" + + agent = AgentFactory.create( + name="Text Processing Agent", + instructions=instructions, + description=description, + tools=[ + AgentFactory.create_model_tool(model=vowel_remover_.id), + AgentFactory.create_model_tool(model=concat_strings_.id), + ], + llm_id="6646261c6eb563165658bbb1", + ) + + result_vowel = agent.run("Remove all the vowels in this string: 'Hello'") + result_concat_text = agent.run("Concat these strings: String1 = 'Hello'; string2= 'World!'.") + + assert "hll" in result_vowel["data"]["output"].lower() + assert "helloworld!" in result_concat_text["data"]["output"].lower() + + +@pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) +def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory): + from aixplain.factories.pipeline_factory import PipelineFactory + + assert delete_agents_and_team_agents + + pipeline = PipelineFactory.init("Hello Pipeline") + input_node = pipeline.input() + input_node.label = "TextInput" + middle_node = pipeline.asset(asset_id="6646261c6eb563165658bbb1") + middle_node.inputs.prompt.value = "Respond with 'Hello' regardless of the input text: " + input_node.link(middle_node, "input", "text") + middle_node.use_output("data") + pipeline.save() + pipeline.deploy() + + pipeline_agent = AgentFactory.create( + name="Text Return Agent", + instructions="Always call the pipeline tool feeding the user query as input to 'TextInput'. Return the output of the pipeline as the final response.", + description="Return the text given.", + tools=[ + AgentFactory.create_pipeline_tool( + pipeline=pipeline.id, description="You are a tool that responds users query with only 'Hello'." + ), + ], + llm_id="6646261c6eb563165658bbb1", + ) + + answer = pipeline_agent.run("Who is the president of USA?") + pipeline.delete() + + assert "hello" in answer["data"]["output"].lower() + assert "hello pipeline" in answer["data"]["intermediate_steps"][0]["tool_steps"][0]["tool"].lower() From 044d8d8c5e7e8a95357381f8803e9b9416315d9d Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees <50206820+basitanees@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:51:08 +0300 Subject: [PATCH 11/22] add pydantic requirement (#502) --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 422ea158..1cf10327 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,8 @@ dependencies = [ "PyYAML>=6.0.1", "dataclasses-json>=0.5.2", "Jinja2==3.1.6", - "sentry-sdk>=1.0.0" + "sentry-sdk>=1.0.0", + "pydantic>=2.10.6" ] [project.urls] From 4efdbc88d60a20f507c09befdcaa13030717a6fa Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Fri, 25 Apr 2025 10:21:21 -0300 Subject: [PATCH 12/22] ENG-1978: Adding instructions to teams (#485) * Adding instructions to teams * Update docstring --- .../factories/team_agent_factory/__init__.py | 5 ++- .../factories/team_agent_factory/utils.py | 1 + aixplain/modules/team_agent/__init__.py | 7 +++- .../team_agent/team_agent_functional_test.py | 42 +++++++++++++++++++ 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index c2611d6c..892d7ded 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -50,6 +50,7 @@ def create( num_inspectors: int = 1, inspector_targets: List[Union[InspectorTarget, Text]] = [InspectorTarget.STEPS], use_mentalist_and_inspector: bool = False, # TODO: remove this + instructions: Optional[Text] = None, ) -> TeamAgent: """Create a new team agent in the platform. @@ -57,7 +58,7 @@ def create( name: The name of the team agent. agents: A list of agents to be added to the team. llm_id: The ID of the LLM to be used for the team agent. - description: The description of the team agent. + description: The description of the team agent to be displayed in the aiXplain platform. api_key: The API key to be used for the team agent. supplier: The supplier of the team agent. version: The version of the team agent. @@ -66,6 +67,7 @@ def create( num_inspectors: The number of inspectors to be used for each inspection. inspector_targets: Which stages to be inspected during an execution of the team agent. (steps, output) use_mentalist_and_inspector: Whether to use the mentalist and inspector agents. (legacy) + instructions: The instructions to guide the team agent (i.e. appended in the prompt of the team agent). Returns: A new team agent instance. @@ -139,6 +141,7 @@ def create( "supplier": supplier, "version": version, "status": "draft", + "role": instructions, } team_agent = build_team_agent(payload=payload, agents=agent_list, api_key=api_key) diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index debadff8..48268968 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -38,6 +38,7 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = name=payload.get("name", ""), agents=payload_agents, description=payload.get("description", ""), + instructions=payload.get("role", None), supplier=payload.get("teamId", None), version=payload.get("version", None), cost=payload.get("cost", None), diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index a913f223..3014adb1 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -88,6 +88,7 @@ def __init__( max_inspectors: int = 1, inspector_targets: List[InspectorTarget] = [InspectorTarget.STEPS], status: AssetStatus = AssetStatus.DRAFT, + instructions: Optional[Text] = None, **additional_info, ) -> None: """Create a FineTune with the necessary information. @@ -96,7 +97,7 @@ def __init__( id (Text): ID of the Team Agent name (Text): Name of the Team Agent agents (List[Agent]): List of agents that the Team Agent uses. - description (Text, optional): description of the Team Agent. Defaults to "". + description (Text, optional): The description of the team agent to be displayed in the aiXplain platform. Defaults to "". llm_id (Text, optional): large language model. Defaults to GPT-4o (6646261c6eb563165658bbb1). supplier (Text): Supplier of the Team Agent. version (Text): Version of the Team Agent. @@ -104,6 +105,7 @@ def __init__( api_key (str): The TEAM API key used for authentication. cost (Dict, optional): model price. Defaults to None. use_mentalist_and_inspector (bool): Use Mentalist and Inspector tools. Defaults to True. + instructions (Text, optional): The instructions to guide the team agent (i.e. appended in the prompt of the team agent). Defaults to None. """ super().__init__(id, name, description, api_key, supplier, version, cost=cost) self.additional_info = additional_info @@ -113,7 +115,7 @@ def __init__( self.use_inspector = use_inspector self.max_inspectors = max_inspectors self.inspector_targets = inspector_targets - + self.instructions = instructions if isinstance(status, str): try: status = AssetStatus(status) @@ -345,6 +347,7 @@ def to_dict(self) -> Dict: "supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier, "version": self.version, "status": self.status.value, + "role": self.instructions, } def _validate(self) -> None: diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index cb5f80a9..759f3920 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -592,3 +592,45 @@ def test_team_agent_with_multiple_inspectors(run_input_map, delete_agents_and_te verify_response_generator(steps, has_output_target=False) team_agent.delete() + + +def test_team_agent_with_instructions(delete_agents_and_team_agents): + assert delete_agents_and_team_agents + + agent_1 = AgentFactory.create( + name="Agent 1", + description="Translation agent", + tools=[AgentFactory.create_model_tool(function=Function.TRANSLATION, supplier=Supplier.MICROSOFT)], + llm_id="6646261c6eb563165658bbb1", + ) + + agent_2 = AgentFactory.create( + name="Agent 2", + description="Translation agent", + tools=[AgentFactory.create_model_tool(function=Function.TRANSLATION, supplier=Supplier.GOOGLE)], + llm_id="6646261c6eb563165658bbb1", + ) + + team_agent = TeamAgentFactory.create( + name="Team Agent", + agents=[agent_1, agent_2], + description="Team agent", + instructions="Use only 'Agent 2' to solve the tasks.", + llm_id="6646261c6eb563165658bbb1", + use_mentalist=True, + use_inspector=False, + ) + + response = team_agent.run(data="Translate 'cat' to Portuguese") + assert response.status == "SUCCESS" + assert "gato" in response.data["output"] + + mentalist_steps = eval(response.data["intermediate_steps"][0]["output"]) + + called_agents = set([step["worker"] for step in mentalist_steps]) + assert len(called_agents) == 1 + assert "Agent 2" in called_agents + + team_agent.delete() + agent_1.delete() + agent_2.delete() From ab2fcc544c8309d0a6ef9d8a008fa087d0462f7b Mon Sep 17 00:00:00 2001 From: kadirpekel Date: Fri, 25 Apr 2025 16:59:01 +0200 Subject: [PATCH 13/22] BUG-504: Merged paramMappings for the same link vectors (#499) --- .../modules/pipeline/designer/pipeline.py | 17 ++++++- tests/unit/designer_unit_test.py | 44 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/aixplain/modules/pipeline/designer/pipeline.py b/aixplain/modules/pipeline/designer/pipeline.py index 80248909..0985026d 100644 --- a/aixplain/modules/pipeline/designer/pipeline.py +++ b/aixplain/modules/pipeline/designer/pipeline.py @@ -68,9 +68,22 @@ def serialize(self) -> dict: :return: the pipeline as a dictionary """ + nodes = [node.serialize() for node in self.nodes] + links = [link.serialize() for link in self.links] + + # merge the params for links using the key `from_node` and `to_node` + merged_links = {} + for link in links: + key = (link['from'], link['to']) + if key not in merged_links: + merged_links[key] = link + else: + existing_link = merged_links[key] + existing_link['paramMapping'] += link['paramMapping'] + return { - "nodes": [node.serialize() for node in self.nodes], - "links": [link.serialize() for link in self.links], + "nodes": nodes, + "links": list(merged_links.values()), } def validate_nodes(self): diff --git a/tests/unit/designer_unit_test.py b/tests/unit/designer_unit_test.py index aaf45026..143d5485 100644 --- a/tests/unit/designer_unit_test.py +++ b/tests/unit/designer_unit_test.py @@ -815,3 +815,47 @@ def test_pipeline_special_prompt_validation(): def test_find_prompt_params(input, expected): print(input, expected) assert find_prompt_params(input) == expected + + +def test_pipeline_serialize(): + pipeline = DesignerPipeline() + + # Create nodes + class AssetNode(Node, LinkableMixin): + type: NodeType = NodeType.ASSET + + node1 = AssetNode(pipeline=pipeline) + node2 = AssetNode(pipeline=pipeline) + + # Create multiple parameters for each node + node1.outputs.create_param("output1", DataType.TEXT, "foo1") + node1.outputs.create_param("output2", DataType.TEXT, "foo2") + node2.inputs.create_param("input1", DataType.TEXT, "bar1") + node2.inputs.create_param("input2", DataType.TEXT, "bar2") + + # Create multiple links between the same nodes + link1 = Link( + from_node=node1, + to_node=node2, + from_param="output1", + to_param="input1", + ) + pipeline.add_link(link1) + + link2 = Link( + from_node=node1, + to_node=node2, + from_param="output2", + to_param="input2" + ) + pipeline.add_link(link2) + + serialized = pipeline.serialize() + + # Verify pipeline serialization + assert len(serialized["links"]) == 1 + assert len(serialized["links"][0]["paramMapping"]) == 2 + assert {"from": "output1", "to": "input1"} in serialized["links"][0]["paramMapping"] + assert {"from": "output2", "to": "input2"} in serialized["links"][0]["paramMapping"] + assert serialized["nodes"][0] == node1.serialize() + assert serialized["nodes"][1] == node2.serialize() From 268cc1a69ee2633664cf7f95c2f0bd2d9e0b49a2 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Mon, 28 Apr 2025 16:44:45 -0300 Subject: [PATCH 14/22] ENG-1836: Set name of tools on the SDK (#501) * Set model tool name * Setting tool names * Validate custom tools * Addressing comments * Name on SQL tool * Default name for SQL tool --- aixplain/factories/agent_factory/__init__.py | 16 ++- aixplain/factories/agent_factory/utils.py | 8 +- aixplain/modules/agent/__init__.py | 12 +- .../agent/tool/custom_python_code_tool.py | 12 +- aixplain/modules/agent/tool/model_tool.py | 31 +++++ aixplain/modules/agent/tool/pipeline_tool.py | 22 +++- .../agent/tool/python_interpreter_tool.py | 4 + aixplain/modules/agent/tool/sql_tool.py | 4 +- aixplain/v2/agent.py | 17 ++- .../functional/agent/agent_functional_test.py | 21 ++-- tests/unit/agent/agent_factory_utils_test.py | 44 +++++-- tests/unit/agent/agent_test.py | 112 ++++++++++++------ tests/unit/agent/model_tool_test.py | 13 +- tests/unit/agent/sql_tool_test.py | 39 +++--- 14 files changed, 257 insertions(+), 98 deletions(-) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index bad94fb3..040bcd71 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -164,6 +164,7 @@ def create_model_tool( supplier: Optional[Union[Supplier, Text]] = None, description: Text = "", parameters: Optional[Dict] = None, + name: Optional[Text] = None, ) -> ModelTool: """Create a new model tool.""" if function is not None and isinstance(function, str): @@ -179,9 +180,11 @@ def create_model_tool( return ModelTool(function=function, supplier=supplier, model=model, description=description, parameters=parameters) @classmethod - def create_pipeline_tool(cls, description: Text, pipeline: Union[Pipeline, Text]) -> PipelineTool: + def create_pipeline_tool( + cls, description: Text, pipeline: Union[Pipeline, Text], name: Optional[Text] = None + ) -> PipelineTool: """Create a new pipeline tool.""" - return PipelineTool(description=description, pipeline=pipeline) + return PipelineTool(description=description, pipeline=pipeline, name=name) @classmethod def create_python_interpreter_tool(cls) -> PythonInterpreterTool: @@ -189,13 +192,16 @@ def create_python_interpreter_tool(cls) -> PythonInterpreterTool: return PythonInterpreterTool() @classmethod - def create_custom_python_code_tool(cls, code: Union[Text, Callable], description: Text = "") -> CustomPythonCodeTool: + def create_custom_python_code_tool( + cls, code: Union[Text, Callable], name: Text, description: Text = "" + ) -> CustomPythonCodeTool: """Create a new custom python code tool.""" - return CustomPythonCodeTool(description=description, code=code) + return CustomPythonCodeTool(name=name, description=description, code=code) @classmethod def create_sql_tool( cls, + name: Text, description: Text, source: str, source_type: Union[str, DatabaseSourceType], @@ -206,6 +212,7 @@ def create_sql_tool( """Create a new SQL tool Args: + name (Text): name of the tool description (Text): description of the database tool source (Union[Text, Dict]): database source - can be a connection string or dictionary with connection details source_type (Union[str, DatabaseSourceType]): type of source (postgresql, sqlite, csv) or DatabaseSourceType enum @@ -317,6 +324,7 @@ def create_sql_tool( # Create and return SQLTool return SQLTool( + name=name, description=description, database=database_path, schema=schema, diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 6dae6098..704d7fe7 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -60,6 +60,7 @@ def build_tool(tool: Dict): else: tool = PythonInterpreterTool() elif tool["type"] == "sql": + name = tool.get("name", "SQLTool") parameters = {parameter["name"]: parameter["value"] for parameter in tool.get("parameters", [])} database = parameters.get("database") schema = parameters.get("schema") @@ -67,7 +68,12 @@ def build_tool(tool: Dict): tables = tables.split(",") if tables is not None else None enable_commit = parameters.get("enable_commit", False) tool = SQLTool( - description=tool["description"], database=database, schema=schema, tables=tables, enable_commit=enable_commit + name=name, + description=tool["description"], + database=database, + schema=schema, + tables=tables, + enable_commit=enable_commit, ) else: raise ValueError("Agent Creation Error: Tool type not supported.") diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 38b8a7ae..227c0040 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -127,11 +127,21 @@ def _validate(self) -> None: assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." + tool_names = [] for tool in self.tools: + tool_name = None if isinstance(tool, Tool): - tool.validate() + tool_name = tool.name elif isinstance(tool, Model): assert not isinstance(tool, Agent), "Agent cannot contain another Agent." + tool_name = tool.name + tool_names.append(tool_name) + + if len(tool_names) != len(set(tool_names)): + duplicates = set([name for name in tool_names if tool_names.count(name) > 1]) + raise Exception( + f"Agent Creation Error - Duplicate tool names found: {', '.join(duplicates)}. Make sure all tool names are unique." + ) def validate(self, raise_exception: bool = False) -> bool: """Validate the Agent.""" diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index 511e4a8f..d544d2b1 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -21,7 +21,7 @@ Agentification Class """ -from typing import Text, Union, Callable +from typing import Text, Union, Callable, Optional from aixplain.modules.agent.tool import Tool import logging @@ -29,10 +29,13 @@ class CustomPythonCodeTool(Tool): """Custom Python Code Tool""" - def __init__(self, code: Union[Text, Callable], description: Text = "", **additional_info) -> None: + def __init__( + self, code: Union[Text, Callable], description: Text = "", name: Optional[Text] = None, **additional_info + ) -> None: """Custom Python Code Tool""" - super().__init__(name="Custom Python Code", description=description, **additional_info) + super().__init__(name=name or "", description=description, **additional_info) self.code = code + self.validate() def to_dict(self): return { @@ -64,3 +67,6 @@ def validate(self): ), "Custom Python Code Tool Error: Tool description is required" assert self.code and self.code.strip() != "", "Custom Python Code Tool Error: Code is required" assert self.name and self.name.strip() != "", "Custom Python Code Tool Error: Name is required" + + def __repr__(self) -> Text: + return f"CustomPythonCodeTool(name={self.name})" diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 29c68e0f..6a945a15 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -28,6 +28,30 @@ from aixplain.modules.model import Model +def set_tool_name(function: Function, supplier: Supplier = None, model: Model = None) -> Text: + """Sets the name of the tool based on the function, supplier, and model. + + Args: + function (Function): The function to be used in the tool. + supplier (Supplier): The supplier to be used in the tool. + model (Model): The model to be used in the tool. + + Returns: + Text: The name of the tool. + """ + function_name = function.value.lower().replace(" ", "_") + tool_name = f"{function_name}" + + if supplier is not None: + supplier_name = supplier.name.lower().replace(" ", "_") + tool_name += f"-{supplier_name}" + + if model is not None and supplier is not None: + model_name = model.name.lower().replace(" ", "_") + tool_name += f"-{model_name}" + return tool_name + + class ModelTool(Tool): """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. @@ -139,6 +163,8 @@ def validate(self) -> None: self.parameters = self.validate_parameters(self.parameters) + self.name = self.name if self.name else set_tool_name(self.function, self.supplier, self.model) + def get_parameters(self) -> Dict: return self.parameters @@ -187,3 +213,8 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) raise ValueError(f"Invalid parameters provided: {invalid_params}. Expected parameters are: {expected_param_names}") return received_parameters + + def __repr__(self) -> Text: + supplier_str = self.supplier.value if self.supplier is not None else None + model_str = self.model.id if self.model is not None else None + return f"ModelTool(name={self.name}, function={self.function}, supplier={supplier_str}, model={model_str})" diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 13d3c46f..0de83916 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -50,9 +50,8 @@ def __init__( name = name or "" super().__init__(name=name, description=description, **additional_info) - if isinstance(pipeline, Pipeline): - pipeline = pipeline.id self.pipeline = pipeline + self.validate() def to_dict(self): return { @@ -65,7 +64,18 @@ def to_dict(self): def validate(self): from aixplain.factories.pipeline_factory import PipelineFactory - try: - PipelineFactory.get(self.pipeline, api_key=self.api_key) - except Exception: - raise Exception(f"Pipeline Tool Unavailable. Make sure Pipeline '{self.pipeline}' exists or you have access to it.") + if isinstance(self.pipeline, Pipeline): + pipeline_obj = self.pipeline + else: + try: + pipeline_obj = PipelineFactory.get(self.pipeline, api_key=self.api_key) + except Exception: + raise Exception( + f"Pipeline Tool Unavailable. Make sure Pipeline '{self.pipeline}' exists or you have access to it." + ) + + if self.name.strip() == "": + self.name = pipeline_obj.name + + def __repr__(self) -> Text: + return f"PipelineTool(name={self.name}, pipeline={self.pipeline})" diff --git a/aixplain/modules/agent/tool/python_interpreter_tool.py b/aixplain/modules/agent/tool/python_interpreter_tool.py index d5947bff..42621c45 100644 --- a/aixplain/modules/agent/tool/python_interpreter_tool.py +++ b/aixplain/modules/agent/tool/python_interpreter_tool.py @@ -22,6 +22,7 @@ """ from aixplain.modules.agent.tool import Tool +from typing import Text class PythonInterpreterTool(Tool): @@ -41,3 +42,6 @@ def to_dict(self): def validate(self): pass + + def __repr__(self) -> Text: + return "PythonInterpreterTool()" diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index 81444db0..56cf116c 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -259,17 +259,18 @@ class SQLTool(Tool): def __init__( self, + name: Text, description: Text, database: Text, schema: Optional[Text] = None, tables: Optional[Union[List[Text], Text]] = None, enable_commit: bool = False, - name: Optional[Text] = None, **additional_info, ) -> None: """Tool to execute SQL query commands in an SQLite database. Args: + name (Text): name of the tool description (Text): description of the tool database (Text): database uri schema (Optional[Text]): database schema description @@ -277,7 +278,6 @@ def __init__( enable_commit (bool): enable to modify the database (optional) """ - name = name or "" super().__init__(name=name, description=description, **additional_info) self.database = database diff --git a/aixplain/v2/agent.py b/aixplain/v2/agent.py index af8c2d39..eadb155f 100644 --- a/aixplain/v2/agent.py +++ b/aixplain/v2/agent.py @@ -95,17 +95,22 @@ def create_model_tool( function: Union[Function, str] = None, supplier: Union["Supplier", str] = None, description: str = "", + name: Optional[str] = None, ): from aixplain.factories import AgentFactory - return AgentFactory.create_model_tool(model=model, function=function, supplier=supplier, description=description) + return AgentFactory.create_model_tool( + model=model, function=function, supplier=supplier, description=description, name=name + ) @classmethod - def create_pipeline_tool(cls, description: str, pipeline: Union["Pipeline", str]) -> "PipelineTool": + def create_pipeline_tool( + cls, description: str, pipeline: Union["Pipeline", str], name: Optional[str] = None + ) -> "PipelineTool": """Create a new pipeline tool.""" from aixplain.factories import AgentFactory - return AgentFactory.create_pipeline_tool(description=description, pipeline=pipeline) + return AgentFactory.create_pipeline_tool(description=description, pipeline=pipeline, name=name) @classmethod def create_python_interpreter_tool(cls) -> "PythonInterpreterTool": @@ -115,11 +120,13 @@ def create_python_interpreter_tool(cls) -> "PythonInterpreterTool": return AgentFactory.create_python_interpreter_tool() @classmethod - def create_custom_python_code_tool(cls, code: Union[str, Callable], description: str = "") -> "CustomPythonCodeTool": + def create_custom_python_code_tool( + cls, code: Union[str, Callable], name: str, description: str = "" + ) -> "CustomPythonCodeTool": """Create a new custom python code tool.""" from aixplain.factories import AgentFactory - return AgentFactory.create_custom_python_code_tool(code=code, description=description) + return AgentFactory.create_custom_python_code_tool(code=code, name=name, description=description) @classmethod def create_sql_tool( diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 0eecfa3f..ca3c4816 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -131,24 +131,27 @@ def test_python_interpreter_tool(delete_agents_and_team_agents, AgentFactory): def test_custom_code_tool(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents tool = AgentFactory.create_custom_python_code_tool( - description="Add two numbers", - code='def main(aaa: int, bbb: int) -> int:\n """Add two numbers"""\n return aaa + bbb', + description="Add two strings", + code='def main(aaa: str, bbb: str) -> str:\n """Add two strings"""\n return aaa + bbb', + name="Add Strings", ) assert tool is not None - assert tool.description == "Add two numbers" - assert tool.code == 'def main(aaa: int, bbb: int) -> int:\n """Add two numbers"""\n return aaa + bbb' + assert tool.description == "Add two strings" + assert tool.code == 'def main(aaa: str, bbb: str) -> str:\n """Add two strings"""\n return aaa + bbb' agent = AgentFactory.create( - name="Add Numbers Agent", - description="Add two numbers. Do not directly answer. Use the tool to add the numbers.", - instructions="Add two numbers. Do not directly answer. Use the tool to add the numbers.", + name="Add Strings Agent", + description="Add two strings. Do not directly answer. Use the tool to add the strings.", + instructions="Add two strings. Do not directly answer. Use the tool to add the strings.", tools=[tool], ) assert agent is not None - response = agent.run("How much is 12342 + 112312? Do not directly answer the question, call the tool.") + response = agent.run( + "What is the result of concatenating 'Hello' and 'World'? Do not directly answer the question, call the tool." + ) assert response is not None assert response["completed"] is True assert response["status"].lower() == "success" - assert "124654" in response["data"]["output"] + assert "HelloWorld" in response["data"]["output"] agent.delete() diff --git a/tests/unit/agent/agent_factory_utils_test.py b/tests/unit/agent/agent_factory_utils_test.py index 66a7a4e2..c0104f5c 100644 --- a/tests/unit/agent/agent_factory_utils_test.py +++ b/tests/unit/agent/agent_factory_utils_test.py @@ -3,6 +3,7 @@ from aixplain.factories.agent_factory.utils import build_tool, build_agent from aixplain.enums import Function, Supplier from aixplain.enums.asset_status import AssetStatus +from aixplain.modules.pipeline import Pipeline from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool @@ -10,7 +11,7 @@ from aixplain.modules.agent.tool.sql_tool import SQLTool from aixplain.modules.agent import Agent from aixplain.modules.agent.agent_task import AgentTask -from aixplain.factories.model_factory import ModelFactory +from aixplain.factories import ModelFactory, PipelineFactory @pytest.fixture @@ -139,6 +140,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): pytest.param( { "type": "sql", + "name": "Test SQL", "description": "Test SQL", "parameters": [ {"name": "database", "value": "test_db"}, @@ -149,6 +151,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): }, SQLTool, { + "name": "Test SQL", "description": "Test SQL", "database": "test_db", "schema": "public", @@ -160,6 +163,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): pytest.param( { "type": "sql", + "name": "Test SQL", "description": "Test SQL with string enable_commit", "parameters": [ {"name": "database", "value": "test_db"}, @@ -170,6 +174,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): }, SQLTool, { + "name": "Test SQL", "description": "Test SQL with string enable_commit", "database": "test_db", "schema": "public", @@ -180,17 +185,28 @@ def test_build_tool_error_cases(tool_dict, expected_error): ), ], ) -def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock_model): +def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock_model, mocker): """Test successful tool creation with various configurations.""" - with patch.object(ModelFactory, "get", return_value=mock_model): - tool = build_tool(tool_dict) - assert isinstance(tool, expected_type) + mocker.patch.object(ModelFactory, "get", return_value=mock_model) + mocker.patch( + "aixplain.modules.model.utils.parse_code_decorated", + return_value=("print('Hello World')", [], "Test description", "test_name"), + ) + if tool_dict["type"] == "pipeline": + mocker.patch.object( + PipelineFactory, + "get", + return_value=Pipeline(id=tool_dict["assetId"], description=tool_dict["description"], name="Pipeline", api_key=""), + ) - for attr, value in expected_attrs.items(): - if attr == "model": - assert tool.model == mock_model - else: - assert getattr(tool, attr) == value + tool = build_tool(tool_dict) + assert isinstance(tool, expected_type) + + for attr, value in expected_attrs.items(): + if attr == "model": + assert tool.model == mock_model + else: + assert getattr(tool, attr) == value @pytest.mark.parametrize( @@ -263,8 +279,14 @@ def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock ), ], ) -def test_build_agent_success_cases(payload, expected_attrs, mock_tools): +def test_build_agent_success_cases(payload, expected_attrs, mock_tools, mocker): """Test successful agent creation with various configurations.""" + mocker.patch.object( + PipelineFactory, + "get", + return_value=Pipeline(id="test_pipeline", description="Test pipeline", name="Pipeline", api_key=""), + ) + agent = build_agent(payload, tools=mock_tools if "assets" not in payload else None) assert isinstance(agent, Agent) diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index f767f0ec..3a151b08 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -12,7 +12,7 @@ from aixplain.modules.agent.utils import process_variables from urllib.parse import urljoin from unittest.mock import patch -from aixplain.enums.function import Function +from aixplain.enums import Function, Supplier from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData @@ -81,7 +81,11 @@ def test_success_query_content(): assert response["url"] == ref_response["data"] -def test_invalid_pipelinetool(): +def test_invalid_pipelinetool(mocker): + mocker.patch( + "aixplain.factories.model_factory.ModelFactory.get", + return_value=Model(id="6646261c6eb563165658bbb1", name="Test Model", function=Function.TEXT_GENERATION), + ) with pytest.raises(Exception) as exc_info: AgentFactory.create( name="Test", @@ -155,6 +159,7 @@ def test_create_agent(mock_model_factory_get): "utility": "custom_python_code", "utilityCode": "def main(query: str) -> str:\n return 'Hello, how are you?'", "description": "Test Tool", + "name": "Test Tool", }, { "type": "utility", @@ -188,7 +193,9 @@ def test_create_agent(mock_model_factory_get): supplier=Supplier.OPENAI, function="text-generation", description="Test Tool" ), AgentFactory.create_custom_python_code_tool( - code="def main(query: str) -> str:\n return 'Hello, how are you?'", description="Test Tool" + code="def main(query: str) -> str:\n return 'Hello, how are you?'", + description="Test Tool", + name="Test Tool", ), AgentFactory.create_python_interpreter_tool(), ], @@ -462,7 +469,7 @@ def test_agent_multiple_tools_api_key(): AgentFactory.create_model_tool(function="text-generation"), AgentFactory.create_python_interpreter_tool(), AgentFactory.create_custom_python_code_tool( - code="def main(query: str) -> str:\n return 'Hello'", description="Test Tool" + code="def main(query: str) -> str:\n return 'Hello'", description="Test Tool", name="Test Tool" ), ] @@ -551,19 +558,28 @@ def test_agent_response(): assert response["data"]["output"] == "new_output" -def test_custom_python_code_tool_initialization(): +def test_custom_python_code_tool_initialization(mocker): """Test basic initialization of CustomPythonCodeTool""" + mocker.patch( + "aixplain.modules.model.utils.parse_code_decorated", + return_value=("def main(query: str) -> str:\n return 'Hello'", [], "Test description", "HelloWorld"), + ) + code = "def main(query: str) -> str:\n return 'Hello'" description = "Test description" - tool = CustomPythonCodeTool(code=code, description=description) + tool = CustomPythonCodeTool(code=code, description=description, name="HelloWorld") assert tool.code == code assert tool.description == description - assert tool.name == "Custom Python Code" + assert tool.name == "HelloWorld" -def test_custom_python_code_tool_to_dict(): +def test_custom_python_code_tool_to_dict(mocker): """Test the to_dict method of CustomPythonCodeTool""" + mocker.patch( + "aixplain.modules.model.utils.parse_code_decorated", + return_value=("def main(query: str) -> str:\n return 'Hello'", [], "Test description", "HelloWorld"), + ) code = "def main(query: str) -> str:\n return 'Hello'" description = "Test description" tool = CustomPythonCodeTool(code=code, description=description) @@ -594,22 +610,17 @@ def test_custom_python_code_tool_validation(): assert tool.name == "test_name" -def test_custom_python_code_tool_validation_missing_description(): +def test_custom_python_code_tool_validation_missing_description(mocker): """Test validation fails when description is missing""" - with patch( - "aixplain.modules.model.utils.parse_code", - return_value=( - "def main(query: str) -> str:\n return 'Hello'", # code - [], # inputs - None, # description - "test_name", # name - ), - ): - code = "def main(query: str) -> str:\n return 'Hello'" - tool = CustomPythonCodeTool(code=code) - with pytest.raises(AssertionError) as exc_info: - tool.validate() - assert str(exc_info.value) == "Custom Python Code Tool Error: Tool description is required" + mocker.patch( + "aixplain.modules.model.utils.parse_code_decorated", + return_value=("def main(query: str) -> str:\n return 'Hello'", [], "", "HelloWorld"), + ) + + code = "def main(query: str) -> str:\n return 'Hello'" + with pytest.raises(AssertionError) as exc_info: + CustomPythonCodeTool(code=code) + assert str(exc_info.value) == "Custom Python Code Tool Error: Tool description is required" def test_custom_python_code_tool_validation_missing_code(): @@ -619,22 +630,10 @@ def test_custom_python_code_tool_validation_missing_code(): return_value=("", [], "Parsed description", "test_name"), # code # inputs # description # name ): with pytest.raises(AssertionError) as exc_info: - tool = CustomPythonCodeTool(code="", description="Test description") - tool.validate() + CustomPythonCodeTool(code="", description="Test description") assert str(exc_info.value) == "Custom Python Code Tool Error: Code is required" -def test_custom_python_code_tool_with_callable(): - """Test CustomPythonCodeTool with a callable function""" - - def test_function(query: str) -> str: - return "Hello" - - tool = CustomPythonCodeTool(code=test_function, description="Test description") - assert callable(tool.code) - assert tool.description == "Test description" - - @patch("aixplain.factories.model_factory.ModelFactory.get") def test_create_agent_with_model_instance(mock_model_factory_get): from aixplain.enums import Supplier, Function @@ -957,3 +956,44 @@ def test_agent_response_repr(): # Most importantly, verify that 'status' is complete (not 'tatus') assert "status=" in repr_str # Should find complete field name + + +@pytest.mark.parametrize( + "function,supplier,model,expected_name", + [ + (Function.TRANSLATION, None, None, "translation"), + (Function.TEXT_GENERATION, Supplier.AIXPLAIN, None, "text-generation-aixplain"), + (Function.TEXT_GENERATION, Supplier.OPENAI, None, "text-generation-openai"), + ( + Function.TEXT_GENERATION, + Supplier.AIXPLAIN, + Model(id="123", name="Test Model"), + "text-generation-aixplain-test_model", + ), + ], +) +def test_set_tool_name(function, supplier, model, expected_name): + from aixplain.modules.agent.tool.model_tool import set_tool_name + + name = set_tool_name(function, supplier, model) + assert name == expected_name + + +def test_create_agent_with_duplicate_tool_names(mocker): + from aixplain.factories import AgentFactory + from aixplain.modules import Model + from aixplain.modules.agent.tool.model_tool import ModelTool + + mocker.patch( + "aixplain.factories.model_factory.ModelFactory.get", + return_value=Model(id="123", name="Test Model", function=Function.TEXT_GENERATION), + ) + + # Create a ModelTool with a specific name + tool1 = ModelTool(model="123", name="Test Model") + tool2 = ModelTool(model="123", name="Test Model") + with pytest.raises(Exception) as exc_info: + AgentFactory.create(name="Test Agent", description="Test Agent Description", tools=[tool1, tool2]) + assert "Agent Creation Error - Duplicate tool names found: Test Model. Make sure all tool names are unique." in str( + exc_info.value + ) diff --git a/tests/unit/agent/model_tool_test.py b/tests/unit/agent/model_tool_test.py index c29df945..869979e1 100644 --- a/tests/unit/agent/model_tool_test.py +++ b/tests/unit/agent/model_tool_test.py @@ -64,7 +64,7 @@ def test_init_with_supplier_dict(): with patch("aixplain.modules.agent.tool.model_tool.Supplier") as mock_supplier: # Create a mock for the Supplier enum - mock_supplier.AIXPLAIN = type("MockSupplier", (), {"value": mock_enum["AIXPLAIN"]})() + mock_supplier.AIXPLAIN = type("MockSupplier", (), {"value": mock_enum["AIXPLAIN"], "name": "aiXplain"})() mock_supplier.return_value = mock_supplier.AIXPLAIN tool = ModelTool(function=Function.TRANSLATION, supplier=supplier_dict) @@ -87,8 +87,11 @@ def test_init_validation_errors(error_case, expected_error, expected_message): error_case() -def test_to_dict(mock_model, mock_model_factory): +def test_to_dict(mocker, mock_model, mock_model_factory): mock_model_factory.get.return_value = mock_model + + mocker.patch("aixplain.modules.agent.tool.model_tool.set_tool_name", return_value="test_tool_name") + tool = ModelTool( model="test_model_id", description="Test description", @@ -98,7 +101,7 @@ def test_to_dict(mock_model, mock_model_factory): expected = { "function": mock_model.function.value, "type": "model", - "name": "", + "name": "test_tool_name", "description": "Test description", "supplier": mock_model.supplier.value["code"], "version": None, @@ -172,9 +175,9 @@ def test_validate_parameters(mocker, mock_model, params, expected_result, error_ "tool_name,expected_name", [ ("custom_tool", "custom_tool"), - ("", ""), # Test empty name + ("", "translation-aixplain-test_model"), # Test empty name ("translation_model", "translation_model"), - (None, ""), # Test None value should default to empty string + (None, "translation-aixplain-test_model"), # Test None value should default to empty string ], ) def test_tool_name(mock_model, mock_model_factory, tool_name, expected_name): diff --git a/tests/unit/agent/sql_tool_test.py b/tests/unit/agent/sql_tool_test.py index 12170a23..1a1dbb84 100644 --- a/tests/unit/agent/sql_tool_test.py +++ b/tests/unit/agent/sql_tool_test.py @@ -55,7 +55,7 @@ def test_create_sql_tool(mocker, tmp_path): # Test SQLite source type tool = AgentFactory.create_sql_tool( - description="Test", source=db_path, source_type="sqlite", schema="test", tables=["test", "test2"] + name="Test SQL", description="Test", source=db_path, source_type="sqlite", schema="test", tables=["test", "test2"] ) assert isinstance(tool, SQLTool) assert tool.description == "Test" @@ -67,7 +67,9 @@ def test_create_sql_tool(mocker, tmp_path): df = pd.DataFrame({"id": [1, 2, 3], "name": ["test1", "test2", "test3"]}) df.to_csv(csv_path, index=False) # Test CSV source type - csv_tool = AgentFactory.create_sql_tool(description="Test CSV", source=csv_path, source_type="csv", tables=["data"]) + csv_tool = AgentFactory.create_sql_tool( + name="Test CSV", description="Test CSV", source=csv_path, source_type="csv", tables=["data"] + ) assert isinstance(csv_tool, SQLTool) assert csv_tool.description == "Test CSV" assert csv_tool.database.endswith(".db") @@ -160,36 +162,37 @@ def test_sql_tool_validation_errors(tmp_path): # Test missing description with pytest.raises(SQLToolError, match="Description is required"): - tool = AgentFactory.create_sql_tool(description="", source=db_path, source_type="sqlite") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="", source=db_path, source_type="sqlite") tool.validate() # Test missing source with pytest.raises(SQLToolError, match="Source must be provided"): - tool = AgentFactory.create_sql_tool(description="Test", source="", source_type="sqlite") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source="", source_type="sqlite") tool.validate() # Test missing source_type with pytest.raises(TypeError, match="missing 1 required positional argument: 'source_type'"): - tool = AgentFactory.create_sql_tool(description="Test", source=db_path) + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=db_path) tool.validate() # Test invalid source type with pytest.raises(SQLToolError, match="Invalid source type"): - AgentFactory.create_sql_tool(description="Test", source=db_path, source_type="invalid") + AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=db_path, source_type="invalid") # Test non-existent SQLite database with pytest.raises(SQLToolError, match="Database .* does not exist"): - tool = AgentFactory.create_sql_tool(description="Test", source="nonexistent.db", source_type="sqlite") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source="nonexistent.db", source_type="sqlite") tool.validate() # Test non-existent CSV file with pytest.raises(SQLToolError, match="CSV file .* does not exist"): - tool = AgentFactory.create_sql_tool(description="Test", source="nonexistent.csv", source_type="csv") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source="nonexistent.csv", source_type="csv") tool.validate() # Test PostgreSQL (not supported) with pytest.raises(SQLToolError, match="PostgreSQL is not supported yet"): tool = AgentFactory.create_sql_tool( + name="Test SQL", description="Test", source="postgresql://user:pass@localhost/mydb", source_type="postgresql", @@ -210,7 +213,7 @@ def test_create_sql_tool_with_schema_inference(tmp_path, mocker): conn.close() # Create tool without schema and tables - tool = AgentFactory.create_sql_tool(description="Test", source=db_path, source_type="sqlite") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=db_path, source_type="sqlite") # Mock schema inference schema = "CREATE TABLE test (id INTEGER, name TEXT)" @@ -239,7 +242,7 @@ def test_create_sql_tool_from_csv_with_warnings(tmp_path, mocker): # Create tool and check for warnings with pytest.warns(UserWarning) as record: - tool = AgentFactory.create_sql_tool(description="Test", source=csv_path, source_type="csv") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=csv_path, source_type="csv") # Verify warnings about column name changes warning_messages = [str(w.message) for w in record] @@ -274,7 +277,9 @@ def test_create_sql_tool_from_csv(tmp_path): df.to_csv(csv_path, index=False) # Test successful creation - tool = AgentFactory.create_sql_tool(description="Test", source=csv_path, source_type="csv", tables=["test"]) + tool = AgentFactory.create_sql_tool( + name="Test SQL", description="Test", source=csv_path, source_type="csv", tables=["test"] + ) assert isinstance(tool, SQLTool) assert tool.description == "Test" assert tool.database.endswith(".db") @@ -301,7 +306,7 @@ def test_sql_tool_schema_inference(tmp_path): df.to_csv(csv_path, index=False) # Create tool without schema and tables - tool = AgentFactory.create_sql_tool(description="Test", source=csv_path, source_type="csv") + tool = AgentFactory.create_sql_tool(name="Test SQL", description="Test", source=csv_path, source_type="csv") try: tool.validate() @@ -324,15 +329,19 @@ def test_create_sql_tool_source_type_handling(tmp_path): conn.close() # Test with string input - tool_str = AgentFactory.create_sql_tool(description="Test", source=db_path, source_type="sqlite", schema="test") + tool_str = AgentFactory.create_sql_tool( + name="Test SQL", description="Test", source=db_path, source_type="sqlite", schema="test" + ) assert isinstance(tool_str, SQLTool) # Test with enum input tool_enum = AgentFactory.create_sql_tool( - description="Test", source=db_path, source_type=DatabaseSourceType.SQLITE, schema="test" + name="Test SQL", description="Test", source=db_path, source_type=DatabaseSourceType.SQLITE, schema="test" ) assert isinstance(tool_enum, SQLTool) # Test invalid type with pytest.raises(SQLToolError, match="Source type must be either a string or DatabaseSourceType enum, got "): - AgentFactory.create_sql_tool(description="Test", source=db_path, source_type=123, schema="test") # Invalid type \ No newline at end of file + AgentFactory.create_sql_tool( + name="Test SQL", description="Test", source=db_path, source_type=123, schema="test" + ) # Invalid type From 9796d197a19c05d6ccb49c12af128b65b13f5018 Mon Sep 17 00:00:00 2001 From: kadirpekel Date: Wed, 30 Apr 2025 14:47:21 +0200 Subject: [PATCH 15/22] Eng 2051 Improvements on CI flow (#509) * ENG-2051: improvements on CI workflow * Update main.yaml --- .github/workflows/main.yaml | 87 +++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 18 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index faf380eb..9f1ac34a 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -9,29 +9,79 @@ on: workflow_dispatch: jobs: - setup-and-test: + test: runs-on: ubuntu-latest + timeout-minutes: 45 # Global timeout fallback strategy: fail-fast: false matrix: test-suite: [ - 'tests/unit', - 'tests/functional/file_asset', - 'tests/functional/data_asset', - 'tests/functional/benchmark', - 'tests/functional/model', - 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v1 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v2 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v1 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v2 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/designer_test.py', - 'tests/functional/pipelines/create_test.py', - 'tests/functional/finetune --sdk_version v1 --sdk_version_param FinetuneFactory', - 'tests/functional/finetune --sdk_version v2 --sdk_version_param FinetuneFactory', - 'tests/functional/general_assets', - 'tests/functional/apikey', - 'tests/functional/agent tests/functional/team_agent', + 'unit', + 'file_asset', + 'data_asset', + 'benchmark', + 'model', + 'pipeline_2.0_v1', + 'pipeline_2.0_v2', + 'pipeline_3.0_v1', + 'pipeline_3.0_v2', + 'pipeline_designer', + 'pipeline_create', + 'finetune_v1', + 'finetune_v2', + 'general_assets', + 'apikey', + 'agent_and_team_agent', ] + include: + - test-suite: 'unit' + path: 'tests/unit' + timeout: 45 # Tweaked timeout for each unit tests + - test-suite: 'file_asset' + path: 'tests/functional/file_asset' + timeout: 45 + - test-suite: 'data_asset' + path: 'tests/functional/data_asset' + timeout: 45 + - test-suite: 'benchmark' + path: 'tests/functional/benchmark' + timeout: 45 + - test-suite: 'model' + path: 'tests/functional/model' + timeout: 45 + - test-suite: 'pipeline_2.0_v1' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v1 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_2.0_v2' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v2 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_3.0_v1' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v1 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_3.0_v2' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v2 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_designer' + path: 'tests/functional/pipelines/designer_test.py' + timeout: 45 + - test-suite: 'pipeline_create' + path: 'tests/functional/pipelines/create_test.py' + timeout: 45 + - test-suite: 'finetune_v1' + path: 'tests/functional/finetune --sdk_version v1 --sdk_version_param FinetuneFactory' + timeout: 45 + - test-suite: 'finetune_v2' + path: 'tests/functional/finetune --sdk_version v2 --sdk_version_param FinetuneFactory' + timeout: 45 + - test-suite: 'general_assets' + path: 'tests/functional/general_assets' + timeout: 45 + - test-suite: 'apikey' + path: 'tests/functional/apikey' + timeout: 45 + - test-suite: 'agent_and_team_agent' + path: 'tests/functional/agent tests/functional/team_agent' + timeout: 45 steps: - name: Checkout repository uses: actions/checkout@v4 @@ -62,4 +112,5 @@ jobs: fi - name: Run Tests - run: python -m pytest ${{ matrix.test-suite}} \ No newline at end of file + timeout-minutes: ${{ matrix.timeout }} + run: python -m pytest ${{ matrix.path }} From f0837fca2b737d0491fd66932b8204d8d785dee8 Mon Sep 17 00:00:00 2001 From: Muhammad-Elmallah <145364766+Muhammad-Elmallah@users.noreply.github.com> Date: Mon, 5 May 2025 14:48:27 +0300 Subject: [PATCH 16/22] Add a finetuned version of BGE model (#512) * Add a finetuned version of BGE model * changin mode name --- aixplain/enums/embedding_model.py | 4 +--- tests/functional/model/run_model_test.py | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aixplain/enums/embedding_model.py b/aixplain/enums/embedding_model.py index c52387b2..31618580 100644 --- a/aixplain/enums/embedding_model.py +++ b/aixplain/enums/embedding_model.py @@ -27,9 +27,7 @@ class EmbeddingModel(str, Enum): JINA_CLIP_V2_MULTIMODAL = "67c5f705d8f6a65d6f74d732" MULTILINGUAL_E5_LARGE = "67efd0772a0a850afa045af3" BGE_M3 = "67f401032a0a850afa045b19" - - - + AIXPLAIN_LEGAL_EMBEDDINGS = "681254b668e47e7844c1f15a" def __str__(self): return self._value_ diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 916d6077..1aed18df 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -99,6 +99,7 @@ def run_index_model(index_model): pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, AirParams, id="AIR - Snowflake Arctic Embed L v2.0"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="AIR - BGE M3"), + pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="AIR - aiXplain Legal Embeddings"), ], ) def test_index_model(embedding_model, supplier_params): @@ -123,6 +124,7 @@ def test_index_model(embedding_model, supplier_params): pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"), + pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="aiXplain Legal Embeddings"), ], ) def test_index_model_with_filter(embedding_model, supplier_params): From 511bf5f570dcec7011b2ac011653b319d7018e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Tue, 6 May 2025 13:33:05 +0300 Subject: [PATCH 17/22] ENG-2055-Aixplain-SDK-Centralized-Error-Handling (#510) * move request_with_retry from file_utils to request_utils because of circular import * added get error from response status for model/pipeline calls * added error codes * Small refactor and error code inclusion --------- Co-authored-by: Thiago Castro Ferreira --- .pre-commit-config.yaml | 2 +- aixplain/enums/license.py | 2 +- aixplain/enums/supplier.py | 2 +- aixplain/exceptions/__init__.py | 116 ++++++++++ aixplain/exceptions/types.py | 217 ++++++++++++++++++ aixplain/factories/agent_factory/__init__.py | 2 +- aixplain/factories/api_key_factory.py | 2 +- aixplain/factories/benchmark_factory.py | 2 +- aixplain/factories/corpus_factory.py | 2 +- aixplain/factories/data_factory.py | 8 +- aixplain/factories/dataset_factory.py | 3 +- .../factories/finetune_factory/__init__.py | 4 +- aixplain/factories/metric_factory.py | 6 +- aixplain/factories/model_factory/__init__.py | 2 +- aixplain/factories/model_factory/utils.py | 2 +- .../factories/pipeline_factory/__init__.py | 2 +- .../factories/team_agent_factory/__init__.py | 2 +- aixplain/factories/wallet_factory.py | 2 +- aixplain/modules/agent/__init__.py | 2 +- aixplain/modules/api_key.py | 2 +- aixplain/modules/benchmark.py | 2 +- aixplain/modules/benchmark_job.py | 5 +- aixplain/modules/corpus.py | 4 +- aixplain/modules/dataset.py | 2 +- aixplain/modules/finetune/__init__.py | 2 +- aixplain/modules/model/__init__.py | 5 +- aixplain/modules/model/llm_model.py | 1 + aixplain/modules/model/response.py | 6 + aixplain/modules/model/utility_model.py | 2 +- aixplain/modules/model/utils.py | 21 +- aixplain/modules/pipeline/asset.py | 24 +- aixplain/modules/team_agent/__init__.py | 2 +- .../data_onboarding/onboard_functions.py | 5 +- aixplain/utils/file_utils.py | 21 +- tests/test_utils.py | 4 +- tests/unit/llm_test.py | 8 +- tests/unit/model_test.py | 8 +- tests/unit/pipeline_test.py | 132 +---------- 38 files changed, 410 insertions(+), 226 deletions(-) create mode 100644 aixplain/exceptions/__init__.py create mode 100644 aixplain/exceptions/types.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 456aba3b..2500c2c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,4 +22,4 @@ repos: hooks: - id: flake8 args: # arguments to configure flake8 - - --ignore=E402,E501,E203 \ No newline at end of file + - --ignore=E402,E501,E203,W503 \ No newline at end of file diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index a860a539..566be092 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -25,8 +25,8 @@ from enum import Enum from urllib.parse import urljoin from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.utils.request_utils import _request_with_retry CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" diff --git a/aixplain/enums/supplier.py b/aixplain/enums/supplier.py index 26058bf5..18a3e81d 100644 --- a/aixplain/enums/supplier.py +++ b/aixplain/enums/supplier.py @@ -24,7 +24,7 @@ import logging from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from enum import Enum from urllib.parse import urljoin import re diff --git a/aixplain/exceptions/__init__.py b/aixplain/exceptions/__init__.py new file mode 100644 index 00000000..5d645c62 --- /dev/null +++ b/aixplain/exceptions/__init__.py @@ -0,0 +1,116 @@ +""" +Error message registry for aiXplain SDK. + +This module maintains a centralized registry of error messages used throughout the aiXplain ecosystem. +It allows developers to look up existing error messages and reuse them instead of creating new ones. +""" + +from aixplain.exceptions.types import ( + AixplainBaseException, + AuthenticationError, + ValidationError, + ResourceError, + BillingError, + SupplierError, + NetworkError, + ServiceError, + InternalError, +) + + +def get_error_from_status_code(status_code: int, error_details: str = None) -> AixplainBaseException: + """ + Map HTTP status codes to appropriate exception types. + + Args: + status_code (int): The HTTP status code to map. + default_message (str, optional): The default message to use if no specific message is available. + + Returns: + AixplainBaseException: An exception of the appropriate type. + """ + try: + if isinstance(status_code, str): + status_code = int(status_code) + except Exception as e: + raise InternalError(f"Failed to get status code from {status_code}: {e}") from e + + error_details = f"Details: {error_details}" if error_details else "" + if status_code == 400: + return ValidationError( + message=f"Bad request: Please verify the request payload and ensure it is correct. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 401: + return AuthenticationError( + message=f"Unauthorized API key: Please verify the spelling of the API key and its current validity. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 402: + return BillingError( + message=f"Payment required: Please ensure you have enough credits to run this asset. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 403: + # 403 could be auth or resource, using ResourceError as a general 'forbidden' + return ResourceError( + message=f"Forbidden access: Please verify the API key and its current validity. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 404: + # Added 404 mapping + return ResourceError( + message=f"Resource not found: Please verify the spelling of the resource and its current availability. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 429: + # Using SupplierError for rate limiting as per your original function + return SupplierError( + message=f"Rate limit exceeded: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif status_code == 500: + return InternalError( + message=f"Internal server error: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif status_code == 503: + return ServiceError( + message=f"Service unavailable: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif status_code == 504: + return NetworkError( + message=f"Gateway timeout: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif 460 <= status_code < 470: + return ResourceError( + message=f"Subscription-related error: Please ensure that your subscription is active and has not expired. {error_details}".strip(), + status_code=status_code, + ) + elif 470 <= status_code < 480: + return BillingError( + message=f"Billing-related error: Please ensure you have enough credits to run this asset. {error_details}".strip(), + status_code=status_code, + ) + elif 480 <= status_code < 490: + return SupplierError( + message=f"Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access. {error_details}".strip(), + status_code=status_code, + ) + elif 490 <= status_code < 500: + return ValidationError( + message=f"Validation-related error: Please verify the request payload and ensure it is correct. {error_details}".strip(), + status_code=status_code, + ) + else: + # Catch-all for other client/server errors + category = "Client" if 400 <= status_code < 500 else "Server" + return InternalError( + message=f"Unspecified {category} Error (Status {status_code}) {error_details}".strip(), status_code=status_code + ) diff --git a/aixplain/exceptions/types.py b/aixplain/exceptions/types.py new file mode 100644 index 00000000..e41c7fd2 --- /dev/null +++ b/aixplain/exceptions/types.py @@ -0,0 +1,217 @@ +from enum import Enum +from typing import Optional, Dict, Any + + +class ErrorSeverity(str, Enum): + """Severity levels for errors.""" + + INFO = "info" # Informational, not an error + WARNING = "warning" # Warning, operation can continue + ERROR = "error" # Error, operation cannot continue + CRITICAL = "critical" # System stability might be compromised + + +class ErrorCategory(Enum): + """Categorizes errors by their domain.""" + + AUTHENTICATION = "authentication" # API keys, permissions + VALIDATION = "validation" # Input validation + RESOURCE = "resource" # Resource availability + BILLING = "billing" # Credits, payment + SUPPLIER = "supplier" # External supplier issues + NETWORK = "network" # Network connectivity + SERVICE = "service" # Service availability + INTERNAL = "internal" # Internal system errors + AGENT = "agent" # Agent-specific errors + UNKNOWN = "unknown" # Uncategorized errors + + +class ErrorCode(str, Enum): + """Standard error codes for aiXplain exceptions. + + The format is AX--, where is a short identifier + derived from the ErrorCategory (e.g., AUTH, VAL, RES) and is a + unique sequential number within that category, starting from 1000. + + How to Add a New Error Code: + 1. Identify the appropriate `ErrorCategory` for the new error. + 2. Determine the next available sequential ID within that category. + For example, if `AX-AUTH-1000` exists, the next authentication-specific + error could be `AX-AUTH-1001`. + 3. Define the new enum member using the format `AX--`. + Use a concise abbreviation for the category (e.g., AUTH, VAL, RES, BIL, + SUP, NET, SVC, INT). + 4. Assign the string value (e.g., `"AX-AUTH-1001"`). + 5. Add a clear docstring explaining the specific condition that triggers + this error code. + 6. (Optional but recommended) Consider creating a more specific exception + class inheriting from the corresponding category exception (e.g., + `class InvalidApiKeyError(AuthenticationError): ...`) and assign the + new error code to it. + """ + + AX_AUTH_ERROR = "AX-AUTH-1000" # General authentication error. Use for issues like invalid API keys, insufficient permissions, or failed login attempts. + AX_VAL_ERROR = "AX-VAL-1000" # General validation error. Use when user-provided input fails validation checks (e.g., incorrect data type, missing required fields, invalid format. + AX_RES_ERROR = "AX-RES-1000" # General resource error. Use for issues related to accessing or managing resources, such as a requested model being unavailable or quota limits exceeded. + AX_BIL_ERROR = "AX-BIL-1000" # General billing error. Use for problems related to billing, payments, or credits (e.g., insufficient funds, expired subscription. + AX_SUP_ERROR = "AX-SUP-1000" # General supplier error. Use when an error originates from an external supplier or third-party service integrated with aiXplain. + AX_NET_ERROR = "AX-NET-1000" # General network error. Use for issues related to network connectivity, such as timeouts, DNS resolution failures, or unreachable services. + AX_SVC_ERROR = "AX-SVC-1000" # General service error. Use when a specific aiXplain service or endpoint is unavailable or malfunctioning (e.g., service downtime, internal component failure. + AX_INT_ERROR = "AX-INT-1000" # General internal error. Use for unexpected server-side errors that are not covered by other categories. This often indicates a bug or an issue within the aiXplain platform itself. + + +class AixplainBaseException(Exception): + """Base exception class for all aiXplain exceptions.""" + + def __init__( + self, + message: str, + category: ErrorCategory = ErrorCategory.UNKNOWN, + severity: ErrorSeverity = ErrorSeverity.ERROR, + status_code: Optional[int] = None, + details: Optional[Dict[str, Any]] = None, + retry_recommended: bool = False, + error_code: Optional[ErrorCode] = None, + ): + self.message = message + self.category = category + self.severity = severity + self.status_code = status_code + self.details = details or {} + self.retry_recommended = retry_recommended + self.error_code = error_code + super().__init__(self.message) + + def __str__(self): + error_code_str = f" [{self.error_code}]" if self.error_code else "" + return f"{self.__class__.__name__}{error_code_str}: {self.message}" + + def to_dict(self) -> Dict[str, Any]: + """Convert exception to dictionary for serialization.""" + return { + "message": self.message, + "category": self.category.value, + "severity": self.severity.value, + "status_code": self.status_code, + "details": self.details, + "retry_recommended": self.retry_recommended, + "error_code": self.error_code.value if self.error_code else None, + } + + +class AuthenticationError(AixplainBaseException): + """Raised when authentication fails.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.AUTHENTICATION, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_AUTH_ERROR, + **kwargs, + ) + + +class ValidationError(AixplainBaseException): + """Raised when input validation fails.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.VALIDATION, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_VAL_ERROR, + **kwargs, + ) + + +class ResourceError(AixplainBaseException): + """Raised when a resource is unavailable.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.RESOURCE, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_RES_ERROR, + **kwargs, + ) + + +class BillingError(AixplainBaseException): + """Raised when there are billing issues.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.BILLING, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_BIL_ERROR, + **kwargs, + ) + + +class SupplierError(AixplainBaseException): + """Raised when there are issues with external suppliers.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.SUPPLIER, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", True), + error_code=ErrorCode.AX_SUP_ERROR, + **kwargs, + ) + + +class NetworkError(AixplainBaseException): + """Raised when there are network connectivity issues.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.NETWORK, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", True), + error_code=ErrorCode.AX_NET_ERROR, + **kwargs, + ) + + +class ServiceError(AixplainBaseException): + """Raised when a service is unavailable.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.SERVICE, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", True), + error_code=ErrorCode.AX_SVC_ERROR, + **kwargs, + ) + + +class InternalError(AixplainBaseException): + """Raised when there is an internal system error.""" + + def __init__(self, message: str, **kwargs): + # Server errors (5xx) should generally be retryable + status_code = kwargs.get("status_code") + retry_recommended = kwargs.pop("retry_recommended", False) + if status_code and status_code in [500, 502, 503, 504]: + retry_recommended = True + + super().__init__( + message=message, + category=ErrorCategory.INTERNAL, + severity=ErrorSeverity.ERROR, + retry_recommended=retry_recommended, + error_code=ErrorCode.AX_INT_ERROR, + **kwargs, + ) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 040bcd71..9440ff81 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -41,7 +41,7 @@ from aixplain.utils import config from typing import Callable, Dict, List, Optional, Text, Union -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from aixplain.enums import DatabaseSourceType diff --git a/aixplain/factories/api_key_factory.py b/aixplain/factories/api_key_factory.py index c719c26b..3d081e27 100644 --- a/aixplain/factories/api_key_factory.py +++ b/aixplain/factories/api_key_factory.py @@ -3,7 +3,7 @@ import aixplain.utils.config as config from datetime import datetime from typing import Text, List, Optional, Dict, Union -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.api_key import APIKey, APIKeyLimits, APIKeyUsageLimit diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index 3f643a3d..c37f17a8 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -32,7 +32,7 @@ from aixplain.factories.dataset_factory import DatasetFactory from aixplain.factories.model_factory import ModelFactory from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin diff --git a/aixplain/factories/corpus_factory.py b/aixplain/factories/corpus_factory.py index db7aa44e..9563ad14 100644 --- a/aixplain/factories/corpus_factory.py +++ b/aixplain/factories/corpus_factory.py @@ -38,7 +38,7 @@ from aixplain.enums.language import Language from aixplain.enums.license import License from aixplain.enums.privacy import Privacy -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.utils import config from pathlib import Path from tqdm import tqdm diff --git a/aixplain/factories/data_factory.py b/aixplain/factories/data_factory.py index 1879b321..65aa8a87 100644 --- a/aixplain/factories/data_factory.py +++ b/aixplain/factories/data_factory.py @@ -28,15 +28,11 @@ from aixplain.modules.data import Data from aixplain.enums.data_subtype import DataSubtype from aixplain.enums.data_type import DataType -from aixplain.enums.function import Function from aixplain.enums.language import Language -from aixplain.enums.license import License from aixplain.enums.privacy import Privacy -from aixplain.utils.file_utils import _request_with_retry -from aixplain.utils import config -from typing import Any, Dict, List, Text +from aixplain.utils.request_utils import _request_with_retry +from typing import Dict, Text from urllib.parse import urljoin -from uuid import uuid4 class DataFactory(AssetFactory): diff --git a/aixplain/factories/dataset_factory.py b/aixplain/factories/dataset_factory.py index ca9d993e..3b86b45e 100644 --- a/aixplain/factories/dataset_factory.py +++ b/aixplain/factories/dataset_factory.py @@ -41,7 +41,8 @@ from aixplain.enums.privacy import Privacy from aixplain.utils import config from aixplain.utils.convert_datatype_utils import dict_to_metadata -from aixplain.utils.file_utils import _request_with_retry, s3_to_csv +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.file_utils import s3_to_csv from aixplain.utils.validation_utils import dataset_onboarding_validation from pathlib import Path from tqdm import tqdm diff --git a/aixplain/factories/finetune_factory/__init__.py b/aixplain/factories/finetune_factory/__init__.py index 238d0d0c..b6006f0b 100644 --- a/aixplain/factories/finetune_factory/__init__.py +++ b/aixplain/factories/finetune_factory/__init__.py @@ -33,7 +33,7 @@ from aixplain.modules.dataset import Dataset from aixplain.modules.model import Model from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin @@ -98,7 +98,7 @@ def create( if prompt_template is not None: prompt_template = validate_prompt(prompt_template, dataset_list) try: - url = urljoin(cls.backend_url, f"sdk/finetune/cost-estimation") + url = urljoin(cls.backend_url, "sdk/finetune/cost-estimation") headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} payload = { "datasets": [ diff --git a/aixplain/factories/metric_factory.py b/aixplain/factories/metric_factory.py index 9f42fb3e..6279ffc1 100644 --- a/aixplain/factories/metric_factory.py +++ b/aixplain/factories/metric_factory.py @@ -22,14 +22,12 @@ """ import logging -import os from typing import List, Optional from aixplain.modules import Metric from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from typing import Dict, Text from urllib.parse import urljoin -from warnings import warn class MetricFactory: @@ -113,7 +111,7 @@ def list( List[Metric]: List of supported metrics """ try: - url = urljoin(cls.backend_url, f"sdk/metrics") + url = urljoin(cls.backend_url, "sdk/metrics") filter_params = {} if model_id is not None: filter_params["modelId"] = model_id diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index b39cc668..85c1ac4f 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -27,7 +27,7 @@ from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 1be2186c..418a32b9 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -6,7 +6,7 @@ from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.enums import DataType, Function, Language, OwnershipType, Supplier, SortBy, SortOrder, AssetStatus from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from datetime import datetime from typing import Dict, Union, List, Optional, Tuple from urllib.parse import urljoin diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index cfbfce54..ba164199 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -31,7 +31,7 @@ from aixplain.modules.model import Model from aixplain.modules.pipeline import Pipeline from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from warnings import warn diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index 892d7ded..6a1db846 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -31,7 +31,7 @@ from aixplain.modules.team_agent import TeamAgent, InspectorTarget from aixplain.utils import config from aixplain.factories.team_agent_factory.utils import build_team_agent -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class TeamAgentFactory: diff --git a/aixplain/factories/wallet_factory.py b/aixplain/factories/wallet_factory.py index 1591dc2e..2de28ec4 100644 --- a/aixplain/factories/wallet_factory.py +++ b/aixplain/factories/wallet_factory.py @@ -1,6 +1,6 @@ import aixplain.utils.config as config from aixplain.modules.wallet import Wallet -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry import logging from typing import Text diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 227c0040..b6920f5c 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -26,7 +26,7 @@ import time import traceback -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier from aixplain.enums.asset_status import AssetStatus diff --git a/aixplain/modules/api_key.py b/aixplain/modules/api_key.py index ae774c23..d606e106 100644 --- a/aixplain/modules/api_key.py +++ b/aixplain/modules/api_key.py @@ -1,6 +1,6 @@ import logging from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.modules import Model from datetime import datetime from typing import Dict, List, Optional, Text, Union diff --git a/aixplain/modules/benchmark.py b/aixplain/modules/benchmark.py index d76b2e62..6878becf 100644 --- a/aixplain/modules/benchmark.py +++ b/aixplain/modules/benchmark.py @@ -26,7 +26,7 @@ from aixplain.modules import Asset, Dataset, Metric, Model from aixplain.modules.benchmark_job import BenchmarkJob from urllib.parse import urljoin -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class Benchmark(Asset): diff --git a/aixplain/modules/benchmark_job.py b/aixplain/modules/benchmark_job.py index 29a33aa7..cd17c0e1 100644 --- a/aixplain/modules/benchmark_job.py +++ b/aixplain/modules/benchmark_job.py @@ -4,7 +4,8 @@ from urllib.parse import urljoin import pandas as pd from pathlib import Path -from aixplain.utils.file_utils import _request_with_retry, save_file +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.file_utils import save_file class BenchmarkJob: @@ -72,7 +73,7 @@ def download_results_as_csv(self, save_path: Optional[Text] = None, return_dataf logging.info(f"Downloading Benchmark Results: Status of downloading results for {self.id}: {resp}") if "reportUrl" not in resp or resp["reportUrl"] == "": logging.error( - f"Downloading Benchmark Results: Can't get download results as they aren't generated yet. Please wait for a while." + "Downloading Benchmark Results: Can't get download results as they aren't generated yet. Please wait for a while." ) return None csv_url = resp["reportUrl"] diff --git a/aixplain/modules/corpus.py b/aixplain/modules/corpus.py index b65664b6..10101292 100644 --- a/aixplain/modules/corpus.py +++ b/aixplain/modules/corpus.py @@ -29,8 +29,8 @@ from aixplain.modules.asset import Asset from aixplain.modules.data import Data from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry -from typing import Any, List, Optional, Text +from aixplain.utils.request_utils import _request_with_retry +from typing import List, Optional, Text from urllib.parse import urljoin diff --git a/aixplain/modules/dataset.py b/aixplain/modules/dataset.py index fd79e9f3..85264013 100644 --- a/aixplain/modules/dataset.py +++ b/aixplain/modules/dataset.py @@ -30,7 +30,7 @@ from aixplain.modules.asset import Asset from aixplain.modules.data import Data from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from typing import Any, Dict, List, Optional, Text diff --git a/aixplain/modules/finetune/__init__.py b/aixplain/modules/finetune/__init__.py index fe2cb15c..15cc37a7 100644 --- a/aixplain/modules/finetune/__init__.py +++ b/aixplain/modules/finetune/__init__.py @@ -31,7 +31,7 @@ from aixplain.modules.model import Model from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class Finetune(Asset): diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index adedfcfb..4729dd42 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -28,7 +28,7 @@ from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config from urllib.parse import urljoin -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from typing import Union, Optional, Text, Dict from datetime import datetime from aixplain.modules.model.response import ModelResponse @@ -199,6 +199,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: status = ResponseStatus.FAILED else: status = ResponseStatus.IN_PROGRESS + logging.debug(f"Single Poll for Model: Status of polling for {name}: {resp}") return ModelResponse( status=resp.pop("status", status), @@ -209,6 +210,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: used_credits=resp.pop("usedCredits", 0), run_time=resp.pop("runTime", 0), usage=resp.pop("usage", None), + error_code=resp.get("error_code", None), **resp, ) except Exception as e: @@ -264,6 +266,7 @@ def run( used_credits=response.pop("usedCredits", 0), run_time=response.pop("runTime", 0), usage=response.pop("usage", None), + error_code=response.get("error_code", None), **response, ) diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index cf60d0a2..0b8e6cf8 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -166,6 +166,7 @@ def run( used_credits=response.pop("usedCredits", 0), run_time=response.pop("runTime", 0), usage=response.pop("usage", None), + error_code=response.get("error_code", None), **response, ) diff --git a/aixplain/modules/model/response.py b/aixplain/modules/model/response.py index ac9f8184..a0cf08f8 100644 --- a/aixplain/modules/model/response.py +++ b/aixplain/modules/model/response.py @@ -1,5 +1,6 @@ from typing import Text, Any, Optional, Dict, List, Union from aixplain.enums import ResponseStatus +from aixplain.exceptions.types import ErrorCode class ModelResponse: @@ -16,6 +17,7 @@ def __init__( run_time: float = 0.0, usage: Optional[Dict] = None, url: Optional[Text] = None, + error_code: Optional[ErrorCode] = None, **kwargs, ): self.status = status @@ -31,6 +33,7 @@ def __init__( self.run_time = run_time self.usage = usage self.url = url + self.error_code = error_code self.additional_fields = kwargs def __getitem__(self, key: Text) -> Any: @@ -82,6 +85,8 @@ def __repr__(self) -> str: fields.append(f"usage={self.usage}") if self.url: fields.append(f"url='{self.url}'") + if self.error_code: + fields.append(f"error_code='{self.error_code}'") if self.additional_fields: fields.extend([f"{k}={repr(v)}" for k, v in self.additional_fields.items()]) return f"ModelResponse({', '.join(fields)})" @@ -104,6 +109,7 @@ def to_dict(self) -> Dict[Text, Any]: "run_time": self.run_time, "usage": self.usage, "url": self.url, + "error_code": self.error_code, } if self.additional_fields: base_dict.update(self.additional_fields) diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 96454181..043571c2 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -24,7 +24,7 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.modules.model import Model from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.model.utils import parse_code_decorated from dataclasses import dataclass from typing import Callable, Union, Optional, List, Text, Dict diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 7ba42f2d..6f3f9319 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -4,6 +4,7 @@ import logging from aixplain.utils.file_utils import _request_with_retry from typing import Callable, Dict, List, Text, Tuple, Union, Optional +from aixplain.exceptions import get_error_from_status_code def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): @@ -61,22 +62,12 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: else: response = resp else: - resp = resp["error"] if isinstance(resp, dict) and "error" in resp else resp - if r.status_code == 401: - error = f"Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {resp}" - elif 460 <= r.status_code < 470: - error = f"Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {resp}" - elif 470 <= r.status_code < 480: - error = f"Billing-related error: Please ensure you have enough credits to run this model. Details: {resp}" - elif 480 <= r.status_code < 490: - error = f"Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {resp}" - elif 490 <= r.status_code < 500: - error = f"{resp}" - else: - status_code = str(r.status_code) - error = f"Status {status_code} - Unspecified error: {resp}" - response = {"status": "FAILED", "error_message": error, "completed": True} + error_details = resp["error"] if isinstance(resp, dict) and "error" in resp else resp + status_code = r.status_code + error = get_error_from_status_code(status_code, error_details) + logging.error(f"Error in request: {r.status_code}: {error}") + response = {"status": "FAILED", "error_message": error.message, "completed": True} return response diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 0642e6ee..63eac9b7 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -29,10 +29,11 @@ from aixplain.enums.response_status import ResponseStatus from aixplain.modules.asset import Asset from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin from aixplain.modules.pipeline.response import PipelineResponse +from aixplain.exceptions import get_error_from_status_code class Pipeline(Asset): @@ -410,33 +411,20 @@ def run_async( return res else: - if r.status_code == 401: - error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." - elif 460 <= r.status_code < 470: - error = "Subscription-related error: Please ensure that your subscription is active and has not expired." - elif 470 <= r.status_code < 480: - error = "Billing-related error: Please ensure you have enough credits to run this pipeline. " - elif 480 <= r.status_code < 490: - error = "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access." - elif 490 <= r.status_code < 500: - error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." - else: - status_code = str(r.status_code) - error = ( - f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." - ) + status_code = r.status_code + error = get_error_from_status_code(status_code) logging.error(f"Error in request for {name} (Pipeline ID '{self.id}') - {r.status_code}: {error}") if response_version == "v1": return { "status": "failed", - "error": error, + "error": error.message, "elapsed_time": None, **kwargs, } return PipelineResponse( status=ResponseStatus.FAILED, - error={"error": error, "status": "ERROR"}, + error={"error": error.message, "status": "ERROR"}, elapsed_time=None, **kwargs, ) diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 3014adb1..56f18633 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -41,7 +41,7 @@ from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class InspectorTarget(str, Enum): diff --git a/aixplain/processes/data_onboarding/onboard_functions.py b/aixplain/processes/data_onboarding/onboard_functions.py index 01a3fe9b..09d1b153 100644 --- a/aixplain/processes/data_onboarding/onboard_functions.py +++ b/aixplain/processes/data_onboarding/onboard_functions.py @@ -18,7 +18,7 @@ from aixplain.modules.dataset import Dataset from aixplain.modules.file import File from aixplain.modules.metadata import MetaData -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Text, Union from urllib.parse import urljoin @@ -203,11 +203,10 @@ def build_payload_dataset( "description": dataset.description, "function": dataset.function.value, "onboardingErrorsPolicy": error_handler.value, - "tags": dataset.tags, + "tags": dataset.tags or tags, "privacy": dataset.privacy.value, "license": {"typeId": dataset.license.value}, "refData": ref_data, - "tags": tags, "data": [], "input": [], "hypotheses": [], diff --git a/aixplain/utils/file_utils.py b/aixplain/utils/file_utils.py index 58781dfb..f2bb55bc 100644 --- a/aixplain/utils/file_utils.py +++ b/aixplain/utils/file_utils.py @@ -20,10 +20,9 @@ import aixplain.utils.config as config from aixplain.enums.license import License - +from aixplain.utils.request_utils import _request_with_retry from collections import defaultdict from pathlib import Path -from requests.adapters import HTTPAdapter, Retry from typing import Any, Optional, Text, Union, Dict, List from uuid import uuid4 from urllib.parse import urljoin, urlparse @@ -52,24 +51,6 @@ def save_file(download_url: Text, download_file_path: Optional[Any] = None) -> A return download_file_path -def _request_with_retry(method: Text, url: Text, **params) -> requests.Response: - """Wrapper around requests with Session to retry in case it fails - - Args: - method (Text): HTTP method, such as 'GET' or 'HEAD'. - url (Text): The URL of the resource to fetch. - **params: Params to pass to request function - - Returns: - requests.Response: Response object of the request - """ - session = requests.Session() - retries = Retry(total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]) - session.mount("https://", HTTPAdapter(max_retries=retries)) - response = session.request(method=method.upper(), url=url, **params) - return response - - def download_data(url_link, local_filename=None): if local_filename is None: local_filename = url_link.split("/")[-1] diff --git a/tests/test_utils.py b/tests/test_utils.py index 264538d5..3d93fe25 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin import logging from aixplain.utils import config @@ -12,7 +12,7 @@ def delete_asset(model_id, api_key): def delete_service_account(api_key): - delete_url = urljoin(config.BACKEND_URL, f"sdk/ecr/logout") + delete_url = urljoin(config.BACKEND_URL, "sdk/ecr/logout") logging.debug(f"URL: {delete_url}") headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} _ = _request_with_retry("post", delete_url, headers=headers) diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index 753a8f7a..5d188470 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -24,17 +24,17 @@ ), ( 475, - "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", + "Billing-related error: Please ensure you have enough credits to run this asset. Details: An unspecified error occurred while processing your request.", ), ( 485, - "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", + "Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access. Details: An unspecified error occurred while processing your request.", ), ( 495, - "An unspecified error occurred while processing your request.", + "Validation-related error: Please verify the request payload and ensure it is correct. Details: An unspecified error occurred while processing your request.", ), - (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), + (501, "Unspecified Server Error (Status 501) Details: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message): diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 65708d0c..f9e1a379 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -121,17 +121,17 @@ def test_failed_poll(): ), ( 475, - "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", + "Billing-related error: Please ensure you have enough credits to run this asset. Details: An unspecified error occurred while processing your request.", ), ( 485, - "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", + "Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access. Details: An unspecified error occurred while processing your request.", ), ( 495, - "An unspecified error occurred while processing your request.", + "Validation-related error: Please verify the request payload and ensure it is correct. Details: An unspecified error occurred while processing your request.", ), - (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), + (501, "Unspecified Server Error (Status 501) Details: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message): diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index 7df95691..5f08eec3 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -35,12 +35,8 @@ def test_create_pipeline(): headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} ref_response = {"id": "12345"} mock.post(url, headers=headers, json=ref_response) - ref_pipeline = Pipeline( - id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY - ) - hyp_pipeline = PipelineFactory.create( - pipeline={"nodes": []}, name="Pipeline Test" - ) + ref_pipeline = Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) + hyp_pipeline = PipelineFactory.create(pipeline={"nodes": []}, name="Pipeline Test") assert hyp_pipeline.id == ref_pipeline.id assert hyp_pipeline.name == ref_pipeline.name @@ -58,19 +54,19 @@ def test_create_pipeline(): ), ( 475, - "{'error': 'Billing-related error: Please ensure you have enough credits to run this pipeline. ', 'status': 'ERROR'}", + "{'error': 'Billing-related error: Please ensure you have enough credits to run this asset.', 'status': 'ERROR'}", ), ( 485, - "{'error': 'Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access.', 'status': 'ERROR'}", + "{'error': 'Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access.', 'status': 'ERROR'}", ), ( 495, - "{'error': 'Validation-related error: Please ensure all required fields are provided and correctly formatted.', 'status': 'ERROR'}", + "{'error': 'Validation-related error: Please verify the request payload and ensure it is correct.', 'status': 'ERROR'}", ), ( 501, - "{'error': 'Status 501: Unspecified error: An unspecified error occurred while processing your request.', 'status': 'ERROR'}", + "{'error': 'Unspecified Server Error (Status 501)', 'status': 'ERROR'}", ), ], ) @@ -107,14 +103,9 @@ def test_list_pipelines_error_response(): mock.post(url, headers=headers, json=error_response, status_code=400) with pytest.raises(Exception) as excinfo: - PipelineFactory.list( - query=query, page_number=page_number, page_size=page_size - ) + PipelineFactory.list(query=query, page_number=page_number, page_size=page_size) - assert ( - "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" - in str(excinfo.value) - ) + assert "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" in str(excinfo.value) def test_get_pipeline_error_response(): @@ -132,112 +123,7 @@ def test_get_pipeline_error_response(): with pytest.raises(Exception) as excinfo: PipelineFactory.get(pipeline_id=pipeline_id) - assert ( - "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" - in str(excinfo.value) - ) - - -@pytest.fixture -def mock_pipeline(): - return Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) - - -def test_run_async_success(mock_pipeline): - with requests_mock.Mocker() as mock: - execute_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" - ) - success_response = PipelineResponse( - status=ResponseStatus.SUCCESS, url=execute_url - ) - mock.post(execute_url, json=success_response.__dict__, status_code=200) - - response = mock_pipeline.run_async(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_run_sync_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" - ) - execute_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" - ) - success_response = PipelineResponse(status=ResponseStatus.SUCCESS, url=poll_url) - poll_response = PipelineResponse( - status=ResponseStatus.SUCCESS, data={"output": "poll_result"} - ) - mock.post(execute_url, json=success_response.__dict__, status_code=200) - mock.get(poll_url, json=poll_response.__dict__, status_code=200) - response = mock_pipeline.run(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_poll_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" - ) - poll_response = PipelineResponse( - status=ResponseStatus.SUCCESS, data={"output": "poll_result"} - ) - mock.get(poll_url, json=poll_response.__dict__, status_code=200) - - response = mock_pipeline.poll(poll_url=poll_url) - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - assert response.data["output"] == "poll_result" - - -@pytest.fixture -def mock_pipeline(): - return Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) - - -def test_run_async_success(mock_pipeline): - with requests_mock.Mocker() as mock: - execute_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}") - success_response = PipelineResponse(status=ResponseStatus.SUCCESS, url=execute_url) - mock.post(execute_url, json=success_response.__dict__, status_code=200) - - response = mock_pipeline.run_async(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_run_sync_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}") - execute_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}") - success_response = {"status": "SUCCESS", "url": poll_url, "completed": True} - poll_response = {"status": "SUCCESS", "data": {"output": "poll_result"}, "completed": True} - mock.post(execute_url, json=success_response, status_code=200) - mock.get(poll_url, json=poll_response, status_code=200) - response = mock_pipeline.run(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_poll_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}") - poll_response = PipelineResponse(status=ResponseStatus.SUCCESS, data={"output": "poll_result"}) - mock.get(poll_url, json=poll_response.__dict__, status_code=200) - - response = mock_pipeline.poll(poll_url=poll_url) - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - assert response.data["output"] == "poll_result" + assert "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" in str(excinfo.value) @pytest.fixture From 0c4edf4fd6165212a5101ee20a01d5aa135e3842 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Tue, 6 May 2025 15:54:02 +0300 Subject: [PATCH 18/22] ENG-1862:Added status to Tools and deployment check for Agent and TeamAgent (#455) * Added status to Models, Pipelines * added status to Tools and deployment check for Agent and TeamAgent * DeployableMixin added * removed the deployement status check as it is moved to backend * for pipeline and model if id given check status added * model params set for model tool explicitly from object * BUG-499: Fix status update in class when update failed * Set to previous status if failed to update --- aixplain/enums/asset_status.py | 1 + aixplain/factories/model_factory/utils.py | 5 +- aixplain/modules/agent/__init__.py | 19 ++--- aixplain/modules/agent/tool/__init__.py | 3 + .../agent/tool/custom_python_code_tool.py | 6 ++ aixplain/modules/agent/tool/model_tool.py | 57 +++++++++++--- aixplain/modules/agent/tool/pipeline_tool.py | 10 ++- .../agent/tool/python_interpreter_tool.py | 2 + aixplain/modules/agent/tool/sql_tool.py | 3 +- aixplain/modules/mixins.py | 74 +++++++++++++++++++ aixplain/modules/model/__init__.py | 12 ++- aixplain/modules/model/utility_model.py | 13 ++-- aixplain/modules/pipeline/asset.py | 37 ++++++---- aixplain/modules/team_agent/__init__.py | 16 +--- tests/functional/model/run_model_test.py | 5 ++ tests/unit/agent/model_tool_test.py | 5 +- .../mock_responses/list_models_response.json | 11 ++- tests/unit/model_test.py | 8 +- tests/unit/team_agent_test.py | 25 ++++++- 19 files changed, 241 insertions(+), 71 deletions(-) create mode 100644 aixplain/modules/mixins.py diff --git a/aixplain/enums/asset_status.py b/aixplain/enums/asset_status.py index 994212fb..3d1e4323 100644 --- a/aixplain/enums/asset_status.py +++ b/aixplain/enums/asset_status.py @@ -43,3 +43,4 @@ class AssetStatus(Text, Enum): COMPLETED = "completed" CANCELING = "canceling" CANCELED = "canceled" + DEPRECATED_DRAFT = "deprecated_draft" diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 418a32b9..05a4b4c9 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -10,6 +10,7 @@ from datetime import datetime from typing import Dict, Union, List, Optional, Tuple from urllib.parse import urljoin +from aixplain.enums import AssetStatus import requests @@ -61,7 +62,6 @@ def create_model_from_response(response: Dict) -> Model: for param in response["params"] ] input_params = model_params - if not code: if "version" in response and response["version"]: version_link = response["version"]["id"] @@ -74,6 +74,7 @@ def create_model_from_response(response: Dict) -> Model: else: raise Exception("Utility Model Error: Code not found") + status = AssetStatus(response.get("status", AssetStatus.DRAFT.value)) created_at = None if "createdAt" in response and response["createdAt"]: created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) @@ -96,7 +97,7 @@ def create_model_from_response(response: Dict) -> Model: version=response["version"]["id"], inputs=inputs, temperature=temperature, - status=response.get("status", AssetStatus.DRAFT), + status=status, ) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index b6920f5c..8f19d3ec 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -26,26 +26,23 @@ import time import traceback -from aixplain.utils.request_utils import _request_with_retry -from aixplain.enums.function import Function -from aixplain.enums.supplier import Supplier -from aixplain.enums.asset_status import AssetStatus -from aixplain.enums.storage_type import StorageType +from aixplain.utils.file_utils import _request_with_retry +from aixplain.enums import Function, Supplier, AssetStatus, StorageType, ResponseStatus from aixplain.modules.model import Model from aixplain.modules.agent.agent_task import AgentTask from aixplain.modules.agent.output_format import OutputFormat from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData -from aixplain.enums import ResponseStatus from aixplain.modules.agent.utils import process_variables from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin from aixplain.utils import config +from aixplain.modules.mixins import DeployableMixin -class Agent(Model): +class Agent(Model, DeployableMixin[Tool]): """Advanced AI system capable of performing tasks by leveraging specialized software tools and resources from aiXplain marketplace. Attributes: @@ -212,6 +209,8 @@ def run( poll_url = response["url"] end = time.time() result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + # if result.status == ResponseStatus.FAILED: + # raise Exception("Model failed to run with error: " + result.error_message) result_data = result.get("data") or {} return AgentResponse( status=ResponseStatus.SUCCESS, @@ -419,11 +418,5 @@ def save(self) -> None: """Save the Agent.""" self.update() - def deploy(self) -> None: - assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." - assert self.status != AssetStatus.ONBOARDED, "Agent is already deployed." - self.status = AssetStatus.ONBOARDED - self.update() - def __repr__(self): return f"Agent(id={self.id}, name={self.name}, function={self.function})" diff --git a/aixplain/modules/agent/tool/__init__.py b/aixplain/modules/agent/tool/__init__.py index aefa093a..93dc269d 100644 --- a/aixplain/modules/agent/tool/__init__.py +++ b/aixplain/modules/agent/tool/__init__.py @@ -23,6 +23,7 @@ from abc import ABC from typing import Optional, Text from aixplain.utils import config +from aixplain.enums import AssetStatus class Tool(ABC): @@ -40,6 +41,7 @@ def __init__( description: Text, version: Optional[Text] = None, api_key: Optional[Text] = config.TEAM_API_KEY, + status: Optional[AssetStatus] = AssetStatus.DRAFT, **additional_info, ) -> None: """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. @@ -55,6 +57,7 @@ def __init__( self.version = version self.api_key = api_key self.additional_info = additional_info + self.status = status def to_dict(self): """Converts the tool to a dictionary.""" diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index d544d2b1..9d6363c6 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -24,6 +24,7 @@ from typing import Text, Union, Callable, Optional from aixplain.modules.agent.tool import Tool import logging +from aixplain.enums import AssetStatus class CustomPythonCodeTool(Tool): @@ -35,6 +36,7 @@ def __init__( """Custom Python Code Tool""" super().__init__(name=name or "", description=description, **additional_info) self.code = code + self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool self.validate() def to_dict(self): @@ -67,6 +69,10 @@ def validate(self): ), "Custom Python Code Tool Error: Tool description is required" assert self.code and self.code.strip() != "", "Custom Python Code Tool Error: Code is required" assert self.name and self.name.strip() != "", "Custom Python Code Tool Error: Name is required" + assert self.status in [ + AssetStatus.DRAFT, + AssetStatus.ONBOARDED, + ], "Custom Python Code Tool Error: Status must be DRAFT or ONBOARDED" def __repr__(self) -> Text: return f"CustomPythonCodeTool(name={self.name})" diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 6a945a15..9b073a84 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -22,8 +22,7 @@ """ from typing import Optional, Union, Text, Dict, List -from aixplain.enums.function import Function -from aixplain.enums.supplier import Supplier +from aixplain.enums import AssetStatus, Function, Supplier from aixplain.modules.agent.tool import Tool from aixplain.modules.model import Model @@ -83,8 +82,26 @@ def __init__( """ name = name or "" super().__init__(name=name, description=description, **additional_info) + status = AssetStatus.ONBOARDED if model is None else AssetStatus.DRAFT + model_id = model # if None, Set id to None as default + self.model_object = None # Store the actual model object for parameter access + + if isinstance(model, Model): + model_id = model.id + status = model.status + self.model_object = model # Store the Model object + elif isinstance(model, Text): + # get model from id + try: + self.model_object = self._get_model(model) # Store the Model object + model_id = self.model_object.id + status = self.model_object.status + except Exception: + raise Exception(f"Model Tool Unavailable. Make sure Model '{model}' exists or you have access to it.") + self.supplier = supplier - self.model = model + self.model = model_id + self.status = status self.function = function self.parameters = parameters self.validate() @@ -109,6 +126,7 @@ def to_dict(self) -> Dict: "version": self.version if self.version else None, "assetId": self.model.id if self.model is not None and isinstance(self.model, Model) else self.model, "parameters": self.parameters, + "status": self.status, } def validate(self) -> None: @@ -123,7 +141,6 @@ def validate(self) -> None: - If the description is empty, it sets the description to the function description or the model description. """ from aixplain.enums import FunctionInputOutput - from aixplain.factories.model_factory import ModelFactory assert ( self.function is not None or self.model is not None @@ -145,7 +162,7 @@ def validate(self) -> None: if self.model is not None: if isinstance(self.model, Text) is True: try: - self.model = ModelFactory.get(self.model, api_key=self.api_key) + self.model = self._get_model() except Exception: raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.") self.function = self.model.function @@ -166,8 +183,22 @@ def validate(self) -> None: self.name = self.name if self.name else set_tool_name(self.function, self.supplier, self.model) def get_parameters(self) -> Dict: + # If parameters were not explicitly provided, get them from the model + if ( + self.parameters is None + and self.model_object is not None # noqa: W503 + and hasattr(self.model_object, "model_params") # noqa: W503 + and self.model_object.model_params is not None # noqa: W503 + ): + return self.model_object.model_params.to_list() return self.parameters + def _get_model(self, model_id: Text = None): + from aixplain.factories.model_factory import ModelFactory + + model_id = model_id or self.model + return ModelFactory.get(model_id, api_key=self.api_key) + def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) -> Optional[List[Dict]]: """Validates and formats the parameters for the tool. @@ -182,8 +213,12 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) """ if received_parameters is None: # Get default parameters if none provided - if self.model is not None and self.model.model_params is not None: - return self.model.model_params.to_list() + if ( + self.model_object is not None + and hasattr(self.model_object, "model_params") # noqa: W503 + and self.model_object.model_params is not None # noqa: W503 + ): + return self.model_object.model_params.to_list() elif self.function is not None: function_params = self.function.get_parameters() if function_params is not None: @@ -192,8 +227,12 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) # Get expected parameters expected_params = None - if self.model is not None and self.model.model_params is not None: - expected_params = self.model.model_params + if ( + self.model_object is not None + and hasattr(self.model_object, "model_params") # noqa: W503 + and self.model_object.model_params is not None # noqa: W503 + ): + expected_params = self.model_object.model_params elif self.function is not None: expected_params = self.function.get_parameters() diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 0de83916..eccc80f7 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -24,6 +24,7 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.pipeline import Pipeline +from aixplain.enums import AssetStatus class PipelineTool(Tool): @@ -50,6 +51,7 @@ def __init__( name = name or "" super().__init__(name=name, description=description, **additional_info) + self.status = AssetStatus.DRAFT self.pipeline = pipeline self.validate() @@ -59,8 +61,12 @@ def to_dict(self): "name": self.name, "description": self.description, "type": "pipeline", + "status": self.status, } + def __repr__(self) -> Text: + return f"PipelineTool(name={self.name}, pipeline={self.pipeline})" + def validate(self): from aixplain.factories.pipeline_factory import PipelineFactory @@ -76,6 +82,4 @@ def validate(self): if self.name.strip() == "": self.name = pipeline_obj.name - - def __repr__(self) -> Text: - return f"PipelineTool(name={self.name}, pipeline={self.pipeline})" + self.status = pipeline_obj.status diff --git a/aixplain/modules/agent/tool/python_interpreter_tool.py b/aixplain/modules/agent/tool/python_interpreter_tool.py index 42621c45..5e1e45e5 100644 --- a/aixplain/modules/agent/tool/python_interpreter_tool.py +++ b/aixplain/modules/agent/tool/python_interpreter_tool.py @@ -22,6 +22,7 @@ """ from aixplain.modules.agent.tool import Tool +from aixplain.enums import AssetStatus from typing import Text @@ -32,6 +33,7 @@ def __init__(self, **additional_info) -> None: """Python Interpreter Tool""" description = "A Python shell. Use this to execute python commands. Input should be a valid python command." super().__init__(name="Python Interpreter", description=description, **additional_info) + self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool def to_dict(self): return { diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index 56cf116c..f2262ee9 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -27,7 +27,7 @@ import numpy as np from typing import Text, Optional, Dict, List, Union import sqlite3 - +from aixplain.enums import AssetStatus from aixplain.modules.agent.tool import Tool @@ -284,6 +284,7 @@ def __init__( self.schema = schema self.tables = tables if isinstance(tables, list) else [tables] if tables else None self.enable_commit = enable_commit + self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool def to_dict(self) -> Dict[str, Text]: return { diff --git a/aixplain/modules/mixins.py b/aixplain/modules/mixins.py new file mode 100644 index 00000000..0b402cf2 --- /dev/null +++ b/aixplain/modules/mixins.py @@ -0,0 +1,74 @@ +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: November 25th 2024 +Description: + Mixins for common functionality across different asset types +""" +from abc import ABC +from typing import TypeVar, Generic +from aixplain.enums import AssetStatus + +T = TypeVar("T") + + +class DeployableMixin(ABC, Generic[T]): + """A mixin that provides common deployment-related functionality for assets. + + This mixin provides methods for: + 1. Filtering items that are not onboarded + 2. Validating if an asset is ready to be deployed + 3. Deploying an asset + + Classes that inherit from this mixin should: + 1. Implement _validate_deployment_readiness to call the parent implementation with their specific asset type + 2. Optionally override deploy() if they need special deployment handling + """ + + def _validate_deployment_readiness(self) -> None: + """Validate if the asset is ready to be deployed. + + Args: + asset_type (str): Type of asset being validated (e.g. "Agent", "Team Agent", "Pipeline") + items (Optional[List[T]], optional): List of items to validate (e.g. tools for Agent, agents for TeamAgent) + + Raises: + ValueError: If the asset is not ready to be deployed + """ + asset_type = self.__class__.__name__ + if self.status == AssetStatus.ONBOARDED: + raise ValueError(f"{asset_type} is already deployed.") + + if self.status != AssetStatus.DRAFT: + raise ValueError(f"{asset_type} must be in DRAFT status to be deployed.") + + def deploy(self) -> None: + """Deploy the asset. + + This method validates that the asset is ready to be deployed and updates its status to ONBOARDED. + Classes that need special deployment handling should override this method. + + Raises: + ValueError: If the asset is not ready to be deployed + """ + self._validate_deployment_readiness() + previous_status = self.status + try: + self.status = AssetStatus.ONBOARDED + self.update() + except Exception as e: + self.status = previous_status + raise Exception(f"Error deploying because of backend error: {e}") from e diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 4729dd42..49cf4591 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -34,6 +34,7 @@ from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters +from aixplain.enums import AssetStatus class Model(Asset): @@ -72,6 +73,7 @@ def __init__( input_params: Optional[Dict] = None, output_params: Optional[Dict] = None, model_params: Optional[Dict] = None, + status: Optional[AssetStatus] = AssetStatus.ONBOARDED, # default status for models is ONBOARDED **additional_info, ) -> None: """Model Init @@ -89,6 +91,7 @@ def __init__( input_params (Dict, optional): input parameters for the function. output_params (Dict, optional): output parameters for the function. model_params (Dict, optional): parameters for the function. + status (AssetStatus, optional): status of the model. Defaults to None. **additional_info: Any additional Model info to be saved """ super().__init__(id, name, description, supplier, version, cost=cost) @@ -102,6 +105,12 @@ def __init__( self.input_params = input_params self.output_params = output_params self.model_params = ModelParameters(model_params) if model_params else None + if isinstance(status, str): + try: + status = AssetStatus(status) + except Exception: + status = AssetStatus.ONBOARDED + self.status = status def to_dict(self) -> Dict: """Get the model info as a Dictionary @@ -119,6 +128,7 @@ def to_dict(self) -> Dict: "input_params": self.input_params, "output_params": self.output_params, "model_params": self.model_params.to_dict(), + "status": self.status, } def get_parameters(self) -> ModelParameters: @@ -309,7 +319,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): Returns: FinetuneStatus: The status of the FineTune model. """ - from aixplain.enums.asset_status import AssetStatus + from aixplain.enums import AssetStatus from aixplain.modules.finetune.status import FinetuneStatus headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 043571c2..1cbbf3da 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -21,7 +21,7 @@ import logging import warnings from aixplain.enums import Function, Supplier, DataType -from aixplain.enums.asset_status import AssetStatus +from aixplain.enums import AssetStatus from aixplain.modules.model import Model from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry @@ -29,6 +29,7 @@ from dataclasses import dataclass from typing import Callable, Union, Optional, List, Text, Dict from urllib.parse import urljoin +from aixplain.modules.mixins import DeployableMixin @dataclass @@ -88,7 +89,7 @@ def decorator(func): return decorator -class UtilityModel(Model): +class UtilityModel(Model, DeployableMixin): """Ready-to-use Utility Model. Note: Non-deployed utility models (status=DRAFT) will expire after 24 hours after creation. @@ -107,6 +108,7 @@ class UtilityModel(Model): function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. + status (AssetStatus, optional): status of the model. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Model info to be saved """ @@ -155,6 +157,7 @@ def __init__( function=function, is_subscribed=is_subscribed, api_key=api_key, + status=status, **additional_info, ) self.url = config.MODELS_RUN_URL @@ -274,9 +277,3 @@ def delete(self): message = f"Utility Model Deletion Error: {response}" logging.error(message) raise Exception(f"{message}") - - def deploy(self) -> None: - assert self.status == AssetStatus.DRAFT, "Utility Model must be in draft status to be deployed." - assert self.status != AssetStatus.ONBOARDED, "Utility Model is already deployed." - self.status = AssetStatus.ONBOARDED - self.update() diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 63eac9b7..cc234337 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -1,7 +1,7 @@ __author__ = "aiXplain" """ -Copyright 2022 The aiXplain SDK authors +Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,28 +15,28 @@ See the License for the specific language governing permissions and limitations under the License. -Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli -Date: September 1st 2022 +Author: Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: November 25th 2024 Description: - Pipeline Class + Pipeline Asset Class """ import time import json import os import logging -from aixplain.enums.asset_status import AssetStatus -from aixplain.enums.response_status import ResponseStatus -from aixplain.modules.asset import Asset +from aixplain.enums import AssetStatus, ResponseStatus +from aixplain.modules import Asset from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin from aixplain.modules.pipeline.response import PipelineResponse +from aixplain.modules.mixins import DeployableMixin from aixplain.exceptions import get_error_from_status_code -class Pipeline(Asset): +class Pipeline(Asset, DeployableMixin): """Representing a custom pipeline that was created on the aiXplain Platform Attributes: @@ -46,6 +46,7 @@ class Pipeline(Asset): url (Text, optional): running URL of platform. Defaults to config.BACKEND_URL. supplier (Text, optional): Pipeline supplier. Defaults to "aiXplain". version (Text, optional): version of the pipeline. Defaults to "1.0". + status (AssetStatus, optional): Pipeline status. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Pipeline info to be saved """ @@ -581,13 +582,23 @@ def save( raise Exception(e) def deploy(self, api_key: Optional[Text] = None) -> None: - """Deploy the Pipeline.""" - assert self.status == "draft", "Pipeline Deployment Error: Pipeline must be in draft status." - assert self.status != "onboarded", "Pipeline Deployment Error: Pipeline must be onboarded." + """Deploy the Pipeline. + This method overrides the deploy method in DeployableMixin to handle + Pipeline-specific deployment functionality. + + Args: + api_key (Optional[Text], optional): Team API Key to deploy the Pipeline. Defaults to None. + """ + self._validate_deployment_readiness() pipeline = self.to_dict() - self.update(pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name) - self.status = AssetStatus.ONBOARDED + previous_status = self.status + try: + self.status = AssetStatus.ONBOARDED + self.update(pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name) + except Exception as e: + self.status = previous_status + raise Exception(f"Error deploying because of backend error: {e}") from e def __repr__(self): return f"Pipeline(id={self.id}, name={self.name})" diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 56f18633..449e9549 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -41,7 +41,8 @@ from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.file_utils import _request_with_retry +from aixplain.modules.mixins import DeployableMixin class InspectorTarget(str, Enum): @@ -53,7 +54,7 @@ def __str__(self): return self._value_ -class TeamAgent(Model): +class TeamAgent(Model, DeployableMixin[Agent]): """Advanced AI system capable of using multiple agents to perform a variety of tasks. Attributes: @@ -416,14 +417,3 @@ def update(self) -> None: else: error_msg = f"Team Agent Update Error (HTTP {r.status_code}): {resp}" raise Exception(error_msg) - - def save(self) -> None: - """Save the Team Agent.""" - self.update() - - def deploy(self) -> None: - """Deploy the Team Agent.""" - assert self.status == AssetStatus.DRAFT, "Team Agent Deployment Error: Team Agent must be in draft status." - assert self.status != AssetStatus.ONBOARDED, "Team Agent Deployment Error: Team Agent must be onboarded." - self.status = AssetStatus.ONBOARDED - self.update() diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 1aed18df..25a28640 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -237,3 +237,8 @@ def test_index_model_air_with_image(): assert "hurricane" in second_record.lower() index_model.delete() + + import os + + if os.path.exists("hurricane.jpeg"): + os.remove("hurricane.jpeg") diff --git a/tests/unit/agent/model_tool_test.py b/tests/unit/agent/model_tool_test.py index 869979e1..5dfc736e 100644 --- a/tests/unit/agent/model_tool_test.py +++ b/tests/unit/agent/model_tool_test.py @@ -4,11 +4,10 @@ from unittest.mock import MagicMock from aixplain.modules.agent.tool.model_tool import ModelTool -from aixplain.enums.function import Function -from aixplain.enums.supplier import Supplier from aixplain.modules.model import Model from aixplain.modules.model.model_parameters import ModelParameters from aixplain.base.parameters import Parameter +from aixplain.enums import AssetStatus, Function, Supplier @pytest.fixture @@ -19,6 +18,7 @@ def mock_model(): model.supplier = Supplier.AIXPLAIN model.name = "Test Model" model.description = "Test Model Description" + model.status = AssetStatus.ONBOARDED model.model_params = ModelParameters( { "sourcelanguage": {"name": "sourcelanguage", "required": True}, @@ -107,6 +107,7 @@ def test_to_dict(mocker, mock_model, mock_model_factory): "version": None, "assetId": "test_model_id", "parameters": [{"name": "sourcelanguage", "value": "en"}, {"name": "targetlanguage", "value": "es"}], + "status": mock_model.status.value, } result = tool.to_dict() diff --git a/tests/unit/mock_responses/list_models_response.json b/tests/unit/mock_responses/list_models_response.json index de2cb7ba..c927ba0b 100644 --- a/tests/unit/mock_responses/list_models_response.json +++ b/tests/unit/mock_responses/list_models_response.json @@ -1,4 +1,3 @@ - { "total": 15, "pageTotal": 10, @@ -7,6 +6,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Transformerml6x6)", "serviceName": "Machine Translation", + "status": "onboarded", "supplier": { "id": "Nvidia", "name": "Nvidia" @@ -129,6 +129,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Transformer12x2)", "serviceName": "Machine Translation", + "status": "onboarded", "supplier": { "id": "Nvidia", "name": "Nvidia" @@ -250,6 +251,7 @@ { "id": "test_asset_id", "name": "Translate from English to French", + "status": "onboarded", "supplier": { "id": "eBay", "name": "eBay" @@ -376,6 +378,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Canada)", "serviceName": "Amazon Translate", + "status": "onboarded", "supplier": { "id": "AWS", "name": "AWS" @@ -507,6 +510,7 @@ "id": "test_asset_id", "name": "Translate from English to French", "serviceName": "Cloud Translation", + "status": "onboarded", "supplier": { "id": "Google", "name": "Google" @@ -632,6 +636,7 @@ { "id": "test_asset_id", "name": "Translate from English to French", + "status": "onboarded", "supplier": { "id": "ModernMT", "name": "ModernMT" @@ -758,6 +763,7 @@ "id": "test_asset_id", "name": "Translate from English to French", "serviceName": "Cognitive Services", + "status": "onboarded", "supplier": { "id": "Azure", "name": "Azure" @@ -884,6 +890,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Canada)", "serviceName": "Cognitive Services", + "status": "onboarded", "supplier": { "id": "Azure", "name": "Azure" @@ -1010,6 +1017,7 @@ { "id": "test_asset_id", "name": "Translate from English to French", + "status": "onboarded", "supplier": { "id": "AppTek", "name": "AppTek" @@ -1136,6 +1144,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Transformerml24x6)", "serviceName": "Machine Translation", + "status": "onboarded", "supplier": { "id": "Nvidia", "name": "Nvidia" diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index f9e1a379..943d501d 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -204,6 +204,7 @@ def test_get_model_from_ids(): { "id": "test-model-id-1", "name": "Test Model 1", + "status": "onboarded", "description": "Test Description 1", "function": {"id": "text-generation"}, "supplier": {"id": "aiXplain"}, @@ -214,6 +215,7 @@ def test_get_model_from_ids(): { "id": "test-model-id-2", "name": "Test Model 2", + "status": "onboarded", "description": "Test Description 2", "function": {"id": "text-generation"}, "supplier": {"id": "aiXplain"}, @@ -237,9 +239,9 @@ def test_list_models_error(): with pytest.raises(Exception) as excinfo: ModelFactory.list(model_ids=model_ids, function=Function.TEXT_GENERATION, api_key=config.AIXPLAIN_API_KEY) - assert ( - str(excinfo.value) - == "Cannot filter by function, suppliers, source languages, target languages, is finetunable, ownership, sort by when using model ids" + assert str(excinfo.value) == ( + "Cannot filter by function, suppliers, " + "source languages, target languages, is finetunable, ownership, sort by when using model ids" ) with pytest.raises(Exception) as excinfo: diff --git a/tests/unit/team_agent_test.py b/tests/unit/team_agent_test.py index 2b06043e..e84e2e34 100644 --- a/tests/unit/team_agent_test.py +++ b/tests/unit/team_agent_test.py @@ -1,7 +1,7 @@ import pytest import requests_mock from urllib.parse import urljoin -from unittest.mock import patch +from unittest.mock import patch, Mock from aixplain.enums.asset_status import AssetStatus from aixplain.factories import TeamAgentFactory @@ -181,7 +181,7 @@ def test_create_team_agent(mock_model_factory_get): "role": "Test Agent Role", "teamId": "123", "version": "1.0", - "status": "draft", + "status": "onboarded", "llmId": "6646261c6eb563165658bbb1", "pricing": {"currency": "USD", "value": 0.0}, "assets": [ @@ -502,3 +502,24 @@ def get_mock(agent_id): agent1 = next((agent for agent in team_agent.agents if agent.id == "agent1"), None) assert agent1 is not None assert agent1.tasks[0].dependencies[0].name == "Test Task 2" + + +def test_deploy_team_agent(): + # Create a mock agent with ONBOARDED status + mock_agent = Mock() + mock_agent.id = "agent-id" + mock_agent.name = "Test Agent" + mock_agent.status = AssetStatus.ONBOARDED + + # Create the team agent + team_agent = TeamAgent(id="team-agent-id", name="Test Team Agent", agents=[mock_agent], status=AssetStatus.DRAFT) + + # Mock the update method + team_agent.update = Mock() + + # Deploy the team agent + team_agent.deploy() + + # Verify that status was updated and update was called + assert team_agent.status == AssetStatus.ONBOARDED + team_agent.update.assert_called_once() From 84baed8b98f0802f222fc294c680ffeb1cd622c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Fri, 9 May 2025 11:37:59 +0300 Subject: [PATCH 19/22] Fix: BUG-543 SQL Tool upload db issue (#519) --- aixplain/factories/agent_factory/__init__.py | 5 +++-- aixplain/modules/agent/tool/sql_tool.py | 2 +- aixplain/v2/agent.py | 2 ++ .../functional/agent/agent_functional_test.py | 21 ++++++++++++++----- tests/unit/agent/agent_factory_utils_test.py | 13 ++++++++---- tests/unit/agent/sql_tool_test.py | 14 ++++++++----- 6 files changed, 40 insertions(+), 17 deletions(-) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 9440ff81..156f1b54 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -177,7 +177,9 @@ def create_model_tool( supplier = supplier_ break assert isinstance(supplier, Supplier), f"Supplier {supplier} is not a valid supplier" - return ModelTool(function=function, supplier=supplier, model=model, description=description, parameters=parameters) + return ModelTool( + function=function, supplier=supplier, model=model, name=name, description=description, parameters=parameters + ) @classmethod def create_pipeline_tool( @@ -278,7 +280,6 @@ def create_sql_tool( base_name = os.path.splitext(os.path.basename(source))[0] db_path = os.path.join(os.path.dirname(source), f"{base_name}.db") table_name = tables[0] if tables else None - try: # Create database from CSV schema = create_database_from_csv(source, db_path, table_name) diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index f2262ee9..5b3f4178 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -285,6 +285,7 @@ def __init__( self.tables = tables if isinstance(tables, list) else [tables] if tables else None self.enable_commit = enable_commit self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool + self.validate() # to upload the database def to_dict(self) -> Dict[str, Text]: return { @@ -306,7 +307,6 @@ def validate(self): raise SQLToolError("Description is required") if not self.database: raise SQLToolError("Database must be provided") - # Handle database validation if not ( str(self.database).startswith("s3://") diff --git a/aixplain/v2/agent.py b/aixplain/v2/agent.py index eadb155f..a493f723 100644 --- a/aixplain/v2/agent.py +++ b/aixplain/v2/agent.py @@ -131,6 +131,7 @@ def create_custom_python_code_tool( @classmethod def create_sql_tool( cls, + name: str, description: str, source: str, source_type: str, @@ -172,6 +173,7 @@ def create_sql_tool( from aixplain.factories import AgentFactory return AgentFactory.create_sql_tool( + name=name, description=description, source=source, source_type=source_type, diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 73bab1fe..10569440 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -354,6 +354,7 @@ def test_specific_model_parameters_e2e(tool_config, delete_agents_and_team_agent @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents + agent = None try: import os @@ -362,6 +363,7 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): f.write("") tool = AgentFactory.create_sql_tool( + name="Teste", description="Execute an SQL query and return the result", source="ftest.db", source_type="sqlite", @@ -394,11 +396,13 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): assert "eve" in str(response["data"]["output"]).lower() finally: os.remove("ftest.db") - agent.delete() + if agent: + agent.delete() @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents + agent = None try: import os import pandas as pd @@ -424,7 +428,11 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): # Create SQL tool from CSV tool = AgentFactory.create_sql_tool( - description="Execute SQL queries on employee data", source="test.csv", source_type="csv", tables=["employees"] + name="CSV Tool Test", + description="Execute SQL queries on employee data", + source="test.csv", + source_type="csv", + tables=["employees"], ) # Verify tool setup @@ -470,9 +478,12 @@ def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): finally: # Cleanup - os.remove("test.csv") - os.remove("test.db") - agent.delete() + if agent: + agent.delete() + if os.path.exists("test.csv"): + os.remove("test.csv") + if os.path.exists("test.db"): + os.remove("test.db") @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) diff --git a/tests/unit/agent/agent_factory_utils_test.py b/tests/unit/agent/agent_factory_utils_test.py index c0104f5c..f1b4da70 100644 --- a/tests/unit/agent/agent_factory_utils_test.py +++ b/tests/unit/agent/agent_factory_utils_test.py @@ -12,6 +12,7 @@ from aixplain.modules.agent import Agent from aixplain.modules.agent.agent_task import AgentTask from aixplain.factories import ModelFactory, PipelineFactory +import os @pytest.fixture @@ -143,7 +144,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): "name": "Test SQL", "description": "Test SQL", "parameters": [ - {"name": "database", "value": "test_db"}, + {"name": "database", "value": "test_db.db"}, {"name": "schema", "value": "public"}, {"name": "tables", "value": "table1,table2"}, {"name": "enable_commit", "value": True}, @@ -153,7 +154,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): { "name": "Test SQL", "description": "Test SQL", - "database": "test_db", + "database": "test_db.db", "schema": "public", "tables": ["table1", "table2"], "enable_commit": True, @@ -166,7 +167,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): "name": "Test SQL", "description": "Test SQL with string enable_commit", "parameters": [ - {"name": "database", "value": "test_db"}, + {"name": "database", "value": "test_db.db"}, {"name": "schema", "value": "public"}, {"name": "tables", "value": "table1"}, {"name": "enable_commit", "value": True}, @@ -176,7 +177,7 @@ def test_build_tool_error_cases(tool_dict, expected_error): { "name": "Test SQL", "description": "Test SQL with string enable_commit", - "database": "test_db", + "database": "test_db.db", "schema": "public", "tables": ["table1"], "enable_commit": True, @@ -192,6 +193,8 @@ def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock "aixplain.modules.model.utils.parse_code_decorated", return_value=("print('Hello World')", [], "Test description", "test_name"), ) + mocker.patch("os.path.exists", lambda path: True if path == "test_db.db" else os.path.exists(path)) + mocker.patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://mocked-file-path/test_db.db") if tool_dict["type"] == "pipeline": mocker.patch.object( PipelineFactory, @@ -205,6 +208,8 @@ def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock for attr, value in expected_attrs.items(): if attr == "model": assert tool.model == mock_model + elif attr == "database" and value == "test_db.db": + assert getattr(tool, attr) == "s3://mocked-file-path/test_db.db" else: assert getattr(tool, attr) == value diff --git a/tests/unit/agent/sql_tool_test.py b/tests/unit/agent/sql_tool_test.py index 1a1dbb84..9f03a279 100644 --- a/tests/unit/agent/sql_tool_test.py +++ b/tests/unit/agent/sql_tool_test.py @@ -59,7 +59,9 @@ def test_create_sql_tool(mocker, tmp_path): ) assert isinstance(tool, SQLTool) assert tool.description == "Test" - assert tool.database == db_path + assert os.path.basename(db_path) in os.path.basename(tool.database) + assert tool.database.startswith("s3://") + assert tool.database.endswith(".db") assert tool.schema == "test" assert tool.tables == ["test", "test2"] @@ -78,7 +80,7 @@ def test_create_sql_tool(mocker, tmp_path): tool_dict = tool.to_dict() assert tool_dict["description"] == "Test" assert tool_dict["parameters"] == [ - {"name": "database", "value": db_path}, + {"name": "database", "value": tool.database}, {"name": "schema", "value": "test"}, {"name": "tables", "value": "test,test2"}, {"name": "enable_commit", "value": False}, @@ -89,7 +91,8 @@ def test_create_sql_tool(mocker, tmp_path): mocker.patch("os.path.exists", return_value=True) mocker.patch("aixplain.modules.agent.tool.sql_tool.get_table_schema", return_value="CREATE TABLE test (id INTEGER)") tool.validate() - assert tool.database == "s3://test.db" + assert tool.database.startswith("s3://") + assert tool.database.endswith(".db") def test_create_database_from_csv(tmp_path): @@ -225,7 +228,8 @@ def test_create_sql_tool_with_schema_inference(tmp_path, mocker): tool.validate() assert tool.schema == schema assert tool.tables == ["test"] - assert tool.database == "s3://test.db" + assert tool.database.startswith("s3://") + assert tool.database.endswith(".db") def test_create_sql_tool_from_csv_with_warnings(tmp_path, mocker): @@ -283,7 +287,7 @@ def test_create_sql_tool_from_csv(tmp_path): assert isinstance(tool, SQLTool) assert tool.description == "Test" assert tool.database.endswith(".db") - assert os.path.exists(tool.database) + assert tool.database.startswith("s3://") # Test schema and table inference during validation try: From de6a1c4739594d474e61ae9fbbccbe717ec5ddd5 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Fri, 9 May 2025 14:08:54 -0300 Subject: [PATCH 20/22] ENG-2028: model streaming (#506) * Enabling streaming * ModelStreamer * Adding debug logging messages * Streaming service and tests * Parse new JSON content * Remove full_content from model streamer --- aixplain/factories/model_factory/utils.py | 1 + aixplain/modules/model/__init__.py | 27 ++++++++++++++++-- aixplain/modules/model/llm_model.py | 13 ++++++--- .../modules/model/model_response_streamer.py | 28 +++++++++++++++++++ aixplain/modules/model/utils.py | 8 +++++- tests/functional/model/run_model_test.py | 21 ++++++++++++-- tests/unit/model_test.py | 20 +++++++++++-- 7 files changed, 106 insertions(+), 12 deletions(-) create mode 100644 aixplain/modules/model/model_response_streamer.py diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 589bdf68..a5f489be 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -98,6 +98,7 @@ def create_model_from_response(response: Dict) -> Model: version=response["version"]["id"], inputs=inputs, temperature=temperature, + supports_streaming=response.get("supportsStreaming", False), status=status, ) diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 49cf4591..92e48990 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -25,6 +25,7 @@ import traceback from aixplain.enums import Supplier, Function from aixplain.modules.asset import Asset +from aixplain.modules.model.model_response_streamer import ModelResponseStreamer from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config from urllib.parse import urljoin @@ -56,6 +57,7 @@ class Model(Asset): input_params (ModelParameters, optional): input parameters for the function. output_params (Dict, optional): output parameters for the function. model_params (ModelParameters, optional): parameters for the function. + supports_streaming (bool, optional): whether the model supports streaming. Defaults to False. """ def __init__( @@ -73,6 +75,7 @@ def __init__( input_params: Optional[Dict] = None, output_params: Optional[Dict] = None, model_params: Optional[Dict] = None, + supports_streaming: bool = False, status: Optional[AssetStatus] = AssetStatus.ONBOARDED, # default status for models is ONBOARDED **additional_info, ) -> None: @@ -91,6 +94,7 @@ def __init__( input_params (Dict, optional): input parameters for the function. output_params (Dict, optional): output parameters for the function. model_params (Dict, optional): parameters for the function. + supports_streaming (bool, optional): whether the model supports streaming. Defaults to False. status (AssetStatus, optional): status of the model. Defaults to None. **additional_info: Any additional Model info to be saved """ @@ -105,6 +109,7 @@ def __init__( self.input_params = input_params self.output_params = output_params self.model_params = ModelParameters(model_params) if model_params else None + self.supports_streaming = supports_streaming if isinstance(status, str): try: status = AssetStatus(status) @@ -232,6 +237,19 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: completed=False, ) + def run_stream( + self, + data: Union[Text, Dict], + parameters: Optional[Dict] = None, + ) -> ModelResponseStreamer: + assert self.supports_streaming, f"Model '{self.name} ({self.id})' does not support streaming" + payload = build_payload(data=data, parameters=parameters, stream=True) + url = f"{self.url}/{self.id}".replace("api/v1/execute", "api/v2/execute") + logging.debug(f"Model Run Stream: Start service for {url} - {payload}") + headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + r = _request_with_retry("post", url, headers=headers, data=payload, stream=True) + return ModelResponseStreamer(r.iter_lines(decode_unicode=True)) + def run( self, data: Union[Text, Dict], @@ -239,7 +257,8 @@ def run( timeout: float = 300, parameters: Optional[Dict] = None, wait_time: float = 0.5, - ) -> ModelResponse: + stream: bool = False, + ) -> Union[ModelResponse, ModelResponseStreamer]: """Runs a model call. Args: @@ -248,10 +267,12 @@ def run( timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. - + stream (bool, optional): whether the model supports streaming. Defaults to False. Returns: - Dict: parsed output from model + Union[ModelResponse, ModelStreamer]: parsed output from model """ + if stream: + return self.run_stream(data=data, parameters=parameters) start = time.time() payload = build_payload(data=data, parameters=parameters) url = f"{self.url}/{self.id}".replace("api/v1/execute", "api/v2/execute") diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index 0b8e6cf8..546be1b8 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -25,6 +25,7 @@ import traceback from aixplain.enums import Function, Supplier from aixplain.modules.model import Model +from aixplain.modules.model.model_response_streamer import ModelResponseStreamer from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config from typing import Union, Optional, List, Text, Dict @@ -108,7 +109,8 @@ def run( timeout: float = 300, parameters: Optional[Dict] = None, wait_time: float = 0.5, - ) -> ModelResponse: + stream: bool = False, + ) -> Union[ModelResponse, ModelResponseStreamer]: """Synchronously running a Large Language Model (LLM) model. Args: @@ -123,9 +125,9 @@ def run( timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to None. wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. - + stream (bool, optional): whether the model supports streaming. Defaults to False. Returns: - Dict: parsed output from model + Union[ModelResponse, ModelStreamer]: parsed output from model """ start = time.time() parameters = parameters or {} @@ -141,10 +143,13 @@ def run( parameters.setdefault("max_tokens", max_tokens) parameters.setdefault("top_p", top_p) + if stream: + return self.run_stream(data=data, parameters=parameters) + payload = build_payload(data=data, parameters=parameters) logging.info(payload) url = f"{self.url}/{self.id}".replace("/api/v1/execute", "/api/v2/execute") - logging.debug(f"Model Run Sync: Start service for {name} - {url}") + logging.debug(f"Model Run Sync: Start service for {name} - {url} - {payload}") response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) if response["status"] == "IN_PROGRESS": try: diff --git a/aixplain/modules/model/model_response_streamer.py b/aixplain/modules/model/model_response_streamer.py new file mode 100644 index 00000000..f1fef5f8 --- /dev/null +++ b/aixplain/modules/model/model_response_streamer.py @@ -0,0 +1,28 @@ +import json +from typing import Iterator + +from aixplain.modules.model.response import ModelResponse, ResponseStatus + + +class ModelResponseStreamer: + def __init__(self, iterator: Iterator): + self.iterator = iterator + self.status = ResponseStatus.IN_PROGRESS + + def __next__(self): + """ + Returns the next chunk of the response. + """ + line = next(self.iterator).replace("data: ", "") + try: + data = json.loads(line) + except json.JSONDecodeError: + data = {"data": line} + content = data.get("data", "") + if content == "[DONE]": + self.status = ResponseStatus.SUCCESS + content = "" + return ModelResponse(status=self.status, data=content) + + def __iter__(self): + return self diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 6f3f9319..cc68a347 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -7,12 +7,17 @@ from aixplain.exceptions import get_error_from_status_code -def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): +def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None, stream: Optional[bool] = None): from aixplain.factories import FileFactory if parameters is None: parameters = {} + if stream is not None: + if "options" not in parameters: + parameters["options"] = {} + parameters["options"]["stream"] = stream + data = FileFactory.to_link(data) if isinstance(data, dict): payload = data @@ -36,6 +41,7 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: resp = "unspecified error" try: + logging.debug(f"Calling {url} with payload: {payload}") r = _request_with_retry("post", url, headers=headers, data=payload) resp = r.json() except Exception as e: diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index dd8cb04e..7fbb74d1 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -46,6 +46,25 @@ def test_llm_run(llm_model): assert response["status"] == "SUCCESS" +def test_llm_run_stream(): + """Testing LLMs with streaming""" + from aixplain.modules.model.response import ModelResponse, ResponseStatus + from aixplain.modules.model.model_response_streamer import ModelResponseStreamer + + llm_model = ModelFactory.get("669a63646eb56306647e1091") + + assert isinstance(llm_model, LLM) + response = llm_model.run( + data="This is a test prompt where I expect you to respond with the following phrase: 'This is a test response.'", + stream=True, + ) + assert isinstance(response, ModelResponseStreamer) + for chunk in response: + assert isinstance(chunk, ModelResponse) + assert chunk.data in "This is a test response." + assert response.status == ResponseStatus.SUCCESS + + def test_run_async(): """Testing Model Async""" model = ModelFactory.get("60ddef828d38c51c5885d491") @@ -100,7 +119,6 @@ def run_index_model(index_model): pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="AIR - BGE M3"), pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="AIR - aiXplain Legal Embeddings"), - ], ) def test_index_model(embedding_model, supplier_params): @@ -126,7 +144,6 @@ def test_index_model(embedding_model, supplier_params): pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"), pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="aiXplain Legal Embeddings"), - ], ) def test_index_model_with_filter(embedding_model, supplier_params): diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 943d501d..57ef17d0 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -25,8 +25,8 @@ from aixplain.factories import ModelFactory from aixplain.enums import Function from urllib.parse import urljoin -from aixplain.enums import ResponseStatus -from aixplain.modules.model.response import ModelResponse +from aixplain.modules.model.response import ModelResponse, ResponseStatus +from aixplain.modules.model.model_response_streamer import ModelResponseStreamer import pytest from unittest.mock import patch from aixplain.enums.asset_status import AssetStatus @@ -638,3 +638,19 @@ def test_empty_model_parameters_string(): """Test string representation of empty ModelParameters.""" params = ModelParameters({}) assert str(params) == "No parameters defined" + + +def test_model_response_streamer(): + """Test ModelResponseStreamer class.""" + streamer = ModelResponseStreamer(iter([])) + assert isinstance(streamer, ModelResponseStreamer) + assert streamer.status == ResponseStatus.IN_PROGRESS + + +def test_model_not_supports_streaming(mocker): + """Test ModelResponseStreamer class.""" + mocker.patch("aixplain.modules.model.utils.build_payload", return_value={"data": "test"}) + model = Model(id="test-id", name="Test Model", supports_streaming=False) + with pytest.raises(Exception) as excinfo: + model.run(data="test", stream=True) + assert f"Model '{model.name} ({model.id})' does not support streaming" in str(excinfo.value) From b909a5418357987d683a58490f2e4cc47b5fec32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20G=C3=BCnd=C3=BCz?= Date: Fri, 9 May 2025 20:32:54 +0300 Subject: [PATCH 21/22] BUG-542-Utility-Model-Update-Test-Failing-in-SDK (#515) * set utility models status to draft if updated * BUG-542: fix utility model tests --- aixplain/modules/model/utility_model.py | 1 - tests/functional/model/run_utility_model_test.py | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 1cbbf3da..f6b69ba2 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -237,7 +237,6 @@ def update(self): DeprecationWarning, stacklevel=2, ) - self.validate() url = urljoin(self.backend_url, f"sdk/utilities/{self.id}") headers = {"x-api-key": f"{self.api_key}", "Content-Type": "application/json"} diff --git a/tests/functional/model/run_utility_model_test.py b/tests/functional/model/run_utility_model_test.py index a2c67ca9..3b6616ae 100644 --- a/tests/functional/model/run_utility_model_test.py +++ b/tests/functional/model/run_utility_model_test.py @@ -272,7 +272,7 @@ def sum_numbers(num1: int, num2: int): # Verify updated state updated_model = ModelFactory.get(utility_model.id) - assert updated_model.status == AssetStatus.DRAFT + assert updated_model.status == AssetStatus.ONBOARDED assert updated_model.name == "sum_numbers_test" assert updated_model.description == "Updated to sum numbers utility" assert len(updated_model.inputs) == 2 @@ -298,17 +298,19 @@ def multiply_numbers(num1: int, num2: int): return num1 * num2 updated_model.code = multiply_numbers - assert updated_model.status == AssetStatus.DRAFT - updated_model.deploy() assert updated_model.status == AssetStatus.ONBOARDED + # the next line should raise an exception + with pytest.raises(Exception, match=".*UtilityModel is already deployed*"): + updated_model.deploy() updated_model.save() + assert updated_model.status == AssetStatus.ONBOARDED # Verify partial update final_model = ModelFactory.get(utility_model.id) assert final_model.name == "sum_numbers_test" assert final_model.description == "Updated to sum numbers utility" - assert final_model.status == AssetStatus.DRAFT + assert final_model.status == AssetStatus.ONBOARDED # Test final behavior with new function but same input field names response = final_model.run({"num1": 5, "num2": 7}) assert response.status == "SUCCESS" From 8aa3d3192ea8dbb312ba76fe6dffb204befe53e0 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira <85182544+thiago-aixplain@users.noreply.github.com> Date: Fri, 9 May 2025 16:25:10 -0300 Subject: [PATCH 22/22] ENG-2115 fixing agent tests (#522) * Fixing agent functional tests * Fixing team agent test * Fixing test_instructions of team agents --- tests/functional/agent/agent_functional_test.py | 11 ++--------- .../team_agent/team_agent_functional_test.py | 4 ++-- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 10569440..c7c294fe 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -137,7 +137,7 @@ def test_custom_code_tool(delete_agents_and_team_agents, AgentFactory): ) assert tool is not None assert tool.description == "Add two strings" - assert tool.code == 'def main(aaa: str, bbb: str) -> str:\n """Add two strings"""\n return aaa + bbb' + assert tool.code.startswith("s3://") agent = AgentFactory.create( name="Add Strings Agent", description="Add two strings. Do not directly answer. Use the tool to add the strings.", @@ -399,6 +399,7 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): if agent: agent.delete() + @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents @@ -510,13 +511,6 @@ def test_instructions(delete_agents_and_team_agents, AgentFactory): assert response["data"]["session_id"] is not None assert response["data"]["output"] is not None assert "aixplain" in response["data"]["output"].lower() - assert "eve" in response["data"]["output"].lower() - - import os - - # Cleanup - os.remove("test.csv") - os.remove("test.db") agent.delete() @@ -606,4 +600,3 @@ def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory): assert "hello" in answer["data"]["output"].lower() assert "hello pipeline" in answer["data"]["intermediate_steps"][0]["tool_steps"][0]["tool"].lower() - diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 759f3920..baa0a34b 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -625,9 +625,9 @@ def test_team_agent_with_instructions(delete_agents_and_team_agents): assert response.status == "SUCCESS" assert "gato" in response.data["output"] - mentalist_steps = eval(response.data["intermediate_steps"][0]["output"]) + mentalist_steps = [json.loads(step) for step in eval(response.data["intermediate_steps"][0]["output"])] - called_agents = set([step["worker"] for step in mentalist_steps]) + called_agents = set([step["agent"] for step in mentalist_steps]) assert len(called_agents) == 1 assert "Agent 2" in called_agents