Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/SMOODEV-950-cb-rate-state-change.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@smooai/fetch': patch
---

SMOODEV-950: Circuit breaker — rate-based detection + `on_state_change` callback in Rust/Python/Go. Adds `failure_rate_threshold` + `sliding_window_size` for rate-based tripping (Python, Rust) and an `on_state_change` callback that fires on every state transition (Python, Rust, Go-builder). Mirrors the TS `failureRateThreshold` + `onStateChange` surface.
14 changes: 14 additions & 0 deletions go/fetch/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ func (b *ClientBuilder) WithCircuitBreaker(name string, opts *CircuitBreakerOpti
return b
}

// WithCircuitBreakerStateChange registers a state-change callback on the
// configured circuit breaker. If WithCircuitBreaker has not been called, a
// fresh CircuitBreakerOptions is created so the callback has somewhere to live.
//
// This exposes the underlying sony/gobreaker `OnStateChange` at the builder
// level (mirrors the SMOODEV-950 onStateChange parity surface).
func (b *ClientBuilder) WithCircuitBreakerStateChange(fn func(name string, from, to CircuitBreakerState)) *ClientBuilder {
if b.circuitBreakerOpts == nil {
b.circuitBreakerOpts = &CircuitBreakerOptions{}
}
b.circuitBreakerOpts.OnStateChange = fn
return b
}

// WithHooks sets lifecycle hooks for the client.
func (b *ClientBuilder) WithHooks(hooks *LifecycleHooks) *ClientBuilder {
b.hooks = hooks
Expand Down
45 changes: 45 additions & 0 deletions go/fetch/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package fetch
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -344,3 +345,47 @@ func TestClientBuilder_WithContainerOptions_NilFieldsLeaveUnchanged(t *testing.T
t.Error("expected rate-limit retry to remain unset")
}
}

// TestClientBuilder_WithCircuitBreakerStateChange verifies that the dedicated
// builder helper exposes sony/gobreaker's OnStateChange callback at the
// top-level builder API (SMOODEV-950).
func TestClientBuilder_WithCircuitBreakerStateChange(t *testing.T) {
type stateChange struct{ from, to CircuitBreakerState }
var observed []stateChange
hook := func(_ string, from, to CircuitBreakerState) {
observed = append(observed, stateChange{from, to})
}

client := NewClientBuilder().
WithCircuitBreaker("rate-state-cb", &CircuitBreakerOptions{
MaxRequests: 1,
Timeout: 50 * time.Millisecond,
ReadyToTrip: func(c CircuitBreakerCounts) bool {
return c.ConsecutiveFailures >= 2
},
}).
WithCircuitBreakerStateChange(hook).
Build()

if client.circuitBreaker == nil {
t.Fatal("expected circuit breaker to be configured")
}

// Drive 2 consecutive failures → breaker should trip to open.
testErr := errors.New("boom")
for i := 0; i < 2; i++ {
_, _ = client.circuitBreaker.Execute(context.Background(), func(_ context.Context) (any, error) {
return nil, testErr
})
}
if got := client.circuitBreaker.State(); got != CircuitBreakerStateOpen {
t.Fatalf("expected breaker to be open, got %d", got)
}
if len(observed) == 0 {
t.Fatal("expected at least one state-change callback invocation")
}
// The first transition should be Closed → Open.
if observed[0].from != CircuitBreakerStateClosed || observed[0].to != CircuitBreakerStateOpen {
t.Errorf("expected Closed→Open as first transition, got %d→%d", observed[0].from, observed[0].to)
}
}
2 changes: 2 additions & 0 deletions python/src/smooai_fetch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from smooai_fetch._types import (
AuthTokenProvider,
CircuitBreakerOptions,
CircuitStateChangeCallback,
FetchContainerOptions,
FetchOptions,
LifecycleHooks,
Expand All @@ -69,6 +70,7 @@
# Types
"AuthTokenProvider",
"CircuitBreakerOptions",
"CircuitStateChangeCallback",
"FetchContainerOptions",
"FetchOptions",
"LifecycleHooks",
Expand Down
55 changes: 45 additions & 10 deletions python/src/smooai_fetch/_circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import asyncio
import time
from collections import deque
from collections.abc import Awaitable, Callable
from enum import Enum
from typing import TypeVar
Expand All @@ -30,8 +31,9 @@ class CircuitBreaker:
"""An async-compatible circuit breaker.

State transitions:
- CLOSED: Normal operation. Failures are counted.
- OPEN: Requests are rejected immediately. After timeout, transitions to HALF_OPEN.
- CLOSED: Normal operation. Failures are counted (and optionally tracked over a
sliding window when `failure_rate_threshold` is set).
- OPEN: Requests are rejected immediately. After `timeout`, transitions to HALF_OPEN.
- HALF_OPEN: A limited number of requests are allowed through. If they succeed
(reaching success_threshold), transitions to CLOSED. If one fails, transitions
back to OPEN.
Expand All @@ -41,11 +43,16 @@ def __init__(self, options: CircuitBreakerOptions) -> None:
self._failure_threshold = options.failure_threshold
self._success_threshold = options.success_threshold
self._timeout = options.timeout # seconds
self._failure_rate_threshold = options.failure_rate_threshold
self._sliding_window_size = max(1, options.sliding_window_size)
self._on_state_change = options.on_state_change

self._state = CircuitState.CLOSED
self._failure_count = 0
self._success_count = 0
self._last_failure_time: float | None = None
# Recent outcomes window: True = success, False = failure.
self._outcomes: deque[bool] = deque(maxlen=self._sliding_window_size)
self._lock = asyncio.Lock()

@property
Expand Down Expand Up @@ -93,37 +100,65 @@ async def call(self, func: Callable[..., Awaitable[T]]) -> T:

return result

def _transition(self, target: CircuitState) -> None:
"""Move to `target` and fire the on_state_change callback (if registered)."""
if self._state is target:
return
previous = self._state
self._state = target
if self._on_state_change is not None:
# Errors in user callbacks should not break the breaker's bookkeeping.
try:
self._on_state_change(previous.value, target.value)
except Exception:
pass

def _get_state(self) -> CircuitState:
"""Get the current state, potentially transitioning OPEN -> HALF_OPEN."""
if self._state == CircuitState.OPEN and self._last_failure_time is not None:
elapsed = time.monotonic() - self._last_failure_time
if elapsed >= self._timeout:
self._state = CircuitState.HALF_OPEN
self._transition(CircuitState.HALF_OPEN)
self._success_count = 0
return CircuitState.HALF_OPEN
return self._state

def _record_success(self) -> None:
"""Record a successful call."""
self._outcomes.append(True)
if self._state == CircuitState.HALF_OPEN:
self._success_count += 1
if self._success_count >= self._success_threshold:
self._state = CircuitState.CLOSED
self._transition(CircuitState.CLOSED)
self._failure_count = 0
self._success_count = 0
self._outcomes.clear()
elif self._state == CircuitState.CLOSED:
# Reset failure count on success in closed state
# Reset failure count on success in closed state.
self._failure_count = 0

def _record_failure(self) -> None:
"""Record a failed call."""
self._outcomes.append(False)
if self._state == CircuitState.HALF_OPEN:
# Any failure in half-open goes back to open
self._state = CircuitState.OPEN
# Any failure in half-open goes back to open.
self._transition(CircuitState.OPEN)
self._last_failure_time = time.monotonic()
self._success_count = 0
elif self._state == CircuitState.CLOSED:
return
if self._state == CircuitState.CLOSED:
self._failure_count += 1
if self._failure_count >= self._failure_threshold:
self._state = CircuitState.OPEN
# Rate-based detection: when a threshold is configured and enough
# samples have been observed, evaluate the failure ratio over the
# sliding window.
if self._failure_rate_threshold is not None and len(self._outcomes) >= self._failure_threshold:
failures = sum(1 for ok in self._outcomes if not ok)
rate = failures / len(self._outcomes)
if rate >= self._failure_rate_threshold:
self._transition(CircuitState.OPEN)
self._last_failure_time = time.monotonic()
return
# Count-based detection: trip when consecutive failures reach the threshold.
if self._failure_rate_threshold is None and self._failure_count >= self._failure_threshold:
self._transition(CircuitState.OPEN)
self._last_failure_time = time.monotonic()
36 changes: 35 additions & 1 deletion python/src/smooai_fetch/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,19 +153,53 @@ class RateLimitOptions:
"""Duration of the sliding window in milliseconds."""


CircuitStateChangeCallback = Callable[[str, str], None]
"""Callback invoked when the circuit breaker transitions between states.

Receives `(from_state, to_state)` where each value is one of `"closed"`,
`"open"`, or `"half-open"`. Mirrors the Go port's `OnStateChange` callback.
"""


@dataclass
class CircuitBreakerOptions:
"""Configuration options for circuit breaker behavior."""

failure_threshold: int = 5
"""Number of failures before the circuit opens."""
"""Number of failures before the circuit opens.

Used when `failure_rate_threshold` is None (the default). With a rate-based
threshold this still acts as the minimum sample count before the rate
evaluation kicks in.
"""

success_threshold: int = 2
"""Number of successes in half-open state to close the circuit."""

timeout: float = 30.0
"""Seconds to wait before transitioning from open to half-open."""

failure_rate_threshold: float | None = None
"""Optional failure rate (0.0–1.0) over a sliding window that trips the breaker.

When set, the breaker tracks the most recent `sliding_window_size` outcomes
and trips when the failure ratio meets or exceeds this threshold (after
`failure_threshold` minimum samples have been observed). Mirrors the TS
`failureRateThreshold` setting.
"""

sliding_window_size: int = 10
"""Number of recent outcomes to retain for rate-based detection.

Only consulted when `failure_rate_threshold` is set.
"""

on_state_change: CircuitStateChangeCallback | None = None
"""Optional callback invoked when the breaker transitions between states.

Receives `(from_state, to_state)`. Useful for telemetry / alerting.
"""


# Rate-limit-specific retry options share the same shape as the main RetryOptions.
# This mirrors the Go port (`type RateLimitRetryOptions = RetryOptions`) and the
Expand Down
107 changes: 107 additions & 0 deletions python/tests/test_circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,113 @@ async def fail():
assert "Circuit breaker is open" in str(exc_info.value)


async def test_on_state_change_fires_on_open_and_half_open_and_closed():
"""`on_state_change` callback fires on each transition."""
transitions: list[tuple[str, str]] = []

cb = CircuitBreaker(
CircuitBreakerOptions(
failure_threshold=2,
success_threshold=1,
timeout=0.1,
on_state_change=lambda fr, to: transitions.append((fr, to)),
)
)

async def fail():
raise ValueError("fail")

async def success():
return "ok"

# Closed → Open after threshold
for _ in range(2):
with pytest.raises(ValueError):
await cb.call(fail)
assert ("closed", "open") in transitions

# Wait for half-open transition triggered on next call.
await asyncio.sleep(0.15)
result = await cb.call(success)
assert result == "ok"

# Should have seen: closed→open, open→half-open, half-open→closed.
kinds = transitions
assert ("closed", "open") in kinds
assert ("open", "half-open") in kinds
assert ("half-open", "closed") in kinds


async def test_rate_based_threshold_trips_on_failure_rate():
"""`failure_rate_threshold` trips the breaker when the failure rate in the window crosses the threshold."""
cb = CircuitBreaker(
CircuitBreakerOptions(
failure_threshold=4, # minimum sample count before rate eval kicks in
success_threshold=1,
timeout=10.0,
failure_rate_threshold=0.7, # 70% failure rate to trip
sliding_window_size=10,
)
)

async def fail():
raise ValueError("fail")

async def success():
return "ok"

# 3 successes followed by 1 failure = 1/4 = 25% — below threshold, stays closed.
for _ in range(3):
await cb.call(success)
with pytest.raises(ValueError):
await cb.call(fail)
assert cb.state == "closed"

# 4 more failures: window now 3 ok / 5 fail = 5/8 = 62.5% — still below 70%.
for _ in range(4):
with pytest.raises(ValueError):
await cb.call(fail)
assert cb.state == "closed"

# 2 more failures pushes failure rate over 70%.
# window: 3 ok / 7 fail = 7/10 = 70% — trips.
with pytest.raises(ValueError):
await cb.call(fail)
# Window is now full (10) at 6 fail / 3 ok / 1 fail = 7 fail; 7/10 = 70%.
# If this didn't trip yet, one more failure definitely will.
if cb.state != "open":
with pytest.raises(ValueError):
await cb.call(fail)
assert cb.state == "open"


async def test_rate_threshold_respects_minimum_samples():
"""Below the minimum sample count (`failure_threshold`), the rate evaluation is suppressed."""
cb = CircuitBreaker(
CircuitBreakerOptions(
failure_threshold=5,
success_threshold=1,
timeout=10.0,
failure_rate_threshold=0.5,
sliding_window_size=10,
)
)

async def fail():
raise ValueError("fail")

# 4 consecutive failures (below the 5-sample minimum) → still closed.
for _ in range(4):
with pytest.raises(ValueError):
await cb.call(fail)
assert cb.state == "closed"

# 5th failure: 5/5 = 100% → trips.
with pytest.raises(ValueError):
await cb.call(fail)
assert cb.state == "open"


async def test_success_does_not_open():
"""Test that successful calls do not affect the circuit breaker."""
cb = CircuitBreaker(
Expand Down
Loading
Loading