Skip to content

Commit

Permalink
add validate browser (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
LawyZheng committed Jun 17, 2024
1 parent 10612f0 commit df2c55b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
50 changes: 45 additions & 5 deletions skyvern/webeye/browser_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import uuid
from datetime import datetime
from typing import Any, Awaitable, Protocol
from typing import Any, Awaitable, Callable, Protocol

import structlog
from playwright._impl._errors import TimeoutError
Expand All @@ -21,6 +21,7 @@
UnknownErrorWhileCreatingBrowserContext,
)
from skyvern.forge.sdk.core.skyvern_context import current
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.settings_manager import SettingsManager

LOG = structlog.get_logger()
Expand All @@ -34,6 +35,7 @@ def __call__(

class BrowserContextFactory:
_creators: dict[str, BrowserContextCreator] = {}
_validator: Callable[[Page], Awaitable[bool]] | None = None

@staticmethod
def get_subdir() -> str:
Expand Down Expand Up @@ -101,6 +103,16 @@ async def create_browser_context(
except Exception as e:
raise UnknownErrorWhileCreatingBrowserContext(browser_type, e) from e

@classmethod
def set_validate_browser_context(cls, validator: Callable[[Page], Awaitable[bool]]) -> None:
cls._validator = validator

@classmethod
async def validate_browser_context(cls, page: Page) -> bool:
if cls._validator is None:
return True
return await cls._validator(page)


class BrowserArtifacts(BaseModel):
video_path: str | None = None
Expand Down Expand Up @@ -155,7 +167,12 @@ async def _close_all_other_pages(self) -> None:
if page != self.page:
await page.close()

async def check_and_fix_state(self, url: str | None = None) -> None:
async def check_and_fix_state(
self,
url: str | None = None,
proxy_location: ProxyLocation | None = None,
task_id: str | None = None,
) -> None:
if self.pw is None:
LOG.info("Starting playwright")
self.pw = await async_playwright().start()
Expand All @@ -165,7 +182,12 @@ async def check_and_fix_state(self, url: str | None = None) -> None:
(
browser_context,
browser_artifacts,
) = await BrowserContextFactory.create_browser_context(self.pw, url=url)
) = await BrowserContextFactory.create_browser_context(
self.pw,
url=url,
proxy_location=proxy_location,
task_id=task_id,
)
self.browser_context = browser_context
self.browser_artifacts = browser_artifacts
LOG.info("browser context is created")
Expand Down Expand Up @@ -216,9 +238,27 @@ async def check_and_fix_state(self, url: str | None = None) -> None:
if self.browser_artifacts.video_path is None:
self.browser_artifacts.video_path = await self.page.video.path() if self.page and self.page.video else None

async def get_or_create_page(self, url: str | None = None) -> Page:
await self.check_and_fix_state(url)
async def get_or_create_page(
self,
url: str | None = None,
proxy_location: ProxyLocation | None = None,
task_id: str | None = None,
) -> Page:
if self.page is not None:
return self.page

await self.check_and_fix_state(url=url, proxy_location=proxy_location, task_id=task_id)
assert self.page is not None

if not await BrowserContextFactory.validate_browser_context(self.page):
await self._close_all_other_pages()
if self.browser_context is not None:
await self.browser_context.close()
self.browser_context = None
self.page = None
await self.check_and_fix_state(url=url, proxy_location=proxy_location, task_id=task_id)
assert self.page is not None

return self.page

async def close(self, close_browser_on_completion: bool = True) -> None:
Expand Down
4 changes: 2 additions & 2 deletions skyvern/webeye/browser_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def get_or_create_for_task(self, task: Task) -> BrowserState:

# The URL here is only used when creating a new page, and not when using an existing page.
# This will make sure browser_state.page is not None.
await browser_state.get_or_create_page(task.url)
await browser_state.get_or_create_page(url=task.url, proxy_location=task.proxy_location, task_id=task.task_id)

self.pages[task.task_id] = browser_state
if task.workflow_run_id:
Expand All @@ -78,7 +78,7 @@ async def get_or_create_for_workflow_run(self, workflow_run: WorkflowRun, url: s

# The URL here is only used when creating a new page, and not when using an existing page.
# This will make sure browser_state.page is not None.
await browser_state.get_or_create_page(url)
await browser_state.get_or_create_page(url=url, proxy_location=workflow_run.proxy_location)

self.pages[workflow_run.workflow_run_id] = browser_state
return browser_state
Expand Down

0 comments on commit df2c55b

Please sign in to comment.