From eeaab6f40239cb3162504b3122cf783f71615f8b Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 30 Apr 2024 12:23:32 +0200 Subject: [PATCH] ENH: formulate amplitudes for existing subsystems only (#127) * DOC: catch only `RuntimeWarning`s * ENH: raise warning if subsystem does not exist * ENH: warn if wrong reference subsystem is chosen * MAINT: simplify `get_subsystem()` implementation --- docs/comparison/d2kkk.ipynb | 2 +- docs/comparison/jpsi2phipipi.ipynb | 2 +- docs/comparison/jpsi2pipipi.ipynb | 2 +- docs/jpsi2ksp.ipynb | 6 ++-- docs/lc2pkpi.ipynb | 5 +--- docs/xib2pkk.ipynb | 4 +-- src/ampform_dpd/__init__.py | 44 +++++++++++++++++++----------- src/ampform_dpd/decay.py | 26 ++++++++++++------ 8 files changed, 54 insertions(+), 37 deletions(-) diff --git a/docs/comparison/d2kkk.ipynb b/docs/comparison/d2kkk.ipynb index c119cac1..548f3da1 100644 --- a/docs/comparison/d2kkk.ipynb +++ b/docs/comparison/d2kkk.ipynb @@ -84,7 +84,7 @@ "simplify_latex_rendering()\n", "logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n", - "warnings.simplefilter(\"ignore\")\n", + "warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n", "NO_TQDM = \"EXECUTE_NB\" in os.environ\n", "if NO_TQDM:\n", " logging.getLogger(\"ampform.sympy\").setLevel(logging.ERROR)\n", diff --git a/docs/comparison/jpsi2phipipi.ipynb b/docs/comparison/jpsi2phipipi.ipynb index b812df55..b0c5d3ba 100644 --- a/docs/comparison/jpsi2phipipi.ipynb +++ b/docs/comparison/jpsi2phipipi.ipynb @@ -84,7 +84,7 @@ "simplify_latex_rendering()\n", "logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n", - "warnings.simplefilter(\"ignore\")\n", + "warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n", "NO_TQDM = \"EXECUTE_NB\" in os.environ\n", "if NO_TQDM:\n", " logging.getLogger(\"ampform.sympy\").setLevel(logging.ERROR)\n", diff --git a/docs/comparison/jpsi2pipipi.ipynb b/docs/comparison/jpsi2pipipi.ipynb index b3a381b7..418c6e8d 100644 --- a/docs/comparison/jpsi2pipipi.ipynb +++ b/docs/comparison/jpsi2pipipi.ipynb @@ -84,7 +84,7 @@ "simplify_latex_rendering()\n", "logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n", - "warnings.simplefilter(\"ignore\")\n", + "warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n", "NO_TQDM = \"EXECUTE_NB\" in os.environ\n", "if NO_TQDM:\n", " logging.getLogger(\"ampform.sympy\").setLevel(logging.ERROR)\n", diff --git a/docs/jpsi2ksp.ipynb b/docs/jpsi2ksp.ipynb index a968a98a..984f99e7 100644 --- a/docs/jpsi2ksp.ipynb +++ b/docs/jpsi2ksp.ipynb @@ -58,7 +58,7 @@ "\n", "simplify_latex_rendering()\n", "logging.getLogger(\"absl\").setLevel(logging.ERROR) # mute JAX\n", - "warnings.simplefilter(\"ignore\")\n", + "warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n", "\n", "NO_TQDM = \"EXECUTE_NB\" in os.environ\n", "if NO_TQDM:\n", @@ -191,7 +191,7 @@ " model_builder.dynamics_choices.register_builder(\n", " chain, formulate_breit_wigner_with_form_factor\n", " )\n", - "model = model_builder.formulate(reference_subsystem=1)\n", + "model = model_builder.formulate(reference_subsystem=2)\n", "model.intensity" ] }, @@ -230,7 +230,7 @@ }, "outputs": [], "source": [ - "Latex(aslatex({k: v for k, v in model.amplitudes.items() if v}))" + "Latex(aslatex(model.amplitudes))" ] }, { diff --git a/docs/lc2pkpi.ipynb b/docs/lc2pkpi.ipynb index 569bcde0..39e098ed 100644 --- a/docs/lc2pkpi.ipynb +++ b/docs/lc2pkpi.ipynb @@ -26,8 +26,6 @@ "source": [ "from __future__ import annotations\n", "\n", - "import warnings\n", - "\n", "import graphviz\n", "import qrules\n", "import sympy as sp\n", @@ -44,8 +42,7 @@ "from ampform_dpd.dynamics.builder import create_mass_symbol, get_mandelstam_s\n", "from ampform_dpd.io import as_markdown_table, aslatex, simplify_latex_rendering\n", "\n", - "simplify_latex_rendering()\n", - "warnings.simplefilter(\"ignore\")" + "simplify_latex_rendering()" ] }, { diff --git a/docs/xib2pkk.ipynb b/docs/xib2pkk.ipynb index 27e197ca..e25629ed 100644 --- a/docs/xib2pkk.ipynb +++ b/docs/xib2pkk.ipynb @@ -63,7 +63,7 @@ "\n", "simplify_latex_rendering()\n", "logging.getLogger(\"absl\").setLevel(logging.ERROR) # mute JAX\n", - "warnings.simplefilter(\"ignore\")\n", + "warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n", "\n", "NO_TQDM = \"EXECUTE_NB\" in os.environ\n", "if NO_TQDM:\n", @@ -256,7 +256,7 @@ }, "outputs": [], "source": [ - "Latex(aslatex({k: v for k, v in model.amplitudes.items() if v}))" + "Latex(aslatex(model.amplitudes))" ] }, { diff --git a/src/ampform_dpd/__init__.py b/src/ampform_dpd/__init__.py index 4c6b5cfb..a5da0790 100644 --- a/src/ampform_dpd/__init__.py +++ b/src/ampform_dpd/__init__.py @@ -5,6 +5,7 @@ from functools import lru_cache from itertools import product from typing import Literal, Protocol +from warnings import warn import sympy as sp from ampform.kinematics.phasespace import compute_third_mandelstam @@ -22,6 +23,8 @@ Particle, ThreeBodyDecay, ThreeBodyDecayChain, + _get_decay_description, # pyright:ignore[reportPrivateUsage] + _get_subsystem_ids, # pyright:ignore[reportPrivateUsage] get_decay_product_ids, to_particle, ) @@ -81,6 +84,7 @@ def formulate( reference_subsystem: FinalStateID = 1, cleanup_summations: bool = False, ) -> AmplitudeModel: + _check_reference_subsystems(self.decay, reference_subsystem) helicity_symbols: tuple[sp.Symbol, sp.Symbol, sp.Symbol, sp.Symbol] = ( sp.symbols("lambda:4", rational=True) ) @@ -93,7 +97,7 @@ def formulate( parameter_defaults = {} args: tuple[sp.Rational, sp.Rational, sp.Rational, sp.Rational] for args in product(*allowed_helicities.values()): # type:ignore[assignment] - for sub_system in (1, 2, 3): + for sub_system in _get_subsystem_ids(self.decay): chain_model = self.formulate_subsystem_amplitude(*args, sub_system) # type:ignore[arg-type] amplitude_definitions.update(chain_model.amplitudes) angle_definitions.update(chain_model.variables) @@ -234,26 +238,20 @@ def formulate_aligned_amplitude( λ3: sp.Rational | sp.Symbol, reference_subsystem: FinalStateID = 1, ) -> tuple[PoolSum, dict[sp.Symbol, sp.Expr]]: + _check_reference_subsystems(self.decay, reference_subsystem) wigner_generator = _AlignmentWignerGenerator(reference_subsystem) _λ0, _λ1, _λ2, _λ3 = sp.symbols(R"\lambda_(0:4)^{\prime}", rational=True) j0, j1, j2, j3 = (self.decay.states[i].spin for i in sorted(self.decay.states)) A = _generate_amplitude_index_bases() amp_expr = PoolSum( - A[1][_λ0, _λ1, _λ2, _λ3] - * wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=1) - * wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=1) - * wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=1) - * wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=1) - + A[2][_λ0, _λ1, _λ2, _λ3] - * wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=2) - * wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=2) - * wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=2) - * wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=2) - + A[3][_λ0, _λ1, _λ2, _λ3] - * wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=3) - * wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=3) - * wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=3) - * wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=3), + sum( + A[k][_λ0, _λ1, _λ2, _λ3] + * wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=k) + * wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=k) + * wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=k) + * wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=k) + for k in _get_subsystem_ids(self.decay) + ), (_λ0, create_spin_range(j0)), (_λ1, create_spin_range(j1)), (_λ2, create_spin_range(j2)), @@ -262,6 +260,20 @@ def formulate_aligned_amplitude( return amp_expr, wigner_generator.angle_definitions # type:ignore[return-value] +def _check_reference_subsystems( + decay: ThreeBodyDecay, reference_subsystem: FinalStateID +) -> None: + subsystem_ids = _get_subsystem_ids(decay) + if reference_subsystem not in subsystem_ids: + decay_description = _get_decay_description(decay) + subsystems = ", ".join(sorted(str(i) for i in _get_subsystem_ids(decay))) + msg = ( + f"Decay {decay_description} only has subsystems {subsystems}. Are you" + f" sure you want to use subsystem {reference_subsystem} as reference?" + ) + warn(msg, category=UserWarning) + + def _create_coupling_symbol( helicity_coupling: bool, resonance: Str, diff --git a/src/ampform_dpd/decay.py b/src/ampform_dpd/decay.py index 0141f9ef..d545a5e1 100644 --- a/src/ampform_dpd/decay.py +++ b/src/ampform_dpd/decay.py @@ -5,6 +5,7 @@ from functools import lru_cache from textwrap import dedent from typing import TYPE_CHECKING, Generic, Literal, TypeVar, overload +from warnings import warn from attrs import field, frozen from attrs.validators import instance_of @@ -14,7 +15,6 @@ if TYPE_CHECKING: import sympy as sp - InitialStateID = Literal[0] """ID for the initial state particle in a three-body decay.""" FinalStateID = Literal[1, 2, 3] @@ -117,17 +117,25 @@ def find_chain(self, resonance_name: str) -> ThreeBodyDecayChain: raise KeyError(msg) def get_subsystem(self, subsystem_id: FinalStateID) -> ThreeBodyDecay: - child1_id, child2_id = get_decay_product_ids(subsystem_id) - child1 = self.final_state[child1_id] - child2 = self.final_state[child2_id] - filtered_chains = [ - chain - for chain in self.chains - if chain.decay_products in {(child1, child2), (child2, child1)} - ] + filtered_chains = [c for c in self.chains if c.spectator.index == subsystem_id] + if not filtered_chains: + decay_description = _get_decay_description(self) + subsystems = ", ".join(sorted(str(i) for i in _get_subsystem_ids(self))) + msg = f"Decay {decay_description} only has subsystems {subsystems}, not {subsystem_id}" + warn(msg, category=UserWarning) return ThreeBodyDecay(self.states, filtered_chains) +def _get_decay_description(decay: ThreeBodyDecay) -> str: + initial_state = decay.initial_state.name + final_state = ", ".join(f"{i}: {s.name}" for i, s in decay.final_state.items()) + return f"{initial_state} → {final_state}" + + +def _get_subsystem_ids(decay: ThreeBodyDecay) -> set[FinalStateID]: + return {c.spectator.index for c in decay.chains} + + def get_decay_product_ids( spectator_id: FinalStateID, ) -> tuple[FinalStateID, FinalStateID]: