Skip to content

Commit

Permalink
Implement work around to partially intitialise pydantic class variabl…
Browse files Browse the repository at this point in the history
…es to allow setting of self when setting _delayed_kwargs:O
  • Loading branch information
Jypear committed Jul 24, 2023
1 parent bbb5a37 commit 31f821f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions rocketry/tasks/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Callable, List, Optional
import warnings

from pydantic import Field, PrivateAttr, field_validator, field_serializer
from pydantic import Field, PrivateAttr, field_validator, field_serializer, BaseModel

from rocketry.core.task import Task
from rocketry.core.parameters import Parameters
Expand Down Expand Up @@ -128,7 +128,7 @@ def wrapper(*args, **kwargs):
def my_task_func():
...
"""
func: Optional[Callable] = Field(description="Executed function")
func: Optional[Callable] = Field(description="Executed function", default=None)

path: Optional[Path] = Field(description="Path to the script that is executed", default = None)
func_name: Optional[str] = Field(default="main", description="Name of the function in given path. Pass path as well")
Expand Down Expand Up @@ -171,6 +171,9 @@ def ser_func(self, func):
def __init__(self, func=None, **kwargs):
only_func_set = func is not None and not kwargs
no_func_set = func is None and kwargs.get('path') is None
from pydantic.main import _object_setattr
_object_setattr(self, "__pydantic_extra__", {})
_object_setattr(self, "__pydantic_private__", None)
if no_func_set:
# FuncTask was probably called like:
# @FuncTask(...)
Expand All @@ -179,7 +182,7 @@ def __init__(self, func=None, **kwargs):
# We initiate the class lazily by creating
# almost empty shell class that is populated
# in next __call__ (which should occur immediately)
FuncTask._delayed_kwargs = kwargs
self._delayed_kwargs = kwargs
return
if only_func_set:
# Most likely called as:
Expand All @@ -199,7 +202,7 @@ def __call__(self, *args, **kwargs):
func = args[0]
super().__init__(func=func, **self._delayed_kwargs)
self._set_descr(is_delayed=False)
FuncTask._delayed_kwargs = {}
self._delayed_kwargs = {}

# Note that we must return the function or
# we are in deep shit with multiprocessing
Expand Down

0 comments on commit 31f821f

Please sign in to comment.