Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions burr/core/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,18 @@ def tasks(state: State, context: ApplicationContext) -> Generator[SubGraphTask,
query_llm.bind(model="o1").with_name("o1_answer"),
query_llm.bind(model="claude").with_name("claude_answer"),
]
# Route the application_id through self.sub_application_id so
# subclasses can customize the sub-app cache/resume behavior --
# e.g. salt with a fresh value per invocation to bypass a
# cascading state_initializer's checkpoint hits. The default
# hook is a stable deterministic hash and is what most users
# want (it's load-bearing for retry-on-failure resume).
key = f"{prompt}:{action.name}" # any stable key you choose
yield SubGraphTask(
action=action, # can be a RunnableGraph as well
state=state.update(prompt=prompt),
inputs={},
# stable hash -- up to you to ensure uniqueness
application_id=hashlib.sha256(context.application_id + action.name + prompt).hexdigest(),
application_id=self.sub_application_id(key, state, context),
# a few other parameters we might add -- see advanced usage -- failure conditions, etc...
)

Expand Down Expand Up @@ -316,6 +322,35 @@ def is_async(self) -> bool:
"""
return False

def sub_application_id(self, key: str, state: State, context: ApplicationContext) -> str:
"""Compute the application_id for a sub-task.

Default: deterministic hash of (parent_app_id, key) -- stable across parent
rebuilds, which is what enables sub-app checkpoint resume on crash recovery.
If the parent application is rebuilt (e.g. as part of retry-on-failure or
a resume-from-persistence flow), each sub-task gets the same id it had
before, so a cascading state initializer can find its prior checkpoint
and pick up where it left off.

Override to customize cache/resume behavior:

- Fresh execution per invocation: salt with something that advances
per-call (e.g. ``context.sequence_id``, a uuid, your own counter).
This is the workaround for `#761 <https://github.com/apache/burr/issues/761>`_,
where a cascading state initializer combined with deterministic sub-app
ids causes a parallel action to replay stale sub-app state on every
invocation instead of running fresh. Note that opting into per-invocation
ids gives up the resume-on-rebuild guarantee above.
- Pin to a business key: derive from ``state`` contents (e.g. a record
id) so re-runs against the same logical input reuse the same sub-app.

:param key: Per-task key (unique within this invocation of the parent action).
:param state: State that will be passed to the sub-task.
:param context: Parent application context.
:return: Application id to use for the sub-task.
"""
return _stable_app_id_hash(context.app_id, key)

@property
def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:
"""Inputs from this -- if you want to override you'll want to call super()
Expand Down Expand Up @@ -509,7 +544,7 @@ def _create_task(key: str, action: Action, substate: State) -> SubGraphTask:
graph=RunnableGraph.create(action),
inputs=inputs,
state=substate,
application_id=_stable_app_id_hash(context.app_id, key),
application_id=self.sub_application_id(key, substate, context),
tracker=tracker,
state_persister=state_persister,
state_initializer=state_initializer,
Expand All @@ -518,7 +553,10 @@ def _create_task(key: str, action: Action, substate: State) -> SubGraphTask:
def _tasks() -> Generator[SubGraphTask, None, None]:
for i, action in enumerate(self.actions(state, context, inputs)):
for j, substate in enumerate(self.states(state, context, inputs)):
key = f"{i}-{j}" # this is a stable hash for now but will not handle caching
# Per-task key is stable across rebuilds. The actual sub-app id is
# computed via ``sub_application_id``; override that hook to opt
# into per-invocation ids (see issue #761).
key = f"{i}-{j}"
yield _create_task(key, action, substate)

async def _atasks() -> AsyncGenerator[SubGraphTask, None]:
Expand Down
94 changes: 93 additions & 1 deletion tests/core/test_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
_cascade_adapter,
map_reduce_action,
)
from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData
from burr.core.persistence import (
BaseStateLoader,
BaseStateSaver,
InMemoryPersister,
PersistedStateData,
)
from burr.tracking.base import SyncTrackingClient
from burr.visibility import ActionSpan

Expand Down Expand Up @@ -1227,3 +1232,90 @@ def reads(self) -> list[str]:
assert task.state_initializer is not None
assert task.tracker is not None
assert task.state_persister is task.state_initializer # This ensures they're the same


def test_sub_application_id_override_enables_fresh_execution_with_cascading_initializer():
"""Regression test for #761.

With a cascading state initializer + the default deterministic sub-app id,
re-invoking the parallel action on the same parent reuses prior sub-app
state via the initializer, so per-invocation work does not actually re-run.
Overriding ``sub_application_id`` to salt with a per-invocation value
restores fresh execution while leaving the default (resume-on-rebuild)
behavior alone for everyone else.
"""
invocation_count = {"n": 0}
shared_persister = InMemoryPersister()

@old_action(reads=["input_number"], writes=["output_number", "invocation"])
def record_invocation(state: State) -> State:
invocation_count["n"] += 1
return state.update(output_number=state["input_number"], invocation=invocation_count["n"])

class SaltedMapStates(MapStates):
call_index = 0

def states(
self, state: State, context: ApplicationContext, inputs: Dict[str, Any]
) -> Generator[State, None, None]:
for input_number in state["input_numbers_in_state"]:
yield state.update(input_number=input_number)

def action(
self, state: State, inputs: Dict[str, Any]
) -> Union[Action, Callable, RunnableGraph]:
return record_invocation

def reduce(self, state: State, states: Generator[State, None, None]) -> State:
return state.update(invocations=[output_state["invocation"] for output_state in states])

# Pin sub-app persistence to the shared persister. This mirrors the
# #761 setup where a cascading initializer makes sub-apps resume.
def state_initializer(self, **kwargs):
return shared_persister

def state_persister(self, **kwargs):
return shared_persister

def sub_application_id(self, key: str, state: State, context: ApplicationContext) -> str:
# Per-invocation salt -- each top-level run gets fresh sub-app ids.
return f"{context.app_id}:{key}:call-{type(self).call_index}"

@property
def writes(self) -> list[str]:
return ["invocations"]

@property
def reads(self) -> list[str]:
return ["input_numbers_in_state"]

def _build_and_run():
app = (
ApplicationBuilder()
.with_actions(
initial=Input("input_numbers_in_state"),
map_action=SaltedMapStates(),
final=Result("invocations"),
)
.with_transitions(("initial", "map_action"), ("map_action", "final"))
.with_entrypoint("initial")
.with_identifiers(app_id="parent-app-761")
.build()
)
_, _, state = app.run(halt_after=["final"], inputs={"input_numbers_in_state": [1, 2, 3]})
return state

# Three independent parent invocations against the same parent app_id and
# the same sub-app persister. With the override, every sub-task should
# actually execute on every run (no stale-replay caching).
SaltedMapStates.call_index = 0
_build_and_run()
SaltedMapStates.call_index = 1
_build_and_run()
SaltedMapStates.call_index = 2
final_state = _build_and_run()

# 3 inputs * 3 invocations = 9 actual executions.
assert invocation_count["n"] == 9
# The latest run's invocations all come from the most recent counter range.
assert all(inv > 6 for inv in final_state["invocations"])
Loading