In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Rectangle

from blazeface.dataset import input_dataset, anchors, target_encoder, prediction_decoder, utils
from blazeface.model import losses

# Load data

In [None]:
# take small subset of the training set to analyze
data_train, info = input_dataset.load_the300w_lp(split="train[:10%]")

In [None]:
for x in data_train.take(1):
    break

print(x.keys())

# Visualize raw labels

In [None]:
def visualize_landmarks(sample, ax, landmarks=None):
    """
    Args:
        sample (dict): must contain keys 'image' and 'landmarks_2d'.
        ax (AxesSubplot):
        landmarks (np.ndarray): if given, it will override 'landmarks_2d' in sample.
    """
    img = sample['image']
    if landmarks is None:
        landmarks = sample['landmarks_2d'].numpy()
    shape = tf.shape(img).numpy()
    ax.scatter(landmarks[:,0] * shape[0], landmarks[:,1] * shape[1], alpha=0.6, s=2, c='red');

In [None]:
n_rows = 5
n_cols = 5
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 15))
axes = np.ravel(axes)

for i, x in enumerate(data_train.take(n_rows * n_cols)):
    ax = axes[i]
    ax.imshow(x['image'])
    visualize_landmarks(x, ax)

# Visualize preprocessed data (inputs to BlazeFace)

In [None]:
def visualize_bbox(sample, ax):
    img = sample['image']
    shape = tf.shape(img).numpy()
    x1, y1, x2, y2 = input_dataset.landmarks_to_bboxes(x['landmarks_2d']).numpy()
    x1 *= shape[1]
    y1 *= shape[0]
    x2 *= shape[1]
    y2 *= shape[0]
    rect = Rectangle((x1, y1), x2 - x1, y2 - y1, fc="None", ec='green')
    ax.add_patch(rect)

In [None]:
n_rows = 5
n_cols = 5
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 15))
axes = np.ravel(axes)

for i, x in enumerate(data_train.take(n_rows * n_cols)):
    ax = axes[i]
    ax.imshow(x['image'])
    # visualize bbox
    visualize_bbox(x, ax)
    # visualize landmarks
    landmarks_2d = input_dataset.reduce_landmarks(x['landmarks_2d']).numpy()
    visualize_landmarks(x, ax, landmarks=landmarks_2d)

# Visualize anchors

In [None]:
all_anchors = anchors.generate_anchors()

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

for i, loc in enumerate(all_anchors):
    y1, x1, w, h = loc
    c = mcolors.CSS4_COLORS[list(mcolors.CSS4_COLORS.keys())[int(i % len(mcolors.CSS4_COLORS.keys()))]]
    rect = Rectangle((x1 - w/2, y1-h/2), w, h, fc="None", ec=c, alpha=0.9, lw=0.5)
    ax.add_patch(rect)

# Create input dataset

In [None]:
ds = data_train
ds = ds.map(input_dataset.unpack_dct)
ds = ds.map(input_dataset.preprocess_image)
ds = ds.map(lambda img, lmarks: (img, input_dataset.landmarks_to_bboxes(lmarks), input_dataset.reduce_landmarks(lmarks)))

ds = ds.batch(12)

In [None]:
for sample_batch in ds.take(1):
    break

[element.shape for element in sample_batch]

In [None]:
deltas, labels = target_encoder.calculate_targets(all_anchors, sample_batch[1], sample_batch[2])

deltas = deltas.numpy()
labels = labels.numpy()
deltas.shape, labels.shape

In [None]:
dec = prediction_decoder.get_bboxes_and_landmarks_from_deltas(all_anchors, deltas)
dec.shape

In [None]:
dec[0][np.reshape(labels, (8, 896))[0] > 0]

In [None]:
labels.shape

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
colors = dict(zip(range(len(mcolors.CSS4_COLORS)), mcolors.CSS4_COLORS.values()))

i = 5
img = sample_batch[0][i]
bbox = sample_batch[1][i]
ax.imshow(img)

x1, y1, x2, y2 = bbox[0]
x1 *= img.shape[1]
y1 *= img.shape[0]
x2 *= img.shape[1]
y2 *= img.shape[0]
rect = Rectangle((x1, y1), x2 - x1, y2 - y1, fc="None", ec='green', lw=3)
ax.add_patch(rect)


for ci, pos_anchor in enumerate(all_anchors[tf.cast(labels[i, :, 0], dtype=tf.bool)]):
    x1, y1, x2, y2 = utils.xywh_to_xyxy(pos_anchor)
    x1 *= img.shape[1]
    y1 *= img.shape[0]
    x2 *= img.shape[1]
    y2 *= img.shape[0]
    rect = Rectangle((x1, y1), x2 - x1, y2 - y1, fc="None", ec=colors[5*ci])
    ax.add_patch(rect)

In [None]:
class_loss = losses.ClassLoss()
reg_loss = losses.RegressionLoss()

reg_loss(deltas, deltas + tf.random.normal(deltas.shape, 0, 0.5, dtype=tf.float32))

class_loss(labels, tf.cast(tf.random.uniform(labels.shape, 0, 1, dtype=tf.float32), dtype=tf.float32))