Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Heartbeat manager for sending periodic heartbeats to SyftHub marketplaces."""
"""Endpoint heartbeat manager for sending periodic endpoint health to SyftHub marketplaces."""

from __future__ import annotations

import asyncio
import secrets
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from uuid import UUID

from loguru import logger
Expand All @@ -17,6 +17,19 @@
if TYPE_CHECKING:
from syft_space.components.marketplaces.repository import MarketplaceRepository
from syft_space.components.settings.repository import SettingsRepository
from syft_space.components.tenants.entities import Tenant


@runtime_checkable
class EndpointHealthChecker(Protocol):
"""Interface for checking published endpoint health.

Satisfied structurally by EndpointHandler — no import needed.
"""

async def get_published_endpoint_health(
self, tenant: Tenant, health_timeout: float = ...
) -> list[dict[str, Any]]: ...


@dataclass
Expand All @@ -29,15 +42,20 @@ class HeartbeatState:
consecutive_failures: int = field(default=0)


class HeartbeatManager(LifecycleService):
"""Manages periodic heartbeats to SyftHub marketplaces with adaptive backoff.
class EndpointHeartbeatManager(LifecycleService):
"""Manages periodic endpoint health reporting to SyftHub marketplaces.

Periodically checks health of all published endpoints and reports their
status to SyftHub via a non-destructive health API. Replaces the domain-level
heartbeat with endpoint-level health reporting that also serves as a domain
liveness signal via TTL.

Features:
- Exponential backoff: Starts frequent, increases interval on success
- Jitter: 10-15% randomness to prevent thundering herd
- Failure recovery: Resets to aggressive (small) interval on failure
- Failure backoff: After repeated failures, backs off instead of aggressive retry
- Multi-marketplace support: Sends heartbeats to all active marketplaces
- Multi-marketplace support: Sends health to all active marketplaces
- TTL > interval: TTL is 3x the interval to tolerate missed heartbeats
"""

Expand All @@ -53,17 +71,20 @@ class HeartbeatManager(LifecycleService):

def __init__(
self,
health_checker: EndpointHealthChecker,
marketplace_repository: MarketplaceRepository,
settings_repository: SettingsRepository,
enabled: bool = True,
) -> None:
"""Initialize the heartbeat manager.
"""Initialize the endpoint heartbeat manager.

Args:
health_checker: Provider for endpoint health checks (e.g. EndpointHandler)
marketplace_repository: Repository for accessing marketplace credentials
settings_repository: Repository for accessing public_url setting
enabled: Whether heartbeat manager is enabled
enabled: Whether endpoint heartbeat manager is enabled
"""
self._health_checker = health_checker
self._marketplace_repository = marketplace_repository
self._settings_repository = settings_repository
self._enabled = enabled
Expand All @@ -77,46 +98,50 @@ def __init__(
# Async primitives - initialized in startup()
self._shutdown_event: asyncio.Event | None = None
self._heartbeat_task: asyncio.Task | None = None
self._tenant: Tenant | None = None
self._tenant_id: UUID | None = None

def set_tenant_id(self, tenant_id: UUID) -> None:
"""Set the tenant ID for marketplace queries.
def set_tenant(self, tenant: Tenant) -> None:
"""Set the tenant for endpoint health queries.

Args:
tenant_id: Default tenant ID
tenant: Default tenant
"""
self._tenant_id = tenant_id
self._tenant = tenant
self._tenant_id = tenant.id

async def startup(self) -> None:
"""Start the heartbeat manager."""
"""Start the endpoint heartbeat manager."""
if not self._enabled:
logger.info("Heartbeat manager is disabled")
logger.info("Endpoint heartbeat manager is disabled")
return

# Check if heartbeat manager is already running
# Check if already running
# Just a safety check to prevent multiple startup calls
if self._heartbeat_task is not None and not self._heartbeat_task.done():
logger.warning("Heartbeat manager already running, skipping startup")
logger.warning(
"Endpoint heartbeat manager already running, skipping startup"
)
return

logger.info("Starting heartbeat manager...")
logger.info("Starting endpoint heartbeat manager...")

# Initialize async primitives
self._shutdown_event = asyncio.Event()

# Start background heartbeat loop
self._heartbeat_task = asyncio.create_task(
self._heartbeat_loop(), name="HeartbeatManager"
self._heartbeat_loop(), name="EndpointHeartbeatManager"
)

logger.info("Heartbeat manager started")
logger.info("Endpoint heartbeat manager started")

async def shutdown(self) -> None:
"""Shutdown the heartbeat manager gracefully."""
"""Shutdown the endpoint heartbeat manager gracefully."""
if not self._enabled:
return

logger.info("Shutting down heartbeat manager...")
logger.info("Shutting down endpoint heartbeat manager...")

if self._shutdown_event:
self._shutdown_event.set()
Expand All @@ -132,24 +157,28 @@ async def shutdown(self) -> None:
pass

self._states.clear()
logger.info("Heartbeat manager shutdown complete")
logger.info("Endpoint heartbeat manager shutdown complete")

async def _heartbeat_loop(self) -> None:
"""Main heartbeat loop - waits for public_url, then sends heartbeats."""
logger.info("Heartbeat loop started, waiting for public_url to be set...")
"""Main heartbeat loop - waits for public_url, then sends endpoint health."""
logger.info(
"Endpoint heartbeat loop started, waiting for public_url to be set..."
)

# Wait for public_url to be set before starting heartbeats
if not await self._wait_for_public_url():
logger.info("Heartbeat loop stopped (shutdown during public_url wait)")
logger.info(
"Endpoint heartbeat loop stopped (shutdown during public_url wait)"
)
return

logger.info("Public URL available, beginning heartbeat cycle")
logger.info("Public URL available, beginning endpoint heartbeat cycle")
self._had_public_url = True

while not self._shutdown_event.is_set():
try:
# Send heartbeats to all marketplaces
await self._send_heartbeats_to_all()
# Send endpoint heartbeats to all marketplaces
await self._send_endpoint_heartbeats_to_all()

# Calculate sleep time (minimum interval across all marketplaces)
sleep_time = self._get_minimum_interval()
Expand All @@ -166,7 +195,7 @@ async def _heartbeat_loop(self) -> None:
except asyncio.CancelledError:
break
except Exception as e:
logger.exception(f"Unexpected error in heartbeat loop: {e}")
logger.exception(f"Unexpected error in endpoint heartbeat loop: {e}")
# Wait before retrying
try:
await asyncio.wait_for(
Expand All @@ -176,7 +205,7 @@ async def _heartbeat_loop(self) -> None:
except asyncio.TimeoutError:
pass

logger.info("Heartbeat loop stopped")
logger.info("Endpoint heartbeat loop stopped")

async def _wait_for_public_url(self) -> bool:
"""Wait for public_url to be set in settings.
Expand Down Expand Up @@ -212,18 +241,20 @@ def _get_minimum_interval(self) -> float:
return self.INITIAL_INTERVAL
return min(state.current_interval for state in self._states.values())

async def _send_heartbeats_to_all(self) -> None:
"""Send heartbeats to all active marketplaces concurrently."""
if not self._tenant_id:
logger.warning("Tenant ID not set, skipping heartbeat")
async def _send_endpoint_heartbeats_to_all(self) -> None:
"""Send endpoint health to all active marketplaces concurrently."""
if not self._tenant:
logger.warning("Tenant not set, skipping endpoint heartbeat")
return

# Get public URL from settings
public_url = await self._settings_repository.get_public_url()
if not public_url:
# URL was removed - reset all states
if self._had_public_url and self._states:
logger.info("Public URL removed, resetting heartbeat intervals")
logger.info(
"Public URL removed, resetting endpoint heartbeat intervals"
)
for state in self._states.values():
state.current_interval = self.INITIAL_INTERVAL
state.consecutive_successes = 0
Expand All @@ -235,7 +266,7 @@ async def _send_heartbeats_to_all(self) -> None:
# Get all active marketplaces
marketplaces = await self._marketplace_repository.get_active(self._tenant_id)
if not marketplaces:
logger.debug("No active marketplaces, skipping heartbeat")
logger.debug("No active marketplaces, skipping endpoint heartbeat")
return

# Prune stale marketplace states (for removed marketplaces)
Expand All @@ -245,29 +276,41 @@ async def _send_heartbeats_to_all(self) -> None:
del self._states[stale_id]
logger.debug(f"Pruned stale heartbeat state for marketplace {stale_id}")

# Send heartbeats concurrently
# Check health of all published endpoints (once, shared across marketplaces)
endpoint_health = await self._health_checker.get_published_endpoint_health(
self._tenant
)
if not endpoint_health:
logger.debug("No published endpoints, skipping endpoint heartbeat")
return

# Send to each marketplace concurrently
tasks = [
self._send_heartbeat_to_marketplace(marketplace, public_url)
self._send_endpoint_heartbeat_to_marketplace(
marketplace, public_url, endpoint_health
)
for marketplace in marketplaces
]
await asyncio.gather(*tasks, return_exceptions=True)

async def _send_heartbeat_to_marketplace(
async def _send_endpoint_heartbeat_to_marketplace(
self,
marketplace: Marketplace,
public_url: str,
endpoint_health: list[dict[str, Any]],
) -> None:
"""Send heartbeat to a single marketplace.
"""Send endpoint health to a single marketplace.

Args:
marketplace: Marketplace entity with credentials
public_url: Public URL to send in heartbeat
public_url: Domain's public URL
endpoint_health: List of endpoint health statuses
"""
# Validate credentials
if not marketplace.email or not marketplace.password:
logger.debug(
f"Marketplace {marketplace.name} missing credentials, "
"skipping heartbeat"
"skipping endpoint heartbeat"
)
return

Expand All @@ -279,26 +322,32 @@ async def _send_heartbeat_to_marketplace(
await client.login(
username=marketplace.email, password=marketplace.password
)
await client.send_heartbeat(public_url, ttl)
await client.update_endpoint_health(
endpoint_health=endpoint_health,
ttl_seconds=ttl,
public_url=public_url,
)

self._update_state_on_success(state)
logger.info(
f"Heartbeat sent to {marketplace.name}: "
f"ttl={ttl}s, next_interval={state.current_interval:.1f}s, "
f"Endpoint heartbeat sent to {marketplace.name}: "
f"{len(endpoint_health)} endpoints, ttl={ttl}s, "
f"next_interval={state.current_interval:.1f}s, "
f"successes={state.consecutive_successes}"
)

except SyftHubError as e:
self._update_state_on_failure(state)
logger.warning(
f"Heartbeat to {marketplace.name} failed: {e.message} "
f"Endpoint heartbeat to {marketplace.name} failed: {e.message} "
f"(failures={state.consecutive_failures}, "
f"next_interval={state.current_interval:.1f}s)"
)
except Exception as e:
self._update_state_on_failure(state)
logger.exception(
f"Unexpected error sending heartbeat to {marketplace.name}: {e}"
f"Unexpected error sending endpoint heartbeat to "
f"{marketplace.name}: {e}"
)

def _get_or_create_state(self, marketplace_id: UUID) -> HeartbeatState:
Expand Down
Loading