Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions src/pyrecest/utils/multisession_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def solve_multisession_assignment( # pylint: disable=too-many-locals
left_nodes, right_nodes, edge_gains, adjusted_costs = _build_candidate_edges(
normalized_pairwise_costs,
session_sizes_map,
session_positions,
session_offsets,
start_cost=start_cost,
end_cost=end_cost,
Expand Down Expand Up @@ -378,6 +377,15 @@ def _validate_scalar_cost(name: str, value: float) -> None:
raise ValueError(f"{name} must be finite.")


def _normalize_session_index(session_idx: Any) -> int:
session_idx = int(session_idx)
if session_idx < 0:
raise ValueError(
f"Session indices must be non-negative, got {session_idx}."
)
return session_idx


def _normalize_pairwise_costs(
pairwise_costs: PairwiseCostsInput,
) -> dict[tuple[int, int], Any]:
Expand All @@ -388,7 +396,8 @@ def _normalize_pairwise_costs(
raise ValueError(
"Each pairwise-cost key must contain two session indices."
)
source_session, target_session = int(key[0]), int(key[1])
source_session = _normalize_session_index(key[0])
target_session = _normalize_session_index(key[1])
if source_session >= target_session:
raise ValueError(
"Pairwise-cost keys must satisfy source_session < target_session."
Expand All @@ -414,9 +423,9 @@ def _normalize_session_sizes(
if session_sizes is None:
return {}
if isinstance(session_sizes, Mapping):
normalized = {
int(session_idx): int(size) for session_idx, size in session_sizes.items()
}
normalized = {}
for session_idx, size in session_sizes.items():
normalized[_normalize_session_index(session_idx)] = int(size)
else:
normalized = {
session_idx: int(size) for session_idx, size in enumerate(session_sizes)
Expand Down Expand Up @@ -485,7 +494,6 @@ def _build_observation_index(
def _build_candidate_edges( # pylint: disable=too-many-arguments,too-many-locals
pairwise_costs: Mapping[tuple[int, int], Any],
session_sizes: Mapping[int, int],
session_positions: Mapping[int, int],
session_offsets: Mapping[int, int],
*,
start_cost: float,
Expand All @@ -506,9 +514,7 @@ def _build_candidate_edges( # pylint: disable=too-many-arguments,too-many-local
f"{cost_matrix.shape}, expected {expected_shape}."
)

source_position = session_positions[source_session]
target_position = session_positions[target_session]
gap = target_position - source_position - 1
gap = int(target_session) - int(source_session) - 1
if gap < 0:
raise ValueError(
"Session indices must define a forward-in-time edge ordering."
Expand Down Expand Up @@ -722,14 +728,13 @@ def _iter_track_items(track: TrackInput) -> list[Observation]:
if isinstance(track, Mapping):
items = list(track.items())
else:
items = [
(int(session_idx), int(detection_idx))
for session_idx, detection_idx in track
]
items.sort(key=lambda item: item[0])
return [
(int(session_idx), int(detection_idx)) for session_idx, detection_idx in items
items = list(track)
items = [
(_normalize_session_index(session_idx), int(detection_idx))
for session_idx, detection_idx in items
]
items.sort(key=lambda item: item[0])
return items


__all__ = [
Expand Down
49 changes: 49 additions & 0 deletions tests/test_multisession_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,23 @@ def test_cross_gap_linking_is_supported(self):
self.assertAlmostEqual(result.total_cost, 8.8)
self.assertEqual(result.matched_edges, [((0, 0), (2, 0), 0.8)])

@unittest.skipIf(
__backend_name__ == "jax",
reason="Not supported on this backend",
)
def test_gap_penalty_uses_numeric_session_indices_when_sizes_are_inferred(self):
result = solve_multisession_assignment(
{(0, 2): array([[0.3]], dtype=float)},
start_cost=4.0,
end_cost=4.0,
gap_penalty=0.5,
)

expected_tracks = [((0, 0), (2, 0))]
self.assertEqual(self._canonical_tracks(result.tracks), expected_tracks)
self.assertAlmostEqual(result.total_cost, 8.8)
self.assertEqual(result.matched_edges, [((0, 0), (2, 0), 0.8)])

@unittest.skipIf(
__backend_name__ == "jax",
reason="Not supported on this backend",
Expand Down Expand Up @@ -140,6 +157,38 @@ def test_rejects_inconsistent_session_sizes(self):
with self.assertRaises(ValueError):
solve_multisession_assignment(pairwise_costs)

@unittest.skipIf(
__backend_name__ == "jax",
reason="Not supported on this backend",
)
def test_rejects_negative_pairwise_session_indices(self):
with self.assertRaisesRegex(
ValueError,
"Session indices must be non-negative",
):
solve_multisession_assignment(
{(-1, 0): array([[0.1]], dtype=float)},
start_cost=1.0,
end_cost=1.0,
)

@unittest.skipIf(
__backend_name__ == "jax",
reason="Not supported on this backend",
)
def test_rejects_negative_explicit_session_indices(self):
with self.assertRaisesRegex(
ValueError,
"Session indices must be non-negative",
):
solve_multisession_assignment({}, session_sizes={-1: 1})

with self.assertRaisesRegex(
ValueError,
"Session indices must be non-negative",
):
tracks_to_session_labels([{-1: 0}])

@unittest.skipIf(
__backend_name__ == "jax",
reason="Not supported on this backend",
Expand Down
Loading