In [None]:
import numpy as np
import tensorflow as tf
from pathlib import Path

import matplotlib.patches as patches
import matplotlib.pyplot as plt

from watch_recognition.data_preprocessing import load_keypoints_data
from watch_recognition.models import build_backbone

%matplotlib inline


EPOCHS = 2
image_size = (224, 224)

In [None]:
base_model = build_backbone(image_size)

In [None]:
# build model

inputs = tf.keras.Input(shape=(*image_size, 3), )
x = base_model(inputs)
output = tf.keras.layers.Conv2D(
    filters=4, kernel_size=1, strides=1, padding="same", activation="sigmoid"
)(x)

model = tf.keras.models.Model(
    inputs=inputs, outputs=output
)

In [None]:
model.summary()


In [None]:
model.output.shape[1:]

In [None]:
X, y = load_keypoints_data(Path("../data/watch-points/tags.csv"),
                           model_output_shape=model.output.shape[1:],
                           )
X_val, y_val = load_keypoints_data(Path("../data/watch-points/tags.csv"),
                                   split='validation',
                                   model_output_shape=model.output.shape[1:],
                                   )
X.shape, y.shape

In [None]:
model.compile(loss='binary_crossentropy', optimizer='adam')

In [None]:
model.fit(
    X, y,
    epochs=EPOCHS,
    validation_data=(X_val, y_val),
)

In [None]:

def run_on_image_debug(model, image):
    predicted = model(np.expand_dims(image, 0)).numpy()
    downsample_factor = image.shape[1] / predicted.shape[1]
    for i, point in enumerate(predicted[0].transpose((2, 1, 0))):
        fig, ax = plt.subplots(1, 3)

        ax[0].imshow(predicted[0, :, :, i],
                     extent=[0, predicted.shape[1], predicted.shape[1], 0])
        ax[1].imshow(point, extent=[0, predicted.shape[1], predicted.shape[1], 0])

        ax[2].imshow(image.astype('uint8'),
                     extent=[0, image.shape[0], image.shape[1], 0])
        for j in range(predicted.shape[1]):
            ax[2].axvline(j * downsample_factor)
        for j in range(predicted.shape[2]):
            ax[2].axhline(j * downsample_factor)
        grid_predicted = np.unravel_index(np.argmax(predicted[0, :, :, i]),
                                          predicted[0, :, :, i].shape)

        rectangle_predicted = grid_predicted[1] * downsample_factor, grid_predicted[
            0] * downsample_factor

        rect_pred = patches.Rectangle(rectangle_predicted, downsample_factor,
                                      downsample_factor,
                                      linewidth=1, edgecolor='r', facecolor='red')

        ax[2].add_patch(rect_pred)
        plt.show()
    center = np.array(
        np.unravel_index(np.argmax(predicted[0, :, :, 0]), predicted.shape[1:3]))[::-1]
    hour = np.array(
        np.unravel_index(np.argmax(predicted[0, :, :, 1]), predicted.shape[1:3]))[
           ::-1] - center
    minute = np.array(
        np.unravel_index(np.argmax(predicted[0, :, :, 2]), predicted.shape[1:3]))[
             ::-1] - center
    top = np.array(
        np.unravel_index(np.argmax(predicted[0, :, :, 3]), predicted.shape[1:3]))[
          ::-1] - center
    read_hour = np.rad2deg(
        np.arctan2(top[0], top[1]) - np.arctan2(hour[0], hour[1])) / 360 * 12
    if read_hour < 0:
        read_hour += 12

    read_minute = np.rad2deg(
        np.arctan2(top[0], top[1]) - np.arctan2(minute[0], minute[1])) / 360 * 60

    if read_minute < 0:
        read_minute += 60
    print(f"{read_hour:.0f}:{read_minute:.0f}")

In [None]:
run_on_image_debug(model, X[0])


In [None]:
path = Path("../example_data/test-image-2.jpg")
test_image = tf.keras.preprocessing.image.load_img(
    path, "rgb", target_size=image_size, interpolation="bicubic"
)
test_image_np = tf.keras.preprocessing.image.img_to_array(test_image)
run_on_image_debug(model, test_image_np)
