In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Rectangle

from blazeface.constants import *
from blazeface.dataset import input_dataset, anchors, target_encoder, prediction_decoder, utils
from blazeface.model import blazeface, losses

# Load trained model

In [None]:
model, all_anchors = blazeface.load_model('../data/experiments/20221031_run000/checkpoints/weights-16.hdf5')

# Load data

In [None]:
# take small subset of the training set to analyze
data_validation, info = input_dataset.load_the300w_lp(split="train[10%:]")

ds_validation = input_dataset.create_images_dataset(data_validation, batch_size=1)

# Inference

In [None]:
for x in ds_validation.take(5): pass

In [None]:
predictions = model.predict(x)

In [None]:
pred_coordinates = prediction_decoder.get_bboxes_and_landmarks_from_deltas(all_anchors, predictions['deltas'])

predictions['labels'].shape, pred_coordinates.shape

In [None]:
pred_scores = tf.cast(predictions['labels'], tf.float32)

In [None]:
weighted_suppressed_data = prediction_decoder.weighted_suppression(pred_scores[0], pred_coordinates[0])

In [None]:
weighted_bboxes = weighted_suppressed_data[..., 0:4]
weighted_landmarks = weighted_suppressed_data[..., 4:]

In [None]:
denormalized_bboxes = utils.denormalize_bboxes(weighted_bboxes, IMG_SIZE, IMG_SIZE)
weighted_landmarks = tf.reshape(weighted_landmarks, (-1, N_LANDMARKS, 2))
denormalized_landmarks = utils.denormalize_landmarks(weighted_landmarks, IMG_SIZE, IMG_SIZE)

In [None]:
fig, ax = plt.subplots(figsize=(7, 7))

ax.imshow(x[0])

for index, bbox in enumerate(denormalized_bboxes):
    x1, y1, x2, y2 = tf.split(bbox, 4)
    width = x2 - x1
    height = y2 - y1
    if width <= 0 or height <= 0:
        continue
    rect = Rectangle((x1, y1), width, height, fc="None", ec='green', lw=2, alpha=0.7)
    ax.add_patch(rect)
for index, landmark in enumerate(denormalized_landmarks):
    if tf.reduce_max(landmark) <= 0:
        continue
    ax.scatter(landmark[:, 0], landmark[:, 1], alpha=0.9, s=20)

In [None]:
temp = pred_coordinates[predictions['labels'][:, :, 0] > 0.9]
temp_bboxes = temp[..., 0:4]
temp_lmarks = temp[..., 4:]
temp_lmarks = tf.reshape(temp_lmarks, (-1, N_LANDMARKS, 2))

In [None]:
fig, ax = plt.subplots()
ax.imshow(x[0])
ax.scatter(utils.denormalize_landmarks(temp_lmarks, 128, 128)[0, :, 0], utils.denormalize_landmarks(temp_lmarks, 128, 128)[0, :, 1])

In [None]:
fig, ax = plt.subplots()
ax.imshow(x[1])
ax.scatter(utils.denormalize_landmarks(temp_lmarks, 128, 128)[0, :, 0], utils.denormalize_landmarks(temp_lmarks, 128, 128)[0, :, 1])

In [None]:
fig, ax = plt.subplots()

ax.scatter(temp[0, 4::2], temp[0, 5::2])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.invert_yaxis();

In [None]:
weighted_suppressed_data = prediction_decoder.weighted_suppression(predictions['labels'][0], pred_coordinates[0])

weighted_bboxes = weighted_suppressed_data[..., 0:4]
weighted_landmarks = weighted_suppressed_data[..., 4:]

denormalized_bboxes = utils.denormalize_bboxes(weighted_bboxes, x.shape[2], x.shape[1])
weighted_landmarks = tf.reshape(weighted_landmarks, (-1, N_LANDMARKS, 2))
denormalized_landmarks = utils.denormalize_landmarks(weighted_landmarks, x.shape[2], x.shape[1])

In [None]:
n_rows = 3
n_cols = 3
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 15))
axes = np.ravel(axes)

for i, ax in enumerate(axes):
    ax.imshow(x[i])
    
    weighted_suppressed_data = prediction_decoder.weighted_suppression(predictions['labels'][i], pred_coordinates[i])
    weighted_bboxes = weighted_suppressed_data[..., 0:4]
    weighted_landmarks = weighted_suppressed_data[..., 4:]
    denormalized_bboxes = utils.denormalize_bboxes(weighted_bboxes, x.shape[2], x.shape[1])
    weighted_landmarks = tf.reshape(weighted_landmarks, (-1, N_LANDMARKS, 2))
    denormalized_landmarks = utils.denormalize_landmarks(weighted_landmarks, x.shape[2], x.shape[1])
    x1, y1, x2, y2 = denormalized_bboxes[0]
#     print(denormalized_bboxes[0])
#     print(denormalized_landmarks[0])
    
    rect = Rectangle((x1, y1), x2 - x1, y2 - y1, fc="None", ec='green', lw=2)
    ax.add_patch(rect)
    ax.scatter(denormalized_landmarks[:,0], denormalized_landmarks[:,1], alpha=0.6, s=3, c='red')