/
pw.py
237 lines (196 loc) · 9.93 KB
/
pw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# -*- coding: utf-8 -*-
"""Protocol definitions for workflow input generation."""
import json
import os
from copy import deepcopy
def _load_pseudo_metadata(filename):
"""Load from the current folder a json file containing metadata (incl.
suggested cutoffs) for a library of pseudopotentials.
"""
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), filename)) as handle:
return json.load(handle)
def _get_all_protocol_modifiers():
"""Return the information on all possibile modifiers for all known protocols.
It is a function so we can lazily load the jsons.
"""
protocols = {
'theos-ht-1.0': {
'pseudo': {
# SSSP Efficiency & Precision v1.0, see https://www.materialscloud.org/archive/2018.0001/v2
'SSSP-efficiency-1.0': _load_pseudo_metadata('sssp_efficiency_1.0.json'),
'SSSP-precision-1.0': _load_pseudo_metadata('sssp_precision_1.0.json'),
# SSSP Efficiency & Precision v1.1, see https://www.materialscloud.org/archive/2018.0001/v3
'SSSP-efficiency-1.1': _load_pseudo_metadata('sssp_efficiency_1.1.json'),
'SSSP-precision-1.1': _load_pseudo_metadata('sssp_precision_1.1.json'),
},
'pseudo_default': 'SSSP-efficiency-1.1',
'parameters': {
'fast': {
'kpoints_mesh_offset': [0., 0., 0.],
'kpoints_mesh_density': 0.2,
'kpoints_distance_for_bands': 0.02,
'convergence_threshold_per_atom': 2.E-06,
'smearing': 'marzari-vanderbilt',
'degauss': 0.02,
'occupations': 'smearing',
'meta_convergence': False,
'volume_convergence': 0.01,
'tstress': True,
'tprnfor': True,
'num_bands_factor': None, # number of bands wrt number of occupied bands
},
'default': {
'kpoints_mesh_offset': [0., 0., 0.],
'kpoints_mesh_density': 0.2,
'kpoints_distance_for_bands': 0.01,
'convergence_threshold_per_atom': 1.E-10,
'smearing': 'marzari-vanderbilt',
'degauss': 0.02,
'occupations': 'smearing',
'meta_convergence': True,
'volume_convergence': 0.01,
'tstress': True,
'tprnfor': True,
'num_bands_factor': None, # number of bands wrt number of occupied bands
},
},
'parameters_default': 'default'
}
}
protocols['theos-ht-1.0']['parameters']['scdm'] = protocols['theos-ht-1.0']['parameters']['default']
protocols['theos-ht-1.0']['parameters']['scdm']['num_bands_factor'] = 3.0
# a protocol for testing purpose, decrease kmesh density & ecutoff
testing = deepcopy(protocols['theos-ht-1.0'])
testing['parameters']['fast']['kpoints_mesh_density'] = 0.3
testing['parameters_default'] = 'fast'
ps_data = testing['pseudo']['SSSP-efficiency-1.1']
for pseudo in ps_data:
ps_data[pseudo]['cutoff'] = ps_data[pseudo]['cutoff'] / 2
testing['pseudo']['SSSP-efficiency-1.1'] = ps_data
protocols['testing'] = testing
return protocols
class ProtocolManager:
"""A class to manage calculation protocols."""
def __init__(self, name):
"""Initialize a protocol instance.
Pass a string specifying the protocol.
"""
self.name = name
try:
self.modifiers = _get_all_protocol_modifiers()[name]
except KeyError:
raise ValueError("Unknown protocol '{}'".format(name))
def get_protocol_data(self, modifiers=None):
"""Return the full info on the specific protocol, using the (optional) modifiers.
:param modifiers: should be a dictionary with (optional) keys 'parameters' and 'pseudo', and
whose value is the modifier name for that category.
If the key-value pair is not specified, the default for the protocol will be used.
In this case, if no default is specified, a ValueError is thrown.
.. note:: If you pass 'custom' as the modifier name for 'pseudo',
then you have to pass an additional key, called 'pseudo_data', that will be
used to populate the output.
"""
if modifiers is None:
modifiers = {}
modifiers_copy = modifiers.copy()
parameters_modifier_name = modifiers_copy.pop('parameters', self.get_default_parameters_modifier_name())
pseudo_modifier_name = modifiers_copy.pop('pseudo', self.get_default_pseudo_modifier_name())
if parameters_modifier_name is None:
raise ValueError(
"You did not specify a modifier name for 'parameters', but no default "
"modifier name exists for protocol '{}'.".format(self.name)
)
if pseudo_modifier_name is None:
raise ValueError(
"You did not specify a modifier name for 'pseudo', but no default "
"modifier name exists for protocol '{}'.".format(self.name)
)
if pseudo_modifier_name == 'custom':
try:
pseudo_data = modifiers_copy.pop('pseudo_data')
except KeyError:
raise ValueError(
"You specified 'custom' as a modifier name for 'pseudo', but you did not provide "
"a 'pseudo_data' key."
)
else:
pseudo_data = self.get_pseudo_data(pseudo_modifier_name)
# Check that there are no unknown modifiers
if modifiers_copy:
raise ValueError('Unknown modifiers specified: {}'.format(','.join(sorted(modifiers_copy))))
retdata = self.get_parameters_data(parameters_modifier_name)
retdata['pseudo_data'] = pseudo_data
return retdata
def get_parameters_modifier_names(self):
"""Get all valid parameters modifier names."""
return list(self.modifiers['parameters'].keys())
def get_default_parameters_modifier_name(self): # pylint: disable=invalid-name
"""Return the default parameter modifier name (or None if no default is specified)."""
return self.modifiers.get('parameters_default', None)
def get_parameters_data(self, modifier_name):
"""Given a parameter modifier name, return a dictionary of data associated to it."""
return self.modifiers['parameters'][modifier_name]
def get_pseudo_modifier_names(self):
"""Get all valid pseudopotential modifier names."""
return list(self.modifiers['pseudo'].keys())
def get_default_pseudo_modifier_name(self): # pylint: disable=invalid-name
"""Return the default pseudopotential modifier name (or None if no default is specified)."""
return self.modifiers.get('pseudo_default', None)
def get_pseudo_data(self, modifier_name):
"""Given a pseudo modifier name, return the ``pseudo_data`` associated to it."""
return self.modifiers['pseudo'][modifier_name]
def check_pseudos(self, modifier_name=None, pseudo_data=None):
"""Given a pseudo modifier name, checks which pseudos exist in the DB.
:param modifier_name: the name of the modifier. Leave to None to use the default one.
:param pseudo_data: should be passed only if modifier_name == 'custom'
:return: a dictionary with three keys:
- ``missing``: a set of element names that are not in the DB
- ``found``: a dictionary with key-value: ``{element_name: uuid}`` for the pseudos that were found
- ``mismatch``: a dictionary with key-value: ``{element_name: [list-of-elements-found]}`` for those
pseudos for which one (or more) pseudos were found with the same MD5, but associated to different elements
(listed in `list-of-elements-found`)
"""
from aiida.orm import QueryBuilder
from aiida.plugins import DataFactory
UpfData = DataFactory('upf')
if modifier_name is None:
modifier_name = self.get_default_pseudo_modifier_name()
if modifier_name is None:
raise ValueError(
'You did not specify a modifier name, but no default '
"modifier name exists for protocol '{}'.".format(self.name)
)
if modifier_name == 'custom':
if pseudo_data is None:
raise ValueError("You chose 'custom' as modifier_name, but did not provide a " 'pseudo_data!')
else:
if pseudo_data is not None:
raise ValueError("You passed a pseudo_data, but the modifier name is not 'custom'!")
pseudo_data = self.get_pseudo_data(modifier_name)
# No pseudo found
missing = set()
# Pseudo found and ok
found = {}
# Pseudo with MD5 found, but wrong element!
mismatch = {}
for element, this_pseudo_data in pseudo_data.items():
md5 = this_pseudo_data['md5']
builder = QueryBuilder()
builder.append(UpfData, filters={'attributes.md5': md5}, project=['uuid', 'attributes.element'])
res = builder.all()
if len(res) >= 1:
this_mismatch_elements = []
for this_uuid, this_element in res:
if element == this_element:
found[element] = this_uuid
break
this_mismatch_elements.append(this_element)
if element not in found:
mismatch[element] = this_mismatch_elements
else:
missing.add(element)
return {'missing': missing, 'found': found, 'mismatch': mismatch}
if __name__ == '__main__':
MANAGER = ProtocolManager('theos-ht-1.0')
print(MANAGER.check_pseudos())
print(MANAGER.get_protocol_data())