Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Unpack user-defined unit aliases in operations #74

Merged
merged 5 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions qexpy/utils/units.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions for parsing and constructing unit strings"""

# pylint: disable=too-few-public-methods
# pylint: disable=too-few-public-methods,protected-access

from __future__ import annotations

Expand Down Expand Up @@ -92,9 +92,29 @@ def __setitem__(self, key, value):
def update(self, __m, **_):
raise TypeError("Unit does not support item assignment.")

def _unpack(self):
"""Recursively unpacks user-defined aliases for compound units"""

try:
result = {}
for unit, exp in self.items():
if unit in _registered_units:
unpacked = _registered_units[unit]._unpack()
for tok, val in unpacked.items():
result[tok] = result.get(tok, 0) + val * exp
else:
result[unit] = result.get(unit, 0) + exp
return Unit(result)
except RecursionError as e:
raise RecursionError(
"Unable to derive units for the result of this operation, there is likely "
"circular reference in your custom unit definitions."
) from e

def __add__(self, other: dict) -> Unit:
assert isinstance(other, Unit)
if self and other and self != other:
_self, _other = self._unpack(), other._unpack()
if _self and _other and _self != _other:
warnings.warn("Adding two quantities with mismatching units!")
return Unit({})
return Unit(dict(self.items())) if self else Unit(dict(other.items()))
Expand All @@ -103,7 +123,8 @@ def __add__(self, other: dict) -> Unit:

def __sub__(self, other):
assert isinstance(other, Unit)
if self and other and self != other:
_self, _other = self._unpack(), other._unpack()
if _self and _other and _self != _other:
warnings.warn("Subtracting two quantities with mismatching units!")
return Unit({})
return Unit(dict(self.items())) if self else Unit(dict(other.items()))
Expand All @@ -112,25 +133,33 @@ def __sub__(self, other):

def __mul__(self, other):
assert isinstance(other, Unit)
if self and not other:
return Unit(dict(self.items()))
if not self and other:
return Unit(dict(other.items()))
result = {}
for unit, exp in self.items():
_self, _other = self._unpack(), other._unpack()
for unit, exp in _self.items():
result[unit] = exp
for unit, exp in other.items():
for unit, exp in _other.items():
result[unit] = result.get(unit, 0) + exp

result = {name: exp for name, exp in result.items() if exp != 0}
return Unit(result)

__rmul__ = __mul__

def __truediv__(self, other):
assert isinstance(other, Unit)
if self and not other:
return Unit(dict(self.items()))
if not self and other:
return Unit({k: -v for k, v in other.items()})
result = {}
for unit, exp in self.items():
_self, _other = self._unpack(), other._unpack()
for unit, exp in _self.items():
result[unit] = exp
for unit, exp in other.items():
for unit, exp in _other.items():
result[unit] = result.get(unit, 0) - exp

result = {name: exp for name, exp in result.items() if exp != 0}
return Unit(result)

Expand Down
45 changes: 45 additions & 0 deletions tests/utils/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
}

STRINGS_TO_UNITS = [
("", {}),
("kg*m^2/s^2", PREDEFINED["joule"]),
("kg^1m^2s^-2", PREDEFINED["joule"]),
("kg/(m*s^2)", PREDEFINED["pascal"]),
Expand Down Expand Up @@ -107,6 +108,18 @@ def test_define_unit(self):
q.clear_unit_definitions()
assert q.utils.units._registered_units == {}

def test_define_unit_circular_reference(self):
"""Tests that an easy-to-understand error is raised with circular unit definitions"""

q.define_unit("A", "B*C")
q.define_unit("B", "X*V/A")

unit_1 = Unit({"A": 1})
unit_2 = Unit({"B": 1})

with pytest.raises(RecursionError, match="Unable to derive units"):
_ = unit_1 + unit_2


class TestUnitOperations:
"""Tests for unit operations"""
Expand Down Expand Up @@ -140,3 +153,35 @@ def test_unit_exponentiation(self):
unit = Unit({"kg": 1, "m": 2, "s": -2})
assert unit**2 == Unit({"kg": 2, "m": 4, "s": -4})
assert unit**0.5 == Unit({"kg": 0.5, "m": 1, "s": -1})

def test_unpack_unit(self):
"""Tests that pre-defined units are unpacked when possible"""

q.define_unit("F", "C^2/(N*m)")
q.define_unit("N", "kg*m/s^2")

unit_q = Unit({"C": 1})
unit_r = Unit({"m": 1})
unit_eps = Unit({"F": 1, "m": -1})

res = unit_q * unit_q / unit_eps / unit_r**2
assert res == Unit({"kg": 1, "m": 1, "s": -2})
assert str(res) == "N"

q.clear_unit_definitions()

def test_unit_not_unpacked_if_unnecessary(self):
"""Tests that pre-defined units are not unpacked when not necessary"""

q.define_unit("N", "kg*m/s^2")

unit_1 = Unit({"N": 1})

assert unit_1 + Unit({}) == Unit({"N": 1})
assert unit_1 - Unit({}) == Unit({"N": 1})
assert unit_1 * Unit({}) == Unit({"N": 1})
assert Unit({}) * unit_1 == Unit({"N": 1})
assert unit_1 / Unit({}) == Unit({"N": 1})
assert Unit({}) / unit_1 == Unit({"N": -1})

q.clear_unit_definitions()
Loading