diff --git a/pyproject.toml b/pyproject.toml index da518dcfe..5536f8065 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ dependencies = [ "aioredis==1.3.1", "aiosqlite==0.19", "alembic==1.13.1", - "aleph-message~=1.0.1", + # "aleph-message~=1.0.1", + "aleph-message @ git+https://github.com/aleph-im/aleph-message@andres-feature-implement_credits_payment", "aleph-superfluid~=0.2.1", "dbus-python==1.3.2", "eth-account~=0.10", diff --git a/src/aleph/vm/orchestrator/payment.py b/src/aleph/vm/orchestrator/payment.py index f5a79bbca..8b863451f 100644 --- a/src/aleph/vm/orchestrator/payment.py +++ b/src/aleph/vm/orchestrator/payment.py @@ -2,6 +2,7 @@ import logging from collections.abc import Iterable from decimal import Decimal +from typing import List import aiohttp from aleph_message.models import ItemHash, PaymentType @@ -44,30 +45,36 @@ async def fetch_balance_of_address(address: str) -> Decimal: return resp_data["balance"] -async def fetch_execution_flow_price(item_hash: ItemHash) -> Decimal: - """Fetch the flow price of an execution from the reference API server.""" +async def fetch_credit_balance_of_address(address: str) -> Decimal: + """ + Get the balance of the user from the PyAleph API. + + API Endpoint: + GET /api/v0/addresses/{address}/balance + + For more details, see the PyAleph API documentation: + https://github.com/aleph-im/pyaleph/blob/master/src/aleph/web/controllers/routes.py#L62 + """ + async with aiohttp.ClientSession() as session: - url = f"{settings.API_SERVER}/api/v0/price/{item_hash}" + url = f"{settings.API_SERVER}/api/v0/addresses/{address}/credit_balance" resp = await session.get(url) + + # Consider the balance as null if the address is not found + if resp.status == 404: + return Decimal(0) + # Raise an error if the request failed resp.raise_for_status() resp_data = await resp.json() - required_flow: float = resp_data["required_tokens"] - payment_type: str | None = resp_data["payment_type"] - - if payment_type is None: - msg = "Payment type must be specified in the message" - raise ValueError(msg) - elif payment_type != PaymentType.superfluid: - msg = f"Payment type {payment_type} is not supported" - raise ValueError(msg) - - return Decimal(required_flow) + return resp_data["credits"] -async def fetch_execution_hold_price(item_hash: ItemHash) -> Decimal: - """Fetch the hold price of an execution from the reference API server.""" +async def fetch_execution_price( + item_hash: ItemHash, allowed_payments: List[PaymentType], payment_type_required: bool = True +) -> Decimal: + """Fetch the credit price of an execution from the reference API server.""" async with aiohttp.ClientSession() as session: url = f"{settings.API_SERVER}/api/v0/price/{item_hash}" resp = await session.get(url) @@ -75,14 +82,19 @@ async def fetch_execution_hold_price(item_hash: ItemHash) -> Decimal: resp.raise_for_status() resp_data = await resp.json() - required_hold: float = resp_data["required_tokens"] + required_credits: float = resp_data["required_credits"] # Field not defined yet on API side. payment_type: str | None = resp_data["payment_type"] - if payment_type not in (None, PaymentType.hold): - msg = f"Payment type {payment_type} is not supported" + if payment_type_required and payment_type is None: + msg = "Payment type must be specified in the message" raise ValueError(msg) - return Decimal(required_hold) + if payment_type: + if payment_type not in allowed_payments: + msg = f"Payment type {payment_type} is not supported" + raise ValueError(msg) + + return Decimal(required_credits) class InvalidAddressError(ValueError): @@ -133,11 +145,26 @@ async def get_stream(sender: str, receiver: str, chain: str) -> Decimal: async def compute_required_balance(executions: Iterable[VmExecution]) -> Decimal: """Get the balance required for the resources of the user from the messages and the pricing aggregate.""" - costs = await asyncio.gather(*(fetch_execution_hold_price(execution.vm_hash) for execution in executions)) + costs = await asyncio.gather( + *( + fetch_execution_price(execution.vm_hash, [PaymentType.hold], payment_type_required=False) + for execution in executions + ) + ) + return sum(costs, Decimal(0)) + + +async def compute_required_credit_balance(executions: Iterable[VmExecution]) -> Decimal: + """Get the balance required for the resources of the user from the messages and the pricing aggregate.""" + costs = await asyncio.gather( + *(fetch_execution_price(execution.vm_hash, [PaymentType.credit]) for execution in executions) + ) return sum(costs, Decimal(0)) async def compute_required_flow(executions: Iterable[VmExecution]) -> Decimal: """Compute the flow required for a collection of executions, typically all executions from a specific address""" - flows = await asyncio.gather(*(fetch_execution_flow_price(execution.vm_hash) for execution in executions)) + flows = await asyncio.gather( + *(fetch_execution_price(execution.vm_hash, [PaymentType.superfluid]) for execution in executions) + ) return sum(flows, Decimal(0)) diff --git a/src/aleph/vm/orchestrator/tasks.py b/src/aleph/vm/orchestrator/tasks.py index 803d3ca32..93250b239 100644 --- a/src/aleph/vm/orchestrator/tasks.py +++ b/src/aleph/vm/orchestrator/tasks.py @@ -31,8 +31,10 @@ from .messages import get_message_status from .payment import ( compute_required_balance, + compute_required_credit_balance, compute_required_flow, fetch_balance_of_address, + fetch_credit_balance_of_address, get_stream, ) from .pubsub import PubSub @@ -187,44 +189,71 @@ async def check_payment(pool: VmPool): pool.forget_vm(vm_hash) # Check if the balance held in the wallet is sufficient holder tier resources (Not do it yet) - for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.hold).items(): + for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.hold).items(): for chain, executions in chains.items(): executions = [execution for execution in executions if execution.is_confidential] if not executions: continue - balance = await fetch_balance_of_address(sender) + balance = await fetch_balance_of_address(execution_address) # Stop executions until the required balance is reached required_balance = await compute_required_balance(executions) - logger.debug(f"Required balance for Sender {sender} executions: {required_balance}, {executions}") + logger.debug( + f"Required balance for Sender {execution_address} executions: {required_balance}, {executions}" + ) # Stop executions until the required balance is reached while executions and balance < (required_balance + settings.PAYMENT_BUFFER): last_execution = executions.pop(-1) logger.debug(f"Stopping {last_execution} due to insufficient balance") await pool.stop_vm(last_execution.vm_hash) required_balance = await compute_required_balance(executions) + community_wallet = await get_community_wallet_address() if not community_wallet: logger.error("Monitor payment ERROR: No community wallet set. Cannot check community payment") + # Check if the credit balance held in the wallet is sufficient credit tier resources (Not do it yet) + for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.credit).items(): + for chain, executions in chains.items(): + executions = [execution for execution in executions] + if not executions: + continue + balance = await fetch_credit_balance_of_address(execution_address) + + # Stop executions until the required credits are reached + required_credits = await compute_required_credit_balance(executions) + logger.debug( + f"Required credit balance for Address {execution_address} executions: {required_credits}, {executions}" + ) + # Stop executions until the required credits are reached + while executions and balance < (required_credits + settings.PAYMENT_BUFFER): + last_execution = executions.pop(-1) + logger.debug(f"Stopping {last_execution} due to insufficient credit balance") + await pool.stop_vm(last_execution.vm_hash) + required_credits = await compute_required_credit_balance(executions) + # Check if the balance held in the wallet is sufficient stream tier resources - for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.superfluid).items(): + for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.superfluid).items(): for chain, executions in chains.items(): try: - stream = await get_stream(sender=sender, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain) + stream = await get_stream( + sender=execution_address, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain + ) logger.debug( - f"Stream flow from {sender} to {settings.PAYMENT_RECEIVER_ADDRESS} = {stream} {chain.value}" + f"Stream flow from {execution_address} to {settings.PAYMENT_RECEIVER_ADDRESS} = {stream} {chain.value}" ) except ValueError as error: - logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}") + logger.error(f"Error found getting stream for chain {chain} and sender {execution_address}: {error}") continue try: - community_stream = await get_stream(sender=sender, receiver=community_wallet, chain=chain) - logger.debug(f"Stream flow from {sender} to {community_wallet} (community) : {stream} {chain}") + community_stream = await get_stream(sender=execution_address, receiver=community_wallet, chain=chain) + logger.debug( + f"Stream flow from {execution_address} to {community_wallet} (community) : {stream} {chain}" + ) except ValueError as error: - logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}") + logger.error(f"Error found getting stream for chain {chain} and sender {execution_address}: {error}") continue while executions: @@ -249,7 +278,7 @@ async def check_payment(pool: VmPool): ) required_community_stream = format_cost(required_stream * COMMUNITY_STREAM_RATIO) logger.debug( - f"Stream for senders {sender} {len(executions)} executions. CRN : {stream} / {required_crn_stream}." + f"Stream for senders {execution_address} {len(executions)} executions. CRN : {stream} / {required_crn_stream}." f"Community: {community_stream} / {required_community_stream}" ) # Can pay all executions @@ -259,7 +288,7 @@ async def check_payment(pool: VmPool): break # Stop executions until the required stream is reached last_execution = executions.pop(-1) - logger.info(f"Stopping {last_execution} of {sender} due to insufficient stream") + logger.info(f"Stopping {last_execution} of {execution_address} due to insufficient stream") await pool.stop_vm(last_execution.vm_hash) diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index 9d81a2df5..a964b72ac 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -35,7 +35,7 @@ from aleph.vm.orchestrator.payment import ( InvalidAddressError, InvalidChainError, - fetch_execution_flow_price, + fetch_execution_price, get_stream, ) from aleph.vm.orchestrator.pubsub import PubSub @@ -577,7 +577,7 @@ async def notify_allocation(request: web.Request): if have_gpu: logger.debug(f"GPU Instance {item_hash} not using PAYG") user_balance = await payment.fetch_balance_of_address(message.sender) - hold_price = await payment.fetch_execution_hold_price(item_hash) + hold_price = await payment.fetch_execution_price(item_hash, [PaymentType.hold], False) logger.debug(f"Address {message.sender} Balance: {user_balance}, Price: {hold_price}") if hold_price > user_balance: return web.HTTPPaymentRequired( @@ -606,7 +606,7 @@ async def notify_allocation(request: web.Request): if not active_flow: raise web.HTTPPaymentRequired(reason="Empty payment stream for this instance") - required_flow: Decimal = await fetch_execution_flow_price(item_hash) + required_flow: Decimal = await fetch_execution_price(item_hash, [PaymentType.superfluid]) community_wallet = await get_community_wallet_address() required_crn_stream: Decimal required_community_stream: Decimal diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index d8aec10f6..0383e6703 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -379,9 +379,9 @@ def get_available_gpus(self) -> list[GpuDevice]: available_gpus.append(gpu) return available_gpus - def get_executions_by_sender(self, payment_type: PaymentType) -> dict[str, dict[str, list[VmExecution]]]: + def get_executions_by_address(self, payment_type: PaymentType) -> dict[str, dict[str, list[VmExecution]]]: """Return all executions of the given type, grouped by sender and by chain.""" - executions_by_sender: dict[str, dict[str, list[VmExecution]]] = {} + executions_by_address: dict[str, dict[str, list[VmExecution]]] = {} for vm_hash, execution in self.executions.items(): if execution.vm_hash in (settings.CHECK_FASTAPI_VM_ID, settings.LEGACY_CHECK_FASTAPI_VM_ID): # Ignore Diagnostic VM execution @@ -399,11 +399,11 @@ def get_executions_by_sender(self, payment_type: PaymentType) -> dict[str, dict[ else Payment(chain=Chain.ETH, type=PaymentType.hold) ) if execution_payment.type == payment_type: - sender = execution.message.address + address = execution.message.address chain = execution_payment.chain - executions_by_sender.setdefault(sender, {}) - executions_by_sender[sender].setdefault(chain, []).append(execution) - return executions_by_sender + executions_by_address.setdefault(address, {}) + executions_by_address[address].setdefault(chain, []).append(execution) + return executions_by_address def get_valid_reservation(self, resource) -> Reservation | None: if resource in self.reservations and self.reservations[resource].is_expired(): diff --git a/tests/supervisor/test_checkpayment.py b/tests/supervisor/test_checkpayment.py index b5d75dbf1..0c4ab6ac3 100644 --- a/tests/supervisor/test_checkpayment.py +++ b/tests/supervisor/test_checkpayment.py @@ -82,7 +82,7 @@ async def compute_required_flow(executions): pool.executions = {hash: execution} - executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid) assert len(executions_by_sender) == 1 assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} @@ -136,7 +136,7 @@ async def compute_required_flow(executions): pool.executions = {hash: execution} - executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid) assert len(executions_by_sender) == 1 assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} @@ -173,7 +173,7 @@ async def test_not_enough_flow(mocker, fake_instance_content): pool.executions = {hash: execution} - executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid) assert len(executions_by_sender) == 1 assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} @@ -217,7 +217,7 @@ async def get_stream(sender, receiver, chain): pool.executions = {hash: execution} - executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid) assert len(executions_by_sender) == 1 assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}}