Skip to content

Commit

Permalink
Adds testing for app forking/tracking
Browse files Browse the repository at this point in the history
This tests that we store + serialize parent pointers. There are a few
edge-cases we should probably handle (nulling out/tracking the sequence
ID, for example), but this is the 80/20 for now.
  • Loading branch information
elijahbenizzy committed May 18, 2024
1 parent 075c257 commit b780f16
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
1 change: 1 addition & 0 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,3 +1964,4 @@ def tests_application_builder_initialize_fork_app_id_happy_pth():
)
assert app.uid != old_app_id
assert app.state == State({"count": 5, "__PRIOR_STEP": "counter", "__SEQUENCE_ID": 5})
assert app.parent_pointer.app_id == old_app_id
75 changes: 73 additions & 2 deletions tests/tracking/test_local_tracking_client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
import os
import uuid
from typing import Tuple
from typing import Literal, Optional, Tuple

import pytest

import burr
from burr.core import Result, State, action, default, expr
from burr.core import Application, ApplicationBuilder, Result, State, action, default, expr
from burr.core.persistence import BaseStatePersister, PersistedStateData
from burr.tracking import LocalTrackingClient
from burr.tracking.client import _allowed_project_name
from burr.tracking.common.models import (
ApplicationMetadataModel,
ApplicationModel,
BeginEntryModel,
BeginSpanModel,
Expand Down Expand Up @@ -134,3 +136,72 @@ def test_application_tracks_end_to_end_broken(tmpdir: str):
)
def test__allowed_project_name(input_string, on_windows, expected_result):
assert _allowed_project_name(input_string, on_windows) == expected_result


class DummyPersister(BaseStatePersister):
"""Dummy persistor."""

def load(
self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs
) -> Optional[PersistedStateData]:
return PersistedStateData(
partition_key="user123",
app_id="123",
sequence_id=5,
position="counter",
state=State({"count": 5}),
created_at="",
status="completed",
)

def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
return ["123"]

def save(
self,
partition_key: Optional[str],
app_id: str,
sequence_id: int,
position: str,
state: State,
status: Literal["completed", "failed"],
**kwargs,
):
return


def test_persister_tracks_parent(tmpdir):
result = Result("count").with_name("result")
old_app_id = "old"
new_app_id = "new"
log_dir = os.path.join(tmpdir, "tracking")
results_dir = os.path.join(log_dir, "test_persister_tracks_parent", new_app_id)
project_name = "test_persister_tracks_parent"
app: Application = (
ApplicationBuilder()
.with_actions(counter, result)
.with_transitions(("counter", "result", default))
.initialize_from(
DummyPersister(),
resume_at_next_action=True,
default_state={},
default_entrypoint="counter",
fork_from_app_id=old_app_id,
fork_from_partition_key="user123",
fork_from_sequence_id=5,
)
.with_identifiers(app_id=new_app_id, partition_key="user123")
.with_tracker(project=project_name, tracker="local", params={"storage_dir": log_dir})
.build()
)
app.run(halt_after=["result"])
assert os.path.exists(
graph_output := os.path.join(results_dir, LocalTrackingClient.METADATA_FILENAME)
)
with open(graph_output) as f:
metadata = json.load(f)
metadata_parsed = ApplicationMetadataModel.model_validate(metadata)
assert metadata_parsed.partition_key == "user123"
assert metadata_parsed.parent_pointer.app_id == old_app_id
assert metadata_parsed.parent_pointer.sequence_id == 5
assert metadata_parsed.parent_pointer.partition_key == "user123"

0 comments on commit b780f16

Please sign in to comment.