diff --git a/aixplain/factories/wallet_factory.py b/aixplain/factories/wallet_factory.py index 59ec7c14..b0a55b65 100644 --- a/aixplain/factories/wallet_factory.py +++ b/aixplain/factories/wallet_factory.py @@ -2,6 +2,7 @@ from aixplain.modules.wallet import Wallet from aixplain.utils.file_utils import _request_with_retry import logging +from typing import Text class WalletFactory: @@ -9,16 +10,14 @@ class WalletFactory: backend_url = config.BACKEND_URL @classmethod - def get(cls) -> Wallet: + def get(cls, api_key: Text = config.TEAM_API_KEY) -> Wallet: """Get wallet information""" try: resp = None - # Check for code 200, other code will be caught when trying to return a Wallet object url = f"{cls.backend_url}/sdk/billing/wallet" - - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} logging.info(f"Start fetching billing information from - {url} - {headers}") - headers = {"Content-Type": "application/json", "x-api-key": config.TEAM_API_KEY} + headers = {"Content-Type": "application/json", "x-api-key": api_key} r = _request_with_retry("get", url, headers=headers) resp = r.json() return Wallet(total_balance=resp["totalBalance"], reserved_balance=resp["reservedBalance"]) diff --git a/aixplain/factories/wallet_factoy.py b/aixplain/factories/wallet_factoy.py deleted file mode 100644 index 59ec7c14..00000000 --- a/aixplain/factories/wallet_factoy.py +++ /dev/null @@ -1,26 +0,0 @@ -import aixplain.utils.config as config -from aixplain.modules.wallet import Wallet -from aixplain.utils.file_utils import _request_with_retry -import logging - - -class WalletFactory: - aixplain_key = config.AIXPLAIN_API_KEY - backend_url = config.BACKEND_URL - - @classmethod - def get(cls) -> Wallet: - """Get wallet information""" - try: - resp = None - # Check for code 200, other code will be caught when trying to return a Wallet object - url = f"{cls.backend_url}/sdk/billing/wallet" - - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - logging.info(f"Start fetching billing information from - {url} - {headers}") - headers = {"Content-Type": "application/json", "x-api-key": config.TEAM_API_KEY} - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - return Wallet(total_balance=resp["totalBalance"], reserved_balance=resp["reservedBalance"]) - except Exception as e: - raise Exception(f"Failed to get the wallet credit information. Error: {str(e)}") diff --git a/tests/unit/wallet_test.py b/tests/unit/wallet_test.py index 48ee19ab..16561dba 100644 --- a/tests/unit/wallet_test.py +++ b/tests/unit/wallet_test.py @@ -11,6 +11,6 @@ def test_wallet_service(): headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} ref_response = {"totalBalance": 5, "reservedBalance": "0"} mock.get(url, headers=headers, json=ref_response) - wallet = WalletFactory.get() + wallet = WalletFactory.get(config.AIXPLAIN_API_KEY) assert wallet.total_balance == ref_response["totalBalance"] assert wallet.reserved_balance == ref_response["reservedBalance"]