diff --git a/.gitignore b/.gitignore index 843c6556..ad7c16c8 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,5 @@ dmypy.json # Vscode .vscode +.DS_Store + diff --git a/aixplain/factories/__init__.py b/aixplain/factories/__init__.py index 7b876899..70361e77 100644 --- a/aixplain/factories/__init__.py +++ b/aixplain/factories/__init__.py @@ -30,3 +30,4 @@ from .model_factory import ModelFactory from .pipeline_factory import PipelineFactory from .finetune_factory import FinetuneFactory +from .wallet_factory import WalletFactory diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 6076eef6..4de4e582 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -65,6 +65,7 @@ def create( tool_payload = [] for tool in tools: if isinstance(tool, ModelTool): + tool.validate() tool_payload.append( { "function": tool.function.value if tool.function is not None else None, @@ -76,6 +77,7 @@ def create( } ) elif isinstance(tool, PipelineTool): + tool.validate() tool_payload.append( { "assetId": tool.pipeline, diff --git a/aixplain/factories/wallet_factory.py b/aixplain/factories/wallet_factory.py new file mode 100644 index 00000000..59ec7c14 --- /dev/null +++ b/aixplain/factories/wallet_factory.py @@ -0,0 +1,26 @@ +import aixplain.utils.config as config +from aixplain.modules.wallet import Wallet +from aixplain.utils.file_utils import _request_with_retry +import logging + + +class WalletFactory: + aixplain_key = config.AIXPLAIN_API_KEY + backend_url = config.BACKEND_URL + + @classmethod + def get(cls) -> Wallet: + """Get wallet information""" + try: + resp = None + # Check for code 200, other code will be caught when trying to return a Wallet object + url = f"{cls.backend_url}/sdk/billing/wallet" + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + logging.info(f"Start fetching billing information from - {url} - {headers}") + headers = {"Content-Type": "application/json", "x-api-key": config.TEAM_API_KEY} + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + return Wallet(total_balance=resp["totalBalance"], reserved_balance=resp["reservedBalance"]) + except Exception as e: + raise Exception(f"Failed to get the wallet credit information. Error: {str(e)}") diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index e15a8bea..c88f1ee0 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -66,9 +66,8 @@ def __init__( if model is not None: if isinstance(model, Text) is True: - from aixplain.factories.model_factory import ModelFactory - - model = ModelFactory.get(model) + self.model = model + model = self.validate() function = model.function if isinstance(model.supplier, Supplier): supplier = model.supplier @@ -76,3 +75,14 @@ def __init__( self.supplier = supplier self.model = model self.function = function + + def validate(self) -> Model: + from aixplain.factories.model_factory import ModelFactory + + try: + model = None + if self.model is not None: + model = ModelFactory.get(self.model) + return model + except Exception: + raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.") diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index a517b198..5ad2915a 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -50,3 +50,11 @@ def __init__( if isinstance(pipeline, Pipeline): pipeline = pipeline.id self.pipeline = pipeline + + def validate(self): + from aixplain.factories.pipeline_factory import PipelineFactory + + try: + PipelineFactory.get(self.pipeline) + except Exception: + raise Exception(f"Pipeline Tool Unavailable. Make sure Pipeline '{self.pipeline}' exists or you have access to it.") diff --git a/aixplain/modules/wallet.py b/aixplain/modules/wallet.py new file mode 100644 index 00000000..d7c63524 --- /dev/null +++ b/aixplain/modules/wallet.py @@ -0,0 +1,34 @@ +__author__ = "aixplain" + +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: aiXplain Team +Date: August 20th 2024 +Description: + Wallet Class +""" + + +class Wallet: + def __init__(self, total_balance: float, reserved_balance: float): + """Create a Wallet with the necessary information + + Args: + total_balance (float): total credit balance + reserved_balance (float): reserved credit balance + """ + self.total_balance = total_balance + self.reserved_balance = reserved_balance diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py index 680fc21a..18c92fa3 100644 --- a/tests/unit/agent_test.py +++ b/tests/unit/agent_test.py @@ -2,6 +2,8 @@ import requests_mock from aixplain.modules import Agent from aixplain.utils import config +from aixplain.factories import AgentFactory +from aixplain.modules.agent import PipelineTool, ModelTool def test_fail_no_data_query(): @@ -61,3 +63,17 @@ def test_sucess_query_content(): response = agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input1": "Hello, how are you?"}) assert response["status"] == ref_response["status"] assert response["url"] == ref_response["data"] + + +def test_invalid_pipelinetool(): + with pytest.raises(Exception) as exc_info: + AgentFactory.create( + name="Test", tools=[PipelineTool(pipeline="309851793", description="Test")], llm_id="6646261c6eb563165658bbb1" + ) + assert str(exc_info.value) == "Pipeline Tool Unavailable. Make sure Pipeline '309851793' exists or you have access to it." + + +def test_invalid_modeltool(): + with pytest.raises(Exception) as exc_info: + AgentFactory.create(name="Test", tools=[ModelTool(model="309851793")], llm_id="6646261c6eb563165658bbb1") + assert str(exc_info.value) == "Model Tool Unavailable. Make sure Model '309851793' exists or you have access to it." diff --git a/tests/unit/wallet_test.py b/tests/unit/wallet_test.py new file mode 100644 index 00000000..48ee19ab --- /dev/null +++ b/tests/unit/wallet_test.py @@ -0,0 +1,16 @@ +__author__ = "aixplain" + +from aixplain.factories import WalletFactory +import aixplain.utils.config as config +import requests_mock + + +def test_wallet_service(): + with requests_mock.Mocker() as mock: + url = f"{config.BACKEND_URL}/sdk/billing/wallet" + headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + ref_response = {"totalBalance": 5, "reservedBalance": "0"} + mock.get(url, headers=headers, json=ref_response) + wallet = WalletFactory.get() + assert wallet.total_balance == ref_response["totalBalance"] + assert wallet.reserved_balance == ref_response["reservedBalance"]