Skip to content

Commit

Permalink
feat(engine): Implement basic conditional skip
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed Jun 6, 2024
1 parent 2754923 commit 8bd4139
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 38 deletions.
96 changes: 59 additions & 37 deletions tracecat/dsl/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Coroutine
from datetime import timedelta
from pathlib import Path
from typing import Any, Literal, Self, TypedDict
from typing import Annotated, Any, Literal, Self, TypedDict

import yaml
from temporalio import activity, workflow
Expand All @@ -19,9 +19,9 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator

from tracecat.auth import Role
from tracecat.contexts import ctx_role, ctx_run
from tracecat.contexts import ctx_role, ctx_run, ctx_logger
from tracecat.registry import registry
from tracecat.templates.eval import eval_templated_object
from tracecat import templates


SLUG_PATTERN = r"^[a-z0-9_]+$"
Expand Down Expand Up @@ -109,6 +109,8 @@ class ActionStatement(BaseModel):
depends_on: list[str] = Field(default_factory=list)
"""Task dependencies"""

run_if: Annotated[str | None, Field(default=None), templates.TemplateValidator()]


class DSLContext(TypedDict):
INPUTS: dict[str, Any]
Expand Down Expand Up @@ -139,10 +141,11 @@ def __init__(
self.adj: dict[str, set[str]] = defaultdict(set)
self.indegrees: dict[str, int] = {}
self.queue: asyncio.Queue[str] = asyncio.Queue()
self.running_tasks: dict[str, asyncio.Task[None]] = {}
# self.running_tasks: dict[str, asyncio.Task[None]] = {}
self.completed_tasks: set[str] = set()

self.executor = activity_coro
self.logger = ctx_logger.get(logger)

for task in dsl.actions:
self.tasks[task.ref] = task
Expand All @@ -162,7 +165,7 @@ async def _dynamic_task(self, task_ref: str) -> None:
await self.executor(task)

self.completed_tasks.add(task_ref)
logger.info("Task completed", task_ref=task_ref, role=ctx_role.get())
self.logger.info("Task completed", task_ref=task_ref)

# Update the indegrees of the tasks
async with asyncio.TaskGroup() as tg:
Expand All @@ -183,7 +186,7 @@ async def dynamic_start(self) -> None:
continue

asyncio.create_task(self._dynamic_task(task_ref))
logger.info("All tasks completed")
self.logger.info("All tasks completed")

async def _static_task(self, task_ref: str) -> None:
raise NotImplementedError
Expand All @@ -210,45 +213,63 @@ async def run(self, args: DSLRunArgs) -> DSLContext:
wf_id=workflow.info().workflow_id,
wf_run_id=workflow.info().run_id,
)
self.logger = logger.bind(wf_id=self.run_ctx.wf_id, role=self.role)
ctx_logger.set(self.logger)

self.dsl = args.dsl
self.context = DSLContext(INPUTS=self.dsl.inputs, ACTIONS={})
self.dep_list = {task.ref: task.depends_on for task in self.dsl.actions}
logger.info(
"Running DSL task workflow", wf_id=self.run_ctx.wf_id, role=self.role
)
self.logger.info("Running DSL task workflow")

self.scheduler = DSLScheduler(activity_coro=self.execute_task, dsl=self.dsl)
await self.scheduler.start()

logger.info("DSL workflow completed", wf_id=self.run_ctx.wf_id)
self.logger.info("DSL workflow completed")
return self.context

async def execute_task(self, task: ActionStatement) -> None:
"""Purely execute a task and manage the results."""
# Resolve all templated arguments
logger.info(f"Executing task {task.ref}")
# Securely inject secrets into the task arguments
# 1. Find all secrets in the task arguments
# 2. Load the secrets
# 3. Inject the secrets into the task arguments using an enriched context
processed_args = eval_templated_object(task.args, operand=self.context)

# TODO: Set a retry policy for the activity
activity_result = await workflow.execute_activity(
"run_udf",
arg=UDFActionArgs(
type=task.action,
args=processed_args,
role=self.role,
context=self.run_ctx,
),
start_to_close_timeout=timedelta(minutes=1),
)
self.context["ACTIONS"][task.ref] = DSLNodeResult(
result=activity_result,
result_typename=type(activity_result).__name__,
)
"""Purely execute a task and manage the results.
Preflight checks
---------------
1. Evaluate `run_if` condition
2. Resolve all templated arguments
"""
# Evaluate the `run_if` condition
with self.logger.contextualize(task_ref=task.ref):
if task.run_if is not None:
expr = templates.TemplateExpression(task.run_if, operand=self.context)
self.logger.info("`run_if` condition", task_run_if=task.run_if)
if not bool(expr.result()):
self.logger.info("Task skipped")
return

self.logger.info("Executing task")
# TODO: Securely inject secrets into the task arguments
# 1. Find all secrets in the task arguments
# 2. Load the secrets
# 3. Inject the secrets into the task arguments using an enriched context

# Resolve all templated arguments
processed_args = templates.eval_templated_object(
task.args, operand=self.context
)

# TODO: Set a retry policy for the activity
activity_result = await workflow.execute_activity(
"run_udf",
arg=UDFActionArgs(
type=task.action,
args=processed_args,
role=self.role,
context=self.run_ctx,
),
start_to_close_timeout=timedelta(minutes=1),
)
self.context["ACTIONS"][task.ref] = DSLNodeResult(
result=activity_result,
result_typename=type(activity_result).__name__,
)


class UDFActionArgs(BaseModel):
Expand All @@ -267,20 +288,21 @@ def __new__(cls): # type: ignore
async def run_udf(action: UDFActionArgs) -> Any:
ctx_run.set(action.context)
ctx_role.set(action.role)
act_logger = logger.bind(wf_id=action.context.wf_id, role=action.role)
ctx_logger.set(act_logger)

udf = registry[action.type]
logger.info(
act_logger.info(
"Run udf",
type=action.type,
role=action.role,
is_async=udf.is_async,
args=action.args,
)
if udf.is_async:
result = await udf.fn(**action.args)
else:
result = await asyncio.to_thread(udf.fn, **action.args)
logger.info(f"Result: {result}")
act_logger.info("Result", result=result)
return result


Expand Down
4 changes: 3 additions & 1 deletion tracecat/templates/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .eval import eval_templated_object
from .expressions import TemplateExpression
from .validators import TemplateValidator

__all__ = ["TemplateValidator"]
__all__ = ["TemplateValidator", "TemplateExpression", "eval_templated_object"]

0 comments on commit 8bd4139

Please sign in to comment.