In [None]:
%load_ext autoreload
%autoreload 2
from datetime import datetime
from functools import partial
from pathlib import Path

import matplotlib.pyplot as plt
import tensorflow as tf

from watch_recognition.data_preprocessing import load_keypoints_data
from watch_recognition.models import build_backbone
from watch_recognition.reports import log_distances, run_on_image_debug, generate_report_for_keypoints

plt.style.use("dark_background")
%matplotlib inline


EPOCHS = 150
image_size = (224, 224)

In [None]:
base_model = build_backbone(image_size)

In [None]:
# build model

inputs = tf.keras.Input(shape=(*image_size, 3), )
x = base_model(inputs)
# x = tf.keras.layers.Dropout(0.3)(x)
output = tf.keras.layers.Conv2D(
    filters=4, kernel_size=1, strides=1, padding="same", activation="sigmoid"
)(x)


model = tf.keras.models.Model(
    inputs=inputs, outputs=output
)

In [None]:
model.summary()


In [None]:
model.output.shape[1:]

In [None]:
X, y, y_labels = load_keypoints_data(
    Path("../data/watch-points/tags.csv"),
    model_output_shape=model.output.shape[1:],
)
X.shape, y.shape, y_labels.shape

In [None]:
X_val, y_val, y_val_labels = load_keypoints_data(
    Path("../data/watch-points/tags.csv"),
    split='validation',
    model_output_shape=model.output.shape[1:],
)
X_val.shape, y_val.shape, y_val_labels.shape

In [None]:

optimizer = tf.keras.optimizers.Adam(1e-2)
model.compile(loss='binary_crossentropy', optimizer=optimizer)

start = datetime.now()
TYPE='keypoint'
MODEL_NAME='efficientnetb0'
logdir = f"tensorboard_logs/{TYPE}/{MODEL_NAME}/run_{start.timestamp()}"
print(logdir)
file_writer_distance_metrics_train = tf.summary.create_file_writer(logdir + "/train")
file_writer_distance_metrics_validation = tf.summary.create_file_writer(logdir + "/validation")

model.fit(
    X, y,
    epochs=EPOCHS,
    validation_data=(X_val, y_val),
    callbacks=[
        tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            update_freq="epoch",
        ),
        # tf.keras.callbacks.EarlyStopping(
        #     monitor="val_loss",
        #     restore_best_weights=True,
        #     patience=10,
        # ),
        tf.keras.callbacks.LambdaCallback(
            on_epoch_end=partial(log_distances, X=X, y=y,
                                 file_writer=file_writer_distance_metrics_train, model=model)),
        tf.keras.callbacks.LambdaCallback(
            on_epoch_end=partial(log_distances, X=X_val, y=y_val,
                                 file_writer=file_writer_distance_metrics_validation, model=model)),
    ]
)
elapsed = (datetime.now() - start).seconds
print(f"total training time: {elapsed / 60} minutes, average: {elapsed/60/EPOCHS} minutes/epoch")

In [None]:
run_on_image_debug(model, X[0], y[0], show_grid=False)


In [None]:
path = Path("../example_data/test-image-2.jpg")
test_image = tf.keras.preprocessing.image.load_img(
    path, "rgb", target_size=image_size, interpolation="bicubic",
)
test_image_np = tf.keras.preprocessing.image.img_to_array(test_image)

In [None]:
run_on_image_debug(model, test_image_np, show_grid=False)

In [None]:
generate_report_for_keypoints(model, X, None, y_labels, show_top_n_errors=5)

In [None]:
generate_report_for_keypoints(model, X_val, None, y_val_labels, show_top_n_errors=5)