In [None]:
import PIL.Image
import numpy as np
from typing import Union
from glob import glob
import os
import cv2
import skimage.morphology
from skimage import measure
from skimage import io, filters
import scipy
from scipy.spatial.distance import cdist
from sklearn.cluster import DBSCAN
import imutils

# [IAPR][iapr]: Project


**Group ID:** 14

**Author 1 (sciper):** David Rüegg (218512)  
**Author 2 (sciper):** Yacine Derder (301994)   
**Author 3 (sciper):** Elsa Pariat (301964)   

**Release date:** 27.04.2023


## Important notes

The assignments are designed to teach practical implementation of the topics presented during class as well as preparation for the final project, which is a practical project which ties together the topics of the course. 

As such, in the lab assignments/final project, unless otherwise specified, you may, if you choose, use external functions from image processing/ML libraries like opencv and sklearn as long as there is sufficient explanation in the lab report. For example, you do not need to implement your own edge detector, etc.

**! Before handling back the notebook !** rerun the notebook from scratch `Kernel` > `Restart & Run All`


[iapr]: https://github.com/LTS5/iapr

---
## 0. Introduction

In this project, you will be working on solving tiling puzzles using image analysis and pattern recognition techniques. Tiling puzzles are a classic type of puzzle game that consists of fitting together pieces of a given shape (in this case squared to form a complete image. The goal of this project is to develop an algorithm that can automatically reconstruct tiling puzzles from a single input image. 

---

## 1. Data

### Input data
To achieve your task, you will be given images that look like this:


![train_00.png](data_project/project_description/train_00.png)

### Example puzzle content
Example of input of solved puzzles. 
Solution 1
<img src="data_project/project_description/solution_example.png" width="512"/>
Solution 2
<img src="data_project/project_description/solution_example2.jpg" width="512"/>


### 1.1. Image layout

- The input for the program will be a single image with a size of __2000x2000 pixels__, containing the pieces of the tiling puzzles randomly placed in it. The puzzles sizes vary from __3x3, 3x4, or 4x4__ size. 
    -__You are guaranteed to always have the exact number of pieces for each puzzle__ 
        -For each puzzle you always are expected to find exaclty 9,12,16 pieces
        -If you find something else, either you are missing pieces, or added incorrect pieces for the puzzle

- The puzzle pieces are square-shaped with dimensions of 128x128 pixels (before rotation). 

- The input image will contain pieces from __two or three (but never four)__ different tiling puzzles, as well as some __extra pieces (outliers)__ that do not belong to either puzzle.


## 2. Tasks (Total 20 points) 


The project aims to:
1) Segment the puzzle pieces from the background (recover the pieces of 128x128 pixels)   \[ __5 points__ \] 

2) Extract features of interest from puzzle pieces images \[ __5 points__ \]   

3) Cluster puzzle pieces to identify which puzzle they belong, and identify outliers.  \[ __5 points__ \]   

4) Solve tiling puzzle (find the rotations and translations to correctly allocate the puzzle pieces in a 3x3, 3x4 or 4x4 array.) \[ __5 points__ \]   

##### The images used for the puzzles have self-repeating patterns or textures, which ensures that all puzzle pieces contain more or less the same features regardless of how they were cut. 




### 1.2. Output solution pieces.

For each inpute image, the output solution will include N images with solved puzzles, where N is the number of puzzles in the input image. and M images, that are Each of these images will contain the solved solution to one of the N puzzles in the input. 


-  Example input:  train_05.png

- Example solution:
        -solution_05_00.png solution_05_01.png solution_05_02.png 
        -outlier_05_00.png outlier_05_01.png outlier_05_02.png ...

- Example input:  train_07.png
- Example solution:
        -solution_07_00.png solution_07_01.png 
        -outlier_07_00.png outlier_07_01.png outlier_07_02.png ...


__Watch out!__ output resolution should always be like this:  
<table ><tr><th >Puzzle pieces <th><th> pixel dimentions <th> <th> pixel dimentions <th> <tr>
<tr><td> 3x3 <td><td> 384x384 <td><td> 3(128)x3(128) <td> <tr>
<tr><td> 3x4 <td><td> 384x512 <td><td> 3(128)x4(128)<tr>
<tr><td> 4x4 <td><td> 512x512 <td><td> 4(128)x4(128)<tr>
<tr><td> 1x1 (outlier)<td><td> 128x128 <td><td> (1)128x(1)128 <td><tr><table>





__Order of the solutions (and rotations) it's not a problem for the grading__




the output solution will be a final image of resolution (1283)x(1283), with each piece correctly placed in its corresponding location in the 3x3 array. Similarly, if the puzzle consists of 3x4 or 4x4 pieces, the output solution will be an image of resolution (1283)x(1284) or (1284)x(1284)



### 1.3 Data folder Structure

You can download the data for the project here: [download data](https://drive.google.com/drive/folders/1k3xTH0ZhpqZb3xcZ6wsOSjLzxBNYabg3?usp=share_link)

```
data_project
│
└─── project_description
│    │    example_input.png      # example input images
│    │    example_textures1.png      # example input images
│    │    example_textures2.png      # example input images
│    └─── ultimate_test.jpg   # If it works on that image, you would probably end up with a good score
│
└─── train
│    │    train_00.png        # Train image 00
│    │    ...
│    │    train_16.png        # Train image 16
│    └─── train_labels.csv    # Ground truth of the train set
|    
└────train_solution
│    │    solution_00_00.png        # Solution puzzle 1 from Train image 00
│    │    solution_00_01.png        # Solution puzzle 2 from Train image 00
│    │    solution_00_02.png        # Solution Puzzle 3 from Train image 00
│    │    outlier_00_00.png         # outlier     from Train image 00
│    │    outlier_00_01.png         # outlier     from Train image 00
│    │    outlier_00_03.png         # outlier     from Train image 00
│    │    ...
│    │    solution_15_00.png        # Solution puzzle 1 from Train image 15
│    │    solution_15_01.png        # Solution puzzle 2 from Train image 15
│    │    outlier_15_00.png         # outlier     from Train image 15
│    └─── outlier_15_01.png         # outlier     from Train image 15
│
└─── test
     │    test_00.png         # Test image 00 (day of the exam only)
     │    ...
     └─── test_xx.png             # Test image xx (day of the exam only)
```



## 3. Evaluation

**Before the exam**
   - Create a zipped folder named **groupid_xx.zip** that you upload on moodle (xx being your group number).
   - Include a **runnable** code (Jupyter Notebook and external files) and your presentation in the zip folder.
   
**The day of the exam**
   - You will be given a **new folder** (test folder) with few images, but **no ground truth** (no solutions).
   - We will ask you to run your pipeline in **real time** and to send us your prediction of the task you obtain with the provided function **save_results**. 
   - On our side, we will compute the performance of your classification algorithm. 
   - To evaluate your method, we will use the **evaluate_solution** function presented below. To understand how the provided functions work, please read the documentation of the functions in **utils.py**.
   - **Please make sure your function returns the proper data format to avoid points penalty on the day of the exam**. 
---


## 4. Your code

In [None]:
## load images
import os 
from PIL import Image


import numpy as np
import matplotlib.pyplot as plt


In [None]:

def load_input_image(image_index ,  folder ="train2" , path = "data_project"):
    
    filename = "train_{}.png".format(str(image_index).zfill(2))
    path_solution = os.path.join(path,folder , filename )
    
    im= Image.open(os.path.join(path,folder,filename)).convert('RGB')
    im = np.array(im)
    return im

def save_solution_puzzles(image_index , solved_puzzles, outliers  , folder ="train2" , path = "data_project"  ,group_id = 0):
    
    path_solution = os.path.join(path,folder + "_solution_{}".format(str(group_id).zfill(2)))
    if not  os.path.isdir(path_solution):
        os.mkdir(path_solution)

    print(path_solution)
    for i, puzzle in enumerate(solved_puzzles):
        filename =os.path.join(path_solution, "solution_{}_{}.png".format(str(image_index).zfill(2), str(i).zfill(2)))
        Image.fromarray(puzzle).save(filename)

    for i , outlier in enumerate(outliers):
        filename =os.path.join(path_solution, "outlier_{}_{}.png".format(str(image_index).zfill(2), str(i).zfill(2)))
        Image.fromarray(outlier).save(filename)


In [None]:
def solve_and_export_puzzles_image(image_index , folder = "train2" , path = "data_project"  , group_id = "00"):
    """
    Wrapper funciton to load image and save solution
            
    Parameters
    ----------
    image:
        index number of the dataset

    Returns
    """

      # open the image
    image_loaded = load_input_image(image_index , folder = folder , path = path)
    #print(image_loaded)
    
   
    ## call functions to solve image_loaded
    solved_puzzles = [ (np.random.rand(512,512,3)*255).astype(np.uint8)  for i in range(2) ]
    outlier_images = [ (np.random.rand(128,128,3)*255).astype(np.uint8) for i in range(3)]
    
    save_solution_puzzles (image_index , solved_puzzles , outlier_images , folder = folder ,group_id =group_id)
    
   
    
    
    
    return image_loaded , solved_puzzles , outlier_images

im, sol , out = solve_and_export_puzzles_image(6 , group_id = 6)

In [None]:
group_id = 14
# Evaluate all images
#games_id = [6,10]  # to evaluate  three images
#
#for i in games_id :
#    
#    print("solving " , i)
#    # Saving results
#    solve_and_export_puzzles_image(6 , group_id = group_id)
  


## Evaluation metrics

The evaluation metrics will be liberated in the following days. 


## Gabor filters

$$ gb(x,y) = \exp \left( -\frac{1}{2} \left( \frac{x_{\theta}^2}{\sigma^2} + \frac{y_{\theta}^2}{(\Gamma\sigma)^2} \right) \right) \cos \left( \frac{2 \pi}{\lambda} x_{\theta} + \psi \right) $$


In [None]:
# Plot all input images

# Get number of train inputs
dir_path = r'data_project\train2'
nb_train_samples = len([entry for entry in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, entry))])

fig, ax = plt.subplots(3, 4, figsize = (15,10))
for i in range(3):
    for j in range(4):
        ax[i,j].imshow(load_input_image(4*i + j))
    
plt.show()

In [None]:
piece_dim = 128

def preprocess(img, th_val = 75, trans="rgb", c_k_size=20, o_k_size=5):
    if trans[:3] == "rgb":
        mean_rgb = np.mean(img, axis=(0,1))
        pix_dist = np.linalg.norm(img - mean_rgb, axis = 2)
    elif trans[:3] == "hsv":
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        mean_hsv = np.mean(hsv, axis=(0,1))
        pix_dist = np.linalg.norm(hsv - mean_hsv, axis = 2)
    elif trans[:3] == "all":
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        img_all = np.concatenate((img, hsv), axis=2)
        mean_all = np.mean(img_all, axis=(0,1))
        pix_dist = np.linalg.norm(img_all - mean_all, axis = 2)
    else:
        raise Exception("Invalid Transformation")
    
#     print(np.max(pix_dist))
#     print(np.min(pix_dist))
    
    # Threshold
    im_prep = np.zeros_like(pix_dist)
    if trans[len(trans)-3:len(trans)] == "esh":
        th_val = filters.threshold_otsu(pix_dist.flatten())
#     else:
#         th_val = (np.max(pix_dist) + np.min(pix_dist)) / 2
    im_prep[pix_dist > th_val] = 255
    
    # Remove small holes and objects
    im_prep = skimage.morphology.remove_small_holes(im_prep.astype(bool), area_threshold=10000).astype(float)*255
#     im_prep = skimage.morphology.remove_small_objects(im_prep.astype(bool), min_size=100).astype(float)*255
    
    # Close then open
    kernel = np.ones((c_k_size,c_k_size), np.uint8)
    im_prep = cv2.morphologyEx(im_prep, cv2.MORPH_CLOSE, kernel)
    #kernel = np.ones((o_k_size,o_k_size), np.uint8)
    #im_prep = cv2.morphologyEx(im_prep, cv2.MORPH_OPEN, kernel)
    
    # Remove small holes and objects
    im_prep = skimage.morphology.remove_small_holes(im_prep.astype(bool), area_threshold=10000).astype(float)*255
    
    # Remove small holes and objects
    kernel = np.ones((o_k_size,o_k_size), np.uint8)
    im_prep = cv2.dilate(im_prep,kernel,iterations = 1)
    im_prep = skimage.morphology.remove_small_holes(im_prep.astype(bool), area_threshold=10000).astype(np.uint8)*255
    im_prep = cv2.erode(im_prep,kernel,iterations = 1)
    im_prep = skimage.morphology.remove_small_objects(im_prep.astype(bool), min_size=50).astype(np.uint8)*255
    
    return im_prep

def comb_transf(img):
    transforms = ["rgb", "rgb + auto_thresh", "hsv", "hsv + auto_thresh", "all", "all + auto_thresh"]
#     transforms = ["rgb", "hsv", "all"]
#     transforms = ["all"]
#     transforms = ["rgb + auto_thresh", "hsv + auto_thresh", "all + auto_thresh"]
    
    im_prep = np.zeros([img.shape[0], img.shape[1], len(transforms)])
    for i in range(len(transforms)):
        im_prep[:,:,i] = preprocess(img, trans=transforms[i], th_val = 80)
    
    im_comb = np.round(np.mean(im_prep, axis=2)/255)*255
    
    return im_comb

def pad_image(image, target_shape):
    copy = image.copy()
    if image.shape[0] > target_shape[0]:
        diff = image.shape[0] - target_shape[0]
        copy = image[diff//2:diff//2+target_shape[0], :, :]
    if image.shape[1] > target_shape[1]:
        diff = image.shape[1] - target_shape[1]
        copy = image[:, diff//2:diff//2+target_shape[1], :]
    
    padded_image = np.zeros(target_shape, dtype=image.dtype)
    height_diff = target_shape[0] - copy.shape[0]
    width_diff = target_shape[1] - copy.shape[1]
    pad_top = height_diff // 2
    pad_bottom = height_diff - pad_top
    pad_left = width_diff // 2
    pad_right = width_diff - pad_left

    padded_image[pad_top:pad_top+image.shape[0], pad_left:pad_left+image.shape[1], :] = copy

    return padded_image

def opt_crop(piece_bin, piece_original):
    if piece_bin.shape[0] != piece_dim or piece_bin.shape[1] != piece_dim:
        i_slide = piece_bin.shape[0]-128
        j_slide = piece_bin.shape[1]-128
        if i_slide > 0 and j_slide > 0:
            scores = np.zeros([i_slide, j_slide])
            for i in range(i_slide):
                for j in range(j_slide):
                    crop = piece_bin[i:i+128,j:j+128]
                    scores[i,j] = np.sum(crop)
            best_loc = np.where(scores == np.max(scores))
            best_loc = np.array([best_loc[0][0], best_loc[1][0]])
            piece_crop = piece_original[best_loc[0]:best_loc[0]+128,best_loc[1]:best_loc[1]+128,:]
        else:
            piece_crop = pad_image(piece_original, [128,128,3])
    else:
        piece_crop = piece_original
        
    return piece_crop

def isolate_pieces(original, plot=False):
    im_prep = comb_transf(original)
    im_prep = im_prep.astype(np.uint8)
    if plot:
        resized = cv2.resize(im_prep, (800,800))
        cv2.imshow(f'Thresholded image', resized)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    contours, _ = cv2.findContours(im_prep, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)

    # Find the initial centroids
    p_pos = np.zeros([len(contours), 2])
    for idx, contour in enumerate(contours):
        canvas = np.zeros_like(im_prep)
        cv2.fillPoly(canvas, [contour], color=(255))

        moments = cv2.moments(contour)
        if moments['m00'] != 0:
            cx = moments['m10'] / moments['m00']
            cy = moments['m01'] / moments['m00']
        else:
            canvas = np.zeros_like(im_prep)
            cv2.fillPoly(canvas, [contour], color=(255))
            resized = cv2.resize(canvas, (800,800))
            cv2.imshow(f'test', resized)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
            cx = 0
            cy = 0
#             contours.remove(contour)
            
        p_pos[idx,:] = [cx, cy]

    # Merge pieces that are very close together using DBSCAN
    db = DBSCAN(eps=75, min_samples=1).fit(p_pos)
    
#     plt.scatter(p_pos[:,0], p_pos[:,1])
        
    merged_contours = []
    stuck_pieces = []
    merged_contours_areas = []
    for labs in range(np.max(db.labels_)+1):
        to_merge = np.where(db.labels_ == labs)[0]
        merged_i = contours[to_merge[0]]
        for i in range(1,len(to_merge)):
            merged_i = np.concatenate((merged_i, contours[to_merge[i]]), axis=0)
        merged_contours.append(merged_i)
        merged_contours_areas.append(cv2.contourArea(merged_i))
        if cv2.contourArea(merged_i) >= piece_dim**2*1.5:
            stuck_pieces.append(merged_i)

    # Delete too small pieces (Noise) or too large pieces (stuck)
    to_delete = np.where(np.array(merged_contours_areas) <= piece_dim**2*0.3)
    to_delete = np.append(to_delete, np.where(np.array(merged_contours_areas) >= piece_dim**2*1.5))
    merged_contours = np.delete(np.array(merged_contours, dtype=object), to_delete)
    merged_contours = list(merged_contours)
    
    for idx, stuck_piece in enumerate(stuck_pieces):
        canvas = np.zeros_like(im_prep)
        cv2.fillPoly(canvas, [stuck_piece], color=(255))
        
        if plot:
            resized = cv2.resize(canvas, (800,800))
            cv2.imshow(f'stuck_piece {idx+1}', resized)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        
        # Open the stuck pieces to seperate them
        kernel = np.ones((25,25), np.uint8)
        canvas = cv2.morphologyEx(canvas, cv2.MORPH_OPEN, kernel)
        
        # Find the resulting pieces
        rebuilt_conts, _ = cv2.findContours(canvas, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
        
        for cont in rebuilt_conts:
            canvas = np.zeros_like(im_prep)
            cv2.fillPoly(canvas, [cont], color=(255))
            merged_contours.append(cont)

            # Show result
            if plot:
                resized = cv2.resize(canvas, (800,800))
                cv2.imshow(f'rebuilt_cont', resized)
                cv2.waitKey(0)
                cv2.destroyAllWindows()
    
    # Print effect of filtering
    #print(f"Initial number of pieces : {len(contours)}, merged number of pieces : {len(merged_contours)}")
    
    segmented_mask = np.zeros_like(im_prep)
    cv2.fillPoly(segmented_mask, merged_contours, color=(255))
    
    # Show the effect of filtering
    if plot:
        resized = cv2.resize(segmented_mask, (800,800))
        cv2.imshow(f'Post-filtering', resized)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    # Crop and rotate each piece
    pieces = np.zeros([len(merged_contours), piece_dim, piece_dim, 3], dtype=np.uint8)
    for idx, contour in enumerate(merged_contours):
        # Define canvases
        canvas = np.zeros_like(im_prep)
        cv2.fillPoly(canvas, [contour], color=(255))

        # Find bounding boxes
        coords = cv2.findNonZero(canvas)
        x1, y1, w1, h1 = cv2.boundingRect(coords)
        # Padd to have a margin of error
        x1 = max(0, x1 - max(0, int(np.ceil((piece_dim*1.5 - w1)/2))))
        y1 = max(0, y1 - max(0, int(np.ceil((piece_dim*1.5 - h1)/2))))
        w1 = max(w1, int(piece_dim*1.5))
        h1 = max(h1, int(piece_dim*1.5))

        crop_canvas = canvas[y1:y1+h1, x1:x1+w1]

        # Rotate by the right angle
        rect = cv2.minAreaRect(contour)
        angle = rect[2]
        rotated = imutils.rotate(crop_canvas, angle=angle)
        
        coords = cv2.findNonZero(rotated)
        x2, y2, w2, h2 = cv2.boundingRect(coords)
        # Padd to have a margin of error
        x2 = max(0, x2 - max(0, int(np.ceil((piece_dim - w2)/2))))
        y2 = max(0, y2 - max(0, int(np.ceil((piece_dim - h2)/2))))
        w2 = max(w2, int(piece_dim))
        h2 = max(h2, int(piece_dim))
        final = rotated[y2:y2+h2, x2:x2+w2]
        
        # Apply all operations on the full image
        canvas_full = original.copy()
        crop_canvas_full = canvas_full[y1:y1+h1, x1:x1+w1]
        rotated_full = imutils.rotate(crop_canvas_full, angle=angle)
        final_full = rotated_full[y2:y2+h2, x2:x2+w2]
        final_full = opt_crop(final, final_full)
        
        pieces[idx,:,:,:] = final_full.astype(np.uint8)
        #pieces.append(final_full)
        if plot:
            resized = cv2.resize(canvas, (800,800))
            cv2.imshow(f'Post-filtering', resized)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
            
            cv2.imshow(f'Post-filtering', pieces[idx,:,:,:])
            cv2.waitKey(0)
            cv2.destroyAllWindows()

    return pieces, segmented_mask

In [None]:
class PuzzleSolver():
    def __init__(self, pieces, margin=3):
        self.piece_dim = 128
        self.pieces = pieces
        self.margin = margin
        self.already_placed = []
        self.to_check = []
        self.solution = []
        self.canvas = np.zeros([1500,1500,3], dtype=np.uint8)
        
        # Get borders
        self.get_borders()
        
    def get_borders(self):
        n_p = self.pieces.shape[0]
        borders = np.zeros([n_p, 4, 4, self.piece_dim, 3]) # [piece, 4 rots, 4 sides, side len, channels]

        for idx, p in enumerate(self.pieces):
            for i in range(4):
                borders[idx, i, 0, :, :] = imutils.rotate(p, angle=90*i)[0+self.margin,:,:] # Top border
                borders[idx, i, 1, :, :] = imutils.rotate(p, angle=90*i)[:,piece_dim-1-self.margin,:] # Right border
                borders[idx, i, 2, :, :] = imutils.rotate(p, angle=90*i)[piece_dim-1-self.margin,:,:] # Bottom border
                borders[idx, i, 3, :, :] = imutils.rotate(p, angle=90*i)[:,0+self.margin,:] # Left border

        self.borders = borders

    def best_match(self, ref_idx):
        interest_side = (ref_idx[2]+2)%4
        ref_border = self.borders[ref_idx[0], ref_idx[1], ref_idx[2], :, :]
        comp = np.full([self.borders.shape[0], self.borders.shape[1], self.borders.shape[2]], np.inf)
        for i in range(self.borders.shape[0]): # Iter through pieces
            if i != ref_idx[0] and not(i in self.already_placed):
                for j in range(self.borders.shape[1]): # Iter through orientations
                    for k in range(self.borders.shape[2]): # Iter through sides
                        if k == interest_side:
                            comp[i,j,k] = np.sum(np.absolute(ref_border - self.borders[i,j,k,:,:]))

        min_pos = np.array(np.unravel_index(np.argmin(comp), comp.shape))
        
        return min_pos, np.min(comp)
    
    def try_all(self):
        values_arr = np.full([self.borders.shape[0], self.borders.shape[1], self.borders.shape[2]], np.inf)
        for i in range(self.borders.shape[0]): # Iter through pieces
            for j in range(self.borders.shape[1]): # Iter through orientations
                for k in range(self.borders.shape[2]): # Iter through sides
                    if [i,j,k] in self.to_check:
                        ref_idx = np.array([i,j,k])
                        _, value = self.best_match(ref_idx)
                        values_arr[i,j,k] = value

        return values_arr

    def get_sides(self, piece_idx):
        sides = []
        for i in range(4):
            sides.append([piece_idx[0], piece_idx[1], i])
        return sides
    
    def get_loc(self, ref_piece):
        for i in range(len(self.solution)):
            if self.solution[i][0] == [ref_piece[0], ref_piece[1]]:
                ref_loc = self.solution[i][1]
                if ref_piece[2] == 0:
                    match_loc = [ref_loc[0]-1, ref_loc[1]]
                if ref_piece[2] == 1:
                    match_loc = [ref_loc[0], ref_loc[1]+1]
                if ref_piece[2] == 2:
                    match_loc = [ref_loc[0]+1, ref_loc[1]]
                if ref_piece[2] == 3:
                    match_loc = [ref_loc[0], ref_loc[1]-1]
                
                break
        
        return match_loc
    
    def check_loc(self, match_loc):
        valid = True
        for i in range(len(self.solution)):
            if self.solution[i][1] == [match_loc[0], match_loc[1]]:
                valid = False
#                 print(f"Invalid loc = {match_loc}")
        
        return valid
        
    def solve(self):
        # Reset solution
        self.soltion = []
        
        # Place first piece
        p0_idx = [0,0] # Piece 0 with orientation 0
        p0_loc = [0,0] # Location (0,0)
        self.solution.append([p0_idx, p0_loc]) 
        self.already_placed.append(p0_idx[0])
        [self.to_check.append(x) for x in self.get_sides(p0_idx)]

        while len(self.already_placed) < len(self.pieces):
            # Find best loc
            values_arr = self.try_all()
            ref_piece = np.array(np.unravel_index(np.argmin(values_arr), values_arr.shape))
            match_piece, value = self.best_match(ref_piece)

            # Find new piece location
            match_idx = [match_piece[0], match_piece[1]]
            match_loc = self.get_loc(ref_piece)
            
            # Check validity of new location
            valid = self.check_loc(match_loc)
            #print(f"valid = {valid}")
            if valid :
                # Append to solution
                self.solution.append([match_idx, match_loc])
                
                # Update to check and already placed list
                self.already_placed.append(match_piece[0])
                [self.to_check.append(x) for x in self.get_sides(match_piece)]
                self.to_check.remove(list(ref_piece))
                self.to_check.remove(list(match_piece))
            else :
                # Edge leads to a taken place, remove from check list
                self.to_check.remove(list(ref_piece))


        
        
    def print_piece(self, piece_idx, position):
        origin = np.array([self.canvas.shape[0]/2, self.canvas.shape[1]/2], dtype=int)
        origin -= int(self.piece_dim/2)
        
        self.canvas[origin[0]+position[0]*self.piece_dim:origin[0]+(position[0]+1)*self.piece_dim, 
                    origin[1]+position[1]*self.piece_dim:origin[1]+(position[1]+1)*self.piece_dim, 
                    :] = imutils.rotate(self.pieces[piece_idx[0]], angle=90*piece_idx[1])
        
    def display(self):
        for i in range(len(self.solution)):
            self.print_piece(self.solution[i][0], self.solution[i][1])
        
        # Crop around solution
        coords = cv2.findNonZero(cv2.cvtColor(self.canvas, cv2.COLOR_BGR2GRAY))
        x, y, w, h = cv2.boundingRect(coords)
        self.canvas = self.canvas[y:y+h, x:x+w]
        
        plt.imshow(self.canvas)
        
    def save(self):
        pass
        

In [None]:
# Tests
do_tests = True

In [None]:
if do_tests:
    # Test thresholding
    original = load_input_image(5)
    im_prep = comb_transf(original)

    resized = cv2.resize(original, (800,800))
    cv2.imshow(f'Original image, RGB', resized)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

    resized = cv2.resize(im_prep, (800,800))
    cv2.imshow('Preprocessed Image', resized)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [None]:
if do_tests:
    # Test full preprocessing
    original = load_input_image(1)
    resized = cv2.resize(original, (800,800))
    cv2.imshow(f'Original image, RGB', resized)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

    pieces, seg_mask = isolate_pieces(original, plot=False)
    
    resized = cv2.resize(seg_mask, (800,800))
    cv2.imshow(f'Original image, RGB', resized)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [None]:
if do_tests:
    # Test full preprocessing for all training set
    for i in range(nb_train_samples):
        original = load_input_image(i)
        pieces, seg_mask = isolate_pieces(original, plot=False)
        

In [None]:
if do_tests:
    # Manually selected test set for puzzle assembly
    original = load_input_image(3)
    pieces, _ = isolate_pieces(original, plot=False)
    p1_sel = np.array([0,1,2,5,11,12,13,16,17])
    p1 = pieces[p1_sel]

    myPuzzle = PuzzleSolver(p1, margin=2)
    myPuzzle.solve()
    myPuzzle.display()

In [None]:
if do_tests:
    # Segmentation
    original = load_input_image(4)
    pieces, _ = isolate_pieces(original, plot=False)
    # Manually selected test set for puzzle assembly
    p2_sel = np.array([2,4,6,10,12,13,14,17,19])
    p2 = pieces[p2_sel]

    myPuzzle = PuzzleSolver(p2, margin=2)
    myPuzzle.solve()
    myPuzzle.display()