Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix __hash__ and __eq__ for basix.ufl elements #718

Merged
merged 21 commits into from
Nov 1, 2023
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 51 additions & 21 deletions python/basix/ufl.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,12 @@ def __init__(self, element: _basix.finite_element.FiniteElement, gdim: _typing.O
"""Create a Basix element."""
if element.family == _basix.ElementFamily.custom:
self._is_custom = True
repr = f"custom Basix element ({_compute_signature(element)})"
repr = f"custom Basix element ({_compute_signature(element)}"
else:
self._is_custom = False
repr = (f"Basix element ({element.family.__name__}, {element.cell_type.__name__}, {element.degree}, "
f"{element.lagrange_variant.__name__}, {element.dpc_variant.__name__}, {element.discontinuous})")
f"{element.lagrange_variant.__name__}, {element.dpc_variant.__name__}, {element.discontinuous}")
repr = _repr_optional_args(repr, ("gdim", gdim))

super().__init__(
repr, element.cell_type.__name__, tuple(element.value_shape), element.degree,
Expand All @@ -364,7 +365,7 @@ def __init__(self, element: _basix.finite_element.FiniteElement, gdim: _typing.O

def __eq__(self, other) -> bool:
"""Check if two elements are equal."""
return isinstance(other, _BasixElement) and self.element == other.element
return isinstance(other, _BasixElement) and (self.element == other.element and self._gdim == other._gdim)

def __hash__(self) -> int:
"""Return a hash."""
Expand Down Expand Up @@ -567,13 +568,14 @@ def __init__(self, element: _ElementBase, component: int, gdim: _typing.Optional
"""Initialise the element."""
self.element = element
self.component = component
super().__init__(f"component element ({element._repr}, {component})",
element.cell_type.__name__, (1, ), element._degree, gdim=gdim)
repr = f"component element ({element._repr}, {component}"
repr = _repr_optional_args(repr, ("gdim", gdim))
super().__init__(repr, element.cell_type.__name__, (1, ), element._degree, gdim=gdim)

def __eq__(self, other) -> bool:
"""Check if two elements are equal."""
return (isinstance(other, _ComponentElement) and self.element == other.element
and self.component == other.component)
and self.component == other.component and self._gdim == other._gdim)

def __hash__(self) -> int:
"""Return a hash."""
Expand Down Expand Up @@ -754,13 +756,15 @@ def __init__(self, sub_elements: _typing.List[_ElementBase], gdim: _typing.Optio
else:
pullback = _MixedPullback(self)

super().__init__("mixed element (" + ", ".join(i._repr for i in sub_elements) + ")",
sub_elements[0].cell_type.__name__,
repr = "mixed element (" + ", ".join(i._repr for i in sub_elements)
repr = _repr_optional_args(repr, ("gdim", gdim))
super().__init__(repr, sub_elements[0].cell_type.__name__,
(sum(i.value_size for i in sub_elements), ), pullback=pullback, gdim=gdim)

def __eq__(self, other) -> bool:
"""Check if two elements are equal."""
if isinstance(other, _MixedElement) and len(self._sub_elements) == len(other._sub_elements):
if isinstance(other, _MixedElement) and (len(self._sub_elements) == len(other._sub_elements)
and self._gdim == other._gdim):
for i, j in zip(self._sub_elements, other._sub_elements):
if i != j:
return False
Expand Down Expand Up @@ -1000,13 +1004,10 @@ def __init__(self, sub_element: _ElementBase, shape: _typing.Tuple[int, ...],

repr = f"blocked element ({sub_element._repr}, {shape}"
if len(shape) == 2:
if symmetry:
repr += ", True"
else:
repr += ", False"
if gdim is not None:
repr += f", gdim={gdim}"
repr += ")"
_symm = ("symmetry", "True" if symmetry else "False")
else:
_symm = ("symmetry", None)
repr = _repr_optional_args(repr, _symm, ("gdim", gdim))

super().__init__(repr, sub_element.cell_type.__name__, shape,
sub_element._degree, sub_element._pullback, gdim=gdim)
Expand All @@ -1026,7 +1027,8 @@ def __eq__(self, other) -> bool:
"""Check if two elements are equal."""
return (
isinstance(other, _BlockedElement) and self._block_size == other._block_size
and self.block_shape == other.block_shape and self.sub_element == other.sub_element)
and self.block_shape == other.block_shape and self.sub_element == other.sub_element
and self._gdim == other._gdim)

def __hash__(self) -> int:
"""Return a hash."""
Expand Down Expand Up @@ -1274,8 +1276,12 @@ def basix_sobolev_space(self):

def __eq__(self, other) -> bool:
"""Check if two elements are equal."""
return isinstance(other, _QuadratureElement) and _np.allclose(self._points, other._points) and \
_np.allclose(self._weights, other._weights)
return isinstance(other, _QuadratureElement) and (
self._cell_type == other._cell_type
and self._pullback == other._pullback
and _np.allclose(self._points, other._points)
and _np.allclose(self._weights, other._weights)
)

def __hash__(self) -> int:
"""Return a hash."""
Expand Down Expand Up @@ -1437,7 +1443,7 @@ def __init__(self, cell: _basix.CellType, value_shape: _typing.Tuple[int, ...]):
self._cell_type = cell
tdim = len(_basix.topology(cell)) - 1

super().__init__(f"RealElement({element})", cell.__name__, value_shape, 0)
super().__init__(f"RealElement({cell.__name__}, {value_shape})", cell.__name__, value_shape, 0)

self._entity_counts = []
if tdim >= 1:
Expand All @@ -1450,7 +1456,8 @@ def __init__(self, cell: _basix.CellType, value_shape: _typing.Tuple[int, ...]):

def __eq__(self, other) -> bool:
"""Check if two elements are equal."""
return isinstance(other, _RealElement)
return isinstance(other, _RealElement) and (self._cell_type == other._cell_type
and self._value_shape == other._value_shape)

def __hash__(self) -> int:
"""Return a hash."""
Expand Down Expand Up @@ -1634,6 +1641,29 @@ def _compute_signature(element: _basix.finite_element.FiniteElement) -> str:
return signature


def _repr_optional_args(partial_repr: str, *args):
"""Augment an element `repr` by appending non-None optional arguments.

Args:
partial_repr: The initial `repr` of a finite element incorporating all required
arguments but not including any optional args
conpierce8 marked this conversation as resolved.
Show resolved Hide resolved
args: Sequence of tuples `(name: str, value: typing.Any)` where `name` is the name
of an optional argument to be including in the repr and `value` is its
conpierce8 marked this conversation as resolved.
Show resolved Hide resolved
value. All arguments for which `value is not None` will be appended to
`partial_repr`.

mscroggs marked this conversation as resolved.
Show resolved Hide resolved
Returns:
A string representation of a finite element
"""

repr = partial_repr
for name, value in args:
if value is not None:
repr += f", {name}={value}"
repr += ")"
return repr


@_functools.lru_cache()
def element(family: _typing.Union[_basix.ElementFamily, str], cell: _typing.Union[_basix.CellType, str], degree: int,
lagrange_variant: _basix.LagrangeVariant = _basix.LagrangeVariant.unset,
Expand Down