Skip to content

Commit

Permalink
feat(engine): Upgrade commit workflow to propagate changes to Workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed Jun 15, 2024
1 parent f3a0b62 commit 820ff66
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 132 deletions.
161 changes: 98 additions & 63 deletions tracecat/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from fastapi.params import Body
from fastapi.responses import ORJSONResponse, StreamingResponse
from pydantic_core import ValidationError
from sqlalchemy import Engine, or_
from sqlalchemy.exc import NoResultFound
from sqlalchemy import Engine, delete, or_
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlmodel import Session, select

from tracecat import config
Expand Down Expand Up @@ -92,7 +92,6 @@
UpdateUserParams,
UpdateWorkflowParams,
UpsertWebhookParams,
UpsertWorkflowDefinitionParams,
WebhookResponse,
WorkflowMetadataResponse,
WorkflowResponse,
Expand All @@ -101,6 +100,7 @@
)
from tracecat.types.cases import Case, CaseMetrics
from tracecat.types.exceptions import TracecatException, TracecatValidationError
from tracecat.utils import action_key

engine: Engine

Expand Down Expand Up @@ -179,8 +179,8 @@ def create_app(**kwargs) -> FastAPI:
@app.exception_handler(Exception)
async def custom_exception_handler(request: Request, exc: Exception):
logger.error(
"Unexpected error: {!s}",
exc,
"Unexpected error",
exc=exc,
role=ctx_role.get(),
params=request.query_params,
path=request.url.path,
Expand Down Expand Up @@ -558,12 +558,10 @@ def commit_workflow(
This deploys the workflow and updates its version. If a YAML file is provided, it will override the workflow in the database."""

with Session(engine) as session:
if yaml_file:
# Uploaded YAML file overrides the workflow in the database
dsl = DSLInput.from_yaml(yaml_file.file)
logger.info("Commiting workflow from yaml file", role=role)
else:
# Committing from YAML (i.e. attaching yaml) will override the workflow definition in the database

with Session(engine) as session, logger.contextualize(role=role):
try:
# Grab workflow and actions from tables
statement = select(Workflow).where(
Workflow.owner_id == role.user_id,
Expand All @@ -576,13 +574,86 @@ def commit_workflow(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found"
) from e

# Hydrate actions
_ = workflow.actions
if yaml_file:
# Uploaded YAML file overrides the workflow in the database
dsl = DSLInput.from_yaml(yaml_file.file)
logger.info("Commiting workflow from yaml file")
else:
# Convert the workflow into a WorkflowDefinition
dsl = converters.workflow_to_dsl(workflow)
logger.info("Commiting workflow from database")
# Phase 1: Commit
defn = _create_wf_definition(session, role, workflow_id, dsl)
# Phase 2: Backpropagate
new_graph = converters.dsl_to_graph(workflow, dsl)

# Replace Actions
del_stmt = delete(Action).where(
Action.workflow_id == workflow_id, Action.owner_id == role.user_id
)
session.exec(del_stmt)
logger.info(result)

session.flush() # Ensure deletions are flushed
session.refresh(workflow)

for act_stmt in dsl.actions:
new_action = Action(
id=action_key(workflow_id, act_stmt.ref),
owner_id=role.user_id,
workflow_id=workflow_id,
type=act_stmt.action,
inputs=act_stmt.args,
title=act_stmt.title,
description=act_stmt.description,
)
session.add(new_action)

# Update Workflow
workflow.object = new_graph.model_dump(by_alias=True)
workflow.version = defn.version
workflow.title = dsl.title
workflow.description = dsl.description

session.add(workflow)
session.add(defn)
session.commit()
session.refresh(workflow)
session.refresh(defn)

except SQLAlchemyError as e:
session.rollback()
logger.opt(exception=e).error("Error committing workflow", error=e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while committing the workflow.",
) from e

# Convert the workflow into a WorkflowDefinition
dsl = converters.workflow_to_dsl(workflow)
_upsert_workflow_definition(session, role, workflow_id, dsl)

def _create_wf_definition(
session: Session, role: Role, workflow_id: str, dsl: DSLInput
) -> WorkflowDefinition:
statement = (
select(WorkflowDefinition)
.where(
WorkflowDefinition.owner_id == role.user_id,
WorkflowDefinition.workflow_id == workflow_id,
)
.order_by(WorkflowDefinition.version.desc())
)
result = session.exec(statement)
latest_defn = result.first()

version = latest_defn.version + 1 if latest_defn else 1
defn = WorkflowDefinition(
owner_id=role.user_id,
workflow_id=workflow_id,
content=dsl.model_dump(),
version=version,
)
return defn


# ----- Workflow Definitions ----- #
Expand Down Expand Up @@ -635,54 +706,6 @@ def get_workflow_definition(
) from e


@app.post(
"/workflows/{workflow_id}/definition",
status_code=status.HTTP_204_NO_CONTENT,
tags=["workflows"],
)
def upsert_workflow_definition(
role: Annotated[Role, Depends(authenticate_user_or_service)],
workflow_id: str,
params: UpsertWorkflowDefinitionParams,
) -> None:
"""Upsert a workflow definition."""

with Session(engine) as session:
_upsert_workflow_definition(session, role, workflow_id, params.content)


def _upsert_workflow_definition(
session: Session, role: Role, workflow_id: str, content: DSLInput
) -> WorkflowDefinition:
statement = (
select(WorkflowDefinition)
.where(
WorkflowDefinition.owner_id == role.user_id,
WorkflowDefinition.workflow_id == workflow_id,
)
.order_by(WorkflowDefinition.version.desc())
)
result = session.exec(statement)
latest_defn = result.first()

version = latest_defn.version + 1 if latest_defn else 1
defn = WorkflowDefinition(
owner_id=role.user_id,
workflow_id=workflow_id,
content=content.model_dump(),
version=version,
)
session.add(defn)
session.commit()
session.refresh(defn)

workflow = defn.workflow # Hydrate relationship attr
workflow.version = version
session.add(workflow)
session.commit()
session.refresh(workflow)


# ----- Workflow Runs ----- #


Expand Down Expand Up @@ -1045,6 +1068,18 @@ def create_action(
title=params.title,
description="", # Default to empty string
)
# Check if a clashing action ref exists
statement = select(Action).where(
Action.owner_id == role.user_id,
Action.workflow_id == action.workflow_id,
Action.ref == action.ref,
)
if session.exec(statement).first():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Action ref already exists in the workflow",
)

session.add(action)
session.commit()
session.refresh(action)
Expand Down
21 changes: 6 additions & 15 deletions tracecat/db/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@

from datetime import datetime
from typing import Any, Self
from uuid import uuid4

import pyarrow as pa
from pydantic import computed_field, field_validator
from slugify import slugify
from sqlalchemy import JSON, TIMESTAMP, Column, ForeignKey, String, text
from sqlmodel import Field, Relationship, SQLModel

from tracecat import config, registry
from tracecat.auth.credentials import compute_hash, decrypt_object, encrypt_object
from tracecat.dsl.common import DSLInput
from tracecat.types.secrets import SECRET_FACTORY, SecretBase, SecretKeyValue
from tracecat.utils import action_key, gen_id, get_ref

DEFAULT_CASE_ACTIONS = [
"Active compromise",
Expand All @@ -25,15 +24,6 @@
]


def gen_id(prefix: str):
separator = "-"

def wrapper():
return prefix + separator + uuid4().hex

return wrapper


class Resource(SQLModel):
"""Base class for all resources in the system."""

Expand Down Expand Up @@ -351,7 +341,7 @@ class Action(Resource, table=True):
id: str = Field(
default_factory=gen_id("act"), nullable=False, unique=True, index=True
)
type: str
type: str = Field(..., description="The action type, i.e. UDF key")
title: str
description: str
status: str = "offline" # "online" or "offline"
Expand All @@ -366,12 +356,13 @@ class Action(Resource, table=True):
@computed_field
@property
def key(self) -> str:
slug = slugify(self.title, separator="_")
return f"{self.id}.{slug}"
"""Workflow-relative key for an Action."""
return action_key(self.workflow_id, self.id)

@property
def ref(self) -> str:
return slugify(self.title, separator="_")
"""Slugified title of the action. Used for references."""
return get_ref(self.title)


class ActionRun(Resource, table=True):
Expand Down
46 changes: 34 additions & 12 deletions tracecat/dsl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,47 @@ class DSLError(ValueError):


class ActionStatement(BaseModel):
ref: str = Field(pattern=SLUG_PATTERN)
"""Unique reference for the task"""
id: str | None = Field(
default=None,
exclude=True,
description=(
"The action ID. If this is populated means there is a corresponding action"
"in the database `Action` table."
),
)

action: str = Field(pattern=ACTION_TYPE_PATTERN)
"""Namespaced action type"""
ref: str = Field(pattern=SLUG_PATTERN, description="Unique reference for the task")

args: dict[str, Any] = Field(default_factory=dict)
"""Arguments for the action"""
description: str = ""

depends_on: list[str] = Field(default_factory=list)
"""Task dependencies"""
action: str = Field(
pattern=ACTION_TYPE_PATTERN, description="Action type / UDF key"
)

run_if: Annotated[str | None, Field(default=None), TemplateValidator()]
"""Condition to run the task"""
args: dict[str, Any] = Field(
default_factory=dict, description="Arguments for the action"
)

depends_on: list[str] = Field(default_factory=list, description="Task dependencies")

run_if: Annotated[
str | None,
Field(default=None, description="Condition to run the task"),
TemplateValidator(),
]

for_each: Annotated[
str | list[str] | None, Field(default=None), TemplateValidator()
str | list[str] | None,
Field(
default=None,
description="Iterate over a list of items and run the task for each item.",
),
TemplateValidator(),
]
"""Run the task over an iterable"""

@property
def title(self) -> str:
return self.ref.capitalize().replace("_", " ")


class DSLConfig(BaseModel):
Expand Down
Loading

0 comments on commit 820ff66

Please sign in to comment.