diff --git a/.gitignore b/.gitignore index f5eea2f3..c5bd7570 100644 --- a/.gitignore +++ b/.gitignore @@ -54,4 +54,7 @@ docs/_build/ .tox/ # Coverage -cov_data/ \ No newline at end of file +cov_data/ + +#Custom +test.py \ No newline at end of file diff --git a/rocketry/_base.py b/rocketry/_base.py index ff1ebf14..e5a678a6 100644 --- a/rocketry/_base.py +++ b/rocketry/_base.py @@ -1,8 +1,12 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ClassVar +from pydantic.dataclasses import dataclass, Field +from pydantic import BaseModel if TYPE_CHECKING: from rocketry import Session class RedBase: """Baseclass for all Rocketry classes""" - session: 'Session' = None + + # Commented this out for now as it was causing issues with the new pydantic implementation + session: 'Session' diff --git a/rocketry/_setup.py b/rocketry/_setup.py index ac681b40..80f0e499 100644 --- a/rocketry/_setup.py +++ b/rocketry/_setup.py @@ -3,7 +3,7 @@ from rocketry.session import Session, Config from rocketry.parse import add_condition_parser from rocketry.conds import true, false -from rocketry.tasks import CommandTask, FuncTask, CodeTask +from rocketry.tasks import CommandTask, FuncTask, CodeTask, _DummyTask from rocketry.tasks.maintain import ShutDown, Restart from rocketry.conditions.meta import _FuncTaskCondWrapper @@ -23,14 +23,24 @@ def _setup_defaults(): cls_tasks = ( Task, FuncTask, CommandTask, CodeTask, - ShutDown, Restart, + ShutDown, Restart, _DummyTask, _FuncTaskCondWrapper ) for cls_task in cls_tasks: - cls_task.update_forward_refs(Session=Session, BaseCondition=BaseCondition) + #cls_task.update_forward_refs(Session=Session, BaseCondition=BaseCondition) + cls_task.model_rebuild( + force=True, + _types_namespace={"Session": Session, "BaseCondition": BaseCondition}, + _parent_namespace_depth=4 + ) - Config.update_forward_refs(BaseCondition=BaseCondition) + # Config.update_forward_refs(BaseCondition=BaseCondition) + Config.model_rebuild( + force=True, + _types_namespace={"Session": Session, "BaseCondition": BaseCondition}, + _parent_namespace_depth=4 + ) #Session.update_forward_refs( # Task=Task, Parameters=Parameters, Scheduler=Scheduler #) diff --git a/rocketry/conditions/meta.py b/rocketry/conditions/meta.py index d12c5856..37e3942b 100644 --- a/rocketry/conditions/meta.py +++ b/rocketry/conditions/meta.py @@ -1,6 +1,6 @@ import copy -from typing import Callable, Optional, Pattern, Union +from typing import Callable, ClassVar, Optional, Pattern, Union from pydantic import Field from rocketry.args import Session diff --git a/rocketry/core/task.py b/rocketry/core/task.py index a3bd0b34..1921e205 100644 --- a/rocketry/core/task.py +++ b/rocketry/core/task.py @@ -15,12 +15,13 @@ import threading from queue import Empty from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Dict, Type, Union, Tuple, Optional +from typing_extensions import Annotated try: from typing import Literal except ImportError: # pragma: no cover from typing_extensions import Literal -from pydantic import BaseModel, Field, PrivateAttr, validator +from pydantic import BaseModel, Field, PrivateAttr, ConfigDict, field_validator, field_serializer from rocketry._base import RedBase from rocketry.core.condition import BaseCondition, AlwaysFalse, All @@ -94,7 +95,7 @@ def is_async(self) -> bool: def is_thread(self) -> bool: return isinstance(self.task, threading.Thread) -class Task(RedBase, BaseModel): +class Task(BaseModel, RedBase): """Base class for Tasks. A task can be a function, command or other procedure that @@ -192,42 +193,37 @@ class Task(RedBase, BaseModel): ... return ... """ - class Config: - arbitrary_types_allowed= True - underscore_attrs_are_private = True - validate_assignment = True - json_encoders = { - Parameters: lambda v: v.to_json(), - 'BaseCondition': lambda v: str(v), - FunctionType: lambda v: v.__name__, - 'Session': lambda v: id(v), - } - + model_config = ConfigDict( + arbitrary_types_allowed= True, + validate_assignment = True, + extra='allow', + ) - session: 'Session' = Field() + session: 'Session' = Field(default=None, validate_default=False) + # Class permanent: bool = False # Whether the task is not meant to finish (Ie. RestAPI) _actions: ClassVar[Tuple] = ("run", "fail", "success", "inaction", "terminate", None, "crash") fmt_log_message: str = r"Task '{task}' status: '{action}'" - daemon: Optional[bool] + daemon: Optional[bool] = None batches: List[Parameters] = Field( default_factory=list, description="Run batches (parameters). If not empty, run is triggered regardless of starting condition" ) # Instance - name: Optional[str] = Field(description="Name of the task. Must be unique") - description: Optional[str] = Field(description="Description of the task for documentation") - logger_name: Optional[str] = Field(description="Logger name to be used in logging the task records") - execution: Optional[Literal['main', 'async', 'thread', 'process']] + name: Optional[str] = Field(description="Name of the task. Must be unique", default=None) + description: Optional[str] = Field(description="Description of the task for documentation", default=None) + logger_name: Optional[str] = Field(description="Logger name to be used in logging the task records", default="rocketry.task") + execution: Optional[Literal['main', 'async', 'thread', 'process']] = None priority: int = 0 disabled: bool = False force_run: bool = False force_termination: bool = False - status: Optional[Literal['run', 'fail', 'success', 'terminate', 'inaction', 'crash']] = Field(description="Latest status of the task") - timeout: Optional[datetime.timedelta] + status: Optional[Literal['run', 'fail', 'success', 'terminate', 'inaction', 'crash']] = Field(description="Latest status of the task", default=None) + timeout: Optional[datetime.timedelta] = None parameters: Parameters = Parameters() @@ -237,7 +233,7 @@ class Config: multilaunch: Optional[bool] = None on_startup: bool = False on_shutdown: bool = False - func_run_id: Callable = None + func_run_id: Union[Callable, None] = None _last_run: Optional[float] _last_success: Optional[float] @@ -252,29 +248,29 @@ class Config: _mark_running = False - @validator('start_cond', pre=True) + @field_validator('start_cond', mode="before") def parse_start_cond(cls, value, values): from rocketry.parse.condition import parse_condition - session = values['session'] + session = values.data['session'] if isinstance(value, str): value = parse_condition(value, session=session) elif value is None: value = AlwaysFalse() return copy(value) - @validator('end_cond', pre=True) + @field_validator('end_cond', mode="before") def parse_end_cond(cls, value, values): from rocketry.parse.condition import parse_condition - session = values['session'] + session = values.data['session'] if isinstance(value, str): value = parse_condition(value, session=session) elif value is None: value = AlwaysFalse() return copy(value) - @validator('logger_name', pre=True, always=True) + @field_validator('logger_name', mode="before") def parse_logger_name(cls, value, values): - session = values['session'] + session = values.data['session'] if isinstance(value, str): logger_name = value @@ -287,7 +283,7 @@ def parse_logger_name(cls, value, values): raise ValueError(f"Logger name must start with '{basename}' as session finds loggers with names") return logger_name - @validator('timeout', pre=True, always=True) + @field_validator('timeout', mode="before") def parse_timeout(cls, value, values): if value == "never": return datetime.timedelta.max @@ -296,6 +292,22 @@ def parse_timeout(cls, value, values): if value is not None: return to_timedelta(value) return value + + @field_serializer("parameters", when_used="json") + def ser_parameters(self, parameters): + return parameters.to_json() + + @field_serializer("start_cond", when_used="json") + def ser_start_cond(self, start_cond): + return str(start_cond) + + @field_serializer("end_cond", when_used="json") + def ser_end_cond(self, end_cond): + return str(end_cond) + + @field_serializer("session", when_used="json", check_fields=False) + def ser_session(self, session): + return id(session) @property def logger(self): @@ -339,9 +351,9 @@ def _get_name(self, name=None, **kwargs): return self.get_default_name(**kwargs) return name - @validator('name', pre=True) + @field_validator('name', mode="before") def parse_name(cls, value, values): - session = values['session'] + session = values.data['session'] on_exists = session.config.task_pre_exist name_exists = value in session if name_exists: @@ -359,9 +371,9 @@ def parse_name(cls, value, values): return name return value - @validator('name', pre=False) + @field_validator('name', mode="after") def validate_name(cls, value, values): - session = values['session'] + session = values.data['session'] on_exists = session.config.task_pre_exist name_exists = value in session @@ -371,17 +383,17 @@ def validate_name(cls, value, values): raise ValueError(f"Task name '{value}' already exists. Please pick another") return value - @validator('parameters', pre=True) + @field_validator('parameters', mode="before") def parse_parameters(cls, value): if isinstance(value, Parameters): return value return Parameters(value) - @validator('force_run', pre=False) + @field_validator('force_run', mode="after") def parse_force_run(cls, value, values): if value: warnings.warn("Attribute 'force_run' is deprecated. Please use method set_running() instead", DeprecationWarning) - values['batches'].append(Parameters()) + values.data['batches'].append(Parameters()) return value def __hash__(self): @@ -731,7 +743,6 @@ def run_as_process(self, params:Parameters, direct_params:Parameters, task_run:T self._run_stack.append(task_run) self._mark_running = True # needed in pickling - process.start() self._mark_running = False @@ -751,7 +762,6 @@ def _run_as_process(self, params:Parameters, direct_params:Parameters, task_run, # in the actual multiprocessing's process. We only add QueueHandler to the # logger (with multiprocessing.Queue as queue) so that all the logging # records end up in the main process to be logged properly. - basename = self.logger_name # handler = logging.handlers.QueueHandler(queue) handler = QueueHandler(queue) @@ -1294,8 +1304,8 @@ def __getstate__(self): #state['__dict__'] = state['__dict__'].copy() # remove unpicklable - state['__private_attribute_values__'] = state['__private_attribute_values__'].copy() - priv_attrs = state['__private_attribute_values__'] + state['__pydantic_private__'] = state['__pydantic_private__'].copy() + priv_attrs = state['__pydantic_private__'] priv_attrs['_lock'] = None priv_attrs['_process'] = None priv_attrs['_thread'] = None @@ -1404,5 +1414,5 @@ def json(self, **kwargs): if 'exclude' not in kwargs: kwargs['exclude'] = set() kwargs['exclude'].update({'session'}) - d = super().json(**kwargs) + d = super().model_dump_json(**kwargs) return d \ No newline at end of file diff --git a/rocketry/log/log_record.py b/rocketry/log/log_record.py index da321e99..eb9be88a 100644 --- a/rocketry/log/log_record.py +++ b/rocketry/log/log_record.py @@ -1,6 +1,6 @@ import datetime from typing import Optional -from pydantic import BaseModel, Field, validator +from pydantic import field_validator, BaseModel, Field from rocketry.pybox.time import to_datetime, to_timedelta @@ -38,36 +38,39 @@ class LogRecord(MinimalRecord): class TaskLogRecord(MinimalRecord): - start: Optional[datetime.datetime] - end: Optional[datetime.datetime] - runtime: Optional[datetime.timedelta] + start: Optional[datetime.datetime] = None + end: Optional[datetime.datetime] = None + runtime: Optional[datetime.timedelta] = None message: str - exc_text: Optional[str] + exc_text: Optional[str] = None - @validator("start", pre=True) + @field_validator("start", mode="before") + @classmethod def format_start(cls, value): if value is not None: value = to_datetime(value) return value - @validator("end", pre=True) + @field_validator("end", mode="before") + @classmethod def format_end(cls, value): if value is not None: value = to_datetime(value) return value - @validator("runtime", pre=True) + @field_validator("runtime", mode="before") + @classmethod def format_runtime(cls, value): if value is not None: value = to_timedelta(value) return value class MinimalRunRecord(MinimalRecord): - run_id: Optional[str] + run_id: Optional[str] = None class RunRecord(LogRecord): - run_id: Optional[str] + run_id: Optional[str] = None class TaskRunRecord(TaskLogRecord): - run_id: Optional[str] + run_id: Optional[str] = None diff --git a/rocketry/session.py b/rocketry/session.py index 8c315d3c..fb65e604 100644 --- a/rocketry/session.py +++ b/rocketry/session.py @@ -13,7 +13,7 @@ from itertools import chain from typing import TYPE_CHECKING, Callable, ClassVar, Iterable, Dict, List, Optional, Set, Tuple, Type, Union -from pydantic import BaseModel, root_validator, validator +from pydantic import field_validator, model_validator, ConfigDict, BaseModel, validator from rocketry.pybox.time import to_timedelta from rocketry.log.defaults import create_default_handler from rocketry._base import RedBase @@ -40,9 +40,11 @@ class Config(BaseModel): - class Config: - validate_assignment = True - arbitrary_types_allowed = True + model_config = ConfigDict( + validate_assignment=True, + arbitrary_types_allowed=True, + validate_default=True + ) # Fields use_instance_naming: bool = False @@ -62,27 +64,29 @@ class Config: multilaunch: bool = False func_run_id: Callable = uuid - max_process_count = cpu_count() + max_process_count:int = cpu_count() tasks_as_daemon: bool = True restarting: str = 'replace' instant_shutdown: bool = False timeout: datetime.timedelta = datetime.timedelta(minutes=30) shut_cond: Optional['BaseCondition'] = None - cls_lock: Type = threading.Lock + cls_lock: Callable = threading.Lock param_materialize:Literal['pre', 'post'] = 'post' timezone: Optional[datetime.tzinfo] = None - time_func: Callable = None + time_func: Union[Callable, None] = None + - @validator('execution', pre=True, always=True) + @field_validator('execution', mode="before") def parse_task_execution(cls, value): if value is None: return 'async' return value - @validator('shut_cond', pre=True) + @field_validator('shut_cond', mode="before") + @classmethod def parse_shut_cond(cls, value): from rocketry.parse import parse_condition from rocketry.conditions import AlwaysFalse @@ -90,7 +94,8 @@ def parse_shut_cond(cls, value): return AlwaysFalse() return parse_condition(value) - @validator('timeout', pre=True, always=True) + + @field_validator('timeout', mode="before") def parse_timeout(cls, value): if isinstance(value, str): return to_timedelta(value) @@ -107,7 +112,8 @@ def task_execution(self): ) return self.execution - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_deprecated(cls, values): if 'task_execution' in values: warnings.warn( @@ -168,8 +174,7 @@ class Session(RedBase): """ config: Config - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) tasks: Set['Task'] hooks: Hooks @@ -531,8 +536,9 @@ def __getstate__(self): state["_cond_cache"] = None state["_cond_parsers"] = None state["session"] = None - #state["parameters"] = None + # state["parameters"] = None state['scheduler'] = None + state['returns'] = None return state def _copy_pickle(self): @@ -543,7 +549,10 @@ def _copy_pickle(self): new_self = copy(self) for attr in unpicklable: setattr(new_self, attr, None) - new_self.config = self.config.copy(exclude=unpicklable_conf) + + data = self.config.model_dump(exclude=unpicklable_conf, round_trip=True) + copied = self.config.model_validate(data) + new_self.config = copied return new_self @property diff --git a/rocketry/tasks/__init__.py b/rocketry/tasks/__init__.py index 372997b8..c761e344 100644 --- a/rocketry/tasks/__init__.py +++ b/rocketry/tasks/__init__.py @@ -1,4 +1,5 @@ from .func import FuncTask from .code import CodeTask from .command import CommandTask +from ._dummy import _DummyTask from . import maintain diff --git a/rocketry/tasks/_dummy.py b/rocketry/tasks/_dummy.py new file mode 100644 index 00000000..50a267f9 --- /dev/null +++ b/rocketry/tasks/_dummy.py @@ -0,0 +1,10 @@ +from rocketry.core import Task +class _DummyTask(Task): + """ + Not used within core application. Only used in UnitTests + DummyTask which inherits task and performs forward_refs + to allow for use inside unit tests. Provides basic implementation + of Task classs and overwrites required abstractmethods. + """ + def execute(self, *args, **kwargs): + return diff --git a/rocketry/tasks/command.py b/rocketry/tasks/command.py index d58a45c6..2748823e 100644 --- a/rocketry/tasks/command.py +++ b/rocketry/tasks/command.py @@ -7,7 +7,7 @@ except ImportError: # pragma: no cover from typing_extensions import Literal -from pydantic import Field, validator +from pydantic import Field, field_validator from rocketry.core.parameters.parameters import Parameters from rocketry.core.task import Task @@ -43,9 +43,9 @@ class CommandTask(Task): command: Union[str, List[str]] shell: bool = False - cwd: Optional[str] + cwd: Optional[str] = None kwds_popen: dict = {} - argform: Optional[Literal['-', '--', 'short', 'long']] = Field(description="Whether the arguments are turned as short or long form command line arguments") + argform: Optional[Literal['-', '--', 'short', 'long']] = Field(description="Whether the arguments are turned as short or long form command line arguments", default=None) def get_kwargs_popen(self) -> dict: kwargs = { @@ -58,7 +58,7 @@ def get_kwargs_popen(self) -> dict: kwargs.update(self.kwds_popen) return kwargs - @validator('argform') + @field_validator('argform') def parse_argform(cls, value): return { "long": "--", diff --git a/rocketry/tasks/func.py b/rocketry/tasks/func.py index 101912f0..13307dcc 100644 --- a/rocketry/tasks/func.py +++ b/rocketry/tasks/func.py @@ -5,7 +5,8 @@ from typing import Callable, List, Optional import warnings -from pydantic import Field, PrivateAttr, validator +from pydantic import Field, PrivateAttr, field_validator, field_serializer +from pydantic.main import _object_setattr from rocketry.core.task import Task from rocketry.core.parameters import Parameters @@ -128,9 +129,9 @@ def wrapper(*args, **kwargs): def my_task_func(): ... """ - func: Optional[Callable] = Field(description="Executed function") + func: Optional[Callable] = Field(description="Executed function", default=None) - path: Optional[Path] = Field(description="Path to the script that is executed") + path: Optional[Path] = Field(description="Path to the script that is executed", default = None) func_name: Optional[str] = Field(default="main", description="Name of the function in given path. Pass path as well") cache: bool = False @@ -143,16 +144,17 @@ def my_task_func(): def delayed(self): return self._is_delayed - @validator('path') + + @field_validator('path') def validate_path(cls, value: Path, values): - name = values['name'] + name = values.data['name'] if value is not None and not value.is_file(): warnings.warn(f"Path {value} does not exists. Task '{name}' may fail.") return value - @validator("func") + @field_validator("func") def validate_func(cls, value, values): - execution = values.get('execution') + execution = values.data.get('execution') func = value if execution == "process" and getattr(func, "__name__", None) == "": @@ -161,10 +163,17 @@ def validate_func(cls, value, values): "The function must be pickleable if task's execution is 'process'. " ) return value + + @field_serializer("func", when_used="json") + def ser_func(self, func): + return func.__name__ + def __init__(self, func=None, **kwargs): only_func_set = func is not None and not kwargs no_func_set = func is None and kwargs.get('path') is None + _object_setattr(self, "__pydantic_extra__", {}) + _object_setattr(self, "__pydantic_private__", None) if no_func_set: # FuncTask was probably called like: # @FuncTask(...) @@ -184,6 +193,7 @@ def __init__(self, func=None, **kwargs): # the execution to else than process # as it's obvious it would not work. kwargs["execution"] = "thread" + super().__init__(func=func, **kwargs) self._set_descr(is_delayed=func is None) diff --git a/rocketry/test/condition/test_meta.py b/rocketry/test/condition/test_meta.py index 96e4e4d8..44ccabb2 100644 --- a/rocketry/test/condition/test_meta.py +++ b/rocketry/test/condition/test_meta.py @@ -15,6 +15,7 @@ def is_foo(status): return True return False + @pytest.mark.parametrize("execution", ["main", "thread", "process"]) def test_taskcond_true(session, execution): assert session._cond_cache == {} @@ -30,7 +31,7 @@ def test_taskcond_true(session, execution): session.config.shut_cond = (TaskStarted(task="a task") >= 2) | ~SchedulerStarted(period="past 5 seconds") session.start() - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) history_task = [ rec for rec in records @@ -73,7 +74,7 @@ def test_taskcond_false(session, execution): session.config.shut_cond = SchedulerCycles() >= 3 session.start() - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) history_task = [ rec for rec in records diff --git a/rocketry/test/schedule/test_core.py b/rocketry/test/schedule/test_core.py index f44c411a..516800ac 100644 --- a/rocketry/test/schedule/test_core.py +++ b/rocketry/test/schedule/test_core.py @@ -85,6 +85,7 @@ def test_scheduler_shut_cond(session): @pytest.mark.parametrize("execution", ["main", "async", "thread", "process"]) @pytest.mark.parametrize("func", [pytest.param(create_line_to_file, id="sync"), pytest.param(create_line_to_file_async, id="async")]) def test_task_execution(tmpdir, execution, func, session): + with tmpdir.as_cwd(): # To be confident the scheduler won't lie to us # we test the task execution with a job that has @@ -92,7 +93,6 @@ def test_task_execution(tmpdir, execution, func, session): FuncTask(func, name="add line to file", start_cond=AlwaysTrue(), execution=execution, session=session) session.config.shut_cond = (TaskStarted(task="add line to file") >= 3) | ~SchedulerStarted(period=TimeDelta("5 second")) - session.start() # Sometimes in CI the task may end up to be started only twice thus we tolerate slightly with open("work.txt", "r", encoding="utf-8") as file: @@ -131,6 +131,7 @@ def test_task_log(tmpdir, execution, task_func, run_count, fail_count, success_c """ # Set session (and logging) + print("test_task_log") session = Session(config={"debug": True, "silence_task_logging": False, "execution": "process"}) rocketry.session = session session.set_as_default() @@ -159,7 +160,7 @@ def test_task_log(tmpdir, execution, task_func, run_count, fail_count, success_c for record in history: is_tasl_log = isinstance(record, TaskLogRecord) if not isinstance(record, dict): - record = record.dict() + record = record.model_dump() assert record["task_name"] == "mytask" assert isinstance(record["created"], float) assert isinstance(record["start"], datetime.datetime if is_tasl_log else float) @@ -188,6 +189,7 @@ def test_task_log(tmpdir, execution, task_func, run_count, fail_count, success_c @pytest.mark.parametrize("func_type", ["sync", "async"]) @pytest.mark.parametrize("execution", ["main", "thread", "process"]) def test_task_status(session, execution, func_type, mode): + print("test_task_status") session.config.force_status_from_logs = mode == "use logs" task_success = FuncTask( @@ -272,6 +274,7 @@ def test_task_status(session, execution, func_type, mode): @pytest.mark.parametrize("execution", ["main", "thread", "process"]) def test_task_disabled(tmpdir, execution, session): + print("test_task_disabled") with tmpdir.as_cwd(): task = FuncTask( @@ -293,6 +296,7 @@ def test_task_disabled(tmpdir, execution, session): @pytest.mark.parametrize("execution", ["main", "thread", "process"]) def test_priority(execution, session): + print("test_task_priority") session.config.max_process_count = 4 task_1 = FuncTask(run_succeeding, name="1", priority=100, start_cond=AlwaysTrue(), execution=execution, session=session) task_3 = FuncTask(run_failing, name="3", priority=10, start_cond=AlwaysTrue(), execution=execution, session=session) @@ -315,6 +319,7 @@ def test_priority(execution, session): @pytest.mark.parametrize("execution", ["main", "thread", "process"]) def test_pass_params_as_global(execution, session): + print("test_pass_params_as_global") # thread-Parameters has been observed to fail rarely task = FuncTask(run_with_param, name="parametrized", start_cond=AlwaysTrue(), execution=execution, session=session) diff --git a/rocketry/test/schedule/test_failure.py b/rocketry/test/schedule/test_failure.py index 28ec9488..1a97636b 100644 --- a/rocketry/test/schedule/test_failure.py +++ b/rocketry/test/schedule/test_failure.py @@ -69,7 +69,7 @@ def test_param_failure(tmpdir, execution, session, fail_in): session.start() assert task.status == "fail" - records = list(map(lambda d: d.dict(exclude={'created'}), task.logger.get_records())) + records = list(map(lambda d: d.model_dump(exclude={'created'}), task.logger.get_records())) assert [{"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "fail"}] == records @pytest.mark.parametrize( @@ -93,7 +93,7 @@ def test_session_param_failure(tmpdir, execution, session, fail_in): session.start() assert task.status == "fail" - records = list(map(lambda d: d.dict(exclude={'created'}), task.logger.get_records())) + records = list(map(lambda d: d.model_dump(exclude={'created'}), task.logger.get_records())) assert [{"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "fail"}] == records diff --git a/rocketry/test/session/test_logs.py b/rocketry/test/session/test_logs.py index 70974972..6622c80d 100644 --- a/rocketry/test/session/test_logs.py +++ b/rocketry/test/session/test_logs.py @@ -3,7 +3,7 @@ import datetime import logging from typing import Optional -from pydantic import root_validator, validator +from pydantic import field_validator, model_validator import pytest @@ -33,23 +33,25 @@ def do_fail(): raise RuntimeError("Oops") class CustomRecord(MinimalRecord): - timestamp: Optional[datetime.datetime] - start: Optional[datetime.datetime] - end: Optional[datetime.datetime] - runtime: Optional[datetime.timedelta] + timestamp: Optional[datetime.datetime] = None + start: Optional[datetime.datetime] = None + end: Optional[datetime.datetime] = None + runtime: Optional[datetime.timedelta] = None message: str - @validator("start", pre=True) + @field_validator("start", mode="before") + @classmethod def parse_start(cls, value): if value is not None: return datetime.datetime.fromtimestamp(value) - @validator("end", pre=True) + @field_validator("end", mode="before") + @classmethod def parse_end(cls, value): if value is not None: return datetime.datetime.fromtimestamp(value) - @root_validator + @model_validator(mode="before") def validate_timestamp(cls, values): values['timestamp'] = datetime.datetime.fromtimestamp(values['created']) return values @@ -258,7 +260,7 @@ def test_get_logs_params(tmpdir, mock_pydatetime, mock_time, query, expected, se logs = list(logs) assert len(expected) == len(logs) - logs = list(map(lambda e: e.dict(), logs)) + logs = list(map(lambda e: e.model_dump(), logs)) for e, a in zip(expected, logs): #assert e.keys() <= a.keys() # Check all expected items in actual (actual can contain extra) diff --git a/rocketry/test/task/code/test_construct.py b/rocketry/test/task/code/test_construct.py index 4125371f..3a4439f3 100644 --- a/rocketry/test/task/code/test_construct.py +++ b/rocketry/test/task/code/test_construct.py @@ -80,6 +80,6 @@ def main(): assert task.status == 'fail' - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) record_fail = [r for r in records if r['action'] == 'fail'][0] assert 'File "", line 5, in \n File "", line 3, in main\nRuntimeError: Failed' in record_fail['exc_text'] diff --git a/rocketry/test/task/command/test_run.py b/rocketry/test/task/command/test_run.py index c9205d35..ebbb9238 100644 --- a/rocketry/test/task/command/test_run.py +++ b/rocketry/test/task/command/test_run.py @@ -72,13 +72,15 @@ def test_fail_command(tmpdir, execution, session): wait_till_task_finish(task) - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert "fail" == task.status err = records[1]["exc_text"].strip().replace('\r', '') if sys.version_info >= (3, 8): - expected = "OSError: Failed running command (2): \nunknown option --not_an_arg\nusage: python [option] ... [-c cmd | -m mod | file | -] [arg] ...\nTry `python -h' for more information." - assert err.endswith(expected) + # Somethings the file path in before 'python' changing endswith to two in statements instead + assert "OSError: Failed running command (2): \nunknown option --not_an_arg\nusage:" in err and "python [option] ... [-c cmd | -m mod | file | -] [arg] ...\nTry `python -h' for more information." in err + # expected = "OSError: Failed running command (2): \nunknown option --not_an_arg\nusage: python [option] ... [-c cmd | -m mod | file | -] [arg] ...\nTry `python -h' for more information." + # assert err.endswith(expected) else: assert err.endswith("Try `python -h' for more information.") assert "OSError: Failed running command (2)" in err diff --git a/rocketry/test/task/func/test_export.py b/rocketry/test/task/func/test_export.py index 2b9e767c..87b92e17 100644 --- a/rocketry/test/task/func/test_export.py +++ b/rocketry/test/task/func/test_export.py @@ -4,9 +4,9 @@ def test_to_dict(session): task1 = FuncTask(func=lambda: None, name="task 1", start_cond="every 10 seconds", session=session) task2 = FuncTask(func=lambda: None, name="task 2", start_cond="after task 'task 1'", session=session) - task1.dict() - task2.dict() + task1.model_dump() + task2.model_dump() - task1.json() - task2.json() + task1.model_dump_json() + task2.model_dump_json() pass diff --git a/rocketry/test/task/func/test_logging.py b/rocketry/test/task/func/test_logging.py index 3025cb56..9543d9c8 100644 --- a/rocketry/test/task/func/test_logging.py +++ b/rocketry/test/task/func/test_logging.py @@ -150,7 +150,7 @@ def create_record(action, task_name): records = session.get_task_log() records = [ - record.dict(exclude={"created"}) + record.model_dump(exclude={"created"}) for record in records ] assert [ @@ -275,7 +275,7 @@ def test_action_start(tmpdir, method, session): task.log_running() getattr(task, method)() - records = list(map(lambda e: e.dict(), session.get_task_log())) + records = list(map(lambda e: e.model_dump(), session.get_task_log())) assert len(records) == 2 # First should not have "end" diff --git a/rocketry/test/task/func/test_run.py b/rocketry/test/task/func/test_run.py index 3411d85b..04281c33 100644 --- a/rocketry/test/task/func/test_run.py +++ b/rocketry/test/task/func/test_run.py @@ -73,7 +73,7 @@ def test_run(task_func, expected_outcome, exc_cls, execution, session): assert task.status == expected_outcome - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": expected_outcome}, @@ -119,7 +119,7 @@ def test_run_async(task_func, expected_outcome, execution, session): assert task.status == expected_outcome - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": expected_outcome}, @@ -261,7 +261,7 @@ def test_parametrization_runtime(session): task(params={"integer": 1, "string": "X", "optional_float": 1.1, "extra_parameter": "Should not be passed"}) - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, @@ -279,7 +279,7 @@ def test_parametrization_local(session): task() - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, @@ -297,7 +297,7 @@ def test_parametrization_kwargs(session): task() - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, diff --git a/rocketry/test/task/func/test_run_delayed.py b/rocketry/test/task/func/test_run_delayed.py index 904fe2f2..27a8bdf2 100644 --- a/rocketry/test/task/func/test_run_delayed.py +++ b/rocketry/test/task/func/test_run_delayed.py @@ -47,7 +47,7 @@ def test_run(tmpdir, script_files, script_path, expected_outcome, exc_cls, execu assert task.status == expected_outcome - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": expected_outcome}, @@ -72,7 +72,7 @@ def myfunc(): ) task() - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, @@ -102,7 +102,7 @@ def main(): ) task() - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, @@ -136,7 +136,7 @@ def main(): ) task() - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, @@ -166,7 +166,7 @@ def main(val_5, optional=None): ) task(params={"val_5":5}) - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, @@ -199,7 +199,7 @@ def main(val_5, optional=None): ) task(params={"val_5":5}) - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, @@ -219,7 +219,7 @@ def test_parametrization_runtime(tmpdir, script_files, session): task(params={"integer": 1, "string": "X", "optional_float": 1.1, "extra_parameter": "Should not be passed"}) - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, @@ -239,7 +239,7 @@ def test_parametrization_local(tmpdir, script_files, session): task() - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, @@ -259,7 +259,7 @@ def test_parametrization_kwargs(tmpdir, script_files, session): task() - records = list(map(lambda e: e.dict(exclude={'created'}), session.get_task_log())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), session.get_task_log())) assert [ {"task_name": "a task", "action": "run"}, {"task_name": "a task", "action": "success"}, diff --git a/rocketry/test/task/misc/test_restart.py b/rocketry/test/task/misc/test_restart.py index f4fc8982..713d83cd 100644 --- a/rocketry/test/task/misc/test_restart.py +++ b/rocketry/test/task/misc/test_restart.py @@ -36,7 +36,7 @@ def test_scheduler_restart(tmpdir, session): cont = f.read() assert "StartedShutStartedShut" == cont - records = list(map(lambda e: e.dict(exclude={'created'}), task.logger.get_records())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), task.logger.get_records())) assert 1 == len([record for record in records if record["action"] == "run"]) assert 1 == len([record for record in records if record["action"] == "success"]) diff --git a/rocketry/test/task/misc/test_shutdown.py b/rocketry/test/task/misc/test_shutdown.py index c02675a4..9bdb2d3d 100644 --- a/rocketry/test/task/misc/test_shutdown.py +++ b/rocketry/test/task/misc/test_shutdown.py @@ -33,6 +33,6 @@ def test_scheduler_shutdown(tmpdir, session): cont = f.read() assert "StartedShut" == cont - records = list(map(lambda e: e.dict(exclude={'created'}), task.logger.get_records())) + records = list(map(lambda e: e.model_dump(exclude={'created'}), task.logger.get_records())) assert 1 == len([record for record in records if record["action"] == "run"]) assert 1 == len([record for record in records if record["action"] == "success"]) diff --git a/rocketry/test/task/test_core.py b/rocketry/test/task/test_core.py index 8612eb4f..497c28e1 100644 --- a/rocketry/test/task/test_core.py +++ b/rocketry/test/task/test_core.py @@ -2,9 +2,12 @@ import logging import pickle from textwrap import dedent +from typing import ClassVar, Generic, Any +from pydantic import Field, BaseModel import pytest from rocketry.args.builtin import Return from rocketry.core import Task as BaseTask +from rocketry.tasks import _DummyTask from rocketry.core.condition.base import AlwaysFalse from rocketry.args import Arg, Session, Task from rocketry.exc import TaskLoggingError @@ -12,42 +15,38 @@ from rocketry import Session as SessionClass from rocketry.testing.log import create_task_record -class DummyTask(BaseTask): - - def execute(self, *args, **kwargs): - return - def test_defaults(session): - task = DummyTask(name="mytest", session=session) + + task = _DummyTask(name="mytest", session=session) assert task.name == "mytest" assert isinstance(task.start_cond, AlwaysFalse) assert isinstance(task.end_cond, AlwaysFalse) def test_defaults_no_session(session): with pytest.warns(UserWarning): - task = DummyTask(name="mytest") + task = _DummyTask(name="mytest") assert task.session is not session assert isinstance(task.session, SessionClass) assert task.session.tasks == {task} def test_set_timeout(session): - task = DummyTask(timeout="1 hour 20 min", session=session, name="1") + task = _DummyTask(session=session, timeout="1 hour 20 min", name="1") assert task.timeout == datetime.timedelta(hours=1, minutes=20) - task = DummyTask(timeout=datetime.timedelta(hours=1, minutes=20), session=session, name="2") + task = _DummyTask(timeout=datetime.timedelta(hours=1, minutes=20), session=session, name="2") assert task.timeout == datetime.timedelta(hours=1, minutes=20) - task = DummyTask(timeout=20, session=session, name="3") + task = _DummyTask(timeout=20, session=session, name="3") assert task.timeout == datetime.timedelta(seconds=20) def test_delete(session): - task = DummyTask(name="mytest", session=session) + task = _DummyTask(name="mytest", session=session) assert session.tasks == {task} task.delete() assert session.tasks == set() def test_set_invalid_status(session): - task = DummyTask(name="mytest", session=session) + task = _DummyTask(name="mytest", session=session) with pytest.raises(ValueError): task.status = "not valid" @@ -58,7 +57,7 @@ def emit(self, record): raise RuntimeError("Oops") logging.getLogger("rocketry.task").handlers.insert(0, MyHandler()) - task = DummyTask(name="mytest", session=session) + task = _DummyTask(name="mytest", session=session) for func in (task.log_crash, task.log_failure, task.log_success, task.log_inaction, task.log_termination): with pytest.raises(TaskLoggingError): func() @@ -69,13 +68,13 @@ def emit(self, record): task.log_record(record) # Used by process logging def test_pickle(session): - task_1 = DummyTask(name="mytest", session=session) + task_1 = _DummyTask(name="mytest", session=session) pkl_obj = pickle.dumps(task_1) task_2 = pickle.loads(pkl_obj) assert task_1.name == task_2.name def test_crash(session): - task = DummyTask(name="mytest", session=session) + task = _DummyTask(name="mytest", session=session) task.set_cached() task.log_running() assert task.status == "run" @@ -83,7 +82,7 @@ def test_crash(session): task.delete() # Recreating and now should log crash - task = DummyTask(name="mytest", session=session) + task = _DummyTask(name="mytest", session=session) task.set_cached() assert task.status == "crash" assert task.last_crash @@ -92,7 +91,7 @@ def test_crash(session): assert [ {'action': 'run', 'task_name': 'mytest'}, {'action': 'crash', 'task_name': 'mytest'} - ] == [log.dict(exclude={'created'}) for log in logs] + ] == [log.model_dump(exclude={'created'}) for log in logs] def test_json(session): session.parameters['x'] = 5 @@ -100,7 +99,7 @@ def test_json(session): repo.add(MinimalRecord(task_name="mytest", action="run", created=1640988000)) repo.add(MinimalRecord(task_name="mytest", action="success", created=1640988060)) - task = DummyTask(name="mytest", parameters={ + task = _DummyTask(name="mytest", parameters={ "arg_2": Arg("x"), "arg_2": Return("another"), "session": Session(), @@ -108,7 +107,10 @@ def test_json(session): "another_task": Task('another') }, session=session) task.set_cached() - j = task.json(indent=4) + # Deleting session from this test. Session is a random ID each time + # With pydantic changes it includes it in serialization + delattr(task, "session") + j = task.model_dump_json(indent=4) dt_run = datetime.datetime.fromtimestamp(1640988000) dt_success = datetime.datetime.fromtimestamp(1640988060) diff --git a/rocketry/test/test_hooks.py b/rocketry/test/test_hooks.py index 9d16d534..20ac2d22 100644 --- a/rocketry/test/test_hooks.py +++ b/rocketry/test/test_hooks.py @@ -1,5 +1,6 @@ from functools import partial from textwrap import dedent +from typing import ClassVar import sys import pytest @@ -7,7 +8,7 @@ from rocketry.core import Task, Scheduler from rocketry.session import Session -from rocketry.tasks import FuncTask +from rocketry.tasks import FuncTask, _DummyTask from rocketry.conditions import SchedulerCycles from rocketry.conds import true @@ -25,28 +26,24 @@ def test_task_init(session): @session.hook_task_init() def myhook(task=TaskArg()): timeline.append("Function hook called") - assert isinstance(task, DummyTask) + assert isinstance(task, _DummyTask) assert not hasattr(task, "name") # Should not yet have created this attr @session.hook_task_init() def mygenerhook(task=TaskArg()): timeline.append("Generator hook called (pre)") - assert isinstance(task, DummyTask) + assert isinstance(task, _DummyTask) assert not hasattr(task, "name") # Should not yet have created this attr yield assert hasattr(task, "session") # Should now have it timeline.append("Generator hook called (post)") - class DummyTask(Task): - - def execute(self, *args, **kwargs): - return assert session.hooks.task_init == [myhook, mygenerhook] # The func is in different namespace thus different timeline.append("Main") - mytask = DummyTask(name="dummy", session=session) + mytask = _DummyTask(name="dummy", session=session) assert timeline == [ "Main", "Function hook called", diff --git a/rocketry/utils/dependencies.py b/rocketry/utils/dependencies.py index cc012ed6..f7b7c594 100644 --- a/rocketry/utils/dependencies.py +++ b/rocketry/utils/dependencies.py @@ -1,6 +1,6 @@ from typing import List, Optional, Union -from pydantic import BaseModel +from pydantic import ConfigDict, BaseModel from rocketry.conditions import Any, All, DependFinish, DependSuccess from rocketry.conditions.task import DependFailure @@ -43,8 +43,7 @@ def __repr__(self): return f'Link({self.parent.name!r}, {self.child.name!r}, relation={getattr(self.relation, "__name__", None)}, type={getattr(self.type, "__name__", None)})' class Dependencies(BaseModel): - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) session: Session diff --git a/tox.ini b/tox.ini index 5710bf09..b6eb9dc0 100644 --- a/tox.ini +++ b/tox.ini @@ -63,4 +63,4 @@ deps = twine # install_command = pip install --upgrade build commands = python setup.py bdist_wheel sdist - twine upload -r testpypi dist/* \ No newline at end of file + twine upload -r testpypi dist/*