-
Notifications
You must be signed in to change notification settings - Fork 41
/
dicts.py
254 lines (207 loc) · 6.34 KB
/
dicts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""Utility functions for dealing with dictionaries."""
from __future__ import annotations
import logging
from collections.abc import MutableMapping
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING
from monty.json import jsanitize
from monty.serialization import dumpfn
from quacc.wflow_tools.db import results_to_db
if TYPE_CHECKING:
from typing import Any
from maggma.stores import Store
LOGGER = logging.getLogger(__name__)
class Remove:
"""
A sentinel class used in quacc to mark a key in a dictionary for removal.
Note: This is more robust than using `None` as the sentinel value because
`None` is a valid value for many keyword arguments.
"""
def __init__(self):
raise NotImplementedError(
"Remove is a sentinel class and should not be instantiated."
)
def recursive_dict_merge(
*dicts: MutableMapping[str, Any] | None,
remove_trigger: Any = Remove,
verbose: bool = False,
) -> MutableMapping[str, Any]:
"""
Recursively merge several dictionaries, taking the latter in the list as higher
preference. Also removes any entries that have a value of `remove_trigger` from the
final dictionary. If a `None` is provided, it is assumed to be `{}`.
This function should be used instead of the | operator when merging nested dictionaries,
e.g. `{"a": {"b": 1}} | {"a": {"c": 2}}` will return `{"a": {"c": 2}}` whereas
`recursive_dict_merge({"a": {"b": 1}}, {"a": {"c": 2}})` will return `{"a": {"b": 1, "c": 2}}`.
Parameters
----------
*dicts
Dictionaries to merge
remove_trigger
Value to that triggers removal of the entry
verbose
Whether to log warnings when overwriting keys
Returns
-------
MutableMapping[str, Any]
Merged dictionary
"""
old_dict = dicts[0]
for i in range(len(dicts) - 1):
merged = _recursive_dict_pair_merge(old_dict, dicts[i + 1], verbose=verbose)
old_dict = safe_dict_copy(merged)
return remove_dict_entries(merged, remove_trigger=remove_trigger)
def _recursive_dict_pair_merge(
dict1: MutableMapping[str, Any] | None,
dict2: MutableMapping[str, Any] | None,
verbose: bool = False,
) -> MutableMapping[str, Any]:
"""
Recursively merges two dictionaries. If a `None` is provided, it is assumed to be `{}`.
Parameters
----------
dict1
First dictionary
dict2
Second dictionary
verbose
Whether to log warnings when overwriting keys
Returns
-------
dict
Merged dictionary
"""
dict1 = dict1 or ({} if dict1 is None else dict1.__class__())
dict2 = dict2 or ({} if dict2 is None else dict2.__class__())
merged = safe_dict_copy(dict1)
for key, value in dict2.items():
if key in merged:
if isinstance(merged[key], MutableMapping) and isinstance(
value, MutableMapping
):
merged[key] = _recursive_dict_pair_merge(
merged[key], value, verbose=verbose
)
else:
merged[key] = value
if verbose:
LOGGER.warning(f"Overwriting key '{key}' to: '{merged[key]}'")
else:
merged[key] = value
return merged
def safe_dict_copy(d: dict) -> dict:
"""
Safely copy a dictionary to account for deepcopy errors.
Parameters
----------
d
Dictionary to copy
Returns
-------
dict
Copied dictionary
"""
try:
return deepcopy(d)
except Exception:
return d.copy()
def remove_dict_entries(
start_dict: dict[str, Any], remove_trigger: Any
) -> dict[str, Any]:
"""
For a given dictionary, recursively remove all items that are the `remove_trigger`.
Parameters
----------
start_dict
Dictionary to clean
remove_trigger
Value to that triggers removal of the entry
Returns
-------
dict
Cleaned dictionary
"""
if isinstance(start_dict, MutableMapping):
return {
k: remove_dict_entries(v, remove_trigger)
for k, v in start_dict.items()
if v is not remove_trigger
}
return (
[remove_dict_entries(v, remove_trigger) for v in start_dict]
if isinstance(start_dict, list)
else start_dict
)
def sort_dict(start_dict: dict[str, Any]) -> dict[str, Any]:
"""
For a given dictionary, recursively sort all entries alphabetically by key.
Parameters
----------
start_dict
Dictionary to sort
Returns
-------
dict
Sorted dictionary
"""
return {
k: sort_dict(v) if isinstance(v, MutableMapping) else v
for k, v in sorted(start_dict.items())
}
def clean_dict(start_dict: dict[str, Any]) -> dict[str, Any]:
"""
Clean up a task document dictionary by removing all entries that are None and
sorting the dictionary alphabetically by key.
Parameters
----------
start_dict
Dictionary to clean
Returns
-------
dict
Cleaned dictionary
"""
return sort_dict(remove_dict_entries(start_dict, None))
def finalize_dict(
task_doc: dict,
directory: str | Path | None,
gzip_file: bool = True,
store: Store | None = None,
) -> dict:
"""
Finalize a schema by cleaning it and storing it in a database and/or file.
Parameters
----------
task_doc
Dictionary representation of the task document.
directory
Directory where the results file is stored.
gzip_file
Whether to gzip the results file.
store
Maggma Store object to store the results in.
Returns
-------
dict
Cleaned task document
"""
cleaned_task_doc = clean_dict(task_doc)
if directory:
if "tmp-quacc" in str(directory):
raise ValueError("The directory should not be a temporary directory.")
sanitized_schema = jsanitize(
cleaned_task_doc, enum_values=True, recursive_msonable=True
)
dumpfn(
sanitized_schema,
Path(
directory,
"quacc_results.json.gz" if gzip_file else "quacc_results.json",
),
fmt="json",
indent=4,
)
if store:
results_to_db(store, task_doc)
return cleaned_task_doc