diff --git a/openhands-sdk/openhands/sdk/conversation/conversation_stats.py b/openhands-sdk/openhands/sdk/conversation/conversation_stats.py index ca4132d44..25154899e 100644 --- a/openhands-sdk/openhands/sdk/conversation/conversation_stats.py +++ b/openhands-sdk/openhands/sdk/conversation/conversation_stats.py @@ -1,4 +1,5 @@ import warnings +from collections.abc import Callable from pydantic import AliasChoices, BaseModel, Field, PrivateAttr @@ -29,6 +30,7 @@ class ConversationStats(BaseModel): ) _restored_usage_ids: set[str] = PrivateAttr(default_factory=set) + _on_stats_change: Callable[[], None] | None = PrivateAttr(default=None) @property def service_to_metrics( @@ -83,11 +85,25 @@ def get_metrics_for_service( ) return self.get_metrics_for_usage(service_id) + def set_on_stats_change(self, callback: Callable[[], None] | None) -> None: + """Set a callback to be called when stats change. + + Args: + callback: A function to call when stats are updated, or None to remove + """ + self._on_stats_change = callback + + # Wire up the callback to all already-registered metrics + for metrics in self.usage_to_metrics.values(): + metrics.set_on_change(callback) + def register_llm(self, event: RegistryEvent): # Listen for LLM creations and track their metrics llm = event.llm usage_id = llm.usage_id + stats_changed = False + # Usage costs exist but have not been restored yet if ( usage_id in self.usage_to_metrics @@ -99,3 +115,15 @@ def register_llm(self, event: RegistryEvent): # Usage is new, track its metrics if usage_id not in self.usage_to_metrics and llm.metrics: self.usage_to_metrics[usage_id] = llm.metrics + stats_changed = True + + # Set up callback on the metrics object to get notified of updates + if llm.metrics and self._on_stats_change is not None: + llm.metrics.set_on_change(self._on_stats_change) + + # Notify of stats change if callback is set and stats changed + if stats_changed and self._on_stats_change is not None: + try: + self._on_stats_change() + except Exception: + logger.exception("Stats change callback failed", exc_info=True) diff --git a/openhands-sdk/openhands/sdk/conversation/state.py b/openhands-sdk/openhands/sdk/conversation/state.py index 437a420c0..c57609acd 100644 --- a/openhands-sdk/openhands/sdk/conversation/state.py +++ b/openhands-sdk/openhands/sdk/conversation/state.py @@ -132,6 +132,28 @@ def set_on_state_change(self, callback: ConversationCallbackType | None) -> None or None to remove the callback """ self._on_state_change = callback + # Also set up stats change callback to notify when stats are mutated + if callback is not None: + self.stats.set_on_stats_change(self._notify_stats_change) + else: + self.stats.set_on_stats_change(None) + + def _notify_stats_change(self) -> None: + """Notify state change callback about stats update.""" + if self._on_state_change is not None: + try: + from openhands.sdk.event.conversation_state import ( + ConversationStateUpdateEvent, + ) + + # Create a ConversationStateUpdateEvent with the updated stats + stats_data = self.stats.model_dump(mode="json") + state_update_event = ConversationStateUpdateEvent( + key="stats", value=stats_data + ) + self._on_state_change(state_update_event) + except Exception: + logger.exception("Stats change notification failed", exc_info=True) # ===== Base snapshot helpers (same FileStore usage you had) ===== def _save_base_state(self, fs: FileStore) -> None: diff --git a/openhands-sdk/openhands/sdk/llm/utils/metrics.py b/openhands-sdk/openhands/sdk/llm/utils/metrics.py index ffa283b69..f82c5418c 100644 --- a/openhands-sdk/openhands/sdk/llm/utils/metrics.py +++ b/openhands-sdk/openhands/sdk/llm/utils/metrics.py @@ -1,8 +1,9 @@ import copy import time +from collections.abc import Callable from typing import final -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator class Cost(BaseModel): @@ -111,6 +112,8 @@ class Metrics(MetricsSnapshot): default_factory=list, description="List of token usage records" ) + _on_change: Callable[[], None] | None = PrivateAttr(default=None) + @field_validator("accumulated_cost") @classmethod def validate_accumulated_cost(cls, v: float) -> float: @@ -133,6 +136,23 @@ def initialize_accumulated_token_usage(self) -> "Metrics": ) return self + def set_on_change(self, callback: Callable[[], None] | None) -> None: + """Set a callback to be called when metrics change. + + Args: + callback: A function to call when metrics are updated, or None to remove + """ + self._on_change = callback + + def _notify_change(self) -> None: + """Notify the callback that metrics have changed.""" + if self._on_change is not None: + try: + self._on_change() + except Exception: + # Avoid breaking metrics updates if callback fails + pass + def get_snapshot(self) -> MetricsSnapshot: """Get a snapshot of the current metrics without the detailed lists.""" return MetricsSnapshot( @@ -149,6 +169,7 @@ def add_cost(self, value: float) -> None: raise ValueError("Added cost cannot be negative.") self.accumulated_cost += value self.costs.append(Cost(cost=value, model=self.model_name)) + self._notify_change() def add_response_latency(self, value: float, response_id: str) -> None: self.response_latencies.append( @@ -156,6 +177,7 @@ def add_response_latency(self, value: float, response_id: str) -> None: latency=max(0.0, value), model=self.model_name, response_id=response_id ) ) + self._notify_change() def add_token_usage( self, @@ -201,6 +223,8 @@ def add_token_usage( else: self.accumulated_token_usage = self.accumulated_token_usage + new_usage + self._notify_change() + def merge(self, other: "Metrics") -> None: """Merge 'other' metrics into this one.""" self.accumulated_cost += other.accumulated_cost @@ -221,6 +245,8 @@ def merge(self, other: "Metrics") -> None: self.accumulated_token_usage + other.accumulated_token_usage ) + self._notify_change() + def get(self) -> dict: """Return the metrics in a dictionary.""" return { diff --git a/tests/sdk/conversation/test_conversation_stats.py b/tests/sdk/conversation/test_conversation_stats.py index 220046cf9..8e15cd08a 100644 --- a/tests/sdk/conversation/test_conversation_stats.py +++ b/tests/sdk/conversation/test_conversation_stats.py @@ -382,3 +382,34 @@ def test_service_shims_expose_usage_data(conversation_stats): restored = conversation_stats._restored_services assert "legacy-service" in restored + + +def test_stats_change_callback_triggered(): + """Test that stats change callback is triggered when new LLM is registered.""" + stats = ConversationStats() + callback_called = [] + + def callback(): + callback_called.append(True) + + stats.set_on_stats_change(callback) + + # Create and register a new LLM + with patch("openhands.sdk.llm.llm.litellm_completion"): + llm = LLM( + usage_id="test-service", + model="gpt-4o", + api_key=SecretStr("test_key"), + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + event = RegistryEvent(llm=llm) + stats.register_llm(event) + + # Verify callback was called + assert len(callback_called) == 1 + + # Register the same LLM again - callback should not be called + stats.register_llm(event) + assert len(callback_called) == 1 # Still 1, not 2 diff --git a/tests/sdk/conversation/test_state_change_callback.py b/tests/sdk/conversation/test_state_change_callback.py index 9443a9429..377f45428 100644 --- a/tests/sdk/conversation/test_state_change_callback.py +++ b/tests/sdk/conversation/test_state_change_callback.py @@ -1,11 +1,12 @@ """Tests for ConversationState callback mechanism.""" import uuid +from unittest.mock import patch import pytest from pydantic import SecretStr -from openhands.sdk import LLM, Agent +from openhands.sdk import LLM, Agent, RegistryEvent from openhands.sdk.conversation.state import ( ConversationExecutionStatus, ConversationState, @@ -175,3 +176,34 @@ def callback(event: ConversationStateUpdateEvent): assert len(callback_calls) == 1 assert callback_calls[0].key == "max_iterations" assert callback_calls[0].value == 100 + + +def test_stats_change_triggers_callback(state): + """Test that stats changes trigger the state change callback.""" + callback_calls = [] + + def callback(event: ConversationStateUpdateEvent): + callback_calls.append(event) + + # Set the callback - this also sets up stats callback + state.set_on_state_change(callback) + + # Register a new LLM which will update stats + with patch("openhands.sdk.llm.llm.litellm_completion"): + llm = LLM( + usage_id="new-service", + model="gpt-4o", + api_key=SecretStr("test_key"), + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + event = RegistryEvent(llm=llm) + state.stats.register_llm(event) + + # Verify callback was called for stats change + assert len(callback_calls) == 1 + assert callback_calls[0].key == "stats" + assert isinstance(callback_calls[0].value, dict) + assert "usage_to_metrics" in callback_calls[0].value + assert "new-service" in callback_calls[0].value["usage_to_metrics"] diff --git a/tests/sdk/conversation/test_stats_streaming_integration.py b/tests/sdk/conversation/test_stats_streaming_integration.py new file mode 100644 index 000000000..14b4d790e --- /dev/null +++ b/tests/sdk/conversation/test_stats_streaming_integration.py @@ -0,0 +1,176 @@ +"""Integration test for stats streaming during conversation execution.""" + +import uuid +from unittest.mock import patch + +import pytest +from pydantic import SecretStr + +from openhands.sdk import LLM, Agent, RegistryEvent +from openhands.sdk.conversation.state import ConversationState +from openhands.sdk.event import Event +from openhands.sdk.event.conversation_state import ConversationStateUpdateEvent +from openhands.sdk.io import InMemoryFileStore +from openhands.sdk.workspace import LocalWorkspace + + +@pytest.fixture +def state(): + """Create a ConversationState for testing.""" + llm = LLM(model="gpt-4", api_key=SecretStr("test-key"), usage_id="test-llm") + agent = Agent(llm=llm) + workspace = LocalWorkspace(working_dir="/tmp/test") + + state = ConversationState( + id=uuid.uuid4(), + workspace=workspace, + persistence_dir="/tmp/test/.state", + agent=agent, + ) + + # Set up filestore and enable autosave so callbacks are triggered + state._fs = InMemoryFileStore() + state._autosave_enabled = True + + return state + + +def test_metrics_updates_trigger_state_change_events(state): + """ + Test that when LLM metrics are updated during execution, + the state change callback is triggered with stats updates. + + This is the key integration test for issue #1087. + """ + state_change_events = [] + + def state_callback(event: Event): + if isinstance(event, ConversationStateUpdateEvent): + state_change_events.append(event) + + # Set up the state change callback + state.set_on_state_change(state_callback) + + # Register an LLM (simulating agent initialization) + with patch("openhands.sdk.llm.llm.litellm_completion"): + llm = LLM( + usage_id="test-service", + model="gpt-4o", + api_key=SecretStr("test_key"), + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + event = RegistryEvent(llm=llm) + state.stats.register_llm(event) + + # Clear the initial registration event + state_change_events.clear() + + # Simulate LLM usage during conversation execution + # This is what happens when the agent makes LLM calls + llm.metrics.add_cost(0.05) + llm.metrics.add_token_usage( + prompt_tokens=500, + completion_tokens=200, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=8000, + response_id="resp1", + ) + + # Verify that state change events were generated + assert len(state_change_events) == 2 # One for cost, one for token usage + + # Verify the events are stats updates + for event in state_change_events: + assert event.key == "stats" + assert isinstance(event.value, dict) + assert "usage_to_metrics" in event.value + + # Verify stats contain the updated costs + final_stats = state_change_events[-1].value + assert "test-service" in final_stats["usage_to_metrics"] + service_metrics = final_stats["usage_to_metrics"]["test-service"] + assert service_metrics["accumulated_cost"] == 0.05 + assert service_metrics["accumulated_token_usage"]["prompt_tokens"] == 500 + assert service_metrics["accumulated_token_usage"]["completion_tokens"] == 200 + + +def test_multiple_llms_metrics_updates_all_trigger_events(state): + """Test that metrics updates from multiple LLMs all trigger state events.""" + state_change_events = [] + + def state_callback(event: Event): + if isinstance(event, ConversationStateUpdateEvent): + state_change_events.append(event) + + state.set_on_state_change(state_callback) + + # Register two LLMs + with patch("openhands.sdk.llm.llm.litellm_completion"): + llm1 = LLM( + usage_id="service-1", + model="gpt-4o", + api_key=SecretStr("test_key"), + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + llm2 = LLM( + usage_id="service-2", + model="claude-3", + api_key=SecretStr("test_key"), + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + state.stats.register_llm(RegistryEvent(llm=llm1)) + state.stats.register_llm(RegistryEvent(llm=llm2)) + + state_change_events.clear() + + # Simulate updates from both LLMs + llm1.metrics.add_cost(0.05) + llm2.metrics.add_cost(0.03) + + # Both updates should trigger events + assert len(state_change_events) == 2 + + # Both should be stats events + for event in state_change_events: + assert event.key == "stats" + + +def test_callback_removal_stops_stats_streaming(state): + """Test that removing the callback stops stats streaming.""" + state_change_events = [] + + def state_callback(event: Event): + if isinstance(event, ConversationStateUpdateEvent): + state_change_events.append(event) + + state.set_on_state_change(state_callback) + + with patch("openhands.sdk.llm.llm.litellm_completion"): + llm = LLM( + usage_id="test-service", + model="gpt-4o", + api_key=SecretStr("test_key"), + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + state.stats.register_llm(RegistryEvent(llm=llm)) + + state_change_events.clear() + + # Remove the callback + state.set_on_state_change(None) + + # Update metrics + llm.metrics.add_cost(0.05) + + # No events should be generated + assert len(state_change_events) == 0 diff --git a/tests/sdk/llm/test_metrics_callback.py b/tests/sdk/llm/test_metrics_callback.py new file mode 100644 index 000000000..bc9b49113 --- /dev/null +++ b/tests/sdk/llm/test_metrics_callback.py @@ -0,0 +1,140 @@ +"""Tests for Metrics callback mechanism.""" + +from openhands.sdk.llm.utils.metrics import Metrics + + +def test_metrics_change_callback_on_add_cost(): + """Test that callback is triggered when cost is added.""" + metrics = Metrics(model_name="gpt-4") + callback_calls = [] + + def callback(): + callback_calls.append(True) + + metrics.set_on_change(callback) + + # Add cost - should trigger callback + metrics.add_cost(0.05) + + assert len(callback_calls) == 1 + + +def test_metrics_change_callback_on_add_token_usage(): + """Test that callback is triggered when token usage is added.""" + metrics = Metrics(model_name="gpt-4") + callback_calls = [] + + def callback(): + callback_calls.append(True) + + metrics.set_on_change(callback) + + # Add token usage - should trigger callback + metrics.add_token_usage( + prompt_tokens=100, + completion_tokens=50, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=8000, + response_id="resp1", + ) + + assert len(callback_calls) == 1 + + +def test_metrics_change_callback_on_multiple_updates(): + """Test that callback is triggered for multiple updates.""" + metrics = Metrics(model_name="gpt-4") + callback_calls = [] + + def callback(): + callback_calls.append(True) + + metrics.set_on_change(callback) + + # Make multiple updates + metrics.add_cost(0.05) + metrics.add_token_usage( + prompt_tokens=100, + completion_tokens=50, + cache_read_tokens=0, + cache_write_tokens=0, + context_window=8000, + response_id="resp1", + ) + metrics.add_cost(0.02) + + assert len(callback_calls) == 3 + + +def test_metrics_callback_can_be_cleared(): + """Test that callback can be removed by setting to None.""" + metrics = Metrics(model_name="gpt-4") + callback_calls = [] + + def callback(): + callback_calls.append(True) + + # Set and then clear the callback + metrics.set_on_change(callback) + metrics.set_on_change(None) + + # Add cost - callback should not be called + metrics.add_cost(0.05) + + assert len(callback_calls) == 0 + + +def test_metrics_callback_exception_does_not_break_update(): + """Test that exceptions in callback don't prevent metrics updates.""" + + def bad_callback(): + raise ValueError("Callback error") + + metrics = Metrics(model_name="gpt-4") + metrics.set_on_change(bad_callback) + + # Add cost - should not raise despite callback error + metrics.add_cost(0.05) + + # Verify metric was still updated + assert metrics.accumulated_cost == 0.05 + + +def test_metrics_merge_triggers_callback(): + """Test that merge operation triggers callback.""" + metrics1 = Metrics(model_name="gpt-4") + metrics2 = Metrics(model_name="gpt-4") + callback_calls = [] + + def callback(): + callback_calls.append(True) + + metrics1.set_on_change(callback) + + # Add some costs to metrics2 + metrics2.add_cost(0.03) + + # Clear callback calls from previous operations + callback_calls.clear() + + # Merge - should trigger callback + metrics1.merge(metrics2) + + assert len(callback_calls) == 1 + + +def test_metrics_add_response_latency_triggers_callback(): + """Test that adding response latency triggers callback.""" + metrics = Metrics(model_name="gpt-4") + callback_calls = [] + + def callback(): + callback_calls.append(True) + + metrics.set_on_change(callback) + + # Add response latency - should trigger callback + metrics.add_response_latency(1.5, "resp1") + + assert len(callback_calls) == 1