Skip to content

Commit

Permalink
feat(engine): Update Secret schema
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed Apr 12, 2024
1 parent e27bd0c commit 440bcaa
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 29 deletions.
10 changes: 2 additions & 8 deletions frontend/src/lib/secrets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@ import { getAuthenticatedClient } from "@/lib/api"

export async function createSecret(
maybeSession: Session | null,
name: string,
value: string
secret: Secret
) {
try {
console.log("Creating secret", name, value)
const client = getAuthenticatedClient(maybeSession)
const data = {
name,
value,
}
await client.put("/secrets", JSON.stringify(data), {
await client.put("/secrets", JSON.stringify(secret), {
headers: {
"Content-Type": "application/json",
},
Expand Down
21 changes: 18 additions & 3 deletions frontend/src/types/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,33 @@ export const caseCompletionUpdateSchema = z.object({
})
export type CaseCompletionUpdate = z.infer<typeof caseCompletionUpdateSchema>

export const secretTypes = ["custom", "token", "oauth2"] as const
export type SecretType = (typeof secretTypes)[number]

const keyValueSchema = z.object({
key: z.string().min(1, "Please enter a key."),
value: z.string().min(1, "Please enter a value."),
})

const snakeCaseRegex = /^[a-z]+(_[a-z]+)*$/
export const secretSchema = z.object({
id: z.string().min(1).optional(),
name: z.string().min(1, "Please enter a secret name."),
value: z.string().min(1, "Please enter the secret value."),
type: z.enum(secretTypes),
name: z
.string()
.min(1, "Please enter a secret name.")
.regex(snakeCaseRegex, "Secret name must be snake case."),
description: z.string().max(255).nullable(),
// Can take different types of secrets
keys: z.array(keyValueSchema),
})

export type Secret = z.infer<typeof secretSchema>

export const integrationSchema = z.object({
id: z.string(),
name: z.string(),
description: z.string(),
description: z.string().nullable(),
docstring: z.string(),
parameters: stringToJSONSchema,
platform: z.enum(integrationPlatforms),
Expand Down
20 changes: 16 additions & 4 deletions tracecat/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,13 @@ def list_secrets(
result = session.exec(statement)
secrets = result.all()
return [
SecretResponse(id=secret.id, name=secret.name, value=secret.key)
SecretResponse(
id=secret.id,
type=secret.type,
name=secret.name,
description=secret.description,
keys=secret.keys or [],
)
for secret in secrets
]

Expand Down Expand Up @@ -1400,8 +1406,14 @@ def create_secret(
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="Secret already exists"
)
new_secret = Secret(owner_id=role.user_id, name=params.name)
new_secret.key = params.value # Set and encrypt the key
new_secret = Secret(
owner_id=role.user_id,
name=params.name,
type=params.type,
description=params.description,
tags=params.tags,
)
new_secret.keys = params.keys # Set and encrypt the key

session.add(new_secret)
session.commit()
Expand All @@ -1427,7 +1439,7 @@ def update_secret(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Secret does not exist"
)
secret.key = params.value # Set and encrypt the key
secret.keys = params.keys # Set and encrypt the key
session.add(secret)
session.commit()
session.refresh(secret)
Expand Down
45 changes: 33 additions & 12 deletions tracecat/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from datetime import datetime
from pathlib import Path
from typing import Self
from typing import Any, Self
from uuid import uuid4

import lancedb
Expand All @@ -11,7 +11,7 @@
from croniter import croniter
from pydantic import computed_field, field_validator
from slugify import slugify
from sqlalchemy import TIMESTAMP, Column, Engine, ForeignKey, String, text
from sqlalchemy import JSON, TIMESTAMP, Column, Engine, ForeignKey, String, text
from sqlmodel import (
Field,
Relationship,
Expand All @@ -23,14 +23,15 @@
)

from tracecat import auth
from tracecat.auth import decrypt, encrypt
from tracecat.auth import decrypt_object, encrypt_object
from tracecat.config import (
TRACECAT__APP_ENV,
TRACECAT__RUNNER_URL,
)
from tracecat.integrations import IntegrationSpec, registry
from tracecat.labels.mitre import get_mitre_tactics_techniques
from tracecat.logger import standard_logger
from tracecat.types.secrets import SECRET_FACTORY, SecretBase, SecretKeyValue

logger = standard_logger("db")

Expand Down Expand Up @@ -89,23 +90,43 @@ class Resource(SQLModel):


class Secret(Resource, table=True):
"""Secret model.
A secret can contain an arbitrary number of keys.
e.g.
"""

id: str | None = Field(default_factory=lambda: uuid4().hex, primary_key=True)
name: str | None = Field(default=None, max_length=255, index=True, nullable=True)
encrypted_secret: bytes | None = Field(default=None, nullable=True)
type: str # "custom", "token", "oauth2"
name: str = Field(..., max_length=255, index=True, nullable=False)
description: str | None = Field(default=None, max_length=255)
# We store this object as encrypted bytes, but first validate that it's the correct type
encrypted_keys: bytes | None = Field(default=None, nullable=True)
tags: dict[str, str] | None = Field(sa_column=Column(JSON))
owner_id: str = Field(
sa_column=Column(String, ForeignKey("user.id", ondelete="CASCADE"))
)
owner: User | None = Relationship(back_populates="secrets")

def _validate_obj(self, value: dict[str, Any]) -> SecretBase:
if self.type not in SECRET_FACTORY:
raise ValueError(f"Invalid secret type {self.type!r}")
return SECRET_FACTORY[self.type].model_validate(value)

@property
def key(self) -> str | None:
if not self.encrypted_api_key:
def keys(self) -> list[SecretKeyValue] | None:
if not self.encrypted_keys:
return None
return decrypt(self.encrypted_api_key)

@key.setter
def key(self, value: str) -> None:
self.encrypted_secret = encrypt(value)
obj = decrypt_object(self.encrypted_keys)
kv = self._validate_obj(obj)
return [SecretKeyValue(key=k, value=v) for k, v in kv.model_dump().items()]

@keys.setter
def keys(self, value: list[SecretKeyValue]) -> None:
# Convert to dict
kv = {item.key: item.value for item in value}
self._validate_obj(kv)
self.encrypted_keys = encrypt_object(kv)


class Editor(SQLModel, table=True):
Expand Down
15 changes: 13 additions & 2 deletions tracecat/types/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from tracecat.db import ActionRun, WorkflowRun
from tracecat.types.actions import ActionType
from tracecat.types.secrets import SecretKeyValue

# TODO: Consistent API design
# Action and Workflow create / update params
Expand Down Expand Up @@ -185,8 +186,16 @@ class CreateUserParams(BaseModel):


class CreateSecretParams(BaseModel):
# Secret types
# ------------
# - Custom: Arbitrary user-defined types
# - Token: A token, e.g. API Key, JWT Token (TBC)
# - OAuth2: OAuth2 Client Credentials (TBC)
type: Literal["custom"] # Support other types later
name: str
value: str
description: str | None = None
keys: list[SecretKeyValue]
tags: dict[str, str] | None = None


UpdateSecretParams = CreateSecretParams
Expand Down Expand Up @@ -251,5 +260,7 @@ class CopyWorkflowParams(BaseModel):

class SecretResponse(BaseModel):
id: str
type: Literal["custom"] # Support other types later
name: str
value: str
description: str | None = None
keys: list[SecretKeyValue]
33 changes: 33 additions & 0 deletions tracecat/types/secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pydantic import BaseModel, ConfigDict


class SecretKeyValue(BaseModel):
key: str
value: str


class SecretBase(BaseModel):
pass


class CustomSecret(SecretBase):
model_config = ConfigDict(extra="allow")


# class TokenSecret(SecretBase):
# token: str


# class OAuth2Secret(SecretBase):
# client_id: str
# client_secret: str
# redirect_uri: str


SecretVariant = CustomSecret # | TokenSecret | OAuth2Secret

SECRET_FACTORY: dict[str, type[SecretBase]] = {
"custom": CustomSecret,
# "token": TokenSecret,
# "oauth2": OAuth2Secret,
}

0 comments on commit 440bcaa

Please sign in to comment.