Skip to content

Commit

Permalink
Streaming fixes for tracking/hooks
Browse files Browse the repository at this point in the history
This is still a little messy -- we have some duplicated code for calling
out to hooks/incrementing sequence IDs. That said, it will work for now.
It is also tested.
  • Loading branch information
elijahbenizzy committed Mar 11, 2024
1 parent d9cd5dc commit 17ac40a
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 54 deletions.
4 changes: 4 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,10 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction":
def inputs(self) -> list[str]:
return _get_inputs(self._bound_params, self._fn)

@property
def fn(self) -> Callable:
return self._fn


def _validate_action_function(fn: Callable):
"""Validates that an action has the signature: (state: State) -> Tuple[dict, State]
Expand Down
18 changes: 18 additions & 0 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,15 @@ def streaming_response(state: State, prompt: str) -> Generator[dict, None, Tuple
# For context, this is specifically for the case in which you want to have
# multiple terminal points with a unified API, where some are streaming, and some are not.
if next_action.name in halt_before and next_action.name not in halt_after:
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_run_step",
action=next_action,
state=self._state,
result=None,
sequence_id=self.sequence_id,
exception=None,
)
self._increment_sequence_id()
return next_action, StreamingResultContainer.pass_through(
results=results, final_state=state
)
Expand Down Expand Up @@ -781,6 +790,15 @@ def callback(
# In this case we are halting at a non-streaming condition
# This is allowed as we want to maintain a more consistent API
action, result, state = self._step(inputs=inputs, _run_hooks=False)
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_run_step",
action=next_action,
state=self._state,
result=result,
sequence_id=self.sequence_id,
exception=None,
)
self._increment_sequence_id()
return action, StreamingResultContainer.pass_through(
results=result, final_state=state
)
Expand Down
4 changes: 2 additions & 2 deletions burr/tracking/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import field_serializer

from burr.core import Action
from burr.core.action import FunctionBasedAction
from burr.core.action import FunctionBasedAction, FunctionBasedStreamingAction
from burr.core.application import ApplicationGraph, Transition
from burr.integrations.base import require_plugin

Expand Down Expand Up @@ -39,7 +39,7 @@ def from_action(action: Action) -> "ActionModel":
:param action: Action to create the model from
:return:
"""
if isinstance(action, FunctionBasedAction):
if isinstance(action, (FunctionBasedAction, FunctionBasedStreamingAction)):
code = inspect.getsource(action.fn)
else:
code = inspect.getsource(action.__class__)
Expand Down
113 changes: 61 additions & 52 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
internal,
)
from burr.lifecycle.base import PostApplicationCreateHook
from burr.lifecycle.internal import LifecycleAdapterSet


class PassedInAction(Action):
Expand Down Expand Up @@ -111,6 +112,32 @@ async def run(self, state: State, **run_kwargs) -> dict:
)


class ActionTracker(PreRunStepHook, PostRunStepHook):
def __init__(self):
self.pre_called = []
self.post_called = []

def pre_run_step(self, action: Action, **future_kwargs):
self.pre_called.append(action.name)

def post_run_step(self, action: Action, **future_kwargs):
self.post_called.append(action.name)


class ActionTrackerAsync(PreRunStepHookAsync, PostRunStepHookAsync):
def __init__(self):
self.pre_called = []
self.post_called = []

async def pre_run_step(self, *, action: Action, **future_kwargs):
await asyncio.sleep(0.0001)
self.pre_called.append(action.name)

async def post_run_step(self, *, action: Action, **future_kwargs):
await asyncio.sleep(0.0001)
self.post_called.append(action.name)


async def _counter_update_async(state: State, additional_increment: int = 0) -> dict:
await asyncio.sleep(0.0001) # just so we can make this *truly* async
# does not matter, but more accurately simulates an async function
Expand Down Expand Up @@ -857,6 +884,7 @@ async def test_app_a_run_async_and_sync():


def test_stream_result_halt_after():
action_tracker = ActionTracker()
counter_action = base_streaming_counter.with_name("counter")
counter_action_2 = base_streaming_counter.with_name("counter_2")
app = Application(
Expand All @@ -866,16 +894,22 @@ def test_stream_result_halt_after():
],
state=State({"count": 0}),
initial_step="counter",
adapter_set=LifecycleAdapterSet(action_tracker),
)
action, streaming_container = app.stream_result(halt_after=["counter_2"])
results = list(streaming_container)
assert len(results) == 10
result, state = streaming_container.get()
assert result["count"] == state["count"] == 2
assert state["tracker"] == [1, 2]
assert len(action_tracker.pre_called) == 2
assert len(action_tracker.post_called) == 2
assert set(action_tracker.pre_called) == {"counter", "counter_2"}
assert set(action_tracker.post_called) == {"counter", "counter_2"}


def test_stream_result_halt_after_single_step():
action_tracker = ActionTracker()
counter_action = base_streaming_single_step_counter.with_name("counter")
counter_action_2 = base_streaming_single_step_counter.with_name("counter_2")
app = Application(
Expand All @@ -885,18 +919,24 @@ def test_stream_result_halt_after_single_step():
],
state=State({"count": 0}),
initial_step="counter",
adapter_set=LifecycleAdapterSet(action_tracker),
)
action, streaming_container = app.stream_result(halt_after=["counter_2"])
results = list(streaming_container)
assert len(results) == 10
result, state = streaming_container.get()
assert result["count"] == state["count"] == 2
assert state["tracker"] == [1, 2]
assert len(action_tracker.pre_called) == 2
assert len(action_tracker.post_called) == 2
assert set(action_tracker.pre_called) == {"counter", "counter_2"}
assert set(action_tracker.post_called) == {"counter", "counter_2"}


def test_stream_result_halt_after_run_through_final_streaming():
"""Tests what happens when we have an app that runs through non-streaming
results before hitting a final streaming result specified by halt_after"""
action_tracker = ActionTracker()
counter_non_streaming = base_counter_action.with_name("counter_non_streaming")
counter_streaming = base_streaming_single_step_counter.with_name("counter_streaming")

Expand All @@ -908,15 +948,21 @@ def test_stream_result_halt_after_run_through_final_streaming():
],
state=State({"count": 0}),
initial_step="counter_non_streaming",
adapter_set=LifecycleAdapterSet(action_tracker),
)
action, streaming_container = app.stream_result(halt_after=["counter_streaming"])
results = list(streaming_container)
assert len(results) == 10
result, state = streaming_container.get()
assert result["count"] == state["count"] == 11
assert len(action_tracker.pre_called) == 11
assert len(action_tracker.post_called) == 11
assert set(action_tracker.pre_called) == {"counter_streaming", "counter_non_streaming"}
assert set(action_tracker.post_called) == {"counter_streaming", "counter_non_streaming"}


def test_stream_result_halt_after_run_through_final_non_streaming():
action_tracker = ActionTracker()
counter_non_streaming = base_counter_action.with_name("counter_non_streaming")
counter_final_non_streaming = base_counter_action.with_name("counter_final_non_streaming")

Expand All @@ -928,12 +974,23 @@ def test_stream_result_halt_after_run_through_final_non_streaming():
],
state=State({"count": 0}),
initial_step="counter_non_streaming",
adapter_set=LifecycleAdapterSet(action_tracker),
)
action, streaming_container = app.stream_result(halt_after=["counter_final_non_streaming"])
results = list(streaming_container)
assert len(results) == 0 # nothing to steram
result, state = streaming_container.get()
assert result["count"] == state["count"] == 11
assert len(action_tracker.pre_called) == 11
assert len(action_tracker.post_called) == 11
assert set(action_tracker.pre_called) == {
"counter_non_streaming",
"counter_final_non_streaming",
}
assert set(action_tracker.post_called) == {
"counter_non_streaming",
"counter_final_non_streaming",
}


def test_stream_result_halt_before():
Expand Down Expand Up @@ -1080,17 +1137,6 @@ def test_application_builder_unset():


def test_application_run_step_hooks_sync():
class ActionTracker(PreRunStepHook, PostRunStepHook):
def __init__(self):
self.pre_called = []
self.post_called = []

def pre_run_step(self, *, action: Action, **future_kwargs):
self.pre_called.append(action.name)

def post_run_step(self, *, action: Action, **future_kwargs):
self.post_called.append(action.name)

tracker = ActionTracker()
counter_action = base_counter_action.with_name("counter")
result_action = Result("count").with_name("result")
Expand All @@ -1114,19 +1160,6 @@ def post_run_step(self, *, action: Action, **future_kwargs):


async def test_application_run_step_hooks_async():
class ActionTrackerAsync(PreRunStepHookAsync, PostRunStepHookAsync):
def __init__(self):
self.pre_called = []
self.post_called = []

async def pre_run_step(self, *, action: Action, **future_kwargs):
await asyncio.sleep(0.0001)
self.pre_called.append(action.name)

async def post_run_step(self, *, action: Action, **future_kwargs):
await asyncio.sleep(0.0001)
self.post_called.append(action.name)

tracker = ActionTrackerAsync()
counter_action = base_counter_action.with_name("counter")
result_action = Result("count").with_name("result")
Expand All @@ -1148,30 +1181,6 @@ async def post_run_step(self, *, action: Action, **future_kwargs):


async def test_application_run_step_runs_hooks():
class ActionTracker(PreRunStepHook, PostRunStepHook):
def __init__(self):
self.pre_called_count = 0
self.post_called_count = 0

def pre_run_step(self, **future_kwargs):
self.pre_called_count += 1

def post_run_step(self, **future_kwargs):
self.post_called_count += 1

class ActionTrackerAsync(PreRunStepHookAsync, PostRunStepHookAsync):
def __init__(self):
self.pre_called_count = 0
self.post_called_count = 0

async def pre_run_step(self, **future_kwargs):
await asyncio.sleep(0.0001)
self.pre_called_count += 1

async def post_run_step(self, **future_kwargs):
await asyncio.sleep(0.0001)
self.post_called_count += 1

hooks = [ActionTracker(), ActionTrackerAsync()]

counter_action = base_counter_action.with_name("counter")
Expand All @@ -1185,10 +1194,10 @@ async def post_run_step(self, **future_kwargs):
adapter_set=internal.LifecycleAdapterSet(*hooks),
)
await app.astep()
assert hooks[0].pre_called_count == 1
assert hooks[0].post_called_count == 1
assert hooks[1].pre_called_count == 1
assert hooks[1].post_called_count == 1
assert len(hooks[0].pre_called) == 1
assert len(hooks[0].post_called) == 1
assert len(hooks[1].pre_called) == 1
assert len(hooks[1].post_called) == 1


def test_application_post_application_create_hook():
Expand Down

0 comments on commit 17ac40a

Please sign in to comment.