Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ability for action to specify its own source code #195

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ def single_step(self) -> bool:
def streaming(self) -> bool:
return False

def get_source(self) -> str:
"""Returns the source code of the action. This will default to
the source code of the class in which the action is implemented,
but can be overwritten." Override if you want debugging/tracking
to display a different source"""
return inspect.getsource(self.__class__)

def __repr__(self):
read_repr = ", ".join(self.reads) if self.reads else "{}"
write_repr = ", ".join(self.writes) if self.writes else "{}"
Expand Down Expand Up @@ -539,6 +546,10 @@ def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
def is_async(self) -> bool:
return inspect.iscoroutinefunction(self._fn)

def get_source(self) -> str:
"""Return the source of the code for this action."""
return inspect.getsource(self._fn)


StreamType = Tuple[dict, Optional[State]]

Expand Down Expand Up @@ -984,6 +995,10 @@ def fn(self) -> Union[StreamingFn, StreamingFnAsync]:
def is_async(self) -> bool:
return inspect.isasyncgenfunction(self._fn)

def get_source(self) -> str:
"""Return the source of the code for this action"""
return inspect.getsource(self._fn)


def _validate_action_function(fn: Callable):
"""Validates that an action has the signature: (state: State) -> tuple[dict, State]
Expand Down
7 changes: 1 addition & 6 deletions burr/tracking/common/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import datetime
import inspect
from typing import Any, Dict, List, Optional, Union

from pydantic import field_serializer

from burr.common import types as burr_types
from burr.core import Action
from burr.core.action import FunctionBasedAction, FunctionBasedStreamingAction
from burr.core.application import ApplicationGraph, Transition
from burr.integrations.base import require_plugin

Expand Down Expand Up @@ -52,10 +50,7 @@ def from_action(action: Action) -> "ActionModel":
:param action: Action to create the model from
:return:
"""
if isinstance(action, (FunctionBasedAction, FunctionBasedStreamingAction)):
code = inspect.getsource(action.fn)
else:
code = inspect.getsource(action.__class__)
code = action.get_source() # delegate to the action
optional_inputs, required_inputs = action.optional_and_required_inputs
return ActionModel(
name=action.name,
Expand Down
29 changes: 29 additions & 0 deletions tests/tracking/test_common_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from burr.core import Action, State
from burr.tracking.common.models import ActionModel


class ActionWithCustomSource(Action):
def __init__(self):
super().__init__()

@property
def reads(self) -> list[str]:
return []

def run(self, state: State, **run_kwargs) -> dict:
return {}

@property
def writes(self) -> list[str]:
return []

def update(self, result: dict, state: State) -> State:
return state

def get_source(self) -> str:
return "custom source code"


def test_action_with_custom_source():
model = ActionModel.from_action(ActionWithCustomSource().with_name("foo"))
assert model.code == "custom source code"
Loading