-
Notifications
You must be signed in to change notification settings - Fork 42
/
customizers.py
206 lines (165 loc) · 6.36 KB
/
customizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""Functions to customize workflow steps."""
from __future__ import annotations
from copy import deepcopy
from functools import partial
from typing import TYPE_CHECKING, Literal
from quacc.utils.dicts import recursive_dict_merge
if TYPE_CHECKING:
from typing import Any, Callable
def strip_decorator(func: Callable) -> Callable:
"""
Strip the decorators from a function.
Parameters
----------
func
The function to strip decorators from.
Returns
-------
Callable
The function with all decorators removed.
"""
from quacc import SETTINGS
if SETTINGS.WORKFLOW_ENGINE == "covalent":
from covalent._workflow.lattice import Lattice
if hasattr(func, "electron_object"):
func = func.electron_object.function
if isinstance(func, Lattice):
func = func.workflow_function.get_deserialized()
elif SETTINGS.WORKFLOW_ENGINE == "dask":
from dask.delayed import Delayed
from quacc.wflow_tools.decorators import Delayed_
if isinstance(func, Delayed_):
func = func.func
if isinstance(func, Delayed):
func = func.__wrapped__
if hasattr(func, "__wrapped__"):
# Needed for custom `@subflow` decorator
func = func.__wrapped__
elif SETTINGS.WORKFLOW_ENGINE == "jobflow":
if hasattr(func, "original"):
func = func.original
elif SETTINGS.WORKFLOW_ENGINE == "parsl":
from parsl.app.python import PythonApp
if isinstance(func, PythonApp):
func = func.func
elif SETTINGS.WORKFLOW_ENGINE == "prefect":
from prefect import Flow as PrefectFlow
from prefect import Task
if isinstance(func, (Task, PrefectFlow)):
func = func.fn
elif hasattr(func, "__wrapped__"):
func = func.__wrapped__
elif SETTINGS.WORKFLOW_ENGINE == "redun":
from redun import Task
if isinstance(func, Task):
func = func.func
return func
def redecorate(func: Callable, decorator: Callable | None) -> Callable:
"""
Redecorate a pre-decorated function with a custom decorator.
Parameters
----------
func
The pre-decorated function.
decorator
The new decorator to apply. If `None`, the function is stripped of its
decorators.
Returns
-------
Callable
The newly decorated function.
"""
func = strip_decorator(func)
return func if decorator is None else decorator(func)
def update_parameters(
func: Callable,
params: dict[str, Any],
decorator: Literal["job", "flow", "subflow"] | None = "job",
) -> Callable:
"""
Update the parameters of a (potentially decorated) function.
Parameters
----------
func
The function to update.
params
The parameters and associated values to update.
decorator
The decorator associated with `func`.
Returns
-------
Callable
The updated function.
"""
from quacc import SETTINGS, flow, job, subflow
if decorator and SETTINGS.WORKFLOW_ENGINE == "dask":
if decorator == "job":
decorator = job
elif decorator == "flow":
decorator = flow
elif decorator == "subflow":
decorator = subflow
func = strip_decorator(func)
return decorator(partial(func, **params))
return partial(func, **params)
def customize_funcs(
names: list[str] | str,
funcs: list[Callable] | Callable,
param_defaults: dict[str, dict[str, Any]] | None = None,
param_swaps: dict[str, dict[str, Any]] | None = None,
decorators: dict[str, Callable | None] | None = None,
) -> tuple[Callable, ...] | Callable:
"""
Customize a set of functions with decorators and common parameters.
Parameters
----------
names
The names of the functions to customize, in the order they should be returned.
funcs
The functions to customize, in the order they are described in `names`.
param_defaults
Default parameters to apply to each function. The keys of this dictionary correspond
to the strings in `names`. If the key `"all"` is present, it will be applied to all
functions. If the value is `None`, no custom parameters will be applied to that function.
param_swaps
User-overrides of parameters to apply to each function. The keys of this dictionary correspond
to the strings in `names`. If the key `"all"` is present, it will be applied to all
functions. If the value is `None`, no custom parameters will be applied to that function.
decorators
Custom decorators to apply to each function. The keys of this dictionary correspond
to the strings in `names`. If the key `"all"` is present, it will be applied to all
functions. If a value is `None`, no decorator will be applied that function.
Returns
-------
tuple[Callable, ...] | Callable
The customized functions, returned in the same order as provided in `funcs`.
"""
parameters = recursive_dict_merge(param_defaults, param_swaps)
decorators = decorators or {}
updated_funcs = []
if not isinstance(names, (list, tuple)):
names = [names]
if not isinstance(funcs, (list, tuple)):
funcs = [funcs]
if "all" in names:
raise ValueError("Invalid function name: 'all' is a reserved name.")
if bad_decorator_keys := [k for k in decorators if k not in names and k != "all"]:
raise ValueError(
f"Invalid decorator keys: {bad_decorator_keys}. Valid keys are: {names}"
)
if bad_parameter_keys := [k for k in parameters if k not in names and k != "all"]:
raise ValueError(
f"Invalid parameter keys: {bad_parameter_keys}. Valid keys are: {names}"
)
for i, func in enumerate(funcs):
func_ = deepcopy(func)
if decorator := decorators.get("all"):
func_ = redecorate(func_, decorator)
if decorator := decorators.get(names[i]):
func_ = redecorate(func_, decorator)
if params := parameters.get("all"):
func_ = update_parameters(func_, params)
if params := parameters.get(names[i]):
func_ = update_parameters(func_, params)
updated_funcs.append(func_)
return updated_funcs[0] if len(updated_funcs) == 1 else tuple(updated_funcs)