Skip to content

Commit

Permalink
Adds ability for action to specify its own source code
Browse files Browse the repository at this point in the history
This enables people to build classes that have different sources (E.G.
function wrappers)
  • Loading branch information
elijahbenizzy committed May 22, 2024
1 parent a295c07 commit e10db73
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 6 deletions.
15 changes: 15 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ def single_step(self) -> bool:
def streaming(self) -> bool:
return False

def get_source(self) -> str:
"""Returns the source code of the action. This will default to
the source code of the class in which the action is implemented,
but can be overwritten." Override if you want debugging/tracking
to display a different source"""
return inspect.getsource(self.__class__)

def __repr__(self):
read_repr = ", ".join(self.reads) if self.reads else "{}"
write_repr = ", ".join(self.writes) if self.writes else "{}"
Expand Down Expand Up @@ -539,6 +546,10 @@ def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
def is_async(self) -> bool:
return inspect.iscoroutinefunction(self._fn)

def get_source(self) -> str:
"""Return the source of the code for this action."""
return inspect.getsource(self._fn)


StreamType = Tuple[dict, Optional[State]]

Expand Down Expand Up @@ -984,6 +995,10 @@ def fn(self) -> Union[StreamingFn, StreamingFnAsync]:
def is_async(self) -> bool:
return inspect.isasyncgenfunction(self._fn)

def get_source(self) -> str:
"""Return the source of the code for this action"""
return inspect.getsource(self._fn)


def _validate_action_function(fn: Callable):
"""Validates that an action has the signature: (state: State) -> tuple[dict, State]
Expand Down
7 changes: 1 addition & 6 deletions burr/tracking/common/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import datetime
import inspect
from typing import Any, Dict, List, Optional, Union

from pydantic import field_serializer

from burr.common import types as burr_types
from burr.core import Action
from burr.core.action import FunctionBasedAction, FunctionBasedStreamingAction
from burr.core.application import ApplicationGraph, Transition
from burr.integrations.base import require_plugin

Expand Down Expand Up @@ -52,10 +50,7 @@ def from_action(action: Action) -> "ActionModel":
:param action: Action to create the model from
:return:
"""
if isinstance(action, (FunctionBasedAction, FunctionBasedStreamingAction)):
code = inspect.getsource(action.fn)
else:
code = inspect.getsource(action.__class__)
code = action.get_source() # delegate to the action
optional_inputs, required_inputs = action.optional_and_required_inputs
return ActionModel(
name=action.name,
Expand Down
29 changes: 29 additions & 0 deletions tests/tracking/test_common_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from burr.core import Action, State
from burr.tracking.common.models import ActionModel


class ActionWithCustomSource(Action):
def __init__(self):
super().__init__()

@property
def reads(self) -> list[str]:
return []

def run(self, state: State, **run_kwargs) -> dict:
return {}

@property
def writes(self) -> list[str]:
return []

def update(self, result: dict, state: State) -> State:
return state

def get_source(self) -> str:
return "custom source code"


def test_action_with_custom_source():
model = ActionModel.from_action(ActionWithCustomSource().with_name("foo"))
assert model.code == "custom source code"

0 comments on commit e10db73

Please sign in to comment.