Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add rocketry.args.TerminationFlag #49

Merged
merged 5 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rocketry/args/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@

from .builtin import Arg, FuncArg, Return, Session, Task
from .builtin import Arg, FuncArg, Return, Session, Task, TerminationFlag
from .secret import Private
12 changes: 11 additions & 1 deletion rocketry/args/builtin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from typing import Any, Callable
import warnings

from rocketry.core.parameters import BaseArgument
from rocketry.core.utils import filter_keyword_args
Expand Down Expand Up @@ -55,6 +56,7 @@ def get_value(self, task=None, **kwargs) -> Any:
else:
return task.session[self.name]


class Return(BaseArgument):
"""A return argument

Expand Down Expand Up @@ -169,4 +171,12 @@ def __call__(self, **kwargs):

def __repr__(self):
cls_name = type(self).__name__
return f'{cls_name}({self.func.__name__})'
return f'{cls_name}({self.func.__name__})'

class TerminationFlag(BaseArgument):

def get_value(self, task=None, session=None, **kwargs) -> Any:
execution = task.execution
if execution in ("process", "main"):
warnings.warn(f"Passing termination flag to task with 'execution_type={execution}''. Flag cannot be used.")
return task._thread_terminate
4 changes: 4 additions & 0 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,10 @@ def delete(self):
Overried if needed additional cleaning."""
self.session.tasks.remove(self)

def terminate(self):
"Terminate this task"
self.force_termination = True

def _get_hooks(self, name:str):
return getattr(self.session.hooks, name)

Expand Down
53 changes: 51 additions & 2 deletions rocketry/test/schedule/test_params.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@

import threading
from time import sleep
import pytest
import rocketry
from rocketry.conditions import DependSuccess
from rocketry.exc import TaskTerminationException


from rocketry.tasks import FuncTask
from rocketry.time import TimeDelta
from rocketry.conditions import SchedulerCycles, SchedulerStarted, TaskStarted, AlwaysFalse, AlwaysTrue

from rocketry.args import Arg, Return, Session, Task, FuncArg #, Param, Session
from rocketry.args import Arg, Return, Session, Task, FuncArg, TerminationFlag #, Param, Session

# Example functions
# -----------------
Expand Down Expand Up @@ -46,6 +49,25 @@ def run_with_session(arg=Session()):
assert isinstance(arg, rocketry.Session)
assert arg.parameters['my_arg'] == 'some session value'

def run_with_termination_flag(flag=TerminationFlag(), task=Task()):
if task.execution == "process":
return
assert isinstance(flag, threading.Event), f"Flag incorrect type: {type(flag)}"
assert not flag.is_set()

if task.execution == "main":
return

waited = 0
while True:
if flag.is_set():
raise TaskTerminationException("Flag raised")

sleep(0.001)
waited += 0.001
if waited > 1:
raise RuntimeError("Did not terminate")

# Tests
# -----

Expand Down Expand Up @@ -119,4 +141,31 @@ def test_task_as_arg(execution, session):
logger = task.logger
assert 1 == logger.filter_by(action="run").count()
assert 1 == logger.filter_by(action="success").count()
assert 0 == logger.filter_by(action="fail").count()
assert 0 == logger.filter_by(action="fail").count()

@pytest.mark.parametrize("execution", ["main", "thread", "process"])
def test_termination_flag_as_arg(execution, session):
if execution == "process":
pytest.skip("For some reason CI fails on process. Termination flag should not be used with process tasks anyways.")

task = FuncTask(func=run_with_termination_flag, name="my_task", start_cond=AlwaysTrue(), execution=execution, session=session)
task.terminate()

@FuncTask(name="terminator", execution="main", start_cond="task 'my_task' has started")
def task_terminate(session=Session()):
session["my_task"].terminate()

session.config.shut_cond = (TaskStarted(task="my_task") >= 1) | ~SchedulerStarted(period=TimeDelta("2 seconds"))

if execution in ("main", "process"):
with pytest.warns(UserWarning):
session.start()
else:
session.start()

logger = task.logger
assert 1 == logger.filter_by(action="run").count()
if execution == "thread":
assert 1 == logger.filter_by(action="terminate").count()
else:
assert 1 == logger.filter_by(action="success").count()