diff --git a/src/aleph/vm/orchestrator/tasks.py b/src/aleph/vm/orchestrator/tasks.py index 803d3ca32..3f62d23c7 100644 --- a/src/aleph/vm/orchestrator/tasks.py +++ b/src/aleph/vm/orchestrator/tasks.py @@ -181,7 +181,7 @@ async def check_payment(pool: VmPool): if vm_hash == settings.FAKE_INSTANCE_ID: continue message_status = await get_message_status(vm_hash) - if message_status != MessageStatus.PROCESSED: + if message_status != MessageStatus.PROCESSED and message_status != MessageStatus.REMOVING: logger.debug(f"Stopping {vm_hash} execution due to {message_status} message status") await pool.stop_vm(vm_hash) pool.forget_vm(vm_hash) diff --git a/tests/supervisor/test_checkpayment.py b/tests/supervisor/test_checkpayment.py index b5d75dbf1..0617c7e6c 100644 --- a/tests/supervisor/test_checkpayment.py +++ b/tests/supervisor/test_checkpayment.py @@ -224,3 +224,77 @@ async def get_stream(sender, receiver, chain): await check_payment(pool=pool) execution.stop.assert_called_with() + + +@pytest.mark.asyncio +async def test_message_removing_status(mocker, fake_instance_content): + mocker.patch.object(settings, "ALLOW_VM_NETWORKING", False) + mocker.patch.object(settings, "PAYMENT_RECEIVER_ADDRESS", "0xD39C335404a78E0BDCf6D50F29B86EFd57924288") + + pool = VmPool() + mock_community_wallet_address = "0x23C7A99d7AbebeD245d044685F1893aeA4b5Da90" + + mocker.patch("aleph.vm.orchestrator.tasks.get_stream", return_value=400, autospec=True) + mocker.patch("aleph.vm.orchestrator.tasks.get_community_wallet_address", return_value=mock_community_wallet_address) + mocker.patch("aleph.vm.orchestrator.tasks.get_message_status", return_value=MessageStatus.REMOVING) + mocker.patch("aleph.vm.orchestrator.tasks.compute_required_flow", return_value=5) + message = InstanceContent.model_validate(fake_instance_content) + + mocker.patch.object(VmExecution, "is_running", new=True) + mocker.patch.object(VmExecution, "stop", new=mocker.AsyncMock(return_value=False)) + hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadece" + execution = VmExecution( + vm_hash=hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + + pool.executions = {hash: execution} + + executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + assert len(executions_by_sender) == 1 + assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} + + await check_payment(pool=pool) + + execution.stop.assert_not_called() + + +@pytest.mark.asyncio +async def test_removed_message_status(mocker, fake_instance_content): + mocker.patch.object(settings, "ALLOW_VM_NETWORKING", False) + mocker.patch.object(settings, "PAYMENT_RECEIVER_ADDRESS", "0xD39C335404a78E0BDCf6D50F29B86EFd57924288") + + pool = VmPool() + mock_community_wallet_address = "0x23C7A99d7AbebeD245d044685F1893aeA4b5Da90" + + mocker.patch("aleph.vm.orchestrator.tasks.get_stream", return_value=400, autospec=True) + mocker.patch("aleph.vm.orchestrator.tasks.get_community_wallet_address", return_value=mock_community_wallet_address) + mocker.patch("aleph.vm.orchestrator.tasks.get_message_status", return_value=MessageStatus.REMOVED) + mocker.patch("aleph.vm.orchestrator.tasks.compute_required_flow", return_value=5) + message = InstanceContent.model_validate(fake_instance_content) + + mocker.patch.object(VmExecution, "is_running", new=True) + mocker.patch.object(VmExecution, "stop", new=mocker.AsyncMock(return_value=False)) + hash = "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadece" + execution = VmExecution( + vm_hash=hash, + message=message, + original=message, + persistent=False, + snapshot_manager=None, + systemd_manager=None, + ) + + pool.executions = {hash: execution} + + executions_by_sender = pool.get_executions_by_sender(payment_type=PaymentType.superfluid) + assert len(executions_by_sender) == 1 + assert executions_by_sender == {"0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9": {Chain.BASE: [execution]}} + + await check_payment(pool=pool) + + execution.stop.assert_called_with()