# Evaluate puzzle localizer

In this notebook, we measure the performance of the puzzle localizer on the task of finding the keypoints (corners) of the Sudoku puzzle.

In [1]:
from pathlib import Path
from tqdm import tqdm
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from models.puzzle_localizer.model import HeuristicPuzzleLocalizer
from models.puzzle_localizer.dataset import get_data

In [2]:
model = HeuristicPuzzleLocalizer()

In [3]:
# For model selection, use "val".
# No need to use "test" here as we can evaluate the end-to-end model on the test set.

split = "val"

dataset = tf.data.Dataset.from_generator(lambda: get_data(split), output_types=(tf.float32, tf.float32))

In [8]:
# We'll compare the MSE between the predicted and actual keypoints to this threshold for each example
# to determine if it was localized correctly or not.

# The threshold was obtained by visual inspection. If the MSE is lower than this, then the predicted
# keypoints look almost indistinguishable from the real ones and there is a high chance that the cells
# will be extracted correctly.

MSE_THRESHOLD = 9e-5

In [10]:
accuracy = 0
mse = 0
count = 0

for image, keypoints in tqdm(dataset):
    count += 1
    keypoints_pred = model(image)
    example_mse = tf.keras.metrics.mse(keypoints, keypoints_pred).numpy().item()
    accuracy += (example_mse < MSE_THRESHOLD)
    mse += example_mse

accuracy /= count
mse /= count

print(f"Accuracy: {accuracy:%}")
print(f"MSE: {mse}")

80it [04:22,  3.28s/it]

Accuracy: 98.750000%
MSE: 0.005074748279573526





## Debug

In [5]:
# Capture mse to use as a threshold on examples.

mse_threshold = mse

In [6]:
# Find examples where the mse is worse than the average.

bad_examples = []

for image, keypoints in tqdm(dataset):
    keypoints_pred = model(image)
    mse = tf.keras.metrics.mse(keypoints, keypoints_pred)
    if mse > mse_threshold:
        bad_examples.append((image, keypoints, keypoints_pred, mse))

print(f"Found {len(bad_examples)} bad examples.")

80it [04:24,  3.31s/it]

Found 1 bad examples.





In [7]:
for i, example in enumerate(bad_examples):
    print(f"MSE on bad example {i:>2}: {example[-1]:.5f}")

MSE on bad example  0: 0.40563


In [15]:
image, keypoints, keypoints_pred, mse = bad_examples[0]

In [None]:
# Show original image

fig, ax = plt.subplots()

ax.set_title("Original image")
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.imshow(image)

plt.show()

In [17]:
import PIL.Image
import PIL.ImageDraw

def draw_keypoints(image, keypoints, color):
    image = tf.keras.preprocessing.image.array_to_img(image)
    draw = PIL.ImageDraw.Draw(image)
    draw.polygon(keypoints, outline=color, width=8)
    image = tf.keras.preprocessing.image.img_to_array(image) / 255.0
    return image

In [18]:
def relative_keypoints_to_absolute(image, keypoints):
    # Scale the keypoints to be in terms of pixels
    keypoints_x = keypoints[0::2]
    keypoints_y = keypoints[1::2]
    keypoints_x = keypoints_x * image.shape[1]
    keypoints_y = keypoints_y * image.shape[0]
    keypoints = tf.reshape(tf.stack([keypoints_x, keypoints_y], axis=-1), (-1,))
    return keypoints

In [None]:
# Show image with predicted keypoints

keypoints = model(image)
keypoints = relative_keypoints_to_absolute(image, keypoints)
image_with_predicted_keypoints = draw_keypoints(image, keypoints, "limegreen")

fig, ax = plt.subplots()

ax.set_title("Image with predicted keypoints")
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.imshow(image_with_predicted_keypoints)

plt.show()

In [None]:
image_grayscale = tf.image.rgb_to_grayscale(image)

fig, ax = plt.subplots()

ax.set_title("Grayscale image")
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.imshow(image_grayscale, cmap="gray")

plt.show()

In [None]:
from models.puzzle_localizer.model import adaptive_threshold

image_thresholded = adaptive_threshold(
    tf.expand_dims(image_grayscale, axis=0), blur_size=model.blur_size, threshold=model.threshold
)[0]

fig, ax = plt.subplots()

ax.set_title("Thresholded image")
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.imshow(image_thresholded, cmap="gray")

plt.show()