Skip to content

Commit

Permalink
ENH: warn if wrong reference subsystem is chosen
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Apr 30, 2024
1 parent cd87116 commit 2f7d948
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
19 changes: 17 additions & 2 deletions src/ampform_dpd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import lru_cache
from itertools import product
from typing import Literal, Protocol
from warnings import warn

import sympy as sp
from ampform.kinematics.phasespace import compute_third_mandelstam
Expand All @@ -22,6 +23,8 @@
Particle,
ThreeBodyDecay,
ThreeBodyDecayChain,
_get_decay_description, # pyright:ignore[reportPrivateUsage]
_get_subsystem_ids, # pyright:ignore[reportPrivateUsage]
get_decay_product_ids,
to_particle,
)
Expand Down Expand Up @@ -81,6 +84,7 @@ def formulate(
reference_subsystem: FinalStateID = 1,
cleanup_summations: bool = False,
) -> AmplitudeModel:
_check_reference_subsystems(self.decay, reference_subsystem)
helicity_symbols: tuple[sp.Symbol, sp.Symbol, sp.Symbol, sp.Symbol] = (
sp.symbols("lambda:4", rational=True)
)
Expand Down Expand Up @@ -234,6 +238,7 @@ def formulate_aligned_amplitude(
λ3: sp.Rational | sp.Symbol,
reference_subsystem: FinalStateID = 1,
) -> tuple[PoolSum, dict[sp.Symbol, sp.Expr]]:
_check_reference_subsystems(self.decay, reference_subsystem)
wigner_generator = _AlignmentWignerGenerator(reference_subsystem)
_λ0, _λ1, _λ2, _λ3 = sp.symbols(R"\lambda_(0:4)^{\prime}", rational=True)
j0, j1, j2, j3 = (self.decay.states[i].spin for i in sorted(self.decay.states))
Expand All @@ -255,8 +260,18 @@ def formulate_aligned_amplitude(
return amp_expr, wigner_generator.angle_definitions # type:ignore[return-value]


def _get_subsystem_ids(decay: ThreeBodyDecay) -> list[FinalStateID]:
return sorted({chain.spectator.index for chain in decay.chains})
def _check_reference_subsystems(
decay: ThreeBodyDecay, reference_subsystem: FinalStateID
) -> None:
subsystem_ids = _get_subsystem_ids(decay)
if reference_subsystem not in subsystem_ids:
decay_description = _get_decay_description(decay)
subsystems = ", ".join(sorted(str(i) for i in _get_subsystem_ids(decay)))
msg = (
f"Decay {decay_description} only has subsystems {subsystems}. Are you"
f" sure you want to use subsystem {reference_subsystem} as reference?"
)
warn(msg, category=UserWarning)


def _create_coupling_symbol(
Expand Down
24 changes: 13 additions & 11 deletions src/ampform_dpd/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,23 @@ def find_chain(self, resonance_name: str) -> ThreeBodyDecayChain:
def get_subsystem(self, subsystem_id: FinalStateID) -> ThreeBodyDecay:
filtered_chains = [c for c in self.chains if c.spectator.index == subsystem_id]
if not filtered_chains:
initial_state = self.initial_state.name
final_state = ", ".join(
f"{i}: {s.name}" for i, s in self.final_state.items()
)
subsystems = ", ".join(
sorted({str(c.spectator.index) for c in self.chains})
)
msg = (
f"Three-body decay {initial_state}{final_state} only has subsystems"
f"{subsystems}, not {subsystem_id}"
)
decay_description = _get_decay_description(self)
subsystems = ", ".join(sorted(str(i) for i in _get_subsystem_ids(self)))
msg = f"Decay {decay_description} only has subsystems {subsystems}, not {subsystem_id}"
warn(msg, category=UserWarning)
return ThreeBodyDecay(self.states, filtered_chains)


def _get_decay_description(decay: ThreeBodyDecay) -> str:
initial_state = decay.initial_state.name
final_state = ", ".join(f"{i}: {s.name}" for i, s in decay.final_state.items())
return f"{initial_state}{final_state}"


def _get_subsystem_ids(decay: ThreeBodyDecay) -> set[FinalStateID]:
return {c.spectator.index for c in decay.chains}


def get_decay_product_ids(
spectator_id: FinalStateID,
) -> tuple[FinalStateID, FinalStateID]:
Expand Down

0 comments on commit 2f7d948

Please sign in to comment.