# Setup

In [None]:
import tensorflow as tf

%env MUJOCO_GL=egl

# limit jax and TF from consuming all GPU memory
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

# Check if GPU is available
if tf.config.list_physical_devices('GPU'):
    print("TensorFlow is using the GPU")
else:
    print("TensorFlow is not using the GPU")


gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices("GPU")
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)


import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

# Check for available GPU devices
num_devices = jax.local_device_count()
print(f"Found {num_devices} JAX devices:")


In [None]:
import tensorflow as tf
import tensorflow_hub as hub

from matplotlib import pyplot as plt
import numpy as np

import cv2

In [None]:
# test the body model loads
from monocular_demos.biomechanics_mjx.forward_kinematics import ForwardKinematics

fk = ForwardKinematics()

# Run MeTAbs-ACAE on the video

First upload a video to your colab environment and then select it with the next cell

In [None]:
!ls

In [None]:
import os

# List files in current directory
files = os.listdir()
files = [f for f in files if 'mp4' in f or 'MOV' in f or 'MP4' in f]

if len(files) > 1:
    print("Available files:")
    for i, file in enumerate(files):
        print(f"{i+1}. {file}")

    # Prompt user for selection
    choice = int(input("Enter the number of the file to select: ")) - 1
    video_filepath = files[choice]

    print(f"You selected: {video_filepath}")

else:

    assert len(files) == 1, "No videos uploaded"

    video_filepath = files[0]


In [None]:
model = hub.load('https://bit.ly/metrabs_l')  # Takes about 3 minutes

# there are many skeleton formats support by this model. we are selecting one
# compatible with the gait transformer we will use below
skeleton = 'bml_movi_87'

# get the joint names and the edges between them for visualization below
joint_names = model.per_skeleton_joint_names[skeleton].numpy().astype(str)
joint_edges = model.per_skeleton_joint_edges[skeleton].numpy()

In [None]:
frame_batch

In [None]:
from monocular_demos.utils import video_reader
from tqdm import tqdm

vid, n_frames = video_reader(video_filepath)

accumulated = None
for i, frame_batch in tqdm(enumerate(vid)):
    pred = model.detect_poses_batched(frame_batch, skeleton=skeleton)

    if accumulated is None:
        accumulated = pred

    else:
        # concatenate the ragged tensor along the batch for each element in the dictionary
        for key in accumulated.keys():
            accumulated[key] = tf.concat([accumulated[key], pred[key]], axis=0)

    # if i > 10:
    #     break

In [None]:
num_people = [p.shape[0] for p in accumulated['poses2d']]

# assert this is 1 for all the frames
assert len(set(num_people)) == 1

# then extract the information for that person
boxes = np.array([p[0] for p in accumulated['boxes']])
pose3d = np.array([p[0] for p in accumulated['poses3d']])
pose2d = np.array([p[0] for p in accumulated['poses2d']])


In [None]:
# For convenience, save the keypoints in case the notebook crashes or you have to restart

pose3d = np.array([p[0] for p in accumulated['poses3d']])

with open('keypoints3d.npz', 'wb') as f:
    np.savez(f, pose3d)

# Exploration step: try to extract the knee angle over time

Example approach: take the cross product between limb segments

# Now compute kinematics end-to-end using a differentiable body model

This uses an implicit representation $f_\theta: t \rightarrow \theta \in \mathbb R^{40}$, which is then passed through the forward kinematic model to get the predicted 3D keypoints: $\mathcal M_\beta: \theta \rightarrow \mathbf y \in \mathbb R^{87 \times 3}$.

We optimize the difference between the predicted 3D keypoints and the detected 3D keypoints.

In [None]:
with open('keypoints3d.npz', 'rb') as f:
    pose3d = np.load(f, allow_pickle=True)['arr_0']

In [None]:
pose3d = np.array([p[0] for p in accumulated['poses3d']])

In [None]:
from jax import numpy as jnp

# convert pose to m
pose = pose3d
pose = pose[:, :, [0, 2, 1]]
pose[:, :, 2] *= -1
pose /= 1000.0

pose = pose - np.min(pose, axis=1, keepdims=True)

timestamps = jnp.arange(len(pose)) / 30.0

dataset = (timestamps, pose)

In [None]:
from jaxtyping import Integer, Float, Array, PRNGKeyArray
from typing import Tuple, Dict
from tqdm import trange
import equinox as eqx
import optax

from monocular_demos.biomechanics_mjx.monocular_trajectory import KineticsWrapper, get_default_wrapper

# construct a loss function between the forward pass through the forward kinematic
# implicit representation and the resulting keypoint and the detected keypoitns

def loss(
    model: KineticsWrapper,
    x: Float[Array, "times"],
    y: Float[Array, "times keypoints 3"],
    site_offset_regularization = 1e-1
) -> Tuple[Float, Dict]:

    timestamps = x
    keypoints3d = y
    metrics = {}

    # NOTE: steps is an make sure this retraces for different dimensions
    (state, constraints, next_states), (ang, vel, action), _ = model(
        timestamps,
        skip_vel=True,
        skip_action=True,
    )

    pred_kp3d = state.site_xpos

    l = jnp.mean((pred_kp3d - keypoints3d) ** 2) * 100 # so in cm
    metrics["kp_err"] = l

    # regularize marker offset
    l_site_offset = jnp.sum(jnp.square(model.site_offsets))
    l += l_site_offset * site_offset_regularization

    # make loss the first key in the dictionary by popping and building a new dictionary with the rest
    metrics = {"loss": l, **metrics}

    return l, metrics


@eqx.filter_jit
def step(model, opt_state, data, loss_grad, optimizer, **kwargs):
    x, targets = data

    (val, metrics), grads = loss_grad(model, x=x, y=targets, **kwargs)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return val, model, opt_state, metrics


def fit_model(
    model: KineticsWrapper,
    dataset: Tuple,
    lr_end_value: float = 1e-8,
    lr_init_value: float = 1e-4,
    max_iters: int = 5000,
    clip_by_global_norm: float = 0.1,
):

    # work out the transition steps to make the desired schedule
    transition_steps = 10
    lr_decay_rate = (lr_end_value / lr_init_value) ** (1.0 / (max_iters // transition_steps))
    learning_rate = optax.warmup_exponential_decay_schedule(
        init_value=0,
        warmup_steps=0,
        peak_value=lr_init_value,
        end_value=lr_end_value,
        decay_rate=lr_decay_rate,
        transition_steps=transition_steps,
    )

    optimizer = optax.chain(
        optax.adamw(learning_rate=learning_rate, b1=0.8, weight_decay=1e-5), optax.zero_nans(), optax.clip_by_global_norm(clip_by_global_norm)
    )
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

    loss_grad = eqx.filter_value_and_grad(loss, has_aux=True)

    counter = trange(max_iters)
    for i in counter:

        val, model, opt_state, metrics = step(model, opt_state, dataset, loss_grad, optimizer)

        if i > 0 and i % int(max_iters // 10) == 0:
            print(f"iter: {i} loss: {val}.")  # metrics: {metrics}")

        if i % 50 == 0:
            metrics = {k: v.item() for k,v in metrics.items()}
            print(val, metrics)

    return model, metrics


fkw = get_default_wrapper()
updated_model, metrics = fit_model(fkw, dataset)

# Now explore the results


In [None]:
(state, constraints, next_states), (ang, vel, action), _ = updated_model(dataset[0], skip_vel=True, skip_action=True)
jnp.mean((state.site_xpos.shape - dataset[1]) ** 2)

# plot the knees
plt.figure()
plt.plot(ang[:, [9, 16]]);

In [None]:
from body_models.biomechanics_mjx.visualize import render_trajectory, jupyter_embed_video

fn = 'reconstruction.mp4'
render_trajectory(ang, fn, xml_path=None)
HTML = jupyter_embed_video(fn)
HTML