Skip to content

Commit

Permalink
FIX: support sympy.Indexed in ParameterValues (#336)
Browse files Browse the repository at this point in the history
* DOC: explain motivation for `ParameterValues`
* MAINT: test `ParameterValues` with `sympy.Indexed`
  • Loading branch information
redeboer committed Dec 5, 2022
1 parent 9ccaaf5 commit 113e039
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
38 changes: 20 additions & 18 deletions src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ def _order_amplitudes(
)


def _to_parameter_values(
mapping: Mapping[sp.Symbol, ParameterValue]
) -> ParameterValues:
def _to_parameter_values(mapping: Mapping[sp.Basic, ParameterValue]) -> ParameterValues:
return ParameterValues(mapping)


Expand All @@ -133,7 +131,7 @@ class HelicityModel: # noqa: R701
parameter_defaults: ParameterValues = field(converter=_to_parameter_values)
"""A mapping of suggested parameter values.
Keys are `~sympy.core.symbol.Symbol` instances from the main :attr:`expression` that
Keys are `~sympy.core.basic.Basic` instances from the main :attr:`expression` that
should be interpreted as parameters (as opposed to `kinematic_variables`). The
symbols are ordered alphabetically by name with natural sort order
(:func:`.natural_sorting`). Values have been extracted from the input
Expand Down Expand Up @@ -206,7 +204,7 @@ def rename_symbols( # noqa: R701
for amp, expr in self.amplitudes.items()
},
parameter_defaults={
symbol_mapping.get(par, par): value
symbol_mapping.get(par, par): value # type: ignore[call-overload]
for par, value in self.parameter_defaults.items()
},
components={
Expand Down Expand Up @@ -265,6 +263,10 @@ def sum_components(self, components: Iterable[str]) -> sp.Expr: # noqa: R701
class ParameterValues(abc.Mapping):
"""Ordered mapping to `ParameterValue` with convenient getter and setter.
This class makes it possible to search through a mapping of :mod:`sympy` symbols to
their values (a "parameter mapping") by symbol name or by index in the (ordered)
dictionary.
>>> a, b, c = sp.symbols("a b c")
>>> parameters = ParameterValues({a: 0.0, b: 1+1j, c: -2})
>>> parameters[a]
Expand All @@ -284,7 +286,7 @@ class ParameterValues(abc.Mapping):
.. automethod:: __setitem__
"""

def __init__(self, parameters: Mapping[sp.Symbol, ParameterValue]) -> None:
def __init__(self, parameters: Mapping[sp.Basic, ParameterValue]) -> None:
self.__parameters = dict(parameters)

def __repr__(self) -> str:
Expand All @@ -305,35 +307,35 @@ def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None:
p.breakable()
p.text("})")

def __getitem__(self, key: sp.Symbol | int | str) -> ParameterValue:
def __getitem__(self, key: sp.Basic | int | str) -> ParameterValue:
par = self._get_parameter(key)
return self.__parameters[par]

def __setitem__(self, key: sp.Symbol | int | str, value: ParameterValue) -> None:
def __setitem__(self, key: sp.Basic | int | str, value: ParameterValue) -> None:
par = self._get_parameter(key)
self.__parameters[par] = value

@singledispatchmethod
def _get_parameter(self, key: sp.Symbol | int | str) -> sp.Symbol:
def _get_parameter(self, key: sp.Basic | int | str) -> sp.Basic:
raise KeyError( # no TypeError because of sympy.core.expr.Expr.xreplace
f"Cannot find parameter for key type {type(key).__name__}"
)

@_get_parameter.register(sp.Symbol)
def _(self, par: sp.Symbol) -> sp.Symbol:
@_get_parameter.register(sp.Basic)
def _(self, par: sp.Basic) -> sp.Basic:
if par not in self.__parameters:
raise KeyError(f"{type(self).__name__} has no parameter {par}")
return par

@_get_parameter.register(str)
def _(self, name: str) -> sp.Symbol:
def _(self, name: str) -> sp.Basic:
for parameter in self.__parameters:
if parameter.name == name:
if str(parameter) == name:
return parameter
raise KeyError(f"No parameter available with name {name}")

@_get_parameter.register(int)
def _(self, key: int) -> sp.Symbol:
def _(self, key: int) -> sp.Basic:
for i, parameter in enumerate(self.__parameters):
if i == key:
return parameter
Expand All @@ -345,13 +347,13 @@ def _(self, key: int) -> sp.Symbol:
def __len__(self) -> int:
return len(self.__parameters)

def __iter__(self) -> Iterator[sp.Symbol]:
def __iter__(self) -> Iterator[sp.Basic]:
return iter(self.__parameters)

def items(self) -> ItemsView[sp.Symbol, ParameterValue]:
def items(self) -> ItemsView[sp.Basic, ParameterValue]:
return self.__parameters.items()

def keys(self) -> KeysView[sp.Symbol]:
def keys(self) -> KeysView[sp.Basic]:
return self.__parameters.keys()

def values(self) -> ValuesView[ParameterValue]:
Expand All @@ -364,7 +366,7 @@ def values(self) -> ValuesView[ParameterValue]:

@define
class _HelicityModelIngredients:
parameter_defaults: dict[sp.Symbol, ParameterValue] = field(factory=dict)
parameter_defaults: dict[sp.Basic, ParameterValue] = field(factory=dict)
amplitudes: dict[sp.Indexed, sp.Expr] = field(factory=dict)
components: dict[str, sp.Expr] = field(factory=dict)
kinematic_variables: dict[sp.Symbol, sp.Expr] = field(factory=dict)
Expand Down
16 changes: 9 additions & 7 deletions tests/helicity/test_helicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def test_use_helicity_couplings(self, reaction: ReactionInfo):
builder.use_helicity_couplings = True
coupling_model = builder.formulate()

coefficient_names = {p.name for p in coeff_model.parameter_defaults}
coupling_names = {p.name for p in coupling_model.parameter_defaults}
coefficient_names = {str(p) for p in coeff_model.parameter_defaults}
coupling_names = {str(p) for p in coupling_model.parameter_defaults}
if reaction.formalism == "canonical-helicity":
assert len(coefficient_names) == 4
assert coefficient_names == {
Expand Down Expand Up @@ -217,18 +217,18 @@ def test_rename_all_parameters_with_stable_final_state( # noqa: R701
builder.stable_final_state_ids = set(reaction.final_state)
original_model = builder.formulate()
renames = {
par.name: Rf"{{{par.name}}}_\mathrm{{renamed}}"
str(par): Rf"{{{str(par)}}}_\mathrm{{renamed}}"
for par in original_model.parameter_defaults
}
new_model = original_model.rename_symbols(renames)
for par in original_model.parameter_defaults:
if par.name.startswith("m_") and par.name[-1] in {"0", "1", "2"}:
if str(par).startswith("m_") and str(par)[-1] in {"0", "1", "2"}:
continue
assert par not in new_model.parameter_defaults
for par in new_model.parameter_defaults:
if par.name.startswith("m_") and par.name[-1] in {"0", "1", "2"}:
if str(par).startswith("m_") and str(par)[-1] in {"0", "1", "2"}:
continue
assert par.name.endswith(R"_\mathrm{renamed}")
assert str(par).endswith(R"_\mathrm{renamed}")

def test_rename_variables(self, amplitude_model: tuple[str, HelicityModel]):
_, model = amplitude_model
Expand Down Expand Up @@ -329,7 +329,9 @@ def test_amplitudes(self, formalism: str):
class TestParameterValues:
@pytest.mark.parametrize("subs_method", ["subs", "xreplace"])
def test_subs_xreplace(self, subs_method: str):
a, b, x, y = sp.symbols("a b x y")
base = sp.IndexedBase("b")
a, x, y = sp.symbols("a x y")
b: sp.Indexed = base[1, 2]
expr: sp.Expr = a * x + b * y
parameters = ParameterValues({a: 2, b: -3})
if subs_method == "subs":
Expand Down
4 changes: 2 additions & 2 deletions tests/helicity/test_naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_coefficient_names( # noqa: R701

def get_coefficients(model: HelicityModel) -> list[str]:
return [
symbol.name
str(symbol)
for symbol in model.parameter_defaults
if symbol.name.startswith("C_")
if str(symbol).startswith("C_")
]

0 comments on commit 113e039

Please sign in to comment.