/
dpd.py
198 lines (158 loc) · 6.68 KB
/
dpd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""Spin alignment with Dalitz-Plot Decomposition.
See :cite:`mikhasenkoDalitzplotDecompositionThreebody2020`.
"""
from __future__ import annotations
import sys
from functools import lru_cache, singledispatch
from typing import TYPE_CHECKING, TypeVar
import attrs
import sympy as sp
from attrs import define, field
from attrs.validators import in_
from qrules.topology import Topology
from qrules.transition import ReactionInfo, StateTransition
from sympy.physics.quantum.spin import Rotation as Wigner
from ampform._qrules import get_qrules_version
from ampform.helicity.align import SpinAlignment
from ampform.helicity.decay import (
get_outer_state_ids,
get_spectator_id,
group_by_topology,
)
from ampform.helicity.naming import create_amplitude_base, create_spin_projection_symbol
from ampform.kinematics.angles import formulate_zeta_angle
from ampform.sympy import PoolSum
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if TYPE_CHECKING:
from sympy.physics.quantum.spin import WignerD
if get_qrules_version() < (0, 10):
from qrules.transition import ( # type: ignore[attr-defined]
StateTransitionCollection,
)
@define
class DalitzPlotDecomposition(SpinAlignment):
"""Alignment amplitudes with the "axis-angle" method.
See :cite:`marangottoHelicityAmplitudesGeneric2020` and `Wigner rotations
<https://en.wikipedia.org/wiki/Wigner_rotation>`_.
"""
reference_subsystem: Literal[1, 2, 3] = field(validator=in_({1, 2, 3}))
def formulate_amplitude(self, reaction: ReactionInfo) -> sp.Expr:
return _formulate_aligned_amplitude(reaction, self.reference_subsystem)[0]
def define_symbols(self, reaction: ReactionInfo) -> dict[sp.Symbol, sp.Expr]:
return _formulate_aligned_amplitude(reaction, self.reference_subsystem)[1]
@lru_cache(maxsize=None)
def _formulate_aligned_amplitude( # noqa: PLR0914
reaction: ReactionInfo, reference_subsystem: Literal[1, 2, 3]
) -> tuple[sp.Expr, dict[sp.Symbol, sp.Expr]]:
wigner_generator = _DPDAlignmentWignerGenerator(reference_subsystem)
outer_state_ids = get_outer_state_ids(reaction)
λ0, λ1, λ2, λ3 = ( # noqa: PLC2401
create_spin_projection_symbol(i) for i in outer_state_ids
)
_λ0, _λ1, _λ2, _λ3 = sp.symbols(R"\lambda_(:4)^", rational=True) # noqa: PLC2401
some_transition = reaction.transitions[0]
j0, j1, j2, j3 = (
sp.Rational(some_transition.states[i].particle.spin) for i in outer_state_ids
)
topology_groups = group_by_topology(reaction.transitions)
aligned_amplitudes: list[sp.Mul] = []
for topology in topology_groups:
spectator_id = get_spectator_id(topology)
base = create_amplitude_base(topology)
aligned_amplitudes += [
base[_λ0, _λ1, _λ2, _λ3]
* wigner_generator(j0, λ0, _λ0, 0, spectator_id)
* wigner_generator(j1, _λ1, λ1, 1, spectator_id)
* wigner_generator(j2, _λ2, λ2, 2, spectator_id)
* wigner_generator(j3, _λ3, λ3, 3, spectator_id)
]
outer_helicities = _collect_outer_state_helicities(reaction)
amp_expr = PoolSum(
sp.Add(*aligned_amplitudes),
(_λ0, outer_helicities[0]),
(_λ1, outer_helicities[1]),
(_λ2, outer_helicities[2]),
(_λ3, outer_helicities[3]),
)
return amp_expr, wigner_generator.angle_definitions
class _DPDAlignmentWignerGenerator:
def __init__(self, reference_subsystem: Literal[1, 2, 3] = 1) -> None:
self.angle_definitions: dict[sp.Symbol, sp.Expr] = {}
self.reference_subsystem = reference_subsystem
def __call__(
self,
j: sp.Rational | sp.Symbol,
m: sp.Rational | sp.Symbol,
m_prime: sp.Rational | sp.Symbol,
rotated_state: Literal[0, 1, 2, 3],
aligned_subsystem: Literal[1, 2, 3],
) -> sp.Rational | WignerD:
if j == 0:
return sp.Rational(1)
zeta, zeta_expr = formulate_zeta_angle(
rotated_state, aligned_subsystem, self.reference_subsystem
)
self.angle_definitions[zeta] = zeta_expr
return Wigner.d(j, m, m_prime, zeta)
if get_qrules_version() < (0, 10):
T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology)
"""Allowed types for :func:`relabel_edge_ids`."""
else:
T = TypeVar( # type: ignore[misc] # pyright: ignore[reportConstantRedefinition]
"T", ReactionInfo, StateTransition, Topology
)
"""Allowed types for :func:`relabel_edge_ids`."""
@singledispatch
def relabel_edge_ids(obj: T) -> T: # type: ignore[reportInvalidTypeForm]
msg = f"Cannot relabel edge IDs of a {type(obj).__name__}"
raise NotImplementedError(msg)
@relabel_edge_ids.register(ReactionInfo)
def _(obj: ReactionInfo) -> ReactionInfo: # type: ignore[misc]
if get_qrules_version() < (0, 10):
return ReactionInfo( # type: ignore[call-arg]
transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], # type: ignore[attr-defined]
formalism=obj.formalism,
)
return ReactionInfo(
# no attrs.evolve() in order to call __attrs_post_init__()
transitions=[relabel_edge_ids(g) for g in obj.transitions],
formalism=obj.formalism,
)
if get_qrules_version() < (0, 10):
def __relabel_stc(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc]
return StateTransitionCollection([
relabel_edge_ids(transition) for transition in obj.transitions
])
relabel_edge_ids.register(StateTransitionCollection)(__relabel_stc)
def __relabel_st(obj: StateTransition) -> StateTransition: # type: ignore[misc]
mapping = __get_default_relabel_mapping()
return attrs.evolve(
obj,
topology=relabel_edge_ids(obj.topology),
states={mapping[k]: v for k, v in obj.states.items()},
)
if get_qrules_version() < (0, 10):
relabel_edge_ids.register(StateTransition)(__relabel_st)
else:
from qrules.topology import FrozenTransition
relabel_edge_ids.register(FrozenTransition)(__relabel_st)
@relabel_edge_ids.register(Topology)
def _(obj: Topology) -> Topology: # type: ignore[misc]
mapping = __get_default_relabel_mapping()
return obj.relabel_edges(mapping)
def __get_default_relabel_mapping() -> dict[int, int]:
return {i - 1: i for i in range(5)}
def _collect_outer_state_helicities(
reaction: ReactionInfo,
) -> dict[int, list[sp.Rational]]:
outer_state_ids = get_outer_state_ids(reaction)
return {
i: sorted({
sp.Rational(transition.states[i].spin_projection)
for transition in reaction.transitions
})
for i in outer_state_ids
}