diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 49e2017c..dc9fae55 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -13,7 +13,7 @@ on: jobs: semantic-pr: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest timeout-minutes: 1 steps: - name: Semantic pull-request diff --git a/ai21/clients/common/maestro/run.py b/ai21/clients/common/maestro/run.py index b729c835..f7e7e20a 100644 --- a/ai21/clients/common/maestro/run.py +++ b/ai21/clients/common/maestro/run.py @@ -66,7 +66,7 @@ def retrieve(self, run_id: str) -> RunResponse: pass @abstractmethod - def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse: + def poll_for_status(self, *, run_id: str, poll_interval_sec: float, poll_timeout_sec: float) -> RunResponse: pass @abstractmethod diff --git a/ai21/clients/studio/resources/maestro/run.py b/ai21/clients/studio/resources/maestro/run.py index 5a49176d..c8b7f690 100644 --- a/ai21/clients/studio/resources/maestro/run.py +++ b/ai21/clients/studio/resources/maestro/run.py @@ -53,7 +53,7 @@ def retrieve( ) -> RunResponse: return self._get(path=f"/{self._module_name}/{run_id}", response_cls=RunResponse) - def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse: + def poll_for_status(self, *, run_id: str, poll_interval_sec: float, poll_timeout_sec: float) -> RunResponse: start_time = time.time() while True: @@ -62,10 +62,10 @@ def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: f if run.status in TERMINATED_RUN_STATUSES: return run - if (time.time() - start_time) >= poll_timeout: + if (time.time() - start_time) >= poll_timeout_sec: return run - time.sleep(poll_interval) + time.sleep(poll_interval_sec) def create_and_poll( self, @@ -92,7 +92,9 @@ def create_and_poll( **kwargs, ) - return self._poll_for_status(run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec) + return self.poll_for_status( + run_id=run.id, poll_interval_sec=poll_interval_sec, poll_timeout_sec=poll_timeout_sec + ) class AsyncMaestroRun(AsyncStudioResource, BaseMaestroRun): @@ -127,7 +129,7 @@ async def retrieve( ) -> RunResponse: return await self._get(path=f"/{self._module_name}/{run_id}", response_cls=RunResponse) - async def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse: + async def poll_for_status(self, *, run_id: str, poll_interval_sec: float, poll_timeout_sec: float) -> RunResponse: start_time = time.time() while True: @@ -136,10 +138,10 @@ async def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_time if run.status in TERMINATED_RUN_STATUSES: return run - if (time.time() - start_time) >= poll_timeout: + if (time.time() - start_time) >= poll_timeout_sec: return run - await asyncio.sleep(poll_interval) + await asyncio.sleep(poll_interval_sec) async def create_and_poll( self, @@ -166,6 +168,6 @@ async def create_and_poll( **kwargs, ) - return await self._poll_for_status( - run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec + return await self.poll_for_status( + run_id=run.id, poll_interval_sec=poll_interval_sec, poll_timeout_sec=poll_timeout_sec )