diff --git a/aixplain/factories/wallet_factory.py b/aixplain/factories/wallet_factory.py index b0a55b65..b36000f1 100644 --- a/aixplain/factories/wallet_factory.py +++ b/aixplain/factories/wallet_factory.py @@ -20,6 +20,9 @@ def get(cls, api_key: Text = config.TEAM_API_KEY) -> Wallet: headers = {"Content-Type": "application/json", "x-api-key": api_key} r = _request_with_retry("get", url, headers=headers) resp = r.json() - return Wallet(total_balance=resp["totalBalance"], reserved_balance=resp["reservedBalance"]) + total_balance = float(resp.get("totalBalance", 0.0)) + reserved_balance = float(resp.get("reservedBalance", 0.0)) + + return Wallet(total_balance=total_balance, reserved_balance=reserved_balance) except Exception as e: raise Exception(f"Failed to get the wallet credit information. Error: {str(e)}") diff --git a/aixplain/modules/wallet.py b/aixplain/modules/wallet.py index d7c63524..d61b04ee 100644 --- a/aixplain/modules/wallet.py +++ b/aixplain/modules/wallet.py @@ -24,11 +24,12 @@ class Wallet: def __init__(self, total_balance: float, reserved_balance: float): - """Create a Wallet with the necessary information - + """ Args: - total_balance (float): total credit balance - reserved_balance (float): reserved credit balance + total_balance (float) + reserved_balance (float) + available_balance (float) """ self.total_balance = total_balance self.reserved_balance = reserved_balance + self.available_balance = total_balance-reserved_balance diff --git a/tests/unit/wallet_test.py b/tests/unit/wallet_test.py index 16561dba..50acbbdb 100644 --- a/tests/unit/wallet_test.py +++ b/tests/unit/wallet_test.py @@ -11,6 +11,7 @@ def test_wallet_service(): headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} ref_response = {"totalBalance": 5, "reservedBalance": "0"} mock.get(url, headers=headers, json=ref_response) - wallet = WalletFactory.get(config.AIXPLAIN_API_KEY) - assert wallet.total_balance == ref_response["totalBalance"] - assert wallet.reserved_balance == ref_response["reservedBalance"] + wallet = WalletFactory.get() + assert wallet.total_balance == float(ref_response["totalBalance"]) + assert wallet.reserved_balance == float(ref_response["reservedBalance"]) + assert wallet.available_balance == float(ref_response["totalBalance"]) - float(ref_response["reservedBalance"])