diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index 857fed33..8a5d046e 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -220,12 +220,18 @@ def tasks(state: State, context: ApplicationContext) -> Generator[SubGraphTask, query_llm.bind(model="o1").with_name("o1_answer"), query_llm.bind(model="claude").with_name("claude_answer"), ] + # Route the application_id through self.sub_application_id so + # subclasses can customize the sub-app cache/resume behavior -- + # e.g. salt with a fresh value per invocation to bypass a + # cascading state_initializer's checkpoint hits. The default + # hook is a stable deterministic hash and is what most users + # want (it's load-bearing for retry-on-failure resume). + key = f"{prompt}:{action.name}" # any stable key you choose yield SubGraphTask( action=action, # can be a RunnableGraph as well state=state.update(prompt=prompt), inputs={}, - # stable hash -- up to you to ensure uniqueness - application_id=hashlib.sha256(context.application_id + action.name + prompt).hexdigest(), + application_id=self.sub_application_id(key, state, context), # a few other parameters we might add -- see advanced usage -- failure conditions, etc... ) @@ -316,6 +322,35 @@ def is_async(self) -> bool: """ return False + def sub_application_id(self, key: str, state: State, context: ApplicationContext) -> str: + """Compute the application_id for a sub-task. + + Default: deterministic hash of (parent_app_id, key) -- stable across parent + rebuilds, which is what enables sub-app checkpoint resume on crash recovery. + If the parent application is rebuilt (e.g. as part of retry-on-failure or + a resume-from-persistence flow), each sub-task gets the same id it had + before, so a cascading state initializer can find its prior checkpoint + and pick up where it left off. + + Override to customize cache/resume behavior: + + - Fresh execution per invocation: salt with something that advances + per-call (e.g. ``context.sequence_id``, a uuid, your own counter). + This is the workaround for `#761 `_, + where a cascading state initializer combined with deterministic sub-app + ids causes a parallel action to replay stale sub-app state on every + invocation instead of running fresh. Note that opting into per-invocation + ids gives up the resume-on-rebuild guarantee above. + - Pin to a business key: derive from ``state`` contents (e.g. a record + id) so re-runs against the same logical input reuse the same sub-app. + + :param key: Per-task key (unique within this invocation of the parent action). + :param state: State that will be passed to the sub-task. + :param context: Parent application context. + :return: Application id to use for the sub-task. + """ + return _stable_app_id_hash(context.app_id, key) + @property def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: """Inputs from this -- if you want to override you'll want to call super() @@ -509,7 +544,7 @@ def _create_task(key: str, action: Action, substate: State) -> SubGraphTask: graph=RunnableGraph.create(action), inputs=inputs, state=substate, - application_id=_stable_app_id_hash(context.app_id, key), + application_id=self.sub_application_id(key, substate, context), tracker=tracker, state_persister=state_persister, state_initializer=state_initializer, @@ -518,7 +553,10 @@ def _create_task(key: str, action: Action, substate: State) -> SubGraphTask: def _tasks() -> Generator[SubGraphTask, None, None]: for i, action in enumerate(self.actions(state, context, inputs)): for j, substate in enumerate(self.states(state, context, inputs)): - key = f"{i}-{j}" # this is a stable hash for now but will not handle caching + # Per-task key is stable across rebuilds. The actual sub-app id is + # computed via ``sub_application_id``; override that hook to opt + # into per-invocation ids (see issue #761). + key = f"{i}-{j}" yield _create_task(key, action, substate) async def _atasks() -> AsyncGenerator[SubGraphTask, None]: diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py index 25d37cc2..3931eeac 100644 --- a/tests/core/test_parallelism.py +++ b/tests/core/test_parallelism.py @@ -45,7 +45,12 @@ _cascade_adapter, map_reduce_action, ) -from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData +from burr.core.persistence import ( + BaseStateLoader, + BaseStateSaver, + InMemoryPersister, + PersistedStateData, +) from burr.tracking.base import SyncTrackingClient from burr.visibility import ActionSpan @@ -1227,3 +1232,90 @@ def reads(self) -> list[str]: assert task.state_initializer is not None assert task.tracker is not None assert task.state_persister is task.state_initializer # This ensures they're the same + + +def test_sub_application_id_override_enables_fresh_execution_with_cascading_initializer(): + """Regression test for #761. + + With a cascading state initializer + the default deterministic sub-app id, + re-invoking the parallel action on the same parent reuses prior sub-app + state via the initializer, so per-invocation work does not actually re-run. + Overriding ``sub_application_id`` to salt with a per-invocation value + restores fresh execution while leaving the default (resume-on-rebuild) + behavior alone for everyone else. + """ + invocation_count = {"n": 0} + shared_persister = InMemoryPersister() + + @old_action(reads=["input_number"], writes=["output_number", "invocation"]) + def record_invocation(state: State) -> State: + invocation_count["n"] += 1 + return state.update(output_number=state["input_number"], invocation=invocation_count["n"]) + + class SaltedMapStates(MapStates): + call_index = 0 + + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + for input_number in state["input_numbers_in_state"]: + yield state.update(input_number=input_number) + + def action( + self, state: State, inputs: Dict[str, Any] + ) -> Union[Action, Callable, RunnableGraph]: + return record_invocation + + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + return state.update(invocations=[output_state["invocation"] for output_state in states]) + + # Pin sub-app persistence to the shared persister. This mirrors the + # #761 setup where a cascading initializer makes sub-apps resume. + def state_initializer(self, **kwargs): + return shared_persister + + def state_persister(self, **kwargs): + return shared_persister + + def sub_application_id(self, key: str, state: State, context: ApplicationContext) -> str: + # Per-invocation salt -- each top-level run gets fresh sub-app ids. + return f"{context.app_id}:{key}:call-{type(self).call_index}" + + @property + def writes(self) -> list[str]: + return ["invocations"] + + @property + def reads(self) -> list[str]: + return ["input_numbers_in_state"] + + def _build_and_run(): + app = ( + ApplicationBuilder() + .with_actions( + initial=Input("input_numbers_in_state"), + map_action=SaltedMapStates(), + final=Result("invocations"), + ) + .with_transitions(("initial", "map_action"), ("map_action", "final")) + .with_entrypoint("initial") + .with_identifiers(app_id="parent-app-761") + .build() + ) + _, _, state = app.run(halt_after=["final"], inputs={"input_numbers_in_state": [1, 2, 3]}) + return state + + # Three independent parent invocations against the same parent app_id and + # the same sub-app persister. With the override, every sub-task should + # actually execute on every run (no stale-replay caching). + SaltedMapStates.call_index = 0 + _build_and_run() + SaltedMapStates.call_index = 1 + _build_and_run() + SaltedMapStates.call_index = 2 + final_state = _build_and_run() + + # 3 inputs * 3 invocations = 9 actual executions. + assert invocation_count["n"] == 9 + # The latest run's invocations all come from the most recent counter range. + assert all(inv > 6 for inv in final_state["invocations"])