In [None]:
import cv2 as cv
import numpy as np
import scipy
from matplotlib import pyplot as plt
import enum
from functools import reduce
from typing import Self
from scipy.spatial import KDTree
from sklearn.preprocessing import Normalizer
import random
import skimage

## Utility Functions

In [None]:
def image_to_pieces(im, R=3, C=3):
    images = []
    r, c, _d = im.shape
    cell_r = r // R
    cell_c = c // C

    for i in range(R):
        i_from, i_to = cell_r * i, cell_r * (i + 1)
        for j in range(C):
            j_from, j_to = cell_c * j, cell_c * (j + 1)
            images.append(im[i_from:i_to, j_from:j_to])

    return images

def pieces_to_image(pieces, r, c):
    res = []
    for i in range(c):
        res.append(np.concatenate(pieces[c * i : c * (i + 1)], axis=1))
    return np.concatenate(res, axis=0)

def gray_to_pieces(im, R=3, C=3):
    images = []
    r, c = im.shape
    cell_r = r // R
    cell_c = c // C

    for i in range(R):
        i_from, i_to = cell_r * i, cell_r * (i + 1)
        for j in range(C):
            j_from, j_to = cell_c * j, cell_c * (j + 1)
            images.append(im[i_from:i_to, j_from:j_to])

    return images

def get_piece_sides(piece: np.ndarray) -> (
        (np.ndarray, np.ndarray), (np.ndarray, np.ndarray)
    ):
    """Return (top, bottom), (left, right) sides."""
    return (piece[0], piece[-1]), (piece[:, 0], piece[:, -1])


In [None]:
class PieceSide(enum.IntEnum):
    TOP = 0
    BOTTOM = 1
    LEFT = 2
    RIGHT = 3

    def inverse(self) -> Self:
        return PieceSide(
            self + 1 if self % 2 == 0 
            else self - 1
        )

In [None]:
class PieceFeatureExtractorSidesPooling:

    def __init__(self, reducers=[np.max, np.min], sizes=[1]):
        self.reducers = reducers  # TODO
        self.sizes = sizes
        self.sides = [[] for _ in range(4)]

    # def fit(self, pieces):
    #     self._pieces_len = len(pieces)

    #     features = [self._piece_to_features(piece) for piece in pieces]
    #     for side in PieceSide:
    #         sides_features = np.array([pf[side] for pf in features])
    #         sides_features = self.normalizers[side].fit_transform(sides_features)
    #         # self.sides[side] = sides_features  # .T

    #     return self
    
    # def fit_transform(self, pieces):
    #     pass
    
    def transform(self, pieces):
        features = [self._piece_to_features(piece) for piece in pieces]
        res = [[] for _ in range(4)]
        for side in PieceSide:
            sides_features = np.array([pf[side] for pf in features])
            res[side] = sides_features
        
        return res

    def _piece_to_features(self, piece):
        piece_features = [[] for _ in range(4)]
        sides = get_piece_sides(piece)

        for idx, side in enumerate([*sides[0], *sides[1]]):
            for k in self.sizes:
                for reducer in self.reducers:
                    sub_sides = np.array_split(side, k)
                    for sub_side in sub_sides:
                        piece_features[idx].append(reducer(sub_side))

        return piece_features

In [None]:
class PieceMatcher:

    def __init__(self):
        self.kdtree = {}

    def fit(self, pieces_features):
        self._pieces_features = pieces_features
        self._pieces_len = len(pieces_features[0])

        for side in PieceSide:
            data = self._pieces_features[side]
            self.kdtree[side] = KDTree(data)

        return self



## A _home-made_ accuracy function

In [None]:
def measure_acc_1_neigh(matcher, image_pieces_shape):
    r, c = image_pieces_shape
    ids = np.arange(r * c)
    lc_rem = ids.reshape(r, c)[:, :-1].reshape(-1)  # last  column
    fc_rem = ids.reshape(r, c)[:, 1:].reshape(-1)   # first column
    fr_rem = ids[r:].reshape(-1, c).T.reshape(-1)   # first row
    lr_rem = ids[:-r].reshape(-1, c).T.reshape(-1)  # last  row

    pairs = [
        zip(fr_rem[::-1], lr_rem[::-1]),  # to top
        zip(lr_rem, fr_rem),              # to bottom

        zip(fc_rem[::-1], lc_rem[-1::-1]),  # to left
        zip(lc_rem, fc_rem),              # to right
    ]

    acc = 0
    for side, p in zip(PieceSide, pairs):
        res = np.array([
            1
            # (idx, nidx, matcher.get_nearest_neighbor(idx, side)[1][0])
            for (idx, nidx) in p 
            if matcher.get_nearest_neighbor(idx, side)[1] == nidx
        ])
        # print(res)
        acc += res.sum()

    acc /= 4
    acc /= ((r * (c - 1)) * 2 + (c * (r - 1)) * 2) / 4
    return acc

In [None]:

def fill(grid, g, idx, ridx, shape, visited, inds):
    if visited[idx]: 
        return

    r, c = shape
    pr, nr = ridx - c, ridx + c
    pc, nc = ridx - 1, ridx + 1

    visited[idx] = True
    grid[ridx] = idx

    for i, mov in enumerate([pr, nr, pc, nc]):
        if g[idx][i][1] != None and inds[g[idx][i][1]] == -1:
            inds[g[idx][i][1]] = mov

    for i, mov in enumerate([pr, nr, pc, nc]):
        if g[idx][i][1] != None and grid[inds[g[idx][i][1]]] == -1:
            fill(grid, g, g[idx][i][1], inds[g[idx][i][1]], shape, visited, inds)


def make_grid(g, shape):
    r, c = shape
    max_idx = r * c

    grid = np.ones(shape=max_idx * 4).reshape(-1) * -1
    print(grid.shape)
    grid = grid.astype(np.int32)
    indices = np.ones(shape=shape).reshape(-1) * -1
    indices = indices.astype(np.int32)
    visited = [False for _ in range(max_idx)]

    g_shrinked = [[p for p in v  if p != (0, None)] for v in g]
    g_len = np.array([len(l) for l in g_shrinked]) + 1
    g_sum = np.array([sum(p[0] for p in v) for v in g]) + 1
    g_prt = np.array(g_sum) * 4 / np.array(g_len)
    root = g_prt.argsort()[0]

    indices[root] = max_idx
    grid[max_idx] = root
    fill(grid, g, root, max_idx, shape, visited, indices) 

    grid = grid[grid != -1]
    grid = grid.astype(np.int32)
    return grid

In [None]:
im = cv.imread("Images/home.jpg")
height, width = 13, 13

h, s, v = cv.split(im)
pieces = image_to_pieces(im, height, width)

random.shuffle(pieces)
pieces_f = {}
# for i, c in enumerate('bgr'):
#     pieces_f[c] = np.array([cv.split(piece)[i] / 255. for piece in pieces])

pieces_hsv = np.array([cv.cvtColor(piece, cv.COLOR_BGR2HSV) for piece in pieces])
for i, (c, maxval) in enumerate(zip('hs', [180, 100])):
    pieces_f[c] = np.array([cv.split(piece)[i] / maxval for piece in pieces_hsv])

pieces_features = []
for p in pieces_f.values():
    features = PieceFeatureExtractorSidesPooling([np.max, np.min], sizes=[11]).transform(p)
    pieces_features.append(features)

features = np.concatenate(pieces_features, axis=2)
matcher = PieceMatcher().fit(features)

In [None]:
from queue import PriorityQueue

def solve(matcher, features):
    g = [[(0, None)] * 4 for _ in range(height * width)]
    pq = PriorityQueue()

    for side in [PieceSide(i) for i in [0, 2]]:
        iside = PieceSide.inverse(side)
        dist, bottom = matcher.kdtree[side].query(features[iside], k=1)
        _, top = matcher.kdtree[iside].query(features[side], k=1)
        for i0, n0 in enumerate(top):
            if i0 == n0:
                continue
            if bottom[n0] == i0:
                # if dist[i] < 1.3: # TODO: change me later
                pq.put((dist[i0], i0, n0, side))

    while not pq.empty():
        dist, p0, p1, side = pq.get()
        # print(dist, p0, p1, side)
        iside = side.inverse()
        if g[p0][side] != (0, None) or g[p1][iside] != (0, None):
            print(f"Warning {p0},{side} and {p1},{iside}: {g[p0][side]}, {g[p1][iside]}")
            continue
        g[p0][side] = (dist, p1)
        g[p1][iside] = (dist, p0)

    return g

g = solve(matcher, features)
# print(g)
grid = make_grid(g, (height, width))
pieces_ordered = [pieces[i] for i in grid]

# plt.figure(figsize=(16, 16))
for i, p in enumerate(pieces_ordered):
    plt.subplot(height, width, i + 1)
    plt.imshow(cv.cvtColor(p, cv.COLOR_BGR2RGB))
    plt.axis(False)

plt.show()

In [None]:
# from scipy.spatial import KDTree

class PieceMatcher_:
    def __init__(self):
        pass

    def fit(self, pieces_features):
        self._pieces_features = pieces_features
        self.kdtree = KDTree(self._pieces_features)

        return self

    def get(self, piece_features):
        piece = self.kdtree.query(piece_features)
        return piece


In [None]:
class PieceFeatureExtractorAbstract:
    
    def __init__(self, /, *, features=None):
        self._features = features

    def __add__(self, other: Self) -> Self:
        features = np.concatenate(
            [self.features, other.features], 
            axis=1)

        return PieceFeatureExtractorAbstract(features=features)

    @property
    def features(self):
        return self._features

    def __getitem__(self, key):
        return self.features[key]


class PieceFeatureExtractorSIFT(PieceFeatureExtractorAbstract):

    def __init__(self, /, *, reducers=[np.max, np.mean]): 
        super().__init__()
        self._sift = cv.SIFT_create()
        self._reducers = reducers

    def fit(self, pieces): 
        pieces_kps = self._sift.detect(pieces, None)
        self.pieces_sift = [
            np.array([[kp.angle, kp.response, kp.octave] for kp in piece_kps])
            for piece_kps in pieces_kps
        ]

        _features = []
        for reducer in self._reducers:
            pieces_features = np.array([
                reducer(piece_sift, axis=0) 
                if piece_sift.size != 0
                else np.zeros(3)
                for piece_sift in self.pieces_sift
            ])
            _features.append(pieces_features)

        self._features = np.concatenate(_features, axis=1)

        return self


class PieceFeatureExtractor2DPooling(PieceFeatureExtractorAbstract):

    def __init__(self, /, *, reducers=[np.max, np.mean]):
        super().__init__()
        self._reducers = reducers
        
    def fit(self, pieces):
        features = []

        pieces = np.array(pieces)
        for reducer in self._reducers:
            pieces_features = np.array(list(map(
                lambda p: skimage.measure.block_reduce(p, (91, 91), reducer),
                pieces,
            ))).reshape(pieces.shape[0], -1)
            features.append(pieces_features)

        self._features = np.concatenate(features, axis=1)
        return self


class PieceFeatureExtractorSidesPooling_(PieceFeatureExtractorAbstract):
    # TODO: On Progress 

    def __init__(self, reducers=[np.max, np.min], sizes=[1]):
        super().__init__()
        self.reducers = reducers
        self.sizes = sizes
        self.sides = [[] for _ in range(4)]

    def fit(self, pieces):
        features = [self._piece_to_features(piece) for piece in pieces]
        res = [[] for _ in range(4)]
        for side in PieceSide:
            sides_features = np.array([pf[side] for pf in features])
            res[side] = sides_features
        
        self._features = features
        return self

    def _piece_to_features(self, piece):
        piece_features = [[] for _ in range(4)]
        sides = get_piece_sides(piece)

        for idx, side in enumerate([*sides[0], *sides[1]]):
            for k in self.sizes:
                for reducer in self.reducers:
                    sub_sides = np.array_split(side, k)
                    for sub_side in sub_sides:
                        piece_features[idx].append(reducer(sub_side))

        return piece_features

In [None]:
from scipy.spatial import KDTree

class PieceMatcher_:
    def __init__(self):
        pass

    def fit(self, pieces_features):
        self._pieces_features = pieces_features
        self.kdtree = KDTree(self._pieces_features)

        return self

    def get(self, piece_features):
        dist, piece = self.kdtree.query(piece_features)
        return dist, piece


In [None]:
class PuzzleSolver:
    pass

class PuzzleSolverRectHint(PuzzleSolver):
    
    def __init__(self):
        pass

    def fit(self, target_im, pieces_shape):
        self._pieces_shape = pieces_shape

        grayed = cv.cvtColor(target_im, cv.COLOR_BGR2GRAY)
        gray_pieces = gray_to_pieces(grayed, *self._pieces_shape)
        features = (PieceFeatureExtractorSIFT().fit(gray_pieces) +
                    PieceFeatureExtractor2DPooling().fit(gray_pieces))

        self._matcher = PieceMatcher_().fit(features.features)

    def solve(self, im):
        pieces = image_to_pieces(im, *self._pieces_shape)

        grayed = cv.cvtColor(im, cv.COLOR_BGR2GRAY)
        gray_pieces = gray_to_pieces(grayed, *self._pieces_shape)
        features = (PieceFeatureExtractorSIFT().fit(gray_pieces) +
                    PieceFeatureExtractor2DPooling().fit(gray_pieces))
        
        res_im = [None for _ in range(self._pieces_shape[0] * self._pieces_shape[1])]
        for i, f in enumerate(features.features):
            dist, j = self._matcher.get(f)
            res_im[j] = pieces[i]

        return pieces_to_image(res_im, self._pieces_shape[0], self._pieces_shape[1])


In [None]:
target_im = cv.imread('Images/nature.jpg')
r, c = 50, 50
pieces = image_to_pieces(target_im, r, c)
random.shuffle(pieces)
im = pieces_to_image(pieces, r, c)

solver = PuzzleSolverRectHint()
solver.fit(target_im, (r, c))
res_im = solver.solve(im)

plt.figure(figsize=(14, 24))
plt.subplot(1, 3, 1)
plt.title('Input')
plt.axis(False)
plt.tight_layout()
plt.imshow(cv.cvtColor(im, cv.COLOR_BGR2RGB))

plt.subplot(1, 3, 2)
plt.axis(False)
plt.title('Target (hint)')
plt.tight_layout()
plt.imshow(cv.cvtColor(target_im, cv.COLOR_BGR2RGB))

plt.subplot(1, 3, 3)
plt.axis(False)
plt.title('Result')
plt.tight_layout()
plt.imshow(cv.cvtColor(res_im, cv.COLOR_BGR2RGB))