Skip to content

Commit

Permalink
ENH: formulate amplitudes for existing subsystems only (#127)
Browse files Browse the repository at this point in the history
* DOC: catch only `RuntimeWarning`s
* ENH: raise warning if subsystem does not exist
* ENH: warn if wrong reference subsystem is chosen
* MAINT: simplify `get_subsystem()` implementation
  • Loading branch information
redeboer committed Apr 30, 2024
1 parent b5404b6 commit eeaab6f
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 37 deletions.
2 changes: 1 addition & 1 deletion docs/comparison/d2kkk.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"simplify_latex_rendering()\n",
"logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n",
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
" logging.getLogger(\"ampform.sympy\").setLevel(logging.ERROR)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/comparison/jpsi2phipipi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"simplify_latex_rendering()\n",
"logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n",
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
" logging.getLogger(\"ampform.sympy\").setLevel(logging.ERROR)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/comparison/jpsi2pipipi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"simplify_latex_rendering()\n",
"logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n",
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
" logging.getLogger(\"ampform.sympy\").setLevel(logging.ERROR)\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/jpsi2ksp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"\n",
"simplify_latex_rendering()\n",
"logging.getLogger(\"absl\").setLevel(logging.ERROR) # mute JAX\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
Expand Down Expand Up @@ -191,7 +191,7 @@
" model_builder.dynamics_choices.register_builder(\n",
" chain, formulate_breit_wigner_with_form_factor\n",
" )\n",
"model = model_builder.formulate(reference_subsystem=1)\n",
"model = model_builder.formulate(reference_subsystem=2)\n",
"model.intensity"
]
},
Expand Down Expand Up @@ -230,7 +230,7 @@
},
"outputs": [],
"source": [
"Latex(aslatex({k: v for k, v in model.amplitudes.items() if v}))"
"Latex(aslatex(model.amplitudes))"
]
},
{
Expand Down
5 changes: 1 addition & 4 deletions docs/lc2pkpi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
"source": [
"from __future__ import annotations\n",
"\n",
"import warnings\n",
"\n",
"import graphviz\n",
"import qrules\n",
"import sympy as sp\n",
Expand All @@ -44,8 +42,7 @@
"from ampform_dpd.dynamics.builder import create_mass_symbol, get_mandelstam_s\n",
"from ampform_dpd.io import as_markdown_table, aslatex, simplify_latex_rendering\n",
"\n",
"simplify_latex_rendering()\n",
"warnings.simplefilter(\"ignore\")"
"simplify_latex_rendering()"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/xib2pkk.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"\n",
"simplify_latex_rendering()\n",
"logging.getLogger(\"absl\").setLevel(logging.ERROR) # mute JAX\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
Expand Down Expand Up @@ -256,7 +256,7 @@
},
"outputs": [],
"source": [
"Latex(aslatex({k: v for k, v in model.amplitudes.items() if v}))"
"Latex(aslatex(model.amplitudes))"
]
},
{
Expand Down
44 changes: 28 additions & 16 deletions src/ampform_dpd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import lru_cache
from itertools import product
from typing import Literal, Protocol
from warnings import warn

import sympy as sp
from ampform.kinematics.phasespace import compute_third_mandelstam
Expand All @@ -22,6 +23,8 @@
Particle,
ThreeBodyDecay,
ThreeBodyDecayChain,
_get_decay_description, # pyright:ignore[reportPrivateUsage]
_get_subsystem_ids, # pyright:ignore[reportPrivateUsage]
get_decay_product_ids,
to_particle,
)
Expand Down Expand Up @@ -81,6 +84,7 @@ def formulate(
reference_subsystem: FinalStateID = 1,
cleanup_summations: bool = False,
) -> AmplitudeModel:
_check_reference_subsystems(self.decay, reference_subsystem)
helicity_symbols: tuple[sp.Symbol, sp.Symbol, sp.Symbol, sp.Symbol] = (
sp.symbols("lambda:4", rational=True)
)
Expand All @@ -93,7 +97,7 @@ def formulate(
parameter_defaults = {}
args: tuple[sp.Rational, sp.Rational, sp.Rational, sp.Rational]
for args in product(*allowed_helicities.values()): # type:ignore[assignment]
for sub_system in (1, 2, 3):
for sub_system in _get_subsystem_ids(self.decay):
chain_model = self.formulate_subsystem_amplitude(*args, sub_system) # type:ignore[arg-type]
amplitude_definitions.update(chain_model.amplitudes)
angle_definitions.update(chain_model.variables)
Expand Down Expand Up @@ -234,26 +238,20 @@ def formulate_aligned_amplitude(
λ3: sp.Rational | sp.Symbol,
reference_subsystem: FinalStateID = 1,
) -> tuple[PoolSum, dict[sp.Symbol, sp.Expr]]:
_check_reference_subsystems(self.decay, reference_subsystem)
wigner_generator = _AlignmentWignerGenerator(reference_subsystem)
_λ0, _λ1, _λ2, _λ3 = sp.symbols(R"\lambda_(0:4)^{\prime}", rational=True)
j0, j1, j2, j3 = (self.decay.states[i].spin for i in sorted(self.decay.states))
A = _generate_amplitude_index_bases()
amp_expr = PoolSum(
A[1][_λ0, _λ1, _λ2, _λ3]
* wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=1)
* wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=1)
* wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=1)
* wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=1)
+ A[2][_λ0, _λ1, _λ2, _λ3]
* wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=2)
* wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=2)
* wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=2)
* wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=2)
+ A[3][_λ0, _λ1, _λ2, _λ3]
* wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=3)
* wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=3)
* wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=3)
* wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=3),
sum(
A[k][_λ0, _λ1, _λ2, _λ3]
* wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=k)
* wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=k)
* wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=k)
* wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=k)
for k in _get_subsystem_ids(self.decay)
),
(_λ0, create_spin_range(j0)),
(_λ1, create_spin_range(j1)),
(_λ2, create_spin_range(j2)),
Expand All @@ -262,6 +260,20 @@ def formulate_aligned_amplitude(
return amp_expr, wigner_generator.angle_definitions # type:ignore[return-value]


def _check_reference_subsystems(
decay: ThreeBodyDecay, reference_subsystem: FinalStateID
) -> None:
subsystem_ids = _get_subsystem_ids(decay)
if reference_subsystem not in subsystem_ids:
decay_description = _get_decay_description(decay)
subsystems = ", ".join(sorted(str(i) for i in _get_subsystem_ids(decay)))
msg = (
f"Decay {decay_description} only has subsystems {subsystems}. Are you"
f" sure you want to use subsystem {reference_subsystem} as reference?"
)
warn(msg, category=UserWarning)


def _create_coupling_symbol(
helicity_coupling: bool,
resonance: Str,
Expand Down
26 changes: 17 additions & 9 deletions src/ampform_dpd/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import lru_cache
from textwrap import dedent
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, overload
from warnings import warn

from attrs import field, frozen
from attrs.validators import instance_of
Expand All @@ -14,7 +15,6 @@
if TYPE_CHECKING:
import sympy as sp


InitialStateID = Literal[0]
"""ID for the initial state particle in a three-body decay."""
FinalStateID = Literal[1, 2, 3]
Expand Down Expand Up @@ -117,17 +117,25 @@ def find_chain(self, resonance_name: str) -> ThreeBodyDecayChain:
raise KeyError(msg)

def get_subsystem(self, subsystem_id: FinalStateID) -> ThreeBodyDecay:
child1_id, child2_id = get_decay_product_ids(subsystem_id)
child1 = self.final_state[child1_id]
child2 = self.final_state[child2_id]
filtered_chains = [
chain
for chain in self.chains
if chain.decay_products in {(child1, child2), (child2, child1)}
]
filtered_chains = [c for c in self.chains if c.spectator.index == subsystem_id]
if not filtered_chains:
decay_description = _get_decay_description(self)
subsystems = ", ".join(sorted(str(i) for i in _get_subsystem_ids(self)))
msg = f"Decay {decay_description} only has subsystems {subsystems}, not {subsystem_id}"
warn(msg, category=UserWarning)
return ThreeBodyDecay(self.states, filtered_chains)


def _get_decay_description(decay: ThreeBodyDecay) -> str:
initial_state = decay.initial_state.name
final_state = ", ".join(f"{i}: {s.name}" for i, s in decay.final_state.items())
return f"{initial_state}{final_state}"


def _get_subsystem_ids(decay: ThreeBodyDecay) -> set[FinalStateID]:
return {c.spectator.index for c in decay.chains}


def get_decay_product_ids(
spectator_id: FinalStateID,
) -> tuple[FinalStateID, FinalStateID]:
Expand Down

0 comments on commit eeaab6f

Please sign in to comment.