Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AsyncOperation: support for running asynchronous jobs while the agent is running #111

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions skyvern/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class Settings(BaseSettings):
# browser settings
BROWSER_LOCALE: str = "en-US"
BROWSER_TIMEZONE: str = "America/New_York"
BROWSER_WIDTH: int = 1920
BROWSER_HEIGHT: int = 1080

#####################
# LLM Configuration #
Expand Down
19 changes: 19 additions & 0 deletions skyvern/forge/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import requests
import structlog
from playwright._impl._errors import TargetClosedError
from playwright.async_api import Page

from skyvern import analytics
from skyvern.exceptions import (
Expand All @@ -17,6 +18,7 @@
TaskNotFound,
)
from skyvern.forge import app
from skyvern.forge.async_operations import AgentPhase, AsyncOperationPool
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.agent import Agent
from skyvern.forge.sdk.artifact.models import ArtifactType
Expand Down Expand Up @@ -64,6 +66,7 @@ def __init__(self) -> None:
long_running_task_warning_ratio=SettingsManager.get_settings().LONG_RUNNING_TASK_WARNING_RATIO,
debug_mode=SettingsManager.get_settings().DEBUG_MODE,
)
self.async_operation_pool = AsyncOperationPool()

async def validate_step_execution(
self,
Expand Down Expand Up @@ -193,6 +196,12 @@ async def create_task(self, task_request: TaskRequest, organization_id: str | No
)
return task

def register_async_operations(self, organization: Organization, task: Task, page: Page) -> None:
if not app.generate_async_operations:
return
operations = app.generate_async_operations(organization, task, page)
self.async_operation_pool.add_operations(task.task_id, operations)

async def execute_step(
self,
organization: Organization,
Expand All @@ -208,6 +217,10 @@ async def execute_step(
# Check some conditions before executing the step, throw an exception if the step can't be executed
await self.validate_step_execution(task, step)
step, browser_state, detailed_output = await self._initialize_execution_state(task, step, workflow_run)

if browser_state.page:
self.register_async_operations(organization, task, browser_state.page)

step, detailed_output = await self.agent_step(task, step, browser_state, organization=organization)
task = await self.update_task_errors_from_detailed_output(task, detailed_output)
retry = False
Expand All @@ -226,6 +239,7 @@ async def execute_step(
api_key=api_key,
close_browser_on_completion=close_browser_on_completion,
)
await self.async_operation_pool.remove_task(task.task_id)
return step, detailed_output, None
elif step.status == StepStatus.completed:
# TODO (kerem): keep the task object uptodate at all times so that send_task_response can just use it
Expand Down Expand Up @@ -332,6 +346,7 @@ async def agent_step(
json_response = None
actions: list[Action]
if task.navigation_goal:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
json_response = await app.LLM_API_HANDLER(
prompt=extract_action_prompt,
step=step,
Expand Down Expand Up @@ -403,6 +418,7 @@ async def agent_step(
break
web_action_element_ids.add(action.element_id)

self.async_operation_pool.run_operation(task.task_id, AgentPhase.action)
results = await ActionHandler.handle_action(scraped_page, task, step, browser_state, action)
detailed_agent_step_output.actions_and_results[action_idx] = (action, results)
# wait random time between actions to avoid detection
Expand Down Expand Up @@ -559,6 +575,9 @@ async def _build_and_record_step_prompt(
step: Step,
browser_state: BrowserState,
) -> tuple[ScrapedPage, str]:
# start the async tasks while running scrape_website
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)

# Scrape the web page and get the screenshot and the elements
scraped_page = await scrape_website(
browser_state,
Expand Down
8 changes: 8 additions & 0 deletions skyvern/forge/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import Callable

from ddtrace import tracer
from ddtrace.filters import FilterRequestsOnUrl
from playwright.async_api import Page

from skyvern.forge.agent import ForgeAgent
from skyvern.forge.async_operations import AsyncOperation
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.artifact.manager import ArtifactManager
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
from skyvern.forge.sdk.db.client import AgentDB
from skyvern.forge.sdk.forge_log import setup_logger
from skyvern.forge.sdk.models import Organization
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
from skyvern.forge.sdk.workflow.service import WorkflowService
Expand All @@ -31,6 +37,8 @@
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
WORKFLOW_SERVICE = WorkflowService()
generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None

agent = ForgeAgent()

app = agent.get_agent_app()
132 changes: 132 additions & 0 deletions skyvern/forge/async_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import asyncio
from enum import StrEnum

import structlog
from playwright.async_api import Page

LOG = structlog.get_logger()


class AgentPhase(StrEnum):
"""
Phase of agent when async execution events are happening
"""

action = "action"
scrape = "scrape"
llm = "llm"


VALID_AGENT_PHASES = [phase.value for phase in AgentPhase]


class AsyncOperation:
"""
AsyncOperation can take async actions on the page while agent is performing the task.

Examples:
- collect info based on the html/DOM and send data to your server
"""

def __init__(self, task_id: str, operation_type: str, agent_phase: AgentPhase, page: Page) -> None:
"""
:param task_id: task_id of the task
:param operation_type: it's the custom type of the operation.
there will only be up to one aio task running per operation_type
:param agent_phase: AgentPhase type. phase of the agent when the operation is running
:param page: playwright page for the task
"""
self.task_id = task_id
self.type = operation_type
self.agent_phase = agent_phase
self.aio_task: asyncio.Task | None = None

# playwright page could be used by the operation to take actions
self.page = page

async def execute(self) -> None:
return

def run(self) -> asyncio.Task | None:
if self.aio_task is not None and not self.aio_task.done():
LOG.warning(
f"Task already running",
task_id=self.task_id,
operation_type=self.type,
agent_phase=self.agent_phase,
)
return None
self.aio_task = asyncio.create_task(self.execute())
return self.aio_task


class AsyncOperationPool:
_operations: dict[str, dict[AgentPhase, AsyncOperation]] = {} # task_id: {agent_phase: operation}

# use _aio_tasks to ensure we're only execution one aio task for the same operation_type
_aio_tasks: dict[str, dict[str, asyncio.Task]] = {} # task_id: {operation_type: aio_task}

def _add_operation(self, task_id: str, operation: AsyncOperation) -> None:
if operation.agent_phase not in VALID_AGENT_PHASES:
raise ValueError(f"operation's agent phase {operation.agent_phase} is not valid")
if task_id not in self._operations:
self._operations[task_id] = {}
self._operations[task_id][operation.agent_phase] = operation

def add_operations(self, task_id: str, operations: list[AsyncOperation]) -> None:
if task_id in self._operations:
# already exists
return
for operation in operations:
self._add_operation(task_id, operation)

def _get_operation(self, task_id: str, operation_type: AgentPhase) -> AsyncOperation | None:
return self._operations.get(task_id, {}).get(operation_type, None)

def remove_operations(self, task_id: str) -> None:
if task_id in self._operations:
del self._operations[task_id]

def get_aio_tasks(self, task_id: str) -> list[asyncio.Task]:
"""
Get all the running/pending aio tasks for the given task_id
"""
return [aio_task for aio_task in self._aio_tasks.get(task_id, {}).values() if not aio_task.done()]

def run_operation(self, task_id: str, agent_phase: AgentPhase) -> None:
# get the operation from the pool
operation = self._get_operation(task_id, agent_phase)
if operation is None:
return

# if found, initialize the operation if it's the first time running the aio task
operation_type = operation.type
if task_id not in self._aio_tasks:
self._aio_tasks[task_id] = {}

# if the aio task is already running, don't run it again
aio_task: asyncio.Task | None = None
if operation_type in self._aio_tasks[task_id]:
aio_task = self._aio_tasks[task_id][operation_type]
if not aio_task.done():
LOG.info(
f"aio task already running",
task_id=task_id,
operation_type=operation_type,
agent_phase=agent_phase,
)
return

# run the operation if the aio task is not running
aio_task = operation.run()
if aio_task:
self._aio_tasks[task_id][operation_type] = aio_task

async def remove_task(self, task_id: str) -> None:
try:
async with asyncio.timeout(30):
await asyncio.gather(*[aio_task for aio_task in self.get_aio_tasks(task_id) if not aio_task.done()])
except asyncio.TimeoutError:
LOG.error(f"Timeout (30s) while waiting for pending async tasks for task_id={task_id}", task_id=task_id)

self.remove_operations(task_id)
3 changes: 2 additions & 1 deletion skyvern/webeye/browser_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from playwright.async_api import BrowserContext, Error, Page, Playwright, async_playwright
from pydantic import BaseModel

from skyvern.config import settings
from skyvern.exceptions import (
FailedToNavigateToUrl,
FailedToTakeScreenshot,
Expand Down Expand Up @@ -61,7 +62,7 @@ def build_browser_args() -> dict[str, Any]:
],
"record_har_path": har_dir,
"record_video_dir": video_dir,
"viewport": {"width": 1920, "height": 1080},
"viewport": {"width": settings.BROWSER_WIDTH, "height": settings.BROWSER_HEIGHT},
}

@staticmethod
Expand Down
Loading