In [None]:
# Environment: Apple M1 Max, macOS 15.0, Python 3.10
!pip install imageio
!pip install ipython
!pip install matplotlib
!pip install opencv-python
!pip install tensorflow
!pip install tensorflow-hub
!pip install tensorflow-metal

In [None]:
import os
import resource
import time

import cv2
# Some modules to display an animation using imageio
import numpy as np
import tensorflow as tf
from PIL import Image
# Import matplotlib libraries
from matplotlib import pyplot as plt
from tensorflow.python.framework.ops import EagerTensor

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# Load the TFLite model using the TFLite Interpreter
model_path = "./posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite"
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

In [None]:
# Get input and output tensor details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [None]:
input_details

In [None]:
output_details

In [None]:
input_size = input_details[0]['shape'][1]

In [None]:
#@title Helper functions for visualization

# Dictionary that maps from joint names to keypoint indices.
KEYPOINT_DICT = {
    'nose': 0,
    'left_eye': 1,
    'right_eye': 2,
    'left_ear': 3,
    'right_ear': 4,
    'left_shoulder': 5,
    'right_shoulder': 6,
    'left_elbow': 7,
    'right_elbow': 8,
    'left_wrist': 9,
    'right_wrist': 10,
    'left_hip': 11,
    'right_hip': 12,
    'left_knee': 13,
    'right_knee': 14,
    'left_ankle': 15,
    'right_ankle': 16
}

# Maps bones to a matplotlib color name.
KEYPOINT_EDGE_INDS_TO_COLOR = {
    (0, 1): 'm',
    (0, 2): 'c',
    (1, 3): 'm',
    (2, 4): 'c',
    (0, 5): 'm',
    (0, 6): 'c',
    (5, 7): 'm',
    (7, 9): 'm',
    (6, 8): 'c',
    (8, 10): 'c',
    (5, 6): 'y',
    (5, 11): 'm',
    (6, 12): 'c',
    (11, 12): 'y',
    (11, 13): 'm',
    (13, 15): 'm',
    (12, 14): 'c',
    (14, 16): 'c'
}

color_map = {
    'c': (0, 191, 191),
    'm': (191, 0, 191),
    'y': (191, 191, 0)
}


def save_image_with_prediction(idx,
                               raw_image,
                               keypoints_with_scores,
                               keypoint_threshold=0.11):
    raw_image = raw_image.numpy()
    keypoints_with_scores = keypoints_with_scores.copy()

    raw_height, raw_width, _ = raw_image.shape
    longest_side = max(raw_height, raw_width)

    # Convert relative coordinates to actual coordinates
    keypoints_with_scores[..., :2] *= longest_side

    # Offset the coordinates based on the aspect ratio
    if raw_height > raw_width:
        keypoints_with_scores[..., 1] -= (longest_side - raw_width) // 2
    elif raw_height < raw_width:
        keypoints_with_scores[..., 0] -= (longest_side - raw_height) // 2

    # Retrieve values from the output
    kpts_x = keypoints_with_scores[0, 0, :, 1].astype(int)
    kpts_y = keypoints_with_scores[0, 0, :, 0].astype(int)
    kpts_scores = keypoints_with_scores[0, 0, :, 2]

    # Pair up keypoints to form edges
    for edge_pair, color in KEYPOINT_EDGE_INDS_TO_COLOR.items():
        if (kpts_scores[edge_pair[0]] > keypoint_threshold and
                kpts_scores[edge_pair[1]] > keypoint_threshold):
            x_start = kpts_x[edge_pair[0]]
            y_start = kpts_y[edge_pair[0]]
            x_end = kpts_x[edge_pair[1]]
            y_end = kpts_y[edge_pair[1]]

            cv2.line(raw_image, [x_start, y_start], [x_end, y_end], color_map[color],
                     thickness=max(longest_side // 300, 1))

    # Plot the keypoints
    for i, coord in enumerate(zip(kpts_x, kpts_y)):
        if kpts_scores[i] > keypoint_threshold:
            cv2.circle(raw_image, coord, radius=max(longest_side // 150, 2), color=(255, 20, 147), thickness=-1)

    # Convert RGB to BGR
    bgr_image = cv2.cvtColor(raw_image, cv2.COLOR_RGB2BGR)

    # Save the image
    cv2.imwrite(f"./output/{idx:08d}.png", bgr_image)

    return f"./output/{idx:08d}.png"

### Load Dataset

In [None]:
dataset_root_dir = "./dataset"
raw_images = []
for dirpath, dirnames, filenames in os.walk(dataset_root_dir):
    dirnames.sort()
    filenames.sort()

    for filename in filenames:
        filepath = os.path.join(dirpath, filename)
        file_extension = os.path.splitext(filepath)[1].lower()

        image = tf.io.read_file(filepath)
        if file_extension in ('.jpg', '.jpeg'):
            image = tf.image.decode_jpeg(image)
        elif file_extension == '.png':
            image = tf.image.decode_png(image)
        else:
            continue

        # Ensure image is 3-channel
        image = image[..., :3]
        raw_images.append(image)

### Run Inference

In [None]:
def run_inference(image: EagerTensor):
    # Resize and pad the image to keep the aspect ratio and fit the expected size.
    input_image = tf.expand_dims(image, axis=0)
    input_image = tf.image.resize_with_pad(input_image, input_size, input_size)
    input_image = input_image.numpy() / 255

    # Run model inference.
    interpreter.set_tensor(input_details[0]['index'], input_image)
    interpreter.invoke()

    heatmaps = interpreter.get_tensor(output_details[0]['index'])  # (1, 9, 9, 17)
    offsets = interpreter.get_tensor(output_details[1]['index'])  # (1, 9, 9, 34)
    forward_displacements = interpreter.get_tensor(output_details[2]['index'])  # (1, 9, 9, 32)
    backward_displacements = interpreter.get_tensor(output_details[3]['index'])  # (1, 9, 9, 32)

    return heatmaps, offsets, forward_displacements, backward_displacements

In [None]:
def raw_output_to_coords(heatmaps, offsets, forward_displacements, backward_displacements):
    # Reference: https://raw.githubusercontent.com/joonb14/TFLitePoseEstimation/refs/heads/main/pose%20estimation.ipynb

    def sigmoid(x):
        return 1 / (1 + np.exp(x))

    _, height, width, num_keypoints = heatmaps.shape

    keypoint_positions = []
    for keypoint in range(num_keypoints):
        # Get the heatmap for the current keypoint
        heatmap = heatmaps[0, :, :, keypoint]

        # Find the index of the maximum value in the heatmap
        max_index = np.unravel_index(np.argmax(heatmap), heatmap.shape)

        # Append the row and column of the max value
        keypoint_positions.append(list(max_index))

    confidence_scores = []
    y_coords = []
    x_coords = []
    for idx, (position_y, position_x) in enumerate(keypoint_positions):
        # Normalize the coordinates and add the offset
        y_normalized = position_y / (height - 1) * input_size + offsets[0, position_y, position_x, idx]
        x_normalized = position_x / (width - 1) * input_size + offsets[0, position_y, position_x, idx + num_keypoints]

        y_coords.append(y_normalized)
        x_coords.append(x_normalized)

        # Calculate and append the confidence score using the sigmoid of the heatmap value
        confidence_score = sigmoid(heatmaps[0, position_y, position_x, idx])
        confidence_scores.append(confidence_score)

    y_relative_coords = np.array(y_coords) / input_size
    x_relative_coords = np.array(x_coords) / input_size

    return np.stack([y_relative_coords, x_relative_coords, confidence_scores], axis=1)

In [None]:
start_time = time.time()
results = [raw_output_to_coords(*run_inference(image)) for image in raw_images]
end_time = time.time()

print("Total time spent:", end_time - start_time)

In [None]:
memory_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss  # bytes
print(f"Memory usage: {memory_usage / 1024 ** 3:.2f} GB")

In [None]:
def show_and_save(image_idx):
    output_path = save_image_with_prediction(image_idx,
                                             raw_images[image_idx],
                                             np.array([[results[image_idx]]]),
                                             keypoint_threshold=0)

    image = Image.open(output_path)

    # Display the resultant image using Matplotlib
    plt.imshow(image)
    plt.axis('off')  # Hide the axis
    # plt.show()

In [None]:
os.makedirs("./output", exist_ok=True)
for i in range(len(raw_images)):
    show_and_save(i)
plt.close('all')