Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions openhands-sdk/openhands/sdk/conversation/conversation_stats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from collections.abc import Callable

from pydantic import AliasChoices, BaseModel, Field, PrivateAttr

Expand Down Expand Up @@ -29,6 +30,7 @@ class ConversationStats(BaseModel):
)

_restored_usage_ids: set[str] = PrivateAttr(default_factory=set)
_on_stats_change: Callable[[], None] | None = PrivateAttr(default=None)

@property
def service_to_metrics(
Expand Down Expand Up @@ -83,11 +85,25 @@ def get_metrics_for_service(
)
return self.get_metrics_for_usage(service_id)

def set_on_stats_change(self, callback: Callable[[], None] | None) -> None:
"""Set a callback to be called when stats change.

Args:
callback: A function to call when stats are updated, or None to remove
"""
self._on_stats_change = callback

# Wire up the callback to all already-registered metrics
for metrics in self.usage_to_metrics.values():
metrics.set_on_change(callback)

def register_llm(self, event: RegistryEvent):
# Listen for LLM creations and track their metrics
llm = event.llm
usage_id = llm.usage_id

stats_changed = False

# Usage costs exist but have not been restored yet
if (
usage_id in self.usage_to_metrics
Expand All @@ -99,3 +115,15 @@ def register_llm(self, event: RegistryEvent):
# Usage is new, track its metrics
if usage_id not in self.usage_to_metrics and llm.metrics:
self.usage_to_metrics[usage_id] = llm.metrics
stats_changed = True

# Set up callback on the metrics object to get notified of updates
if llm.metrics and self._on_stats_change is not None:
llm.metrics.set_on_change(self._on_stats_change)

# Notify of stats change if callback is set and stats changed
if stats_changed and self._on_stats_change is not None:
try:
self._on_stats_change()
except Exception:
logger.exception("Stats change callback failed", exc_info=True)
22 changes: 22 additions & 0 deletions openhands-sdk/openhands/sdk/conversation/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,28 @@ def set_on_state_change(self, callback: ConversationCallbackType | None) -> None
or None to remove the callback
"""
self._on_state_change = callback
# Also set up stats change callback to notify when stats are mutated
if callback is not None:
self.stats.set_on_stats_change(self._notify_stats_change)
else:
self.stats.set_on_stats_change(None)

def _notify_stats_change(self) -> None:
"""Notify state change callback about stats update."""
if self._on_state_change is not None:
try:
from openhands.sdk.event.conversation_state import (
ConversationStateUpdateEvent,
)

# Create a ConversationStateUpdateEvent with the updated stats
stats_data = self.stats.model_dump(mode="json")
state_update_event = ConversationStateUpdateEvent(
key="stats", value=stats_data
)
self._on_state_change(state_update_event)
except Exception:
logger.exception("Stats change notification failed", exc_info=True)

# ===== Base snapshot helpers (same FileStore usage you had) =====
def _save_base_state(self, fs: FileStore) -> None:
Expand Down
28 changes: 27 additions & 1 deletion openhands-sdk/openhands/sdk/llm/utils/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import time
from collections.abc import Callable
from typing import final

from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator


class Cost(BaseModel):
Expand Down Expand Up @@ -111,6 +112,8 @@ class Metrics(MetricsSnapshot):
default_factory=list, description="List of token usage records"
)

_on_change: Callable[[], None] | None = PrivateAttr(default=None)

@field_validator("accumulated_cost")
@classmethod
def validate_accumulated_cost(cls, v: float) -> float:
Expand All @@ -133,6 +136,23 @@ def initialize_accumulated_token_usage(self) -> "Metrics":
)
return self

def set_on_change(self, callback: Callable[[], None] | None) -> None:
"""Set a callback to be called when metrics change.

Args:
callback: A function to call when metrics are updated, or None to remove
"""
self._on_change = callback

def _notify_change(self) -> None:
"""Notify the callback that metrics have changed."""
if self._on_change is not None:
try:
self._on_change()
except Exception:
# Avoid breaking metrics updates if callback fails
pass

def get_snapshot(self) -> MetricsSnapshot:
"""Get a snapshot of the current metrics without the detailed lists."""
return MetricsSnapshot(
Expand All @@ -149,13 +169,15 @@ def add_cost(self, value: float) -> None:
raise ValueError("Added cost cannot be negative.")
self.accumulated_cost += value
self.costs.append(Cost(cost=value, model=self.model_name))
self._notify_change()

def add_response_latency(self, value: float, response_id: str) -> None:
self.response_latencies.append(
ResponseLatency(
latency=max(0.0, value), model=self.model_name, response_id=response_id
)
)
self._notify_change()

def add_token_usage(
self,
Expand Down Expand Up @@ -201,6 +223,8 @@ def add_token_usage(
else:
self.accumulated_token_usage = self.accumulated_token_usage + new_usage

self._notify_change()

def merge(self, other: "Metrics") -> None:
"""Merge 'other' metrics into this one."""
self.accumulated_cost += other.accumulated_cost
Expand All @@ -221,6 +245,8 @@ def merge(self, other: "Metrics") -> None:
self.accumulated_token_usage + other.accumulated_token_usage
)

self._notify_change()

def get(self) -> dict:
"""Return the metrics in a dictionary."""
return {
Expand Down
31 changes: 31 additions & 0 deletions tests/sdk/conversation/test_conversation_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,34 @@ def test_service_shims_expose_usage_data(conversation_stats):
restored = conversation_stats._restored_services

assert "legacy-service" in restored


def test_stats_change_callback_triggered():
"""Test that stats change callback is triggered when new LLM is registered."""
stats = ConversationStats()
callback_called = []

def callback():
callback_called.append(True)

stats.set_on_stats_change(callback)

# Create and register a new LLM
with patch("openhands.sdk.llm.llm.litellm_completion"):
llm = LLM(
usage_id="test-service",
model="gpt-4o",
api_key=SecretStr("test_key"),
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
event = RegistryEvent(llm=llm)
stats.register_llm(event)

# Verify callback was called
assert len(callback_called) == 1

# Register the same LLM again - callback should not be called
stats.register_llm(event)
assert len(callback_called) == 1 # Still 1, not 2
34 changes: 33 additions & 1 deletion tests/sdk/conversation/test_state_change_callback.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Tests for ConversationState callback mechanism."""

import uuid
from unittest.mock import patch

import pytest
from pydantic import SecretStr

from openhands.sdk import LLM, Agent
from openhands.sdk import LLM, Agent, RegistryEvent
from openhands.sdk.conversation.state import (
ConversationExecutionStatus,
ConversationState,
Expand Down Expand Up @@ -175,3 +176,34 @@ def callback(event: ConversationStateUpdateEvent):
assert len(callback_calls) == 1
assert callback_calls[0].key == "max_iterations"
assert callback_calls[0].value == 100


def test_stats_change_triggers_callback(state):
"""Test that stats changes trigger the state change callback."""
callback_calls = []

def callback(event: ConversationStateUpdateEvent):
callback_calls.append(event)

# Set the callback - this also sets up stats callback
state.set_on_state_change(callback)

# Register a new LLM which will update stats
with patch("openhands.sdk.llm.llm.litellm_completion"):
llm = LLM(
usage_id="new-service",
model="gpt-4o",
api_key=SecretStr("test_key"),
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
event = RegistryEvent(llm=llm)
state.stats.register_llm(event)

# Verify callback was called for stats change
assert len(callback_calls) == 1
assert callback_calls[0].key == "stats"
assert isinstance(callback_calls[0].value, dict)
assert "usage_to_metrics" in callback_calls[0].value
assert "new-service" in callback_calls[0].value["usage_to_metrics"]
Loading
Loading