diff --git a/src/pyrecest/utils/multisession_assignment.py b/src/pyrecest/utils/multisession_assignment.py index b164d65fc..446522191 100644 --- a/src/pyrecest/utils/multisession_assignment.py +++ b/src/pyrecest/utils/multisession_assignment.py @@ -210,7 +210,6 @@ def solve_multisession_assignment( # pylint: disable=too-many-locals left_nodes, right_nodes, edge_gains, adjusted_costs = _build_candidate_edges( normalized_pairwise_costs, session_sizes_map, - session_positions, session_offsets, start_cost=start_cost, end_cost=end_cost, @@ -378,6 +377,15 @@ def _validate_scalar_cost(name: str, value: float) -> None: raise ValueError(f"{name} must be finite.") +def _normalize_session_index(session_idx: Any) -> int: + session_idx = int(session_idx) + if session_idx < 0: + raise ValueError( + f"Session indices must be non-negative, got {session_idx}." + ) + return session_idx + + def _normalize_pairwise_costs( pairwise_costs: PairwiseCostsInput, ) -> dict[tuple[int, int], Any]: @@ -388,7 +396,8 @@ def _normalize_pairwise_costs( raise ValueError( "Each pairwise-cost key must contain two session indices." ) - source_session, target_session = int(key[0]), int(key[1]) + source_session = _normalize_session_index(key[0]) + target_session = _normalize_session_index(key[1]) if source_session >= target_session: raise ValueError( "Pairwise-cost keys must satisfy source_session < target_session." @@ -414,9 +423,9 @@ def _normalize_session_sizes( if session_sizes is None: return {} if isinstance(session_sizes, Mapping): - normalized = { - int(session_idx): int(size) for session_idx, size in session_sizes.items() - } + normalized = {} + for session_idx, size in session_sizes.items(): + normalized[_normalize_session_index(session_idx)] = int(size) else: normalized = { session_idx: int(size) for session_idx, size in enumerate(session_sizes) @@ -485,7 +494,6 @@ def _build_observation_index( def _build_candidate_edges( # pylint: disable=too-many-arguments,too-many-locals pairwise_costs: Mapping[tuple[int, int], Any], session_sizes: Mapping[int, int], - session_positions: Mapping[int, int], session_offsets: Mapping[int, int], *, start_cost: float, @@ -506,9 +514,7 @@ def _build_candidate_edges( # pylint: disable=too-many-arguments,too-many-local f"{cost_matrix.shape}, expected {expected_shape}." ) - source_position = session_positions[source_session] - target_position = session_positions[target_session] - gap = target_position - source_position - 1 + gap = int(target_session) - int(source_session) - 1 if gap < 0: raise ValueError( "Session indices must define a forward-in-time edge ordering." @@ -722,14 +728,13 @@ def _iter_track_items(track: TrackInput) -> list[Observation]: if isinstance(track, Mapping): items = list(track.items()) else: - items = [ - (int(session_idx), int(detection_idx)) - for session_idx, detection_idx in track - ] - items.sort(key=lambda item: item[0]) - return [ - (int(session_idx), int(detection_idx)) for session_idx, detection_idx in items + items = list(track) + items = [ + (_normalize_session_index(session_idx), int(detection_idx)) + for session_idx, detection_idx in items ] + items.sort(key=lambda item: item[0]) + return items __all__ = [ diff --git a/tests/test_multisession_assignment.py b/tests/test_multisession_assignment.py index 4b62fc802..a1f1933ea 100644 --- a/tests/test_multisession_assignment.py +++ b/tests/test_multisession_assignment.py @@ -80,6 +80,23 @@ def test_cross_gap_linking_is_supported(self): self.assertAlmostEqual(result.total_cost, 8.8) self.assertEqual(result.matched_edges, [((0, 0), (2, 0), 0.8)]) + @unittest.skipIf( + __backend_name__ == "jax", + reason="Not supported on this backend", + ) + def test_gap_penalty_uses_numeric_session_indices_when_sizes_are_inferred(self): + result = solve_multisession_assignment( + {(0, 2): array([[0.3]], dtype=float)}, + start_cost=4.0, + end_cost=4.0, + gap_penalty=0.5, + ) + + expected_tracks = [((0, 0), (2, 0))] + self.assertEqual(self._canonical_tracks(result.tracks), expected_tracks) + self.assertAlmostEqual(result.total_cost, 8.8) + self.assertEqual(result.matched_edges, [((0, 0), (2, 0), 0.8)]) + @unittest.skipIf( __backend_name__ == "jax", reason="Not supported on this backend", @@ -140,6 +157,38 @@ def test_rejects_inconsistent_session_sizes(self): with self.assertRaises(ValueError): solve_multisession_assignment(pairwise_costs) + @unittest.skipIf( + __backend_name__ == "jax", + reason="Not supported on this backend", + ) + def test_rejects_negative_pairwise_session_indices(self): + with self.assertRaisesRegex( + ValueError, + "Session indices must be non-negative", + ): + solve_multisession_assignment( + {(-1, 0): array([[0.1]], dtype=float)}, + start_cost=1.0, + end_cost=1.0, + ) + + @unittest.skipIf( + __backend_name__ == "jax", + reason="Not supported on this backend", + ) + def test_rejects_negative_explicit_session_indices(self): + with self.assertRaisesRegex( + ValueError, + "Session indices must be non-negative", + ): + solve_multisession_assignment({}, session_sizes={-1: 1}) + + with self.assertRaisesRegex( + ValueError, + "Session indices must be non-negative", + ): + tracks_to_session_labels([{-1: 0}]) + @unittest.skipIf( __backend_name__ == "jax", reason="Not supported on this backend",