Skip to content
104 changes: 100 additions & 4 deletions apps/backend/agents/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from typing import Any

from claude_agent_sdk import ClaudeSDKClient
from core.circuit_breaker import CircuitBreaker
from core.error_classifier import ErrorClassifier
from core.memory_monitor import MemoryMonitor, MemoryPressure, SessionBounds
from core.token_stats import PhaseTokenStats, PhaseType, TaskTokenStats
from debug import (
debug,
Expand Down Expand Up @@ -64,6 +67,13 @@

logger = logging.getLogger(__name__)

# Module-level resilience singletons (shared across sessions)
_memory_monitor = MemoryMonitor()
_GC_MESSAGE_INTERVAL = 50 # Run GC check every N messages
_api_circuit_breaker = CircuitBreaker(
name="sdk_api", failure_threshold=3, recovery_timeout=60.0
)


# ============================================================================
# Conversation History Tracking
Expand Down Expand Up @@ -908,6 +918,29 @@ async def run_agent_session(
"session", "Created conversation round", round_number=current_round.round_number
)

# Session-scoped error classifier (avoids cross-session state leaking)
error_classifier = ErrorClassifier()

# Check memory pressure before starting
pressure = _memory_monitor.check_pressure()
if pressure == MemoryPressure.CRITICAL:
msg = "Cannot start session: memory pressure is CRITICAL"
debug_error("session", msg, usage_mb=_memory_monitor.get_usage_mb())
if task_logger:
task_logger.log_error(msg, phase)
return "error", msg, None, decision_tracker

# Check circuit breaker
if not _api_circuit_breaker.can_execute():
msg = (
f"API circuit breaker is OPEN ({_api_circuit_breaker.name}). "
"Too many consecutive failures — waiting for recovery."
)
debug_error("session", msg)
if task_logger:
task_logger.log_error(msg, phase)
return "error", msg, None, decision_tracker

try:
# Send the query
debug("session", "Sending query to Claude SDK...")
Expand All @@ -926,6 +959,19 @@ async def run_agent_session(
msg_type=msg_type,
)

# Session bounds safety check
if SessionBounds.check(current_round.round_number, message_count):
reason = SessionBounds.reason(current_round.round_number, message_count)
debug_error("session", reason)
if task_logger:
task_logger.log_error(reason, phase)
_memory_monitor.maybe_gc()
return "error", reason, None, decision_tracker

# Periodic GC under memory pressure
if message_count % _GC_MESSAGE_INTERVAL == 0:
_memory_monitor.maybe_gc()

# Handle AssistantMessage (text and tool use)
if msg_type == "AssistantMessage" and hasattr(msg, "content"):
for block in msg.content:
Expand Down Expand Up @@ -1083,6 +1129,26 @@ async def run_agent_session(

print("\n" + "-" * 70 + "\n")

# Record successful API interaction
_api_circuit_breaker.record_success()

# Check response for error signals (auth failures, stuck loops, etc.)
classified = error_classifier.classify_response(response_text)
if classified and classified.is_fatal:
error_msg = classified.message
debug_error(
"session",
f"Fatal error detected in response: [{classified.category.value}] {error_msg}",
)
print(f"\n[{classified.category.value.upper()}] {error_msg}")
if classified.action_hint:
print(f" Action: {classified.action_hint}")
if task_logger:
task_logger.log_error(
f"[{classified.category.value.upper()}] {error_msg}", phase
)
return "error", error_msg, None, decision_tracker

# Extract usage metadata from Claude SDK client
usage_metadata = None
try:
Expand Down Expand Up @@ -1181,22 +1247,37 @@ async def run_agent_session(
return "continue", response_text, usage_metadata, decision_tracker

except Exception as e:
# Classify the exception for structured error reporting
classified = error_classifier.classify_exception(e)
_api_circuit_breaker.record_failure(e)

debug_error(
"session",
f"Session error: {e}",
f"Session error [{classified.category.value}]: {e}",
exception_type=type(e).__name__,
is_fatal=classified.is_fatal,
is_retryable=classified.is_retryable,
message_count=message_count,
tool_count=tool_count,
)
print(f"Error during agent session: {e}")

# Print structured error message matching frontend AUTH_FAILURE_PATTERNS
error_msg = classified.message
print(f"\n[{classified.category.value.upper()}] {error_msg}")
if classified.action_hint:
print(f" Action: {classified.action_hint}")

if task_logger:
task_logger.log_error(f"Session error: {e}", phase)
task_logger.log_error(
f"[{classified.category.value.upper()}] {error_msg}", phase
)

# Save conversation history even on error for debugging
try:
conversation_history.save()
except Exception as save_err:
logger.debug(f"Failed to save conversation history after error: {save_err}")
return "error", str(e), None, decision_tracker
return "error", error_msg, None, decision_tracker


def _assess_and_record_failure(
Expand Down Expand Up @@ -1280,6 +1361,21 @@ async def run_agent_session_isolated(
subtask_id=subtask_id,
)

# Pre-checks: fail fast if system is unhealthy (matches run_agent_session)
pressure = _memory_monitor.check_pressure()
if pressure == MemoryPressure.CRITICAL:
msg = "Cannot start isolated session: memory pressure is CRITICAL"
debug_error("session", msg, usage_mb=_memory_monitor.get_usage_mb())
return "error", msg, None

if not _api_circuit_breaker.can_execute():
msg = (
f"API circuit breaker is OPEN ({_api_circuit_breaker.name}). "
"Too many consecutive failures — waiting for recovery."
)
debug_error("session", msg)
return "error", msg, None

# Initialize recovery manager for automatic crash recovery
recovery_manager = RecoveryManager(spec_dir=spec_dir, project_dir=project_dir)

Expand Down
13 changes: 13 additions & 0 deletions apps/backend/core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,19 @@ def _trigger_login_windows() -> bool:
return False


def emit_rate_limit_marker(reset_time: str | None = None) -> None:
"""Print a structured marker that the frontend can detect for rate-limit handling.

The frontend ``rate-limit-detector.ts`` scans process output for rate-limit
patterns. This function prints a canonical marker line so the detection is
reliable regardless of the upstream error format.
"""
parts = ["[RATE_LIMITED]"]
if reset_time:
parts.append(f"reset_time={reset_time}")
print(" ".join(parts), flush=True)


def ensure_authenticated() -> str:
"""
Ensure the user is authenticated, prompting for login if needed.
Expand Down
108 changes: 108 additions & 0 deletions apps/backend/core/circuit_breaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Circuit Breaker
===============

Implements the circuit breaker pattern to prevent repeated calls to a
failing service. After *failure_threshold* consecutive failures the
circuit opens and all subsequent calls are rejected until
*recovery_timeout* seconds have elapsed, at which point it enters a
half-open state allowing one probe call.
"""

import logging
import time
from enum import Enum

logger = logging.getLogger(__name__)


class CircuitState(Enum):
"""Circuit breaker states."""

CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half_open"


class CircuitBreaker:
"""Simple circuit breaker for API calls."""

def __init__(
self,
name: str,
failure_threshold: int = 3,
recovery_timeout: float = 60.0,
) -> None:
self.name = name
self._failure_threshold = failure_threshold
self._recovery_timeout = recovery_timeout

self._failure_count: int = 0
self._last_failure_time: float = 0.0
self._state = CircuitState.CLOSED

@property
def state(self) -> CircuitState:
"""Return the current effective state (may transition from OPEN → HALF_OPEN)."""
if self._state == CircuitState.OPEN:
elapsed = time.monotonic() - self._last_failure_time
if elapsed >= self._recovery_timeout:
self._state = CircuitState.HALF_OPEN
logger.info(
"Circuit breaker '%s' transitioned to HALF_OPEN after %.1fs",
self.name,
elapsed,
)
return self._state

def can_execute(self) -> bool:
"""Return True if the circuit allows a call to proceed."""
current = self.state
if current == CircuitState.CLOSED:
return True
if current == CircuitState.HALF_OPEN:
return True # Allow one probe call
return False # OPEN — reject

def record_success(self) -> None:
"""Record a successful call — resets the breaker to CLOSED."""
if self._state != CircuitState.CLOSED:
logger.info(
"Circuit breaker '%s' recovered → CLOSED",
self.name,
)
self._failure_count = 0
self._state = CircuitState.CLOSED

def record_failure(self, error: Exception | None = None) -> None:
"""Record a failed call — may trip the breaker to OPEN."""
self._failure_count += 1
self._last_failure_time = time.monotonic()

if self._state == CircuitState.HALF_OPEN:
# Probe failed — reopen
self._state = CircuitState.OPEN
logger.warning(
"Circuit breaker '%s' probe failed → OPEN (error: %s)",
self.name,
error,
)
elif self._failure_count >= self._failure_threshold:
self._state = CircuitState.OPEN
logger.warning(
"Circuit breaker '%s' tripped → OPEN after %d failures (error: %s)",
self.name,
self._failure_count,
error,
)

def reset(self) -> None:
"""Manually reset the breaker to CLOSED."""
self._failure_count = 0
self._state = CircuitState.CLOSED

def __repr__(self) -> str:
return (
f"CircuitBreaker(name={self.name!r}, state={self.state.value}, "
f"failures={self._failure_count}/{self._failure_threshold})"
)
Loading
Loading