Skip to content

Commit

Permalink
feat(engine): Dynamically register UDFs in DSLActivities
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed Jun 16, 2024
1 parent 8766765 commit 7b2934b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 21 deletions.
15 changes: 5 additions & 10 deletions tests/unit/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@
from tracecat.contexts import ctx_role
from tracecat.dsl.common import DSLInput, get_temporal_client
from tracecat.dsl.worker import new_sandbox_runner
from tracecat.dsl.workflow import (
DSLContext,
DSLRunArgs,
DSLWorkflow,
dsl_activities,
)
from tracecat.dsl.workflow import DSLActivities, DSLContext, DSLRunArgs, DSLWorkflow
from tracecat.expressions import ExprContext
from tracecat.identifiers.resource import ResourcePrefix
from tracecat.types.exceptions import TracecatExpressionError
Expand Down Expand Up @@ -128,7 +123,7 @@ async def test_workflow_can_run_from_yaml(
async with Worker(
client,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
activities=dsl_activities,
activities=DSLActivities.load(),
workflows=[DSLWorkflow],
workflow_runner=new_sandbox_runner(),
):
Expand Down Expand Up @@ -193,7 +188,7 @@ async def test_workflow_ordering_is_correct(
async with Worker(
client,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
activities=dsl_activities,
activities=DSLActivities.load(),
workflows=[DSLWorkflow],
workflow_runner=new_sandbox_runner(),
):
Expand Down Expand Up @@ -262,7 +257,7 @@ async def test_workflow_completes_and_correct(
async with Worker(
client,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
activities=dsl_activities,
activities=DSLActivities.load(),
workflows=[DSLWorkflow],
workflow_runner=new_sandbox_runner(),
):
Expand Down Expand Up @@ -292,7 +287,7 @@ async def test_conditional_execution_fails(
async with Worker(
client,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
activities=dsl_activities,
activities=DSLActivities.load(),
workflows=[DSLWorkflow],
workflow_runner=new_sandbox_runner(),
):
Expand Down
5 changes: 3 additions & 2 deletions tracecat/dsl/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# are safe for workflow use
with workflow.unsafe.imports_passed_through():
from tracecat.dsl.common import get_temporal_client
from tracecat.dsl.workflow import DSLWorkflow, dsl_activities
from tracecat.dsl.workflow import DSLActivities, DSLWorkflow
from tracecat.registry import registry


Expand Down Expand Up @@ -52,10 +52,11 @@ async def main() -> None:
client = await get_temporal_client()

# Run a worker for the activities and workflow
DSLActivities.init()
async with Worker(
client,
task_queue=os.environ.get("TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue"),
activities=dsl_activities,
activities=DSLActivities.load(),
workflows=[DSLWorkflow],
workflow_runner=new_sandbox_runner(),
):
Expand Down
54 changes: 45 additions & 9 deletions tracecat/dsl/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ async def execute_task(self, task: ActionStatement) -> None:
self.logger.info("Executing task")
# TODO: Set a retry policy for the activity
activity_result = await workflow.execute_activity(
"run_udf",
_udf_key_to_activity_name(task.action),
arg=UDFActionInput(
task=task,
role=self.role,
Expand Down Expand Up @@ -267,12 +267,49 @@ class UDFActionInput(BaseModel):
run_context: RunContext


def _udf_key_to_activity_name(key: str) -> str:
return key.replace(".", "__")


class DSLActivities:
"""Container for all UDFs registered in the registry."""

def __new__(cls): # type: ignore
raise RuntimeError("This class should not be instantiated")

@staticmethod
@activity.defn
@classmethod
def init(cls):
"""Create activity methods from the UDF registry and attach them to DSLActivities."""
global registry
for key in registry.keys:
# path.to.method_name -> path__to__method_name
method_name = _udf_key_to_activity_name(key)

async def async_wrapper(input: UDFActionInput):
# loop = asyncio.get_event_loop()
# loop.set_task_factory(asyncio.eager_task_factory)
return await cls.run_udf(input)

fn = activity.defn(name=method_name)(async_wrapper)
setattr(cls, method_name, staticmethod(fn))

return cls

@classmethod
def get_activities(cls) -> list[Callable[[UDFActionInput], Any]]:
"""Get all loaded UDFs in the class."""
return [
getattr(cls, method_name)
for method_name in dir(cls)
if hasattr(getattr(cls, method_name), "__temporal_activity_definition")
]

@classmethod
def load(cls) -> list[Callable[[UDFActionInput], Any]]:
"""Load and return all UDFs in the class."""
cls.init()
return cls.get_activities()

async def run_udf(input: UDFActionInput) -> Any:
ctx_run.set(input.run_context)
ctx_role.set(input.role)
Expand Down Expand Up @@ -385,9 +422,8 @@ def patch_object(obj: dict[str, Any], *, path: str, value: Any, sep: str = ".")
obj[leaf] = value


# Dynamically register all static methods as activities
dsl_activities = [
getattr(DSLActivities, method_name)
for method_name in dir(DSLActivities)
if hasattr(getattr(DSLActivities, method_name), "__temporal_activity_definition")
]
if __name__ == "__main__":
print(DSLActivities.load())
registry.init()
DSLActivities.init()
print(DSLActivities.load())

0 comments on commit 7b2934b

Please sign in to comment.