-
Notifications
You must be signed in to change notification settings - Fork 44
/
aiida_utils.py
222 lines (171 loc) · 7.18 KB
/
aiida_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""Utilities for working with aiida in general"""
from functools import wraps
import numpy as np
from packaging import version
def load_dbenv_if_not_loaded(**kwargs):
"""Load dbenv if necessary, run spinner meanwhile to show command hasn't crashed."""
from aiida.backends.utils import load_dbenv, is_dbenv_loaded
if not is_dbenv_loaded():
load_dbenv(**kwargs)
def dbenv(function):
"""A function decorator that loads the dbenv if necessary before running the function."""
@wraps(function)
def decorated_function(*args, **kwargs):
"""Load dbenv if not yet loaded, then run the original function."""
load_dbenv_if_not_loaded()
return function(*args, **kwargs)
return decorated_function
def get_data_node(data_type, *args, **kwargs):
return get_data_class(data_type)(*args, **kwargs)
@dbenv
def get_data_class(data_type):
"""
Provide access to the orm.data classes with deferred dbenv loading.
compatiblity: also provide access to the orm.data.base memebers, which are loadable through the DataFactory as of 1.0.0-alpha only.
"""
from aiida.orm import DataFactory
from aiida.common.exceptions import MissingPluginError
data_cls = None
try:
data_cls = DataFactory(data_type)
except MissingPluginError as err:
if data_type in BASIC_DATA_TYPES:
data_cls = get_basic_data_pre_1_0(data_type)
else:
raise err
return data_cls
BASIC_DATA_TYPES = set(['bool', 'float', 'int', 'list', 'str'])
@dbenv
def get_basic_data_pre_1_0(data_type):
from aiida.orm.data import base as base_data
return getattr(base_data, data_type.capitalize())
@dbenv
def backend_obj_users():
"""Test if aiida accesses users through backend object."""
backend_obj_flag = False
try:
from aiida.backends.utils import get_automatic_user # pylint: disable=unused-variable,no-name-in-module
except ImportError:
backend_obj_flag = True
return backend_obj_flag
@dbenv
def get_current_user():
"""Get current user backwards compatibly with aiida-core <= 0.12.1."""
current_user = None
if backend_obj_users():
from aiida.orm.backend import construct_backend # pylint: disable=no-name-in-module
backend = construct_backend()
current_user = backend.users.get_automatic_user()
else:
from aiida.backends.utils import get_automatic_user # pylint: disable=no-name-in-module
current_user = get_automatic_user()
return current_user
def builder_interface(calc_cls):
"""Return the JobProcess or the JobCalculation class, depending on aiida version."""
if hasattr(calc_cls, 'get_builder'):
return True
return False
def copy_parameter(old_parameter):
"""Assemble a new ParameterData."""
return get_data_node('parameter', dict=old_parameter.get_dict())
def displaced_structure(structure, displacement, entry):
disp_structure = structure.clone()
displace_position(disp_structure, displacement, entry)
return disp_structure
def compressed_structure(structure, volume_change):
comp_structure = structure.clone()
compress_cell(comp_structure, volume_change)
return comp_structure
def displace_position(structure, displacement, entry):
"""Displace a position in the StructureData."""
sites = structure.sites
positions = []
for site in sites:
positions.append(site.position)
new_position = np.asarray(positions[entry - 1]) + displacement
new_position = new_position.tolist()
positions[entry - 1] = tuple(new_position)
structure.reset_sites_positions(positions)
def compress_cell(structure, volume_change):
"""Apply compression or tensile forces to the unit cell."""
cell = structure.cell
new_cell = np.array(cell) * volume_change
structure.reset_cell(new_cell.tolist())
def aiida_version():
from aiida import __version__ as aiida_version_
return version.parse(aiida_version_)
def cmp_version(string):
return version.parse(string)
def cmp_load_verdi_data():
"""Load the verdi data click command group for any version since 0.11."""
verdi_data = None
import_errors = []
try:
from aiida.cmdline.commands import data_cmd as verdi_data
except ImportError as err:
import_errors.append(err)
if not verdi_data:
try:
from aiida.cmdline.commands import verdi_data
except ImportError as err:
import_errors.append(err)
if not verdi_data:
try:
from aiida.cmdline.commands.cmd_data import verdi_data
except ImportError as err:
import_errors.append(err)
if not verdi_data:
err_messages = '\n'.join([' * {}'.format(err) for err in import_errors])
raise ImportError('The verdi data base command group could not be found:\n' + err_messages)
return verdi_data
@dbenv
def create_authinfo(computer, store=False):
"""
Allow the current user to use the given computer.
Deal with backwards compatibility down to aiida 0.11
"""
from aiida.orm import backend as orm_backend
authinfo = None
if hasattr(orm_backend, 'construct_backend'):
backend = orm_backend.construct_backend()
authinfo = backend.authinfos.create(computer=computer, user=get_current_user())
if store:
authinfo.store()
else:
from aiida.backends.settings import BACKEND
from aiida.backends.profile import BACKEND_SQLA, BACKEND_DJANGO
if BACKEND == BACKEND_DJANGO:
from aiida.backends.djsite.db.models import DbAuthInfo
authinfo = DbAuthInfo(dbcomputer=computer.dbcomputer, aiidauser=get_current_user())
elif BACKEND == BACKEND_SQLA:
from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo
from aiida.backends.sqlalchemy import get_scoped_session
_ = get_scoped_session()
authinfo = DbAuthInfo(dbcomputer=computer.dbcomputer, aiidauser=get_current_user())
if store:
authinfo.save()
return authinfo
@dbenv
def cmp_get_authinfo(computer):
"""Get an existing authinfo or None for the given computer and current user."""
if hasattr(computer, 'get_authinfo'):
return computer.get_authinfo(get_current_user())
else:
from aiida.backends.settings import BACKEND
from aiida.backends.profile import BACKEND_SQLA, BACKEND_DJANGO
if BACKEND == BACKEND_DJANGO:
from aiida.backends.djsite.db.models import DbAuthInfo
return DbAuthInfo.objects.get(dbcomputer=computer.dbcomputer, aiidauser=get_current_user()) # pylint: disable=no-member
elif BACKEND == BACKEND_SQLA:
from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo
from aiida.backends.sqlalchemy import get_scoped_session
session = get_scoped_session()
return session.query(DbAuthInfo).filter(DbAuthInfo.dbcomputer == computer.dbcomputer).filter(
DbAuthInfo.aiidauser == get_current_user())
return None
@dbenv
def cmp_get_transport(computer):
if hasattr(computer, 'get_transport'):
return computer.get_transport()
authinfo = cmp_get_authinfo(computer)
return authinfo.get_transport()