Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for PydanticV2 #212

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,7 @@ docs/_build/
.tox/

# Coverage
cov_data/
cov_data/

#Custom
test.py
8 changes: 6 additions & 2 deletions rocketry/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, ClassVar
from pydantic.dataclasses import dataclass, Field
from pydantic import BaseModel

if TYPE_CHECKING:
from rocketry import Session

class RedBase:
"""Baseclass for all Rocketry classes"""
session: 'Session' = None

# Commented this out for now as it was causing issues with the new pydantic implementation
session: 'Session'
18 changes: 14 additions & 4 deletions rocketry/_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from rocketry.session import Session, Config
from rocketry.parse import add_condition_parser
from rocketry.conds import true, false
from rocketry.tasks import CommandTask, FuncTask, CodeTask
from rocketry.tasks import CommandTask, FuncTask, CodeTask, _DummyTask
from rocketry.tasks.maintain import ShutDown, Restart

from rocketry.conditions.meta import _FuncTaskCondWrapper
Expand All @@ -23,14 +23,24 @@ def _setup_defaults():
cls_tasks = (
Task,
FuncTask, CommandTask, CodeTask,
ShutDown, Restart,
ShutDown, Restart, _DummyTask,

_FuncTaskCondWrapper
)
for cls_task in cls_tasks:
cls_task.update_forward_refs(Session=Session, BaseCondition=BaseCondition)
#cls_task.update_forward_refs(Session=Session, BaseCondition=BaseCondition)
cls_task.model_rebuild(
force=True,
_types_namespace={"Session": Session, "BaseCondition": BaseCondition},
_parent_namespace_depth=4
)

Config.update_forward_refs(BaseCondition=BaseCondition)
# Config.update_forward_refs(BaseCondition=BaseCondition)
Config.model_rebuild(
force=True,
_types_namespace={"Session": Session, "BaseCondition": BaseCondition},
_parent_namespace_depth=4
)
#Session.update_forward_refs(
# Task=Task, Parameters=Parameters, Scheduler=Scheduler
#)
2 changes: 1 addition & 1 deletion rocketry/conditions/meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import copy
from typing import Callable, Optional, Pattern, Union
from typing import Callable, ClassVar, Optional, Pattern, Union

from pydantic import Field
from rocketry.args import Session
Expand Down
92 changes: 51 additions & 41 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import threading
from queue import Empty
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Dict, Type, Union, Tuple, Optional
from typing_extensions import Annotated
try:
from typing import Literal
except ImportError: # pragma: no cover
from typing_extensions import Literal

from pydantic import BaseModel, Field, PrivateAttr, validator
from pydantic import BaseModel, Field, PrivateAttr, ConfigDict, field_validator, field_serializer

from rocketry._base import RedBase
from rocketry.core.condition import BaseCondition, AlwaysFalse, All
Expand Down Expand Up @@ -94,7 +95,7 @@ def is_async(self) -> bool:
def is_thread(self) -> bool:
return isinstance(self.task, threading.Thread)

class Task(RedBase, BaseModel):
class Task(BaseModel, RedBase):
"""Base class for Tasks.

A task can be a function, command or other procedure that
Expand Down Expand Up @@ -192,42 +193,37 @@ class Task(RedBase, BaseModel):
... return ...

"""
class Config:
arbitrary_types_allowed= True
underscore_attrs_are_private = True
validate_assignment = True
json_encoders = {
Parameters: lambda v: v.to_json(),
'BaseCondition': lambda v: str(v),
FunctionType: lambda v: v.__name__,
'Session': lambda v: id(v),
}

model_config = ConfigDict(
arbitrary_types_allowed= True,
validate_assignment = True,
extra='allow',
)

session: 'Session' = Field()
session: 'Session' = Field(default=None, validate_default=False)


# Class
permanent: bool = False # Whether the task is not meant to finish (Ie. RestAPI)
_actions: ClassVar[Tuple] = ("run", "fail", "success", "inaction", "terminate", None, "crash")
fmt_log_message: str = r"Task '{task}' status: '{action}'"

daemon: Optional[bool]
daemon: Optional[bool] = None
batches: List[Parameters] = Field(
default_factory=list,
description="Run batches (parameters). If not empty, run is triggered regardless of starting condition"
)

# Instance
name: Optional[str] = Field(description="Name of the task. Must be unique")
description: Optional[str] = Field(description="Description of the task for documentation")
logger_name: Optional[str] = Field(description="Logger name to be used in logging the task records")
execution: Optional[Literal['main', 'async', 'thread', 'process']]
name: Optional[str] = Field(description="Name of the task. Must be unique", default=None)
description: Optional[str] = Field(description="Description of the task for documentation", default=None)
logger_name: Optional[str] = Field(description="Logger name to be used in logging the task records", default="rocketry.task")
execution: Optional[Literal['main', 'async', 'thread', 'process']] = None
priority: int = 0
disabled: bool = False
force_run: bool = False
force_termination: bool = False
status: Optional[Literal['run', 'fail', 'success', 'terminate', 'inaction', 'crash']] = Field(description="Latest status of the task")
timeout: Optional[datetime.timedelta]
status: Optional[Literal['run', 'fail', 'success', 'terminate', 'inaction', 'crash']] = Field(description="Latest status of the task", default=None)
timeout: Optional[datetime.timedelta] = None

parameters: Parameters = Parameters()

Expand All @@ -237,7 +233,7 @@ class Config:
multilaunch: Optional[bool] = None
on_startup: bool = False
on_shutdown: bool = False
func_run_id: Callable = None
func_run_id: Union[Callable, None] = None

_last_run: Optional[float]
_last_success: Optional[float]
Expand All @@ -252,29 +248,29 @@ class Config:

_mark_running = False

@validator('start_cond', pre=True)
@field_validator('start_cond', mode="before")
def parse_start_cond(cls, value, values):
from rocketry.parse.condition import parse_condition
session = values['session']
session = values.data['session']
if isinstance(value, str):
value = parse_condition(value, session=session)
elif value is None:
value = AlwaysFalse()
return copy(value)

@validator('end_cond', pre=True)
@field_validator('end_cond', mode="before")
def parse_end_cond(cls, value, values):
from rocketry.parse.condition import parse_condition
session = values['session']
session = values.data['session']
if isinstance(value, str):
value = parse_condition(value, session=session)
elif value is None:
value = AlwaysFalse()
return copy(value)

@validator('logger_name', pre=True, always=True)
@field_validator('logger_name', mode="before")
def parse_logger_name(cls, value, values):
session = values['session']
session = values.data['session']

if isinstance(value, str):
logger_name = value
Expand All @@ -287,7 +283,7 @@ def parse_logger_name(cls, value, values):
raise ValueError(f"Logger name must start with '{basename}' as session finds loggers with names")
return logger_name

@validator('timeout', pre=True, always=True)
@field_validator('timeout', mode="before")
def parse_timeout(cls, value, values):
if value == "never":
return datetime.timedelta.max
Expand All @@ -296,6 +292,22 @@ def parse_timeout(cls, value, values):
if value is not None:
return to_timedelta(value)
return value

@field_serializer("parameters", when_used="json")
def ser_parameters(self, parameters):
return parameters.to_json()

@field_serializer("start_cond", when_used="json")
def ser_start_cond(self, start_cond):
return str(start_cond)

@field_serializer("end_cond", when_used="json")
def ser_end_cond(self, end_cond):
return str(end_cond)

@field_serializer("session", when_used="json", check_fields=False)
def ser_session(self, session):
return id(session)

@property
def logger(self):
Expand Down Expand Up @@ -339,9 +351,9 @@ def _get_name(self, name=None, **kwargs):
return self.get_default_name(**kwargs)
return name

@validator('name', pre=True)
@field_validator('name', mode="before")
def parse_name(cls, value, values):
session = values['session']
session = values.data['session']
on_exists = session.config.task_pre_exist
name_exists = value in session
if name_exists:
Expand All @@ -359,9 +371,9 @@ def parse_name(cls, value, values):
return name
return value

@validator('name', pre=False)
@field_validator('name', mode="after")
def validate_name(cls, value, values):
session = values['session']
session = values.data['session']
on_exists = session.config.task_pre_exist
name_exists = value in session

Expand All @@ -371,17 +383,17 @@ def validate_name(cls, value, values):
raise ValueError(f"Task name '{value}' already exists. Please pick another")
return value

@validator('parameters', pre=True)
@field_validator('parameters', mode="before")
def parse_parameters(cls, value):
if isinstance(value, Parameters):
return value
return Parameters(value)

@validator('force_run', pre=False)
@field_validator('force_run', mode="after")
def parse_force_run(cls, value, values):
if value:
warnings.warn("Attribute 'force_run' is deprecated. Please use method set_running() instead", DeprecationWarning)
values['batches'].append(Parameters())
values.data['batches'].append(Parameters())
return value

def __hash__(self):
Expand Down Expand Up @@ -731,7 +743,6 @@ def run_as_process(self, params:Parameters, direct_params:Parameters, task_run:T

self._run_stack.append(task_run)
self._mark_running = True # needed in pickling

process.start()
self._mark_running = False

Expand All @@ -751,7 +762,6 @@ def _run_as_process(self, params:Parameters, direct_params:Parameters, task_run,
# in the actual multiprocessing's process. We only add QueueHandler to the
# logger (with multiprocessing.Queue as queue) so that all the logging
# records end up in the main process to be logged properly.

basename = self.logger_name
# handler = logging.handlers.QueueHandler(queue)
handler = QueueHandler(queue)
Expand Down Expand Up @@ -1294,8 +1304,8 @@ def __getstate__(self):
#state['__dict__'] = state['__dict__'].copy()

# remove unpicklable
state['__private_attribute_values__'] = state['__private_attribute_values__'].copy()
priv_attrs = state['__private_attribute_values__']
state['__pydantic_private__'] = state['__pydantic_private__'].copy()
priv_attrs = state['__pydantic_private__']
Jypear marked this conversation as resolved.
Show resolved Hide resolved
priv_attrs['_lock'] = None
priv_attrs['_process'] = None
priv_attrs['_thread'] = None
Expand Down Expand Up @@ -1404,5 +1414,5 @@ def json(self, **kwargs):
if 'exclude' not in kwargs:
kwargs['exclude'] = set()
kwargs['exclude'].update({'session'})
d = super().json(**kwargs)
d = super().model_dump_json(**kwargs)
return d
25 changes: 14 additions & 11 deletions rocketry/log/log_record.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from typing import Optional
from pydantic import BaseModel, Field, validator
from pydantic import field_validator, BaseModel, Field

from rocketry.pybox.time import to_datetime, to_timedelta

Expand Down Expand Up @@ -38,36 +38,39 @@ class LogRecord(MinimalRecord):

class TaskLogRecord(MinimalRecord):

start: Optional[datetime.datetime]
end: Optional[datetime.datetime]
runtime: Optional[datetime.timedelta]
start: Optional[datetime.datetime] = None
end: Optional[datetime.datetime] = None
runtime: Optional[datetime.timedelta] = None

message: str
exc_text: Optional[str]
exc_text: Optional[str] = None

@validator("start", pre=True)
@field_validator("start", mode="before")
@classmethod
def format_start(cls, value):
if value is not None:
value = to_datetime(value)
return value

@validator("end", pre=True)
@field_validator("end", mode="before")
@classmethod
def format_end(cls, value):
if value is not None:
value = to_datetime(value)
return value

@validator("runtime", pre=True)
@field_validator("runtime", mode="before")
@classmethod
def format_runtime(cls, value):
if value is not None:
value = to_timedelta(value)
return value

class MinimalRunRecord(MinimalRecord):
run_id: Optional[str]
run_id: Optional[str] = None

class RunRecord(LogRecord):
run_id: Optional[str]
run_id: Optional[str] = None

class TaskRunRecord(TaskLogRecord):
run_id: Optional[str]
run_id: Optional[str] = None
Loading