diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 4de4e582..134b3560 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -34,7 +34,7 @@ from aixplain.utils import config from typing import Dict, List, Optional, Text, Union -from aixplain.factories.agent_factory.utils import build_agent +from aixplain.factories.agent_factory.utils import build_agent, validate_llm from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin @@ -50,8 +50,30 @@ def create( api_key: Text = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, + use_mentalist_and_inspector: bool = False, ) -> Agent: - """Create a new agent in the platform.""" + """Create a new agent in the platform. + + Args: + name (Text): name of the agent + llm_id (Text): aiXplain ID of the large language model to be used as agent. + tools (List[Tool], optional): list of tool for the agent. Defaults to []. + description (Text, optional): description of the agent role. Defaults to "". + api_key (Text, optional): team/user API key. Defaults to config.TEAM_API_KEY. + supplier (Union[Dict, Text, Supplier, int], optional): owner of the agent. Defaults to "aiXplain". + version (Optional[Text], optional): version of the agent. Defaults to None. + use_mentalist_and_inspector (bool, optional): flag to enable mentalist and inspector agents (which only works when a supervisor is enabled). Defaults to False. + + Returns: + Agent: created Agent + """ + # validate LLM ID + validate_llm(llm_id) + + orchestrator_llm_id, mentalist_and_inspector_llm_id = llm_id, None + if use_mentalist_and_inspector is True: + mentalist_and_inspector_llm_id = llm_id + try: agent = None url = urljoin(config.BACKEND_URL, "sdk/agents") @@ -94,9 +116,10 @@ def create( "description": description, "supplier": supplier, "version": version, + "llmId": llm_id, + "supervisorId": orchestrator_llm_id, + "plannerId": mentalist_and_inspector_llm_id, } - if llm_id is not None: - payload["llmId"] = llm_id logging.info(f"Start service for POST Create Agent - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("post", url, headers=headers, json=payload) diff --git a/aixplain/factories/agent_factory/utils.py b/aixplain/factories/agent_factory/utils.py index 4b314ef7..6aed75ae 100644 --- a/aixplain/factories/agent_factory/utils.py +++ b/aixplain/factories/agent_factory/utils.py @@ -47,3 +47,13 @@ def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent: ) agent.url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") return agent + + +def validate_llm(model_id: Text) -> None: + from aixplain.factories.model_factory import ModelFactory + + try: + llm = ModelFactory.get(model_id) + assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model." + except Exception: + raise Exception(f"Large Language Model with ID '{model_id}' not found.") diff --git a/aixplain/factories/wallet_factory.py b/aixplain/factories/wallet_factory.py index 59ec7c14..b36000f1 100644 --- a/aixplain/factories/wallet_factory.py +++ b/aixplain/factories/wallet_factory.py @@ -2,6 +2,7 @@ from aixplain.modules.wallet import Wallet from aixplain.utils.file_utils import _request_with_retry import logging +from typing import Text class WalletFactory: @@ -9,18 +10,19 @@ class WalletFactory: backend_url = config.BACKEND_URL @classmethod - def get(cls) -> Wallet: + def get(cls, api_key: Text = config.TEAM_API_KEY) -> Wallet: """Get wallet information""" try: resp = None - # Check for code 200, other code will be caught when trying to return a Wallet object url = f"{cls.backend_url}/sdk/billing/wallet" - - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} logging.info(f"Start fetching billing information from - {url} - {headers}") - headers = {"Content-Type": "application/json", "x-api-key": config.TEAM_API_KEY} + headers = {"Content-Type": "application/json", "x-api-key": api_key} r = _request_with_retry("get", url, headers=headers) resp = r.json() - return Wallet(total_balance=resp["totalBalance"], reserved_balance=resp["reservedBalance"]) + total_balance = float(resp.get("totalBalance", 0.0)) + reserved_balance = float(resp.get("reservedBalance", 0.0)) + + return Wallet(total_balance=total_balance, reserved_balance=reserved_balance) except Exception as e: raise Exception(f"Failed to get the wallet credit information. Error: {str(e)}") diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 4be40225..8fcd80d2 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -239,11 +239,27 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param resp = None try: - resp = r.json() - logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - - poll_url = resp["data"] - response = {"status": "IN_PROGRESS", "url": poll_url} + if 200 <= r.status_code < 300: + resp = r.json() + logging.info(f"Result of request for {name} - {r.status_code} - {resp}") + poll_url = resp["data"] + response = {"status": "IN_PROGRESS", "url": poll_url} + else: + if r.status_code == 401: + error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." + elif 460 <= r.status_code < 470: + error = "Subscription-related error: Please ensure that your subscription is active and has not expired." + elif 470 <= r.status_code < 480: + error = "Billing-related error: Please ensure you have enough credits to run this model. " + elif 480 <= r.status_code < 490: + error = "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access." + elif 490 <= r.status_code < 500: + error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." + else: + status_code = str(r.status_code) + error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + response = {"status": "FAILED", "error_message": error} + logging.error(f"Error in request for {name} - {r.status_code}: {error}") except Exception: response = {"status": "FAILED"} msg = f"Error in request for {name} - {traceback.format_exc()}" diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index 5c5c4140..c595d207 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -214,11 +214,27 @@ def run_async( resp = None try: - resp = r.json() - logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - - poll_url = resp["data"] - response = {"status": "IN_PROGRESS", "url": poll_url} + if 200 <= r.status_code < 300: + resp = r.json() + logging.info(f"Result of request for {name} - {r.status_code} - {resp}") + poll_url = resp["data"] + response = {"status": "IN_PROGRESS", "url": poll_url} + else: + if r.status_code == 401: + error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." + elif 460 <= r.status_code < 470: + error = "Subscription-related error: Please ensure that your subscription is active and has not expired." + elif 470 <= r.status_code < 480: + error = "Billing-related error: Please ensure you have enough credits to run this model. " + elif 480 <= r.status_code < 490: + error = "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access." + elif 490 <= r.status_code < 500: + error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." + else: + status_code = str(r.status_code) + error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + response = {"status": "FAILED", "error_message": error} + logging.error(f"Error in request for {name} - {r.status_code}: {error}") except Exception: response = {"status": "FAILED"} msg = f"Error in request for {name} - {traceback.format_exc()}" diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index ad7cfa1b..860a08a5 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -323,11 +323,27 @@ def run_async( resp = None try: - resp = r.json() - logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - - poll_url = resp["url"] - response = {"status": "IN_PROGRESS", "url": poll_url} + if 200 <= r.status_code < 300: + resp = r.json() + logging.info(f"Result of request for {name} - {r.status_code} - {resp}") + poll_url = resp["url"] + response = {"status": "IN_PROGRESS", "url": poll_url} + else: + if r.status_code == 401: + error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." + elif 460 <= r.status_code < 470: + error = "Subscription-related error: Please ensure that your subscription is active and has not expired." + elif 470 <= r.status_code < 480: + error = "Billing-related error: Please ensure you have enough credits to run this pipeline. " + elif 480 <= r.status_code < 490: + error = "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access." + elif 490 <= r.status_code < 500: + error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." + else: + status_code = str(r.status_code) + error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." + response = {"status": "FAILED", "error_message": error} + logging.error(f"Error in request for {name} - {r.status_code}: {error}") except Exception: response = {"status": "FAILED"} if resp is not None: diff --git a/aixplain/modules/wallet.py b/aixplain/modules/wallet.py index d7c63524..2b2b1cd4 100644 --- a/aixplain/modules/wallet.py +++ b/aixplain/modules/wallet.py @@ -24,11 +24,12 @@ class Wallet: def __init__(self, total_balance: float, reserved_balance: float): - """Create a Wallet with the necessary information - + """ Args: total_balance (float): total credit balance reserved_balance (float): reserved credit balance + available_balance (float): available balance (total - credit) """ self.total_balance = total_balance self.reserved_balance = reserved_balance + self.available_balance = total_balance-reserved_balance \ No newline at end of file diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index f6ff0408..0acdb5be 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -75,3 +75,9 @@ def test_list_agents(): assert "results" in agents agents_result = agents["results"] assert type(agents_result) is list + + +def test_fail_non_existent_llm(): + with pytest.raises(Exception) as exc_info: + AgentFactory.create(name="Test Agent", llm_id="non_existent_llm", tools=[]) + assert str(exc_info.value) == "Large Language Model with ID 'non_existent_llm' not found." diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py index 18c92fa3..8a619011 100644 --- a/tests/unit/agent_test.py +++ b/tests/unit/agent_test.py @@ -4,6 +4,7 @@ from aixplain.utils import config from aixplain.factories import AgentFactory from aixplain.modules.agent import PipelineTool, ModelTool +from urllib.parse import urljoin def test_fail_no_data_query(): @@ -77,3 +78,57 @@ def test_invalid_modeltool(): with pytest.raises(Exception) as exc_info: AgentFactory.create(name="Test", tools=[ModelTool(model="309851793")], llm_id="6646261c6eb563165658bbb1") assert str(exc_info.value) == "Model Tool Unavailable. Make sure Model '309851793' exists or you have access to it." + + +def test_create_agent(): + from aixplain.enums import Supplier + + with requests_mock.Mocker() as mock: + url = urljoin(config.BACKEND_URL, "sdk/agents") + headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + + ref_response = { + "id": "123", + "name": "Test Agent", + "description": "Test Agent Description", + "teamId": "123", + "version": "1.0", + "status": "onboarded", + "llmId": "6646261c6eb563165658bbb1", + "pricing": {"currency": "USD", "value": 0.0}, + "assets": [ + { + "type": "model", + "supplier": "openai", + "version": "1.0", + "assetId": "6646261c6eb563165658bbb1", + "function": "text-generation", + } + ], + } + mock.post(url, headers=headers, json=ref_response) + + url = urljoin(config.BACKEND_URL, "sdk/models/6646261c6eb563165658bbb1") + model_ref_response = { + "id": "6646261c6eb563165658bbb1", + "name": "Test LLM", + "description": "Test LLM Description", + "function": {"id": "text-generation"}, + "supplier": "openai", + "version": {"id": "1.0"}, + "status": "onboarded", + "pricing": {"currency": "USD", "value": 0.0}, + } + mock.get(url, headers=headers, json=model_ref_response) + + agent = AgentFactory.create( + name="Test Agent", + description="Test Agent Description", + llm_id="6646261c6eb563165658bbb1", + tools=[AgentFactory.create_model_tool(supplier=Supplier.OPENAI, function="text-generation")], + ) + + assert agent.name == ref_response["name"] + assert agent.description == ref_response["description"] + assert agent.llm_id == ref_response["llmId"] + assert agent.tools[0].function.value == ref_response["assets"][0]["function"] diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py new file mode 100644 index 00000000..430fc338 --- /dev/null +++ b/tests/unit/llm_test.py @@ -0,0 +1,36 @@ + +from dotenv import load_dotenv +from urllib.parse import urljoin +import requests_mock +from aixplain.enums import Function + +load_dotenv() +from aixplain.utils import config +from aixplain.modules import LLM + +import pytest + +@pytest.mark.parametrize( + "status_code,error_message", + [ + (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), + (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), + (475,"Billing-related error: Please ensure you have enough credits to run this model. "), + (485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."), + (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), + (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), + + ], +) + +def test_run_async_errors(status_code, error_message): + base_url = config.MODELS_RUN_URL + llm_id = "llm-id" + execute_url = urljoin(base_url, f"execute/{llm_id}") + + with requests_mock.Mocker() as mock: + mock.post(execute_url, status_code=status_code) + test_llm = LLM(id=llm_id, name="Test llm",url=base_url, function=Function.TEXT_GENERATION) + response = test_llm.run_async(data="input_data") + assert response["status"] == "FAILED" + assert response["error_message"] == error_message \ No newline at end of file diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 269c821e..cd6f7a5a 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -17,10 +17,11 @@ """ from dotenv import load_dotenv +from urllib.parse import urljoin +import requests_mock load_dotenv() import re -import requests_mock from aixplain.utils import config from aixplain.modules import Model @@ -57,3 +58,29 @@ def test_failed_poll(): assert hyp_response["error"] == ref_response["error"] assert hyp_response["supplierError"] == ref_response["supplierError"] assert hyp_response["status"] == "FAILED" + + +@pytest.mark.parametrize( + "status_code,error_message", + [ + (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), + (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), + (475,"Billing-related error: Please ensure you have enough credits to run this model. "), + (485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."), + (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), + (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), + + ], +) + +def test_run_async_errors(status_code, error_message): + base_url = config.MODELS_RUN_URL + model_id = "model-id" + execute_url = urljoin(base_url, f"execute/{model_id}") + + with requests_mock.Mocker() as mock: + mock.post(execute_url, status_code=status_code) + test_model = Model(id=model_id, name="Test Model",url=base_url) + response = test_model.run_async(data="input_data") + assert response["status"] == "FAILED" + assert response["error_message"] == error_message \ No newline at end of file diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index e983a298..d3c1c725 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -17,6 +17,7 @@ """ from dotenv import load_dotenv +import pytest load_dotenv() import requests_mock @@ -36,3 +37,28 @@ def test_create_pipeline(): hyp_pipeline = PipelineFactory.create(pipeline={"nodes": []}, name="Pipeline Test") assert hyp_pipeline.id == ref_pipeline.id assert hyp_pipeline.name == ref_pipeline.name + +@pytest.mark.parametrize( + "status_code,error_message", + [ + (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), + (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), + (475,"Billing-related error: Please ensure you have enough credits to run this pipeline. "), + (485, "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access."), + (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), + (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), + + ], +) + +def test_run_async_errors(status_code, error_message): + base_url = config.BACKEND_URL + pipeline_id = "pipeline_id" + execute_url = f"{base_url}/assets/pipeline/execution/run/{pipeline_id}" + + with requests_mock.Mocker() as mock: + mock.post(execute_url, status_code=status_code) + test_pipeline = Pipeline(id=pipeline_id, api_key=config.TEAM_API_KEY, name="Test Pipeline", url=base_url) + response = test_pipeline.run_async(data="input_data") + assert response["status"] == "FAILED" + assert response["error_message"] == error_message \ No newline at end of file diff --git a/tests/unit/wallet_test.py b/tests/unit/wallet_test.py index 48ee19ab..50acbbdb 100644 --- a/tests/unit/wallet_test.py +++ b/tests/unit/wallet_test.py @@ -12,5 +12,6 @@ def test_wallet_service(): ref_response = {"totalBalance": 5, "reservedBalance": "0"} mock.get(url, headers=headers, json=ref_response) wallet = WalletFactory.get() - assert wallet.total_balance == ref_response["totalBalance"] - assert wallet.reserved_balance == ref_response["reservedBalance"] + assert wallet.total_balance == float(ref_response["totalBalance"]) + assert wallet.reserved_balance == float(ref_response["reservedBalance"]) + assert wallet.available_balance == float(ref_response["totalBalance"]) - float(ref_response["reservedBalance"])