diff --git a/src/qrules/topology.py b/src/qrules/topology.py index bb69981b..98634daf 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -777,7 +777,7 @@ def intermediate_states(self) -> dict[int, EdgeType]: def filter_states(self, edge_ids: Iterable[int]) -> dict[int, EdgeType]: """Filter `states` by a selection of :code:`edge_ids`.""" - return {i: self.states[i] for i in edge_ids} + return {i: self.states[i] for i in edge_ids if i in self.states} @implement_pretty_repr diff --git a/tests/unit/test_transition.py b/tests/unit/test_transition.py index 72b00924..9f75d501 100644 --- a/tests/unit/test_transition.py +++ b/tests/unit/test_transition.py @@ -24,6 +24,23 @@ NAMESPACE_WITH_FRACTIONS["Fraction"] = Fraction +class TestMutableTransition: + def test_intermediate_states(self): + stm = StateTransitionManager( + initial_state=[("J/psi(1S)", [-1, +1])], + final_state=["K0", "Sigma+", "p~"], + allowed_intermediate_particles=["N(1700)", "Sigma(1750)"], + formalism="helicity", + mass_conservation_factor=0, + ) + stm.set_allowed_interaction_types([InteractionType.STRONG, InteractionType.EM]) + problem_sets = stm.create_problem_sets() + some_problem_set = problem_sets[3600.0][0] + assert set(some_problem_set.initial_facts.initial_states) == {-1} + assert set(some_problem_set.initial_facts.final_states) == {0, 1, 2} + assert set(some_problem_set.initial_facts.intermediate_states) == set() + + class TestReactionInfo: def test_properties(self, reaction: ReactionInfo): assert reaction.initial_state[-1].name == "J/psi(1S)"