Skip to content

Commit

Permalink
feat: Implement basic runner (#1)
Browse files Browse the repository at this point in the history
* feat(engine): Set runner endpoint default response class to orjson

* build: Remove polars

* feat(engine): Implement basic runner with workflow execution and llm + webhook + http actions

* test: Add pytest config
  • Loading branch information
daryllimyt committed Mar 3, 2024
1 parent 2fb9b08 commit 8a2ea57
Show file tree
Hide file tree
Showing 11 changed files with 915 additions and 6 deletions.
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ classifiers = [
dependencies = [
"fastapi",
"uvicorn",
"pydantic"
"pydantic==2.6.1",
"openai",
"tenacity",
]
dynamic = ["version"]

Expand All @@ -38,7 +40,7 @@ Repository = "https://github.com/TracecatHQ/tracecat"
[project.optional-dependencies]
runner = [
"orjson",
"polars"
"httpx[http2]",
]

[tool.hatch.version]
Expand Down Expand Up @@ -66,6 +68,13 @@ ignore = [
"W191", # indentation contains tabs
]

[tool.pytest.ini_options]
addopts = ["--strict-config", "--strict-markers", "-vvrP"]
xfail_strict = true
log_level = "INFO"
log_cli = true
log_cli_level = "INFO"

[tool.ruff.pyupgrade]
# Preserve types, even if a file imports `from __future__ import annotations`.
keep-runtime-typing = true
291 changes: 291 additions & 0 deletions tracecat/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
from __future__ import annotations

import textwrap
from collections.abc import Awaitable, Callable
from typing import Any, Literal
from uuid import uuid4

import httpx
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential

from tracecat.config import MAX_RETRIES
from tracecat.llm import DEFAULT_MODEL_TYPE, ModelType, async_openai_call
from tracecat.logger import standard_logger
from tracecat.types import TemplatedField

logger = standard_logger(__name__)

# TODO: Add support for the rest of the Actions
ActionType = Literal[
"webhook",
"http_request",
"condition",
"llm",
"receive_email",
"send_email",
"transform",
]


class Action(BaseModel):
"""An action in a workflow graph.
An action is an instance of a Action with templated fields."""

id: str = Field(default_factory=lambda: uuid4().hex)
type: ActionType
title: str
tags: dict[str, Any] | None = None
# Templated variables to be replaced with actual values
# based on the results of the previous step
templated_fields: list[TemplatedField] = Field(default_factory=list)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Action:
action_type = data.pop("type")
action_cls = ACTION_FACTORY[action_type]
return action_cls(**data)


class ActionResult(BaseModel):
"""The result of an action."""

id: str = Field(default_factory=lambda: uuid4().hex)
action_id: str
action_title: str
data: dict[str, Any] = Field(default_factory=dict)
should_continue: bool = True


class WebhookAction(Action):
type: Literal["webhook"] = Field("webhook", frozen=True)

url: str | None = None
method: Literal["GET", "POST"] = "POST"


class HTTPRequestAction(Action):
type: Literal["http_request"] = Field("http_request", frozen=True)

url: str | None = None
method: Literal["GET", "POST", "PUT", "PATCH", "DELETE"] = "GET"
headers: dict[str, str] = Field(default_factory=dict)
payload: dict[str, Any] = Field(default_factory=dict)


class ConditionAction(Action):
type: Literal["condition"] = Field("condition", frozen=True)

# TODO: Replace placeholder
event: str | None = None


class LLMAction(Action):
"""
Represents an LLM action.
Attributes:
type (Literal["llm"]): The type of the action, which is always "llm".
instructions (str): The instructions for the LLM action.
system_context (str | None): The system context for the LLM action, if any.
model (ModelType): The model type for the LLM action.
response_schema (dict[str, Any] | None): The response schema for the LLM action, if any.
kwargs (dict[str, Any] | None): Additional keyword arguments for the LLM action, if any.
"""

type: Literal["llm"] = Field("llm", frozen=True)

instructions: str
system_context: str | None = None
model: ModelType = DEFAULT_MODEL_TYPE
response_schema: dict[str, Any] | None = None
kwargs: dict[str, Any] | None = None


ActionTrail = dict[str, ActionResult]
ActionSubclass = WebhookAction | HTTPRequestAction | ConditionAction | LLMAction


ACTION_FACTORY: dict[str, type[Action]] = {
"webhook": WebhookAction,
"http_request": HTTPRequestAction,
"condition": ConditionAction,
"llm": LLMAction,
}


async def run_action(
type: ActionType,
id: str,
title: str,
action_trail: dict[str, ActionResult],
templated_fields: list[TemplatedField],
tags: dict[str, Any] | None = None,
**action_kwargs: Any,
) -> ActionResult:
"""Run an action.
In this step we should populate the templated fields with actual values.
Each action should only receive the actual values it needs to run.
Actions
-------
- webhook: Forward the data in the POST body to the next node
- http_equest: Send an HTTP request to the specified URL, then parse the result.
- conditional: Conditional logic to trigger other actions based on the result of the previous action.
- llm: Apply a language model to the data.
- receive_email: Receive an email and parse the data.
- send_email: Send an email.
- transform: Apply a transformation to the data.
"""

logger.debug(f"Running action {title} with id {id} of type {type}.")
action_runner = ACTION_RUNNER_FACTORY[type]

# TODO: Populate the templated fields with actual values

try:
result = await action_runner(action_trail=action_trail, **action_kwargs)
except Exception as e:
logger.error(f"Error running action {title} with id {id}.", exc_info=e)
raise
return ActionResult(action_id=id, action_title=title, data=result)


async def run_webhook_action(
action_trail: ActionTrail, url: str, method: str
) -> dict[str, Any]:
"""Run a webhook action."""
logger.info("Perform webhook action")
logger.info(f"{url = }")
logger.info(f"{method = }")
return {"data": "test_webhook"}


@retry(
stop=stop_after_attempt(MAX_RETRIES),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
async def run_http_request_action(
action_trail: ActionTrail,
url: str,
method: str,
headers: dict[str, str] | None,
payload: dict[str, str | bytes] | None,
) -> dict[str, Any]:
"""Run an HTTP request action."""
logger.info("Perform HTTP request action")
logger.info(f"{url = }")
logger.info(f"{method = }")
logger.info(f"{headers = }")
logger.info(f"{payload = }")

try:
async with httpx.AsyncClient(http2=True) as client:
response = await client.request(
method=method,
url=url,
headers=headers,
data=payload,
)
response.raise_for_status()
data: dict[str, Any] = response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP request failed with status {e.response.status_code}.")
raise
return data


async def run_conditional_action(
action_trail: ActionTrail, event: str
) -> dict[str, Any]:
"""Run a conditional action."""
logger.debug(f"Run conditional event {event}.")
return {"data": "test_conditional"}


async def run_llm_action(
action_trail: ActionTrail,
instructions: str,
system_context: str | None = None,
model: ModelType = DEFAULT_MODEL_TYPE,
response_schema: dict[str, Any] | None = None,
kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Run an LLM action."""
logger.info("Perform LLM action")
logger.info(f"{instructions = }")
logger.info(f"{response_schema = }")

system_context = (
"You are an expert decision maker and instruction follower."
" You will be given JSON data as context to help you make a decision."
)
kwargs = kwargs or {}
if response_schema is None:
prompt = textwrap.dedent(
f"""
Your objective is the following: {instructions}
You have also been provided with the following JSON data of the previous task execution results.
The keys are the action ids and the values are the results of the actions.
```
{action_trail}
```
You must complete the objective using the past task execution data.
"""
)
logger.debug(f"Prompt: {prompt}")
text_response: str = await async_openai_call(
prompt=prompt,
model=model,
system_context=system_context,
response_format="text",
**kwargs,
)
return {"response": text_response}
else:
prompt = textwrap.dedent(
f"""
Your objective is the following: {instructions}
You have also been provided with the following JSON data of the previous task execution results:
```
{action_trail}
```
You must complete the objective using the past task execution data.
Create a `JSONDataResponse` according to the following pydantic model:
```
class JSONDataResponse(BaseModel):
{"\n".join(f"\t{k}: {v}" for k, v in response_schema.items())}
```
"""
)
logger.debug(f"Prompt: {prompt}")
json_response: dict[str, Any] = await async_openai_call(
prompt=prompt,
model=model,
system_context=system_context,
response_format="json_object",
**kwargs,
)
if "JSONDataResponse" in json_response:
inner_dict: dict[str, str] = json_response["JSONDataResponse"]
return inner_dict
return json_response


_ActionRunner = Callable[..., Awaitable[dict[str, Any]]]

ACTION_RUNNER_FACTORY: dict[ActionType, _ActionRunner] = {
"webhook": run_webhook_action,
"http_request": run_http_request_action,
"condition": run_conditional_action,
"llm": run_llm_action,
}
2 changes: 1 addition & 1 deletion tracecat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@


@app.get("/")
def root():
def root() -> dict[str, str]:
return {"message": "Hello world. I am the API."}
1 change: 1 addition & 0 deletions tracecat/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MAX_RETRIES = 3
24 changes: 24 additions & 0 deletions tracecat/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
def find_entrypoint(graph: dict[str, list[str]]) -> str:
"""Find the entrypoint of a workflow.
This is the first node in the graph with no incoming edges.
"""

nodes = find_entrypoints(graph)

if len(nodes) == 0:
raise ValueError("No entrypoints found.")
return nodes.pop()


def find_entrypoints(graph: dict[str, list[str]]) -> list[str]:
"""Find the entrypoints of a graph.
This is the first node in the graph with no incoming edges.
"""

nodes = set(graph.keys())
for edges in graph.values():
nodes.difference_update(edges)

return list(nodes)
Loading

0 comments on commit 8a2ea57

Please sign in to comment.