Below are links to resources related to this notebook:
* [W&B Project](https://wandb.ai/vincenttu/torch_vs_tf_talmo_lab?workspace=user-vincenttu)
* [GitHub](https://github.com/alckasoc/sleap_keypoint_tf_torch)



# Install SLEAP
Don't forget to set **Runtime** -> **Change runtime type...** -> **GPU** as the accelerator.

In [None]:
!pip install sleap -qqq
!pip install nvidia-ml-py3 -qqq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.4/64.4 MB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m71.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m54.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.5/60.5 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.9/84.9 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m69.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.9/43.9 MB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import gc
import random
import time

import nvidia_smi

import numpy as np
import tensorflow as tf

import sleap

sleap.versions()

SLEAP: 1.3.1
TensorFlow: 2.8.4
Numpy: 1.22.4
Python: 3.10.12
OS: Linux-5.15.107+-x86_64-with-glibc2.31


# Utils

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.keras.utils.set_random_seed(seed)

def get_vram():
    nvidia_smi.nvmlInit()

    deviceCount = nvidia_smi.nvmlDeviceGetCount()
    for i in range(deviceCount):
        handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
        info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
        output = ("Device {}: {}, Memory : ({:.2f}% free): {} (total), {} (free), {} (used)"
              .format(i, nvidia_smi.nvmlDeviceGetName(handle), 100*info.free/info.total,
                      info.total/(1024 ** 3), info.free/(1024 ** 3), info.used/(1024 ** 3)))

    nvidia_smi.nvmlShutdown()

    return output

def get_param_count(model):
  trainable_params = np.sum([np.prod(v.get_shape()) for v in model.trainable_weights])
  nontrainable_params = np.sum([np.prod(v.get_shape()) for v in model.non_trainable_weights])
  total_params = trainable_params + nontrainable_params
  return trainable_params, nontrainable_params, total_params

In [None]:
seed = 42
seed_everything(seed)

# Download training data

In [None]:
!curl -L --output labels.slp https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/tracking_split2/train.pkg.slp
!ls -lah

!curl -L --output val_labels.slp https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/tracking_split2/val.pkg.slp
!ls -lah

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  619M  100  619M    0     0  47.1M      0  0:00:13  0:00:13 --:--:-- 65.7M
total 620M
drwxr-xr-x 1 root root 4.0K Jun 26 17:21 .
drwxr-xr-x 1 root root 4.0K Jun 26 17:18 ..
drwxr-xr-x 4 root root 4.0K Jun 23 13:40 .config
-rw-r--r-- 1 root root 620M Jun 26 17:21 labels.slp
drwxr-xr-x 1 root root 4.0K Jun 23 13:41 sample_data
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 77.2M  100 77.2M    0     0  22.8M      0  0:00:03  0:00:03 --:--:-- 22.8M
total 697M
drwxr-xr-x 1 root root 4.0K Jun 26 17:21 .
drwxr-xr-x 1 root root 4.0K Jun 26 17:18 ..
drwxr-xr-x 4 root root 4.0K Jun 23 13:40 .config
-rw-r--r-- 1 root root 620M Jun 26 17:21 labels.slp
drwxr-xr-x 1 root root 4.0K Jun 23 13:41 sample_data
-rw-r--r-- 1 root root

# Load the training data

In [None]:
# SLEAP Labels files (.slp) can include the images as well as labeled instances and
# other metadata for a project.
labels = sleap.load_file("labels.slp")
labels = labels.with_user_labels_only()
labels.describe()

Skeleton: Skeleton(description=None, nodes=[head, thorax, abdomen, wingL, wingR, forelegL4, forelegR4, midlegL4, midlegR4, hindlegL4, hindlegR4, eyeL, eyeR], edges=[thorax->head, thorax->abdomen, thorax->wingL, thorax->wingR, thorax->forelegL4, thorax->forelegR4, thorax->midlegL4, thorax->midlegR4, thorax->hindlegL4, thorax->hindlegR4, head->eyeL, head->eyeR], symmetries=[forelegL4<->forelegR4, hindlegL4<->hindlegR4, wingL<->wingR, midlegL4<->midlegR4, eyeL<->eyeR])
Videos: ['labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp']
Frames (user/predicted): 1,600/0
Instances (user/predicted): 3,200/0
Tracks: [Track(spawned_on=0, na

In [None]:
# Let's also do the same for the val labels.
val_labels = sleap.load_file("val_labels.slp")
val_labels = val_labels.with_user_labels_only()
val_labels.describe()

Skeleton: Skeleton(description=None, nodes=[head, thorax, abdomen, wingL, wingR, forelegL4, forelegR4, midlegL4, midlegR4, hindlegL4, hindlegR4, eyeL, eyeR], edges=[thorax->head, thorax->abdomen, thorax->wingL, thorax->wingR, thorax->forelegL4, thorax->forelegR4, thorax->midlegL4, thorax->midlegR4, thorax->hindlegL4, thorax->hindlegR4, head->eyeL, head->eyeR], symmetries=[midlegL4<->midlegR4, forelegL4<->forelegR4, eyeL<->eyeR, hindlegL4<->hindlegR4, wingL<->wingR])
Videos: ['val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp'

In [None]:
# Labels are list-like containers whose elements are LabeledFrames
print(f"Number of labels: {len(labels)}")

labeled_frame = labels[0]
labeled_frame

Number of labels: 1600


LabeledFrame(video=HDF5Video('labels.slp'), frame_idx=166050, instances=2)

In [None]:
# LabeledFrames are containers for instances that were labeled in a single frame
instance = labeled_frame[0]
instance

Instance(video=Video(filename=labels.slp, shape=(66, 1024, 1024, 1), backend=HDF5Video), frame_idx=166050, points=[head: (491.6, 187.7), thorax: (474.4, 224.8), abdomen: (459.9, 262.2), wingL: (448.3, 271.7), wingR: (452.1, 273.5), forelegL4: (478.5, 175.9), forelegR4: (499.9, 177.9), midlegL4: (440.6, 216.4), midlegR4: (510.1, 242.7), hindlegL4: (437.2, 234.3), hindlegR4: (490.9, 266.7), eyeL: (477.5, 193.2), eyeR: (498.4, 201.2)], track=Track(spawned_on=0, name='female'))

In [None]:
# They can be converted to numpy arrays where each row corresponds to the coordinates
# of a different body part:
pts = instance.numpy()
pts

rec.array([[491.58118169, 187.72078779],
           [474.3603939 , 224.80196948],
           [459.90098474, 262.16236338],
           [448.26137864, 271.72078779],
           [452.08118169, 273.54059084],
           [478.5       , 175.90098474],
           [499.94157558, 177.90098474],
           [440.58118169, 216.3603939 ],
           [510.12177253, 242.72078779],
           [         nan,          nan],
           [490.90098474, 266.72078779],
           [477.54059084, 193.16236338],
           [498.40098474, 201.18019695]],
          dtype=float64)

# Setup training data generation

In [None]:
# Initialize a pipeline from the labels.
p = labels.with_user_labels_only().to_pipeline()

# This pipeline will output dictionaries with tensors containing frame data:
p.describe()

         image: type=EagerTensor, shape=(1024, 1024, 1), dtype=tf.uint8, device=/job:localhost/replica:0/task:0/device:CPU:0
raw_image_size: type=EagerTensor, shape=(3,), dtype=tf.int32, device=/job:localhost/replica:0/task:0/device:CPU:0
   example_ind: type=EagerTensor, shape=(), dtype=tf.int64, device=/job:localhost/replica:0/task:0/device:CPU:0
     video_ind: type=EagerTensor, shape=(), dtype=tf.int32, device=/job:localhost/replica:0/task:0/device:CPU:0
     frame_ind: type=EagerTensor, shape=(), dtype=tf.int64, device=/job:localhost/replica:0/task:0/device:CPU:0
         scale: type=EagerTensor, shape=(2,), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
     instances: type=EagerTensor, shape=(2, 13, 2), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
 skeleton_inds: type=EagerTensor, shape=(2,), dtype=tf.int32, device=/job:localhost/replica:0/task:0/device:CPU:0
    track_inds: type=EagerTensor, shape=(2,), dtype=tf.int32, device=/job

In [None]:
# Let's add some transformations necessary for the centered-instance model.
p = labels.with_user_labels_only().to_pipeline()
p += sleap.pipelines.ImgaugAugmenter.from_config(sleap.pipelines.AugmentationConfig(rotate=True, rotation_min_angle=-180, rotation_max_angle=180))
p += sleap.pipelines.Normalizer()
p += sleap.pipelines.InstanceCentroidFinder(center_on_anchor_part=True, anchor_part_names="thorax", skeletons=labels.skeletons)
p += sleap.pipelines.InstanceCropper(crop_width=160, crop_height=160)
p += sleap.pipelines.InstanceConfidenceMapGenerator(sigma=1.5, output_stride=2)
p += sleap.pipelines.Batcher(batch_size=4, drop_remainder=True)
p.describe()

          instance_image: type=EagerTensor, shape=(4, 160, 160, 1), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
                    bbox: type=EagerTensor, shape=(4, 4), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
         center_instance: type=EagerTensor, shape=(4, 13, 2), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
     center_instance_ind: type=EagerTensor, shape=(4, 1), dtype=tf.int32, device=/job:localhost/replica:0/task:0/device:CPU:0
               track_ind: type=EagerTensor, shape=(4, 1), dtype=tf.int32, device=/job:localhost/replica:0/task:0/device:CPU:0
           all_instances: type=EagerTensor, shape=(4, 2, 13, 2), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
                centroid: type=EagerTensor, shape=(4, 2), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
       full_image_height: type=EagerTensor, shape=(4, 1), dtype=tf.int32, device=/job:l

In [None]:
# Let's build our validation pipeline.
# Note, we didn't include the augmentations.
val_p = labels.with_user_labels_only().to_pipeline()
val_p += sleap.pipelines.Normalizer()
val_p += sleap.pipelines.InstanceCentroidFinder(center_on_anchor_part=True, anchor_part_names="thorax", skeletons=labels.skeletons)
val_p += sleap.pipelines.InstanceCropper(crop_width=160, crop_height=160)
val_p += sleap.pipelines.InstanceConfidenceMapGenerator(sigma=1.5, output_stride=2)
val_p += sleap.pipelines.Batcher(batch_size=4, drop_remainder=True)
val_p.describe()

          instance_image: type=EagerTensor, shape=(4, 160, 160, 1), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
                    bbox: type=EagerTensor, shape=(4, 4), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
         center_instance: type=EagerTensor, shape=(4, 13, 2), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
     center_instance_ind: type=EagerTensor, shape=(4, 1), dtype=tf.int32, device=/job:localhost/replica:0/task:0/device:CPU:0
               track_ind: type=EagerTensor, shape=(4, 1), dtype=tf.int32, device=/job:localhost/replica:0/task:0/device:CPU:0
           all_instances: type=EagerTensor, shape=(4, 2, 13, 2), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
                centroid: type=EagerTensor, shape=(4, 2), dtype=tf.float32, device=/job:localhost/replica:0/task:0/device:CPU:0
       full_image_height: type=EagerTensor, shape=(4, 1), dtype=tf.int32, device=/job:l

# Setting up a neural network model

In [None]:
# Instantiate the backbone builder.
unet = sleap.nn.architectures.unet.UNet(filters=32, filters_rate=1.5, down_blocks=4, up_blocks=3, up_interpolate=True)

# Create the input layer (see above for the dimensions)
x_in = tf.keras.layers.Input((160, 160, 1))

# Create the feature extractor backbone.
x_features, x_intermediate = unet.make_backbone(x_in)

# Do a 1x1 conv with linear activation to remap activations to the number of channels in
# the confidence maps (see above)
x_confmaps = tf.keras.layers.Conv2D(filters=13, kernel_size=1, strides=1, padding="same")(x_features)

# Create a Model that links the whole graph
model = tf.keras.Model(x_in, x_confmaps)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 160, 160, 1  0           []                               
                                )]                                                                
                                                                                                  
 stack0_enc0_conv0 (Conv2D)     (None, 160, 160, 32  320         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 stack0_enc0_act0_relu (Activat  (None, 160, 160, 32  0          ['stack0_enc0_conv0[0][0]']      
 ion)                           )                                                             

# Train the model

In [None]:
# Setup the optimizer and loss function.
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
loss_fn = tf.keras.losses.MeanSquaredError()

# Define a "training step" function. This does the forward/backward passes and applies
# the gradients to update the model weights.
@tf.function
def train_step(ex, model, optimizer, loss_fn):
    with tf.GradientTape() as tape:
        predicted_confmaps = model(ex["instance_image"])
        loss = loss_fn(ex["instance_confidence_maps"], predicted_confmaps)

    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    return loss

@tf.function
def val_step(ex, model, loss_fn):
    predicted_confmaps = model(ex["instance_image"])
    loss = loss_fn(ex["instance_confidence_maps"], predicted_confmaps)

    return loss

In [None]:
unet = sleap.nn.architectures.unet.UNet(filters=32, filters_rate=1.5, down_blocks=4, up_blocks=3, up_interpolate=True)
x_in = tf.keras.layers.Input((160, 160, 1))
x_features, x_intermediate = unet.make_backbone(x_in)
x_confmaps = tf.keras.layers.Conv2D(filters=13, kernel_size=1, strides=1, padding="same")(x_features)
model = tf.keras.Model(x_in, x_confmaps)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
loss_fn = tf.keras.losses.MeanSquaredError()

# Training loop, go!
epochs = 3
for epoch in range(epochs):
    start_time = time.time()
    train_loss = 0
    for step, ex in enumerate(p.make_dataset()):
        loss = train_step(ex, model, optimizer, loss_fn)

        if step % 100 == 0:
            print(f"Epoch {epoch:03d} | Step {step:03d} | loss = {loss:.5f}")

        train_loss += loss

    train_loss /= (step+1)
    train_time = time.time() - start_time
    print(f"TRAIN: --- {train_time}s seconds ---")

    start_time = time.time()
    val_loss = 0
    for step, ex in enumerate(val_p.make_dataset()):
        loss = val_step(ex, model, loss_fn)
        val_loss += loss

    val_loss /= (step+1)
    val_time = time.time() - start_time
    print(f"VAL: --- {val_time}s seconds ---")

Epoch 000 | Step 000 | loss = 0.00120
Epoch 000 | Step 100 | loss = 0.00094
Epoch 000 | Step 200 | loss = 0.00104
Epoch 000 | Step 300 | loss = 0.00103
Epoch 000 | Step 400 | loss = 0.00100
Epoch 000 | Step 500 | loss = 0.00098
Epoch 000 | Step 600 | loss = 0.00097
Epoch 000 | Step 700 | loss = 0.00092
TRAIN: --- 69.75044798851013s seconds ---
VAL: --- 42.28296756744385s seconds ---
Epoch 001 | Step 000 | loss = 0.00084
Epoch 001 | Step 100 | loss = 0.00076
Epoch 001 | Step 200 | loss = 0.00070
Epoch 001 | Step 300 | loss = 0.00066
Epoch 001 | Step 400 | loss = 0.00068
Epoch 001 | Step 500 | loss = 0.00066
Epoch 001 | Step 600 | loss = 0.00082
Epoch 001 | Step 700 | loss = 0.00064
TRAIN: --- 61.75702166557312s seconds ---
VAL: --- 37.68329906463623s seconds ---
Epoch 002 | Step 000 | loss = 0.00060
Epoch 002 | Step 100 | loss = 0.00065
Epoch 002 | Step 200 | loss = 0.00047
Epoch 002 | Step 300 | loss = 0.00047
Epoch 002 | Step 400 | loss = 0.00059
Epoch 002 | Step 500 | loss = 0.00052


# Inference & Evaluation

## Utility Functions

In [None]:
def make_grid_vectors(
    image_height: int, image_width: int, output_stride: int = 1):

    xv = np.arange(0, image_width, step=output_stride)
    yv = np.arange(0, image_height, step=output_stride)
    return xv, yv

In [None]:
# 1. find instance peaks
def find_peaks(cms, xv, yv):
    """Find peaks in a set of confidence maps via integral regression.

    Args:
        cms: A batch of confidence maps of shape (batch_size, height, width, n_points).
        xv: X-sampling vector of shape (grid_width,).
        yv: Y-sampling vector of shape (grid_width,).

    Returns:
        A set of estimated peaks of shape (batch_size, n_points, 2).

    Notes:
        This function can also accept confidence maps of shape (height, width, n_points)
        and returns peaks as (n_points, 2).
    """
    is_singleton = cms.ndim == 3
    if is_singleton:
        cms = np.expand_dims(cms, axis=0)

    # Find integral over height and width.
    z = cms.reshape(cms.shape[0], -1, cms.shape[-1]).sum(axis=1)

    # Compute x- and y-coordinates.
    x = (cms * xv.reshape(1, 1, -1, 1)).reshape(cms.shape[0], -1, cms.shape[-1]).sum(axis=1) / z
    y = (cms * yv.reshape(1, -1, 1, 1)).reshape(cms.shape[0], -1, cms.shape[-1]).sum(axis=1) / z

    # Stack the coordinates into (batch_size, n_points, 2).
    pts_pr = np.stack([x, y], axis=-1)

    if is_singleton:
        pts_pr = pts_pr.squeeze(axis=0)
    return pts_pr

def clip_peaks(peaks):
  for p in peaks:
    out_of_bounds = np.logical_not(np.logical_and(np.all(np.where(0 <= p, True, False), axis=1),
                                                  np.all(np.where(p <= 160, True, False), axis=1)))
    p[out_of_bounds, :] = np.nan

  return peaks

In [None]:
def compute_instance_area(points: np.ndarray) -> np.ndarray:
    """Compute the area of the bounding box of a set of keypoints.

    Args:
        points: A numpy array of coordinates.

    Returns:
        The area of the bounding box of the points.
    """
    if points.ndim == 2:
        points = np.expand_dims(points, axis=0)

    min_pt = np.nanmin(points, axis=-2)
    max_pt = np.nanmax(points, axis=-2)

    return np.prod(max_pt - min_pt, axis=-1)



def compute_oks(
    points_gt: np.ndarray,
    points_pr: np.ndarray,
    scale = None,
    stddev: float = 0.025,
) -> np.ndarray:
    """Compute the object keypoints similarity between sets of points.

    Args:
        points_gt: Ground truth instances of shape (n_gt, n_nodes, n_ed),
            where n_nodes is the number of body parts/keypoint types, and n_ed
            is the number of Euclidean dimensions (typically 2 or 3). Keypoints
            that are missing/not visible should be represented as NaNs.
        points_pr: Predicted instance of shape (n_pr, n_nodes, n_ed).
        scale: Size scaling factor to use when weighing the scores, typically
            the area of the bounding box of the instance (in pixels). This
            should be of the length n_gt. If a scalar is provided, the same
            number is used for all ground truth instances. If set to None, the
            bounding box area of the ground truth instances will be calculated.
        stddev: The standard deviation associated with the spread in the
            localization accuracy of each node/keypoint type. This should be of
            the length n_nodes. "Easier" keypoint types will have lower values
            to reflect the smaller spread expected in localizing it.

    Returns:
        The object keypoints similarity between every pair of ground truth and
        predicted instance, a numpy array of of shape (n_gt, n_pr) in the range
        of [0, 1.0], with 1.0 denoting a perfect match.

    Notes:
        It's important to set the stddev appropriately when accounting for the
        difficulty of each keypoint type. For reference, the median value for
        all keypoint types in COCO is 0.072. The "easiest" keypoint is the left
        eye, with stddev of 0.025, since it is easy to precisely locate the
        eyes when labeling. The "hardest" keypoint is the left hip, with stddev
        of 0.107, since it's hard to locate the left hip bone without external
        anatomical features and since it is often occluded by clothing.

        The implementation here is based off of the descriptions in:
        Ronch & Perona. "Benchmarking and Error Diagnosis in Multi-Instance Pose
        Estimation." ICCV (2017).
    """
    if points_gt.ndim == 2:
        points_gt = np.expand_dims(points_gt, axis=0)
    if points_pr.ndim == 2:
        points_pr = np.expand_dims(points_pr, axis=0)

    if scale is None:
        scale = compute_instance_area(points_gt)

    n_gt, n_nodes, n_ed = points_gt.shape  # n_ed = 2 or 3 (euclidean dimensions)
    n_pr = points_pr.shape[0]

    # If scalar scale was provided, use the same for each ground truth instance.
    if np.isscalar(scale):
        scale = np.full(n_gt, scale)

    # If scalar standard deviation was provided, use the same for each node.
    if np.isscalar(stddev):
        stddev = np.full(n_nodes, stddev)

    # Compute displacement between each pair.
    displacement = np.reshape(points_gt, (n_gt, 1, n_nodes, n_ed)) - np.reshape(
        points_pr, (1, n_pr, n_nodes, n_ed)
    )
    assert displacement.shape == (n_gt, n_pr, n_nodes, n_ed)

    # Convert to pairwise Euclidean distances.
    distance = (displacement ** 2).sum(axis=-1)  # (n_gt, n_pr, n_nodes)
    assert distance.shape == (n_gt, n_pr, n_nodes)

    # Compute the normalization factor per keypoint.
    spread_factor = (2 * stddev) ** 2
    scale_factor = 2 * (scale + np.spacing(1))
    normalization_factor = np.reshape(spread_factor, (1, 1, n_nodes)) * np.reshape(
        scale_factor, (n_gt, 1, 1)
    )
    assert normalization_factor.shape == (n_gt, 1, n_nodes)

    # Since a "miss" is considered as KS < 0.5, we'll set the
    # distances for predicted points that are missing to inf.
    missing_pr = np.any(np.isnan(points_pr), axis=-1)  # (n_pr, n_nodes)
    assert missing_pr.shape == (n_pr, n_nodes)
    distance[:, missing_pr] = np.inf

    # Compute the keypoint similarity as per the top of Eq. 1.
    ks = np.exp(-(distance / normalization_factor))  # (n_gt, n_pr, n_nodes)
    assert ks.shape == (n_gt, n_pr, n_nodes)

    # Set the KS for missing ground truth points to 0.
    # This is equivalent to the visibility delta function of the bottom
    # of Eq. 1.
    missing_gt = np.any(np.isnan(points_gt), axis=-1)  # (n_gt, n_nodes)
    assert missing_gt.shape == (n_gt, n_nodes)
    ks[np.expand_dims(missing_gt, axis=1)] = 0

    # Compute the OKS.
    n_visible_gt = np.sum(
        (~missing_gt).astype("float64"), axis=-1, keepdims=True
    )  # (n_gt, 1)
    oks = np.sum(ks, axis=-1) / n_visible_gt
    assert oks.shape == (n_gt, n_pr)

    return oks

In [None]:
def get_match_scores(all_y_preds, all_y_gt, stddev=0.25):
  match_scores = []
  for i, j in zip(all_y_gt, all_y_preds):
    oks = compute_oks(i, j, stddev=stddev)[0][0]
    match_scores.append(oks)

  return match_scores

def evaluate(
  match_scores,
  num_positive_pairs,
  num_false_negatives,
  recall_thresholds,
  match_score_thresholds,
):

  precisions = []
  recalls = []

  npig = num_positive_pairs + num_false_negatives  # total number of GT instances

  for match_score_threshold in match_score_thresholds:

      tp = np.cumsum(match_scores >= match_score_threshold)
      fp = np.cumsum(match_scores < match_score_threshold)

      rc = tp / npig
      pr = tp / (fp + tp + np.spacing(1))

      recall = rc[-1]  # best recall at this OKS threshold

      # Ensure strictly decreasing precisions.
      for i in range(len(pr) - 1, 0, -1):
          if pr[i] > pr[i - 1]:
              pr[i - 1] = pr[i]

      # Find best precision at each recall threshold.
      rc_inds = np.searchsorted(rc, recall_thresholds, side="left")
      precision = np.zeros(rc_inds.shape)
      is_valid_rc_ind = rc_inds < len(pr)
      precision[is_valid_rc_ind] = pr[rc_inds[is_valid_rc_ind]]

      precisions.append(precision)
      recalls.append(recall)

  precisions = np.array(precisions)
  recalls = np.array(recalls)

  AP = precisions.mean(
      axis=1
  )  # AP = average precision over fixed set of recall thresholds
  AR = recalls  # AR = max recall given a fixed number of detections per image

  mAP = precisions.mean()  # mAP = mean over all OKS thresholds
  mAR = recalls.mean()  # mAR = mean over all OKS thresholds

  return precisions, AP, mAP, mAR

## Evaluation Code

In [None]:
# Get all y_preds and y_gt.
xv, yv = make_grid_vectors(
    image_height=160,
    image_width=160,
    output_stride=2
)

all_y_preds, all_y_gt = [], []
for step, ex in enumerate(val_p.make_dataset()):
  y_preds = model(ex["instance_image"])
  y_preds = find_peaks(y_preds.numpy(), xv, yv)
  y_preds = clip_peaks(y_preds)
  y_gt = find_peaks(ex["instance_confidence_maps"].numpy(), xv, yv)

  all_y_preds.append(y_preds)
  all_y_gt.append(y_gt)

all_y_preds = np.concatenate(all_y_preds, axis=0)
all_y_gt = np.concatenate(all_y_gt, axis=0)

In [None]:

# Define thresholds.
match_score_thresholds = np.linspace(0.5, 0.95, 10)
recall_thresholds = np.linspace(0, 1, 101)

# Get evaluation metrics.
match_scores = get_match_scores(all_y_preds, all_y_gt)
precisions, AP, mAP, mAR = evaluate(match_scores, 400, 0, recall_thresholds, match_score_thresholds)

In [None]:
mAP, AP

(0.8722750563003294,
 array([0.99721919, 0.99610686, 0.99242706, 0.98310167, 0.97586516,
        0.96117723, 0.93709047, 0.87867787, 0.73393265, 0.2671524 ]))

In [None]:
import numpy as np

np.linspace(0.5, 0.95, 10)

array([0.5 , 0.55, 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95])