Skip to content

Commit

Permalink
Updates chatbot in-app demo to use new persistence capability
Browse files Browse the repository at this point in the history
This also chagnes the tracker to enable that.
  • Loading branch information
elijahbenizzy committed Mar 19, 2024
1 parent ba7cb8a commit 2eca07d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 13 deletions.
55 changes: 53 additions & 2 deletions burr/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(

self.f = None
self.storage_dir = LocalTrackingClient.get_storage_path(project, storage_dir)
self.project_id = project

@classmethod
def get_storage_path(cls, project, storage_dir):
Expand Down Expand Up @@ -118,7 +119,9 @@ def load_state(
sequence_id: int = -1,
storage_dir: str = DEFAULT_STORAGE_DIR,
) -> tuple[dict, str]:
"""Function to load state from what the tracking client got.
"""THis is deprecated and will be removed when we migrate over demos. Do not use! Instead use
the persistence API :py:class:`initialize_from <burr.core.application.ApplicationBuilder.initialize_from>`
to load state.
It defaults to loading the last state, but you can supply a sequence number.
Expand Down Expand Up @@ -293,7 +296,55 @@ def load(
self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None
) -> Optional[PersistedStateData]:
# TODO:
pass
if app_id is None:
return # no application ID
if sequence_id is None:
sequence_id = -1 # get the last one
path = os.path.join(self.storage_dir, app_id, self.LOG_FILENAME)
if not os.path.exists(path):
raise ValueError(
f"No logs found for {self.project_id}/{app_id} under {self.storage_dir}"
)
with open(path, "r") as f:
json_lines = f.readlines()
# load as JSON
json_lines = [json.loads(js_line) for js_line in json_lines]
# filter to only end_entry
json_lines = [js_line for js_line in json_lines if js_line["type"] == "end_entry"]
try:
line = json_lines[sequence_id]
except IndexError:
raise ValueError(
f"Sequence number {sequence_id} not found for {self.project_id}/{app_id}."
)
# check sequence number matches if non-negative; will break if either is None.
line_seq = int(line["sequence_id"])
if -1 < sequence_id != line_seq:
logger.warning(
f"Sequence number mismatch. For {self.project_id}/{app_id}: "
f"actual:{line_seq} != expected:{sequence_id}"
)
# get the prior state
prior_state = line["state"]
position = line["action"]
# delete internally stuff. We can't loop over the keys and delete them in the same loop
to_delete = []
for key in prior_state.keys():
# remove any internal "__" state
if key.startswith("__"):
to_delete.append(key)
for key in to_delete:
del prior_state[key]
prior_state["__SEQUENCE_ID"] = line_seq # add the sequence id back
return {
"partition_key": partition_key,
"app_id": app_id,
"sequence_id": line_seq,
"position": position,
"state": State(prior_state),
"created_at": datetime.datetime.fromtimestamp(os.path.getctime(path)).isoformat(),
"status": "success" if line["exception"] is None else "failed",
}


# TODO -- implement async version
Expand Down
10 changes: 2 additions & 8 deletions burr/tracking/server/examples/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from fastapi import FastAPI
from starlette.requests import Request

from burr.core import Application, State
from burr.core import Application
from burr.examples.gpt import application as chat_application
from burr.tracking import LocalTrackingClient


class ChatItem(pydantic.BaseModel):
Expand All @@ -18,12 +17,7 @@ class ChatItem(pydantic.BaseModel):

@functools.lru_cache(maxsize=128)
def _get_application(project_id: str, app_id: str) -> Application:
app = chat_application.application(use_hamilton=False, app_id=app_id, project_id="demo:chatbot")
if LocalTrackingClient.app_log_exists(project_id, app_id):
state, _ = LocalTrackingClient.load_state(project_id, app_id) # TODO -- handle entrypoint
app.update_state(
State(state)
) # TODO -- handle the entrypoint -- this will always reset to prompt
app = chat_application.application(use_hamilton=False, app_id=app_id, project_id=project_id)
return app


Expand Down
15 changes: 12 additions & 3 deletions examples/gpt/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from burr.core.action import action
from burr.integrations.hamilton import Hamilton, append_state, from_state, update_state
from burr.lifecycle import LifecycleAdapter
from burr.tracking import LocalTrackingClient

MODES = {
"answer_question": "text",
Expand Down Expand Up @@ -149,6 +150,9 @@ def base_application(
):
if hooks is None:
hooks = []
# we're initializing above so we can load from this as well
# we could also use `with_tracker("local", project=project_id, params={"storage_dir": storage_dir})`
tracker = LocalTrackingClient(project=project_id, storage_dir=storage_dir)
return (
ApplicationBuilder()
.with_actions(
Expand All @@ -166,8 +170,6 @@ def base_application(
prompt_for_more=prompt_for_more,
response=response,
)
.with_entrypoint("prompt")
.with_state(chat_history=[])
.with_transitions(
("prompt", "check_openai_key", default),
("check_openai_key", "check_safety", when(has_openai_key=True)),
Expand All @@ -184,8 +186,15 @@ def base_application(
),
("response", "prompt", default),
)
# initializes from the tracking log if it does not already exist
.initialize_from(
tracker,
resume_at_next_action=False, # always resume from entrypoint in the case of failure
default_state={"chat_history": []},
default_entrypoint="prompt",
)
.with_hooks(*hooks)
.with_tracker("local", project=project_id, params={"storage_dir": storage_dir})
.with_tracker(tracker)
.with_identifiers(app_id=app_id)
.build()
)
Expand Down

0 comments on commit 2eca07d

Please sign in to comment.