Skip to content

Commit

Permalink
Merge branch 'master' into c_interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
erikvansebille committed Jul 20, 2018
2 parents 456e1ad + 3e62b9c commit 915d1b5
Show file tree
Hide file tree
Showing 11 changed files with 519 additions and 63 deletions.
102 changes: 99 additions & 3 deletions parcels/codegenerator.py
@@ -1,4 +1,5 @@
from parcels.field import Field, VectorField
from parcels.fieldset import FieldList, VectorFieldList
from parcels.loggers import logger
import ast
import cgen as c
Expand All @@ -19,6 +20,12 @@ def __getattr__(self, attr):
if isinstance(getattr(self.obj, attr), Field):
return FieldNode(getattr(self.obj, attr),
ccode="%s->%s" % (self.ccode, attr))
elif isinstance(getattr(self.obj, attr), VectorFieldList):
return VectorFieldListNode(getattr(self.obj, attr),
ccode="%s->%s" % (self.ccode, attr))
elif isinstance(getattr(self.obj, attr), FieldList) or isinstance(getattr(self.obj, attr), list):
return FieldListNode(getattr(self.obj, attr),
ccode="%s->%s" % (self.ccode, attr))
elif isinstance(getattr(self.obj, attr), VectorField):
return VectorFieldNode(getattr(self.obj, attr),
ccode="%s->%s" % (self.ccode, attr))
Expand Down Expand Up @@ -53,6 +60,32 @@ def __init__(self, field, args, var, var2, var3):
self.var3 = var3 # third variable for UVW interpolation


class FieldListNode(IntrinsicNode):
def __getitem__(self, attr):
return FieldListEvalNode(self.obj, attr)


class FieldListEvalNode(IntrinsicNode):
def __init__(self, fields, args, var):
self.fields = fields
self.args = args
self.var = var # the variable in which the interpolated field is written


class VectorFieldListNode(IntrinsicNode):
def __getitem__(self, attr):
return VectorFieldListEvalNode(self.obj, attr)


class VectorFieldListEvalNode(IntrinsicNode):
def __init__(self, field, args, var, var2, var3):
self.field = field
self.args = args
self.var = var # the variable in which the interpolated field is written
self.var2 = var2 # second variable for UV interpolation
self.var3 = var3 # third variable for UVW interpolation


class ConstNode(IntrinsicNode):
def __getitem__(self, attr):
return attr
Expand Down Expand Up @@ -177,7 +210,24 @@ def visit_Subscript(self, node):

# If we encounter field evaluation we replace it with a
# temporary variable and put the evaluation call on the stack.
if isinstance(node.value, FieldNode):
if isinstance(node.value, FieldListNode):
tmp = [self.get_tmp() for _ in node.value.obj]
# Insert placeholder node for field eval ...
self.stmt_stack += [FieldListEvalNode(node.value, node.slice, tmp)]
# .. and return the name of the temporary that will be populated
return ast.Name(id='+'.join(tmp))
elif isinstance(node.value, VectorFieldListNode):
tmp = [self.get_tmp() for _ in node.value.obj.U]
tmp2 = [self.get_tmp() for _ in node.value.obj.U]
tmp3 = [self.get_tmp() if node.value.obj.W else None for _ in node.value.obj.U]
# Insert placeholder node for field eval ...
self.stmt_stack += [VectorFieldListEvalNode(node.value, node.slice, tmp, tmp2, tmp3)]
# .. and return the name of the temporary that will be populated
if all(tmp3):
return ast.Tuple([ast.Name(id='+'.join(tmp)), ast.Name(id='+'.join(tmp2)), ast.Name(id='+'.join(tmp3))], ast.Load())
else:
return ast.Tuple([ast.Name(id='+'.join(tmp)), ast.Name(id='+'.join(tmp2))], ast.Load())
elif isinstance(node.value, FieldNode):
tmp = self.get_tmp()
# Insert placeholder node for field eval ...
self.stmt_stack += [FieldEvalNode(node.value, node.slice, tmp)]
Expand Down Expand Up @@ -260,8 +310,6 @@ def __init__(self, fieldset, ptype):
self.fieldset = fieldset
self.ptype = ptype
self.field_args = OrderedDict()
# Hack alert: JIT requires U field to update fieldset indexes
self.field_args['U'] = fieldset.U
self.vector_field_args = OrderedDict()
self.const_args = OrderedDict()

Expand Down Expand Up @@ -567,10 +615,25 @@ def visit_FieldNode(self, node):
"""Record intrinsic fields used in kernel"""
self.field_args[node.obj.name] = node.obj

def visit_FieldListNode(self, node):
"""Record intrinsic fields used in kernel"""
for fld in node.obj:
self.field_args[fld.name] = fld

def visit_VectorFieldNode(self, node):
"""Record intrinsic fields used in kernel"""
self.vector_field_args[node.obj.name] = node.obj

def visit_VectorFieldListNode(self, node):
"""Record intrinsic fields used in kernel"""
for fld in node.obj.U:
self.field_args[fld.name] = fld
for fld in node.obj.V:
self.field_args[fld.name] = fld
if hasattr(node.obj, 'W') and node.obj.W:
for fld in node.obj.W:
self.field_args[fld.name] = fld

def visit_ConstNode(self, node):
self.const_args[node.ccode] = node.obj

Expand Down Expand Up @@ -603,6 +666,39 @@ def visit_VectorFieldEvalNode(self, node):
node.ccode = c.Block([c.Assign("err", ccode_eval),
conv_stat, c.Statement("CHECKERROR(err)")])

def visit_FieldListEvalNode(self, node):
self.visit(node.fields)
self.visit(node.args)
cstat = []
for fld, var in zip(node.fields.obj, node.var):
ccode_eval = fld.ccode_eval(var, *node.args.ccode)
ccode_conv = fld.ccode_convert(*node.args.ccode)
conv_stat = c.Statement("%s *= %s" % (var, ccode_conv))
cstat += [c.Assign("err", ccode_eval), conv_stat, c.Statement("CHECKERROR(err)")]
node.ccode = c.Block(cstat)

def visit_VectorFieldListEvalNode(self, node):
self.visit(node.field)
self.visit(node.args)
cstat = []
if node.field.obj.W:
Wlist = node.field.obj.W
else:
Wlist = [None] * len(node.field.obj.U)
for U, V, W, var, var2, var3 in zip(node.field.obj.U, node.field.obj.V, Wlist, node.var, node.var2, node.var3):
vfld = VectorField(node.field.obj.name, U, V, W)
ccode_eval = vfld.ccode_eval(var, var2, var3, U, V, W, *node.args.ccode)
ccode_conv1 = U.ccode_convert(*node.args.ccode)
ccode_conv2 = V.ccode_convert(*node.args.ccode)
statements = [c.Statement("%s *= %s" % (var, ccode_conv1)),
c.Statement("%s *= %s" % (var2, ccode_conv2))]
if var3:
ccode_conv3 = W.ccode_convert(*node.args.ccode)
statements.append(c.Statement("%s *= %s" % (var3, ccode_conv3)))
conv_stat = c.Block(statements)
cstat += [c.Assign("err", ccode_eval), conv_stat, c.Statement("CHECKERROR(err)")]
node.ccode = c.Block(cstat)

def visit_Return(self, node):
self.visit(node.value)
node.ccode = c.Statement('return %s' % node.value.ccode)
Expand Down
4 changes: 4 additions & 0 deletions parcels/examples/example_moving_eddies.py
Expand Up @@ -152,6 +152,10 @@ def test_moving_eddies_file(fieldsetfile, mode):
def test_periodic_and_computeTimeChunk_eddies(mode):
filename = path.join(path.dirname(__file__), 'MovingEddies_data', 'moving_eddies')
fieldset = FieldSet.from_parcels(filename)
fieldset.add_constant('halo_west', fieldset.U.grid.lon[0])
fieldset.add_constant('halo_east', fieldset.U.grid.lon[-1])
fieldset.add_constant('halo_south', fieldset.U.grid.lat[0])
fieldset.add_constant('halo_north', fieldset.U.grid.lat[-1])
fieldset.add_periodic_halo(zonal=True, meridional=True)
pset = ParticleSet.from_list(fieldset=fieldset,
pclass=ptype[mode],
Expand Down

0 comments on commit 915d1b5

Please sign in to comment.