From 69e03d6d80e2f21305c6885c7ba44bd45e9311ea Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Fri, 8 Aug 2025 10:53:37 +0200 Subject: [PATCH 1/2] Feature: Implement credits payment method monitoring, same as PAYG. --- pyproject.toml | 3 +- src/aleph/vm/orchestrator/payment.py | 51 ++++++++++++++++++++++++++ src/aleph/vm/orchestrator/tasks.py | 53 +++++++++++++++++++++------ src/aleph/vm/pool.py | 12 +++--- tests/supervisor/test_checkpayment.py | 8 ++-- 5 files changed, 104 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index da518dcfe..5536f8065 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ dependencies = [ "aioredis==1.3.1", "aiosqlite==0.19", "alembic==1.13.1", - "aleph-message~=1.0.1", + # "aleph-message~=1.0.1", + "aleph-message @ git+https://github.com/aleph-im/aleph-message@andres-feature-implement_credits_payment", "aleph-superfluid~=0.2.1", "dbus-python==1.3.2", "eth-account~=0.10", diff --git a/src/aleph/vm/orchestrator/payment.py b/src/aleph/vm/orchestrator/payment.py index f5a79bbca..a8bb4aed1 100644 --- a/src/aleph/vm/orchestrator/payment.py +++ b/src/aleph/vm/orchestrator/payment.py @@ -44,6 +44,32 @@ async def fetch_balance_of_address(address: str) -> Decimal: return resp_data["balance"] +async def fetch_credit_balance_of_address(address: str) -> Decimal: + """ + Get the balance of the user from the PyAleph API. + + API Endpoint: + GET /api/v0/addresses/{address}/balance + + For more details, see the PyAleph API documentation: + https://github.com/aleph-im/pyaleph/blob/master/src/aleph/web/controllers/routes.py#L62 + """ + + async with aiohttp.ClientSession() as session: + url = f"{settings.API_SERVER}/api/v0/addresses/{address}/credit_balance" + resp = await session.get(url) + + # Consider the balance as null if the address is not found + if resp.status == 404: + return Decimal(0) + + # Raise an error if the request failed + resp.raise_for_status() + + resp_data = await resp.json() + return resp_data["credits"] + + async def fetch_execution_flow_price(item_hash: ItemHash) -> Decimal: """Fetch the flow price of an execution from the reference API server.""" async with aiohttp.ClientSession() as session: @@ -85,6 +111,25 @@ async def fetch_execution_hold_price(item_hash: ItemHash) -> Decimal: return Decimal(required_hold) +async def fetch_execution_credit_price(item_hash: ItemHash) -> Decimal: + """Fetch the credit price of an execution from the reference API server.""" + async with aiohttp.ClientSession() as session: + url = f"{settings.API_SERVER}/api/v0/price/{item_hash}" + resp = await session.get(url) + # Raise an error if the request failed + resp.raise_for_status() + + resp_data = await resp.json() + required_credits: float = resp_data["required_credits"] # Field not defined yet on API side. + payment_type: str | None = resp_data["payment_type"] + + if payment_type not in (None, PaymentType.credit): + msg = f"Payment type {payment_type} is not supported" + raise ValueError(msg) + + return Decimal(required_credits) + + class InvalidAddressError(ValueError): """The blockchain address could not be parsed.""" @@ -137,6 +182,12 @@ async def compute_required_balance(executions: Iterable[VmExecution]) -> Decimal return sum(costs, Decimal(0)) +async def compute_required_credit_balance(executions: Iterable[VmExecution]) -> Decimal: + """Get the balance required for the resources of the user from the messages and the pricing aggregate.""" + costs = await asyncio.gather(*(fetch_execution_credit_price(execution.vm_hash) for execution in executions)) + return sum(costs, Decimal(0)) + + async def compute_required_flow(executions: Iterable[VmExecution]) -> Decimal: """Compute the flow required for a collection of executions, typically all executions from a specific address""" flows = await asyncio.gather(*(fetch_execution_flow_price(execution.vm_hash) for execution in executions)) diff --git a/src/aleph/vm/orchestrator/tasks.py b/src/aleph/vm/orchestrator/tasks.py index 803d3ca32..93250b239 100644 --- a/src/aleph/vm/orchestrator/tasks.py +++ b/src/aleph/vm/orchestrator/tasks.py @@ -31,8 +31,10 @@ from .messages import get_message_status from .payment import ( compute_required_balance, + compute_required_credit_balance, compute_required_flow, fetch_balance_of_address, + fetch_credit_balance_of_address, get_stream, ) from .pubsub import PubSub @@ -187,44 +189,71 @@ async def check_payment(pool: VmPool): pool.forget_vm(vm_hash) # Check if the balance held in the wallet is sufficient holder tier resources (Not do it yet) - for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.hold).items(): + for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.hold).items(): for chain, executions in chains.items(): executions = [execution for execution in executions if execution.is_confidential] if not executions: continue - balance = await fetch_balance_of_address(sender) + balance = await fetch_balance_of_address(execution_address) # Stop executions until the required balance is reached required_balance = await compute_required_balance(executions) - logger.debug(f"Required balance for Sender {sender} executions: {required_balance}, {executions}") + logger.debug( + f"Required balance for Sender {execution_address} executions: {required_balance}, {executions}" + ) # Stop executions until the required balance is reached while executions and balance < (required_balance + settings.PAYMENT_BUFFER): last_execution = executions.pop(-1) logger.debug(f"Stopping {last_execution} due to insufficient balance") await pool.stop_vm(last_execution.vm_hash) required_balance = await compute_required_balance(executions) + community_wallet = await get_community_wallet_address() if not community_wallet: logger.error("Monitor payment ERROR: No community wallet set. Cannot check community payment") + # Check if the credit balance held in the wallet is sufficient credit tier resources (Not do it yet) + for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.credit).items(): + for chain, executions in chains.items(): + executions = [execution for execution in executions] + if not executions: + continue + balance = await fetch_credit_balance_of_address(execution_address) + + # Stop executions until the required credits are reached + required_credits = await compute_required_credit_balance(executions) + logger.debug( + f"Required credit balance for Address {execution_address} executions: {required_credits}, {executions}" + ) + # Stop executions until the required credits are reached + while executions and balance < (required_credits + settings.PAYMENT_BUFFER): + last_execution = executions.pop(-1) + logger.debug(f"Stopping {last_execution} due to insufficient credit balance") + await pool.stop_vm(last_execution.vm_hash) + required_credits = await compute_required_credit_balance(executions) + # Check if the balance held in the wallet is sufficient stream tier resources - for sender, chains in pool.get_executions_by_sender(payment_type=PaymentType.superfluid).items(): + for execution_address, chains in pool.get_executions_by_address(payment_type=PaymentType.superfluid).items(): for chain, executions in chains.items(): try: - stream = await get_stream(sender=sender, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain) + stream = await get_stream( + sender=execution_address, receiver=settings.PAYMENT_RECEIVER_ADDRESS, chain=chain + ) logger.debug( - f"Stream flow from {sender} to {settings.PAYMENT_RECEIVER_ADDRESS} = {stream} {chain.value}" + f"Stream flow from {execution_address} to {settings.PAYMENT_RECEIVER_ADDRESS} = {stream} {chain.value}" ) except ValueError as error: - logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}") + logger.error(f"Error found getting stream for chain {chain} and sender {execution_address}: {error}") continue try: - community_stream = await get_stream(sender=sender, receiver=community_wallet, chain=chain) - logger.debug(f"Stream flow from {sender} to {community_wallet} (community) : {stream} {chain}") + community_stream = await get_stream(sender=execution_address, receiver=community_wallet, chain=chain) + logger.debug( + f"Stream flow from {execution_address} to {community_wallet} (community) : {stream} {chain}" + ) except ValueError as error: - logger.error(f"Error found getting stream for chain {chain} and sender {sender}: {error}") + logger.error(f"Error found getting stream for chain {chain} and sender {execution_address}: {error}") continue while executions: @@ -249,7 +278,7 @@ async def check_payment(pool: VmPool): ) required_community_stream = format_cost(required_stream * COMMUNITY_STREAM_RATIO) logger.debug( - f"Stream for senders {sender} {len(executions)} executions. CRN : {stream} / {required_crn_stream}." + f"Stream for senders {execution_address} {len(executions)} executions. CRN : {stream} / {required_crn_stream}." f"Community: {community_stream} / {required_community_stream}" ) # Can pay all executions @@ -259,7 +288,7 @@ async def check_payment(pool: VmPool): break # Stop executions until the required stream is reached last_execution = executions.pop(-1) - logger.info(f"Stopping {last_execution} of {sender} due to insufficient stream") + logger.info(f"Stopping {last_execution} of {execution_address} due to insufficient stream") await pool.stop_vm(last_execution.vm_hash) diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index d8aec10f6..0383e6703 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -379,9 +379,9 @@ def get_available_gpus(self) -> list[GpuDevice]: available_gpus.append(gpu) return available_gpus - def get_executions_by_sender(self, payment_type: PaymentType) -> dict[str, dict[str, list[VmExecution]]]: + def get_executions_by_address(self, payment_type: PaymentType) -> dict[str, dict[str, list[VmExecution]]]: """Return all executions of the given type, grouped by sender and by chain.""" - executions_by_sender: dict[str, dict[str, list[VmExecution]]] = {} + executions_by_address: dict[str, dict[str, list[VmExecution]]] = {} for vm_hash, execution in self.executions.items(): if execution.vm_hash in (settings.CHECK_FASTAPI_VM_ID, settings.LEGACY_CHECK_FASTAPI_VM_ID): # Ignore Diagnostic VM execution @@ -399,11 +399,11 @@ def get_executions_by_sender(self, payment_type: PaymentType) -> dict[str, dict[ else Payment(chain=Chain.ETH, type=PaymentType.hold) ) if execution_payment.type == payment_type: - sender = execution.message.address + address = execution.message.address chain = execution_payment.chain - executions_by_sender.setdefault(sender, {}) - executions_by_sender[sender].setdefault(chain, []).append(execution) - return executions_by_sender + executions_by_address.setdefault(address, {}) + executions_by_address[address].setdefault(chain, []).append(execution) + return executions_by_address def get_valid_reservation(self, resource) -> Reservation | None: if resource in self.reservations and self.reservations[resource].is_expired(): diff --git a/tests/supervisor/test_checkpayment.py b/tests/supervisor/test_checkpayment.py index b5d75dbf1..0c4ab6ac3 100644 --- a/tests/supervisor/test_checkpayment.py +++ b/tests/supervisor/test_checkpayment.py @@ -82,7 +82,7 @@ async def compute_required_flow(executions): pool.executions = {hash: execution} - executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid) assert len(executions_by_sender) == 1 assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} @@ -136,7 +136,7 @@ async def compute_required_flow(executions): pool.executions = {hash: execution} - executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid) assert len(executions_by_sender) == 1 assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} @@ -173,7 +173,7 @@ async def test_not_enough_flow(mocker, fake_instance_content): pool.executions = {hash: execution} - executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid) assert len(executions_by_sender) == 1 assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} @@ -217,7 +217,7 @@ async def get_stream(sender, receiver, chain): pool.executions = {hash: execution} - executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + executions_by_sender = pool.get_executions_by_address(payment_type=PaymentType.superfluid) assert len(executions_by_sender) == 1 assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} From 3e3351f54916084060b6eaae8129566deffafc34 Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Thu, 14 Aug 2025 12:31:07 +0200 Subject: [PATCH 2/2] Fix: Refactor to unify getting execution price method making it compatible with multiple payment methods. --- src/aleph/vm/orchestrator/payment.py | 70 +++++++-------------- src/aleph/vm/orchestrator/views/__init__.py | 6 +- 2 files changed, 26 insertions(+), 50 deletions(-) diff --git a/src/aleph/vm/orchestrator/payment.py b/src/aleph/vm/orchestrator/payment.py index a8bb4aed1..8b863451f 100644 --- a/src/aleph/vm/orchestrator/payment.py +++ b/src/aleph/vm/orchestrator/payment.py @@ -2,6 +2,7 @@ import logging from collections.abc import Iterable from decimal import Decimal +from typing import List import aiohttp from aleph_message.models import ItemHash, PaymentType @@ -70,48 +71,9 @@ async def fetch_credit_balance_of_address(address: str) -> Decimal: return resp_data["credits"] -async def fetch_execution_flow_price(item_hash: ItemHash) -> Decimal: - """Fetch the flow price of an execution from the reference API server.""" - async with aiohttp.ClientSession() as session: - url = f"{settings.API_SERVER}/api/v0/price/{item_hash}" - resp = await session.get(url) - # Raise an error if the request failed - resp.raise_for_status() - - resp_data = await resp.json() - required_flow: float = resp_data["required_tokens"] - payment_type: str | None = resp_data["payment_type"] - - if payment_type is None: - msg = "Payment type must be specified in the message" - raise ValueError(msg) - elif payment_type != PaymentType.superfluid: - msg = f"Payment type {payment_type} is not supported" - raise ValueError(msg) - - return Decimal(required_flow) - - -async def fetch_execution_hold_price(item_hash: ItemHash) -> Decimal: - """Fetch the hold price of an execution from the reference API server.""" - async with aiohttp.ClientSession() as session: - url = f"{settings.API_SERVER}/api/v0/price/{item_hash}" - resp = await session.get(url) - # Raise an error if the request failed - resp.raise_for_status() - - resp_data = await resp.json() - required_hold: float = resp_data["required_tokens"] - payment_type: str | None = resp_data["payment_type"] - - if payment_type not in (None, PaymentType.hold): - msg = f"Payment type {payment_type} is not supported" - raise ValueError(msg) - - return Decimal(required_hold) - - -async def fetch_execution_credit_price(item_hash: ItemHash) -> Decimal: +async def fetch_execution_price( + item_hash: ItemHash, allowed_payments: List[PaymentType], payment_type_required: bool = True +) -> Decimal: """Fetch the credit price of an execution from the reference API server.""" async with aiohttp.ClientSession() as session: url = f"{settings.API_SERVER}/api/v0/price/{item_hash}" @@ -123,10 +85,15 @@ async def fetch_execution_credit_price(item_hash: ItemHash) -> Decimal: required_credits: float = resp_data["required_credits"] # Field not defined yet on API side. payment_type: str | None = resp_data["payment_type"] - if payment_type not in (None, PaymentType.credit): - msg = f"Payment type {payment_type} is not supported" + if payment_type_required and payment_type is None: + msg = "Payment type must be specified in the message" raise ValueError(msg) + if payment_type: + if payment_type not in allowed_payments: + msg = f"Payment type {payment_type} is not supported" + raise ValueError(msg) + return Decimal(required_credits) @@ -178,17 +145,26 @@ async def get_stream(sender: str, receiver: str, chain: str) -> Decimal: async def compute_required_balance(executions: Iterable[VmExecution]) -> Decimal: """Get the balance required for the resources of the user from the messages and the pricing aggregate.""" - costs = await asyncio.gather(*(fetch_execution_hold_price(execution.vm_hash) for execution in executions)) + costs = await asyncio.gather( + *( + fetch_execution_price(execution.vm_hash, [PaymentType.hold], payment_type_required=False) + for execution in executions + ) + ) return sum(costs, Decimal(0)) async def compute_required_credit_balance(executions: Iterable[VmExecution]) -> Decimal: """Get the balance required for the resources of the user from the messages and the pricing aggregate.""" - costs = await asyncio.gather(*(fetch_execution_credit_price(execution.vm_hash) for execution in executions)) + costs = await asyncio.gather( + *(fetch_execution_price(execution.vm_hash, [PaymentType.credit]) for execution in executions) + ) return sum(costs, Decimal(0)) async def compute_required_flow(executions: Iterable[VmExecution]) -> Decimal: """Compute the flow required for a collection of executions, typically all executions from a specific address""" - flows = await asyncio.gather(*(fetch_execution_flow_price(execution.vm_hash) for execution in executions)) + flows = await asyncio.gather( + *(fetch_execution_price(execution.vm_hash, [PaymentType.superfluid]) for execution in executions) + ) return sum(flows, Decimal(0)) diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index 0cee1dfad..4447ed736 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -35,7 +35,7 @@ from aleph.vm.orchestrator.payment import ( InvalidAddressError, InvalidChainError, - fetch_execution_flow_price, + fetch_execution_price, get_stream, ) from aleph.vm.orchestrator.pubsub import PubSub @@ -570,7 +570,7 @@ async def notify_allocation(request: web.Request): if have_gpu: logger.debug(f"GPU Instance {item_hash} not using PAYG") user_balance = await payment.fetch_balance_of_address(message.sender) - hold_price = await payment.fetch_execution_hold_price(item_hash) + hold_price = await payment.fetch_execution_price(item_hash, [PaymentType.hold], False) logger.debug(f"Address {message.sender} Balance: {user_balance}, Price: {hold_price}") if hold_price > user_balance: return web.HTTPPaymentRequired( @@ -599,7 +599,7 @@ async def notify_allocation(request: web.Request): if not active_flow: raise web.HTTPPaymentRequired(reason="Empty payment stream for this instance") - required_flow: Decimal = await fetch_execution_flow_price(item_hash) + required_flow: Decimal = await fetch_execution_price(item_hash, [PaymentType.superfluid]) community_wallet = await get_community_wallet_address() required_crn_stream: Decimal required_community_stream: Decimal