Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Update session #79

Merged
merged 10 commits into from
Aug 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/versions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ Version history
- ``2.3.0``

- Add: Cron style scheduling
- Add: Task groups (``Grouper``) to support bigger applications
- Add: New condition, ``TaskRunnable``
- Add: New methods to session (``remove_task`` & ``create_task``)
- Add: ``always`` time period
- Fix: Various bugs related to ``Any``, ``All`` and ``StaticInterval`` time periods
- Fix: Integers as start and end in time periods
- Upd: Now time periods are immutable
- Upd: Now if session is not specified, tasks create new one.

- ``2.2.0``

Expand Down
2 changes: 1 addition & 1 deletion rocketry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
session = Session()
session.set_as_default()

from .application import Rocketry
from .application import Rocketry, Grouper

from . import _version
__version__ = _version.get_versions()['version']
52 changes: 31 additions & 21 deletions rocketry/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,35 @@
from rocketry import Session

class _AppMixin:

session: Session

def task(self, start_cond=None, name=None, *, command=None, path=None, **kwargs):
def task(self, start_cond=None, name=None, **kwargs):
"Create a task"

kwargs['session'] = self.session
kwargs['start_cond'] = start_cond
kwargs['name'] = name

if command is not None:
return CommandTask(command=command, **kwargs)
elif path is not None:
# Non-wrapped FuncTask
return FuncTask(path=path, **kwargs)
else:
return FuncTask(name_include_module=False, _name_template='{func_name}', **kwargs)
return self.session.create_task(start_cond=start_cond, name=name, **kwargs)

def param(self, name:Optional[str]=None):
"Set one session parameter (decorator)"
return FuncParam(name, session=self.session)

def cond(self, syntax: Union[str, Pattern, List[Union[str, Pattern]]]=None):
"Create a condition (decorator)"
return FuncCond(syntax=syntax, session=self.session, decor_return_func=False)

def params(self, **kwargs):
"Set session parameters"
self.session.parameters.update(kwargs)

def include_grouper(self, group:'Grouper'):
for task in group.session.tasks:
if group.prefix:
task.name = group.prefix + task.name
if group.start_cond is not None:
task.start_cond = task.start_cond & group.start_cond
task.execution = group.execution if task.execution is None else task.execution

self.session.add_task(task)
self.session.parameters.update(group.session.parameters)

class Rocketry(_AppMixin):
"""Rocketry scheduling application"""
Expand Down Expand Up @@ -66,14 +75,6 @@ async def serve(self, debug=False):
self.session.set_as_default()
await self.session.serve()

def cond(self, syntax: Union[str, Pattern, List[Union[str, Pattern]]]=None):
"Create a condition (decorator)"
return FuncCond(syntax=syntax, session=self.session, decor_return_func=False)

def params(self, **kwargs):
"Set session parameters"
self.session.parameters.update(kwargs)

def set_logger(self):
warnings.warn((
"set_logger is deprecated and will be removed in the future. "
Expand Down Expand Up @@ -103,3 +104,12 @@ def _get_repo(self, repo:str):
return CSVFileRepo(filename=filepath, model=LogRecord)
else:
raise NotImplementedError(f"Repo creation for {repo} not implemented")

class Grouper(_AppMixin):

def __init__(self, prefix:str=None, start_cond=None, execution=None):
self.prefix = prefix
self.start_cond = start_cond
self.execution = execution

self.session = Session()
2 changes: 1 addition & 1 deletion rocketry/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def handle_logs(self):
break
else:
self.logger.debug(f"Inserting record for '{record.task_name}' ({record.action})")
task = self.session.get_task(record.task_name)
task = self.session[record.task_name]
if record.action == "fail":
# There is a caveat in logging
# https://github.com/python/cpython/blame/fad6af2744c0b022568f7f4a8afc93fed056d4db/Lib/logging/handlers.py#L1383
Expand Down
17 changes: 11 additions & 6 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@

_IS_WINDOWS = platform.system()

def _create_session():
# To avoid circular imports
from rocketry import Session
return Session()

class Task(RedBase, BaseModel):
"""Base class for Tasks.

Expand Down Expand Up @@ -116,8 +121,7 @@ class Task(RedBase, BaseModel):
Logger of the task. Typically not needed
to be set.
session : rocketry.session.Session, optional
Session the task is binded to,
by default default session
Session the task is binded to.


Attributes
Expand Down Expand Up @@ -253,7 +257,8 @@ def __init__(self, **kwargs):
hooker.prerun(self)

if kwargs.get("session") is None:
kwargs['session'] = self.session
warnings.warn("Task's session not defined. Creating new.", UserWarning)
kwargs['session'] = _create_session()
kwargs['name'] = self._get_name(**kwargs)

super().__init__(**kwargs)
Expand Down Expand Up @@ -811,7 +816,7 @@ def _lock_to_run_log(self, log_queue):
else:

#self.logger.debug(f"Inserting record for '{record.task_name}' ({record.action})")
task = self.session.get_task(record.task_name)
task = self.session[record.task_name]
task.log_record(record)

action = record.action
Expand Down Expand Up @@ -1065,13 +1070,13 @@ def period(self) -> TimePeriod:
session = self.session

if isinstance(cond, (TaskSucceeded, TaskFinished)):
if session.get_task(cond.kwargs["task"]) is self:
if session[cond.kwargs["task"]] is self:
return cond.period

elif isinstance(cond, All):
task_periods = []
for sub_stmt in cond:
if isinstance(sub_stmt, (TaskFinished, TaskFinished)) and session.get_task(sub_stmt.kwargs["task"]) is self:
if isinstance(sub_stmt, (TaskFinished, TaskFinished)) and session[sub_stmt.kwargs["task"]] is self:
task_periods.append(sub_stmt.period)
if task_periods:
return AllTime(*task_periods)
Expand Down
29 changes: 28 additions & 1 deletion rocketry/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,32 @@ def get_tasks(self) -> list:
return self.tasks

def get_task(self, task):
#! TODO: Do we need this?
warnings.warn((
"Method get_task will be removed in the future version."
"Please use instead: session['task name']"
), DeprecationWarning)
return self[task]

def get_cond_parsers(self):
"Used by the actual string condition parser"
return self._cond_parsers

def create_task(self, *, command=None, path=None, **kwargs):
"Create a task and put it to the session"

# To avoid circular imports
from rocketry.tasks import CommandTask, FuncTask

kwargs['session'] = self

if command is not None:
return CommandTask(command=command, **kwargs)
elif path is not None:
# Non-wrapped FuncTask
return FuncTask(path=path, **kwargs)
else:
return FuncTask(name_include_module=False, _name_template='{func_name}', **kwargs)

def add_task(self, task: 'Task'):
"Add the task to the session"
if_exists = self.config.task_pre_exist
Expand All @@ -359,6 +378,14 @@ def add_task(self, task: 'Task'):
raise KeyError(f"Task '{task.name}' already exists")
else:
self.tasks.add(task)

# Adding the session to the task
task.session = self

def remove_task(self, task: Union['Task', str]):
if isinstance(task, str):
task = self[task]
self.session.tasks.remove(task)

def task_exists(self, task: 'Task'):
warnings.warn((
Expand Down
45 changes: 44 additions & 1 deletion rocketry/test/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,36 @@ def do_daily(arg=Arg('arg_3')):
logger = app.session['do_daily'].logger
assert logger.filter_by(action="success").count() == 1

def test_nested_args_from_func_arg():
set_logging_defaults()

# Creating app
app = Rocketry(config={'task_execution': 'main'})

@app.param('arg_1')
def my_arg_1():
return 'arg 1'

def my_func_2(arg=Arg('arg_1')):
assert arg == "arg 1"
return 'arg 2'

def my_func_3(arg_1=Arg('arg_1'), arg_2=FuncArg(my_func_2)):
assert arg_1 == "arg 1"
assert arg_2 == "arg 2"
return 'arg 3'

# Creating a task to test this
@app.task(true)
def do_daily(arg=FuncArg(my_func_3)):
...
assert arg == "arg 3"

app.session.config.shut_cond = TaskStarted(task='do_daily')
app.run()
logger = app.session['do_daily'].logger
assert logger.filter_by(action="success").count() == 1

def test_arg_ref():
set_logging_defaults()

Expand Down Expand Up @@ -203,4 +233,17 @@ def do_never(arg_1):
task_example = session['never done']
assert task_example.execution == 'process'
assert task_example.name == 'never done'
assert dict(task_example.parameters) == {'arg_1': 'something'}
assert dict(task_example.parameters) == {'arg_1': 'something'}


def test_task_name():
set_logging_defaults()

app = Rocketry(config={'task_execution': 'main'})

@app.task()
def do_func():
...
return 'return value'

assert app.session[do_func].name == "do_func"
Loading