Skip to content

Commit

Permalink
Merge pull request #60 from Miksus/dev/async
Browse files Browse the repository at this point in the history
ENH: Add async tasks
  • Loading branch information
Miksus committed Jul 23, 2022
2 parents 3f54835 + 617de60 commit 91790ee
Show file tree
Hide file tree
Showing 12 changed files with 349 additions and 100 deletions.
6 changes: 6 additions & 0 deletions rocketry/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def run(self, debug=False):
self.session.set_as_default()
self.session.start()

async def serve(self, debug=False):
"Run the scheduler"
self.session.config.debug = debug
self.session.set_as_default()
await self.session.serve()

def cond(self, syntax: Union[str, Pattern, List[Union[str, Pattern]]]=None):
"Create a condition (decorator)"
return FuncCond(syntax=syntax, session=self.session, decor_return_func=False)
Expand Down
85 changes: 51 additions & 34 deletions rocketry/core/schedule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import asyncio
from multiprocessing import cpu_count
import multiprocessing
from typing import TYPE_CHECKING, Callable, Optional, Union
Expand Down Expand Up @@ -101,6 +102,12 @@ def tasks(self):
return sorted(tasks, key=lambda task: getattr(task, "priority", 0), reverse=True)

def __call__(self):
return self.run()

def run(self):
return asyncio.run(self.serve())

async def serve(self):
"""Start and run the scheduler. Will block till the end of the scheduling
session."""
# Unsetting some flags
Expand All @@ -111,16 +118,16 @@ def __call__(self):
self.is_alive = True
exception = None
try:
self.startup()
await self.startup()

while not self.check_shut_cond(self.session.config.shut_cond):
await self._hibernate()
if self._flag_shutdown.is_set():
break
elif self._flag_restart.is_set():
raise SchedulerRestart()

self._hibernate()
self.run_cycle()
await self.run_cycle()

# self.maintain()
except SystemExit as exc:
Expand All @@ -146,9 +153,9 @@ def __call__(self):
else:
self.logger.info('Purpose completed. Shutting down...', extra={"action": "shutdown"})
finally:
self.shut_down(exception=exception)
await self.shut_down(exception=exception)

def run_cycle(self):
async def run_cycle(self):
"""Run one round of tasks.
Each task is inspected and in case their starting condition
Expand All @@ -173,15 +180,15 @@ def run_cycle(self):
pass
elif self._flag_enabled.is_set() and self.is_task_runnable(task):
# Run the actual task
self.run_task(task)
await self.run_task(task)
# Reset force_run as a run has forced
task.force_run = False
elif self.is_timeouted(task):
# Terminate the task
self.terminate_task(task, reason="timeout")
await self.terminate_task(task, reason="timeout")
elif self.is_out_of_condition(task):
# Terminate the task
self.terminate_task(task)
await self.terminate_task(task)

# Running hooks
hooker.postrun()
Expand All @@ -202,12 +209,12 @@ def check_task_cond(self, task:Task):
raise
return False

def run_task(self, task:Task, *args, **kwargs):
async def run_task(self, task:Task, *args, **kwargs):
"""Run a given task"""
start_time = datetime.datetime.fromtimestamp(time.time())

try:
task(log_queue=self._log_queue)
await task.start_async(log_queue=self._log_queue)
except (SchedulerRestart, SchedulerExit) as exc:
raise
except Exception as exc:
Expand All @@ -217,13 +224,13 @@ def run_task(self, task:Task, *args, **kwargs):
exception = None
status = "success"

def terminate_all(self, reason:str=None):
async def terminate_all(self, reason:str=None):
"""Terminate all running tasks."""
for task in self.tasks:
if task.is_alive():
self.terminate_task(task, reason=reason)
await self.terminate_task(task, reason=reason)

def terminate_task(self, task, reason=None):
async def terminate_task(self, task, reason=None):
"""Terminate a given task."""
self.logger.debug(f"Terminating task '{task.name}'")
is_threaded = hasattr(task, "_thread")
Expand All @@ -242,6 +249,12 @@ def terminate_task(self, task, reason=None):

# Resetting attr force_termination
task.force_termination = False
elif task.is_alive_as_async():
task._async_task.cancel()
try:
await task._async_task
except asyncio.CancelledError:
task.log_termination()
else:
# The process/thread probably just died after the check
pass
Expand Down Expand Up @@ -283,6 +296,9 @@ def is_task_runnable(self, task:Task):
elif execution == "thread":
is_not_running = not task.is_alive()
return is_not_running and is_condition
elif execution == "async":
is_not_running = not task.is_alive()
return is_not_running and is_condition
else:
raise NotImplementedError(task.execution)

Expand Down Expand Up @@ -336,13 +352,13 @@ def handle_logs(self):

task.log_record(record)

def _hibernate(self):
async def _hibernate(self):
"""Go to sleep and wake up when next task can be executed."""
delay = self.session.config.cycle_sleep
if delay is not None:
time.sleep(delay)
await asyncio.sleep(delay)

def startup(self):
async def startup(self):
"""Start up the scheduler.
Starting up includes setting up attributes and
Expand All @@ -363,7 +379,7 @@ def startup(self):
task.force_run = True

if self.is_task_runnable(task):
self.run_task(task)
await self.run_task(task)

hooker.postrun()
self.logger.info(f"Setup complete.")
Expand All @@ -378,40 +394,43 @@ def n_alive(self) -> int:
"""Count of tasks that are alive."""
return sum(task.is_alive() for task in self.tasks)

def _shut_down_tasks(self, traceback=None, exception=None):
async def _shut_down_tasks(self, traceback=None, exception=None):
non_fatal_excs = (SchedulerRestart,) # Exceptions that are allowed to have graceful exit
wait_for_finish = not self.session.config.instant_shutdown and (exception is None or isinstance(exception, non_fatal_excs))
if wait_for_finish:
try:
# Gracefully shut down (allow remaining tasks to finish)
while self.n_alive:
#time.sleep(self.min_sleep)

await self._hibernate() # This is the time async tasks can continue

self.handle_logs()
for task in self.tasks:
if task.permanent_task:
# Would never "finish" anyways
self.terminate_task(task)
await self.terminate_task(task)
elif self.is_timeouted(task):
# Terminate the task
self.terminate_task(task, reason="timeout")
await self.terminate_task(task, reason="timeout")
elif self.is_out_of_condition(task):
# Terminate the task
self.terminate_task(task)
await self.terminate_task(task)
except Exception as exc:
# Fuck it, terminate all
self._shut_down_tasks(exception=exc)
await self._shut_down_tasks(exception=exc)
return
else:
self.handle_logs()
else:
self.terminate_all(reason="shutdown")
await self.terminate_all(reason="shutdown")

def wait_task_alive(self):
async def wait_task_alive(self):
"""Wait till all, especially threading tasks, are finished."""
while self.n_alive > 0:
time.sleep(0.005)
await self._hibernate()

def shut_down(self, traceback=None, exception=None):
async def shut_down(self, traceback=None, exception=None):
"""Shut down the scheduler.
Shutting down includes running tasks that have
Expand Down Expand Up @@ -440,13 +459,12 @@ def shut_down(self, traceback=None, exception=None):
task.force_run = True

if self.is_task_runnable(task):
self.run_task(task)
await self.run_task(task)

self.logger.info(f"Shutting down tasks...")
self._shut_down_tasks(traceback, exception)
await self._shut_down_tasks(traceback, exception)

if not self.session.config.instant_shutdown:
self.wait_task_alive() # Wait till all tasks' threads and processes are dead
await self.wait_task_alive() # Wait till all tasks' threads and processes are dead

# Running hooks
hooker.postrun()
Expand All @@ -456,14 +474,13 @@ def shut_down(self, traceback=None, exception=None):
if isinstance(exception, SchedulerRestart):
# Clean up finished, restart is finally
# possible
self._restart()
await self._restart()

def _restart(self):
async def _restart(self):
"""Restart the scheduler by creating a new process
on the temporary run script where the scheduler's is
process is started.
"""
# TODO
# https://stackoverflow.com/a/35874988
self.logger.debug(f"Restarting...", extra={"action": "restart"})
python = sys.executable
Expand All @@ -485,7 +502,7 @@ def _restart(self):
elif restarting == "recall":
# Mostly useful for testing.
# Restart by calling the self.__call__ again
return self()
await asyncio.create_task(self.serve())
else:
raise ValueError(f"Invalid restaring: {restarting}")

Expand Down
43 changes: 34 additions & 9 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

import asyncio
import inspect
from pickle import PicklingError
import sys
import time
Expand Down Expand Up @@ -163,7 +165,7 @@ class Config:
name: Optional[str] = Field(description="Name of the task. Must be unique")
description: Optional[str] = Field(description="Description of the task for documentation")
logger_name: Optional[str] = Field(description="Logger name to be used in logging the task records")
execution: Optional[Literal['main', 'thread', 'process']]
execution: Optional[Literal['main', 'async', 'thread', 'process']]
priority: int = 0
disabled: bool = False
force_run: bool = False
Expand All @@ -189,6 +191,7 @@ class Config:
_thread: threading.Thread = None
_thread_terminate: threading.Event = PrivateAttr(default_factory=threading.Event)
_lock: Optional[threading.Lock] = PrivateAttr(default_factory=threading.Lock)
_async_task: Optional[asyncio.Task] = PrivateAttr(default=None)

_mark_running = False

Expand Down Expand Up @@ -231,6 +234,8 @@ def parse_logger_name(cls, value, values):
def parse_timeout(cls, value, values):
if value == "never":
return datetime.timedelta.max
elif isinstance(value, (float, int)):
return to_timedelta(value, unit="s")
elif value is not None:
return to_timedelta(value)
else:
Expand Down Expand Up @@ -315,7 +320,14 @@ def parse_parameters(cls, value):
def __hash__(self):
return id(self)

def __call__(self, params:Union[dict, Parameters]=None, **kwargs):
def __call__(self, *args, **kwargs):
"Run sync"
self.start(*args, **kwargs)

def start(self, *args, **kwargs):
return asyncio.run(self.start_async(*args, **kwargs))

async def start_async(self, params:Union[dict, Parameters]=None, **kwargs):
"""Execute the task. Creates a new process
(if execution='process'), a new thread
(if execution='thread') or blocks and
Expand Down Expand Up @@ -350,9 +362,13 @@ def __call__(self, params:Union[dict, Parameters]=None, **kwargs):
try:
params = self.get_extra_params(params)
# Run the actual task
if execution == "main":
if execution in ("main", "async"):
direct_params = self.parameters
self._run_as_main(params=params, direct_params=direct_params, execution="main", **kwargs)
async_task = asyncio.create_task(self._run_as_async(params=params, direct_params=direct_params, execution=execution, **kwargs))
if execution == "main":
await async_task
else:
self._async_task = async_task
if _IS_WINDOWS:
#! TODO: This probably is now solved
# There is an annoying bug (?) in Windows:
Expand Down Expand Up @@ -408,7 +424,10 @@ def is_runnable(self):
def run_as_main(self, params:Parameters):
return self._run_as_main(params, self.parameters)

def _run_as_main(self, params:Parameters, direct_params:Parameters, execution=None, **kwargs):
def _run_as_main(self, **kwargs):
return asyncio.run(self._run_as_async(**kwargs))

async def _run_as_async(self, params:Parameters, direct_params:Parameters, execution=None, **kwargs):
"""Run the task on the current thread and process"""
#self.logger.info(f'Running {self.name}', extra={"action": "run"})

Expand All @@ -431,10 +450,13 @@ def _run_as_main(self, params:Parameters, direct_params:Parameters, execution=No
params = Parameters(params) | Parameters(direct_params)
params = params.materialize(task=self)

if execution == 'main':
if execution in ('main', 'async'):
self.log_running()
try:
output = self.execute(**params)
if inspect.iscoroutinefunction(self.execute):
output = await self.execute(**params)
else:
output = self.execute(**params)

# NOTE: we process success here in case the process_success
# fails (therefore task fails)
Expand All @@ -456,7 +478,7 @@ def _run_as_main(self, params:Parameters, direct_params:Parameters, execution=No
status = "inaction"
exc_info = sys.exc_info()

except TaskTerminationException:
except (TaskTerminationException, asyncio.CancelledError):
# Task was terminated and the task's function
# did listen to that.
self.log_termination()
Expand Down Expand Up @@ -753,7 +775,10 @@ def get_default_name(self, **kwargs):

def is_alive(self) -> bool:
"""Whether the task is alive: check if the task has a live process or thread."""
return self.is_alive_as_thread() or self.is_alive_as_process()
return self.is_alive_as_async() or self.is_alive_as_thread() or self.is_alive_as_process()

def is_alive_as_async(self) -> bool:
return self._async_task is not None and not self._async_task.done()

def is_alive_as_thread(self) -> bool:
"""Whether the task has a live thread."""
Expand Down
Loading

0 comments on commit 91790ee

Please sign in to comment.