Skip to content
This repository has been archived by the owner on Feb 21, 2022. It is now read-only.

Commit

Permalink
Merged in miklos1/fix-defaultdict (pull request #37)
Browse files Browse the repository at this point in the history
Do not expose defaultdict in entity_dofs()

Approved-by: Lawrence Mitchell <wence@gmx.li>
  • Loading branch information
miklos1 committed Jun 2, 2017
2 parents d29555a + 8a66d43 commit d4812f9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion FIAT/enriched.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, *elements):

# set up entity_ids - for each geometric entity, just concatenate
# the entities of the constituent elements
entity_ids = concatenate_entity_dofs(elements)
entity_ids = concatenate_entity_dofs(ref_el, elements)

# set up dual basis - just concatenation
nodes = list(chain.from_iterable(e.dual_basis() for e in elements))
Expand Down
15 changes: 8 additions & 7 deletions FIAT/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
# along with FIAT. If not, see <http://www.gnu.org/licenses/>.

from __future__ import absolute_import, print_function, division
from six import iteritems
from six.moves import map

import numpy

from collections import defaultdict
from operator import add
from functools import partial

Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, elements, ref_el=None):
# expect them to be. :(
nodes = [L for e in elements for L in e.dual_basis()]

entity_dofs = concatenate_entity_dofs(elements)
entity_dofs = concatenate_entity_dofs(ref_el, elements)

dual = DualSet(nodes, ref_el, entity_dofs)
super(MixedElement, self).__init__(ref_el, dual, None, mapping=None)
Expand Down Expand Up @@ -90,7 +90,7 @@ def tabulate(self, order, points, entity=None):
for i, e in enumerate(self.elements()):
table = e.tabulate(order, points, entity)

for d, tab in table.items():
for d, tab in iteritems(table):
try:
arr = output[d]
except KeyError:
Expand All @@ -109,15 +109,16 @@ def is_nodal(self):
return all(e.is_nodal() for e in self._elements)


def concatenate_entity_dofs(elements):
def concatenate_entity_dofs(ref_el, elements):
"""Combine the entity_dofs from a list of elements into a combined
entity_dof containing the information for the concatenated DoFs of
all the elements."""
entity_dofs = defaultdict(partial(defaultdict, list))
entity_dofs = {dim: {i: [] for i in entities}
for dim, entities in iteritems(ref_el.get_topology())}
offsets = numpy.cumsum([0] + list(e.space_dimension()
for e in elements), dtype=int)
for i, d in enumerate(e.entity_dofs() for e in elements):
for dim, dofs in d.items():
for ent, off in dofs.items():
for dim, dofs in iteritems(d):
for ent, off in iteritems(dofs):
entity_dofs[dim][ent] += list(map(partial(add, offsets[i]), off))
return entity_dofs

0 comments on commit d4812f9

Please sign in to comment.