Skip to content
17 changes: 10 additions & 7 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
context: ServerCallContext | None = None,
) -> Task | None:
"""Default handler for 'tasks/get'."""
task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand Down Expand Up @@ -141,7 +141,7 @@

Attempts to cancel the task managed by the `AgentExecutor`.
"""
task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand All @@ -150,20 +150,21 @@
raise ServerError(
error=TaskNotCancelableError(
message=f'Task cannot be canceled - current state: {task.status.state}'
)
)

task_manager = TaskManager(
task_id=task.id,
context_id=task.context_id,
task_store=self.task_store,
initial_message=None,
context=context,
)
result_aggregator = ResultAggregator(task_manager)

queue = await self._queue_manager.tap(task.id)
if not queue:
queue = EventQueue()

Check notice on line 167 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (489-504)

await self.agent_executor.cancel(
RequestContext(
Expand Down Expand Up @@ -224,6 +225,7 @@
context_id=params.message.context_id,
task_store=self.task_store,
initial_message=params.message,
context=context,
)
task: Task | None = await task_manager.get_task()

Expand Down Expand Up @@ -424,7 +426,7 @@
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.task_id)
task: Task | None = await self.task_store.get(params.task_id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand All @@ -447,7 +449,7 @@
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand Down Expand Up @@ -476,7 +478,7 @@
Allows a client to re-attach to a running streaming task's event stream.
Requires the task and its queue to still be active.
"""
task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand All @@ -484,21 +486,22 @@
raise ServerError(
error=InvalidParamsError(
message=f'Task {task.id} is in terminal state: {task.status.state.value}'
)
)

task_manager = TaskManager(
task_id=task.id,
context_id=task.context_id,
task_store=self.task_store,
initial_message=None,
context=context,
)

result_aggregator = ResultAggregator(task_manager)

queue = await self._queue_manager.tap(task.id)
if not queue:
raise ServerError(error=TaskNotFoundError())

Check notice on line 504 in src/a2a/server/request_handlers/default_request_handler.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/request_handlers/default_request_handler.py (153-167)

consumer = EventConsumer(queue)
async for event in result_aggregator.consume_and_emit(consumer):
Expand All @@ -516,7 +519,7 @@
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand All @@ -543,7 +546,7 @@
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.id)
task: Task | None = await self.task_store.get(params.id, context)
if not task:
raise ServerError(error=TaskNotFoundError())

Expand Down
13 changes: 10 additions & 3 deletions src/a2a/server/tasks/database_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"or 'pip install a2a-sdk[sql]'"
) from e

from a2a.server.context import ServerCallContext
from a2a.server.models import Base, TaskModel, create_task_model
from a2a.server.tasks.task_store import TaskStore
from a2a.types import Task # Task is the Pydantic model
Expand Down Expand Up @@ -119,15 +120,19 @@ def _from_orm(self, task_model: TaskModel) -> Task:
# Pydantic's model_validate will parse the nested dicts/lists from JSON
return Task.model_validate(task_data_from_db)

async def save(self, task: Task) -> None:
async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
"""Saves or updates a task in the database."""
await self._ensure_initialized()
db_task = self._to_orm(task)
async with self.async_session_maker.begin() as session:
await session.merge(db_task)
logger.debug('Task %s saved/updated successfully.', task.id)

async def get(self, task_id: str) -> Task | None:
async def get(
self, task_id: str, context: ServerCallContext | None = None
) -> Task | None:
"""Retrieves a task from the database by ID."""
await self._ensure_initialized()
async with self.async_session_maker() as session:
Expand All @@ -142,7 +147,9 @@ async def get(self, task_id: str) -> Task | None:
logger.debug('Task %s not found in store.', task_id)
return None

async def delete(self, task_id: str) -> None:
async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
"""Deletes a task from the database by ID."""
await self._ensure_initialized()

Expand Down
13 changes: 10 additions & 3 deletions src/a2a/server/tasks/inmemory_task_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging

from a2a.server.context import ServerCallContext
from a2a.server.tasks.task_store import TaskStore
from a2a.types import Task

Expand All @@ -21,13 +22,17 @@ def __init__(self) -> None:
self.tasks: dict[str, Task] = {}
self.lock = asyncio.Lock()

async def save(self, task: Task) -> None:
async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
"""Saves or updates a task in the in-memory store."""
async with self.lock:
self.tasks[task.id] = task
logger.debug('Task %s saved successfully.', task.id)

async def get(self, task_id: str) -> Task | None:
async def get(
self, task_id: str, context: ServerCallContext | None = None
) -> Task | None:
"""Retrieves a task from the in-memory store by ID."""
async with self.lock:
logger.debug('Attempting to get task with id: %s', task_id)
Expand All @@ -38,7 +43,9 @@ async def get(self, task_id: str) -> Task | None:
logger.debug('Task %s not found in store.', task_id)
return task

async def delete(self, task_id: str) -> None:
async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
"""Deletes a task from the in-memory store by ID."""
async with self.lock:
logger.debug('Attempting to delete task with id: %s', task_id)
Expand Down
12 changes: 9 additions & 3 deletions src/a2a/server/tasks/task_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from a2a.server.context import ServerCallContext
from a2a.server.events.event_queue import Event
from a2a.server.tasks.task_store import TaskStore
from a2a.types import (
Expand Down Expand Up @@ -31,6 +32,7 @@ def __init__(
context_id: str | None,
task_store: TaskStore,
initial_message: Message | None,
context: ServerCallContext | None = None,
):
"""Initializes the TaskManager.

Expand All @@ -40,6 +42,7 @@ def __init__(
task_store: The `TaskStore` instance for persistence.
initial_message: The `Message` that initiated the task, if any.
Used when creating a new task object.
context: The `ServerCallContext` that this task is produced under.
"""
if task_id is not None and not (isinstance(task_id, str) and task_id):
raise ValueError('Task ID must be a non-empty string')
Expand All @@ -49,6 +52,7 @@ def __init__(
self.task_store = task_store
self._initial_message = initial_message
self._current_task: Task | None = None
self._call_context: ServerCallContext | None = context
logger.debug(
'TaskManager initialized with task_id: %s, context_id: %s',
task_id,
Expand All @@ -74,7 +78,9 @@ async def get_task(self) -> Task | None:
logger.debug(
'Attempting to get task from store with id: %s', self.task_id
)
self._current_task = await self.task_store.get(self.task_id)
self._current_task = await self.task_store.get(
self.task_id, self._call_context
)
if self._current_task:
logger.debug('Task %s retrieved successfully.', self.task_id)
else:
Expand Down Expand Up @@ -167,7 +173,7 @@ async def ensure_task(
logger.debug(
'Attempting to retrieve existing task with id: %s', self.task_id
)
task = await self.task_store.get(self.task_id)
task = await self.task_store.get(self.task_id, self._call_context)

if not task:
logger.info(
Expand Down Expand Up @@ -231,7 +237,7 @@ async def _save_task(self, task: Task) -> None:
task: The `Task` object to save.
"""
logger.debug('Saving task with id: %s', task.id)
await self.task_store.save(task)
await self.task_store.save(task, self._call_context)
self._current_task = task
if not self.task_id:
logger.info('New task created with id: %s', task.id)
Expand Down
13 changes: 10 additions & 3 deletions src/a2a/server/tasks/task_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod

from a2a.server.context import ServerCallContext
from a2a.types import Task


Expand All @@ -10,13 +11,19 @@ class TaskStore(ABC):
"""

@abstractmethod
async def save(self, task: Task) -> None:
async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
"""Saves or updates a task in the store."""

@abstractmethod
async def get(self, task_id: str) -> Task | None:
async def get(
self, task_id: str, context: ServerCallContext | None = None
) -> Task | None:
"""Retrieves a task from the store by ID."""

@abstractmethod
async def delete(self, task_id: str) -> None:
async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
"""Deletes a task from the store by ID."""
Loading
Loading