In [1]:
from __future__ import annotations as _annotations

from dataclasses import dataclass
from pathlib import Path

from pydantic_graph import BaseNode, End, Graph, GraphRunContext
from pydantic_graph.persistence.file import FileStatePersistence

%load_ext autoreload
%autoreload 2

In [2]:
@dataclass
class CountDownState:
    counter: int


@dataclass
class CountDown(BaseNode[CountDownState, None, int]):
    async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
        if ctx.state.counter <= 0:
            return End(ctx.state.counter)
        ctx.state.counter -= 1
        return CountDown()


@dataclass
class CountUp(BaseNode[CountDownState, None, int]):
    async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
        ctx.state.counter += 1
        return CountDown()


count_down_graph = Graph(nodes=[CountDown, CountUp], auto_instrument=False)


In [6]:
run_id = "run_2"
persistence = FileStatePersistence(Path(f"count_down_{run_id}.json"))
state = CountDownState(counter=5)
# await count_down_graph.initialize(Coun(), state=state, persistence=persistence)

In [10]:
async with count_down_graph.iter(CountUp(), state=state, persistence=persistence) as run:
    async for node in run:
        print(node)

CountDown()
CountDown()
End(data=0)


In [11]:
(await persistence.load_all())


[NodeSnapshot(state=CountDownState(counter=5), node=CountUp(), start_ts=datetime.datetime(2025, 3, 18, 5, 53, 33, 242302, tzinfo=TzInfo(UTC)), duration=6.062036845833063e-06, status='success', kind='node', id='CountUp:9787ab7269894ba99fbb86f04ef5a618'),
 NodeSnapshot(state=CountDownState(counter=6), node=CountDown(), start_ts=datetime.datetime(2025, 3, 18, 5, 53, 33, 244172, tzinfo=TzInfo(UTC)), duration=4.243047442287207e-06, status='success', kind='node', id='CountDown:14690e0fee4947eaa10d364591c758c1'),
 NodeSnapshot(state=CountDownState(counter=5), node=CountDown(), start_ts=datetime.datetime(2025, 3, 18, 5, 53, 33, 245710, tzinfo=TzInfo(UTC)), duration=3.29798785969615e-06, status='success', kind='node', id='CountDown:975a587e35284ee1815f62bef25e3662'),
 NodeSnapshot(state=CountDownState(counter=4), node=CountDown(), start_ts=datetime.datetime(2025, 3, 18, 5, 53, 33, 247361, tzinfo=TzInfo(UTC)), duration=2.5129993446171284e-06, status='success', kind='node', id='CountDown:448d33f6