# PieceMatcher Prototype

Validate the `PieceMatcher` logic using `SegmentId` and `SheetManager`.


## Import


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from loguru import logger as lg
from rich import get_console
from rich import print as rprint
from rich.console import Console

# some magic to make rich work in jupyter
# https://github.com/Textualize/rich/issues/3483
# enable it for every cell output with %load_ext rich
console: Console = get_console()
console.is_jupyter = False

In [None]:
from pydantic import BaseModel
from snap_fit.data_models.segment_id import SegmentId
from snap_fit.image.segment_matcher import SegmentMatcher
from snap_fit.puzzle.sheet_manager import SheetManager
from snap_fit.config.types import EdgePos
import itertools
from unittest.mock import MagicMock, patch

## Params and config


In [None]:
from snap_fit.params.snap_fit_params import get_snap_fit_paths

sf_paths = get_snap_fit_paths()
rprint(sf_paths)

## Develop and prototype


In [None]:
class MatchResult(BaseModel):
    """Result of matching two segments."""

    seg_id1: SegmentId
    seg_id2: SegmentId
    similarity: float

    @property
    def pair(self) -> frozenset[SegmentId]:
        """Get the segment IDs as a frozenset for symmetric lookup."""
        return frozenset({self.seg_id1, self.seg_id2})

    def get_other(self, seg_id: SegmentId) -> SegmentId:
        """Get the other segment ID in the match."""
        if seg_id == self.seg_id1:
            return self.seg_id2
        if seg_id == self.seg_id2:
            return self.seg_id1
        raise ValueError(f"SegmentId {seg_id} not in this match result")


class PieceMatcher:
    """Matches puzzle pieces and stores results."""

    def __init__(self, manager: SheetManager):
        self.manager = manager
        self._results: list[MatchResult] = []
        self._lookup: dict[frozenset[SegmentId], MatchResult] = {}

    def match_pair(self, id1: SegmentId, id2: SegmentId) -> MatchResult:
        """Match two segments and store the result."""
        pair = frozenset({id1, id2})
        if pair in self._lookup:
            return self._lookup[pair]

        seg1 = self.manager.get_segment(id1)
        seg2 = self.manager.get_segment(id2)

        if seg1 is None or seg2 is None:
            lg.warning(f"Could not find segments for {id1} or {id2}")
            # Return a high similarity (poor match) if segments are missing
            res = MatchResult(seg_id1=id1, seg_id2=id2, similarity=1e6)
        else:
            matcher = SegmentMatcher(seg1, seg2)
            similarity = matcher.compute_similarity()
            res = MatchResult(seg_id1=id1, seg_id2=id2, similarity=float(similarity))

        self._results.append(res)
        self._lookup[pair] = res
        return res

    def match_all(self):
        """Match all segments against all segments from other pieces."""
        all_ids = self.manager.get_segment_ids_all()
        lg.info(f"Matching {len(all_ids)} segments...")

        for id1 in all_ids:
            # Use the manager's helper to get candidates from other pieces
            other_ids = self.manager.get_segment_ids_other_pieces(id1)
            for id2 in other_ids:
                self.match_pair(id1, id2)

        # Sort results by similarity (lower is better)
        self._results.sort(key=lambda x: x.similarity)
        lg.info(f"Completed {len(self._results)} matches.")

    def get_top_matches(self, n: int = 10) -> list[MatchResult]:
        """Get the top N matches."""
        return self._results[:n]

    def get_matches_for_piece(self, sheet_id: str, piece_id: int) -> list[MatchResult]:
        """Get all matches involving a specific piece."""
        matches = []
        for res in self._results:
            if (
                res.seg_id1.sheet_id == sheet_id and res.seg_id1.piece_id == piece_id
            ) or (
                res.seg_id2.sheet_id == sheet_id and res.seg_id2.piece_id == piece_id
            ):
                matches.append(res)
        return matches


In [None]:
# Test with mock data first to verify logic
from unittest.mock import MagicMock, patch

# Mock SegmentId
id_a0_l = SegmentId(sheet_id="A", piece_id=0, edge_pos="left")
id_b1_l = SegmentId(sheet_id="B", piece_id=1, edge_pos="left")

# Mock SheetManager
mock_manager = MagicMock(spec=SheetManager)
mock_manager.get_segment.return_value = MagicMock()  # Just something not None
mock_manager.get_segment_ids_all.return_value = [id_a0_l, id_b1_l]
mock_manager.get_segment_ids_other_pieces.side_effect = (
    lambda x: [id_b1_l] if x == id_a0_l else [id_a0_l]
)

matcher = PieceMatcher(mock_manager)

# Mock SegmentMatcher in the notebook's namespace
with patch("__main__.SegmentMatcher") as mock_seg_matcher:
    mock_seg_matcher.return_value.compute_similarity.return_value = 0.5

    res1 = matcher.match_pair(id_a0_l, id_b1_l)
    res2 = matcher.match_pair(id_b1_l, id_a0_l)

    rprint(f"Match 1: {res1}")
    rprint(f"Match 2: {res2}")

    # Verify symmetry
    assert res1 is res2
    rprint("Symmetry verified: res1 is res2")

    # Verify lookup
    assert len(matcher._results) == 1
    rprint(f"Lookup verified: results length is {len(matcher._results)}")


In [None]:
# Test match_all
matcher = PieceMatcher(mock_manager)

with patch("__main__.SegmentMatcher") as mock_seg_matcher:
    mock_seg_matcher.return_value.compute_similarity.return_value = 0.1
    matcher.match_all()

    rprint(f"Total results: {len(matcher._results)}")
    rprint(f"Top matches: {matcher.get_top_matches(2)}")
