Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add async tasks #60

Merged
merged 17 commits into from
Jul 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rocketry/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def run(self, debug=False):
self.session.set_as_default()
self.session.start()

async def serve(self, debug=False):
"Run the scheduler"
self.session.config.debug = debug
self.session.set_as_default()
await self.session.serve()

def cond(self, syntax: Union[str, Pattern, List[Union[str, Pattern]]]=None):
"Create a condition (decorator)"
return FuncCond(syntax=syntax, session=self.session, decor_return_func=False)
Expand Down
85 changes: 51 additions & 34 deletions rocketry/core/schedule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import asyncio
from multiprocessing import cpu_count
import multiprocessing
from typing import TYPE_CHECKING, Callable, Optional, Union
Expand Down Expand Up @@ -101,6 +102,12 @@ def tasks(self):
return sorted(tasks, key=lambda task: getattr(task, "priority", 0), reverse=True)

def __call__(self):
return self.run()

def run(self):
return asyncio.run(self.serve())

async def serve(self):
"""Start and run the scheduler. Will block till the end of the scheduling
session."""
# Unsetting some flags
Expand All @@ -111,16 +118,16 @@ def __call__(self):
self.is_alive = True
exception = None
try:
self.startup()
await self.startup()

while not self.check_shut_cond(self.session.config.shut_cond):
await self._hibernate()
if self._flag_shutdown.is_set():
break
elif self._flag_restart.is_set():
raise SchedulerRestart()

self._hibernate()
self.run_cycle()
await self.run_cycle()

# self.maintain()
except SystemExit as exc:
Expand All @@ -146,9 +153,9 @@ def __call__(self):
else:
self.logger.info('Purpose completed. Shutting down...', extra={"action": "shutdown"})
finally:
self.shut_down(exception=exception)
await self.shut_down(exception=exception)

def run_cycle(self):
async def run_cycle(self):
"""Run one round of tasks.

Each task is inspected and in case their starting condition
Expand All @@ -173,15 +180,15 @@ def run_cycle(self):
pass
elif self._flag_enabled.is_set() and self.is_task_runnable(task):
# Run the actual task
self.run_task(task)
await self.run_task(task)
# Reset force_run as a run has forced
task.force_run = False
elif self.is_timeouted(task):
# Terminate the task
self.terminate_task(task, reason="timeout")
await self.terminate_task(task, reason="timeout")
elif self.is_out_of_condition(task):
# Terminate the task
self.terminate_task(task)
await self.terminate_task(task)

# Running hooks
hooker.postrun()
Expand All @@ -202,12 +209,12 @@ def check_task_cond(self, task:Task):
raise
return False

def run_task(self, task:Task, *args, **kwargs):
async def run_task(self, task:Task, *args, **kwargs):
"""Run a given task"""
start_time = datetime.datetime.fromtimestamp(time.time())

try:
task(log_queue=self._log_queue)
await task.start_async(log_queue=self._log_queue)
except (SchedulerRestart, SchedulerExit) as exc:
raise
except Exception as exc:
Expand All @@ -217,13 +224,13 @@ def run_task(self, task:Task, *args, **kwargs):
exception = None
status = "success"

def terminate_all(self, reason:str=None):
async def terminate_all(self, reason:str=None):
"""Terminate all running tasks."""
for task in self.tasks:
if task.is_alive():
self.terminate_task(task, reason=reason)
await self.terminate_task(task, reason=reason)

def terminate_task(self, task, reason=None):
async def terminate_task(self, task, reason=None):
"""Terminate a given task."""
self.logger.debug(f"Terminating task '{task.name}'")
is_threaded = hasattr(task, "_thread")
Expand All @@ -242,6 +249,12 @@ def terminate_task(self, task, reason=None):

# Resetting attr force_termination
task.force_termination = False
elif task.is_alive_as_async():
task._async_task.cancel()
try:
await task._async_task
except asyncio.CancelledError:
task.log_termination()
else:
# The process/thread probably just died after the check
pass
Expand Down Expand Up @@ -283,6 +296,9 @@ def is_task_runnable(self, task:Task):
elif execution == "thread":
is_not_running = not task.is_alive()
return is_not_running and is_condition
elif execution == "async":
is_not_running = not task.is_alive()
return is_not_running and is_condition
else:
raise NotImplementedError(task.execution)

Expand Down Expand Up @@ -336,13 +352,13 @@ def handle_logs(self):

task.log_record(record)

def _hibernate(self):
async def _hibernate(self):
"""Go to sleep and wake up when next task can be executed."""
delay = self.session.config.cycle_sleep
if delay is not None:
time.sleep(delay)
await asyncio.sleep(delay)

def startup(self):
async def startup(self):
"""Start up the scheduler.

Starting up includes setting up attributes and
Expand All @@ -363,7 +379,7 @@ def startup(self):
task.force_run = True

if self.is_task_runnable(task):
self.run_task(task)
await self.run_task(task)

hooker.postrun()
self.logger.info(f"Setup complete.")
Expand All @@ -378,40 +394,43 @@ def n_alive(self) -> int:
"""Count of tasks that are alive."""
return sum(task.is_alive() for task in self.tasks)

def _shut_down_tasks(self, traceback=None, exception=None):
async def _shut_down_tasks(self, traceback=None, exception=None):
non_fatal_excs = (SchedulerRestart,) # Exceptions that are allowed to have graceful exit
wait_for_finish = not self.session.config.instant_shutdown and (exception is None or isinstance(exception, non_fatal_excs))
if wait_for_finish:
try:
# Gracefully shut down (allow remaining tasks to finish)
while self.n_alive:
#time.sleep(self.min_sleep)

await self._hibernate() # This is the time async tasks can continue

self.handle_logs()
for task in self.tasks:
if task.permanent_task:
# Would never "finish" anyways
self.terminate_task(task)
await self.terminate_task(task)
elif self.is_timeouted(task):
# Terminate the task
self.terminate_task(task, reason="timeout")
await self.terminate_task(task, reason="timeout")
elif self.is_out_of_condition(task):
# Terminate the task
self.terminate_task(task)
await self.terminate_task(task)
except Exception as exc:
# Fuck it, terminate all
self._shut_down_tasks(exception=exc)
await self._shut_down_tasks(exception=exc)
return
else:
self.handle_logs()
else:
self.terminate_all(reason="shutdown")
await self.terminate_all(reason="shutdown")

def wait_task_alive(self):
async def wait_task_alive(self):
"""Wait till all, especially threading tasks, are finished."""
while self.n_alive > 0:
time.sleep(0.005)
await self._hibernate()

def shut_down(self, traceback=None, exception=None):
async def shut_down(self, traceback=None, exception=None):
"""Shut down the scheduler.

Shutting down includes running tasks that have
Expand Down Expand Up @@ -440,13 +459,12 @@ def shut_down(self, traceback=None, exception=None):
task.force_run = True

if self.is_task_runnable(task):
self.run_task(task)
await self.run_task(task)

self.logger.info(f"Shutting down tasks...")
self._shut_down_tasks(traceback, exception)
await self._shut_down_tasks(traceback, exception)

if not self.session.config.instant_shutdown:
self.wait_task_alive() # Wait till all tasks' threads and processes are dead
await self.wait_task_alive() # Wait till all tasks' threads and processes are dead

# Running hooks
hooker.postrun()
Expand All @@ -456,14 +474,13 @@ def shut_down(self, traceback=None, exception=None):
if isinstance(exception, SchedulerRestart):
# Clean up finished, restart is finally
# possible
self._restart()
await self._restart()

def _restart(self):
async def _restart(self):
"""Restart the scheduler by creating a new process
on the temporary run script where the scheduler's is
process is started.
"""
# TODO
# https://stackoverflow.com/a/35874988
self.logger.debug(f"Restarting...", extra={"action": "restart"})
python = sys.executable
Expand All @@ -485,7 +502,7 @@ def _restart(self):
elif restarting == "recall":
# Mostly useful for testing.
# Restart by calling the self.__call__ again
return self()
await asyncio.create_task(self.serve())
else:
raise ValueError(f"Invalid restaring: {restarting}")

Expand Down
43 changes: 34 additions & 9 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

import asyncio
import inspect
from pickle import PicklingError
import sys
import time
Expand Down Expand Up @@ -163,7 +165,7 @@ class Config:
name: Optional[str] = Field(description="Name of the task. Must be unique")
description: Optional[str] = Field(description="Description of the task for documentation")
logger_name: Optional[str] = Field(description="Logger name to be used in logging the task records")
execution: Optional[Literal['main', 'thread', 'process']]
execution: Optional[Literal['main', 'async', 'thread', 'process']]
priority: int = 0
disabled: bool = False
force_run: bool = False
Expand All @@ -189,6 +191,7 @@ class Config:
_thread: threading.Thread = None
_thread_terminate: threading.Event = PrivateAttr(default_factory=threading.Event)
_lock: Optional[threading.Lock] = PrivateAttr(default_factory=threading.Lock)
_async_task: Optional[asyncio.Task] = PrivateAttr(default=None)

_mark_running = False

Expand Down Expand Up @@ -231,6 +234,8 @@ def parse_logger_name(cls, value, values):
def parse_timeout(cls, value, values):
if value == "never":
return datetime.timedelta.max
elif isinstance(value, (float, int)):
return to_timedelta(value, unit="s")
elif value is not None:
return to_timedelta(value)
else:
Expand Down Expand Up @@ -315,7 +320,14 @@ def parse_parameters(cls, value):
def __hash__(self):
return id(self)

def __call__(self, params:Union[dict, Parameters]=None, **kwargs):
def __call__(self, *args, **kwargs):
"Run sync"
self.start(*args, **kwargs)

def start(self, *args, **kwargs):
return asyncio.run(self.start_async(*args, **kwargs))

async def start_async(self, params:Union[dict, Parameters]=None, **kwargs):
"""Execute the task. Creates a new process
(if execution='process'), a new thread
(if execution='thread') or blocks and
Expand Down Expand Up @@ -350,9 +362,13 @@ def __call__(self, params:Union[dict, Parameters]=None, **kwargs):
try:
params = self.get_extra_params(params)
# Run the actual task
if execution == "main":
if execution in ("main", "async"):
direct_params = self.parameters
self._run_as_main(params=params, direct_params=direct_params, execution="main", **kwargs)
async_task = asyncio.create_task(self._run_as_async(params=params, direct_params=direct_params, execution=execution, **kwargs))
if execution == "main":
await async_task
else:
self._async_task = async_task
if _IS_WINDOWS:
#! TODO: This probably is now solved
# There is an annoying bug (?) in Windows:
Expand Down Expand Up @@ -408,7 +424,10 @@ def is_runnable(self):
def run_as_main(self, params:Parameters):
return self._run_as_main(params, self.parameters)

def _run_as_main(self, params:Parameters, direct_params:Parameters, execution=None, **kwargs):
def _run_as_main(self, **kwargs):
return asyncio.run(self._run_as_async(**kwargs))

async def _run_as_async(self, params:Parameters, direct_params:Parameters, execution=None, **kwargs):
"""Run the task on the current thread and process"""
#self.logger.info(f'Running {self.name}', extra={"action": "run"})

Expand All @@ -431,10 +450,13 @@ def _run_as_main(self, params:Parameters, direct_params:Parameters, execution=No
params = Parameters(params) | Parameters(direct_params)
params = params.materialize(task=self)

if execution == 'main':
if execution in ('main', 'async'):
self.log_running()
try:
output = self.execute(**params)
if inspect.iscoroutinefunction(self.execute):
output = await self.execute(**params)
else:
output = self.execute(**params)

# NOTE: we process success here in case the process_success
# fails (therefore task fails)
Expand All @@ -456,7 +478,7 @@ def _run_as_main(self, params:Parameters, direct_params:Parameters, execution=No
status = "inaction"
exc_info = sys.exc_info()

except TaskTerminationException:
except (TaskTerminationException, asyncio.CancelledError):
# Task was terminated and the task's function
# did listen to that.
self.log_termination()
Expand Down Expand Up @@ -753,7 +775,10 @@ def get_default_name(self, **kwargs):

def is_alive(self) -> bool:
"""Whether the task is alive: check if the task has a live process or thread."""
return self.is_alive_as_thread() or self.is_alive_as_process()
return self.is_alive_as_async() or self.is_alive_as_thread() or self.is_alive_as_process()

def is_alive_as_async(self) -> bool:
return self._async_task is not None and not self._async_task.done()

def is_alive_as_thread(self) -> bool:
"""Whether the task has a live thread."""
Expand Down
Loading