Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: formulate amplitudes for existing subsystems only #127

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/comparison/d2kkk.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"simplify_latex_rendering()\n",
"logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n",
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
" logging.getLogger(\"ampform.sympy\").setLevel(logging.ERROR)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/comparison/jpsi2phipipi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"simplify_latex_rendering()\n",
"logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n",
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
" logging.getLogger(\"ampform.sympy\").setLevel(logging.ERROR)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/comparison/jpsi2pipipi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"simplify_latex_rendering()\n",
"logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n",
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
" logging.getLogger(\"ampform.sympy\").setLevel(logging.ERROR)\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/jpsi2ksp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"\n",
"simplify_latex_rendering()\n",
"logging.getLogger(\"absl\").setLevel(logging.ERROR) # mute JAX\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
Expand Down Expand Up @@ -191,7 +191,7 @@
" model_builder.dynamics_choices.register_builder(\n",
" chain, formulate_breit_wigner_with_form_factor\n",
" )\n",
"model = model_builder.formulate(reference_subsystem=1)\n",
"model = model_builder.formulate(reference_subsystem=2)\n",
"model.intensity"
]
},
Expand Down Expand Up @@ -230,7 +230,7 @@
},
"outputs": [],
"source": [
"Latex(aslatex({k: v for k, v in model.amplitudes.items() if v}))"
"Latex(aslatex(model.amplitudes))"
]
},
{
Expand Down
5 changes: 1 addition & 4 deletions docs/lc2pkpi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
"source": [
"from __future__ import annotations\n",
"\n",
"import warnings\n",
"\n",
"import graphviz\n",
"import qrules\n",
"import sympy as sp\n",
Expand All @@ -44,8 +42,7 @@
"from ampform_dpd.dynamics.builder import create_mass_symbol, get_mandelstam_s\n",
"from ampform_dpd.io import as_markdown_table, aslatex, simplify_latex_rendering\n",
"\n",
"simplify_latex_rendering()\n",
"warnings.simplefilter(\"ignore\")"
"simplify_latex_rendering()"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/xib2pkk.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"\n",
"simplify_latex_rendering()\n",
"logging.getLogger(\"absl\").setLevel(logging.ERROR) # mute JAX\n",
"warnings.simplefilter(\"ignore\")\n",
"warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
"\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
Expand Down Expand Up @@ -256,7 +256,7 @@
},
"outputs": [],
"source": [
"Latex(aslatex({k: v for k, v in model.amplitudes.items() if v}))"
"Latex(aslatex(model.amplitudes))"
]
},
{
Expand Down
44 changes: 28 additions & 16 deletions src/ampform_dpd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import lru_cache
from itertools import product
from typing import Literal, Protocol
from warnings import warn

import sympy as sp
from ampform.kinematics.phasespace import compute_third_mandelstam
Expand All @@ -22,6 +23,8 @@
Particle,
ThreeBodyDecay,
ThreeBodyDecayChain,
_get_decay_description, # pyright:ignore[reportPrivateUsage]
_get_subsystem_ids, # pyright:ignore[reportPrivateUsage]
get_decay_product_ids,
to_particle,
)
Expand Down Expand Up @@ -81,6 +84,7 @@ def formulate(
reference_subsystem: FinalStateID = 1,
cleanup_summations: bool = False,
) -> AmplitudeModel:
_check_reference_subsystems(self.decay, reference_subsystem)
helicity_symbols: tuple[sp.Symbol, sp.Symbol, sp.Symbol, sp.Symbol] = (
sp.symbols("lambda:4", rational=True)
)
Expand All @@ -93,7 +97,7 @@ def formulate(
parameter_defaults = {}
args: tuple[sp.Rational, sp.Rational, sp.Rational, sp.Rational]
for args in product(*allowed_helicities.values()): # type:ignore[assignment]
for sub_system in (1, 2, 3):
for sub_system in _get_subsystem_ids(self.decay):
chain_model = self.formulate_subsystem_amplitude(*args, sub_system) # type:ignore[arg-type]
amplitude_definitions.update(chain_model.amplitudes)
angle_definitions.update(chain_model.variables)
Expand Down Expand Up @@ -234,26 +238,20 @@ def formulate_aligned_amplitude(
λ3: sp.Rational | sp.Symbol,
reference_subsystem: FinalStateID = 1,
) -> tuple[PoolSum, dict[sp.Symbol, sp.Expr]]:
_check_reference_subsystems(self.decay, reference_subsystem)
wigner_generator = _AlignmentWignerGenerator(reference_subsystem)
_λ0, _λ1, _λ2, _λ3 = sp.symbols(R"\lambda_(0:4)^{\prime}", rational=True)
j0, j1, j2, j3 = (self.decay.states[i].spin for i in sorted(self.decay.states))
A = _generate_amplitude_index_bases()
amp_expr = PoolSum(
A[1][_λ0, _λ1, _λ2, _λ3]
* wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=1)
* wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=1)
* wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=1)
* wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=1)
+ A[2][_λ0, _λ1, _λ2, _λ3]
* wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=2)
* wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=2)
* wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=2)
* wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=2)
+ A[3][_λ0, _λ1, _λ2, _λ3]
* wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=3)
* wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=3)
* wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=3)
* wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=3),
sum(
A[k][_λ0, _λ1, _λ2, _λ3]
* wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=k)
* wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=k)
* wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=k)
* wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=k)
for k in _get_subsystem_ids(self.decay)
),
(_λ0, create_spin_range(j0)),
(_λ1, create_spin_range(j1)),
(_λ2, create_spin_range(j2)),
Expand All @@ -262,6 +260,20 @@ def formulate_aligned_amplitude(
return amp_expr, wigner_generator.angle_definitions # type:ignore[return-value]


def _check_reference_subsystems(
decay: ThreeBodyDecay, reference_subsystem: FinalStateID
) -> None:
subsystem_ids = _get_subsystem_ids(decay)
if reference_subsystem not in subsystem_ids:
decay_description = _get_decay_description(decay)
subsystems = ", ".join(sorted(str(i) for i in _get_subsystem_ids(decay)))
msg = (
f"Decay {decay_description} only has subsystems {subsystems}. Are you"
f" sure you want to use subsystem {reference_subsystem} as reference?"
)
warn(msg, category=UserWarning)


def _create_coupling_symbol(
helicity_coupling: bool,
resonance: Str,
Expand Down
26 changes: 17 additions & 9 deletions src/ampform_dpd/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import lru_cache
from textwrap import dedent
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, overload
from warnings import warn

from attrs import field, frozen
from attrs.validators import instance_of
Expand All @@ -14,7 +15,6 @@
if TYPE_CHECKING:
import sympy as sp


InitialStateID = Literal[0]
"""ID for the initial state particle in a three-body decay."""
FinalStateID = Literal[1, 2, 3]
Expand Down Expand Up @@ -117,17 +117,25 @@ def find_chain(self, resonance_name: str) -> ThreeBodyDecayChain:
raise KeyError(msg)

def get_subsystem(self, subsystem_id: FinalStateID) -> ThreeBodyDecay:
child1_id, child2_id = get_decay_product_ids(subsystem_id)
child1 = self.final_state[child1_id]
child2 = self.final_state[child2_id]
filtered_chains = [
chain
for chain in self.chains
if chain.decay_products in {(child1, child2), (child2, child1)}
]
filtered_chains = [c for c in self.chains if c.spectator.index == subsystem_id]
if not filtered_chains:
decay_description = _get_decay_description(self)
subsystems = ", ".join(sorted(str(i) for i in _get_subsystem_ids(self)))
msg = f"Decay {decay_description} only has subsystems {subsystems}, not {subsystem_id}"
warn(msg, category=UserWarning)
return ThreeBodyDecay(self.states, filtered_chains)


def _get_decay_description(decay: ThreeBodyDecay) -> str:
initial_state = decay.initial_state.name
final_state = ", ".join(f"{i}: {s.name}" for i, s in decay.final_state.items())
return f"{initial_state} → {final_state}"


def _get_subsystem_ids(decay: ThreeBodyDecay) -> set[FinalStateID]:
return {c.spectator.index for c in decay.chains}


def get_decay_product_ids(
spectator_id: FinalStateID,
) -> tuple[FinalStateID, FinalStateID]:
Expand Down
Loading