diff --git a/alembic/versions/2026_04_09_0001-add_is_admin_to_users.py b/alembic/versions/2026_04_09_0001-add_is_admin_to_users.py new file mode 100644 index 0000000..a0b4ac4 --- /dev/null +++ b/alembic/versions/2026_04_09_0001-add_is_admin_to_users.py @@ -0,0 +1,26 @@ +"""add is_admin column to users table + +Revision ID: add_is_admin_to_users +Revises: add_communication_networks +Create Date: 2026-04-09 00:01:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "add_is_admin_to_users" +down_revision = "add_communication_networks" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "users", + sa.Column("is_admin", sa.Boolean(), nullable=False, server_default="false"), + ) + + +def downgrade() -> None: + op.drop_column("users", "is_admin") diff --git a/alembic/versions/2026_04_09_0002-add_halt_codes_table.py b/alembic/versions/2026_04_09_0002-add_halt_codes_table.py new file mode 100644 index 0000000..daeb287 --- /dev/null +++ b/alembic/versions/2026_04_09_0002-add_halt_codes_table.py @@ -0,0 +1,36 @@ +"""add halt_codes table for distributed kill switch + +Revision ID: add_halt_codes +Revises: add_is_admin_to_users +Create Date: 2026-04-09 00:02:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + +# revision identifiers, used by Alembic. +revision = "add_halt_codes" +down_revision = "add_is_admin_to_users" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "halt_codes", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column("code_hash", sa.String(), nullable=False), + sa.Column("label", sa.String(), nullable=False), + sa.Column("trustee_name", sa.String(), nullable=False), + sa.Column("trustee_email", sa.String(), nullable=True), + sa.Column("is_master", sa.Boolean(), nullable=False, server_default="false"), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("created_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + +def downgrade() -> None: + op.drop_table("halt_codes") diff --git a/src/core/admin_auth.py b/src/core/admin_auth.py new file mode 100644 index 0000000..9ec5123 --- /dev/null +++ b/src/core/admin_auth.py @@ -0,0 +1,20 @@ +"""Admin authentication dependency.""" + +from fastapi import Depends + +from src.core.auth import get_current_user +from src.exceptions import ForbiddenException +from src.models.auth import User + + +async def get_admin_user( + current_user: User = Depends(get_current_user), +) -> User: + """Require the current user to be an admin. + + Wraps get_current_user and raises 403 if the user does not have + the is_admin flag set. + """ + if not getattr(current_user, "is_admin", False): + raise ForbiddenException("Admin access required") + return current_user diff --git a/src/core/settings.py b/src/core/settings.py index 097e7cc..3495d30 100644 --- a/src/core/settings.py +++ b/src/core/settings.py @@ -103,6 +103,10 @@ class Settings(BaseSettings): NETWORK_CALLBACK_TIMEOUT_SECONDS: int = 30 NETWORK_MESSAGE_DELIVERY_MAX_RETRIES: int = 3 + # ── Safety & Governance ───────────────────────────────────────────── + SAFETY_CHECK_ENABLED: bool = True + AGENT_STATUS_CACHE_TTL: int = 300 # seconds to cache agent active status in Redis + # ── Economy settings (from agent-economy) ────────────────────────── ECONOMY_WELCOME_BONUS_CREDITS: int = 500 ECONOMY_CREDIT_PACKAGES: list[dict] = [ diff --git a/src/exceptions.py b/src/exceptions.py index 70a6836..940bab7 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -12,6 +12,8 @@ "ValidationException", "DatabaseException", "RateLimitException", + "PlatformHaltedException", + "AgentDisabledException", ] @@ -91,3 +93,18 @@ class RateLimitException(BaseCustomException): def __init__(self, detail: str = "Rate limit exceeded. Please try again later"): super().__init__(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=detail) + + +# Safety & Governance Exceptions +class PlatformHaltedException(BaseCustomException): + """Exception raised when the platform is in emergency halt mode""" + + def __init__(self, detail: str = "Platform is in emergency halt mode. All agent operations are suspended."): + super().__init__(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=detail) + + +class AgentDisabledException(BaseCustomException): + """Exception raised when a disabled agent is invoked""" + + def __init__(self, detail: str = "Agent has been disabled by an administrator"): + super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail) diff --git a/src/main.py b/src/main.py index b82e095..e67ef58 100644 --- a/src/main.py +++ b/src/main.py @@ -30,6 +30,8 @@ from src.routes.message import router as message_router from src.routes.registry import router as registry_router from src.routes.task import router as task_router +from src.routes.admin import router as admin_router +from src.routes.safety import router as safety_router from src.mcp_app import create_mcp_app # Workflow routers (from agent-os) @@ -232,6 +234,10 @@ async def handle_workflow_exception(_request: Request, exc: WorkflowAppException app.include_router(invocation_log_router) app.include_router(task_router) +# ── Admin / Safety routers ─────────────────────────────────────────── +app.include_router(admin_router, tags=["Admin"]) +app.include_router(safety_router, tags=["Safety"]) + # ── Workflow routers (from agent-os) ───────────────────────────────── app.include_router(workflow_router, prefix="/workflows", tags=["Workflows"]) app.include_router(execution_router, tags=["Executions"]) diff --git a/src/models/__init__.py b/src/models/__init__.py index 78b58a1..fe6882e 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -12,6 +12,7 @@ from src.models.message import Message from src.models.registry import Agent, AgentCredential, AgentRating from src.models.task import Task +from src.models.halt_code import HaltCode # Workflow models (from agent-os) from src.workflow.models.entities import ( # noqa: F401 @@ -47,6 +48,7 @@ "AgentCredential", "InvocationLog", "Task", + "HaltCode", # Workflow "WorkflowDefinition", "WorkflowExecution", diff --git a/src/models/auth.py b/src/models/auth.py index 816be30..cb577da 100644 --- a/src/models/auth.py +++ b/src/models/auth.py @@ -22,6 +22,7 @@ class User(BaseModel): last_name: Column[str] = Column(String, nullable=True) phone_number: Column[str] = Column(String, nullable=True, unique=True) is_active: Column[bool] = Column(Boolean, default=True, nullable=False) + is_admin: Column[bool] = Column(Boolean, default=False, nullable=False) # Relationships api_keys = relationship("ApiKey", back_populates="user", cascade="all, delete-orphan") diff --git a/src/models/halt_code.py b/src/models/halt_code.py new file mode 100644 index 0000000..f9260a9 --- /dev/null +++ b/src/models/halt_code.py @@ -0,0 +1,25 @@ +"""Halt code model — distributed kill switch codes for trustees.""" + +from typing import Optional +from uuid import UUID + +from sqlalchemy import Boolean, Column, ForeignKey, String +from sqlalchemy.dialects.postgresql import UUID as PostgresUUID + +from .base import BaseModel + + +class HaltCode(BaseModel): + """A halt code held by a trustee who can stop the platform.""" + + __tablename__: str = "halt_codes" + + code_hash: Column[str] = Column(String, nullable=False) + label: Column[str] = Column(String, nullable=False) + trustee_name: Column[str] = Column(String, nullable=False) + trustee_email: Column[Optional[str]] = Column(String, nullable=True) + is_master: Column[bool] = Column(Boolean, default=False, nullable=False) + is_active: Column[bool] = Column(Boolean, default=True, nullable=False) + created_by: Column[UUID] = Column( + PostgresUUID, ForeignKey("users.id"), nullable=False + ) diff --git a/src/network/a2a/routes.py b/src/network/a2a/routes.py index e4330de..d16cb60 100644 --- a/src/network/a2a/routes.py +++ b/src/network/a2a/routes.py @@ -79,6 +79,10 @@ async def a2a_task_send( from src.network.utils.context_manager import NetworkContextManager from src.database import get_redis + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + params = data.params task_data = params.get("task", {}) network_id = params.get("network_id") diff --git a/src/network/services/channels.py b/src/network/services/channels.py index 01f6591..09a4025 100644 --- a/src/network/services/channels.py +++ b/src/network/services/channels.py @@ -258,6 +258,10 @@ async def handle_callback( This is the key to bidirectionality: external agents POST to their reply_url and this method records the message in the network. """ + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + sender = await self.repo.get_participant(participant_id) if not sender or sender.network_id != network_id: raise NotFoundException("Participant") @@ -322,6 +326,10 @@ async def _validate_communication( sender_id: UUID, recipient_id: UUID, ) -> tuple[NetworkParticipant, NetworkParticipant, CommunicationNetwork]: + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_agent_active, check_platform_halt + await check_platform_halt() + network = await self.repo.get_network(network_id) if not network: raise NotFoundException("Network") @@ -340,6 +348,12 @@ async def _validate_communication( if recipient.status != ParticipantStatus.active: raise BadRequestException("Recipient is not active") + # Safety check: verify linked agents are still active + if sender.agent_id: + await check_agent_active(sender.agent_id) + if recipient.agent_id: + await check_agent_active(recipient.agent_id) + return sender, recipient, network async def _record_message( diff --git a/src/routes/admin.py b/src/routes/admin.py new file mode 100644 index 0000000..d73e1be --- /dev/null +++ b/src/routes/admin.py @@ -0,0 +1,182 @@ +"""Admin routes — platform governance and agent kill switch. + +All endpoints require admin privileges via the get_admin_user dependency. +""" + +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter, Depends +from pydantic import BaseModel, Field +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.admin_auth import get_admin_user +from src.database import get_db +from src.exceptions import NotFoundException +from src.models.auth import User +from src.models.registry import Agent +from src.services import safety + +router = APIRouter(prefix="/admin", tags=["Admin"]) + + +# ── Request/Response schemas ──────────────────────────────────────── + + +class KillAgentRequest(BaseModel): + reason: str = Field(..., min_length=1, description="Why is this agent being disabled?") + + +class HaltPlatformRequest(BaseModel): + reason: str = Field(..., min_length=1, description="Why is the platform being halted?") + + +class AgentStatusResponse(BaseModel): + agent_id: str + agent_uuid: UUID + name: str + is_active: bool + owner_id: UUID + + +class PlatformStatusResponse(BaseModel): + halted: bool + reason: Optional[str] = None + halted_by: Optional[str] = None + redis_available: bool + disabled_agent_count: int = 0 + + +# ── Agent kill switch ─────────────────────────────────────────────── + + +@router.post("/agents/{agent_uuid}/kill") +async def kill_agent( + agent_uuid: UUID, + body: KillAgentRequest, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Force-disable an agent. Sets is_active=False and caches in Redis.""" + result = await session.execute(select(Agent).where(Agent.id == agent_uuid)) + agent = result.scalar_one_or_none() + if not agent: + raise NotFoundException("Agent") + + agent.is_active = False + await session.commit() + + # Cache kill in Redis for fast rejection + await safety.kill_agent(agent_uuid) + + return { + "success": True, + "agent_id": agent.agent_id, + "agent_uuid": str(agent_uuid), + "reason": body.reason, + "killed_by": str(admin.id), + } + + +@router.post("/agents/{agent_uuid}/reactivate") +async def reactivate_agent( + agent_uuid: UUID, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Re-enable a previously disabled agent.""" + result = await session.execute(select(Agent).where(Agent.id == agent_uuid)) + agent = result.scalar_one_or_none() + if not agent: + raise NotFoundException("Agent") + + agent.is_active = True + await session.commit() + + # Clear Redis kill cache + await safety.reactivate_agent(agent_uuid) + + return { + "success": True, + "agent_id": agent.agent_id, + "agent_uuid": str(agent_uuid), + "reactivated_by": str(admin.id), + } + + +@router.get("/agents/disabled", response_model=list[AgentStatusResponse]) +async def list_disabled_agents( + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """List all currently disabled agents.""" + result = await session.execute( + select(Agent).where(Agent.is_active == False).order_by(Agent.updated_at.desc()) # noqa: E712 + ) + agents = result.scalars().all() + + return [ + AgentStatusResponse( + agent_id=a.agent_id, + agent_uuid=a.id, + name=a.name, + is_active=a.is_active, + owner_id=a.user_id, + ) + for a in agents + ] + + +# ── Platform halt ─────────────────────────────────────────────────── + + +@router.post("/platform/halt") +async def halt_platform( + body: HaltPlatformRequest, + admin: User = Depends(get_admin_user), +): + """Emergency halt — suspend all agent operations platform-wide.""" + await safety.halt_platform(body.reason, admin.id) + return { + "success": True, + "halted": True, + "reason": body.reason, + "halted_by": str(admin.id), + } + + +@router.post("/platform/resume") +async def resume_platform( + admin: User = Depends(get_admin_user), +): + """Resume platform operations after an emergency halt.""" + await safety.resume_platform(admin.id) + return { + "success": True, + "halted": False, + "resumed_by": str(admin.id), + } + + +@router.get("/platform/status", response_model=PlatformStatusResponse) +async def get_platform_status( + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Get current platform safety status.""" + status = await safety.get_platform_status() + + # Count disabled agents + result = await session.execute( + select(func.count()).select_from(Agent).where(Agent.is_active == False) # noqa: E712 + ) + disabled_count = result.scalar() or 0 + + return PlatformStatusResponse( + halted=status.get("halted", False), + reason=status.get("reason"), + halted_by=status.get("halted_by"), + redis_available=status.get("redis_available", False), + disabled_agent_count=disabled_count, + ) diff --git a/src/routes/safety.py b/src/routes/safety.py new file mode 100644 index 0000000..726ae08 --- /dev/null +++ b/src/routes/safety.py @@ -0,0 +1,204 @@ +"""Safety routes — public halt endpoint and admin halt code management. + +The halt endpoint is intentionally PUBLIC (no JWT required). The code +itself is the authentication. This is by design: it should be easy to +stop the platform, harder to restart it. +""" + +import secrets +from typing import Optional +from uuid import UUID + +import bcrypt +from fastapi import APIRouter, Depends +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.admin_auth import get_admin_user +from src.database import get_db +from src.models.auth import User +from src.models.halt_code import HaltCode +from src.services import safety + +router = APIRouter(prefix="/safety", tags=["Safety"]) + + +# ── Schemas ───────────────────────────────────────────────────────── + + +class HaltRequest(BaseModel): + code: str = Field(..., min_length=1, description="Halt code issued to a trustee") + reason: Optional[str] = Field(default=None, description="Optional reason for halting") + + +class HaltResponse(BaseModel): + halted: bool + trustee: str + reason: Optional[str] = None + message: str + + +class CreateHaltCodeRequest(BaseModel): + trustee_name: str = Field(..., min_length=1) + trustee_email: Optional[str] = None + label: str = Field(..., min_length=1, description="Human-readable label, e.g. 'Guardian - Europe'") + is_master: bool = Field(default=False) + + +class CreateHaltCodeResponse(BaseModel): + id: UUID + label: str + trustee_name: str + is_master: bool + code: str = Field(description="The plaintext code — shown ONCE, never stored") + + +class HaltCodeListItem(BaseModel): + id: UUID + label: str + trustee_name: str + trustee_email: Optional[str] + is_master: bool + is_active: bool + + +class PlatformStatusPublic(BaseModel): + halted: bool + message: str + + +# ── Public endpoints (no auth) ────────────────────────────────────── + + +@router.get("/status", response_model=PlatformStatusPublic) +async def get_public_status(): + """Public platform status — anyone can check if the platform is halted.""" + status = await safety.get_platform_status() + halted = status.get("halted", False) + return PlatformStatusPublic( + halted=halted, + message="Platform is halted. All agent operations are suspended." if halted + else "Platform is operational.", + ) + + +@router.post("/halt", response_model=HaltResponse) +async def halt_with_code( + body: HaltRequest, + session: AsyncSession = Depends(get_db), +): + """Halt the platform using a trustee code. + + This endpoint is PUBLIC — no JWT required. The halt code is the + authentication. By design, stopping the platform should be easy. + Restarting requires admin authentication. + """ + # Find all active halt codes and check against each + result = await session.execute( + select(HaltCode).where(HaltCode.is_active == True) # noqa: E712 + ) + halt_codes = result.scalars().all() + + matched_code = None + for hc in halt_codes: + if bcrypt.checkpw(body.code.encode("utf-8"), hc.code_hash.encode("utf-8")): + matched_code = hc + break + + if not matched_code: + from src.exceptions import ForbiddenException + raise ForbiddenException("Invalid halt code") + + reason = body.reason or f"Halted by trustee: {matched_code.trustee_name}" + await safety.halt_platform(reason, matched_code.created_by) + + return HaltResponse( + halted=True, + trustee=matched_code.trustee_name, + reason=reason, + message="Platform has been halted. All agent operations are suspended.", + ) + + +# ── Admin endpoints (manage halt codes) ───────────────────────────── + + +@router.post("/codes", response_model=CreateHaltCodeResponse) +async def create_halt_code( + body: CreateHaltCodeRequest, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Create a new halt code for a trustee. The plaintext code is returned + ONCE and never stored — only its bcrypt hash is persisted.""" + # Generate a secure random code: 8 groups of 4 chars + raw_code = "-".join( + secrets.token_hex(2).upper() for _ in range(4) + ) + + code_hash = bcrypt.hashpw( + raw_code.encode("utf-8"), bcrypt.gensalt() + ).decode("utf-8") + + halt_code = HaltCode( + code_hash=code_hash, + label=body.label, + trustee_name=body.trustee_name, + trustee_email=body.trustee_email, + is_master=body.is_master, + created_by=admin.id, + ) + session.add(halt_code) + await session.commit() + await session.refresh(halt_code) + + return CreateHaltCodeResponse( + id=halt_code.id, + label=body.label, + trustee_name=body.trustee_name, + is_master=body.is_master, + code=raw_code, + ) + + +@router.get("/codes", response_model=list[HaltCodeListItem]) +async def list_halt_codes( + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """List all halt codes (without the actual codes — those are never stored).""" + result = await session.execute( + select(HaltCode).order_by(HaltCode.created_at.desc()) + ) + codes = result.scalars().all() + return [ + HaltCodeListItem( + id=c.id, + label=c.label, + trustee_name=c.trustee_name, + trustee_email=c.trustee_email, + is_master=c.is_master, + is_active=c.is_active, + ) + for c in codes + ] + + +@router.delete("/codes/{code_id}") +async def revoke_halt_code( + code_id: UUID, + admin: User = Depends(get_admin_user), + session: AsyncSession = Depends(get_db), +): + """Revoke a halt code — it can no longer be used to halt the platform.""" + result = await session.execute(select(HaltCode).where(HaltCode.id == code_id)) + halt_code = result.scalar_one_or_none() + if not halt_code: + from src.exceptions import NotFoundException + raise NotFoundException("Halt code") + + halt_code.is_active = False + await session.commit() + + return {"success": True, "revoked": str(code_id)} diff --git a/src/schemas/registry.py b/src/schemas/registry.py index d566e24..9b5d1c6 100644 --- a/src/schemas/registry.py +++ b/src/schemas/registry.py @@ -115,6 +115,7 @@ class AgentUpdate(BaseModel): base_price: Optional[float] = None pricing_enabled: Optional[bool] = None supports_streaming: Optional[bool] = None + is_active: Optional[bool] = None @field_validator("auth_type") @classmethod diff --git a/src/services/broker.py b/src/services/broker.py index 2f2e634..de8b81c 100644 --- a/src/services/broker.py +++ b/src/services/broker.py @@ -88,6 +88,10 @@ async def invoke_agent( """ start_time = time.time() + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + # Resolve conversation_id and message_id from request if not passed conv_id = conversation_id or invoke_request.conversation_id msg_id = message_id or invoke_request.message_id @@ -532,13 +536,25 @@ async def invoke_agent_stream( async generator of dicts. Otherwise, falls back to a normal invocation and returns an InvokeResponse. """ + # Safety check: reject if platform is in emergency halt + from src.services.safety import check_platform_halt + await check_platform_halt() + # Resolve agent to check streaming support agent = await self.registry_repository.get_agent_by_agent_id( invoke_request.agent_id ) - supports_streaming = ( - getattr(agent, "supports_streaming", False) if agent else False - ) + + # Check agent exists and is active (fixes gap where streaming path skipped this) + if not agent or not agent.is_active: + return InvokeResponse( + success=False, + error="Agent not found or inactive", + latency_ms=0, + status_code=404, + ) + + supports_streaming = getattr(agent, "supports_streaming", False) if not supports_streaming: # Fall back to normal invocation diff --git a/src/services/registry.py b/src/services/registry.py index ace1c91..4686a27 100644 --- a/src/services/registry.py +++ b/src/services/registry.py @@ -271,6 +271,8 @@ async def update_agent( agent.base_price = update.base_price if update.pricing_enabled is not None: agent.pricing_enabled = update.pricing_enabled + if update.is_active is not None: + agent.is_active = update.is_active # Regenerate embedding if enhance: diff --git a/src/services/safety.py b/src/services/safety.py new file mode 100644 index 0000000..4f87576 --- /dev/null +++ b/src/services/safety.py @@ -0,0 +1,145 @@ +"""Safety service — platform halt, agent kill switch, and safety checks. + +Provides the central "off switch" for the Intuno platform. All communication +chokepoints (broker, channels, A2A) call into this service before processing. +""" + +import logging +from typing import Optional +from uuid import UUID + +from src.core.redis_client import get_redis +from src.core.settings import settings +from src.exceptions import AgentDisabledException, PlatformHaltedException + +logger = logging.getLogger(__name__) + +# Redis key constants +EMERGENCY_HALT_KEY = "platform:emergency_halt" +EMERGENCY_HALT_REASON_KEY = "platform:emergency_halt:reason" +EMERGENCY_HALT_ACTOR_KEY = "platform:emergency_halt:actor" +AGENT_STATUS_PREFIX = "agent:status:" + + +async def check_platform_halt() -> None: + """Raise PlatformHaltedException if the platform is in emergency halt. + + This is designed to be called at every communication chokepoint. + Fast O(1) Redis GET — adds ~0.1ms overhead per call. + Fails open if Redis is unavailable (consistent with rate limiter pattern). + """ + if not settings.SAFETY_CHECK_ENABLED: + return + + redis = await get_redis() + if not redis: + return # Fail open: if Redis is down, allow requests + + try: + halted = await redis.get(EMERGENCY_HALT_KEY) + if halted == "1": + reason = await redis.get(EMERGENCY_HALT_REASON_KEY) + detail = "Platform is in emergency halt mode." + if reason: + detail += f" Reason: {reason}" + raise PlatformHaltedException(detail) + except PlatformHaltedException: + raise + except Exception as e: + logger.warning("Safety check (platform halt) failed: %s", e) + + +async def check_agent_active(agent_id: UUID) -> None: + """Raise AgentDisabledException if the agent has been killed/deactivated. + + Checks Redis cache first, falls back to no-op if unavailable. + The authoritative is_active check remains in the broker/service layer + via the DB — this adds a fast-path rejection for killed agents. + """ + if not settings.SAFETY_CHECK_ENABLED: + return + + redis = await get_redis() + if not redis: + return # Fail open + + try: + key = f"{AGENT_STATUS_PREFIX}{agent_id}" + cached = await redis.get(key) + if cached == "0": + raise AgentDisabledException() + except AgentDisabledException: + raise + except Exception as e: + logger.warning("Safety check (agent status) failed: %s", e) + + +async def halt_platform(reason: str, actor_id: UUID) -> None: + """Activate emergency halt — all agent operations will be rejected.""" + redis = await get_redis() + if not redis: + raise RuntimeError("Redis is required for platform halt") + + await redis.set(EMERGENCY_HALT_KEY, "1") + await redis.set(EMERGENCY_HALT_REASON_KEY, reason) + await redis.set(EMERGENCY_HALT_ACTOR_KEY, str(actor_id)) + logger.critical( + "PLATFORM HALT activated by user %s. Reason: %s", + actor_id, + reason, + ) + + +async def resume_platform(actor_id: UUID) -> None: + """Deactivate emergency halt — resume normal operations.""" + redis = await get_redis() + if not redis: + raise RuntimeError("Redis is required for platform resume") + + await redis.delete(EMERGENCY_HALT_KEY) + await redis.delete(EMERGENCY_HALT_REASON_KEY) + await redis.delete(EMERGENCY_HALT_ACTOR_KEY) + logger.critical("PLATFORM HALT deactivated by user %s", actor_id) + + +async def kill_agent(agent_id: UUID) -> None: + """Cache agent as killed in Redis for fast rejection at chokepoints.""" + redis = await get_redis() + if not redis: + return + + key = f"{AGENT_STATUS_PREFIX}{agent_id}" + await redis.set(key, "0", ex=settings.AGENT_STATUS_CACHE_TTL) + logger.warning("Agent %s killed (cached in Redis)", agent_id) + + +async def reactivate_agent(agent_id: UUID) -> None: + """Remove killed status from Redis cache.""" + redis = await get_redis() + if not redis: + return + + key = f"{AGENT_STATUS_PREFIX}{agent_id}" + await redis.delete(key) + logger.info("Agent %s reactivated (Redis cache cleared)", agent_id) + + +async def get_platform_status() -> dict: + """Get current platform safety status.""" + redis = await get_redis() + if not redis: + return {"halted": False, "redis_available": False} + + try: + halted = await redis.get(EMERGENCY_HALT_KEY) + reason = await redis.get(EMERGENCY_HALT_REASON_KEY) + actor = await redis.get(EMERGENCY_HALT_ACTOR_KEY) + return { + "halted": halted == "1", + "reason": reason, + "halted_by": actor, + "redis_available": True, + } + except Exception as e: + logger.warning("Failed to get platform status: %s", e) + return {"halted": False, "redis_available": False, "error": str(e)} diff --git a/src/workflow/utils/resolver.py b/src/workflow/utils/resolver.py index 7876a98..790d482 100644 --- a/src/workflow/utils/resolver.py +++ b/src/workflow/utils/resolver.py @@ -50,6 +50,10 @@ async def resolve( if cache_key in self._cache: return self._cache[cache_key] + # Safety check: reject if platform is halted + from src.services.safety import check_platform_halt + await check_platform_halt() + if ref.startswith(SEARCH_PREFIX): query = ref[len(SEARCH_PREFIX):].strip() target = await self._discover(query, exclude_ids or []) @@ -133,6 +137,10 @@ async def _discover( for agent, _distance in results: if agent.agent_id in cb_excluded: continue + if not agent.is_active: + logger.info("Skipping agent '%s' — inactive", agent.agent_id) + cb_excluded.append(agent.agent_id) + continue available = await self._circuit_breaker.is_available(agent.agent_id) if not available: logger.info(