Skip to content

Commit

Permalink
update organization API (#480)
Browse files Browse the repository at this point in the history
  • Loading branch information
ykeremy committed Jun 17, 2024
1 parent af81fb7 commit 10612f0
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 3 deletions.
5 changes: 5 additions & 0 deletions skyvern/forge/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from skyvern.forge import app as forge_app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.routes.agent_protocol import base_router
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.scheduler import SCHEDULER
Expand Down Expand Up @@ -75,6 +76,10 @@ def start_scheduler() -> None:

LOG.info("Server startup complete. Skyvern is now online")

@app.exception_handler(NotFoundError)
async def handle_not_found_error(request: Request, exc: NotFoundError) -> Response:
return Response(status_code=status.HTTP_404_NOT_FOUND)

@app.exception_handler(SkyvernHTTPException)
async def handle_skyvern_http_exception(request: Request, exc: SkyvernHTTPException) -> JSONResponse:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.message})
Expand Down
26 changes: 26 additions & 0 deletions skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,32 @@ async def create_organization(

return convert_to_organization(org)

async def update_organization(
self,
organization_id: str,
organization_name: str | None = None,
webhook_callback_url: str | None = None,
max_steps_per_run: int | None = None,
max_retries_per_step: int | None = None,
) -> Organization:
async with self.Session() as session:
organization = (
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
).first()
if not organization:
raise NotFoundError
if organization_name:
organization.organization_name = organization_name
if webhook_callback_url:
organization.webhook_callback_url = webhook_callback_url
if max_steps_per_run:
organization.max_steps_per_run = max_steps_per_run
if max_retries_per_step:
organization.max_retries_per_step = max_retries_per_step
await session.commit()
await session.refresh(organization)
return Organization.model_validate(organization)

async def get_valid_org_auth_token(
self,
organization_id: str,
Expand Down
4 changes: 2 additions & 2 deletions skyvern/forge/sdk/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class OrganizationModel(Base):
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime,
onupdate=datetime.datetime.utcnow,
nullable=False,
)

Expand All @@ -133,7 +133,7 @@ class OrganizationAuthTokenModel(Base):
modified_at = Column(
DateTime,
default=datetime.datetime.utcnow,
onupdate=datetime.datetime,
onupdate=datetime.datetime.utcnow,
nullable=False,
)
deleted_at = Column(DateTime, nullable=True)
Expand Down
4 changes: 3 additions & 1 deletion skyvern/forge/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from enum import StrEnum

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.webeye.actions.actions import ActionType
Expand Down Expand Up @@ -118,6 +118,8 @@ def is_terminated(self) -> bool:


class Organization(BaseModel):
model_config = ConfigDict(from_attributes=True)

organization_id: str
organization_name: str
webhook_callback_url: str | None = None
Expand Down
16 changes: 16 additions & 0 deletions skyvern/forge/sdk/routes/agent_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.models import Organization, Step
from skyvern.forge.sdk.schemas.organizations import OrganizationUpdate
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.tasks import (
CreateTaskResponse,
Expand Down Expand Up @@ -693,3 +694,18 @@ async def generate_task(
except LLMProviderError:
LOG.error("Failed to generate task", exc_info=True)
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")


@base_router.put("/organizations/", include_in_schema=False)
@base_router.put("/organizations")
async def update_organization(
org_update: OrganizationUpdate,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Organization:
return await app.DATABASE.update_organization(
current_org.organization_id,
organization_name=org_update.organization_name,
webhook_callback_url=org_update.webhook_callback_url,
max_steps_per_run=org_update.max_steps_per_run,
max_retries_per_step=org_update.max_retries_per_step,
)
8 changes: 8 additions & 0 deletions skyvern/forge/sdk/schemas/organizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel


class OrganizationUpdate(BaseModel):
organization_name: str | None = None
webhook_callback_url: str | None = None
max_steps_per_run: int | None = None
max_retries_per_step: int | None = None

0 comments on commit 10612f0

Please sign in to comment.