Skip to content

Commit

Permalink
refactor(engine): Replace standard_logger
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed May 8, 2024
1 parent f31bc17 commit 35b56bb
Show file tree
Hide file tree
Showing 21 changed files with 76 additions and 112 deletions.
10 changes: 4 additions & 6 deletions tracecat/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
create_vdb_conn,
initialize_db,
)
from tracecat.logging import LoggerFactory
from tracecat.logging import Logger

# TODO: Clean up API params / response "zoo"
# lots of repetition and inconsistency
Expand Down Expand Up @@ -83,6 +83,7 @@
)
from tracecat.types.cases import Case, CaseMetrics

logger = Logger("api")
engine: Engine
rabbitmq_channel_pool: Pool[Channel]

Expand All @@ -99,8 +100,7 @@ async def lifespan(app: FastAPI):
def create_app(**kwargs) -> FastAPI:
global logger
app = FastAPI(**kwargs)
app.logger = LoggerFactory.make_logger(name="api.server")
logger = LoggerFactory.make_logger(name="api")
app.logger = logger
return app


Expand Down Expand Up @@ -136,7 +136,7 @@ def create_app(**kwargs) -> FastAPI:
app.add_middleware(RequestLoggingMiddleware)

# TODO: Check TRACECAT__APP_ENV to set methods and headers
logger.bind(env=TRACECAT__APP_ENV, origins=cors_origins_kwargs).info("App started")
logger.bind(env=TRACECAT__APP_ENV, origins=cors_origins_kwargs).warning("App started")


# Catch-all exception handler to prevent stack traces from leaking
Expand Down Expand Up @@ -1291,7 +1291,6 @@ async def streaming_autofill_case_fields(
3. Complete the fields
"""
logger.info(f"Received: {cases = }, {role = }, {fields = }")
fields_set = set(fields)

if not fields_set:
Expand Down Expand Up @@ -1462,7 +1461,6 @@ def get_secret(
Support access for both user and service roles."""

logger.info(f"Role: {role}")
with Session(engine) as session:
# Check if secret exists
statement = (
Expand Down
42 changes: 20 additions & 22 deletions tracecat/api/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from slugify import slugify

from tracecat.llm import async_openai_call
from tracecat.logging import standard_logger
from tracecat.logging import logger
from tracecat.types.cases import Case

logger = standard_logger(__name__)
T = TypeVar("T", bound=BaseModel)


Expand Down Expand Up @@ -246,29 +245,28 @@ async def stream_case_completions(
output_cls=CaseMissingFieldsResponse,
field_cons=field_cons,
)
# logger.info(f"🧠 Starting case completions for %d cases... {system_context =}")

logger.critical(system_context)
logger.critical(f"{field_cons =}")
logger.info("🧠 Starting case completions for %d cases...", len(cases))
logger.bind(system_context=system_context).debug("System context")

async def task(case: Case) -> str:
prompt = f"""Case JSON Object: ```\n{case.model_dump_json()}\n```"""
logger.info(f"🧠 Starting case completion for case {case.id}...")
response: dict[str, str] = await async_openai_call(
prompt=prompt,
model="gpt-4-turbo-preview",
system_context=system_context,
response_format="json_object",
max_tokens=200,
)
# We might have to perform additional matching / postprocessing here
# Depending on what we return.
result = CaseCompletionResponse.model_validate(
{"id": case.id, "response": response}
)
# await asyncio.sleep(random.uniform(1, 10))
logger.info(f"🧠 Completed case completion for case {case.id}")
return result.model_dump_json()
with logger.contextualize(case_id=case.id):
logger.info("🧠 Starting case completion")
response: dict[str, str] = await async_openai_call(
prompt=prompt,
model="gpt-4-turbo-preview",
system_context=system_context,
response_format="json_object",
max_tokens=200,
)
# We might have to perform additional matching / postprocessing here
# Depending on what we return.
result = CaseCompletionResponse.model_validate(
{"id": case.id, "response": response}
)
# await asyncio.sleep(random.uniform(1, 10))
logger.info("🧠 Completed case completion")
return result.model_dump_json()

tasks = [task(case) for case in cases]

Expand Down
4 changes: 1 addition & 3 deletions tracecat/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import tracecat.config as cfg
from tracecat.config import TRACECAT__API_URL, TRACECAT__RUNNER_URL
from tracecat.contexts import ctx_session_role
from tracecat.logging import standard_logger

logger = standard_logger(__name__)
from tracecat.logging import logger

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
api_key_header_scheme = APIKeyHeader(name="X-API-Key", auto_error=False)
Expand Down
10 changes: 5 additions & 5 deletions tracecat/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@

from tracecat.auth import Role
from tracecat.contexts import ctx_session_role
from tracecat.logging import standard_logger

logger = standard_logger(__name__)
from tracecat.logging import Logger

_P = ParamSpec("_P")

logger = Logger("executor.cloudpickle")


def _run_serialized_fn(serialized_wrapped_fn: bytes, role: Role, /, *args, **kwargs):
# NOTE: This is not the raw function - it is still wrapped by the `wrapper` decorator
wrapped_fn: Callable[_P, Any] = cloudpickle.loads(serialized_wrapped_fn)
ctx_session_role.set(role)
logger.debug(f"{role=}")
logger.bind(role=role).debug("Running serialized function")
kwargs["__role"] = role
res = wrapped_fn(*args, **kwargs)
return res
Expand All @@ -28,6 +28,6 @@ class CloudpickleProcessPoolExecutor(ProcessPoolExecutor):
def submit(self, fn: Callable[_P, Any], /, *args, **kwargs):
# We need to pass the role to the function running in the child process
role = ctx_session_role.get()
logger.debug(f"{role=}")
logger.bind(role=role).debug("Serializing function")
serialized_fn = cloudpickle.dumps(fn)
return super().submit(_run_serialized_fn, serialized_fn, role, *args, **kwargs)
19 changes: 0 additions & 19 deletions tracecat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,3 @@
"tracecat-api",
"tracecat-scheduler",
]


_DEFAULT_LOG_FORMAT = (
"<level>{level: <8}</level>"
" [<cyan>{time:YYYY-MM-DD HH:mm:ss.SSS}</cyan>]"
" [<green>{process}</green>][<magenta>{thread}</magenta>]"
" <light-red>{name}</light-red>:<light-red>{function}</light-red>"
" - <level>{message}</level> | {extra}"
)

LOG_CONFIG = {
"logger": {
"path": "/var/lib/tracecat/logs",
"level": "info",
"rotation": "20 days",
"retention": "1 months",
"format": _DEFAULT_LOG_FORMAT,
},
}
7 changes: 7 additions & 0 deletions tracecat/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@
from aio_pika.pool import Pool

from tracecat.auth import Role
from tracecat.runner.actions import ActionRun
from tracecat.runner.workflows import Workflow
from tracecat.types.workflow import WorkflowRunContext


ctx_session_role: ContextVar[Role] = ContextVar("session_role", default=None)
ctx_workflow: ContextVar[Workflow] = ContextVar("workflow", default=None)
ctx_workflow_run: ContextVar[WorkflowRunContext] = ContextVar(
"workflow_run", default=None
)
ctx_action_run: ContextVar[ActionRun] = ContextVar("action_run", default=None)
ctx_mq_channel_pool: ContextVar[Pool[Channel]] = ContextVar(
"mq_channel_pool", default=None
)
Expand Down
4 changes: 2 additions & 2 deletions tracecat/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
TRACECAT_DIR,
)
from tracecat.labels.mitre import get_mitre_tactics_techniques
from tracecat.logging import standard_logger
from tracecat.logging import Logger
from tracecat.types.secrets import SECRET_FACTORY, SecretBase, SecretKeyValue

if TYPE_CHECKING:
from tracecat.integrations import IntegrationSpec

logger = standard_logger("db")
logger = Logger("db")


STORAGE_PATH = TRACECAT_DIR / "storage"
Expand Down
4 changes: 2 additions & 2 deletions tracecat/etl/aws_cloudtrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

from tracecat.config import TRACECAT__TRIAGE_DIR
from tracecat.etl.aws_s3 import list_objects_under_prefix
from tracecat.logging import standard_logger
from tracecat.logging import Logger

logger = standard_logger("runner.aws_cloudtrail")
logger = Logger("runner.aws_cloudtrail")

# Supress botocore info logs
logging.getLogger("botocore").setLevel(logging.CRITICAL)
Expand Down
14 changes: 7 additions & 7 deletions tracecat/integrations/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
get_integration_key,
get_integration_platform,
)
from tracecat.logging import standard_logger
from tracecat.logging import logger
from tracecat.secrets import batch_get_secrets

if TYPE_CHECKING:
from tracecat.db import Secret

logger = standard_logger(__name__)


class Registry:
"""Singleton class to store and manage all registered integrations.
Expand Down Expand Up @@ -80,11 +78,11 @@ def wrapper(*args, **kwargs):
try:
role = kwargs.pop("__role", None)
# Get secrets from the secrets API
self._logger = standard_logger(f"{__name__}[PID={os.getpid()}]")
self._logger.info("Executing %r in subprocess", key)
self._logger = logger.bind(pid=os.getpid())
self._logger.bind(key=key).info("Executing in subprocess")

if secrets:
self._logger.info(f"Pulling secrets: {secrets!r}")
self._logger.bind(secrets=secrets).info("Pull secrets")
_secrets = self._get_secrets(role=role, secret_names=secrets)
self._set_secrets(_secrets)

Expand Down Expand Up @@ -116,7 +114,9 @@ def wrapper(*args, **kwargs):
def _get_secrets(self, role: Role, secret_names: list[str]) -> list[Secret]:
"""Retrieve secrets from the secrets API."""

self._logger.info(f"Getting secrets {secret_names!r}")
self._logger.opt(lazy=True).debug(
"Getting secrets {secret_names}", secret_names=lambda: secret_names
)
return asyncio.run(batch_get_secrets(role, secret_names))

def _set_secrets(self, secrets: list[Secret]):
Expand Down
4 changes: 0 additions & 4 deletions tracecat/integrations/aws_cloudtrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from tracecat.etl.aws_cloudtrail import load_cloudtrail_logs
from tracecat.etl.query_builder import pl_sql_query
from tracecat.integrations._registry import registry
from tracecat.logging import standard_logger

logger = standard_logger(__name__)


AWS_REGION_NAMES = Literal[
"us-east-1", # US East (N. Virginia)
Expand Down
4 changes: 0 additions & 4 deletions tracecat/integrations/datadog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@
import polars.selectors as cs

from tracecat.integrations._registry import registry
from tracecat.logging import standard_logger

logger = standard_logger(__name__)


DD_REGION_TO_URL = {
"ap1": "https://api.ap1.datadoghq.com",
Expand Down
4 changes: 0 additions & 4 deletions tracecat/integrations/virustotal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
import httpx

from tracecat.integrations._registry import registry
from tracecat.logging import standard_logger

logger = standard_logger(__name__)


VT_BASE_URL = "https://www.virustotal.com/api/v3/"

Expand Down
6 changes: 2 additions & 4 deletions tracecat/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from tenacity import retry, stop_after_attempt, wait_exponential

from tracecat.config import LLM_MAX_RETRIES
from tracecat.logging import standard_logger

logger = standard_logger(__name__)
from tracecat.logging import logger

ModelType = Literal[
"gpt-4-turbo-preview",
Expand Down Expand Up @@ -78,7 +76,7 @@ def parse_choice(choice: Choice) -> str | dict[str, Any]:
**kwargs,
)
# TODO: Should track these metrics
logger.info(f"🧠 {response.usage}")
logger.bind(usage=response.usage).info("🧠 Usage")
if stream:
return response

Expand Down
6 changes: 3 additions & 3 deletions tracecat/messaging/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from aio_pika.abc import AbstractRobustConnection
from aio_pika.pool import Pool

from tracecat.logging import standard_logger
from tracecat.logging import Logger

logger = standard_logger(__name__)
logger = Logger("rabbitmq.common")
RABBITMQ_URI = os.environ.get("RABBITMQ_URI", "amqp://guest:guest@localhost/")
RABBITMQ_USER = os.environ.get("RABBITMQ_USER", "guest")
RABBIMQ_PASS = os.environ.get("RABBITMQ_PASS", "guest")
Expand All @@ -37,7 +37,7 @@ async def get_connection() -> AbstractRobustConnection:
if uri.scheme not in ("amqps", "amqp"):
raise ValueError(f"Unsupported RabbitMQ URI scheme: {uri.scheme}")

logger.info(f"Connecting to RabbitMQ at {RABBITMQ_URI}")
logger.info("Connecting to RabbitMQ")
return await aio_pika.connect_robust(
ssl=uri.scheme == "amqps",
login=uri.username or RABBITMQ_USER,
Expand Down
8 changes: 5 additions & 3 deletions tracecat/messaging/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
wait_exponential,
)

from tracecat.logging import standard_logger
from tracecat.logging import Logger
from tracecat.messaging.common import RABBITMQ_RUNNER_EVENTS_EXCHANGE

logger = standard_logger(__name__)
logger = Logger("rabbitmq.consumer")


@asynccontextmanager
Expand All @@ -32,7 +32,9 @@ async def prepare_queue(*, channel: Channel, exchange: str, routing_keys: list[s
)

# Declaring random queue
queue = await channel.declare_queue(durable=True, exclusive=True)
queue = await channel.declare_queue(
durable=True, exclusive=True, auto_delete=True
)
for routing_key in routing_keys:
await queue.bind(ex, routing_key=routing_key)
yield queue
Expand Down
8 changes: 5 additions & 3 deletions tracecat/messaging/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from aio_pika import Channel, DeliveryMode, ExchangeType, Message
from aio_pika.pool import Pool

from tracecat.logging import standard_logger
from tracecat.logging import Logger
from tracecat.messaging.common import RABBITMQ_RUNNER_EVENTS_EXCHANGE

logger = standard_logger(__name__)
logger = Logger("rabbitmq.producer")


async def event_producer(
Expand All @@ -24,7 +24,9 @@ async def event_producer(

for routing_key in routing_keys:
await ex.publish(message, routing_key=routing_key)
logger.debug(f" [x] {routing_key = } Sent {message.body!r}")
logger.bind(routing_key=routing_key, body=message.body).debug(
"Published message"
)


async def publish(
Expand Down
2 changes: 1 addition & 1 deletion tracecat/middleware/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ async def dispatch(self, request: Request, call_next):
path=request.url.path,
params=request_params,
body=request_body,
).opt(lazy=True).debug("Request")
).debug("Request")

return await call_next(request)
Loading

0 comments on commit 35b56bb

Please sign in to comment.