From e85aac0074c19bac0ae82e18c492113ff4d13f93 Mon Sep 17 00:00:00 2001 From: Alexandre Tresallet Date: Thu, 23 Apr 2026 17:05:40 +0200 Subject: [PATCH] feat(task): integrate TaskSettings for configurable task management parameters feat(client): add top-level gRPC client settings and channel configuration feat(client_settings): Update client settings structure and improve task management configurations feat(client_settings): Update client and server configuration settings in .env.example and related files Signed-off-by: Alexandre --- .env.exemple | 289 +++++++++++++----- .gitignore | 3 + .../core/task_manager/base_task_manager.py | 68 ++--- src/digitalkin/grpc_servers/_base_server.py | 5 +- src/digitalkin/grpc_servers/module_server.py | 2 +- .../grpc_servers/module_servicer.py | 15 +- .../grpc_servers/utils/grpc_client_wrapper.py | 53 ++-- src/digitalkin/models/grpc_servers/models.py | 195 +----------- .../models/settings/client/__init__.py | 1 + .../models/settings/client/channel.py | 27 ++ .../models/settings/client/client.py | 34 +++ .../models/settings/client/grpc_client.py | 58 ++++ .../models/settings/client/retry_policy.py | 92 ++++++ src/digitalkin/models/settings/client/task.py | 47 +++ .../models/settings/server/channel.py | 6 +- .../server/{grpc.py => grpc_server.py} | 2 +- .../models/settings/server/server.py | 22 +- .../models/settings/utils/channel.py | 65 +--- .../models/settings/utils/grpc_base.py | 60 ++++ .../models/settings/utils/models.py | 91 ++++++ .../communication/grpc_communication.py | 5 - .../task_manager/grpc_task_manager.py | 51 ++-- tests/core/test_base_task_manager.py | 260 ++++++++-------- tests/grpc_server/test_module_service.py | 5 +- .../utils/test_grpc_client_wrapper.py | 9 +- tests/grpc_server/utils/test_models.py | 8 - tests/services/cost/test_cost_stress.py | 7 +- tests/services/cost/test_grpc_cost.py | 3 +- .../filesystem/test_grpc_filesystem.py | 3 +- tests/services/registry/test_grpc_registry.py | 5 +- tests/services/setup/test_grpc_setup.py | 3 +- tests/services/storage/test_grpc_storage.py | 5 +- .../task_manager/test_grpc_task_manager.py | 3 +- .../user_profile/test_grpc_user_profile.py | 5 +- 34 files changed, 923 insertions(+), 584 deletions(-) create mode 100644 src/digitalkin/models/settings/client/__init__.py create mode 100644 src/digitalkin/models/settings/client/channel.py create mode 100644 src/digitalkin/models/settings/client/client.py create mode 100644 src/digitalkin/models/settings/client/grpc_client.py create mode 100644 src/digitalkin/models/settings/client/retry_policy.py create mode 100644 src/digitalkin/models/settings/client/task.py rename src/digitalkin/models/settings/server/{grpc.py => grpc_server.py} (98%) create mode 100644 src/digitalkin/models/settings/utils/grpc_base.py create mode 100644 src/digitalkin/models/settings/utils/models.py diff --git a/.env.exemple b/.env.exemple index aadffa21..3d5983ab 100644 --- a/.env.exemple +++ b/.env.exemple @@ -1,129 +1,278 @@ -######################################### -# Server Settings for Archetype Module -######################################### +# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +# ┃ SERVER SETTINGS ┃ +# ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ -# ══ Channel Settings ════════════════════════════════════════════════════════════════ # +# Enable health check service +SERVER_HEALTH_CHECK=True -# Host address on which the module gRPC server listens. -# Default: [::] (all IPv6 & IPv4 interfaces) +# Enable reflection for the server +SERVER_REFLECTION=True + +# Maximum number of RPCs handled in parallel by the server. +SERVER_MAX_CONCURRENT_RPCS=2000 + +# Maximum number of workers for sync mode +SERVER_MAX_WORKERS=10 + +# Number of workers in the server thread pool. +SERVER_THREAD_POOL_WORKERS=4 + +# Enable asyncio inspector +SERVER_ASYNCIO_INSPECTOR_STATE=False + +# Port for asyncio inspector +SERVER_ASYNCIO_INSPECTOR_PORT=8765 + +# Maximum number of cached setups in ModuleServicer +SERVER_SETUP_CACHE_MAX=100 + +# Timeout for completion in ModuleServicer +SERVER_COMPLETION_TIMEOUT=300.0 + + +# ═════════════════════════════ Server Channel Settings ══════════════════════════════ # + +# Host address to bind the client to SERVER_CHANNEL_HOST=[::] -# TCP port for the module gRPC server. -# Default: 50055 +# Port to listen on SERVER_CHANNEL_PORT=50055 -# Execution mode of the server. Possible values: async | sync -# Default: async -SERVER_CHANNEL_CONTROL_FLOW=async +# Client/Server operation mode (sync/async) +SERVER_CHANNEL_COMMUNICATION_MODE=async -# Security mode of the server. Possible values: insecure | secure -# Default: insecure +# Security mode (secure/insecure) SERVER_CHANNEL_SECURITY=insecure -# Enable mutual TLS for incoming clients: "true" or "false" -# Default: false -SERVER_CHANNEL_MTLS=false +# Enable mutual TLS +SERVER_CHANNEL_MTLS=False -#SERVER_CHANNEL_CREDENTIALS__KEY_PATH=server.key -#SERVER_CHANNEL_CREDENTIALS__CERT_PATH=cert.key -#SERVER_CHANNEL_CREDENTIALS__ROOT_CERT_PATH=ca.crt +# gRPC compression algorithm +SERVER_CHANNEL_COMPRESSION=gzip -# Hostname or IP address that the server advertises to clients for connection. -# Default: digitalkin-ada-archetype +# (Optional)Public hostname/IP sent to registry for discovery. Falls back to host if not set. SERVER_CHANNEL_ADVERTISE_HOST= -# ══ gRPC Settings ════════════════════════════════════════════════════════════════════ # -# Compression algorithm for gRPC messages. Possible values: gzip, deflate, snappy, zstd, or none -# Default: gzip +#TODO +## ════════════════════════════ Server Credentils Settings ════════════════════════════ # +# +## Path to credentials files +#SERVER_CHANNEL_CREDENTIAL_PATH=/credentials +# +## Path to the private key +#SERVER_CHANNEL_CREDENTIAL_KEY=server.key +# +## Path to the certificate +#SERVER_CHANNEL_CREDENTIAL_CERT=server.crt +# +## (Optional) Path to CA if MTLS is active +#SERVER_CHANNEL_CREDENTIAL_CA=ca.crt + + +# ═══════════════════════════════ Server Grpc Settings ═══════════════════════════════ # + +# gRPC compression algorithm SERVER_GRPC_COMPRESSION=gzip -# ── Option Grpc ───────────────────────────────────────────────────────────────────── # -# Time (in milliseconds) after which a keepalive ping is sent if the connection is idle. -# Default: 120000 (2 minutes) +# Interval for server keepalive pings. SERVER_GRPC_OPTIONS_KEEPALIVE_TIME=120000 -# Time (in milliseconds) the server waits for a keepalive ping ack before closing the connection. -# Default: 20000 (20 seconds) +# Timeout for server keepalive pings. SERVER_GRPC_OPTIONS_KEEPALIVE_TIMEOUT=20000 -# Minimum time (in milliseconds) between client pings. -# Default: 10000 (10 seconds) +# Minimum interval between HTTP/2 pings on the server side. SERVER_GRPC_OPTIONS_MIN_PING_INTERVAL=10000 -# Maximum message size (in bytes) the server can receive. -# Default: 4194304 (4MB) -SERVER_GRPC_OPTIONS_MAX_RECEIVE_MESSAGE_LENGTH=4194304 +# Maximum message size the server can receive, in bytes. +SERVER_GRPC_OPTIONS_MAX_RECEIVE_MESSAGE_LENGTH=104857600 -# Maximum message size (in bytes) the server can send. -# Default: 4194304 (4MB) -SERVER_GRPC_OPTIONS_MAX_SEND_MESSAGE_LENGTH=4194304 +# Maximum message size the server can send, in bytes. +SERVER_GRPC_OPTIONS_MAX_SEND_MESSAGE_LENGTH=104857600 -# Maximum number of pings the server allows without receiving any data from the client. -# Default: 0 (unlimited) +# Maximum number of pings the server allows without receiving any data. Setting to 0 allows unlimited pings, which is important for long-running streams. SERVER_GRPC_OPTIONS_MAX_PINGS_WITHOUT_DATA=0 -# Allow keepalive pings when there are no active calls. "true" or "false" -# Default: true -SERVER_GRPC_OPTIONS_KEEPALIVE_PERMIT_WITHOUT_CALLS=true +# Allow clients to send keepalive pings even when there are no active RPCs. This is important for keeping connections alive through proxies and detecting dead clients. +SERVER_GRPC_OPTIONS_KEEPALIVE_PERMIT_WITHOUT_CALLS=True + + +# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +# ┃ CLIENT SETTINGS ┃ +# ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + +# Maximum number of retries for queries +CLIENT_QUERY_MAX_RETRIES=2 -######################################### -# gRPC Client Configuration Services Provider -######################################### +# Base backoff time in milliseconds for query retries +CLIENT_QUERY_BACKOFF_BASE_MS=50 -# Host address on which the services provider gRPC server listens. -# Default: [::] (all IPv6 & IPv4 interfaces) +# Timeout in seconds for queries +CLIENT_QUERY_TIMEOUT=30 + + +# ═════════════════════════════ Client Channel Settings ══════════════════════════════ # + +# Host address on which the services provider gRPC server listens (all IPv6 & IPv4 interfaces). SERVICES_PROVIDER_URL=[::] # TCP port for the services provider gRPC server. -# Default: 50151 SERVICES_PROVIDER_PORT=50151 # Execution mode of the services provider. Possible values: async | sync -# Default: async SERVICES_PROVIDER_MODE=async # Security mode of the services provider. Possible values: insecure | secure -# Default: insecure SERVICES_PROVIDER_SECURITY=insecure # Enable mutual TLS for outgoing calls: "true" or "false" -# Default: false SERVICES_PROVIDER_MTLS=false -GRPC_DNS_RESOLVER="native" +# Security mode (secure/insecure) +CLIENT_CHANNEL_SECURITY=insecure -############################################# -# Other Configuration -############################################# +# Enable mutual TLS +CLIENT_CHANNEL_MTLS=False -MODULE_ID_TOOLKIT_RAG="modules:1" -OPENAI_API_KEY=sk-xxxx +#TODO +## ════════════════════════════ Client Credentils Settings ════════════════════════════ # +# +## Path to credentials files +#CLIENT_CHANNEL_CREDENTIAL_PATH=/credentials +# +## Path to the private key +#CLIENT_CHANNEL_CREDENTIAL_KEY=client.key +# +## Path to the certificate +#CLIENT_CHANNEL_CREDENTIAL_CERT=client.crt +# +## (Optional) Path to CA if MTLS is active +#CLIENT_CHANNEL_CREDENTIAL_CA=ca.crt + + +# ═══════════════════════════════ Client Grpc Settings ═══════════════════════════════ # + +# Maximum message size the server can receive, in bytes. +CLIENT_GRPC_MAX_RECEIVE_MESSAGE_LENGHT=104857600 + +# Maximum message size the server can send, in bytes. +CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH=104857600 + +# Interval for server keepalive pings. +CLIENT_GRPC_KEEPALIVE_TIME=120000 + +# Timeout for server keepalive pings. +CLIENT_GRPC_KEEPALIVE_TIMEOUT=20000 + +# Allow clients to send keepalive pings even when there are no active RPCs. This is important for keeping connections alive through proxies and detecting dead clients. +CLIENT_GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS=True + +# Minimum time between DNS resolutions in milliseconds +CLIENT_GRPC_DNS_RESOLUTION_TIME=500 + +# Initial reconnect backoff time in milliseconds +CLIENT_GRPC_INITIAL_RECONNECT_TIME=1000 + +# Maximum reconnect backoff time in milliseconds +CLIENT_GRPC_MAX_RECONNECT_TIME=10000 -######################################## -# Certificate Settings (CERTIFICATE_) -######################################## +# Minimum reconnect backoff time in milliseconds +CLIENT_GRPC_MIN_RECONNECT_TIME=500 + +# Minimum time between HTTP/2 pings in milliseconds +CLIENT_GRPC_MIN_PING_INTERVAL_TIME=30000 + +# Enable gRPC retries +CLIENT_GRPC_ENABLE_RETRIES=True + + +# ═══════════════════════════ Client Retry Policy Settings ═══════════════════════════ # + +# Maximum retry attempts including the original call +CLIENT_RETRY_POLICY_MAX_ATTEMPTS=5 + +# Initial backoff duration (e.g., '0.1s') +CLIENT_RETRY_POLICY_INITIAL_BACKOFF=0.1 + +# Maximum backoff duration (e.g., '10s') +CLIENT_RETRY_POLICY_MAX_BACKOFF=10 + +# Multiplier for exponential backoff +CLIENT_RETRY_POLICY_BACKOFF_MULTIPLIER=2.0 + + +# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +# ┃ TASK SETTINGS ┃ +# ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + +# Maximum number of concurrent tasks allowed. +TASK_MAX_CONCURRENT_TASKS=100 + +# Maximum time (in seconds) to wait for a task to complete before timing out. +TASK_WAIT_TIMEOUT=30 + +# Maximum time (in seconds) to wait for a stream to drain before forcing closure. +TASK_STREAM_DRAIN_TIMEOUT=60 + +# Maximum number of tasks that can be queued before new tasks are rejected. +TASK_MAX_QUEUED_TASKS=50 + +# Maximum time (in seconds) to wait for a task to be admitted before timing out. +TASK_ADMISSION_TIMEOUT=5 + +# Maximum time (in seconds) to wait for a queue slot to become available before timing out. +TASK_QUEUE_SLOT_TIMEOUT=600 + +# Interval for flushing signals +TASK_SIGNAL_FLUSH_INTERVAL=0.1 + +# Maximum batch size for signals +TASK_SIGNAL_MAX_BATCH_SIZE=50 + +# Number of retries for sending signals +TASK_SIGNAL_MAX_RETRIES=3 + +# Backoff in ms for sending signals +TASK_SIGNAL_SEND_BACKOFF_MS=100.0 + +# Interval for polling signals +TASK_SIGNAL_POLL_INTERVAL=1.0 + +# Initial interval for polling signals +TASK_SIGNAL_INITIAL_POLL_INTERVAL=0.1 + +# gRPC timeout for task manager operations +TASK_GRPC_TIMEOUT=30.0 + +# Timeout for polling operations +TASK_POLL_TIMEOUT=1.0 + +# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +# ┃ Certificate Settings ┃ +# ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ # Directory where server/client certificates are stored -# (default: "/certificates" / local: "./certs") CERTIFICATE_CERT_VOLUME=/certificates # Directory where registry CA certificates are stored -# (default: "/certificates" / local: "./certs") CERTIFICATE_REGISTRY_CERT_VOLUME=/certificates # (Optional) Directory for services_provider client.key, client.crt, and ca.crt -# If unset, defaults to CERTIFICATE_CERT_VOLUME CERTIFICATE_SERVICES_PROVIDER_CERT_VOLUME=/certificates -######################################## -# Langfuse Tracing Configuration -######################################## -# Langfuse credentials for OTEL tracing -# Note: Authentication is handled programmatically using Base64-encoded Basic auth -# (public_key:secret_key), not via OTEL_EXPORTER_OTLP_HEADERS environment variable. +# ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +# ┃ Other Configuration ┃ +# ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + +# ═════════════════════════════════ Archetype Config ═════════════════════════════════ # + +MODULE_ID_TOOLKIT_RAG="modules:1" +LANGFUSE_BASE_URL=https://langfuse.xxx.com + +# ═════════════════════════════════════ Api Keys ═════════════════════════════════════ # + +OPENAI_API_KEY=sk-xxxx LANGFUSE_SECRET_KEY=sk-lf-xxxx LANGFUSE_PUBLIC_KEY=pk-lf-xxxx -LANGFUSE_BASE_URL=https://langfuse.staging.digitalkin.com diff --git a/.gitignore b/.gitignore index c1b69d49..848514e0 100644 --- a/.gitignore +++ b/.gitignore @@ -190,5 +190,8 @@ requirements.txt certs/ .report.json docker-compose.override.yml +AGENTS.md +docker/init_pycharm_helpers.sh +.ai/ CLAUDE.md \ No newline at end of file diff --git a/src/digitalkin/core/task_manager/base_task_manager.py b/src/digitalkin/core/task_manager/base_task_manager.py index c2b4e1f0..9c292ece 100644 --- a/src/digitalkin/core/task_manager/base_task_manager.py +++ b/src/digitalkin/core/task_manager/base_task_manager.py @@ -2,17 +2,17 @@ import asyncio import contextlib -import os import types from abc import ABC, abstractmethod from collections.abc import Coroutine -from typing import Any +from typing import Any, ClassVar from typing_extensions import Self from digitalkin.core.task_manager.task_session import TaskSession from digitalkin.logger import logger from digitalkin.models.core.task_monitor import CancellationReason, SignalMessage, SignalType +from digitalkin.models.settings.client.task import TaskSettings from digitalkin.modules._base_module import BaseModule @@ -26,9 +26,9 @@ class BaseTaskManager(ABC): tasks: dict[str, asyncio.Task] tasks_sessions: dict[str, TaskSession] default_timeout: float - _max_concurrent_tasks: int _shutdown_event: asyncio.Event _tasks_lock: asyncio.Lock + _task_settings: ClassVar[TaskSettings] = TaskSettings() def __init__(self, default_timeout: float = 300.0) -> None: """Initialize task manager properties. @@ -41,40 +41,36 @@ def __init__(self, default_timeout: float = 300.0) -> None: self.default_timeout = default_timeout self._shutdown_event = asyncio.Event() self._tasks_lock = asyncio.Lock() - self._max_concurrent_tasks = int(os.environ.get("DIGITALKIN_MAX_CONCURRENT_TASKS", "100")) - self._task_slot = asyncio.Semaphore(self._max_concurrent_tasks) + self._task_slot = asyncio.Semaphore(self._task_settings.max_concurrent_tasks) self._active_slots = 0 - self._task_wait_timeout = float(os.environ.get("DIGITALKIN_TASK_WAIT_TIMEOUT", "30")) - self._stream_drain_timeout = float(os.environ.get("DIGITALKIN_STREAM_DRAIN_TIMEOUT", "60.0")) self._cleanup_tasks: set[asyncio.Task] = set() # Admission queue: allows tasks to wait for a slot instead of being rejected. # Total in-system capacity = max_concurrent + max_queued. - self._max_queued_tasks = int(os.environ.get("DIGITALKIN_MAX_QUEUED_TASKS", "0")) - self._admission_timeout = float(os.environ.get("DIGITALKIN_ADMISSION_TIMEOUT", "5.0")) - self._queue_slot_timeout = float(os.environ.get("DIGITALKIN_QUEUE_SLOT_TIMEOUT", "600.0")) - self._system_gate = asyncio.Semaphore(self._max_concurrent_tasks + self._max_queued_tasks) + self._system_gate = asyncio.Semaphore( + self._task_settings.max_concurrent_tasks + self._task_settings.max_queued_tasks + ) self._waiting_count = 0 logger.info( "%s initialized (max_concurrent_tasks=%d, max_queued=%d, default_timeout=%.1fs)", self.__class__.__name__, - self._max_concurrent_tasks, - self._max_queued_tasks, + self._task_settings.max_concurrent_tasks, + self._task_settings.max_queued_tasks, default_timeout, ) @property def max_concurrent_tasks(self) -> int: """Maximum number of concurrent tasks.""" - return self._max_concurrent_tasks + return self._task_settings.max_concurrent_tasks @max_concurrent_tasks.setter def max_concurrent_tasks(self, value: int) -> None: - self._max_concurrent_tasks = value + self._task_settings.max_concurrent_tasks = value self._task_slot = asyncio.Semaphore(value) self._active_slots = 0 - self._system_gate = asyncio.Semaphore(value + self._max_queued_tasks) + self._system_gate = asyncio.Semaphore(value + self._task_settings.max_queued_tasks) @property def task_count(self) -> int: @@ -127,7 +123,7 @@ async def _cleanup_task(self, task_id: str, mission_id: str) -> None: finally: self._active_slots -= 1 # Safe: no await between read/write (single-threaded asyncio) self._task_slot.release() - if self._max_queued_tasks > 0: + if self._task_settings.max_queued_tasks > 0: self._system_gate.release() logger.debug( "Task cleaned up (%d remaining)", @@ -177,7 +173,7 @@ async def _acquire_task_slot(self, coro: Coroutine[Any, Any, None]) -> None: Raises: RuntimeError: If the system is at full capacity. """ - if self._max_queued_tasks > 0: + if self._task_settings.max_queued_tasks > 0: await self._acquire_with_queue(coro) else: await self._acquire_direct(coro) @@ -189,19 +185,22 @@ async def _acquire_direct(self, coro: Coroutine[Any, Any, None]) -> None: RuntimeError: If no slot becomes available within the timeout. """ try: - await asyncio.wait_for(self._task_slot.acquire(), timeout=self._task_wait_timeout) + await asyncio.wait_for(self._task_slot.acquire(), timeout=self._task_settings.wait_timeout) except asyncio.TimeoutError: coro.close() - msg = f"Maximum concurrent tasks ({self.max_concurrent_tasks}) reached, waited {self._task_wait_timeout}s" + msg = ( + f"Maximum concurrent tasks ({self.max_concurrent_tasks}) reached, " + f"waited {self._task_settings.wait_timeout}s" + ) raise RuntimeError(msg) from None self._active_slots += 1 # Safe: no await between read/write (single-threaded asyncio) - available = self._max_concurrent_tasks - self._active_slots - if available < self._max_concurrent_tasks * 2 // 10: + available = self._task_settings.max_concurrent_tasks - self._active_slots + if available < self._task_settings.max_concurrent_tasks * 2 // 10: logger.warning( "Task slot capacity low: %d/%d available", available, - self._max_concurrent_tasks, + self._task_settings.max_concurrent_tasks, ) async def _acquire_with_queue(self, coro: Coroutine[Any, Any, None]) -> None: @@ -210,15 +209,16 @@ async def _acquire_with_queue(self, coro: Coroutine[Any, Any, None]) -> None: Raises: RuntimeError: If the system is at full capacity. """ - total_capacity = self._max_concurrent_tasks + self._max_queued_tasks + total_capacity = self._task_settings.max_concurrent_tasks + self._task_settings.max_queued_tasks # Phase 1: Admit into system (fast reject if completely overloaded) try: - await asyncio.wait_for(self._system_gate.acquire(), timeout=self._admission_timeout) + await asyncio.wait_for(self._system_gate.acquire(), timeout=self._task_settings.admission_timeout) except asyncio.TimeoutError: coro.close() msg = ( - f"System at full capacity ({total_capacity} tasks admitted), rejected after {self._admission_timeout}s" + f"System at full capacity ({total_capacity} tasks admitted), " + f"rejected after {self._task_settings.admission_timeout}s" ) raise RuntimeError(msg) from None @@ -229,14 +229,14 @@ async def _acquire_with_queue(self, coro: Coroutine[Any, Any, None]) -> None: "Task queued for execution (%d waiting, %d/%d slots busy)", self._waiting_count, self._active_slots, - self._max_concurrent_tasks, + self._task_settings.max_concurrent_tasks, ) try: - await asyncio.wait_for(self._task_slot.acquire(), timeout=self._queue_slot_timeout) + await asyncio.wait_for(self._task_slot.acquire(), timeout=self._task_settings.queue_slot_timeout) except asyncio.TimeoutError: self._system_gate.release() coro.close() - msg = f"Queued task waited {self._queue_slot_timeout}s for execution slot, giving up" + msg = f"Queued task waited {self._task_settings.queue_slot_timeout}s for execution slot, giving up" raise RuntimeError(msg) from None except BaseException: self._system_gate.release() @@ -246,12 +246,12 @@ async def _acquire_with_queue(self, coro: Coroutine[Any, Any, None]) -> None: self._waiting_count -= 1 self._active_slots += 1 # Safe: no await between read/write (single-threaded asyncio) - available = self._max_concurrent_tasks - self._active_slots - if available < self._max_concurrent_tasks * 2 // 10: + available = self._task_settings.max_concurrent_tasks - self._active_slots + if available < self._task_settings.max_concurrent_tasks * 2 // 10: logger.warning( "Task slot capacity low: %d/%d available", available, - self._max_concurrent_tasks, + self._task_settings.max_concurrent_tasks, ) def _create_session( @@ -312,7 +312,7 @@ async def _deferred_cleanup(self, task_id: str, mission_id: str) -> None: return try: - await asyncio.wait_for(session._stream_closed.wait(), timeout=self._stream_drain_timeout) # noqa: SLF001 + await asyncio.wait_for(session._stream_closed.wait(), timeout=self._task_settings.stream_drain_timeout) # noqa: SLF001 except asyncio.TimeoutError: logger.warning( "Stream drain timeout, proceeding with cleanup", @@ -601,7 +601,7 @@ async def shutdown(self, mission_id: str, timeout: float = 30.0) -> None: # Await any deferred cleanup tasks if self._cleanup_tasks: - await asyncio.gather(*self._cleanup_tasks, return_exceptions=True) + await asyncio.gather(*list(self._cleanup_tasks), return_exceptions=True) self._cleanup_tasks.clear() logger.info( diff --git a/src/digitalkin/grpc_servers/_base_server.py b/src/digitalkin/grpc_servers/_base_server.py index 78e01709..211ee6b1 100644 --- a/src/digitalkin/grpc_servers/_base_server.py +++ b/src/digitalkin/grpc_servers/_base_server.py @@ -416,12 +416,11 @@ async def start_async(self) -> None: raise ServerStateError(msg) from e # Start asyncio-inspector if enabled - if os.environ.get("DIGITALKIN_ASYNCIO_INSPECTOR", "").lower() == "true": + if self._server_settings.asyncio_inspector_state: try: from digitalkin.core.profiling.asyncio_monitor import AsyncioMonitor - port = int(os.environ.get("DIGITALKIN_ASYNCIO_INSPECTOR_PORT", "8765")) - self._asyncio_monitor = AsyncioMonitor(port=port) + self._asyncio_monitor = AsyncioMonitor(port=self._server_settings.asyncio_inspector_port) await self._asyncio_monitor.start() except Exception: logger.exception("Failed to start asyncio-inspector") diff --git a/src/digitalkin/grpc_servers/module_server.py b/src/digitalkin/grpc_servers/module_server.py index f474f584..b84f9039 100644 --- a/src/digitalkin/grpc_servers/module_server.py +++ b/src/digitalkin/grpc_servers/module_server.py @@ -66,7 +66,7 @@ def _register_servicers(self) -> None: raise RuntimeError(msg) logger.debug("Registering module servicer for %s", self.module_class.__name__) - self.module_servicer = ModuleServicer(self.module_class) + self.module_servicer = ModuleServicer(self.module_class, self._server_settings) self.register_servicer( self.module_servicer, module_service_pb2_grpc.add_ModuleServiceServicer_to_server, diff --git a/src/digitalkin/grpc_servers/module_servicer.py b/src/digitalkin/grpc_servers/module_servicer.py index 6480c610..f7f20712 100644 --- a/src/digitalkin/grpc_servers/module_servicer.py +++ b/src/digitalkin/grpc_servers/module_servicer.py @@ -1,7 +1,6 @@ """Module servicer implementation for DigitalKin.""" import asyncio -import os from argparse import ArgumentParser, Namespace from collections.abc import AsyncGenerator from typing import Any, cast @@ -11,7 +10,6 @@ information_pb2, lifecycle_pb2, module_service_pb2_grpc, - monitoring_pb2, ) from google.protobuf import json_format, struct_pb2 from pydantic import ValidationError @@ -20,7 +18,8 @@ from digitalkin.grpc_servers.utils.exceptions import ServerError, ServicerError from digitalkin.logger import logger from digitalkin.models.core.job_manager_models import JobManagerMode -from digitalkin.models.module.module import ModuleCodeModel, ModuleStatus +from digitalkin.models.module.module import ModuleCodeModel +from digitalkin.models.settings.server.server import ServerSettings from digitalkin.modules._base_module import BaseModule from digitalkin.services.registry import GrpcRegistry, RegistryStrategy from digitalkin.services.services_models import ServicesMode @@ -68,14 +67,16 @@ def _add_parser_args(self, parser: ArgumentParser) -> None: help="Define Module job manager configurations for load balancing", ) - def __init__(self, module_class: type[BaseModule]) -> None: + def __init__(self, module_class: type[BaseModule], server_settings: ServerSettings) -> None: """Initialize the module servicer. Args: module_class: The module type to serve. + server_settings: Settings for all server params """ super().__init__() module_class.discover() + self._server_settings: ServerSettings = server_settings self.module_class = module_class job_manager_class = self.args.job_manager_mode.get_manager_class() self.job_manager = job_manager_class(module_class, self.args.services_mode) @@ -87,9 +88,7 @@ def __init__(self, module_class: type[BaseModule]) -> None: ) self.setup = GrpcSetup() if self.args.services_mode == ServicesMode.REMOTE else DefaultSetup() self._setup_cache: dict[str, SetupVersionData] = {} - self._setup_cache_max = int(os.environ.get("DIGITALKIN_SETUP_CACHE_MAX", "100")) self._setup_inflight: dict[str, asyncio.Future[SetupVersionData]] = {} - self._completion_timeout = float(os.environ.get("DIGITALKIN_COMPLETION_TIMEOUT", "300.0")) async def shutdown(self) -> None: """Release servicer-level resources (GrpcSetup channel, registry cache).""" @@ -130,7 +129,7 @@ def _get_registry(self) -> RegistryStrategy | None: def _cache_setup(self, setup_id: str, version_data: SetupVersionData) -> None: """Cache setup version data, evicting oldest entry if at capacity.""" - if len(self._setup_cache) >= self._setup_cache_max: + if len(self._setup_cache) >= self._server_settings.setup_cache_max: oldest_key = next(iter(self._setup_cache)) del self._setup_cache[oldest_key] self._setup_cache[setup_id] = version_data @@ -596,7 +595,7 @@ async def StartModule( # noqa: C901, PLR0911, PLR0912, PLR0915 break finally: try: - completion_timeout = self._completion_timeout + completion_timeout = self._server_settings.completion_timeout await asyncio.wait_for( self.job_manager.wait_for_completion(job_id), timeout=completion_timeout, diff --git a/src/digitalkin/grpc_servers/utils/grpc_client_wrapper.py b/src/digitalkin/grpc_servers/utils/grpc_client_wrapper.py index cfc1e0e6..308134f9 100644 --- a/src/digitalkin/grpc_servers/utils/grpc_client_wrapper.py +++ b/src/digitalkin/grpc_servers/utils/grpc_client_wrapper.py @@ -2,7 +2,6 @@ import asyncio import logging -import os from pathlib import Path from typing import Any, ClassVar @@ -12,6 +11,7 @@ from digitalkin.grpc_servers.utils.exceptions import ServerError from digitalkin.logger import logger from digitalkin.models.grpc_servers.models import ClientConfig +from digitalkin.models.settings.client.client import ClientSettings from digitalkin.models.settings.utils.channel import SecurityMode @@ -32,34 +32,36 @@ class GrpcClientWrapper: _channel_cache_key: str | None = None _channel_cache: ClassVar[dict[str, grpc.aio.Channel]] = {} _ref_counts: ClassVar[dict[str, int]] = {} + _client_settings: ClassVar[ClientSettings] = ClientSettings() _RETRYABLE_CODES: ClassVar[set[grpc.StatusCode]] = { grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.INTERNAL, grpc.StatusCode.DEADLINE_EXCEEDED, } - _QUERY_MAX_RETRIES: ClassVar[int] = int(os.environ.get("DIGITALKIN_GRPC_QUERY_MAX_RETRIES", "2")) - _QUERY_BACKOFF_BASE_MS: ClassVar[float] = float(os.environ.get("DIGITALKIN_GRPC_QUERY_BACKOFF_BASE_MS", "50")) - _QUERY_DEFAULT_TIMEOUT: ClassVar[float] = float(os.environ.get("DIGITALKIN_GRPC_QUERY_TIMEOUT", "30")) @staticmethod - def _build_channel_credentials(config: ClientConfig) -> grpc.ChannelCredentials | None: + def _build_channel_credentials(settings: ClientSettings) -> grpc.ChannelCredentials | None: """Build SSL channel credentials from config if secure mode. Args: - config: Client configuration with security and credential settings. + settings: Client configuration with security and credential settings. Returns: Channel credentials for secure mode, None for insecure. """ - if config.security != SecurityMode.SECURE or config.credentials is None: + if ( + settings.channel.security != SecurityMode.SECURE + or settings.channel.credentials is None + or settings.channel.credentials.root_cert_path is None + ): return None - root_certificates = Path(config.credentials.root_cert_path).read_bytes() + root_certificates = settings.channel.credentials.root_cert_path.read_bytes() private_key = None certificate_chain = None - if config.credentials.client_cert_path is not None and config.credentials.client_key_path is not None: - private_key = Path(config.credentials.client_key_path).read_bytes() - certificate_chain = Path(config.credentials.client_cert_path).read_bytes() + if settings.channel.credentials.cert_path is not None and settings.channel.credentials.key_path is not None: + private_key = Path(settings.channel.credentials.key_path).read_bytes() + certificate_chain = Path(settings.channel.credentials.cert_path).read_bytes() return grpc.ssl_channel_credentials( root_certificates=root_certificates, certificate_chain=certificate_chain, @@ -78,7 +80,10 @@ def _init_channel(self, config: ClientConfig) -> grpc.aio.Channel: Returns: An async gRPC channel (may be shared with other instances). """ - cache_key = f"{config.address}:{config.security.value}:{config.compression.value}" + cache_key = ( + f"{config.address}:{self._client_settings.channel.security}:" + f"{self._client_settings.channel.compression.value}" + ) if cache_key in GrpcClientWrapper._channel_cache: GrpcClientWrapper._ref_counts[cache_key] += 1 channel = GrpcClientWrapper._channel_cache[cache_key] @@ -86,15 +91,15 @@ def _init_channel(self, config: ClientConfig) -> grpc.aio.Channel: self._channel_cache_key = cache_key return channel - credentials = self._build_channel_credentials(config) - grpc_compression = config.compression.to_grpc() + credentials = self._build_channel_credentials(self._client_settings) + grpc_compression = self._client_settings.channel.compression.to_grpc() if credentials is not None: channel = grpc.aio.secure_channel( - config.address, credentials, options=config.grpc_options, compression=grpc_compression + config.address, credentials, options=self._client_settings.grpc.options, compression=grpc_compression ) else: channel = grpc.aio.insecure_channel( - config.address, options=config.grpc_options, compression=grpc_compression + config.address, options=self._client_settings.grpc.options, compression=grpc_compression ) GrpcClientWrapper._channel_cache[cache_key] = channel GrpcClientWrapper._ref_counts[cache_key] = 1 @@ -195,12 +200,14 @@ async def exec_grpc_query( Raises: ServerError: gRPC error with status code and details for caller to handle. """ - effective_timeout = timeout if timeout is not None else self._QUERY_DEFAULT_TIMEOUT - max_retries = self._QUERY_MAX_RETRIES - backoff_delays = tuple(self._QUERY_BACKOFF_BASE_MS / 1000 * (2**i) for i in range(max_retries)) + effective_timeout = timeout if timeout is not None else self._client_settings.query_timeout + backoff_delays = tuple( + self._client_settings.query_backoff_base_ms / 1000 * (2**i) + for i in range(self._client_settings.query_max_retries) + ) last_error: grpc.RpcError | None = None - for attempt in range(max_retries + 1): + for attempt in range(self._client_settings.query_max_retries + 1): if attempt > 0: await asyncio.sleep(backoff_delays[attempt - 1]) @@ -209,7 +216,7 @@ async def exec_grpc_query( response = await getattr(self.stub, query_endpoint)(request, timeout=effective_timeout) except grpc.RpcError as e: last_error = e - if e.code() not in self._RETRYABLE_CODES or attempt == max_retries: + if e.code() not in self._RETRYABLE_CODES or attempt == self._client_settings.query_max_retries: break logger.warning( "gRPC transient error on %s.%s [%s] (attempt %d/%d), retrying in %.0fms", @@ -217,7 +224,7 @@ async def exec_grpc_query( query_endpoint, e.code().name, attempt + 1, - max_retries + 1, + self._client_settings.query_max_retries + 1, backoff_delays[attempt] * 1000, ) else: @@ -229,7 +236,7 @@ async def exec_grpc_query( status_code = last_error.code().name details = last_error.details() retried = last_error.code() in self._RETRYABLE_CODES - suffix = f" (after {max_retries + 1} attempts)" if retried else "" + suffix = f" (after {self._client_settings.query_max_retries + 1} attempts)" if retried else "" log_level = logging.DEBUG if last_error.code() == grpc.StatusCode.NOT_FOUND else logging.ERROR logger.log( diff --git a/src/digitalkin/models/grpc_servers/models.py b/src/digitalkin/models/grpc_servers/models.py index 2f33ee04..8d39c544 100644 --- a/src/digitalkin/models/grpc_servers/models.py +++ b/src/digitalkin/models/grpc_servers/models.py @@ -1,137 +1,13 @@ """Data models for gRPC server configurations.""" -import os -from enum import Enum -from pathlib import Path from typing import Any -import grpc -from pydantic import BaseModel, Field, ValidationInfo, field_validator +from pydantic import BaseModel, Field, field_validator -from digitalkin.grpc_servers.utils.exceptions import ConfigurationError, SecurityError +from digitalkin.grpc_servers.utils.exceptions import ConfigurationError from digitalkin.models.settings.utils.channel import ControlFlow, SecurityMode -class GrpcCompression(str, Enum): - """gRPC compression algorithm. - - Attributes: - NONE: No compression - GZIP: Gzip compression - DEFLATE: Deflate compression - """ - - NONE = "none" - GZIP = "gzip" - DEFLATE = "deflate" - - def to_grpc(self) -> grpc.Compression: - """Convert to grpc.Compression enum. - - Returns: - The corresponding grpc.Compression value. - """ - match self: - case GrpcCompression.NONE: - return grpc.Compression.NoCompression - case GrpcCompression.GZIP: - return grpc.Compression.Gzip - case GrpcCompression.DEFLATE: - return grpc.Compression.Deflate - - -class RetryPolicy(BaseModel): - """gRPC retry policy configuration for resilient connections. - - Attributes: - max_attempts: Maximum retry attempts including the original call - initial_backoff: Initial backoff duration (e.g., "0.1s") - max_backoff: Maximum backoff duration (e.g., "10s") - backoff_multiplier: Multiplier for exponential backoff - retryable_status_codes: gRPC status codes that trigger retry - """ - - max_attempts: int = Field( - default_factory=lambda: int(os.environ.get("DIGITALKIN_GRPC_RETRY_MAX_ATTEMPTS", "5")), - ge=1, - le=10, - description="Maximum retry attempts including the original call", - ) - initial_backoff: str = Field( - default_factory=lambda: os.environ.get("DIGITALKIN_GRPC_RETRY_INITIAL_BACKOFF", "0.1s"), - description="Initial backoff duration (e.g., '0.1s')", - ) - max_backoff: str = Field( - default_factory=lambda: os.environ.get("DIGITALKIN_GRPC_RETRY_MAX_BACKOFF", "10s"), - description="Maximum backoff duration (e.g., '10s')", - ) - backoff_multiplier: float = Field( - default_factory=lambda: float(os.environ.get("DIGITALKIN_GRPC_RETRY_BACKOFF_MULTIPLIER", "2.0")), - ge=1.0, - description="Multiplier for exponential backoff", - ) - retryable_status_codes: list[str] = Field( - default_factory=lambda: ["UNAVAILABLE", "RESOURCE_EXHAUSTED", "DEADLINE_EXCEEDED"], - description="gRPC status codes that trigger retry", - ) - - model_config = {"extra": "forbid", "frozen": True} - - def to_service_config_json(self) -> str: - """Serialize to gRPC service config JSON string. - - Returns: - JSON string for grpc.service_config channel option. - """ - codes = "[" + ",".join(f'"{c}"' for c in self.retryable_status_codes) + "]" - return ( - f'{{"methodConfig":[{{"name":[{{}}],"retryPolicy":{{"maxAttempts":{self.max_attempts},' - f'"initialBackoff":"{self.initial_backoff}","maxBackoff":"{self.max_backoff}",' - f'"backoffMultiplier":{self.backoff_multiplier},"retryableStatusCodes":{codes}}}}}]}}' - ) - - -class ClientCredentials(BaseModel): - """Model for client credentials in secure mode. - - Attributes: - root_cert_path: path to the root certificate - client_key_path: Path to the client private key - client_cert_path: Path to the client certificate - """ - - root_cert_path: Path = Field(..., description="Path to the root certificate") - client_key_path: Path | None = Field(None, description="Path to the client private key | mTLS enable") - client_cert_path: Path | None = Field(None, description="Path to the client certificate | mTLS enable") - - # Enable __slots__ for memory efficiency - model_config = { - "extra": "forbid", - "arbitrary_types_allowed": True, - "validate_assignment": True, - "frozen": True, - } - - @field_validator("client_key_path", "client_cert_path", "root_cert_path") - @classmethod - def check_path_exists(cls, v: Path | None) -> Path | None: - """Validate that the file path exists. - - Args: - v: Path to validate - - Returns: - The validated path - - Raises: - SecurityError: If the path does not exist - """ - if v is not None and not v.exists(): - msg = f"File not found: {v}" - raise SecurityError(msg) - return v - - class ChannelConfig(BaseModel): """Base configuration for gRPC channels. @@ -195,69 +71,8 @@ class ClientConfig(ChannelConfig): port: Port to listen on mode: Client operation mode (sync/async) security: Security mode (secure/insecure) - credentials: Client credentials for secure mode - channel_options: Additional channel options - retry_policy: Retry policy for failed RPCs - compression: gRPC compression algorithm for channel-level compression """ - credentials: ClientCredentials | None = Field(None, description="Client credentials for secure mode") - retry_policy: RetryPolicy = Field(default_factory=RetryPolicy, description="Retry policy for failed RPCs") - compression: GrpcCompression = Field(GrpcCompression.GZIP, description="gRPC compression algorithm") - channel_options: list[tuple[str, Any]] = Field( - default_factory=lambda: [ - ("grpc.max_receive_message_length", 100 * 1024 * 1024), - ("grpc.max_send_message_length", 100 * 1024 * 1024), - # === DNS Re-resolution (Critical for Container Environments) === - ( - "grpc.dns_min_time_between_resolutions_ms", - int(os.environ.get("DIGITALKIN_GRPC_DNS_RESOLUTION_MS", "500")), - ), - ("grpc.initial_reconnect_backoff_ms", int(os.environ.get("DIGITALKIN_GRPC_INITIAL_RECONNECT_MS", "1000"))), - ("grpc.max_reconnect_backoff_ms", int(os.environ.get("DIGITALKIN_GRPC_MAX_RECONNECT_MS", "10000"))), - ("grpc.min_reconnect_backoff_ms", int(os.environ.get("DIGITALKIN_GRPC_MIN_RECONNECT_MS", "500"))), - # === Keepalive Settings (Detect Dead Connections) === - ("grpc.keepalive_time_ms", int(os.environ.get("DIGITALKIN_GRPC_KEEPALIVE_TIME_MS", "60000"))), - ("grpc.keepalive_timeout_ms", int(os.environ.get("DIGITALKIN_GRPC_KEEPALIVE_TIMEOUT_MS", "20000"))), - ("grpc.keepalive_permit_without_calls", True), - ( - "grpc.http2.min_time_between_pings_ms", - int(os.environ.get("DIGITALKIN_GRPC_MIN_PING_INTERVAL_MS", "30000")), - ), - # === Retry Configuration === - ("grpc.enable_retries", 1), - ], - description="Resilient gRPC channel options with DNS re-resolution, keepalive, and retries", - ) - - @field_validator("credentials") - @classmethod - def validate_credentials(cls, v: ClientCredentials | None, info: ValidationInfo) -> ClientCredentials | None: - """Validate that credentials are provided when in secure mode. - - Args: - v: The credentials value - info: ValidationInfo containing other field values - - Returns: - The validated credentials - - Raises: - ConfigurationError: If credentials are missing in secure mode - """ - # Access security mode from the info.data dictionary - security = info.data.get("security") - - if security == SecurityMode.SECURE and v is None: - msg = "Credentials must be provided when using secure mode" - raise ConfigurationError(msg) - return v - - @property - def grpc_options(self) -> list[tuple[str, Any]]: - """Get channel options with retry policy service config. - - Returns: - Full list of gRPC channel options. - """ - return [*self.channel_options, ("grpc.service_config", self.retry_policy.to_service_config_json())] + def __init__(self, /, **data: Any) -> None: + """Client config constructor.""" + super().__init__(**data) diff --git a/src/digitalkin/models/settings/client/__init__.py b/src/digitalkin/models/settings/client/__init__.py new file mode 100644 index 00000000..063b5dce --- /dev/null +++ b/src/digitalkin/models/settings/client/__init__.py @@ -0,0 +1 @@ +"""client settings packages.""" diff --git a/src/digitalkin/models/settings/client/channel.py b/src/digitalkin/models/settings/client/channel.py new file mode 100644 index 00000000..ac64d972 --- /dev/null +++ b/src/digitalkin/models/settings/client/channel.py @@ -0,0 +1,27 @@ +"""Client Channel settings.""" + +from typing import Any + +from pydantic_settings import SettingsConfigDict + +from digitalkin.models.settings.utils.channel import BaseChannelSettings + + +class ClientChannelSettings(BaseChannelSettings): + """Client channel settings.""" + + model_config = SettingsConfigDict( + env_prefix="CLIENT_CHANNEL_", + env_nested_delimiter="__", + extra="forbid", + arbitrary_types_allowed=True, + validate_assignment=True, + ) + + # ── Options ───────────────────────────────────────────────────────────────────── # + + # ── Functions ─────────────────────────────────────────────────────────────────── # + + def __init__(self, **values: Any) -> None: + """Default constructor.""" + super().__init__(**values) diff --git a/src/digitalkin/models/settings/client/client.py b/src/digitalkin/models/settings/client/client.py new file mode 100644 index 00000000..3b527f52 --- /dev/null +++ b/src/digitalkin/models/settings/client/client.py @@ -0,0 +1,34 @@ +"""Top-level gRPC client settings.""" + +# utiliser channel +from typing import Any + +from pydantic import Field, NonNegativeFloat, NonNegativeInt +from pydantic_settings import BaseSettings, SettingsConfigDict + +from digitalkin.models.settings.client.channel import ClientChannelSettings +from digitalkin.models.settings.client.grpc_client import ClientGrpcSettings +from digitalkin.models.settings.client.task import TaskSettings + + +class ClientSettings(BaseSettings): + """Top-level gRPC client settings.""" + + model_config = SettingsConfigDict(env_prefix="CLIENT_", case_sensitive=False) + + # ── Options ───────────────────────────────────────────────────────────────────── # + + channel: ClientChannelSettings = Field(default_factory=ClientChannelSettings) + grpc: ClientGrpcSettings = Field(default_factory=ClientGrpcSettings) + task: TaskSettings = Field(default_factory=TaskSettings) + query_max_retries: NonNegativeInt = Field(default=2, description="Maximum number of retries for queries") + query_backoff_base_ms: NonNegativeFloat = Field( + default=50, description="Base backoff time in milliseconds for query retries" + ) + query_timeout: NonNegativeFloat = Field(default=30, description="Timeout in seconds for queries") + + # ── Functions ─────────────────────────────────────────────────────────────────── # + + def __init__(self, **values: Any) -> None: + """Initialize client settings.""" + super().__init__(**values) diff --git a/src/digitalkin/models/settings/client/grpc_client.py b/src/digitalkin/models/settings/client/grpc_client.py new file mode 100644 index 00000000..d5934366 --- /dev/null +++ b/src/digitalkin/models/settings/client/grpc_client.py @@ -0,0 +1,58 @@ +"""gRPC client settings for the SDK.""" + +from typing import Any + +from pydantic import Field, NonNegativeInt +from pydantic_settings import SettingsConfigDict + +from digitalkin.models.settings.client.retry_policy import ClientRetryPolicySettings +from digitalkin.models.settings.utils.grpc_base import BaseGrpcSettings + + +class ClientGrpcSettings(BaseGrpcSettings): + """gRPC tuning settings for SDK clients.""" + + model_config = SettingsConfigDict( + env_prefix="CLIENT_GRPC_", + env_nested_delimiter="__", + extra="forbid", + arbitrary_types_allowed=True, + validate_assignment=True, + ) + + # ── Options ───────────────────────────────────────────────────────────────────── # + + dns_resolution_time: NonNegativeInt = Field( + default=500, description="Minimum time between DNS resolutions in milliseconds" + ) + initial_reconnect_time: NonNegativeInt = Field( + default=1000, description="Initial reconnect backoff time in milliseconds" + ) + max_reconnect_time: NonNegativeInt = Field( + default=10000, description="Maximum reconnect backoff time in milliseconds" + ) + min_reconnect_time: NonNegativeInt = Field( + default=500, description="Minimum reconnect backoff time in milliseconds" + ) + min_ping_interval_time: NonNegativeInt = Field( + default=30000, description="Minimum time between HTTP/2 pings in milliseconds" + ) + enable_retries: bool = Field(default=True, description="Enable gRPC retries") + retry_policy: ClientRetryPolicySettings = Field(default_factory=ClientRetryPolicySettings) + + @property + def _specific_options(self) -> list[tuple[str, Any]]: + """Return client specific gRPC options. + + Returns: + List of tuples containing client specific gRPC options. + """ + return [ + ("grpc.dns_min_time_between_resolutions_ms", self.dns_resolution_time), + ("grpc.initial_reconnect_backoff_ms", self.initial_reconnect_time), + ("grpc.max_reconnect_backoff_ms", self.max_reconnect_time), + ("grpc.min_reconnect_backoff_ms", self.min_reconnect_time), + ("grpc.http2.min_time_between_pings_ms", self.min_ping_interval_time), + ("grpc.enable_retries", int(self.enable_retries)), + self.retry_policy.to_grpc_option(), + ] diff --git a/src/digitalkin/models/settings/client/retry_policy.py b/src/digitalkin/models/settings/client/retry_policy.py new file mode 100644 index 00000000..228bd1aa --- /dev/null +++ b/src/digitalkin/models/settings/client/retry_policy.py @@ -0,0 +1,92 @@ +"""Retry policy settings for gRPC clients.""" + +import json +from typing import Any + +from grpc import StatusCode +from pydantic import Field, NonNegativeFloat, NonNegativeInt +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class ClientRetryPolicySettings(BaseSettings): + """Retry policy settings used to build gRPC service config.""" + + model_config = SettingsConfigDict( + env_prefix="CLIENT_RETRY_POLICY_", + case_sensitive=False, + frozen=True, + extra="forbid", + ) + + # ── Options ───────────────────────────────────────────────────────────────────── # + + max_attempts: NonNegativeInt = Field( + default=5, + ge=1, + le=10, + description="Maximum retry attempts including the original call", + ) + initial_backoff: NonNegativeFloat = Field( + default=0.1, + description="Initial backoff duration (e.g., '0.1s')", + ) + max_backoff: NonNegativeInt = Field( + default=10, + description="Maximum backoff duration (e.g., '10s')", + ) + backoff_multiplier: NonNegativeFloat = Field( + default=2.0, + ge=1.0, + description="Multiplier for exponential backoff", + ) + + # ── Functions ─────────────────────────────────────────────────────────────────── # + + def to_grpc_service_config(self) -> dict[str, Any]: + """Build gRPC service config dictionary for retries. + + Returns: + Service config dictionary compatible with ``grpc.service_config``. + """ + return { + "methodConfig": [ + { + "name": [{}], + "retryPolicy": { + "maxAttempts": self.max_attempts, + "initialBackoff": f"{self.initial_backoff:g}s", + "maxBackoff": f"{self.max_backoff}s", + "backoffMultiplier": self.backoff_multiplier, + "retryableStatusCodes": [ + str(StatusCode.UNAVAILABLE.name), + str(StatusCode.RESOURCE_EXHAUSTED.name), + str(StatusCode.DEADLINE_EXCEEDED.name), + ], + }, + } + ] + } + + def to_grpc_service_config_json(self) -> str: + """Serialize retry policy as ``grpc.service_config`` JSON string. + + Returns: + Compact JSON string for gRPC channel option value. + """ + return json.dumps(self.to_grpc_service_config(), separators=(",", ":")) + + def to_grpc_option(self) -> tuple[str, str]: + """Build the gRPC channel option tuple for retry policy. + + Returns: + Tuple usable directly in gRPC channel options. + """ + return "grpc.service_config", self.to_grpc_service_config_json() + + def to_service_config_json(self) -> str: + """Backward-compatible alias for legacy retry policy naming. + + Returns: + JSON string for gRPC service config. + """ + return self.to_grpc_service_config_json() diff --git a/src/digitalkin/models/settings/client/task.py b/src/digitalkin/models/settings/client/task.py new file mode 100644 index 00000000..1866b3c9 --- /dev/null +++ b/src/digitalkin/models/settings/client/task.py @@ -0,0 +1,47 @@ +"""Settings for task manager.""" + +from typing import Any + +from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class TaskSettings(BaseSettings): + """Task settings.""" + + model_config = SettingsConfigDict(env_prefix="TASK_", case_sensitive=False) + + # ── Options ───────────────────────────────────────────────────────────────────── # + + max_concurrent_tasks: PositiveInt = Field(default=100, description="Maximum number of concurrent tasks allowed.") + wait_timeout: PositiveFloat = Field( + default=30, description="Maximum time (in seconds) to wait for a task to complete before timing out." + ) + stream_drain_timeout: PositiveInt = Field( + default=60, description="Maximum time (in seconds) to wait for a stream to drain before forcing closure." + ) + max_queued_tasks: NonNegativeInt = Field( + default=50, description="Maximum number of tasks that can be queued before new tasks are rejected." + ) + admission_timeout: PositiveInt = Field( + default=5, description="Maximum time (in seconds) to wait for a task to be admitted before timing out." + ) + queue_slot_timeout: PositiveInt = Field( + default=600, + description="Maximum time (in seconds) to wait for a queue slot to become available before timing out.", + ) + + signal_flush_interval: PositiveFloat = Field(default=0.1, description="Interval for flushing signals") + signal_max_batch_size: PositiveInt = Field(default=50, description="Maximum batch size for signals") + signal_max_retries: NonNegativeInt = Field(default=3, description="Number of retries for sending signals") + signal_send_backoff_ms: PositiveFloat = Field(default=100.0, description="Backoff in ms for sending signals") + signal_poll_interval: PositiveFloat = Field(default=1.0, description="Interval for polling signals") + signal_initial_poll_interval: PositiveFloat = Field(default=0.1, description="Initial interval for polling signals") + grpc_timeout: PositiveFloat = Field(default=30.0, description="gRPC timeout for task manager operations") + poll_timeout: PositiveFloat = Field(default=1.0, description="Timeout for polling operations") + + # ── Functions ─────────────────────────────────────────────────────────────────── # + + def __init__(self, **values: Any) -> None: + """Default constructor.""" + super().__init__(**values) diff --git a/src/digitalkin/models/settings/server/channel.py b/src/digitalkin/models/settings/server/channel.py index d002fc35..be89b996 100644 --- a/src/digitalkin/models/settings/server/channel.py +++ b/src/digitalkin/models/settings/server/channel.py @@ -25,11 +25,13 @@ class ServerChannelSettings(BaseChannelSettings): validate_assignment=True, ) + # ── Options ───────────────────────────────────────────────────────────────────── # + advertise_host: str | None = Field( - None, description="Public hostname/IP sent to registry for discovery. Falls back to host if not set." + default=None, description="Public hostname/IP sent to registry for discovery. Falls back to host if not set." ) - database_url: str | None = Field(None, description="Database URL for registry data storage") + # ── Functions ─────────────────────────────────────────────────────────────────── # def __init__(self, **values: Any) -> None: """Initialize ServerChannelSettings with default credentials if not provided.""" diff --git a/src/digitalkin/models/settings/server/grpc.py b/src/digitalkin/models/settings/server/grpc_server.py similarity index 98% rename from src/digitalkin/models/settings/server/grpc.py rename to src/digitalkin/models/settings/server/grpc_server.py index c42057c0..001559cb 100644 --- a/src/digitalkin/models/settings/server/grpc.py +++ b/src/digitalkin/models/settings/server/grpc_server.py @@ -5,7 +5,7 @@ from pydantic import Field, NonNegativeFloat from pydantic_settings import BaseSettings, SettingsConfigDict -from digitalkin.models.grpc_servers.models import GrpcCompression +from digitalkin.models.settings.utils.models import GrpcCompression class GrpcServerSettings(BaseSettings): diff --git a/src/digitalkin/models/settings/server/server.py b/src/digitalkin/models/settings/server/server.py index 23a9fb82..aaa86f12 100644 --- a/src/digitalkin/models/settings/server/server.py +++ b/src/digitalkin/models/settings/server/server.py @@ -7,7 +7,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from digitalkin.models.settings.server.channel import ServerChannelSettings -from digitalkin.models.settings.server.grpc import GrpcServerSettings +from digitalkin.models.settings.server.grpc_server import GrpcServerSettings class ServerSettings(BaseSettings): @@ -15,7 +15,7 @@ class ServerSettings(BaseSettings): Attributes: channel (ServerChannelSettings): Settings for the server channel. - grpc (GrpcServerSettings): Settings for the gRPC server. + grpc (ServerGrpcSettings): Settings for the gRPC server. health_check (bool): Whether to enable the health check service. reflection (bool): Whether to enable reflection for the server. max_concurrent_rpcs (NonNegativeInt): Maximum number of RPCs handled in parallel by the server. @@ -26,21 +26,29 @@ class ServerSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="SERVER_", case_sensitive=False) - channel: ServerChannelSettings = Field(default_factory=ServerChannelSettings) + # ── Options ───────────────────────────────────────────────────────────────────── # + channel: ServerChannelSettings = Field(default_factory=ServerChannelSettings) grpc: GrpcServerSettings = Field(default_factory=GrpcServerSettings) - health_check: bool = Field(default=True, description="Enable health check service") reflection: bool = Field(default=True, description="Enable reflection for the server") max_concurrent_rpcs: NonNegativeInt = Field( - (os.cpu_count() or 1) * 200, + default=(os.cpu_count() or 1) * 200, description="Maximum number of RPCs handled in parallel by the server.", ) - max_workers: NonNegativeInt = Field(10, description="Maximum number of workers for sync mode") + max_workers: NonNegativeInt = Field(default=10, description="Maximum number of workers for sync mode") thread_pool_workers: NonNegativeInt = Field( - min(4, os.cpu_count() or 1), + default=min(4, os.cpu_count() or 1), description="Number of workers in the server thread pool.", ) + asyncio_inspector_state: bool = Field(default=False, description="Enable asyncio inspector") + asyncio_inspector_port: int = Field(default=8765, description="Port for asyncio inspector") + setup_cache_max: NonNegativeInt = Field( + default=100, description="Maximum number of cached setups in ModuleServicer" + ) + completion_timeout: float = Field(default=300.0, description="Timeout for completion in ModuleServicer") + + # ── Functions ─────────────────────────────────────────────────────────────────── # def __init__(self, **values: Any) -> None: """Initialize the ServerSettings instance.""" diff --git a/src/digitalkin/models/settings/utils/channel.py b/src/digitalkin/models/settings/utils/channel.py index fb7ae873..f0da799b 100644 --- a/src/digitalkin/models/settings/utils/channel.py +++ b/src/digitalkin/models/settings/utils/channel.py @@ -1,66 +1,12 @@ """This file define channelBase for grpc config.""" -from enum import Enum -from pathlib import Path from typing import Any -from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, field_validator, model_validator +from pydantic import Field, NonNegativeInt, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from digitalkin.grpc_servers.utils.exceptions import ConfigurationError, SecurityError - - -class ControlFlow(str, Enum): - """Enum for server operation mode.""" - - SYNC = "sync" - ASYNC = "async" - - -class SecurityMode(str, Enum): - """Enum for server security mode.""" - - SECURE = "secure" - INSECURE = "insecure" - - -class Credentials(BaseModel): - """Model for server credentials in secure mode. - - Attributes: - key_path: Path to the server private key - cert_path: Path to the server certificate - root_cert_path: Optional path to the root certificate - """ - - model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True, validate_assignment=True, frozen=True) - - key_path: Path | None = Field(default=None, description="Path to the private key") - cert_path: Path | None = Field(default=None, description="Path to the certificate") - root_cert_path: Path | None = Field(default=None, description="Path to the root certificate") - - def __init__(self, /, **data: Any) -> None: - """Initialize the Credentials model.""" - super().__init__(**data) - - @field_validator("key_path", "cert_path", "root_cert_path") - @classmethod - def check_path_exists(cls, v: Path | None) -> Path | None: - """Validate that the file path exists. - - Args: - v: Path to validate - - Returns: - The validated path - - Raises: - SecurityError: If the path does not exist - """ - if v is not None and not v.exists(): - msg = f"File not found: {v}" - raise SecurityError(msg) - return v +from digitalkin.grpc_servers.utils.exceptions import ConfigurationError +from digitalkin.models.settings.utils.models import ControlFlow, Credentials, GrpcCompression, SecurityMode class BaseChannelSettings(BaseSettings): @@ -68,12 +14,17 @@ class BaseChannelSettings(BaseSettings): model_config = SettingsConfigDict(extra="forbid", arbitrary_types_allowed=True, validate_assignment=True) + # ── Options ───────────────────────────────────────────────────────────────────── # + host: str = Field("[::]", description="Host address to bind the client to") port: NonNegativeInt = Field(50055, description="Port to listen on") communication_mode: ControlFlow = Field(ControlFlow.ASYNC, description="Client/Server operation mode (sync/async)") credentials: Credentials | None = Field(None, description="Client credentials for secure mode") security: SecurityMode = Field(SecurityMode.INSECURE, description="Security mode (secure/insecure)") mtls: bool = Field(default=False, description="Enable mutual TLS") + compression: GrpcCompression = Field(GrpcCompression.GZIP, description="gRPC compression algorithm") + + # ── Functions ─────────────────────────────────────────────────────────────────── # def __init__(self, **values: Any) -> None: """Initialize the BaseChannelSettings model.""" diff --git a/src/digitalkin/models/settings/utils/grpc_base.py b/src/digitalkin/models/settings/utils/grpc_base.py new file mode 100644 index 00000000..9ac71213 --- /dev/null +++ b/src/digitalkin/models/settings/utils/grpc_base.py @@ -0,0 +1,60 @@ +"""Shared gRPC settings models.""" + +from typing import Any + +from pydantic import Field, NonNegativeInt +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class BaseGrpcSettings(BaseSettings): + """Base settings shared by client and server gRPC configurations.""" + + model_config = SettingsConfigDict(extra="forbid", arbitrary_types_allowed=True, validate_assignment=True) + + # ── Options ───────────────────────────────────────────────────────────────────── # + + max_receive_message_lenght: NonNegativeInt = Field( + default=100 * 1024 * 1024, description="Maximum message size the server can receive, in bytes." + ) + max_send_message_length: NonNegativeInt = Field( + default=100 * 1024 * 1024, description="Maximum message size the server can send, in bytes." + ) + keepalive_time: NonNegativeInt = Field(default=120000, description="Interval for server keepalive pings.") + keepalive_timeout: NonNegativeInt = Field(default=20000, description="Timeout for server keepalive pings.") + keepalive_permit_without_calls: bool = Field( + default=True, + description="Allow clients to send keepalive pings even when there are no active RPCs. " + "This is important for keeping connections " + "alive through proxies and detecting dead clients.", + ) + + # ── Functions ─────────────────────────────────────────────────────────────────── # + + def __init__(self, **values: Any) -> None: + """Default constructor.""" + super().__init__(**values) + + @property + def options(self) -> list[tuple[str, Any]]: + """Convert settings to gRPC options format. + + Returns: + List of tuples containing gRPC options and their corresponding values. + """ + return [ + ("grpc.max_receive_message_length", self.max_receive_message_lenght), + ("grpc.max_send_message_length", self.max_send_message_length), + ("grpc.keepalive_time_ms", self.keepalive_time), + ("grpc.keepalive_timeout_ms", self.keepalive_timeout), + ("grpc.keepalive_permit_without_calls", self.keepalive_permit_without_calls), + *self._specific_options, + ] + + @property + def _specific_options(self) -> list[tuple[str, Any]]: + """Return settings specific to a gRPC side. + + Returns: + List of tuples containing side-specific gRPC options. + """ + return [] diff --git a/src/digitalkin/models/settings/utils/models.py b/src/digitalkin/models/settings/utils/models.py new file mode 100644 index 00000000..2a8a8b64 --- /dev/null +++ b/src/digitalkin/models/settings/utils/models.py @@ -0,0 +1,91 @@ +"""Models for settings.""" + +from enum import Enum +from pathlib import Path +from typing import Any + +import grpc +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from digitalkin.grpc_servers.utils.exceptions import SecurityError + + +class GrpcCompression(str, Enum): + """gRPC compression algorithm. + + Attributes: + NONE: No compression + GZIP: Gzip compression + DEFLATE: Deflate compression + """ + + NONE = "none" + GZIP = "gzip" + DEFLATE = "deflate" + + def to_grpc(self) -> grpc.Compression: + """Convert to grpc.Compression enum. + + Returns: + The corresponding grpc.Compression value. + """ + match self: + case GrpcCompression.NONE: + return grpc.Compression.NoCompression + case GrpcCompression.GZIP: + return grpc.Compression.Gzip + case GrpcCompression.DEFLATE: + return grpc.Compression.Deflate + + +class ControlFlow(str, Enum): + """Enum for server operation mode.""" + + SYNC = "sync" + ASYNC = "async" + + +class SecurityMode(str, Enum): + """Enum for server security mode.""" + + SECURE = "secure" + INSECURE = "insecure" + + +class Credentials(BaseModel): + """Model for server credentials in secure mode. + + Attributes: + key_path: Path to the server private key + cert_path: Path to the server certificate + root_cert_path: Optional path to the root certificate + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True, validate_assignment=True, frozen=True) + + key_path: Path | None = Field(default=None, description="Path to the private key") + cert_path: Path | None = Field(default=None, description="Path to the certificate") + root_cert_path: Path | None = Field(default=None, description="Path to the root certificate") + + def __init__(self, /, **data: Any) -> None: + """Initialize the Credentials model.""" + super().__init__(**data) + + @field_validator("key_path", "cert_path", "root_cert_path") + @classmethod + def check_path_exists(cls, v: Path | None) -> Path | None: + """Validate that the file path exists. + + Args: + v: Path to validate + + Returns: + The validated path + + Raises: + SecurityError: If the path does not exist + """ + if v is not None and not v.exists(): + msg = f"File not found: {v}" + raise SecurityError(msg) + return v diff --git a/src/digitalkin/services/communication/grpc_communication.py b/src/digitalkin/services/communication/grpc_communication.py index 7b558f06..d654eeb1 100644 --- a/src/digitalkin/services/communication/grpc_communication.py +++ b/src/digitalkin/services/communication/grpc_communication.py @@ -68,11 +68,6 @@ def _get_or_create_channel(self, module_address: str, module_port: int) -> grpc. config = ClientConfig( host=module_address, port=module_port, - mode=self.client_config.mode, - security=self.client_config.security, - credentials=self.client_config.credentials, - compression=self.client_config.compression, - channel_options=self.client_config.channel_options, ) channel = self._init_channel(config) if self._channel_cache_key is not None: diff --git a/src/digitalkin/services/task_manager/grpc_task_manager.py b/src/digitalkin/services/task_manager/grpc_task_manager.py index 68ffec18..373dbd41 100644 --- a/src/digitalkin/services/task_manager/grpc_task_manager.py +++ b/src/digitalkin/services/task_manager/grpc_task_manager.py @@ -4,7 +4,6 @@ import asyncio import contextlib -import os import random import uuid from collections.abc import Awaitable, Callable @@ -30,6 +29,7 @@ from collections.abc import AsyncGenerator from digitalkin.models.grpc_servers.models import ClientConfig + from digitalkin.models.settings.client.task import TaskSettings _PollFn = Callable[[list[str]], Awaitable[list[task_manager_message_pb2.Task]]] @@ -149,18 +149,17 @@ def __init__( self._task_queues: dict[str, asyncio.Queue[task_manager_message_pb2.Task | None]] = {} self._last_seen_ts: dict[str, tuple[int, int]] = {} - def register(self, task_id: str) -> asyncio.Queue[task_manager_message_pb2.Task | None]: + def register(self, task_id: str, max_queue_size: int = 512) -> asyncio.Queue[task_manager_message_pb2.Task | None]: """Register a task_id for polling. Returns queue for signal delivery. Args: task_id: Unique task identifier. + max_queue_size: Maximum size of the signal queue. Returns: asyncio.Queue[task_manager_message_pb2.Task | None]: Queue for signal delivery. """ - queue: asyncio.Queue[task_manager_message_pb2.Task | None] = asyncio.Queue( - maxsize=int(os.environ.get("DIGITALKIN_SIGNAL_QUEUE_SIZE", "512")) - ) + queue: asyncio.Queue[task_manager_message_pb2.Task | None] = asyncio.Queue(maxsize=max_queue_size) self._task_queues[task_id] = queue if self._task is None or self._task.done(): # Recreate stop_event in the current event loop (the old one may belong to a closed loop) @@ -292,31 +291,28 @@ class _SharedSendBuffer(_SharedChannelResource): _instances: ClassVar[dict[str, _SharedSendBuffer]] = {} @classmethod - def get_or_create(cls, key: str, stub: Any, grpc_timeout: float) -> _SharedSendBuffer: + def get_or_create(cls, key: str, stub: Any, settings: TaskSettings) -> _SharedSendBuffer: """Get existing buffer for this channel key or create a new one. Args: key: Unique channel identifier. stub: gRPC stub for SendSignals calls. - grpc_timeout: Seconds before the RPC times out. + settings: Settings for tasks and shared channel ressources Returns: _SharedSendBuffer: Shared buffer for this channel. """ if key not in cls._instances: - cls._instances[key] = cls(stub, grpc_timeout) + cls._instances[key] = cls(stub, settings) inst = cls._instances[key] inst._refcount += 1 # noqa: SLF001 return inst - def __init__(self, stub: Any, grpc_timeout: float) -> None: + def __init__(self, stub: Any, settings: TaskSettings) -> None: super().__init__() self._stub = stub - self._grpc_timeout = grpc_timeout - self._flush_interval = float(os.environ.get("DIGITALKIN_SIGNAL_FLUSH_INTERVAL", "0.1")) - self._max_batch_size = int(os.environ.get("DIGITALKIN_SIGNAL_MAX_BATCH_SIZE", "50")) - self._max_retries = int(os.environ.get("DIGITALKIN_SIGNAL_SEND_RETRIES", "3")) - self._backoff_base = float(os.environ.get("DIGITALKIN_SIGNAL_SEND_BACKOFF_MS", "100")) / 1000 + self._grpc_timeout = settings.grpc_timeout + self._task_settings = settings # List of (proto, future) pairs pending a flush. Swapped atomically in _flush(). self._pending: list[tuple[task_manager_message_pb2.Task, asyncio.Future[bool]]] = [] @@ -335,7 +331,7 @@ async def send(self, task_proto: task_manager_message_pb2.Task) -> bool: future: asyncio.Future[bool] = asyncio.get_running_loop().create_future() self._pending.append((task_proto, future)) - if len(self._pending) >= self._max_batch_size: + if len(self._pending) >= self._task_settings.signal_max_batch_size: # Batch full — flush immediately without waiting for the timer. await self._flush() elif self._task is None or self._task.done(): @@ -350,7 +346,7 @@ async def _flush_after_interval(self) -> None: stop_event = self._stop_event try: stop_wait = asyncio.create_task(stop_event.wait()) - done, _ = await asyncio.wait([stop_wait], timeout=self._flush_interval) + done, _ = await asyncio.wait([stop_wait], timeout=self._task_settings.signal_flush_interval) if not done: stop_wait.cancel() await self._flush() @@ -375,7 +371,7 @@ async def _flush(self) -> None: futures = [f for _, f in batch] exc: Exception | None = None - for attempt in range(1 + self._max_retries): + for attempt in range(1 + self._task_settings.signal_max_retries): exc = None try: req = task_manager_dto_pb2.SendSignalsRequest(tasks=task_protos) @@ -385,13 +381,13 @@ async def _flush(self) -> None: break # Server rejected — not retryable break # Success except grpc.aio.AioRpcError as e: - if e.code() in _RETRYABLE_CODES and attempt < self._max_retries: - delay = self._backoff_base * (2**attempt) + if e.code() in _RETRYABLE_CODES and attempt < self._task_settings.signal_max_retries: + delay = self._task_settings.signal_send_backoff_ms * (2**attempt) jitter = random.uniform(0, delay * 0.5) # noqa: S311 logger.warning( "SendSignals attempt %d/%d failed (%s), retrying in %.0fms", attempt + 1, - 1 + self._max_retries, + 1 + self._task_settings.signal_max_retries, e.code().name, (delay + jitter) * 1000, ) @@ -440,8 +436,8 @@ def __init__( setup_version_id: str, # noqa: ARG002 client_config: ClientConfig, *, - poll_interval: float = float(os.environ.get("DIGITALKIN_SIGNAL_POLL_INTERVAL", "1.0")), - initial_poll_interval: float = float(os.environ.get("DIGITALKIN_SIGNAL_INITIAL_POLL_INTERVAL", "0.1")), + poll_interval: float | None = None, + initial_poll_interval: float | None = None, ) -> None: """Initialize with client config. @@ -463,13 +459,12 @@ def __init__( ) raise ImportError(msg) channel = self._init_channel(client_config) + self._task_settings = self._client_settings.task self.stub = task_manager_service_pb2_grpc.TaskManagerServiceStub(channel) self._subscriptions = {} self._sub_task_ids = {} - self._poll_interval = poll_interval - self._initial_poll_interval = initial_poll_interval - self._grpc_timeout = float(os.environ.get("DIGITALKIN_GRPC_TIMEOUT", "30")) - self._poll_timeout = float(os.environ.get("DIGITALKIN_POLL_TIMEOUT", "1")) + self._poll_interval = poll_interval or self._task_settings.signal_poll_interval + self._initial_poll_interval = initial_poll_interval or self._task_settings.signal_initial_poll_interval # Lazy buffer: created on first send_signal to ensure correct event loop and stub self._send_buffer_key = self._channel_cache_key or "default" self._send_buffer_acquired = False @@ -570,7 +565,7 @@ async def send_signal(self, task_id: str, data: dict[str, Any]) -> dict[str, Any self._send_buffer_acquired = True buffer = None if buffer is None: - buffer = _SharedSendBuffer.get_or_create(self._send_buffer_key, self.stub, self._grpc_timeout) + buffer = _SharedSendBuffer.get_or_create(self._send_buffer_key, self.stub, self._task_settings) await buffer.send(self._signal_to_task_proto(signal)) logger.info("SendSignals: task_id=%s action=%s", task_id, signal.action.value) return data @@ -588,7 +583,7 @@ async def _get_signals(self, task_ids: list[str]) -> list[task_manager_message_p resp = await self.poll_grpc( "GetSignals", task_manager_dto_pb2.GetSignalsRequest(task_ids=task_ids), - timeout=self._poll_timeout, + timeout=self._task_settings.poll_timeout, ) return list(resp.tasks) if resp is not None else [] except Exception: diff --git a/tests/core/test_base_task_manager.py b/tests/core/test_base_task_manager.py index 05626332..7642ebf0 100644 --- a/tests/core/test_base_task_manager.py +++ b/tests/core/test_base_task_manager.py @@ -23,6 +23,7 @@ from digitalkin.core.task_manager.base_task_manager import BaseTaskManager from digitalkin.core.task_manager.task_session import TaskSession from digitalkin.models.core.task_monitor import CancellationReason +from digitalkin.models.settings.client.task import TaskSettings from digitalkin.modules._base_module import BaseModule from digitalkin.services.task_manager.task_manager_strategy import TaskManagerStrategy @@ -39,11 +40,11 @@ class ConcreteTaskManager(BaseTaskManager): """Minimal concrete implementation for testing BaseTaskManager.""" async def create_task( - self, - task_id: str, - mission_id: str, - module: BaseModule, - coro: Coroutine[Any, Any, None], + self, + task_id: str, + mission_id: str, + module: BaseModule, + coro: Coroutine[Any, Any, None], ) -> None: """Create task using base validation and session creation. @@ -106,14 +107,6 @@ async def mock_base_module(mock_signal_service: Mock) -> Mock: return module -@pytest_asyncio.fixture -async def task_manager() -> ConcreteTaskManager: - """Standard concrete task manager for testing.""" - mgr = ConcreteTaskManager(default_timeout=2.0) - mgr.max_concurrent_tasks = 10 - return mgr - - @pytest_asyncio.fixture async def mock_task_session(mock_signal_service: Mock) -> Mock: """Mock TaskSession with expected attributes and async methods.""" @@ -155,10 +148,11 @@ def test_default_params(self) -> None: assert mgr.default_timeout == 300.0 assert mgr.max_concurrent_tasks == 100 - def test_custom_params(self) -> None: + def test_custom_params(self, monkeypatch: pytest.MonkeyPatch) -> None: """Test custom parameter values.""" + monkeypatch.setenv("TASK_MAX_CONCURRENT_TASKS", "50") + BaseTaskManager._task_settings = TaskSettings() mgr = ConcreteTaskManager(default_timeout=5.0) - mgr.max_concurrent_tasks = 50 assert mgr.default_timeout == 5.0 assert mgr.max_concurrent_tasks == 50 @@ -173,9 +167,10 @@ class TestValidation: @pytest.mark.asyncio async def test_duplicate_task_id_raises( - self, task_manager: ConcreteTaskManager, mock_base_module: Mock, + self, mock_base_module: Mock, ) -> None: """Test that duplicate task_id raises ValueError.""" + task_manager = ConcreteTaskManager() async def work(): await asyncio.sleep(1) @@ -186,14 +181,16 @@ async def work(): await task_manager.create_task("dup", "missions:test", mock_base_module, work()) @pytest.mark.asyncio - async def test_max_concurrent_tasks_raises(self, mock_base_module: Mock) -> None: + async def test_max_concurrent_tasks_raises(self, mock_base_module: Mock, monkeypatch: pytest.MonkeyPatch) -> None: """Test that exceeding max tasks raises RuntimeError after wait timeout.""" + monkeypatch.setenv("TASK_MAX_CONCURRENT_TASKS", "2") + monkeypatch.setenv("TASK_MAX_QUEUED_TASKS", "0") + monkeypatch.setenv("TASK_WAIT_TIMEOUT", "0.1") + BaseTaskManager._task_settings = TaskSettings() mgr = ConcreteTaskManager(default_timeout=1.0) - mgr.max_concurrent_tasks = 2 - mgr._task_wait_timeout = 0.1 async def work(): - await asyncio.sleep(1) + await asyncio.sleep(0.01) await mgr.create_task("t1", "missions:test", mock_base_module, work()) await mgr.create_task("t2", "missions:test", mock_base_module, work()) @@ -203,29 +200,32 @@ async def work(): @pytest.mark.asyncio async def test_duplicate_closes_coroutine( - self, task_manager: ConcreteTaskManager, mock_base_module: Mock, + self, mock_base_module: Mock, ) -> None: """Test that duplicate validation closes the rejected coroutine.""" + mgr = ConcreteTaskManager(default_timeout=2.0) async def work(): await asyncio.sleep(1) - await task_manager.create_task("dup2", "missions:test", mock_base_module, work()) + await mgr.create_task("dup2", "missions:test", mock_base_module, work()) coro = work() with pytest.raises(ValueError): - await task_manager.create_task("dup2", "missions:test", mock_base_module, coro) + await mgr.create_task("dup2", "missions:test", mock_base_module, coro) # Coroutine should be closed with pytest.raises((StopIteration, RuntimeError)): await coro @pytest.mark.asyncio - async def test_max_tasks_closes_coroutine(self, mock_base_module: Mock) -> None: + async def test_max_tasks_closes_coroutine(self, mock_base_module: Mock, monkeypatch: pytest.MonkeyPatch) -> None: """Test that max tasks validation closes the rejected coroutine.""" + monkeypatch.setenv("TASK_MAX_CONCURRENT_TASKS", "1") + monkeypatch.setenv("TASK_MAX_QUEUED_TASKS", "0") + monkeypatch.setenv("TASK_WAIT_TIMEOUT", "0.1") + BaseTaskManager._task_settings = TaskSettings() mgr = ConcreteTaskManager() - mgr.max_concurrent_tasks = 1 - mgr._task_wait_timeout = 0.1 async def work(): await asyncio.sleep(1) @@ -250,22 +250,24 @@ class TestSessionCreation: @pytest.mark.asyncio async def test_create_session_registers( - self, task_manager: ConcreteTaskManager, mock_base_module: Mock, + self, mock_base_module: Mock, ) -> None: """Test _create_session registers session in tasks_sessions.""" - session = task_manager._create_session("t1", "missions:test", mock_base_module) + mgr = ConcreteTaskManager(default_timeout=2.0) + session = mgr._create_session("t1", "missions:test", mock_base_module) - assert "t1" in task_manager.tasks_sessions + assert "t1" in mgr.tasks_sessions assert isinstance(session, TaskSession) assert session.task_id == "t1" assert session.mission_id == "missions:test" @pytest.mark.asyncio async def test_create_session_pending_status( - self, task_manager: ConcreteTaskManager, mock_base_module: Mock, + self, mock_base_module: Mock, ) -> None: """Test _create_session creates session with pending status.""" - session = task_manager._create_session("t2", "missions:test", mock_base_module) + mgr = ConcreteTaskManager(default_timeout=2.0) + session = mgr._create_session("t2", "missions:test", mock_base_module) assert session.status == "pending" @@ -279,59 +281,64 @@ class TestCleanup: @pytest.mark.asyncio async def test_cleanup_removes_session( - self, task_manager: ConcreteTaskManager, mock_task_session: Mock, + self, mock_task_session: Mock, ) -> None: """Test cleanup removes session from tracking.""" - task_manager.tasks_sessions["t1"] = mock_task_session + mgr = ConcreteTaskManager(default_timeout=2.0) + mgr.tasks_sessions["t1"] = mock_task_session - await task_manager._cleanup_task("t1", "missions:test") + await mgr._cleanup_task("t1", "missions:test") - assert "t1" not in task_manager.tasks_sessions + assert "t1" not in mgr.tasks_sessions @pytest.mark.asyncio async def test_cleanup_calls_session_cleanup( - self, task_manager: ConcreteTaskManager, mock_task_session: Mock, + self, mock_task_session: Mock, ) -> None: """Test cleanup calls session.cleanup().""" - task_manager.tasks_sessions["t1"] = mock_task_session + mgr = ConcreteTaskManager(default_timeout=2.0) + mgr.tasks_sessions["t1"] = mock_task_session - await task_manager._cleanup_task("t1", "missions:test") + await mgr._cleanup_task("t1", "missions:test") mock_task_session.cleanup.assert_called_once() @pytest.mark.asyncio async def test_cleanup_removes_task( - self, task_manager: ConcreteTaskManager, mock_task_session: Mock, + self, mock_task_session: Mock, ) -> None: """Test cleanup removes task from tasks dict.""" - task_manager.tasks["t1"] = asyncio.create_task(asyncio.sleep(10)) - task_manager.tasks_sessions["t1"] = mock_task_session + mgr = ConcreteTaskManager(default_timeout=2.0) + mgr.tasks["t1"] = asyncio.create_task(asyncio.sleep(10)) + mgr.tasks_sessions["t1"] = mock_task_session - await task_manager._cleanup_task("t1", "missions:test") + await mgr._cleanup_task("t1", "missions:test") - assert "t1" not in task_manager.tasks + assert "t1" not in mgr.tasks @pytest.mark.asyncio async def test_cleanup_handles_missing_session( - self, task_manager: ConcreteTaskManager, + self, ) -> None: """Test cleanup handles non-existent session gracefully.""" # Should not raise - await task_manager._cleanup_task("nonexistent", "missions:test") + mgr = ConcreteTaskManager(default_timeout=2.0) + await mgr._cleanup_task("nonexistent", "missions:test") @pytest.mark.asyncio async def test_cleanup_handles_session_cleanup_failure( - self, task_manager: ConcreteTaskManager, mock_task_session: Mock, + self, mock_task_session: Mock, ) -> None: """Test cleanup continues even if session.cleanup() fails.""" mock_task_session.cleanup = AsyncMock(side_effect=RuntimeError("cleanup failed")) - task_manager.tasks_sessions["t1"] = mock_task_session + mgr = ConcreteTaskManager(default_timeout=2.0) + mgr.tasks_sessions["t1"] = mock_task_session # Should not raise - await task_manager._cleanup_task("t1", "missions:test") + await mgr._cleanup_task("t1", "missions:test") # Session should still be removed - assert "t1" not in task_manager.tasks_sessions + assert "t1" not in mgr.tasks_sessions # ============================================================================ @@ -344,53 +351,52 @@ class TestSignalSending: @pytest.mark.asyncio async def test_send_signal_success( - self, - task_manager: ConcreteTaskManager, - mock_task_session: Mock, - mock_signal_service: Mock, + self, + mock_task_session: Mock, + mock_signal_service: Mock, ) -> None: """Test successful signal sending.""" - task_manager.tasks_sessions["t1"] = mock_task_session + mgr = ConcreteTaskManager(default_timeout=2.0) + mgr.tasks_sessions["t1"] = mock_task_session - result = await task_manager.send_signal("t1", "missions:test", "cancel", {}) + result = await mgr.send_signal("t1", "missions:test", "cancel", {}) assert result is True mock_signal_service.send_signal.assert_called_once() @pytest.mark.asyncio - async def test_send_signal_unknown_task( - self, task_manager: ConcreteTaskManager, - ) -> None: + async def test_send_signal_unknown_task(self) -> None: """Test signal to non-existent task returns False.""" - result = await task_manager.send_signal("unknown", "missions:test", "cancel", {}) + mgr = ConcreteTaskManager(default_timeout=2.0) + result = await mgr.send_signal("unknown", "missions:test", "cancel", {}) assert result is False @pytest.mark.asyncio async def test_send_signal_includes_action( - self, - task_manager: ConcreteTaskManager, - mock_task_session: Mock, - mock_signal_service: Mock, + self, + mock_task_session: Mock, + mock_signal_service: Mock, ) -> None: """Test signal includes correct action type.""" - task_manager.tasks_sessions["t1"] = mock_task_session + mgr = ConcreteTaskManager(default_timeout=2.0) + mgr.tasks_sessions["t1"] = mock_task_session - await task_manager.send_signal("t1", "missions:test", "cancel", {}) + await mgr.send_signal("t1", "missions:test", "cancel", {}) call_data = mock_signal_service.send_signal.call_args[0][1] assert call_data["action"] == "cancel" @pytest.mark.asyncio async def test_send_signal_with_payload( - self, - task_manager: ConcreteTaskManager, - mock_task_session: Mock, - mock_signal_service: Mock, + self, + mock_task_session: Mock, + mock_signal_service: Mock, ) -> None: """Test signal includes payload.""" - task_manager.tasks_sessions["t1"] = mock_task_session + mgr = ConcreteTaskManager(default_timeout=2.0) + mgr.tasks_sessions["t1"] = mock_task_session - await task_manager.send_signal("t1", "missions:test", "start", {"key": "value"}) + await mgr.send_signal("t1", "missions:test", "start", {"key": "value"}) call_data = mock_signal_service.send_signal.call_args[0][1] assert call_data["payload"] == {"key": "value"} @@ -405,27 +411,27 @@ class TestCancelTask: """Tests for cancel_task.""" @pytest.mark.asyncio - async def test_cancel_nonexistent_task( - self, task_manager: ConcreteTaskManager, - ) -> None: + async def test_cancel_nonexistent_task(self) -> None: """Test cancelling a task that doesn't exist returns True.""" - result = await task_manager.cancel_task("nonexistent", "missions:test") + mgr = ConcreteTaskManager(default_timeout=2.0) + result = await mgr.cancel_task("nonexistent", "missions:test") assert result is True @pytest.mark.asyncio async def test_cancel_completed_task( - self, - task_manager: ConcreteTaskManager, - mock_task_session: Mock, + self, + + mock_task_session: Mock, ) -> None: """Test cancelling an already completed task.""" + mgr = ConcreteTaskManager(default_timeout=2.0) done_task = asyncio.create_task(asyncio.sleep(0)) await done_task # Let it complete - task_manager.tasks["t1"] = done_task - task_manager.tasks_sessions["t1"] = mock_task_session + mgr.tasks["t1"] = done_task + mgr.tasks_sessions["t1"] = mock_task_session - result = await task_manager.cancel_task("t1", "missions:test") + result = await mgr.cancel_task("t1", "missions:test") assert result is True @@ -439,22 +445,24 @@ class TestCancelAllTasks: @pytest.mark.asyncio async def test_cancel_all_no_tasks( - self, task_manager: ConcreteTaskManager, + self, ) -> None: """Test cancel_all_tasks with no tasks.""" - results = await task_manager.cancel_all_tasks("missions:test") + mgr = ConcreteTaskManager(default_timeout=2.0) + results = await mgr.cancel_all_tasks("missions:test") assert results == {} @pytest.mark.asyncio async def test_cancel_all_cancels_all_running( - self, - task_manager: ConcreteTaskManager, - mock_task_session: Mock, + self, + + mock_task_session: Mock, ) -> None: """Test cancel_all_tasks cancels all running tasks.""" + mgr = ConcreteTaskManager(default_timeout=2.0) for i in range(3): task = asyncio.create_task(asyncio.sleep(10)) - task_manager.tasks[f"t{i}"] = task + mgr.tasks[f"t{i}"] = task session = Mock(spec=TaskSession) session.status = "running" session.cancellation_reason = CancellationReason.UNKNOWN @@ -462,9 +470,9 @@ async def test_cancel_all_cancels_all_running( session.setup_id = "setup:test" session.setup_version_id = "setup_version:test" session.cleanup = AsyncMock() - task_manager.tasks_sessions[f"t{i}"] = session + mgr.tasks_sessions[f"t{i}"] = session - results = await task_manager.cancel_all_tasks("missions:test", timeout=0.5) + results = await mgr.cancel_all_tasks("missions:test", timeout=0.5) assert len(results) == 3 @@ -479,33 +487,36 @@ class TestShutdown: @pytest.mark.asyncio async def test_shutdown_sets_event( - self, task_manager: ConcreteTaskManager, + self, ) -> None: """Test shutdown sets the shutdown event.""" - assert not task_manager._shutdown_event.is_set() - await task_manager.shutdown("missions:test") - assert task_manager._shutdown_event.is_set() + mgr = ConcreteTaskManager(default_timeout=2.0) + assert not mgr._shutdown_event.is_set() + await mgr.shutdown("missions:test") + assert mgr._shutdown_event.is_set() @pytest.mark.asyncio async def test_shutdown_idempotent( - self, task_manager: ConcreteTaskManager, + self, ) -> None: """Test shutdown can be called multiple times safely.""" - await task_manager.shutdown("missions:test") - await task_manager.shutdown("missions:test") - assert task_manager._shutdown_event.is_set() + mgr = ConcreteTaskManager(default_timeout=2.0) + await mgr.shutdown("missions:test") + await mgr.shutdown("missions:test") + assert mgr._shutdown_event.is_set() @pytest.mark.asyncio async def test_shutdown_marks_sessions_with_reason( - self, - task_manager: ConcreteTaskManager, - mock_task_session: Mock, + self, + + mock_task_session: Mock, ) -> None: """Test shutdown marks sessions with SHUTDOWN reason.""" + mgr = ConcreteTaskManager(default_timeout=2.0) mock_task_session.cancellation_reason = CancellationReason.UNKNOWN - task_manager.tasks_sessions["t1"] = mock_task_session + mgr.tasks_sessions["t1"] = mock_task_session - await task_manager.shutdown("missions:test") + await mgr.shutdown("missions:test") assert mock_task_session.cancellation_reason == CancellationReason.SHUTDOWN @@ -520,43 +531,51 @@ class TestProperties: @pytest.mark.asyncio async def test_task_count_empty( - self, task_manager: ConcreteTaskManager, + self, ) -> None: """Test task_count is 0 when empty.""" - assert task_manager.task_count == 0 + mgr = ConcreteTaskManager(default_timeout=2.0) + assert mgr.task_count == 0 @pytest.mark.asyncio async def test_task_count_active( - self, - task_manager: ConcreteTaskManager, - mock_base_module: Mock, + self, + mock_base_module: Mock, + monkeypatch: pytest.MonkeyPatch ) -> None: """Test task_count counts pending and running sessions.""" + monkeypatch.setenv("TASK_MAX_CONCURRENT_TASKS", "2") + monkeypatch.setenv("TASK_MAX_QUEUED_TASKS", "0") + BaseTaskManager._task_settings = TaskSettings() + mgr = ConcreteTaskManager(default_timeout=2.0) + async def work(): await asyncio.sleep(1) - await task_manager.create_task("t1", "missions:test", mock_base_module, work()) - assert task_manager.task_count == 1 + await mgr.create_task("t1", "missions:test", mock_base_module, work()) + assert mgr.task_count == 1 - await task_manager.create_task("t2", "missions:test", mock_base_module, work()) - assert task_manager.task_count == 2 + await mgr.create_task("t2", "missions:test", mock_base_module, work()) + assert mgr.task_count == 2 @pytest.mark.asyncio async def test_running_tasks_empty( - self, task_manager: ConcreteTaskManager, + self, ) -> None: """Test running_tasks is empty when no tasks.""" - assert len(task_manager.running_tasks) == 0 + mgr = ConcreteTaskManager(default_timeout=2.0) + assert len(mgr.running_tasks) == 0 @pytest.mark.asyncio async def test_running_tasks_tracks_active( - self, task_manager: ConcreteTaskManager, + self, ) -> None: """Test running_tasks returns IDs of active tasks.""" + mgr = ConcreteTaskManager(default_timeout=2.0) task = asyncio.create_task(asyncio.sleep(10)) - task_manager.tasks["t1"] = task + mgr.tasks["t1"] = task - assert "t1" in task_manager.running_tasks + assert "t1" in mgr.running_tasks task.cancel() with contextlib.suppress(asyncio.CancelledError): @@ -596,11 +615,14 @@ class TestTasksLock: """Tests for _tasks_lock preventing TOCTOU race conditions.""" @pytest.mark.asyncio - async def test_concurrent_create_respects_max(self, mock_base_module: Mock) -> None: + async def test_concurrent_create_respects_max(self, mock_base_module: Mock, monkeypatch: pytest.MonkeyPatch) -> None: """Test that concurrent creates don't exceed max_concurrent_tasks.""" + monkeypatch.setenv("TASK_MAX_CONCURRENT_TASKS", "3") + monkeypatch.setenv("TASK_MAX_QUEUED_TASKS", "0") + monkeypatch.setenv("TASK_WAIT_TIMEOUT", "0.1") + BaseTaskManager._task_settings = TaskSettings() + mgr = ConcreteTaskManager() - mgr.max_concurrent_tasks = 3 - mgr._task_wait_timeout = 0.1 async def work(): await asyncio.sleep(1) diff --git a/tests/grpc_server/test_module_service.py b/tests/grpc_server/test_module_service.py index 65b94ed7..3851a793 100644 --- a/tests/grpc_server/test_module_service.py +++ b/tests/grpc_server/test_module_service.py @@ -14,15 +14,15 @@ from agentic_mesh_protocol.module.v1 import ( information_pb2, lifecycle_pb2, - monitoring_pb2, ) from agentic_mesh_protocol.setup.v1 import setup_pb2 from google.protobuf import json_format, struct_pb2 +from tests.fixtures.grpc_fixtures import FakeContext from digitalkin.core.job_manager.base_job_manager import BaseJobManager from digitalkin.grpc_servers.module_servicer import ModuleServicer +from digitalkin.models.settings.server.server import ServerSettings from digitalkin.modules._base_module import BaseModule -from tests.fixtures.grpc_fixtures import FakeContext # Mock Module Class for testing @@ -118,6 +118,7 @@ def module_servicer(mock_job_manager, mock_setup_strategy): servicer._setup_cache_max = 100 servicer._setup_inflight: dict[str, asyncio.Future] = {} servicer._completion_timeout = 300.0 + servicer._server_settings = ServerSettings() return servicer diff --git a/tests/grpc_server/utils/test_grpc_client_wrapper.py b/tests/grpc_server/utils/test_grpc_client_wrapper.py index 86a228f2..8c1619f1 100644 --- a/tests/grpc_server/utils/test_grpc_client_wrapper.py +++ b/tests/grpc_server/utils/test_grpc_client_wrapper.py @@ -5,7 +5,7 @@ import pytest from digitalkin.grpc_servers.utils.grpc_client_wrapper import GrpcClientWrapper -from digitalkin.models.grpc_servers.models import ClientConfig, GrpcCompression +from digitalkin.models.grpc_servers.models import ClientConfig from digitalkin.models.settings.utils.channel import ControlFlow, SecurityMode @@ -26,7 +26,6 @@ def _make_config(host: str = "localhost", port: int = 50051) -> ClientConfig: port=port, mode=ControlFlow.ASYNC, security=SecurityMode.INSECURE, - compression=GrpcCompression.GZIP, ) @@ -52,7 +51,7 @@ def test_same_config_reuses_channel(self, mock_insecure_channel: MagicMock) -> N mock_insecure_channel.assert_called_once() - cache_key = f"{config.address}:{config.security.value}:{config.compression.value}" + cache_key = wrapper_a._channel_cache_key if GrpcClientWrapper._ref_counts[cache_key] != 2: pytest.fail(f"Expected ref_count=2, got {GrpcClientWrapper._ref_counts[cache_key]}") @@ -100,7 +99,7 @@ async def test_close_one_user_keeps_channel_alive(self, mock_insecure_channel: M fake_channel.close.assert_not_called() - cache_key = f"{config.address}:{config.security.value}:{config.compression.value}" + cache_key = wrapper_a._channel_cache_key if cache_key not in GrpcClientWrapper._channel_cache: pytest.fail("Channel should still be in cache with one remaining ref") if GrpcClientWrapper._ref_counts[cache_key] != 1: @@ -124,7 +123,7 @@ async def test_close_last_user_closes_channel(self, mock_insecure_channel: Magic fake_channel.close.assert_awaited_once() - cache_key = f"{config.address}:{config.security.value}:{config.compression.value}" + cache_key = wrapper_a._channel_cache_key if cache_key in GrpcClientWrapper._channel_cache: pytest.fail("Channel should be removed from cache after last ref closed") if cache_key in GrpcClientWrapper._ref_counts: diff --git a/tests/grpc_server/utils/test_models.py b/tests/grpc_server/utils/test_models.py index 227c9ab2..fec55c10 100644 --- a/tests/grpc_server/utils/test_models.py +++ b/tests/grpc_server/utils/test_models.py @@ -251,11 +251,3 @@ def test_module_server_config(self, monkeypatch: pytest.MonkeyPatch) -> None: expected_advertise_host = "digitalkin-test-archetype-server" assert config.channel.advertise_host == expected_advertise_host - - def test_registry_server_config(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Test RegistryServerConfig specific properties.""" - monkeypatch.setenv("SERVER_CHANNEL_DATABASE_URL", "sqlite:///registry.db") - config = ServerSettings() - expected_database_url = "sqlite:///registry.db" - - assert config.channel.database_url == expected_database_url diff --git a/tests/services/cost/test_cost_stress.py b/tests/services/cost/test_cost_stress.py index cc76be5e..974ef71b 100644 --- a/tests/services/cost/test_cost_stress.py +++ b/tests/services/cost/test_cost_stress.py @@ -22,6 +22,9 @@ import grpc_testing import pytest from agentic_mesh_protocol.cost.v1 import cost_service_pb2, cost_service_pb2_grpc +from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext +from tests.fixtures.stress_reporter import StressReporter +from tests.services.cost.mock_cost_servicer import MockCostServicer from digitalkin.models.grpc_servers.models import ClientConfig from digitalkin.models.services.cost import AmountLimit, CostTypeEnum, QuantityLimit @@ -29,9 +32,6 @@ from digitalkin.services.cost.cost_strategy import CostConfig, CostServiceError from digitalkin.services.cost.default_cost import DefaultCost from digitalkin.services.cost.grpc_cost import GrpcCost -from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext -from tests.fixtures.stress_reporter import StressReporter -from tests.services.cost.mock_cost_servicer import MockCostServicer # Set timeout for stress tests pytestmark = pytest.mark.timeout(60) @@ -121,7 +121,6 @@ def grpc_client( port=50051, mode=ControlFlow.ASYNC, security=SecurityMode.INSECURE, - credentials=None, ) client = GrpcCost( diff --git a/tests/services/cost/test_grpc_cost.py b/tests/services/cost/test_grpc_cost.py index 6e538359..51859b12 100644 --- a/tests/services/cost/test_grpc_cost.py +++ b/tests/services/cost/test_grpc_cost.py @@ -13,6 +13,7 @@ import grpc_testing import pytest from agentic_mesh_protocol.cost.v1 import cost_service_pb2, cost_service_pb2_grpc +from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext from digitalkin.grpc_servers.utils.exceptions import ServerError from digitalkin.models.grpc_servers.models import ClientConfig @@ -20,7 +21,6 @@ from digitalkin.services.cost.cost_strategy import CostConfig, CostData, CostServiceError, CostType from digitalkin.services.cost.grpc_cost import GrpcCost from mock_cost_servicer import MockCostServicer -from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext service_instance = MockCostServicer() service_name = cost_service_pb2.DESCRIPTOR.services_by_name["CostService"] @@ -136,7 +136,6 @@ def client(test_channel: grpc_testing.Channel, cost_config: dict[str, CostConfig port=50051, mode=ControlFlow.ASYNC, security=SecurityMode.INSECURE, - credentials=None, ) mission_id = "mission_test" diff --git a/tests/services/filesystem/test_grpc_filesystem.py b/tests/services/filesystem/test_grpc_filesystem.py index 57cfb5ba..0e9b539e 100644 --- a/tests/services/filesystem/test_grpc_filesystem.py +++ b/tests/services/filesystem/test_grpc_filesystem.py @@ -16,6 +16,7 @@ ) from google.protobuf import struct_pb2 from grpc.framework.foundation import logging_pool +from tests.fixtures.grpc_fixtures import FakeContext from digitalkin.models.grpc_servers.models import ClientConfig from digitalkin.models.settings.utils.channel import ControlFlow, SecurityMode @@ -27,7 +28,6 @@ ) from digitalkin.services.filesystem.grpc_filesystem import GrpcFilesystem from mock_filesystem_servicer import MockFilesystemServicer -from tests.fixtures.grpc_fixtures import FakeContext service_instance = MockFilesystemServicer() service_name = filesystem_service_pb2.DESCRIPTOR.services_by_name["FilesystemService"] @@ -80,7 +80,6 @@ def client(test_channel: grpc_testing.Channel) -> GrpcFilesystem: port=50151, mode=ControlFlow.ASYNC, security=SecurityMode.INSECURE, - credentials=None, ) mission_id = "test_mission" diff --git a/tests/services/registry/test_grpc_registry.py b/tests/services/registry/test_grpc_registry.py index 2e4269e2..f9884c5d 100644 --- a/tests/services/registry/test_grpc_registry.py +++ b/tests/services/registry/test_grpc_registry.py @@ -20,6 +20,8 @@ registry_service_pb2, registry_service_pb2_grpc, ) +from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext +from tests.services.registry.mock_registry_servicer import MockRegistryServicer from digitalkin.models.grpc_servers.models import ClientConfig from digitalkin.models.services.registry import RegistryModuleStatus, RegistryModuleType @@ -28,8 +30,6 @@ RegistryServiceError, ) from digitalkin.services.registry.grpc_registry import GrpcRegistry -from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext -from tests.services.registry.mock_registry_servicer import MockRegistryServicer # Set timeout for all tests in this file (20 seconds) pytestmark = pytest.mark.timeout(20) @@ -88,7 +88,6 @@ def dummy_client_config() -> ClientConfig: port=50052, mode=ControlFlow.ASYNC, security=SecurityMode.INSECURE, - credentials=None, ) diff --git a/tests/services/setup/test_grpc_setup.py b/tests/services/setup/test_grpc_setup.py index 8345908b..c4a87a35 100644 --- a/tests/services/setup/test_grpc_setup.py +++ b/tests/services/setup/test_grpc_setup.py @@ -15,13 +15,13 @@ setup_service_pb2_grpc, ) from freezegun import freeze_time +from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext from digitalkin.models.grpc_servers.models import ClientConfig from digitalkin.models.settings.utils.channel import ControlFlow, SecurityMode from digitalkin.services.setup.grpc_setup import GrpcSetup from digitalkin.services.setup.setup_strategy import SetupData, SetupVersionData from mock_setup_servicer import MockSetupServicer -from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext service_instance = MockSetupServicer() service_name = setup_service_pb2.DESCRIPTOR.services_by_name["SetupService"] @@ -77,7 +77,6 @@ def client(test_channel: grpc_testing.Channel) -> GrpcSetup: port=50151, mode=ControlFlow.ASYNC, security=SecurityMode.INSECURE, - credentials=None, ) client = GrpcSetup() # emulate real instance diff --git a/tests/services/storage/test_grpc_storage.py b/tests/services/storage/test_grpc_storage.py index e485dfe1..b200a551 100644 --- a/tests/services/storage/test_grpc_storage.py +++ b/tests/services/storage/test_grpc_storage.py @@ -15,12 +15,12 @@ import pytest from agentic_mesh_protocol.storage.v1 import data_pb2, storage_service_pb2, storage_service_pb2_grpc from pydantic import BaseModel, Field +from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext +from tests.services.storage.mock_storage_servicer import MockStorageServicer from digitalkin.models.grpc_servers.models import ClientConfig from digitalkin.services.storage.grpc_storage import GrpcStorage from digitalkin.services.storage.storage_strategy import DataType, StorageServiceError -from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext -from tests.services.storage.mock_storage_servicer import MockStorageServicer # Set timeout for all tests in this file (20 seconds) pytestmark = pytest.mark.timeout(20) @@ -126,7 +126,6 @@ def dummy_client_config() -> ClientConfig: port=50051, mode=ControlFlow.ASYNC, security=SecurityMode.INSECURE, - credentials=None, ) diff --git a/tests/services/task_manager/test_grpc_task_manager.py b/tests/services/task_manager/test_grpc_task_manager.py index 5af8d098..8c46f004 100644 --- a/tests/services/task_manager/test_grpc_task_manager.py +++ b/tests/services/task_manager/test_grpc_task_manager.py @@ -21,13 +21,13 @@ ) from google.protobuf.struct_pb2 import Struct from google.protobuf.timestamp_pb2 import Timestamp +from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext from digitalkin.models.core.task_monitor import CancellationReason, SignalMessage, SignalType from digitalkin.models.grpc_servers.models import ClientConfig from digitalkin.models.settings.utils.channel import ControlFlow, SecurityMode from digitalkin.services.task_manager.grpc_task_manager import GrpcTaskManager, _SharedPoller, _SharedSendBuffer from mock_task_manager_servicer import MockTaskManagerServicer -from tests.fixtures.grpc_fixtures import AsyncStubWrapper, FakeContext # Set timeout for all tests in this file (30 seconds) pytestmark = pytest.mark.timeout(30) @@ -104,7 +104,6 @@ def client(test_channel: grpc_testing.Channel) -> GrpcTaskManager: port=50051, mode=ControlFlow.ASYNC, security=SecurityMode.INSECURE, - credentials=None, ) client = GrpcTaskManager( diff --git a/tests/services/user_profile/test_grpc_user_profile.py b/tests/services/user_profile/test_grpc_user_profile.py index 5cb117e3..1128ddc4 100644 --- a/tests/services/user_profile/test_grpc_user_profile.py +++ b/tests/services/user_profile/test_grpc_user_profile.py @@ -20,11 +20,11 @@ user_profile_service_pb2, user_profile_service_pb2_grpc, ) +from tests.fixtures.grpc_fixtures import FakeContext +from tests.services.user_profile.mock_user_profile_servicer import MockUserProfileServicer from digitalkin.models.grpc_servers.models import ClientConfig from digitalkin.services.user_profile.grpc_user_profile import GrpcUserProfile -from tests.fixtures.grpc_fixtures import FakeContext -from tests.services.user_profile.mock_user_profile_servicer import MockUserProfileServicer # Set timeout for all tests in this file (20 seconds) pytestmark = pytest.mark.timeout(20) @@ -106,7 +106,6 @@ def dummy_client_config() -> ClientConfig: port=50051, mode=ControlFlow.ASYNC, security=SecurityMode.INSECURE, - credentials=None, )