/
kernel.py
291 lines (255 loc) · 12.4 KB
/
kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
from parcels.codegenerator import KernelGenerator, LoopGenerator
from parcels.compiler import get_cache_dir
from parcels.kernels.error import ErrorCode, recovery_map as recovery_base_map
from parcels.field import FieldSamplingError
from parcels.loggers import logger
from parcels.kernels.advection import AdvectionRK4_3D
from os import path, remove
import numpy as np
import numpy.ctypeslib as npct
import time
from ctypes import c_int, c_float, c_double, c_void_p, byref
import _ctypes
from sys import platform
from ast import parse, FunctionDef, Module
import inspect
from copy import deepcopy
import re
from hashlib import md5
import math # noqa
import random # noqa
__all__ = ['Kernel']
re_indent = re.compile(r"^(\s+)")
def fix_indentation(string):
"""Fix indentation to allow in-lined kernel definitions"""
lines = string.split('\n')
indent = re_indent.match(lines[0])
if indent:
lines = [l.replace(indent.groups()[0], '', 1) for l in lines]
return "\n".join(lines)
class Kernel(object):
"""Kernel object that encapsulates auto-generated code.
:arg fieldset: FieldSet object providing the field information
:arg ptype: PType object for the kernel particle
Note: A Kernel is either created from a compiled <function ...> object
or the necessary information (funcname, funccode, funcvars) is provided.
The py_ast argument may be derived from the code string, but for
concatenation, the merged AST plus the new header definition is required.
"""
def __init__(self, fieldset, ptype, pyfunc=None, funcname=None,
funccode=None, py_ast=None, funcvars=None, c_include=""):
self.fieldset = fieldset
self.ptype = ptype
# Derive meta information from pyfunc, if not given
self.funcname = funcname or pyfunc.__name__
if pyfunc is AdvectionRK4_3D:
logger.warning_once('Note that positive vertical velocity is assumed DOWNWARD by AdvectionRK4_3D')
if funcvars is not None:
self.funcvars = funcvars
elif hasattr(pyfunc, '__code__'):
self.funcvars = list(pyfunc.__code__.co_varnames)
else:
self.funcvars = None
self.funccode = funccode or inspect.getsource(pyfunc.__code__)
# Parse AST if it is not provided explicitly
self.py_ast = py_ast or parse(fix_indentation(self.funccode)).body[0]
if pyfunc is None:
# Extract user context by inspecting the call stack
stack = inspect.stack()
try:
user_ctx = stack[-1][0].f_globals
user_ctx['math'] = globals()['math']
user_ctx['random'] = globals()['random']
user_ctx['ErrorCode'] = globals()['ErrorCode']
except:
logger.warning("Could not access user context when merging kernels")
user_ctx = globals()
finally:
del stack # Remove cyclic references
# Compile and generate Python function from AST
py_mod = Module(body=[self.py_ast])
exec(compile(py_mod, "<ast>", "exec"), user_ctx)
self.pyfunc = user_ctx[self.funcname]
else:
self.pyfunc = pyfunc
self.name = "%s%s" % (ptype.name, self.funcname)
# Generate the kernel function and add the outer loop
if self.ptype.uses_jit:
kernelgen = KernelGenerator(fieldset, ptype)
self.field_args = kernelgen.field_args
kernel_ccode = kernelgen.generate(deepcopy(self.py_ast),
self.funcvars)
self.field_args = kernelgen.field_args
if 'UV' in self.field_args:
fieldset = self.field_args['UV'].fieldset
for f in ['U', 'V', 'cosU', 'sinU', 'cosV', 'sinV']:
if f not in self.field_args:
try:
self.field_args[f] = getattr(fieldset, f)
except:
continue
del self.field_args['UV']
self.const_args = kernelgen.const_args
loopgen = LoopGenerator(fieldset, ptype)
if path.isfile(c_include):
with open(c_include, 'r') as f:
c_include_str = f.read()
else:
c_include_str = c_include
self.ccode = loopgen.generate(self.funcname, self.field_args, self.const_args,
kernel_ccode, c_include_str)
basename = path.join(get_cache_dir(), self._cache_key)
self.src_file = "%s.c" % basename
self.lib_file = "%s.%s" % (basename, 'dll' if platform == 'win32' else 'so')
self.log_file = "%s.log" % basename
self._lib = None
def __del__(self):
# Clean-up the in-memory dynamic linked libraries.
# This is not really necessary, as these programs are not that large, but with the new random
# naming scheme which is required on Windows OS'es to deal with updates to a Parcels' kernel.
if self._lib is not None:
_ctypes.FreeLibrary(self._lib._handle) if platform == 'win32' else _ctypes.dlclose(self._lib._handle)
del self._lib
self._lib = None
if path.isfile(self.lib_file):
[remove(s) for s in [self.src_file, self.lib_file, self.log_file]]
@property
def _cache_key(self):
field_keys = "-".join(["%s:%s" % (name, field.units.__class__.__name__)
for name, field in self.field_args.items()])
key = self.name + self.ptype._cache_key + field_keys + ('TIME:%f' % time.time())
return md5(key.encode('utf-8')).hexdigest()
def remove_lib(self):
# Unload the currently loaded dynamic linked library to be secure
if self._lib is not None:
_ctypes.FreeLibrary(self._lib._handle) if platform == 'win32' else _ctypes.dlclose(self._lib._handle)
del self._lib
self._lib = None
# If file already exists, pull new names. This is necessary on a Windows machine, because
# Python's ctype does not deal in any sort of manner well with dynamic linked libraries on this OS.
if path.isfile(self.lib_file):
[remove(s) for s in [self.src_file, self.lib_file, self.log_file]]
basename = path.join(get_cache_dir(), self._cache_key)
self.src_file = "%s.c" % basename
self.lib_file = "%s.%s" % (basename, 'dll' if platform == 'win32' else 'so')
self.log_file = "%s.log" % basename
def compile(self, compiler):
""" Writes kernel code to file and compiles it."""
with open(self.src_file, 'w') as f:
f.write(self.ccode)
compiler.compile(self.src_file, self.lib_file, self.log_file)
logger.info("Compiled %s ==> %s" % (self.name, self.lib_file))
def load_lib(self):
self._lib = npct.load_library(self.lib_file, '.')
self._function = self._lib.particle_loop
def execute_jit(self, pset, endtime, dt):
"""Invokes JIT engine to perform the core update loop"""
for g in pset.fieldset.gridset.grids:
g.cstruct = None # This force to point newly the grids from Python to C
# Make a copy of the transposed array to enforce
# C-contiguous memory layout for JIT mode.
for f in self.field_args.values():
if not f.data.flags.c_contiguous:
f.data = f.data.copy()
for g in pset.fieldset.gridset.grids:
if not g.depth.flags.c_contiguous:
g.depth = g.depth.copy()
if not g.lon.flags.c_contiguous:
g.lon = g.lon.copy()
if not g.lat.flags.c_contiguous:
g.lat = g.lat.copy()
fargs = [byref(f.ctypes_struct) for f in self.field_args.values()]
fargs += [c_float(f) for f in self.const_args.values()]
particle_data = pset._particle_data.ctypes.data_as(c_void_p)
self._function(c_int(len(pset)), particle_data,
c_double(endtime), c_float(dt), *fargs)
def execute_python(self, pset, endtime, dt):
"""Performs the core update loop via Python"""
sign_dt = np.sign(dt)
for p in pset.particles:
# Don't execute particles that aren't started yet
sign_end_part = np.sign(endtime - p.time)
if (sign_end_part != sign_dt) and (dt != 0):
continue
# Compute min/max dt for first timestep
dt_pos = min(abs(p.dt), abs(endtime - p.time))
while dt_pos > 1e-6 or dt == 0:
try:
res = self.pyfunc(p, pset.fieldset, p.time, sign_dt * dt_pos)
except FieldSamplingError as fse:
res = ErrorCode.ErrorOutOfBounds
p.exception = fse
except Exception as e:
res = ErrorCode.Error
p.exception = e
# Update particle state for explicit returns
if res is not None:
p.state = res
# Handle particle time and time loop
if res is None or res == ErrorCode.Success:
# Update time and repeat
p.time += sign_dt * dt_pos
dt_pos = min(abs(p.dt), abs(endtime - p.time))
if dt == 0:
break
continue
elif res == ErrorCode.Repeat:
# Try again without time update
dt_pos = min(abs(p.dt), abs(endtime - p.time))
continue
else:
break # Failure - stop time loop
def execute(self, pset, endtime, dt, recovery=None, output_file=None):
"""Execute this Kernel over a ParticleSet for several timesteps"""
def remove_deleted(pset):
"""Utility to remove all particles that signalled deletion"""
indices = [i for i, p in enumerate(pset.particles)
if p.state in [ErrorCode.Delete]]
if len(indices) > 0 and output_file is not None:
output_file.write(pset[indices], endtime, deleted_only=True)
pset.remove(indices)
if recovery is None:
recovery = {}
recovery_map = recovery_base_map.copy()
recovery_map.update(recovery)
# Execute the kernel over the particle set
if self.ptype.uses_jit:
self.execute_jit(pset, endtime, dt)
else:
self.execute_python(pset, endtime, dt)
# Remove all particles that signalled deletion
remove_deleted(pset)
# Idenitify particles that threw errors
error_particles = [p for p in pset.particles
if p.state not in [ErrorCode.Success, ErrorCode.Repeat]]
while len(error_particles) > 0:
# Apply recovery kernel
for p in error_particles:
recovery_kernel = recovery_map[p.state]
p.state = ErrorCode.Success
recovery_kernel(p, self.fieldset, p.time, dt)
# Remove all particles that signalled deletion
remove_deleted(pset)
# Execute core loop again to continue interrupted particles
if self.ptype.uses_jit:
self.execute_jit(pset, endtime, dt)
else:
self.execute_python(pset, endtime, dt)
error_particles = [p for p in pset.particles
if p.state not in [ErrorCode.Success, ErrorCode.Repeat]]
def merge(self, kernel):
funcname = self.funcname + kernel.funcname
func_ast = FunctionDef(name=funcname, args=self.py_ast.args,
body=self.py_ast.body + kernel.py_ast.body,
decorator_list=[], lineno=1, col_offset=0)
return Kernel(self.fieldset, self.ptype, pyfunc=None,
funcname=funcname, funccode=self.funccode + kernel.funccode,
py_ast=func_ast, funcvars=self.funcvars + kernel.funcvars)
def __add__(self, kernel):
if not isinstance(kernel, Kernel):
kernel = Kernel(self.fieldset, self.ptype, pyfunc=kernel)
return self.merge(kernel)
def __radd__(self, kernel):
if not isinstance(kernel, Kernel):
kernel = Kernel(self.fieldset, self.ptype, pyfunc=kernel)
return kernel.merge(self)