In [1]:
from __future__ import annotations

import csv
import os
import pickle
from copy import deepcopy
from itertools import combinations
from typing import Any, TypeAlias

import numpy as np
import numpy.typing as npt
import ray
from sklearn.neighbors import KDTree
from tqdm import tqdm

from utils import preview_contents

In [2]:
Point: TypeAlias = npt.NDArray[np.floating]
PointArray: TypeAlias = npt.NDArray[np.floating]

Side: TypeAlias = dict[str, str | PointArray]
Piece: TypeAlias = dict[str, str | list[Side]]

# Implementation

## Load data

In [3]:
stage_02_out_path = os.path.join("out", "02", "out.pickle")

with open(stage_02_out_path, "rb") as f:
    DATA = pickle.load(f)

print("Loaded:")
preview_contents(DATA)

print("\nConverting...")
for piece in tqdm(DATA):
    for side in piece["sides"]:
        side["corners"] = np.array(side["corners"], dtype=np.float32)
        side["points"] = np.array(side["points"], dtype=np.float32)

print("\nConverted:")
preview_contents(DATA)

Loaded:
[{'name': <class 'str'>,
  'sides': [{'corners': ((<class 'float'>,),),
             'points': [(<class 'float'>,)],
             'type': <class 'str'>}]}]

Converting...


100%|██████████| 40/40 [00:00<00:00, 741.52it/s]


Converted:
[{'name': <class 'str'>,
  'sides': [{'corners': <class 'numpy.ndarray'>,
             'points': <class 'numpy.ndarray'>,
             'type': <class 'str'>}]}]





In [4]:
SAMPLE = [
    deepcopy(piece)
    for piece in DATA
    if piece["name"] in ["0001", "0002", "0003", "0004", "0005", "0006"]
]

# Side match score

Calculate the match score between two sets of points using the Iterative Closest Point (ICP) algorithm.
The match score represents the average distance between corresponding points in the source and destination sets.

**Note:**

The order of the source and destination arrays may affect the result.

**Inspired by:**
- https://nghiaho.com/?page_id=671
- https://github.com/ClayFlannigan/icp


In [5]:
def icp(src: np.ndarray, dst: np.ndarray) -> float:
    src_rows, src_dims = src.shape
    dst_rows, dst_dims = dst.shape

    # Should work for points with any dimensions
    n = max(src_dims, dst_dims)

    # Make points homogenious (this also ensures the same dimensionality for src and dst)
    src_h = np.zeros((src_rows, n + 1), dtype=np.float32)
    dst_h = np.zeros((dst_rows, n + 1), dtype=np.float32)
    src_h[:, n] = 1
    dst_h[:, n] = 1
    src_h[:, :src_dims] = src
    dst_h[:, :dst_dims] = dst

    prev_error = np.inf  # error history for early stop
    score = np.inf  # score (average error)

    tree = KDTree(dst_h, leaf_size=1, p=1)

    iterations = 1000
    for _ in range(iterations):
        # Guess which points in dst correspond to points in src based on shortest distance
        distances, idxs = tree.query(src_h, k=1)

        # Calculate error as average distance between points in dst and src
        error = np.mean(distances)

        # Update the score
        score = error if error < score else score

        # Early stop if there is no improvement or error gets worse
        if np.abs(prev_error - error) < 10e-6:
            break
        prev_error = error

        # Select subset of points from dst for transformations
        dst_s = dst_h[idxs.flatten()]

        # Translate to centroids
        src_centroid = np.mean(src_h, axis=0)
        dst_centroid = np.mean(dst_s, axis=0)
        src_c = src_h - src_centroid
        dst_c = dst_s - dst_centroid

        # Calculate rotation matrix
        h = src_c.T @ dst_c
        u, _, v = np.linalg.svd(h)
        r = v.T @ u.T

        # Special reflection case
        if np.linalg.det(r) < 0:
            v[-1, :] *= -1
            r = v.T @ u.T

        # Translation vector
        t = dst_centroid - r @ src_centroid.T

        # Apply transformations to src
        src_h = (r @ src_h.T).T + t
    else:
        # Final score update in case of no early stop
        distances, idxs = tree.query(src_h, k=1)
        error = np.mean(distances)
        score = error if error < score else score

    return score

In [6]:
# Sanity check

a = np.array([[0, 1], [1, 1], [2, 1], [3, 2]], dtype=float)
b = np.array([[1, 1], [2, 2], [2, 3], [2, 4]], dtype=float)
print("score:", icp(a, b), icp(b, a))

a = np.array([[0, 1], [1, 1], [2, 1], [3, 2]], dtype=float)
b = np.array([[2, -1], [2, 0], [2, 1], [1, 2]], dtype=float)
print("score:", icp(a, b), icp(b, a))

a = np.array([[0, 1], [1, 1], [2, 1], [3, 1]], dtype=float)
b = np.array([[1, 1], [2, 1], [3, 1], [4, 1]], dtype=float)
print("score:", icp(a, b), icp(b, a))

a = np.array([[0, 1], [1, 1], [2, 1], [3, 1]], dtype=float)
b = np.array([[0, 2], [1, 2], [2, 2], [3, 2]], dtype=float)
print("score:", icp(a, b), icp(b, a))

score: 0.75 0.3278737850487232
score: 1.25 0.608783908188343
score: 0.25 0.25
score: 0.0 0.0


In [7]:
def get_score(src: np.ndarray, dst: np.ndarray, reverse=True) -> float:
    # If checking hole/knob match, one has to be reversed before ICP score
    if reverse:
        src = -src
        src[:, 0] -= src[:, 0].min()
        src[:, 1] -= src[:, 1].min()

    # Make sure to try matching larger side to smaller (allows to include error caused by length difference)
    if len(dst) > len(src):
        dst, src = src, dst

    # Subsample source, to speed up calculations and cut margin to avoid noise at the ends of the side
    keep_percentage = 0.15
    subsampling_step = int(1 / keep_percentage)
    cut_margin = 15

    # TODO: Test cutting margin based on y corrdinates, not fixed number of points

    return icp(
        src[cut_margin:-cut_margin:subsampling_step], dst[cut_margin:-cut_margin]
    )

In [8]:
a = next(piece for piece in DATA if piece["name"] == "0001")
b = next(piece for piece in DATA if piece["name"] == "0002")

reverse = a is not b

for a_idx, side_a in enumerate(a["sides"]):
    for b_idx, side_b in enumerate(b["sides"]):
        print(
            f"{a_idx = } ({side_a['type']}), {b_idx = } ({side_b['type']});"
            f" score: {get_score(side_a['points'], side_b['points'], reverse):7.3f}"
        )

a_idx = 0 (edge), b_idx = 0 (knob); score: 142.321
a_idx = 0 (edge), b_idx = 1 (knob); score:  97.497
a_idx = 0 (edge), b_idx = 2 (edge); score:  25.628
a_idx = 0 (edge), b_idx = 3 (hole); score: 102.114
a_idx = 1 (edge), b_idx = 0 (knob); score: 136.931
a_idx = 1 (edge), b_idx = 1 (knob); score:  94.650
a_idx = 1 (edge), b_idx = 2 (edge); score:   7.387
a_idx = 1 (edge), b_idx = 3 (hole); score:  97.460
a_idx = 2 (knob), b_idx = 0 (knob); score:  39.949
a_idx = 2 (knob), b_idx = 1 (knob); score:  10.288
a_idx = 2 (knob), b_idx = 2 (edge); score:  96.826
a_idx = 2 (knob), b_idx = 3 (hole); score:  64.598
a_idx = 3 (knob), b_idx = 0 (knob); score:  50.261
a_idx = 3 (knob), b_idx = 1 (knob); score:  90.503
a_idx = 3 (knob), b_idx = 2 (edge); score:  97.162
a_idx = 3 (knob), b_idx = 3 (hole); score:   1.524


# Building comparison index

Compare all knobs with all holes to build an index with match scores. Scores are precomputed and stored to avoid recalculating them while solving the jigsaw.

Comparison index is a csv file with following columns: \
`piece_a_name`, `side_a_idx`, `side_a_type`, `piece_b_name`, `side_b_idx`, `side_b_type`, `score`

In [9]:
def build_index(pieces: list[Piece], index_file_path: str) -> None:
    if os.path.exists(index_file_path):
        os.remove(index_file_path)
    if not os.path.exists(os.path.dirname(index_file_path)):
        os.makedirs(os.path.dirname(index_file_path))

    print("Processing...")

    n_pieces = len(pieces)
    n_comparisons = 16 * (n_pieces - 1) * n_pieces // 2

    with open(index_file_path, "wt") as f, tqdm(total=n_comparisons) as pbar:
        writer = csv.writer(f)

        n = len(pieces)
        for i in range(0, n):
            for j in range(i + 1, n):
                a = pieces[i]
                b = pieces[j]

                for side_a_idx, side_a in enumerate(a["sides"]):
                    for side_b_idx, side_b in enumerate(b["sides"]):
                        if (side_a["type"] == "hole" and side_b["type"] == "knob") or (
                            side_a["type"] == "knob" and side_b["type"] == "hole"
                        ):
                            score = get_score(side_a["points"], side_b["points"])
                            writer.writerow(
                                [
                                    # fmt: off
                                    a['name'], side_a_idx, side_a['type'],
                                    b['name'], side_b_idx, side_b['type'],
                                    score
                                    # fmt: on
                                ]
                            )

                        pbar.update()

In [10]:
build_index(SAMPLE, os.path.join("out", "03", "sample_index.csv"))

Processing...


100%|██████████| 240/240 [00:01<00:00, 219.73it/s]


## Distributed with Ray

In [11]:
ray.init()

2024-01-05 21:45:03,201	INFO worker.py:1724 -- Started a local Ray instance.


0,1
Python version:,3.11.6
Ray version:,2.9.0


In [12]:
@ray.remote
def compare_pieces(piece_a: Piece, piece_b: Piece) -> list[str | int]:
    result = []
    for side_a_idx, side_a in enumerate(piece_a["sides"]):
        for side_b_idx, side_b in enumerate(piece_b["sides"]):
            if (side_a["type"] == "hole" and side_b["type"] == "knob") or (
                side_a["type"] == "knob" and side_b["type"] == "hole"
            ):
                result.append(
                    [
                        # fmt: off
                        piece_a["name"], side_a_idx, side_a["type"],
                        piece_b["name"], side_b_idx, side_b["type"],
                        get_score(side_a["points"], side_b["points"]),
                        # fmt: on
                    ]
                )
    return result


def build_index_ray(pieces: list[Piece], index_file_path: str) -> None:
    ray_data_ids = [ray.put(piece) for piece in pieces]
    tasks = [
        compare_pieces.remote(ray_id_piece_a, ray_id_piece_b)
        for ray_id_piece_a, ray_id_piece_b in combinations(ray_data_ids, 2)
    ]

    def yield_ray(tasks: list[ray.ObjectRef]) -> Any:
        while tasks:
            done, tasks = ray.wait(tasks)
            yield ray.get(done[0])

    with open(index_file_path, "wt") as f:
        writer = csv.writer(f)
        for result in tqdm(yield_ray(tasks), total=len(tasks)):
            writer.writerows(result)


build_index_ray(SAMPLE, os.path.join("out", "03", "sample_index_ray.csv"))

100%|██████████| 15/15 [00:00<00:00, 15.30it/s]


In [13]:
# Sanity check (compare single process vs ray implementation)
with (
    open(os.path.join("out", "03", "sample_index_ray.csv"), "rt") as index_mp,
    open(os.path.join("out", "03", "sample_index.csv"), "rt") as index,
):
    index_mp = set(index_mp.read().strip().split("\n"))
    index = set(index.read().strip().split("\n"))

    assert index_mp == index

In [None]:
ray.shutdown()

# Full dataset

In [None]:
ray.init()

In [14]:
build_index_ray(DATA, os.path.join("out", "03", "index.csv"))

100%|██████████| 780/780 [00:20<00:00, 38.29it/s]


In [15]:
ray.shutdown()