diff --git a/tracecat/contexts.py b/tracecat/contexts.py index fb54efb3..1aeed5de 100644 --- a/tracecat/contexts.py +++ b/tracecat/contexts.py @@ -15,6 +15,7 @@ ctx_session_role: ContextVar[Role] = ContextVar("session_role", default=None) +# TODO: Deprecate this contextvar ctx_workflow: ContextVar[Workflow] = ContextVar("workflow", default=None) ctx_workflow_run: ContextVar[WorkflowRunContext] = ContextVar( "workflow_run", default=None diff --git a/tracecat/runner/actions.py b/tracecat/runner/actions.py index 0e96020d..d60aad83 100644 --- a/tracecat/runner/actions.py +++ b/tracecat/runner/actions.py @@ -1,4 +1,4 @@ -"""Actions to be executed as part of a workflow. +"""Actions to bj executed as part of a workflow. Action @@ -35,7 +35,6 @@ from __future__ import annotations import asyncio -import logging import random from collections.abc import Awaitable, Callable, Iterable from enum import StrEnum, auto @@ -49,11 +48,11 @@ from tracecat.concurrency import CloudpickleProcessPoolExecutor from tracecat.config import HTTP_MAX_RETRIES -from tracecat.contexts import ctx_logger, ctx_session_role +from tracecat.contexts import ctx_action_run, ctx_session_role, ctx_workflow_run from tracecat.db import create_vdb_conn from tracecat.integrations import registry from tracecat.llm import DEFAULT_MODEL_TYPE, ModelType, async_openai_call -from tracecat.logging import standard_logger +from tracecat.logging import logger from tracecat.runner.condition import ConditionRuleValidator, ConditionRuleVariant from tracecat.runner.events import ( emit_create_action_run_event, @@ -94,10 +93,6 @@ ACTION_RUN_ID_PREFIX = "ar" -def _get_logger() -> logging.Logger: - return ctx_logger.get() or standard_logger(__name__) - - def action_key_to_id(action_key: str) -> str: return action_key.split(".")[0] @@ -155,17 +150,17 @@ def id(self) -> str: def action_id(self) -> str: return action_key_to_id(self.action_key) - def downstream_dependencies(self, workflow: Workflow, action_key: str) -> list[str]: + def deps_downstream(self, workflow: Workflow) -> list[str]: downstream_deps_ar_ids = [ get_action_run_id(self.workflow_run_id, k) - for k in workflow.adj_list[action_key] + for k in workflow.adj_list[self.action_key] ] return downstream_deps_ar_ids - def upstream_dependencies(self, workflow: Workflow, action_key: str) -> list[str]: + def deps_upstream(self, workflow: Workflow) -> list[str]: upstream_deps_ar_ids = [ get_action_run_id(self.workflow_run_id, k) - for k in workflow.action_dependencies[action_key] + for k in workflow.action_dependencies[self.action_key] ] return upstream_deps_ar_ids @@ -386,7 +381,6 @@ async def _wait_for_dependencies( async def start_action_run( action_run: ActionRun, # Shared data structures - workflow_ref: Workflow, ready_jobs_queue: asyncio.Queue[ActionRun], running_jobs_store: dict[str, asyncio.Task[None]], action_result_store: dict[str, ActionTrail], @@ -394,110 +388,109 @@ async def start_action_run( # Dynamic data pending_timeout: float | None = None, ) -> None: - logger = _get_logger() - try: - await emit_create_action_run_event(action_run) - ar_id = action_run.id - action_key = action_run.action_key - upstream_deps_ar_ids = action_run.upstream_dependencies( - workflow=workflow_ref, action_key=action_key - ) - logger.debug( - f"Action run {ar_id} waiting for dependencies {upstream_deps_ar_ids}." - ) + ctx_action_run.set(action_run) + workflow = ctx_workflow_run.get().workflow + ar_id = action_run.id + run_status: RunStatus = "failure" + with logger.contextualize(ar_id=ar_id): + try: + await emit_create_action_run_event() + action_key = action_run.action_key + upstream_deps_ar_ids = action_run.deps_upstream(workflow=workflow) + logger.bind(deps=upstream_deps_ar_ids).debug("Waiting for dependencies") + + error_msg: str | None = None + result: ActionRunResult | None = None + await asyncio.wait_for( + _wait_for_dependencies(upstream_deps_ar_ids, action_run_status_store), + timeout=pending_timeout, + ) - run_status: RunStatus = "success" - error_msg: str | None = None - result: ActionRunResult | None = None - await asyncio.wait_for( - _wait_for_dependencies(upstream_deps_ar_ids, action_run_status_store), - timeout=pending_timeout, - ) + action_trail = _get_dependencies_results( + upstream_deps_ar_ids, action_result_store + ) - action_trail = _get_dependencies_results( - upstream_deps_ar_ids, action_result_store - ) + logger.opt(lazy=True).debug( + "Running action. Trail {trail}", trail=lambda: list(action_trail.keys()) + ) + action_run_status_store[ar_id] = ActionRunStatus.RUNNING + action_ref = workflow.actions[action_key] + await emit_update_action_run_event(status="running") + + # Every single 'run_xxx_action' function should return a dict + # This dict always contains a key 'output' with the direct result of the action + # The dict may contain additional keys for metadata or other information + # Dunder keys should are only used for carrying certain execution context information + # - __should_continue__: A boolean that indicates whether the workflow should continue + # - output_type: The type of the output + # We keep them in the result for debugging purposes, for now + result = await run_action( + action_trail=action_trail, + action_run_kwargs=action_run.run_kwargs, + **action_ref.model_dump(), + ) - logger.debug(f"Running action {ar_id!r}. Trail {action_trail.keys()}.") - action_run_status_store[ar_id] = ActionRunStatus.RUNNING - action_ref = workflow_ref.actions[action_key] - await emit_update_action_run_event(action_run, status="running") - - # Every single 'run_xxx_action' function should return a dict - # This dict always contains a key 'output' with the direct result of the action - # The dict may contain additional keys for metadata or other information - # Dunder keys should are only used for carrying certain execution context information - # - __should_continue__: A boolean that indicates whether the workflow should continue - # - output_type: The type of the output - # We keep them in the result for debugging purposes, for now - result = await run_action( - action_run_id=action_run.id, - workflow_id=workflow_ref.id, - action_trail=action_trail, - action_run_kwargs=action_run.run_kwargs, - **action_ref.model_dump(), + # Mark the action as completed + action_run_status_store[action_run.id] = ActionRunStatus.SUCCESS + + # Store the result in the action result store. + # Every action has its own result and the trail of actions that led to it. + # The schema is { : , ...} + action_trail = action_trail | {ar_id: result} + action_result_store[ar_id] = action_trail + run_status = "success" + logger.bind(trail=action_trail).debug("Action run completed") + + except TimeoutError as e: + error_msg = "Action run timed out waiting for dependencies" + logger.bind(upstream_deps=upstream_deps_ar_ids).error(error_msg, exc_info=e) + run_status = "failure" + except asyncio.CancelledError as e: + error_msg = "Action run was cancelled." + logger.warning(error_msg, exc_info=e) + run_status = "canceled" + except Exception as e: + error_msg = f"Action run failed with error: {e}." + logger.error(error_msg, exc_info=e) + run_status = "failure" + finally: + if action_run_status_store[ar_id] != ActionRunStatus.SUCCESS: + # Exception was raised before the action was marked as successful + action_run_status_store[ar_id] = ActionRunStatus.FAILURE + + running_jobs_store.pop(ar_id, None) + + await emit_update_action_run_event( + status=run_status, error_msg=error_msg, result=result ) - # Mark the action as completed - action_run_status_store[action_run.id] = ActionRunStatus.SUCCESS - - # Store the result in the action result store. - # Every action has its own result and the trail of actions that led to it. - # The schema is { : , ...} - action_trail = action_trail | {ar_id: result} - action_result_store[ar_id] = action_trail - logger.debug(f"Action run {ar_id!r} completed with trail: {action_trail}.") - - except TimeoutError as e: - error_msg = f"Action run {ar_id} timed out waiting for dependencies {upstream_deps_ar_ids}." - logger.error(error_msg, exc_info=e) - run_status = "failure" - except asyncio.CancelledError as e: - error_msg = f"Action run {ar_id!r} was cancelled." - logger.warning(error_msg, exc_info=e) - run_status = "canceled" - except Exception as e: - error_msg = f"Action run {ar_id!r} failed with error: {e}." - logger.error(error_msg, exc_info=e) - run_status = "failure" - finally: - if action_run_status_store[ar_id] != ActionRunStatus.SUCCESS: - # Exception was raised before the action was marked as successful - action_run_status_store[ar_id] = ActionRunStatus.FAILURE - - running_jobs_store.pop(ar_id, None) - - await emit_update_action_run_event( - action_run, status=run_status, error_msg=error_msg, result=result - ) - - # Handle downstream dependencies - if run_status != "success": - logger.warning(f"Action run {ar_id!r} stopping due to failure.") - return - logger.debug(f"Remaining action runs: {running_jobs_store.keys()}") - if not result.should_continue: - logger.info(f"Action run {ar_id!r} stopping due to stop signal.") - return - try: - downstream_deps_ar_ids = action_run.downstream_dependencies( - workflow=workflow_ref, action_key=action_key + # Handle downstream dependencies + if run_status != "success": + logger.warning("Action run stopping due to failure.") + return + if not result.should_continue: + logger.info("Action run received stop signal.") + return + logger.opt(lazy=True).debug( + "Remaining action runs {ars}", ars=lambda: list(running_jobs_store.keys()) ) - # Broadcast the results to the next actions and enqueue them - for next_ar_id in downstream_deps_ar_ids: - if next_ar_id not in action_run_status_store: - action_run_status_store[next_ar_id] = ActionRunStatus.QUEUED - ready_jobs_queue.put_nowait( - ActionRun( - workflow_run_id=action_run.workflow_run_id, - action_key=parse_action_run_id(next_ar_id, "action_key"), + try: + downstream_deps_ar_ids = action_run.deps_downstream(workflow=workflow) + # Broadcast the results to the next actions and enqueue them + for next_ar_id in downstream_deps_ar_ids: + if next_ar_id not in action_run_status_store: + action_run_status_store[next_ar_id] = ActionRunStatus.QUEUED + ready_jobs_queue.put_nowait( + ActionRun( + workflow_run_id=action_run.workflow_run_id, + action_key=parse_action_run_id(next_ar_id, "action_key"), + ) ) - ) - except Exception as e: - logger.error( - f"Action run {ar_id!r} failed to broadcast results to downstream dependencies.", - exc_info=e, - ) + except Exception as e: + logger.error( + "Action run failed to broadcast results to downstream dependencies.", + exc_info=e, + ) async def run_webhook_action( @@ -507,13 +500,12 @@ async def run_webhook_action( action_run_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: """Run a webhook action.""" - logger = _get_logger() - logger.debug("Perform webhook action") - logger.debug(f"{url = }") - logger.debug(f"{method = }") - # The payload provided to the webhook action in the HTTP request action_run_kwargs = action_run_kwargs or {} - logger.debug(f"{action_run_kwargs = }") + logger.bind( + url=url, + method=method, + ar_kwargs=action_run_kwargs, + ).debug("Perform webhook action") # TODO: Perform whitelist/filter step here using the url and method return { "output": action_run_kwargs, @@ -555,12 +547,12 @@ async def run_http_request_action( action_run_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: """Run an HTTP request action.""" - logger = _get_logger() - logger.debug("Perform HTTP request action") - logger.debug(f"{url = }") - logger.debug(f"{method = }") - logger.debug(f"{headers = }") - logger.debug(f"{payload = }") + logger.bind( + url=url, + method=method, + headers=headers, + payload=payload, + ).debug("Perform HTTP request action") try: async with httpx.AsyncClient() as client: @@ -584,8 +576,7 @@ async def run_conditional_action( action_run_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: """Run a conditional action.""" - logger = _get_logger() - logger.debug(f"Run conditional rules {condition_rules}.") + logger.bind(rules=condition_rules).debug("Perform conditional rules action") rule = ConditionRuleValidator.validate_python(condition_rules) rule_match = rule.evaluate() return { @@ -608,16 +599,15 @@ async def run_llm_action( action_run_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: """Run an LLM action.""" - logger = _get_logger() - logger.debug("Perform LLM action") - logger.debug(f"{message = }") - logger.debug(f"{response_schema = }") + logger.bind( + message=message, + response_schema=response_schema, + ).debug("Perform LLM action") llm_kwargs = llm_kwargs or {} # TODO(perf): Avoid re-creating the task fields object if possible validated_task_fields = TaskFields.from_dict(task_fields) - logger.debug(f"{type(validated_task_fields) = }") if response_schema is None: system_context = get_system_context( @@ -663,12 +653,12 @@ async def run_send_email_action( action_run_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: """Run a send email action.""" - logger = _get_logger() - logger.debug("Perform send email action") - logger.debug(f"{sender = }") - logger.debug(f"{recipients = }") - logger.debug(f"{subject = }") - logger.debug(f"{body = }") + logger.bind( + sender=sender, + recipients=recipients, + subject=subject, + body=body, + ).debug("Perform send email action") if provider == "resend": email_provider = ResendMailProvider( @@ -695,7 +685,7 @@ async def run_send_email_action( await email_provider.send() except httpx.HTTPError as exc: msg = "Failed to post email to provider" - logger.error(msg, exc_info=exc) + logger.opt(exception=exc).error(msg, exc_info=exc) email_response = { "status": "error", "message": msg, @@ -750,7 +740,6 @@ async def run_open_case_action( # Common action_run_kwargs: dict[str, Any] | None = None, ) -> dict[str, str | dict[str, str] | None]: - logger = _get_logger() db = create_vdb_conn() tbl = db.open_table("cases") role = ctx_session_role.get() @@ -770,7 +759,7 @@ async def run_open_case_action( suppression=suppression, tags=tags, ) - logger.info(f"Sinking case: {case = }") + logger.opt(lazy=True).debug("Sinking case {case}", case=lambda: case.model_dump()) try: await asyncio.to_thread(tbl.add, [case.flatten()]) except Exception as e: @@ -787,10 +776,10 @@ async def run_integration_action( action_run_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: """Run an integration action.""" - logger = _get_logger() - logger.debug("Perform integration action") - logger.debug(f"{qualname = }") - logger.debug(f"{params = }") + logger.bind( + qualname=qualname, + params=params, + ).debug("Perform integration action") params = params or {} @@ -809,8 +798,6 @@ async def run_integration_action( async def run_action( type: ActionType, - action_run_id: str, - workflow_id: str, key: str, title: str, action_trail: dict[str, ActionRunResult], @@ -833,21 +820,19 @@ async def run_action( - transform: Apply a transformation to the data. """ - logger = _get_logger() - logger.debug(f"{"*" * 10} Running action {"*" * 10}") - logger.debug(f"{key = }") - logger.debug(f"{title = }") - logger.debug(f"{type = }") - logger.debug(f"{action_run_kwargs = }") - logger.debug(f"{action_kwargs = }") - logger.debug(f"{"*" * 20}") + logger.bind( + key=key, + title=title, + type=type, + action_run_kwargs=action_run_kwargs, + action_kwargs=action_kwargs, + ).info("Running action") action_runner = _ACTION_RUNNER_FACTORY[type] action_trail_json = { result.action_slug: result.output for result in action_trail.values() } - logger.debug(f"Before template eval: {action_trail_json = }") action_kwargs_with_secrets = await evaluate_templated_secrets( templated_fields=action_kwargs ) @@ -860,11 +845,13 @@ async def run_action( processed_action_kwargs.update(action_trail=action_trail) elif type == "open_case": - processed_action_kwargs.update( - action_run_id=action_run_id, workflow_id=workflow_id - ) + ar_id = ctx_action_run.get().id + workflow = ctx_workflow_run.get().workflow + processed_action_kwargs.update(action_run_id=ar_id, workflow_id=workflow.id) - logger.debug(f"{processed_action_kwargs = }") + logger.bind(processed_action_kwargs=processed_action_kwargs).debug( + "Finish processing action kwargs" + ) try: # The return value from each action runner call should be more or less what @@ -875,7 +862,7 @@ async def run_action( **processed_action_kwargs, ) except Exception as e: - logger.error(f"Error running action {title} with key {key}.", exc_info=e) + logger.bind(key=key).error("Error running action", exc_info=e) raise # Leave dunder keys inside as a form of execution context diff --git a/tracecat/runner/app.py b/tracecat/runner/app.py index 73ea08b3..82bbf0c9 100644 --- a/tracecat/runner/app.py +++ b/tracecat/runner/app.py @@ -53,12 +53,11 @@ from tracecat.auth import AuthenticatedAPIClient, Role, authenticate_service from tracecat.config import TRACECAT__API_URL, TRACECAT__APP_ENV from tracecat.contexts import ( - ctx_logger, ctx_mq_channel_pool, ctx_session_role, - ctx_workflow, + ctx_workflow_run, ) -from tracecat.logging import LoggerFactory, standard_logger +from tracecat.logging import logger from tracecat.messaging import use_channel_pool from tracecat.middleware import RequestLoggingMiddleware from tracecat.runner.actions import ( @@ -79,6 +78,7 @@ StartWorkflowResponse, WorkflowResponse, ) +from tracecat.types.workflow import WorkflowRunContext rabbitmq_channel_pool: Pool[Channel] @@ -94,8 +94,7 @@ async def lifespan(app: FastAPI): def create_app(**kwargs) -> FastAPI: global logger app = FastAPI(**kwargs) - app.logger = LoggerFactory.make_logger(name="runner.server") - logger = LoggerFactory.make_logger(name="runner") + app.logger = logger return app @@ -131,7 +130,9 @@ def create_app(**kwargs) -> FastAPI: app.add_middleware(RequestLoggingMiddleware) # TODO: Check TRACECAT__APP_ENV to set methods and headers -logger.bind(env=TRACECAT__APP_ENV, origins=cors_origins_kwargs).info("App started") +logger.bind(env=TRACECAT__APP_ENV, origins=cors_origins_kwargs).warning( + "Runner started" +) class RunnerStatus(StrEnum): @@ -140,9 +141,6 @@ class RunnerStatus(StrEnum): SHUTTING_DOWN = auto() -runner_status: RunnerStatus = RunnerStatus.RUNNING - - # Dynamic data action_result_store: dict[str, ActionTrail] = {} action_run_status_store: dict[str, ActionRunStatus] = {} @@ -157,7 +155,7 @@ async def get_workflow(workflow_id: str) -> Workflow: response = await client.get(f"/workflows/{workflow_id}") response.raise_for_status() except HTTPException as e: - logger.error(e.detail) + logger.opt(exception=e).error(e.detail) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An error occurred while fetching the workflow.", @@ -183,6 +181,8 @@ async def valid_workflow(workflow_id: str) -> str: # Catch-all exception handler to prevent stack traces from leaking @app.exception_handler(Exception) async def custom_exception_handler(request: Request, exc: Exception): + role = ctx_session_role.get() + logger.opt(exception=exc).bind(role=role).error("An unexpected error occurred.") return ORJSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": "An unexpected error occurred. Please try again later."}, @@ -296,30 +296,32 @@ async def webhook( - Spawn a new process to handle the event. - Store the process in a queue. """ - logger.info(f"Received webhook with entrypoint {webhook_metadata.action_key}") - logger.debug(f"{payload =}") + logger.bind(entrypoint=webhook_metadata.action_key).info("Webhook hit") + logger.bind(payload=payload).debug("Webhook payload") user_id = webhook_metadata.owner_id # If we are here this should be set role = Role(type="service", service_id="tracecat-runner", user_id=user_id) ctx_session_role.set(role) - logger.info(f"Set session role context for {role}") - workflow_id = webhook_metadata.workflow_id - workflow_response = await get_workflow(workflow_id) - if workflow_response.status == "offline": - return StartWorkflowResponse( - status="error", message="Workflow offline", id=workflow_id - ) + with logger.contextualize(role=role): + logger.info("Triggering workflow from webhook") + workflow_id = webhook_metadata.workflow_id + workflow_response = await get_workflow(workflow_id) + if workflow_response.status == "offline": + logger.error("Workflow offline") + return StartWorkflowResponse( + status="error", message="Workflow offline", id=workflow_id + ) - # This data refers to the webhook specific data - response = await start_workflow( - role=role, - workflow_id=workflow_id, - start_workflow_params=StartWorkflowParams( - entrypoint_key=webhook_metadata.action_key, entrypoint_payload=payload - ), - background_tasks=background_tasks, - ) - return response + # This data refers to the webhook specific data + response = await start_workflow( + role=role, + workflow_id=workflow_id, + start_workflow_params=StartWorkflowParams( + entrypoint_key=webhook_metadata.action_key, entrypoint_payload=payload + ), + background_tasks=background_tasks, + ) + return response @app.post("/workflows/{workflow_id}") @@ -400,81 +402,74 @@ async def run_workflow( associate the worker with a specific workflow. - The `start_workflow` function can then just directly enqueue the first action. """ - # TODO: Move some of these into ContextVars - workflow_run_id = uuid4().hex - run_logger = standard_logger(f"wfr-{workflow_run_id}") - ctx_logger.set(run_logger) - try: - await emit_create_workflow_run_event( - workflow_id=workflow_id, workflow_run_id=workflow_run_id - ) - workflow = await get_workflow(workflow_id) - run_logger.info(f"Set workflow context for user {workflow.owner_id}") - ctx_workflow.set(workflow) - - # Initial state - ready_jobs_queue.put_nowait( - ActionRun( - workflow_run_id=workflow_run_id, - run_kwargs=entrypoint_payload, - action_key=entrypoint_key, + workflow_run_id = f"wfr_{uuid4().hex}" + role = ctx_session_role.get() + run_status: RunStatus = "failure" + + with logger.contextualize( + user_id=role.user_id, workflow_id=workflow_id, wfr_id=workflow_run_id + ): + run_logger = logger.bind(tag="runner.queue") + try: + workflow = await get_workflow(workflow_id) + run_logger.info("Set workflow context") + run_context = WorkflowRunContext( + workflow=workflow, workflow_run_id=workflow_run_id, status="pending" ) - ) - - run_status: RunStatus = "success" - - await emit_update_workflow_run_event( - workflow_id=workflow_id, - workflow_run_id=workflow_run_id, - status="running", - ) - while ( - not ready_jobs_queue.empty() or running_jobs_store - ) and runner_status == RunnerStatus.RUNNING: - try: - action_run = await asyncio.wait_for(ready_jobs_queue.get(), timeout=3) - except TimeoutError: - continue - # Defensive: Deduplicate tasks - if ( - action_run.id in running_jobs_store - or action_run.id in action_result_store - ): - run_logger.debug( - f"Action {action_run.id!r} already running or completed. Skipping." + ctx_workflow_run.set(run_context) + await emit_create_workflow_run_event() + + # Initial state + ready_jobs_queue.put_nowait( + ActionRun( + workflow_run_id=workflow_run_id, + run_kwargs=entrypoint_payload, + action_key=entrypoint_key, ) - continue - - run_logger.info( - f"{workflow.actions[action_run.action_key].__class__.__name__} {action_run.id!r} ready. Running." ) - action_run_status_store[action_run.id] = ActionRunStatus.PENDING - # Schedule a new action run - running_jobs_store[action_run.id] = asyncio.create_task( - start_action_run( - action_run=action_run, - workflow_ref=workflow, - ready_jobs_queue=ready_jobs_queue, - running_jobs_store=running_jobs_store, - action_result_store=action_result_store, - action_run_status_store=action_run_status_store, - pending_timeout=120, + + await emit_update_workflow_run_event(status="running") + while not ready_jobs_queue.empty() or running_jobs_store: + try: + action_run = await asyncio.wait_for( + ready_jobs_queue.get(), timeout=3 + ) + except TimeoutError: + continue + # Defensive: Deduplicate tasks + if ( + action_run.id in running_jobs_store + or action_run.id in action_result_store + ): + run_logger.debug( + f"Action {action_run.id!r} already running or completed. Skipping." + ) + continue + + run_logger.bind(ar_id=action_run.id).info("Creating action run task") + action_run_status_store[action_run.id] = ActionRunStatus.PENDING + # Schedule a new action run + running_jobs_store[action_run.id] = asyncio.create_task( + start_action_run( + action_run=action_run, + ready_jobs_queue=ready_jobs_queue, + running_jobs_store=running_jobs_store, + action_result_store=action_result_store, + action_run_status_store=action_run_status_store, + pending_timeout=120, + ) ) - ) - run_logger.info("Workflow completed.") - except asyncio.CancelledError: - run_logger.warning("Workflow was canceled.", exc_info=True) - run_status = "canceled" - except Exception as e: - run_logger.error(f"Workflow failed: {e}", exc_info=True) - run_status = "failure" - finally: - run_logger.info("Shutting down running tasks") - for running_task in running_jobs_store.values(): - running_task.cancel() - - # TODO: Update this to update with status 'failure' if any action fails - await emit_update_workflow_run_event( - workflow_id=workflow_id, workflow_run_id=workflow_run_id, status=run_status - ) + run_status = "success" + run_logger.info("Workflow completed.") + except asyncio.CancelledError: + logger.warning("Workflow was canceled.", exc_info=True) + run_status = "canceled" + except Exception as e: + logger.error(f"Workflow failed: {e}", exc_info=True) + finally: + logger.info("Shutting down running tasks") + for running_task in running_jobs_store.values(): + running_task.cancel() + # TODO: Update this to update with status 'failure' if any action fails + await emit_update_workflow_run_event(status=run_status) diff --git a/tracecat/runner/events.py b/tracecat/runner/events.py index ae5080d4..c9acf6d0 100644 --- a/tracecat/runner/events.py +++ b/tracecat/runner/events.py @@ -13,25 +13,31 @@ from typing import TYPE_CHECKING from tracecat.auth import AuthenticatedAPIClient -from tracecat.contexts import ctx_mq_channel_pool, ctx_session_role +from tracecat.contexts import ( + ctx_action_run, + ctx_mq_channel_pool, + ctx_session_role, + ctx_workflow_run, +) from tracecat.db import ActionRun as ActionRunEvent from tracecat.db import WorkflowRun as WorkflowRunEvent -from tracecat.logging import standard_logger +from tracecat.logging import logger from tracecat.messaging import publish from tracecat.types.api import RunStatus if TYPE_CHECKING: - from tracecat.runner.actions import ActionRun, ActionRunResult -logger = standard_logger(__name__) + from tracecat.runner.actions import ActionRunResult ## Workflow Run Events -async def emit_create_workflow_run_event( - *, workflow_id: str, workflow_run_id: str -) -> None: +async def emit_create_workflow_run_event() -> None: """Create a workflow run.""" role = ctx_session_role.get() + wfr = ctx_workflow_run.get() + workflow_run_id = wfr.workflow_run_id + workflow_id = wfr.workflow.id + time_now = datetime.now(UTC) event = WorkflowRunEvent( id=workflow_run_id, @@ -55,17 +61,17 @@ async def emit_create_workflow_run_event( routing_keys=[role.user_id], payload={"type": "workflow_run", **event.model_dump()}, ) - logger.info(f"Emitted create workflow run event: {workflow_id=}") + logger.bind(name="events.create_wfr", role=role, workflow_id=workflow_id).debug( + "Emitted event" + ) -async def emit_update_workflow_run_event( - *, - workflow_id: str, - workflow_run_id: str, - status: RunStatus, -) -> None: +async def emit_update_workflow_run_event(*, status: RunStatus) -> None: """Update a workflow run.""" role = ctx_session_role.get() + wfr = ctx_workflow_run.get() + workflow_id = wfr.workflow.id + workflow_run_id = wfr.workflow_run_id time_now = datetime.now(UTC) event = WorkflowRunEvent( @@ -92,14 +98,21 @@ async def emit_update_workflow_run_event( routing_keys=[role.user_id], payload={"type": "workflow_run", **event.model_dump()}, ) - logger.info(f"Emitted update workflow run event: {workflow_run_id=}, {status=}") + logger.bind( + name="events.update_wfr", + role=role, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + status=status, + ).debug("Emitted event") ## Action Run Events -async def emit_create_action_run_event(action_run: ActionRun) -> None: +async def emit_create_action_run_event() -> None: """Create a workflow run.""" + action_run = ctx_action_run.get() action_id = action_run.action_id role = ctx_session_role.get() @@ -127,17 +140,21 @@ async def emit_create_action_run_event(action_run: ActionRun) -> None: routing_keys=[role.user_id], payload={"type": "action_run", **event.model_dump()}, ) - logger.info(f"Emitted create action run event: {action_run.id=}") + logger.bind( + name="events.create_ar", + action_id=action_id, + role=role, + ).debug("Emitted event") async def emit_update_action_run_event( - action_run: ActionRun, *, status: RunStatus, error_msg: str | None = None, result: ActionRunResult | None = None, ) -> None: """Update a workflow run.""" + action_run = ctx_action_run.get() action_id = action_run.action_id role = ctx_session_role.get() @@ -172,4 +189,10 @@ async def emit_update_action_run_event( routing_keys=[role.user_id], payload={"type": "action_run", **event.model_dump()}, ) - logger.info(f"Emitted update action run event: {action_run.id=}, {status=}.") + logger.bind( + name="events.update_ar", + role=role, + action_id=action_id, + action_run_id=action_run.id, + status=status, + ).debug("Emitted event") diff --git a/tracecat/types/workflow.py b/tracecat/types/workflow.py new file mode 100644 index 00000000..aac480f7 --- /dev/null +++ b/tracecat/types/workflow.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from tracecat.runner.workflows import Workflow +from tracecat.types.api import RunStatus + + +class WorkflowRunContext(BaseModel): + workflow_run_id: str + workflow: Workflow + status: RunStatus