Skip to content

Commit

Permalink
Merge pull request #49 from Miksus/add/terminate
Browse files Browse the repository at this point in the history
ENH: Add rocketry.args.TerminationFlag
  • Loading branch information
Miksus committed Jul 12, 2022
2 parents 05e205c + 717b4c4 commit 32a6ee9
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 4 deletions.
2 changes: 1 addition & 1 deletion rocketry/args/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@

from .builtin import Arg, FuncArg, Return, Session, Task
from .builtin import Arg, FuncArg, Return, Session, Task, TerminationFlag
from .secret import Private
12 changes: 11 additions & 1 deletion rocketry/args/builtin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from typing import Any, Callable
import warnings

from rocketry.core.parameters import BaseArgument
from rocketry.core.utils import filter_keyword_args
Expand Down Expand Up @@ -55,6 +56,7 @@ def get_value(self, task=None, **kwargs) -> Any:
else:
return task.session[self.name]


class Return(BaseArgument):
"""A return argument
Expand Down Expand Up @@ -169,4 +171,12 @@ def __call__(self, **kwargs):

def __repr__(self):
cls_name = type(self).__name__
return f'{cls_name}({self.func.__name__})'
return f'{cls_name}({self.func.__name__})'

class TerminationFlag(BaseArgument):

def get_value(self, task=None, session=None, **kwargs) -> Any:
execution = task.execution
if execution in ("process", "main"):
warnings.warn(f"Passing termination flag to task with 'execution_type={execution}''. Flag cannot be used.")
return task._thread_terminate
4 changes: 4 additions & 0 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,10 @@ def delete(self):
Overried if needed additional cleaning."""
self.session.tasks.remove(self)

def terminate(self):
"Terminate this task"
self.force_termination = True

def _get_hooks(self, name:str):
return getattr(self.session.hooks, name)

Expand Down
53 changes: 51 additions & 2 deletions rocketry/test/schedule/test_params.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@

import threading
from time import sleep
import pytest
import rocketry
from rocketry.conditions import DependSuccess
from rocketry.exc import TaskTerminationException


from rocketry.tasks import FuncTask
from rocketry.time import TimeDelta
from rocketry.conditions import SchedulerCycles, SchedulerStarted, TaskStarted, AlwaysFalse, AlwaysTrue

from rocketry.args import Arg, Return, Session, Task, FuncArg #, Param, Session
from rocketry.args import Arg, Return, Session, Task, FuncArg, TerminationFlag #, Param, Session

# Example functions
# -----------------
Expand Down Expand Up @@ -46,6 +49,25 @@ def run_with_session(arg=Session()):
assert isinstance(arg, rocketry.Session)
assert arg.parameters['my_arg'] == 'some session value'

def run_with_termination_flag(flag=TerminationFlag(), task=Task()):
if task.execution == "process":
return
assert isinstance(flag, threading.Event), f"Flag incorrect type: {type(flag)}"
assert not flag.is_set()

if task.execution == "main":
return

waited = 0
while True:
if flag.is_set():
raise TaskTerminationException("Flag raised")

sleep(0.001)
waited += 0.001
if waited > 1:
raise RuntimeError("Did not terminate")

# Tests
# -----

Expand Down Expand Up @@ -119,4 +141,31 @@ def test_task_as_arg(execution, session):
logger = task.logger
assert 1 == logger.filter_by(action="run").count()
assert 1 == logger.filter_by(action="success").count()
assert 0 == logger.filter_by(action="fail").count()
assert 0 == logger.filter_by(action="fail").count()

@pytest.mark.parametrize("execution", ["main", "thread", "process"])
def test_termination_flag_as_arg(execution, session):
if execution == "process":
pytest.skip("For some reason CI fails on process. Termination flag should not be used with process tasks anyways.")

task = FuncTask(func=run_with_termination_flag, name="my_task", start_cond=AlwaysTrue(), execution=execution, session=session)
task.terminate()

@FuncTask(name="terminator", execution="main", start_cond="task 'my_task' has started")
def task_terminate(session=Session()):
session["my_task"].terminate()

session.config.shut_cond = (TaskStarted(task="my_task") >= 1) | ~SchedulerStarted(period=TimeDelta("2 seconds"))

if execution in ("main", "process"):
with pytest.warns(UserWarning):
session.start()
else:
session.start()

logger = task.logger
assert 1 == logger.filter_by(action="run").count()
if execution == "thread":
assert 1 == logger.filter_by(action="terminate").count()
else:
assert 1 == logger.filter_by(action="success").count()

0 comments on commit 32a6ee9

Please sign in to comment.