In [None]:
import cv2 as cv

import numpy as np
import numpy.typing as npt
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.collections as mc
import matplotlib.patches as mp

from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from functools import total_ordering
from numpy.random import default_rng
from queue import Queue, PriorityQueue
from typing import *

rng = default_rng(5843)

## Utilities

In [None]:
def fix_color(src: npt.NDArray) -> npt.NDArray:
    return cv.cvtColor(src, cv.COLOR_BGR2RGB)

In [None]:
def plot_image(
        img: npt.NDArray, cmap: str | None = None, *,
        title: str | None = None, format: int = cv.COLOR_BGR2RGB
    ):

    ax: plt.Axes
    _, ax = plt.subplots(dpi=140)
    
    ax.set(xticks=[], yticks=[])

    if title is not None:
        ax.set_title(title)

    if cmap is None:
        ax.imshow(cv.cvtColor(img, format))
    else:
        ax.imshow(img, cmap, vmin=0, vmax=255)

In [None]:
def plot_images(images: Sequence[tuple[str, npt.NDArray] | None], columns: int = 3, *,
                cmap: str | None = None, title: str | None = None, format = cv.COLOR_BGR2RGB,
                cell_size: tuple[float, float]=(4., 3.), dpi=120) -> None:
    
    images = list(images)
    rows = (len(images) + columns) // columns
    fig, axs = plt.subplots(rows, columns, figsize=(columns * cell_size[0], rows * cell_size[1]),
                            layout='constrained', dpi=dpi)

    if title is not None:
        fig.suptitle(title)
    
    ax: plt.Axes
    for ax in axs.flat:
        ax.set_visible(False)
    
    for i, entry in enumerate(images):
        if entry is None: continue
        subtitle, img = entry

        ax: plt.Axes = axs.flat[i]
        ax.set(xticks=[], yticks=[], title=subtitle, visible=True)
        if img.shape.count == 2:
            ax.imshow(img, cmap=cmap, vmin=0, vmax=255)
        else:
            ax.imshow(cv.cvtColor(img, format))

# Split, Shuffle and Merge

In [None]:
image_name = 'uni-1.jpg'
image = cv.imread('../images/' + image_name)
plot_image(image, title=image_name)

In [None]:
def split_grid(img: npt.NDArray, rows = 2, columns = 2, *, name: str | None = None) -> tuple[npt.NDArray, npt.NDArray]:
    rows_steps = np.linspace(0, img.shape[0], rows + 1, dtype=np.uint) # height
    columns_steps = np.linspace(0, img.shape[1], columns + 1, dtype=np.uint) # width

    cells = np.empty((rows, columns), dtype=object)
    for i in range(rows):
        for j in range(columns):
            cells[i, j] = img[
                rows_steps[i]:rows_steps[i+1],
                columns_steps[j]:columns_steps[j+1],
            ]
    
    prefix = '' if name is None else f'{name} • '
    cells_coords = np.stack(np.meshgrid(np.arange(columns), np.arange(rows)))
    cells_names = np.apply_along_axis(
        lambda x: f'{prefix}({x[1]:3d}, {x[0]:3d})',
        0, cells_coords)
    
    return cells, cells_names

cells, cells_names = split_grid(image, 2, 3, name=image_name)
cells.shape, cells_names.shape

In [None]:
plot_images(zip(cells_names.flat, cells.flat), cells.shape[1], cell_size=(4., 4.5), dpi=80)

In [None]:
def shuffle_grid(cells: npt.NDArray) -> npt.NDArray:
    cells_shuffled = cells.flatten()
    rng.shuffle(cells_shuffled)
    return cells_shuffled.reshape(cells.shape)

cells_shuffled = shuffle_grid(cells)

plot_images(
    zip(cells_names.flat, cells_shuffled.flat), cells_shuffled.shape[1],
    title='After shuffling', cell_size=(4., 4.5), dpi=80,
)

In [None]:
def merge_grid(cells: npt.NDArray) -> npt.NDArray:
    return np.concatenate(list(map(lambda x: np.concatenate(x, 1), cells)))

image_shuffled = merge_grid(cells_shuffled)
plot_image(image_shuffled, title=f'"{image_name}" after split, shuffle and merge')

In [None]:
cv.imwrite('out/uni-1-shuffled.jpg', image_shuffled)

# Coherence Heuristic    

We'll use the `Lab` color space because it's the closest to the human eye.

In [None]:
def climb_down(img, depth = 1):
    return climb_down(cv.pyrDown(img), depth-1) if depth > 0 else img

In [None]:
def climb_down_grid(cells: npt.NDArray, depth=1) -> npt.NDArray:
    result = np.empty_like(cells)
    
    for i in range(cells.shape[0]):
        for j in range(cells.shape[1]):
            result[i, j] = climb_down(cells[i, j], depth)

    return result
    

In [None]:
cells, cells_names = split_grid(cv.cvtColor(image, cv.COLOR_BGR2Lab), 2, 3)
cells = climb_down_grid(cells, 3)
plot_images(zip(cells_names.flat, cells.flat), cells.shape[1], format=cv.COLOR_Lab2RGB, cell_size=(4, 4.5), dpi=80)

In [None]:
plot_images([
    ('Target Cell', cells[1, 1]),
    ('Actual Neighbor', cells[1, 2]),
], format=cv.COLOR_Lab2RGB, cell_size=(4, 4.5))

In [None]:
class Direction(Enum):
    UP = 0,
    DOWN = 1,
    LEFT = 2,
    RIGHT = 3,

In [None]:
directions: list[Direction] = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT]

In [None]:
directions_edges: dict[Direction, tuple[slice | int, slice | int]] = {
    Direction.UP: (0, slice(None)),
    Direction.DOWN: (-1, slice(None)),
    Direction.LEFT: (slice(None), 0),
    Direction.RIGHT: (slice(None), -1),
}

def get_edge(img: npt.NDArray, dir: Direction) -> npt.NDArray:
    return img[directions_edges[dir]]

np.sum(get_edge(cells[0, 0], Direction.UP) - cells[0, 0][0, :]) == 0

In [None]:
directions_opposites: dict[Direction, Direction] = {
    Direction.UP: Direction.DOWN,
    Direction.DOWN: Direction.UP,
    Direction.LEFT: Direction.RIGHT,
    Direction.RIGHT: Direction.LEFT,
}

directions_opposites[Direction.LEFT]

In [None]:
def euclidean_distance(a: npt.NDArray, b: npt.NDArray) -> npt.NDArray:
    return np.sqrt(np.sum(np.square(a - b), -1))

def compute_edges_coherence(base: npt.NDArray, neighbor: npt.NDArray, dir: Direction) -> float:
    base_edge, neighbor_edge = get_edge(base, dir), get_edge(neighbor, directions_opposites[dir])
    return np.average(euclidean_distance(base_edge, neighbor_edge), -1)

compute_edges_coherence(cells[1, 1], cells[1, 2], Direction.RIGHT)

In [None]:
def plot_edges_scores(cells: npt.NDArray, target: npt.NDArray, dir: Direction, *, columns=3, title: str | None = None) -> None:
    scores = [(compute_edges_coherence(target, cell, dir), cell) for cell in cells.flat]
    scores = sorted(scores, key=lambda i: i[0])

    plot_images(
        map(lambda i: (f'$dist={i[0]:.2f}$', i[1]), scores), columns,
        title=title, cell_size=(4., 4.5), dpi=80,
        format=cv.COLOR_Lab2RGB,
    )

plot_edges_scores(cells, cells[1, 1], Direction.RIGHT, title='Edge Distances for cells[1, 1] with dir=RIGHT')

## Experimentation

In [None]:
cells, cells_names = split_grid(cv.cvtColor(image, cv.COLOR_BGR2Lab), 4, 6)
cells = climb_down_grid(cells, 3)
plot_images(zip(cells_names.flat, cells.flat), cells.shape[1], format=cv.COLOR_Lab2RGB, cell_size=(4, 4.5), dpi=80)

In [None]:
plot_edges_scores(cells, cells[0, 1], Direction.LEFT, columns=6)

In [None]:
def find_best_neighbor(cells: npt.NDArray, target: npt.NDArray, dir: Direction) -> tuple[float, npt.NDArray]:
    scores = [(compute_edges_coherence(target, cell, dir), cell) for cell in cells.flat]
    scores = sorted(scores, key=lambda i: i[0])
    return scores[0]

def plot_neighbors(
        cells: npt.NDArray, target: tuple[int, int], *,
        title: str | None = None, cell_size=(4., 4.5), dpi=80,
        format=cv.COLOR_BGR2RGB,
    ) -> None:

    def get(dir: Direction) -> tuple[str, npt.NDArray]:
        score, cell = find_best_neighbor(cells, cells[target], dir)
        return f'$dist = {score:.2f}$', cell

    title = title or f'Neighbors for cell[{target[0]}, {target[1]}]'

    plot_images([
        None,                 get(Direction.UP),          None,
        get(Direction.LEFT),  ('Target', cells[target]),  get(Direction.RIGHT),
        None,                 get(Direction.DOWN),        None,
    ], 3, title=title, cell_size=cell_size, dpi=dpi, format=format)

plot_neighbors(cells, (1, 1), format=cv.COLOR_Lab2RGB)

In [None]:
plot_neighbors(cells, (2, 3), format=cv.COLOR_Lab2RGB)

# Puzzle Grid Solver

## Problem Representation in a State Graph

In [None]:
cells, cells_names = split_grid(cv.cvtColor(image, cv.COLOR_BGR2Lab), 5, 5)
cells = climb_down_grid(cells, 3)
cells = shuffle_grid(cells)
plot_images(zip(cells_names.flat, cells.flat), cells.shape[0], format=cv.COLOR_Lab2RGB)

In [None]:
class Puzzle:
    shape: tuple[int, int]
    rows: int
    columns: int

    cells: list[npt.NDArray]
    indices: list[tuple[int, int]]

    __max_cell_id: int

    # precomputed coherence scores between cells at different edges. 
    __coherence: list[list[list[float]]]

    def __init__(self, cells: npt.NDArray):
        self.shape = cells.shape[:2]
        self.rows, self.columns = self.shape

        self.cells = list(cells.flat)
        self.__max_cell_id = len(self.cells) - 1
        self.indices = list(map(tuple, np.stack(np.indices(cells.shape), -1).reshape(-1, 2)))
        self.__coherence = Puzzle.compute_coherence(self.cells)
    

    def get_coherence(self, cell_a: int, cell_b: int, dir: Direction) -> float:
        if cell_a > self.__max_cell_id or cell_a < 0: raise Exception(f'cell_a is out of range ({self.__max_cell_id}): {cell_a}')
        if cell_b > self.__max_cell_id or cell_b < 0: raise Exception(f'cell_b is out of range ({self.__max_cell_id}): {cell_b}')

        if cell_b > cell_a:
            return self.get_coherence(
                cell_b, cell_a, directions_opposites[dir],
            )
        
        return self.__coherence[cell_a][cell_b][dir.value[0]]

    
    @staticmethod
    def compute_coherence(cells: list[npt.NDArray]) -> list[list[list[float]]]:
        return [
                    [
                        [
                            compute_edges_coherence(cells[i], cells[j], dir)
                            for dir in directions
                        ]
                        for j in range(i + 1)
                    ]
                    for i in range(len(cells))
                ]


@total_ordering
class State:
    puzzle: Puzzle

    shape: tuple[int, int]
    rows: int
    columns: int

    min_row: int
    max_row: int

    min_column: int
    max_column: int

    cells: dict[int, dict[int, int]]
    available_cells: set[int]
    free_neighbors: set[tuple[int, int]]

    coherence: float
    __hash: int

    @property
    def actions(self) -> Sequence[tuple[tuple[int, int], int]]: # (coords, cell_id)
        for neighbor in self.free_neighbors:
            for cell_id in self.available_cells:
                yield neighbor, cell_id


    @staticmethod
    def create_initial_state(puzzle: Puzzle) -> Self:
        result = State()
        result.puzzle = puzzle

        result.shape = (0, 0)
        result.rows, result.columns = 0, 0

        result.min_row, result.max_row = 0, 0
        result.min_column, result.max_column = 0, 0

        result.cells = dict()
        result.available_cells = set(range(len(puzzle.cells)))
        result.free_neighbors = set([ (0, 0) ])

        result.coherence = .0
        result.__update_hash()
        
        return result


    def is_empty(self) -> bool:
        return self.rows == 0 and self.columns == 0


    def is_complete(self) -> bool:
        return self.shape == self.puzzle.shape and len(self.available_cells) == 0


    def get_cell_id(self, coords: tuple[int, int]) -> int:
        if not coords[0] in self.cells: return -1
        if not coords[1] in self.cells[coords[0]]: return -1
        return self.cells[coords[0]][coords[1]]
    

    def get_cell_coherence(self, coords: tuple[int, int], cell_id: int) -> tuple[float, int]:
        coherence = .0
        neighbors = 0

        for neighbor, direction in [
            ( (coords[0] - 1, coords[1]),  Direction.UP    ),
            ( (coords[0] + 1, coords[1]),  Direction.DOWN  ),
            ( (coords[0], coords[1] - 1),  Direction.LEFT  ),
            ( (coords[0], coords[1] + 1),  Direction.RIGHT ),
        ]:
            other_cell_id = self.get_cell_id(neighbor)
            if other_cell_id == -1: continue

            coherence += self.puzzle.get_coherence(cell_id, other_cell_id, direction)
            neighbors += 1
        
        return coherence, neighbors
    

    def copy(self) -> Self:
        result = State()
        result.puzzle = self.puzzle

        result.shape = self.shape
        result.rows, result.columns = self.rows, self.columns

        result.min_row, result.max_row = self.min_row, self.max_row
        result.min_column, result.max_column = self.min_column, self.max_column

        result.cells = deepcopy(self.cells)
        result.available_cells = self.available_cells.copy()
        result.free_neighbors = self.free_neighbors.copy()

        result.coherence = self.coherence
        result.__hash = self.__hash

        return result


    def toarray(self) -> npt.NDArray:
        result = np.full(self.shape, -1, dtype=np.int32)

        for i, row in enumerate(self.cells.values()):
            for j, cell_id in enumerate(row.values()):
                result[i, j] = cell_id

        return result


    def plot(
            self, *,
            cell_size: tuple[float, float] = (4., 4.5),
            format: int = cv.COLOR_BGR2RGB,
            title: str = 'State',
            dpi: int = 80,
        ) -> None:

        def generate_sequence():
            for i in range(self.min_row, self.min_row + self.rows):
                for j in range(self.min_column, self.min_column + self.columns):
                    if not i in self.cells:
                        yield None
                    elif not j in self.cells[i]:
                        yield None
                    else:
                        yield (f'({i:2d}, {j:2d})', self.puzzle.cells[self.cells[i][j]])
                        
        plot_images(
            generate_sequence(), self.columns,
            title=title, cell_size=cell_size, format=format, dpi=dpi,
        )


    def apply(self, coords: tuple[int, int], cell_id: int) -> Self:
        if not coords in self.free_neighbors:
            raise Exception('The coordinates are not of a free cell with neighbors.')
        
        if not cell_id in self.available_cells:
            raise Exception('The requested image is not available.')

        if self.__check_overflow(coords):
            raise Exception('Cell out of bounds.')

        result = self.copy()

        result.available_cells.remove(cell_id)
        result.free_neighbors.remove(coords)

        result.__set_cell_id(coords, cell_id)
        result.__update_bounds(coords)

        if result.rows != self.rows or result.columns != self.columns:
            result.__check_neighbors()
        
        result.__add_free_neighbors(coords)

        result.coherence = self.coherence + self.get_cell_coherence(coords, cell_id)[0]
        result.__update_hash()

        return result


    def __set_cell_id(self, coords: tuple[int, int], cell_id: int) -> None:
        if not coords[0] in self.cells:
            self.cells[coords[0]] = dict()
        self.cells[coords[0]][coords[1]] = cell_id


    def __update_bounds(self, coords: tuple[int, int]) -> None:
        row, column = coords

        if row < self.min_row:
            self.min_row = row
            self.rows += 1
        
        if row >= self.max_row:
            self.max_row = row + 1
            self.rows += 1
        
        if column < self.min_column:
            self.min_column = column
            self.columns += 1
        
        if column >= self.max_column:
            self.max_column = column + 1
            self.columns += 1
        
        self.shape = (self.rows, self.columns)


    def __add_free_neighbors(self, coords: tuple[int, int]) -> None:
        for neighbor in [
            (coords[0] - 1, coords[1]),
            (coords[0] + 1, coords[1]),
            (coords[0], coords[1] - 1),
            (coords[0], coords[1] + 1),
        ]:
            if self.__check_overflow(neighbor):
                continue

            if self.get_cell_id(neighbor) == -1:
                self.free_neighbors.add(neighbor)
    

    def __check_neighbors(self) -> None:
        to_remove = list(filter(self.__check_overflow, self.free_neighbors))

        for item in to_remove:
            self.free_neighbors.remove(item)
    

    def __check_overflow(self, coords: tuple[int, int]) -> bool:
        row, column = coords

        if self.rows == self.puzzle.rows:
            if row < self.min_row or row >= self.max_row:
                return True
        
        if self.columns == self.puzzle.columns:
            if column < self.min_column or column >= self.max_column:
                return True
        
        return False
    

    def __update_hash(self) -> None:
        # NOTE: More performant solutions might be possible.
        self.__hash = hash(self.toarray().tobytes())


    def __hash__(self) -> int:
        return self.__hash
    

    def __eq__(self, other) -> bool:
        return hash(self) == hash(other)


    def __le__(self, other) -> bool:
        return hash(self) < hash(other)
    


puzzle = Puzzle(cells)
state = State.create_initial_state(puzzle)

state = state.apply((0, 0), 0)
state = state.apply((0, 1), 1)
state = state.apply((-1, 0), 2)

state.toarray()

In [None]:
state.plot(cell_size=(1., 1.125), format=cv.COLOR_Lab2RGB)

## Search Algorithm

In [None]:
@dataclass
class Solution():
    state: State
    heuristic: float
    visited: int
    queue: int


def search_solutions(puzzle: Puzzle,  heuristic: Callable[[State], float] = lambda _: .0) -> Iterator[Solution]:
    queue: Queue[tuple[float, State]] = PriorityQueue()
    visited: set[State] = set()

    initial_state = State.create_initial_state(puzzle)

    queue.put((.0, initial_state))
    visited.add(initial_state)

    del initial_state

    while not queue.empty():
        score, state = queue.get()

        if state.is_complete():
            yield Solution(state, score, len(visited), queue.qsize())
                
        for action in state.actions:
            child = state.apply(*action)
            if child in visited: continue

            queue.put((heuristic(child), child), block=False)
            visited.add(child)


def heuristic(state: State) -> float:
    return state.coherence + len(state.available_cells) * 1000

In [None]:
# solutions = search_solutions(puzzle, heuristic)

# total = 0
# for solution in solutions:
#     total += 1
#     print(f'{total:3d}). Solution {hash(solution.state):21d} • Heuristic ({solution.heuristic}) • Visited ( {solution.visited:3d} ) • Queue ( {solution.queue:3d} )')

# print('Total solutions:', total)

In [None]:
solution = next(search_solutions(puzzle, heuristic))
solution.state.plot(
    title=f'Solution #{hash(solution.state)} • Heuristic ({solution.heuristic}) • Visited ({solution.visited}) • Queue ({solution.queue})',
    cell_size=(4, 3), format=cv.COLOR_Lab2RGB,
)

# Miscellaneous

In [None]:
# Vectors in (row, column) units.
directions_vectors: dict[Direction, npt.NDArray] = {
    Direction.UP    : np.array([-1,  0], dtype=np.int8),
    Direction.DOWN  : np.array([ 1,  0], dtype=np.int8),
    Direction.LEFT  : np.array([ 0, -1], dtype=np.int8),
    Direction.RIGHT : np.array([ 0,  1], dtype=np.int8),
}