diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index e54fa13b..c5794f45 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -112,6 +112,7 @@ async def consume_all(self) -> AsyncGenerator[Event]: TaskState.failed, TaskState.rejected, TaskState.unknown, + TaskState.input_required ) ) ) diff --git a/tests/server/events/test_event_consumer.py b/tests/server/events/test_event_consumer.py index 3fd0e58d..8b596607 100644 --- a/tests/server/events/test_event_consumer.py +++ b/tests/server/events/test_event_consumer.py @@ -222,3 +222,22 @@ async def mock_dequeue() -> Any: assert len(consumed_events) == 1 assert consumed_events[0] == events[0] assert mock_event_queue.task_done.call_count == 1 + +@pytest.mark.asyncio +async def test_consume_task_input_required( + event_consumer: MagicMock, + mock_event_queue: MagicMock, +): + task = Task(**MINIMAL_TASK) + task.status = TaskStatus(state=TaskState.input_required) + + async def mock_dequeue() -> Any: + return task + + mock_event_queue.dequeue_event = mock_dequeue + consumed_events: list[Any] = [] + #consumer should terminate on input_required task + async for event in event_consumer.consume_all(): + consumed_events.append(event) + assert len(consumed_events) == 1 + assert consumed_events[0] == task \ No newline at end of file