diff --git a/tests/unit/test_workflows.py b/tests/unit/test_workflows.py index f777afa6..1710759a 100644 --- a/tests/unit/test_workflows.py +++ b/tests/unit/test_workflows.py @@ -23,12 +23,7 @@ from tracecat.contexts import ctx_role from tracecat.dsl.common import DSLInput, get_temporal_client from tracecat.dsl.worker import new_sandbox_runner -from tracecat.dsl.workflow import ( - DSLContext, - DSLRunArgs, - DSLWorkflow, - dsl_activities, -) +from tracecat.dsl.workflow import DSLActivities, DSLContext, DSLRunArgs, DSLWorkflow from tracecat.expressions import ExprContext from tracecat.identifiers.resource import ResourcePrefix from tracecat.types.exceptions import TracecatExpressionError @@ -128,7 +123,7 @@ async def test_workflow_can_run_from_yaml( async with Worker( client, task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"], - activities=dsl_activities, + activities=DSLActivities.load(), workflows=[DSLWorkflow], workflow_runner=new_sandbox_runner(), ): @@ -193,7 +188,7 @@ async def test_workflow_ordering_is_correct( async with Worker( client, task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"], - activities=dsl_activities, + activities=DSLActivities.load(), workflows=[DSLWorkflow], workflow_runner=new_sandbox_runner(), ): @@ -262,7 +257,7 @@ async def test_workflow_completes_and_correct( async with Worker( client, task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"], - activities=dsl_activities, + activities=DSLActivities.load(), workflows=[DSLWorkflow], workflow_runner=new_sandbox_runner(), ): @@ -292,7 +287,7 @@ async def test_conditional_execution_fails( async with Worker( client, task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"], - activities=dsl_activities, + activities=DSLActivities.load(), workflows=[DSLWorkflow], workflow_runner=new_sandbox_runner(), ): diff --git a/tracecat/dsl/worker.py b/tracecat/dsl/worker.py index 749b37f1..7ad64d7f 100644 --- a/tracecat/dsl/worker.py +++ b/tracecat/dsl/worker.py @@ -14,7 +14,7 @@ # are safe for workflow use with workflow.unsafe.imports_passed_through(): from tracecat.dsl.common import get_temporal_client - from tracecat.dsl.workflow import DSLWorkflow, dsl_activities + from tracecat.dsl.workflow import DSLActivities, DSLWorkflow from tracecat.registry import registry @@ -52,10 +52,11 @@ async def main() -> None: client = await get_temporal_client() # Run a worker for the activities and workflow + DSLActivities.init() async with Worker( client, task_queue=os.environ.get("TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue"), - activities=dsl_activities, + activities=DSLActivities.load(), workflows=[DSLWorkflow], workflow_runner=new_sandbox_runner(), ): diff --git a/tracecat/dsl/workflow.py b/tracecat/dsl/workflow.py index a7822b84..dd3d1bfa 100644 --- a/tracecat/dsl/workflow.py +++ b/tracecat/dsl/workflow.py @@ -229,7 +229,7 @@ async def execute_task(self, task: ActionStatement) -> None: self.logger.info("Executing task") # TODO: Set a retry policy for the activity activity_result = await workflow.execute_activity( - "run_udf", + _udf_key_to_activity_name(task.action), arg=UDFActionInput( task=task, role=self.role, @@ -267,12 +267,49 @@ class UDFActionInput(BaseModel): run_context: RunContext +def _udf_key_to_activity_name(key: str) -> str: + return key.replace(".", "__") + + class DSLActivities: + """Container for all UDFs registered in the registry.""" + def __new__(cls): # type: ignore raise RuntimeError("This class should not be instantiated") - @staticmethod - @activity.defn + @classmethod + def init(cls): + """Create activity methods from the UDF registry and attach them to DSLActivities.""" + global registry + for key in registry.keys: + # path.to.method_name -> path__to__method_name + method_name = _udf_key_to_activity_name(key) + + async def async_wrapper(input: UDFActionInput): + # loop = asyncio.get_event_loop() + # loop.set_task_factory(asyncio.eager_task_factory) + return await cls.run_udf(input) + + fn = activity.defn(name=method_name)(async_wrapper) + setattr(cls, method_name, staticmethod(fn)) + + return cls + + @classmethod + def get_activities(cls) -> list[Callable[[UDFActionInput], Any]]: + """Get all loaded UDFs in the class.""" + return [ + getattr(cls, method_name) + for method_name in dir(cls) + if hasattr(getattr(cls, method_name), "__temporal_activity_definition") + ] + + @classmethod + def load(cls) -> list[Callable[[UDFActionInput], Any]]: + """Load and return all UDFs in the class.""" + cls.init() + return cls.get_activities() + async def run_udf(input: UDFActionInput) -> Any: ctx_run.set(input.run_context) ctx_role.set(input.role) @@ -385,9 +422,8 @@ def patch_object(obj: dict[str, Any], *, path: str, value: Any, sep: str = ".") obj[leaf] = value -# Dynamically register all static methods as activities -dsl_activities = [ - getattr(DSLActivities, method_name) - for method_name in dir(DSLActivities) - if hasattr(getattr(DSLActivities, method_name), "__temporal_activity_definition") -] +if __name__ == "__main__": + print(DSLActivities.load()) + registry.init() + DSLActivities.init() + print(DSLActivities.load())