Skip to content

Commit

Permalink
Merged in martinal/topic-nested-mixed-split (pull request #58)
Browse files Browse the repository at this point in the history
Attempt at simple extension of split to nested mixed spaces
  • Loading branch information
Martin Sandve Alnæs committed Oct 12, 2016
2 parents eca3520 + 8b655dd commit b1c77f9
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 87 deletions.
41 changes: 30 additions & 11 deletions test/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ def test_split(self):
r = TensorElement("CG", cell, 1, symmetry={(1, 0): (0, 1)}, shape=(d, d))
m = MixedElement(f, v, w, t, s, r)

# Shapes of all these functions are correct:
# Check that shapes of all these functions are correct:
assert () == Coefficient(f).ufl_shape
self.assertEqual((d,), Coefficient(v).ufl_shape)
self.assertEqual((d+1,), Coefficient(w).ufl_shape)
self.assertEqual((d, d), Coefficient(t).ufl_shape)
self.assertEqual((d, d), Coefficient(s).ufl_shape)
self.assertEqual((d, d), Coefficient(r).ufl_shape)
self.assertEqual((3*d*d + 2*d + 2,), Coefficient(m).ufl_shape) # sum of value sizes, not accounting for symmetries
assert (d,) == Coefficient(v).ufl_shape
assert (d+1,) == Coefficient(w).ufl_shape
assert (d, d) == Coefficient(t).ufl_shape
assert (d, d) == Coefficient(s).ufl_shape
assert (d, d) == Coefficient(r).ufl_shape
# sum of value sizes, not accounting for symmetries:
assert (3*d*d + 2*d + 2,) == Coefficient(m).ufl_shape

# Shapes of subelements are reproduced:
g = Coefficient(m)
Expand All @@ -38,10 +39,28 @@ def test_split(self):
s -= product(g2.ufl_shape)
assert s == 0

# TODO: Should functions on mixed elements (vector+vector) be able to have tensor shape instead of vector shape? Think Marie wants this for BDM+BDM?
# Mixed elements of non-scalar subelements are flattened
v2 = MixedElement(v, v)
m2 = MixedElement(t, t)
# assert d == 2
# self.assertEqual((2,2), Coefficient(v2).ufl_shape)
self.assertEqual((d+d,), Coefficient(v2).ufl_shape)
self.assertEqual((2*d*d,), Coefficient(m2).ufl_shape)
# assert (2,2) == Coefficient(v2).ufl_shape
assert (d+d,) == Coefficient(v2).ufl_shape
assert (2*d*d,) == Coefficient(m2).ufl_shape

# Split twice on nested mixed elements gets
# the innermost scalar subcomponents
t = TestFunction(f*v)
assert split(t) == (t[0], as_vector((t[1], t[2])))
assert split(split(t)[1]) == (t[1], t[2])
t = TestFunction(f*(f*v))
assert split(t) == (t[0], as_vector((t[1], t[2], t[3])))
assert split(split(t)[1]) == (t[1], as_vector((t[2], t[3])))
t = TestFunction((v*f)*(f*v))
assert split(t) == (as_vector((t[0], t[1], t[2])),
as_vector((t[3], t[4], t[5])))
assert split(split(t)[0]) == (as_vector((t[0], t[1])), t[2])
assert split(split(t)[1]) == (t[3], as_vector((t[4], t[5])))
assert split(split(split(t)[0])[0]) == (t[0], t[1])
assert split(split(split(t)[0])[1]) == (t[2],)
assert split(split(split(t)[1])[0]) == (t[3],)
assert split(split(split(t)[1])[1]) == (t[4], t[5])
145 changes: 69 additions & 76 deletions ufl/split_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,108 +24,101 @@

from ufl.log import error
from ufl.utils.sequences import product
from ufl.utils.dicts import EmptyDict
from ufl.finiteelement import MixedElement, TensorElement
from ufl.tensors import as_vector, as_matrix
from ufl.tensors import as_vector, as_matrix, ListTensor
from ufl.indexed import Indexed
from ufl.permutation import compute_indices
from ufl.utils.indexflattening import flatten_multiindex, shape_to_strides


def split(v):
"""UFL operator: If v is a Coefficient or Argument in a mixed space, returns
a tuple with the function components corresponding to the subelements."""

# Default range is all of v
begin = 0
end = None

if isinstance(v, Indexed):
# Special case: split previous output of split again
# Consistent with simple element, just return function in a tuple
return (v,)

elif isinstance(v, ListTensor):
# Special case: split previous output of split again
ops = v.ufl_operands
if all(isinstance(comp, Indexed) for comp in ops):
args = [comp.ufl_operands[0] for comp in ops]
if all(args[0] == args[i] for i in range(1, len(args))):
# Get innermost terminal here and its element
v = args[0]
# Get relevant range of v components
begin, = ops[0].ufl_operands[1]
end, = ops[-1].ufl_operands[1]
begin = int(begin)
end = int(end) + 1
else:
error("Don't know how to split %s." % (v,))
else:
error("Don't know how to split %s." % (v,))

# Special case: simple element, just return function in a tuple
element = v.ufl_element()
if not isinstance(element, MixedElement):
assert end is None
return (v,)

if isinstance(element, TensorElement):
s = element.symmetry()
if s:
# FIXME: How should this be defined? Should we return one
# subfunction for each value component or only for those
# not mapped to another? I think split should ignore the
# symmetry.
if element.symmetry():
error("Split not implemented for symmetric tensor elements.")

# Compute value size
value_size = product(element.value_shape())
actual_value_size = value_size
if len(v.ufl_shape) != 1:
error("Don't know how to split tensor valued mixed functions without flattened index space.")

# Extract sub coefficient
offset = 0
# Compute value size and set default range end
value_size = product(element.value_shape())
if end is None:
end = value_size
else:
# Recursively dive into mixedelement in to subelement
# corresponding to beginning of range
j = begin
while True:
sub_i, j = element.extract_subelement_component(j)
element = element.sub_elements()[sub_i]
# Then break when we find the subelement that covers the whole range
if product(element.value_shape()) == (end - begin):
break

# Build expressions representing the subfunction of v for each subelement
offset = begin
sub_functions = []
for i, e in enumerate(element.sub_elements()):
# Get shape, size, indices, and v components
# corresponding to subelement value
shape = e.value_shape()
strides = shape_to_strides(shape)
rank = len(shape)
sub_size = product(shape)
subindices = [flatten_multiindex(c, strides)
for c in compute_indices(shape)]
components = [v[k + offset] for k in subindices]

# Shape components into same shape as subelement
if rank == 0:
# This subelement is a scalar, always maps to a single
# value
subv = v[offset]
offset += 1

elif rank == 1:
# This subelement is a vector, always maps to a sequence of values
sub_size, = shape
components = [v[j] for j in range(offset, offset + sub_size)]
subv, = components
elif rank <= 1:
subv = as_vector(components)
offset += sub_size

elif rank == 2:
# This subelement is a tensor, possibly with symmetries,
# slightly more complicated...

# Size of this subvalue
sub_size = product(shape)

# If this subelement is a symmetric element, subtract
# symmetric components
s = None
if isinstance(e, TensorElement):
s = e.symmetry()
s = s or EmptyDict
# If we do this, we must fix the size computation in
# MixedElement.__init__ as well
# actual_value_size -= len(s)
# sub_size -= len(s)
# print s
# Build list of lists of value components
components = []
for ii in range(shape[0]):
row = []
for jj in range(shape[1]):
# Map component (i,j) through symmetry mapping
c = (ii, jj)
c = s.get(c, c)
i, j = c
# Extract component c of this subvalue from global tensor v
if len(v.ufl_shape) == 1:
# Mapping into a flattened vector
k = offset + i*shape[1] + j
component = v[k]
elif len(v.ufl_shape) == 2:
# Mapping into a concatenated tensor (is this
# a figment of my imagination?)
error("Not implemented.")
row_offset, col_offset = 0, 0 # TODO
k = (row_offset + i, col_offset + j)
component = v[k]
row.append(component)
components.append(row)

# Make a matrix of the components
subv = as_matrix(components)
offset += sub_size

subv = as_matrix([components[i*shape[1]: (i+1)*shape[1]]
for i in range(shape[0])])
else:
# TODO: Handle rank > 2? Or is there such a thing?
error("Don't know how to split functions with sub functions of rank %d (yet)." % rank)
# for indices in compute_indices(shape):
# #k = offset + sum(i*s for (i,s) in izip(indices, shape[1:] + (1,)))
# vs.append(v[indices])
error("Don't know how to split functions with sub functions of rank %d." % rank)

offset += sub_size
sub_functions.append(subv)

if actual_value_size != offset:
error("Logic breach in function splitting.")
if end != offset:
error("Function splitting failed to extract components for whole intended range. Something is wrong.")

return tuple(sub_functions)

0 comments on commit b1c77f9

Please sign in to comment.