Skip to content

Commit

Permalink
FIX: support subs/xreplace in EnergyDependentWidth (#330)
Browse files Browse the repository at this point in the history
* FIX: implement custom `_eval_subs()`
* FIX: implement custom `_xreplace()`
* MAINT: add test for `EnergyDependentWidth.subs()` and `xreplace()`
* MAINT: improve `round_nested()` function
  • Loading branch information
redeboer committed Sep 28, 2022
1 parent 23b5ddc commit 033b6fd
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
38 changes: 38 additions & 0 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import sympy as sp
from sympy.core.basic import _aresame
from sympy.printing.latex import LatexPrinter

from ampform.sympy import (
Expand Down Expand Up @@ -221,6 +222,43 @@ def _latex(self, printer: LatexPrinter, *args) -> str:
name = Rf"\Gamma{subscript}" if self._name is None else self._name
return Rf"{name}\left({s}\right)"

def _eval_subs(self, old, new):
# https://github.com/ComPWA/sympy/blob/bd0cf9a/sympy/core/basic.py#L1074-L1104
hit = False
new_args = list(self.args)
for i, arg in enumerate(self.args):
if not hasattr(arg, "_eval_subs"):
continue
arg = arg._subs(old, new)
if not _aresame(arg, new_args[i]):
hit = True
new_args[i] = arg
if hit:
# pylint: disable=no-value-for-parameter
return self.func(*new_args, self.phsp_factor, self._name)
return self

def _xreplace(self, rule):
# https://github.com/sympy/sympy/blob/bd0cf9a/sympy/core/basic.py#L1190-L1210
if self in rule:
return rule[self], True
if rule:
new_args = []
hit = False
for a in self.args:
_xreplace = getattr(a, "_xreplace", None)
if _xreplace is not None:
a_xr = _xreplace(rule)
new_args.append(a_xr[0])
hit |= a_xr[1]
else:
new_args.append(a)
new_args = tuple(new_args)
if hit:
# pylint: disable=no-value-for-parameter
return self.func(*new_args, self.phsp_factor, self._name), True
return self, False


def relativistic_breit_wigner(s, mass0, gamma0) -> sp.Expr:
"""Relativistic Breit-Wigner lineshape.
Expand Down
40 changes: 36 additions & 4 deletions tests/dynamics/test_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,35 @@ def test_init():
assert width.phsp_factor is EqualMassPhaseSpaceFactor
assert width._name == "Gamma_1"

@pytest.mark.parametrize("method", ["subs", "xreplace"])
def test_doit_and_subs(self, method: str):
s, m0, w0, m_a, m_b = sp.symbols("s m0 Gamma0 m_a m_b", nonnegative=True)
parameters = {
m0: 1.44,
w0: 0.35,
m_a: 0.938,
m_b: 0.548,
}
width = EnergyDependentWidth(
s=s,
mass0=m0,
gamma0=w0,
m_a=m_a,
m_b=m_a,
angular_momentum=0,
meson_radius=1,
phsp_factor=PhaseSpaceFactorSWave,
)
subs_first = round_nested(_subs(width, parameters, method).doit(), n_decimals=3)
doit_first = round_nested(_subs(width.doit(), parameters, method), n_decimals=3)
subs_first = round_nested(subs_first, n_decimals=3)
doit_first = round_nested(doit_first, n_decimals=3)
assert str(subs_first) == str(doit_first)


def _subs(obj: sp.Basic, replacements: dict, method) -> sp.Expr:
return getattr(obj, method)(replacements)


def test_generate( # pylint: disable=too-many-locals
amplitude_model: tuple[str, HelicityModel],
Expand Down Expand Up @@ -171,11 +200,14 @@ def test_relativistic_breit_wigner_with_ff_phsp_factor(func):


def round_nested(expression: sp.Expr, n_decimals: int) -> sp.Expr:
no_sqrt_expr = expression
for node in sp.preorder_traversal(expression):
if node.free_symbols:
continue
if isinstance(node, (float, sp.Float)):
expression = expression.subs(node, round(node, n_decimals))
if isinstance(node, sp.Pow) and node.args[1] == 1 / 2:
expression = expression.subs(node, round(node.n(), n_decimals))
return expression
no_sqrt_expr = no_sqrt_expr.xreplace({node: node.n()})
rounded_expr = no_sqrt_expr
for node in sp.preorder_traversal(no_sqrt_expr):
if isinstance(node, (float, sp.Float)):
rounded_expr = rounded_expr.xreplace({node: round(node, n_decimals)})
return rounded_expr

0 comments on commit 033b6fd

Please sign in to comment.