Skip to content

Commit

Permalink
feat+refactor(engine): Add ResourcePrefix + rename gen_id to id_facto…
Browse files Browse the repository at this point in the history
…ry + distinguish between workflow execution and run
  • Loading branch information
daryllimyt committed Jun 16, 2024
1 parent 72e35a1 commit deffd29
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 52 deletions.
27 changes: 16 additions & 11 deletions tests/unit/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import yaml
from loguru import logger
from slugify import slugify
from temporalio.common import RetryPolicy
from temporalio.worker import Worker

Expand All @@ -29,6 +30,7 @@
dsl_activities,
)
from tracecat.expressions import ExprContext
from tracecat.identifiers.resource import ResourcePrefix
from tracecat.types.exceptions import TracecatExpressionError

DATA_PATH = Path(__file__).parent.parent.joinpath("data/workflows")
Expand All @@ -39,9 +41,12 @@
TEST_WF_ID = "wf-00000000000000000000000000000000"


def gen_test_run_id(name: str) -> str:
run_id = TEST_WF_ID + ":run-" + name
return run_id
def generate_test_exec_id(name: str) -> str:
return (
TEST_WF_ID
+ f":{ResourcePrefix.WORKFLOW_EXECUTION}-"
+ slugify(name, separator="_")
)


@pytest.fixture
Expand Down Expand Up @@ -117,7 +122,7 @@ async def test_workflow_can_run_from_yaml(
dsl, temporal_cluster, mock_registry, auth_sandbox
):
test_name = f"test_workflow_can_run_from_yaml-{dsl.title}"
run_id = gen_test_run_id(test_name)
wf_exec_id = generate_test_exec_id(test_name)
client = await get_temporal_client()
# Run workflow
async with Worker(
Expand All @@ -130,7 +135,7 @@ async def test_workflow_can_run_from_yaml(
result = await client.execute_workflow(
DSLWorkflow.run,
DSLRunArgs(dsl=dsl, role=ctx_role.get(), wf_id=TEST_WF_ID),
id=run_id,
id=wf_exec_id,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
retry_policy=RetryPolicy(maximum_attempts=1),
)
Expand Down Expand Up @@ -182,7 +187,7 @@ async def test_workflow_ordering_is_correct(
# Connect client

test_name = f"test_workflow_ordering_is_correct-{dsl.title}"
run_id = gen_test_run_id(test_name)
wf_exec_id = generate_test_exec_id(test_name)
client = await get_temporal_client()
# Run a worker for the activities and workflow
async with Worker(
Expand All @@ -195,7 +200,7 @@ async def test_workflow_ordering_is_correct(
result = await client.execute_workflow(
DSLWorkflow.run,
DSLRunArgs(dsl=dsl, role=ctx_role.get(), wf_id=TEST_WF_ID),
id=run_id,
id=wf_exec_id,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
retry_policy=RetryPolicy(maximum_attempts=1),
)
Expand Down Expand Up @@ -250,7 +255,7 @@ async def test_workflow_completes_and_correct(
):
dsl, expected = dsl_with_expected
test_name = f"test_correctness_execution-{dsl.title}"
run_id = gen_test_run_id(test_name)
wf_exec_id = generate_test_exec_id(test_name)

client = await get_temporal_client()
# Run a worker for the activities and workflow
Expand All @@ -264,7 +269,7 @@ async def test_workflow_completes_and_correct(
result = await client.execute_workflow(
DSLWorkflow.run,
DSLRunArgs(dsl=dsl, role=ctx_role.get(), wf_id=TEST_WF_ID),
id=run_id,
id=wf_exec_id,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
retry_policy=RetryPolicy(maximum_attempts=1),
)
Expand All @@ -282,7 +287,7 @@ async def test_conditional_execution_fails(
dsl, temporal_cluster, mock_registry, auth_sandbox
):
test_name = f"test_conditional_execution-{dsl.title}"
run_id = gen_test_run_id(test_name)
wf_exec_id = generate_test_exec_id(test_name)
client = await get_temporal_client()
async with Worker(
client,
Expand All @@ -298,7 +303,7 @@ async def test_conditional_execution_fails(
await client.execute_workflow(
DSLWorkflow.run,
DSLRunArgs(dsl=dsl, role=ctx_role.get(), wf_id=TEST_WF_ID),
id=run_id,
id=wf_exec_id,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
retry_policy=RetryPolicy(
maximum_attempts=0,
Expand Down
7 changes: 5 additions & 2 deletions tracecat/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

from pydantic import BaseModel

from tracecat import identifiers

if TYPE_CHECKING:
from tracecat.auth.credentials import Role


class RunContext(BaseModel):
wf_id: str
wf_run_id: str
wf_id: identifiers.workflow.WorkflowID
wf_exec_id: identifiers.workflow.WorkflowExecutionID
wf_run_id: identifiers.workflow.WorkflowRunID


ctx_run: ContextVar[RunContext] = ContextVar("run", default=None)
Expand Down
28 changes: 14 additions & 14 deletions tracecat/db/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tracecat import config, registry
from tracecat.auth.credentials import compute_hash, decrypt_object, encrypt_object
from tracecat.dsl.common import DSLInput
from tracecat.identifiers import action, gen_id
from tracecat.identifiers import action, id_factory
from tracecat.types.secrets import SECRET_FACTORY, SecretBase, SecretKeyValue

DEFAULT_CASE_ACTIONS = [
Expand Down Expand Up @@ -49,7 +49,7 @@ class Resource(SQLModel):
class User(Resource, table=True):
# The id is also the JWT 'sub' claim
id: str = Field(
default_factory=gen_id("user"), nullable=False, unique=True, index=True
default_factory=id_factory("user"), nullable=False, unique=True, index=True
)
tier: str = "free" # "free" or "premium"
settings: str | None = None # JSON-serialized String of settings
Expand All @@ -67,7 +67,7 @@ class User(Resource, table=True):

class Secret(Resource, table=True):
id: str = Field(
default_factory=gen_id("secret"), nullable=False, unique=True, index=True
default_factory=id_factory("secret"), nullable=False, unique=True, index=True
)
type: str = "custom" # "custom", "token", "oauth2"
name: str = Field(
Expand Down Expand Up @@ -111,7 +111,7 @@ def keys(self, value: list[SecretKeyValue]) -> None:

class CaseAction(Resource, table=True):
id: str = Field(
default_factory=gen_id("case-act"), nullable=False, unique=True, index=True
default_factory=id_factory("case-act"), nullable=False, unique=True, index=True
)
tag: str
value: str
Expand All @@ -121,7 +121,7 @@ class CaseAction(Resource, table=True):

class CaseContext(Resource, table=True):
id: str = Field(
default_factory=gen_id("case-ctx"), nullable=False, unique=True, index=True
default_factory=id_factory("case-ctx"), nullable=False, unique=True, index=True
)
tag: str
value: str
Expand All @@ -131,7 +131,7 @@ class CaseContext(Resource, table=True):

class CaseEvent(Resource, table=True):
id: str = Field(
default_factory=gen_id("case-evt"), nullable=False, unique=True, index=True
default_factory=id_factory("case-evt"), nullable=False, unique=True, index=True
)
type: str # The CaseEvent type
workflow_id: str
Expand All @@ -157,7 +157,7 @@ class UDFSpec(Resource, table=True):
"""

id: str = Field(
default_factory=gen_id("udf"), nullable=False, unique=True, index=True
default_factory=id_factory("udf"), nullable=False, unique=True, index=True
)
description: str
namespace: str
Expand Down Expand Up @@ -206,7 +206,7 @@ class WorkflowDefinition(Resource, table=True):

# Metadata
id: str = Field(
default_factory=gen_id("wf-defn"), nullable=False, unique=True, index=True
default_factory=id_factory("wf-defn"), nullable=False, unique=True, index=True
)
version: int = Field(..., index=True, description="DSL spec version")
workflow_id: str = Field(
Expand Down Expand Up @@ -234,7 +234,7 @@ class Workflow(Resource, table=True):
"""

id: str = Field(
default_factory=gen_id("wf"), nullable=False, unique=True, index=True
default_factory=id_factory("wf"), nullable=False, unique=True, index=True
)
title: str
description: str
Expand Down Expand Up @@ -276,7 +276,7 @@ class Workflow(Resource, table=True):

class WorkflowRun(Resource, table=True):
id: str = Field(
default_factory=gen_id("wf-run"), nullable=False, unique=True, index=True
default_factory=id_factory("wf-run"), nullable=False, unique=True, index=True
)
status: str = "pending" # "online" or "offline"
workflow_id: str | None = Field(foreign_key="workflow.id")
Expand All @@ -286,7 +286,7 @@ class WorkflowRun(Resource, table=True):

class Webhook(Resource, table=True):
id: str = Field(
default_factory=gen_id("wh"), nullable=False, unique=True, index=True
default_factory=id_factory("wh"), nullable=False, unique=True, index=True
)
status: str = "offline" # "online" or "offline"
method: str = "POST"
Expand All @@ -311,7 +311,7 @@ def url(self) -> str:

class Schedule(Resource, table=True):
id: str = Field(
default_factory=gen_id("sch"), nullable=False, unique=True, index=True
default_factory=id_factory("sch"), nullable=False, unique=True, index=True
)
status: str = "offline" # "online" or "offline"
cron: str
Expand Down Expand Up @@ -339,7 +339,7 @@ class Action(Resource, table=True):
"""The workspace action state."""

id: str = Field(
default_factory=gen_id("act"), nullable=False, unique=True, index=True
default_factory=id_factory("act"), nullable=False, unique=True, index=True
)
type: str = Field(..., description="The action type, i.e. UDF key")
title: str
Expand Down Expand Up @@ -367,7 +367,7 @@ def ref(self) -> str:

class ActionRun(Resource, table=True):
id: str = Field(
default_factory=gen_id("act-run"), nullable=False, unique=True, index=True
default_factory=id_factory("act-run"), nullable=False, unique=True, index=True
)
status: str = "pending" # "online" or "offline"
# TODO: This allows action/action_id to be None, which may be undesirable.
Expand Down
8 changes: 5 additions & 3 deletions tracecat/dsl/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ async def dispatch_workflow(
) -> DispatchResult:
# Connect client
role = ctx_role.get()
wf_run_id = identifiers.workflow.run_id(wf_id)
logger.info(f"Executing DSL workflow: {dsl.title}", role=role, wf_run_id=wf_run_id)
wf_exec_id = identifiers.workflow.exec_id(wf_id)
logger.info(
f"Executing DSL workflow: {dsl.title}", role=role, wf_exec_id=wf_exec_id
)
client = await get_temporal_client()
# Run workflow
result = await client.execute_workflow(
DSLWorkflow.run,
DSLRunArgs(dsl=dsl, role=role, wf_id=wf_id),
id=wf_run_id,
id=wf_exec_id,
task_queue=os.environ.get("TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue"),
**kwargs,
)
Expand Down
9 changes: 7 additions & 2 deletions tracecat/dsl/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,16 @@ class DSLWorkflow:
async def run(self, args: DSLRunArgs) -> DSLContext:
# Setup
self.role = args.role
self.tracecat_wf_id = args.wf_id
# Temporal workflow execution ID == Tracecat workflow run ID
self.tracecat_wf_run_id = workflow.info().workflow_id

self.run_ctx = RunContext(
wf_id=workflow.info().workflow_id,
wf_id=args.wf_id,
wf_exec_id=workflow.info().workflow_id,
wf_run_id=workflow.info().run_id,
)
self.logger = logger.bind(wf_id=self.run_ctx.wf_id, role=self.role)
self.logger = logger.bind(run_ctx=self.run_ctx, role=self.role)
ctx_logger.set(self.logger)

self.dsl = args.dsl
Expand Down
8 changes: 4 additions & 4 deletions tracecat/identifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@

from tracecat.identifiers import action, workflow
from tracecat.identifiers.action import ActionID, ActionKey, ActionRef
from tracecat.identifiers.resource import gen_id, gen_resource_id
from tracecat.identifiers.workflow import WorkflowID, WorkflowRunID
from tracecat.identifiers.resource import id_factory
from tracecat.identifiers.workflow import WorkflowExecutionID, WorkflowID, WorkflowRunID

__all__ = [
"ActionID",
"ActionKey",
"ActionRef",
"WorkflowID",
"WorkflowExecutionID",
"WorkflowRunID",
"gen_resource_id",
"gen_id",
"id_factory",
"action",
"workflow",
]
4 changes: 3 additions & 1 deletion tracecat/identifiers/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pydantic import StringConstraints
from slugify import slugify

from tracecat.identifiers.resource import ResourcePrefix

ActionID = Annotated[str, StringConstraints(pattern=r"act-[0-9a-f]{32}")]
"""A unique ID for an action. e.g. 'act-77932a0b140a4465a1a25a5c95edcfb8'"""

Expand All @@ -22,4 +24,4 @@ def ref(text: str) -> ActionRef:

def key(workflow_id: str, action_ref: str) -> ActionKey:
"""Identifier key for an action, using the workflow ID and action ref."""
return f"act:{workflow_id}:{action_ref}"
return f"{ResourcePrefix.ACTION}:{workflow_id}:{action_ref}"
47 changes: 44 additions & 3 deletions tracecat/identifiers/resource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Resource identifiers generation utilities."""

from __future__ import annotations

from collections.abc import Callable
from enum import StrEnum
from typing import Annotated
from uuid import uuid4

Expand All @@ -10,16 +13,54 @@
"""Resource identifier pattern. e.g. 'wf-77932a0b140a4465a1a25a5c95edcfb8'"""


def gen_resource_id(prefix: str, *, sep: str = "-") -> ResourceID:
def generate_resource_id(prefix: ResourcePrefix, *, sep: str = "-") -> ResourceID:
"""Generate a short unique identifier with a prefix."""

return prefix + sep + uuid4().hex


def gen_id(prefix: str, *, sep: str = "-") -> Callable[[str], ResourceID]:
def id_factory(
prefix: ResourcePrefix, *, sep: str = "-"
) -> Callable[[str], ResourceID]:
"""Factory function to generate a short unique identifier with a prefix."""

# Assert that the prefix is a valid resource class identifier.
if prefix not in ResourcePrefix:
raise ValueError(f"Invalid resource class identifier: {prefix!r}")

def wrapper() -> ResourceID:
return gen_resource_id(prefix, sep=sep)
return generate_resource_id(prefix, sep=sep)

return wrapper


class ResourcePrefix(StrEnum):
"""Resource class identifier."""

ACTION = "act"
ACTION_RUN = "act-run" # TODO: Unused
UDF = "udf"
WORKFLOW = "wf"
WORKFLOW_EXECUTION = "exec"
WORKFLOW_DEFN = "wf-defn"
WORKFLOW_RUN = "wf-run" # TODO: Unused
WEBHOOK = "wh"
SCHEDULE = "sch"
SECRET = "secret"
USER = "user"
ORG = "org"
CASE = "case"
CASE_ACTION = "case-act"
CASE_EVENT = "case-evt"
CASE_CONTEXT = "case-ctx"

def factory(self) -> Callable[[], ResourceID]:
"""Generate a unique ID with this prefix."""

return id_factory(self)


if __name__ == "__main__":
print(ResourcePrefix.WORKFLOW.factory()())
print(id_factory("act")())
print(id_factory("fails!"))
Loading

0 comments on commit deffd29

Please sign in to comment.