# Motion Fields for Interactive Character Animation

by Yongjoon Lee, Kevin Wampler, Gilbert Bernstein, Jovan Popovic, Zoran Popovic  
2010

Notebook by Jerome Eippers, 2025

In [None]:
%matplotlib widget
import pickle
from dataclasses import dataclass, field
from random import randrange
import numpy as np
from ipywidgets import widgets, interact
from matplotlib import pyplot as plt

import torch
import torch.nn as nn

import ipyanimlab as lab

viewer = lab.Viewer(move_speed=5, width=1280, height=720)


In [None]:
# load the character
character = viewer.import_usd_asset('AnimLabSimpleMale.usd')

In [None]:
displacement_asset = viewer.import_usd_asset('../../meshes/displacement.usd')

In [None]:
animmap = lab.AnimMapper(character, keep_translation=False, root_motion=True, match_effectors=True, local_offsets={'Hips':[0, 2, 0]})
animations = []
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk1_subject1.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk1_subject2.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk1_subject5.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk2_subject1.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk2_subject3.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk2_subject4.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk3_subject1.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk3_subject2.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk3_subject3.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk3_subject4.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk3_subject5.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/walk4_subject1.bvh', anim_mapper=animmap))

animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/run1_subject2.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/run1_subject5.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/run2_subject1.bvh', anim_mapper=animmap))
animations.append(lab.import_bvh(f'../../resources/lafan1/bvh/run2_subject4.bvh', anim_mapper=animmap))

bone_count = character.bone_count()
bones = animations[0].bones
parents = animations[0].parents

In [None]:
contacts = []
for anim in animations:
    _, p = lab.utils.quat_fk(anim.quats, anim.pos, parents)
    contacts.append(lab.utils.extract_feet_contacts(p, bones.index('LeftToe'), bones.index('RightToe'), 0.1))
    

In [None]:
def render(frame, index=0):

    frame = min(frame, animations[index].quats.shape[0] -1)
    q = (animations[index].quats[frame,...])
    p = (animations[index].pos[frame,...])
        
    a =  lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    
    viewer.end_display()

    viewer.disable(depth_test=True)
   
    viewer.draw_axis(character.world_skeleton_xforms(a), 5)
    viewer.draw_lines(character.world_skeleton_lines(a))
    
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=5000),
    index=widgets.IntSlider(max=len(animations)-1)
)
viewer

## Motion states and notation

A **pose** is written as
$$
x = (x_{\text{root}}, x_{\text{hips}}, p_0, p_1, \dots, p_n),
$$
where $x_{\text{root}}, x_{\text{hips}}\in\mathbb{R}^3$ is the root position and hips positions and $p_i$ are unit quaternions for joint orientations.

Given two successive poses $x$ and $x'$ we define a finite-difference **velocity**
$$
v = x' \ominus x
= \big( x'_{\text{root}} - x_{\text{root}},\; x'_{\text{hips}} - x_{\text{hips}},\;p'_0 p_0^{-1},\; p'_1 p_1^{-1},\; \dots,\; p'_n p_n^{-1} \big).
$$

We add a velocity to a pose with the operator 
$$
    x' = x \oplus v.
$$

---
A **motion state** is
$$
m = (x, v).
$$

In [None]:
default_skeleton_p = (animations[0].pos[0,...]).copy()
default_skeleton_p[0] = 0

POSESHAPE = (bone_count + 2, 4)

class PoseData:
    def __init__(self, root, hips, quats):
        self.root = root
        self.hips = hips
        self.quats = quats

def pose_pack(root, hips, quats):
    result = np.zeros(POSESHAPE, dtype=np.float32)
    result[0, :3] = root
    result[1, :3] = hips
    result[2:, :] = quats
    return result

def pose_unpack(pose):
    root = pose[..., 0, :3]
    hips = pose[..., 1, :3]
    quats = pose[..., 2:, :]
    return PoseData( root, hips, quats )

def pose_add(x, v):
    x = pose_unpack(x)
    v = pose_unpack(v)

    _, root = lab.utils.qp_mul((x.quats[0], x.root), (v.quats[0], v.root))
    hips = x.hips + v.hips
    quats = lab.utils.normalize(lab.utils.quat_mul(x.quats, v.quats))

    return pose_pack(root, hips, quats)

def pose_subtract(a, b):
    a = pose_unpack(a)
    b = pose_unpack(b)

    _, root = lab.utils.qp_mul(lab.utils.qp_inv((b.quats[0], b.root)), (a.quats[0], a.root))
    hips = a.hips - b.hips
    quats = lab.utils.normalize(lab.utils.quat_mul(lab.utils.quat_inv(b.quats), a.quats))
    flip = quats[:, 0] < 0
    quats[flip, :] = -quats[flip, :]

    return pose_pack(root, hips, quats)

def pose_lerp(a, b, t):
    a = pose_unpack(a)
    b = pose_unpack(b)

    root = (1.0-t) * a.root + (t) * b.root
    hips = (1.0-t) * a.hips + (t) * b.hips
    quats = lab.utils.normalize(lab.utils.quat_slerp(a.quats, b.quats, t))
    
    return pose_pack(root, hips, quats)

def pose_blend(states, weights):
    states = pose_unpack(states)

    root = np.sum(states.root * weights[:, np.newaxis], axis=0)
    hips = np.sum(states.hips * weights[:, np.newaxis], axis=0)
    quats = lab.utils.normalize(np.sum(states.quats* weights[:, np.newaxis, np.newaxis], axis=0))
    
    return pose_pack(root, hips, quats)

def pose_to_qp(a):
    a = pose_unpack(a)
    p = default_skeleton_p.copy()
    p[0, :] = a.root
    p[1, :] = a.hips
    return a.quats.copy(), p

## Build the states

In [None]:
states_x = np.zeros([50000, bone_count+2, 4], dtype=np.float32)
states_v = np.zeros([50000, bone_count+2, 4], dtype=np.float32)
states_y = np.zeros([50000, bone_count+2, 4], dtype=np.float32)
states_c = np.zeros([50000, 2], dtype=np.bool)
states_count = 0

def add_states_ex(quats, pos, lcontact, rcontact):
    global states_count
  
    for i in range(quats.shape[0] - 2):
        a = pose_pack(pos[i, 0, :].copy(), pos[i, 1, :].copy(), quats[i,...].copy() )
        b = pose_pack(pos[i+1, 0, :].copy(), pos[i+1, 1, :].copy(), quats[i+1,...].copy() )
        c = pose_pack(pos[i+2, 0, :].copy(), pos[i+2, 1, :].copy(), quats[i+2,...].copy() )

        y = pose_subtract(c, b)
        v = pose_subtract(b, a)
        a = pose_unpack(a)
        a.root[:] = 0
        a.quats[0, :] = [1,0,0,0]
        a = pose_pack(a.root, a.hips, a.quats)

        states_x[states_count, :, :] = a
        states_v[states_count, :, :] = v
        states_y[states_count, :, :] = y

        states_c[states_count, 0] = lcontact[i]
        states_c[states_count, 1] = rcontact[i]
        
        states_count += 1

def add_states(anim_id, timings):
    add_states_ex(animations[anim_id].quats[timings], animations[anim_id].pos[timings], contacts[anim_id][0][timings], contacts[anim_id][1][timings])


# Walk :
add_states(2, slice(100,2800))
end_of_walk_ids = states_count

# Jog :
add_states(15, slice(1200,1800))
add_states(15, slice(3450,3860))
add_states(14, slice(180,800))
add_states(13, slice(200,2300))
    
states_x = states_x[:states_count, ...]
states_v = states_v[:states_count, ...]
states_y = states_y[:states_count, ...]
states_c = states_c[:states_count, ...]
states_count

In [None]:
def render(frame, next_state=False):

    controller_orient = np.array([1,0,0,0], dtype=np.float32)       
    display(states_c[frame])

    state = states_x[frame].copy()
    if next_state:
        state = pose_add(state, states_v[frame])
    q, p = pose_to_qp(state)

    angle = np.atan2(p[0, 0], p[0, 2])
    controller_orient[0] = np.cos(angle/2)
    controller_orient[2] = np.sin(angle/2)
        
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)

    d = lab.utils.quat_to_mat(q[0], p[0])
    viewer.draw(displacement_asset, d)

    d = lab.utils.quat_to_mat(controller_orient, p[0])
    viewer.draw(displacement_asset, d)
    
    viewer.end_display()

    viewer.disable(depth_test=True)
   
    viewer.draw_axis(character.world_skeleton_xforms(a), 5)
    viewer.draw_lines(character.world_skeleton_lines(a))
    
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=states_count-1)
)
viewer

## Similarity metric and similarity weights

Similarity between motion states is computed in this case computed using the positions of each bone and their velocities (unlike the paper that uses rotation info)

Given the $k$ nearest neighbours $\{m_i\}$ of $m$, we use inverse-distance weights (normalized) for interpolation:
$$
w_i \;=\; \frac{\dfrac{1}{d(m,m_i)^2}}{\sum_j \dfrac{1}{d(m,m_j)^2}}.
$$
These $w_i$ are called *similarity weights* and are used both for interpolation and as the passive action.


In [None]:
FEATURE_SHAPE = (bone_count*2, 3)
# metric_weights = np.array([0, .3, .1, .1, .01, .01, .01,  .01, .01, .01, .01,  .01, .01, .01, .01,  .2, .5, 1, 1,  .2, .5, 1, 1 ], dtype=np.float32)
# metric_velocity_weights = np.array([1, .8, .5, .5, .5, .01, .01,  .01, .01, .01, .9,  .01, .01, .01, .9,  1.2, 1.5, 2, 0,  1.2, 1.5, 2, 0 ], dtype=np.float32)
# metric_weights = np.ones([bone_count], dtype=np.float32)
# metric_velocity_weights = np.ones([bone_count], dtype=np.float32)
metric_weights = np.array([0, .3, .1, .1, .5, .01, .01,  .01, .01, .01, .01,  .01, .01, .01, .01,  .2, .5, 1, 1,  .2, .5, 1, 1 ], dtype=np.float32)
metric_velocity_weights = np.array([1, .8, .5, .1, .1, .5, .1,  0,0,0,0, 0,0,0,0,  1.2, 1.5, 2, 0,  1.2, 1.5, 2, 0 ], dtype=np.float32)

def build_distance_metric(x, v):

    next_pose = pose_unpack(x.copy()) #pose_unpack(pose_add(x, v))
    next_pose.quats[0] = [1,0,0,0]
    _, p_a = lab.utils.quat_fk(next_pose.quats, default_skeleton_p, parents)

    next_pose = pose_unpack(pose_add(pose_pack(next_pose.root, next_pose.hips, next_pose.quats), v))
    
    #next_pose.quats[0] = [1,0,0,0]
    _, p_b = lab.utils.quat_fk(next_pose.quats, default_skeleton_p, parents)

    return np.concatenate([p_a * metric_weights[:, np.newaxis] * .8, p_b - p_a])

In [None]:
metric_matrix = np.zeros([states_count, FEATURE_SHAPE[0], 3], dtype=np.float32)
for i in range(states_count):
    metric_matrix[i, ...] = build_distance_metric(states_x[i], states_v[i])

### Simple representation of the motion field

This is purely for visually explaining the concept. It does not really represent anything.

In [None]:
from umap import UMAP

# metric_matrix: [N_samples x N_features] matrix
reducer = UMAP(n_components=3, metric='euclidean', n_neighbors = 80)
embedding = reducer.fit_transform(metric_matrix.reshape(-1, FEATURE_SHAPE[0]*3))

fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection='3d')

ax.cla()
ax.scatter(embedding[:, 0], embedding[:, 1], embedding[:, 2], s=5, alpha=0.5)

### Use Pytorch for knn

Running in parallele on the gpu we can just do a brut force search through all the poses to find the k nearest

In [None]:
toch_knn_features = torch.from_numpy(metric_matrix).to('cuda').unsqueeze(0)

In [None]:
def get_nns_by_vector(vector, k):
    query = torch.from_numpy(vector).to('cuda')
    
    # Expand p and q to broadcast
    p_exp = toch_knn_features         # shape: (1, state_count, feature_count, 3)
    q_exp = query.unsqueeze(1)                      # shape: (query_count, 1, feature_count, 3)
    
    diff = p_exp - q_exp            # shape: (query_count, state_count, feature_count, 3), broadcasting
    point_distances = torch.norm(diff, dim=3)  # shape: (query_count, state_count, feature_count)
    
    # Step 2: sum distances per set
    sum_distances = torch.sum(point_distances, dim=2)  # shape: (query_count, state_count)
    
    # Step 3: get k nearest sets for each query
    topk = torch.topk(sum_distances, k=k, largest=False)
    knn_indices = topk.indices.cpu().numpy()      # shape: (query_count, k)
    knn_distances = topk.values.cpu().numpy()    # shape: (query_count, k)

    return knn_indices, knn_distances

## Integration function and drift correction

An action $a=[a_1,\dots,a_k]$ is a convex combination over the neighborhood. The basic integration computes the next pose/velocity by blending the neighbour velocities and next-frame velocities:
$$
I(m,a) \;=\; \Big(x \oplus \sum_{i=1}^k a_i v_i,\; \sum_{i=1}^k a_i y_i\Big),
$$
where $y_i$ denotes the "next" velocity stored at neighbour $m_i$.

To prevent drift into regions with little data, a small drift-correction toward the nearest database state $\bar m=(\bar x,\bar v)$ is blended in with strength $\delta$ (typical $\delta\!=\!0.1$):
$$
v' = (1-\delta)\Big(x \oplus \sum_{i=1}^k a_i v_i\Big) \oplus \delta\big((\bar x\oplus\bar v)\ominus x\big),
\qquad
y' = (1-\delta)\Big(\sum_{i=1}^k a_i y_i\Big)\oplus \delta\,\bar y.
$$

In [None]:
def get_k_neighbors(current_x, current_v, k=15):
    indices, distances = get_nns_by_vector(build_distance_metric(current_x, current_v)[np.newaxis, ...], k)
    indices, distances = indices[0], distances[0] 
    idistances = 1.0/(distances**2 + 1e-8)
    idistances /= np.sum(idistances)
    return indices, idistances

def get_batched_k_neighbors(metric_vector, k=15):
    indices, distances = get_nns_by_vector(metric_vector, k)
    idistances = 1.0/(distances**2 + 1e-8)
    idistances /= np.sum(idistances, axis=1, keepdims=True)
    return indices, idistances

def compute_v_to_reach_state(current_x, state_id):
    next_x = pose_add(states_x[state_id], states_v[state_id])
    next_x = pose_unpack(next_x)
    x = pose_unpack(current_x)
    next_x.root += x.root
    next_x.quats[0] = lab.utils.quat_mul(x.quats[0], next_x.quats[0])
    next_x = pose_pack(next_x.root, next_x.hips, next_x.quats)
    return pose_subtract(next_x, current_x)

def compute_new_state(current_x, indices, weights, tug_ratio=.1):
    tug_indice = np.argmax(weights)
    next_v = compute_v_to_reach_state(current_x, indices[tug_indice])

    blended_v = pose_blend(states_v[indices, ...], weights)
    final_v = pose_lerp(blended_v, next_v, tug_ratio)

    blended_y = pose_blend(states_y[indices, ...], weights)
    final_y = pose_lerp(blended_y, states_y[indices[tug_indice]], tug_ratio)

    return (
        pose_add(current_x, final_v),
        final_y
    )

In [None]:
last_indice = 0

current_x = states_x[last_indice, ...].copy()
current_v = states_v[100, ...].copy()

umap_fig = plt.figure(figsize=(5, 5))
umap_ax = umap_fig.add_subplot(111, projection='3d', computed_zorder=False)

def render(frame, ratio=.1, select=0, on_spot=False, plot_map=False):
    global last_indice
    
    # compute the n closest
    indices, weights = get_k_neighbors(current_x, current_v, 15)

    # display helper
    current_pose_umap= None
    if plot_map:
        current_pose_umap = reducer.transform(build_distance_metric(current_x, current_v).flatten().reshape(1, -1))[0]

    if select > -1:
        weights[select] = 1.0
        weights /= np.sum(weights)

    last_indice = indices[0]
    
    current_x[...], current_v[...] = compute_new_state(current_x, indices, weights, ratio)

    q, p = pose_to_qp(current_x)
    
    display((indices, weights))

    if on_spot:
        p[0] = [0, 0, 200]
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])

    character.materials()[0].set_albedo(np.array([.4, .4, .8], dtype=np.float32))
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)

    for i in range(15):
        q, p = pose_to_qp(states_x[indices[i]])
        e = lab.utils.quat_to_mat(q, p)
        character.materials()[0].set_albedo( np.ones(3, dtype=np.float32)* float(weights[i]) )
        e[0, 0, 3] = i * 40
        viewer.draw(character, e)
    
    viewer.end_display()

    viewer.disable(depth_test=True)

    viewer.draw_axis(character.world_skeleton_xforms(a), 5)
    viewer.draw_lines(character.world_skeleton_lines(a))
    
    viewer.execute_commands()

    # Create matplotlib figure
    if plot_map:
        umap_ax.cla()
        
        # Plot all database points (light gray)
        
    
        umap_ax.scatter(current_pose_umap[0], current_pose_umap[1], current_pose_umap[2],
                      c='red', s=200, marker='*', edgecolors='black', 
                      linewidth=1, label='Current pose', zorder=10)
    
        neighbor_points = embedding[indices, :]
        umap_ax.scatter(neighbor_points[:, 0], neighbor_points[:, 1], neighbor_points[:, 2], 
                       c=weights, cmap='viridis', s=100, 
                       edgecolors='blue', linewidth=0.5, 
                       label='Neighbors', zorder=5)

        umap_ax.scatter(embedding[:, 0], embedding[:, 1], embedding[:, 2],
                  c='lightgray', s=20, alpha=0.6, label='Database poses')
    
interact(
    render, 
    frame=widgets.Play(interval=1000./30.), #lab.Timeline(max=100),
    ratio=widgets.FloatSlider(min=0, max=1.0, value=.1),
    select=widgets.IntSlider(min=-1, max=14, value=-1)
)

viewer

## Greedy action selection from k-NN

At each motion state $m$, we consider a set of discrete candidate actions.  
Each action is constructed from the similarity weights of the $k$ nearest neighbors.



**Step 1: Base weights**  
From the k-NN search, compute similarity weights
$$
w_i = \frac{1/d(m,m_i)^2}{\sum_{j=1}^k 1/d(m,m_j)^2}, \qquad i=1,\dots,k.
$$

These form the *passive action* (a convex combination of neighbors).



**Step 2: Action generation**  
To explore different transitions, we derive one candidate action per neighbor:

- For neighbor $i$, define raw weights
  $$
  \tilde{a}_j =
  \begin{cases}
    1, & j=i, \\
    w_j, & j \neq i.
  \end{cases}
  $$

- Renormalize to obtain a valid action vector
  $$
  a_j = \frac{\tilde{a}_j}{\sum_{\ell=1}^k \tilde{a}_\ell}.
  $$

Thus each candidate action $a^{(i)}$ emphasizes a single neighbor while still respecting the similarity distribution.



**Step 3: Greedy selection**  
For greedy control (without planning), we evaluate the immediate reward (or similarity) of each candidate $a^{(i)}$ and pick
$$
a^\star = \arg\max_{i} \; R\big(m, a^{(i)}\big).
$$


In [None]:
gamepad = widgets.Controller(index=0)
gamepad

In [None]:
K_NEIGHBORS = 15

In [None]:
def signed_angle(desired: np.ndarray, current: np.ndarray, up=np.array([0,1,0])):
    # project to ground
    d = desired.copy(); d[1] = 0; d /= np.linalg.norm(d)+1e-12
    c = current.copy(); c[1] = 0; c /= np.linalg.norm(c)+1e-12
    
    dot = np.clip(np.dot(d,c), -1.0, 1.0)
    angle = np.acos(dot)  # magnitude
    
    # sign from cross product
    cross = np.cross(c,d)  # c -> d
    sign = np.sign(np.dot(cross, up))
    return angle * sign

In [None]:
def action_reward(desired_direction, current_x, next_x):
    n = pose_unpack(next_x)
    
    orientation = lab.utils.quat_mul_vec(n.quats[0], np.array([0,0,1], dtype=np.float32))

    reward = -abs(signed_angle(desired_direction, orientation))
    return reward

In [None]:
TUG_RATIO = .1
last_indice = 400

current_x = states_x[last_indice, ...].copy()
current_v = states_v[last_indice, ...].copy()

def render(frame, on_spot=False):

    controller_orient = np.array([1,0,0,0], dtype=np.float32)
    posx = gamepad.axes[0].value 
    posz = -gamepad.axes[1].value 
    if np.abs(posx) > 0.001 or np.abs(posz) > 0.001:
        angle = np.atan2(posz, posx)
        controller_orient[0] = np.cos(angle/2)
        controller_orient[2] = np.sin(angle/2)

    desired_direction = lab.utils.quat_mul_vec(controller_orient, np.array([0,0,1], dtype=np.float32))
    
    # compute the n closest
    indices, weights = get_k_neighbors(current_x, current_v, K_NEIGHBORS)
    #display(indices)

    rewards = np.zeros(K_NEIGHBORS)
    for n_idx in range(K_NEIGHBORS):

        w = weights.copy()
        w[n_idx] = 1.0
        w /= np.sum(w)
        nx, nv = compute_new_state(current_x, indices, w, TUG_RATIO)

        reward = action_reward(desired_direction, current_x, nx)

        # store the rewards
        rewards[n_idx] = reward

    # apply the best
    display(rewards)
    best_i = np.argmax(rewards)
    weights[best_i] = 1.0
    weights /= np.sum(weights)
    current_x[...], current_v[...] = compute_new_state(current_x, indices, weights, TUG_RATIO)
    
    q, p = pose_to_qp(current_x)

    if on_spot:
        p[0] = [0, 0, 200]
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])

    character.materials()[0].set_albedo(np.array([.4, .4, .8], dtype=np.float32))
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)

    d = lab.utils.quat_to_mat(controller_orient, p[0])
    viewer.draw(displacement_asset, d)

    for i in range(K_NEIGHBORS):
      
        q, p = pose_to_qp(states_x[indices[i]])
        a = lab.utils.quat_to_mat(q, p)
        character.materials()[0].set_albedo(np.ones(3, dtype=np.float32) * float(weights[i]))
        if i == best_i:
            character.materials()[0].set_albedo(np.array([0,1,0], dtype=np.float32))
        a[0, 0, 3] = i * 40
        viewer.draw(character, a)
    
    viewer.end_display()

    viewer.disable(depth_test=True)

    viewer.execute_commands()
    
interact(
    render, 
    frame=widgets.Play(interval=1000./30.), #lab.Timeline(max=100),
)
viewer

# Using the value function

The value function provides a smooth estimate of the future reward from any possible state.  



**What is a state?**  
A state in the motion field is not only the physical **motion state** $m = (x,v)$ (pose + velocity), but also includes the **task goal parameters** $\theta_T$.  
Thus, a full task state is
$$
s = (m, \theta_T).
$$

- $x$: joint configuration (root position + joint orientations)  
- $v$: finite-difference velocity  
- $\theta_T$: high-level control parameters (e.g. desired direction, target location)



**How is the value function stored?**  
- We store one scalar $V(s)$ for every database state $s$.  
- These stored values form anchor points scattered throughout the motion field.  
- Because of the dense motion data and interpolation, the value function behaves as if it were continuous, even though it is only sampled at discrete states.



**How do we read future values?**  
When we take an action $a$ from a state $s$:

1. Integrate one step forward to obtain the predicted next state:
   $$
   s' = I_s(s,a).
   $$

2. Perform a k-NN search around $s'$ to find the closest database states $\{s_j\}$.

3. Compute similarity weights:
   $$
   \omega_j = \frac{1/d(s',s_j)^2}{\sum_{\ell} 1/d(s',s_\ell)^2}.
   $$

4. Interpolate the stored values:
   $$
   V(s') \;\approx\; \sum_j \omega_j \, V(s_j).
   $$



**Key idea**  
Even though we only store values at the discrete database states, k-NN interpolation with similarity weights makes the value function **continuous across the motion field**.  
This means we can evaluate the long-term utility of any new, unseen state by smoothly blending the values of its nearest neighbors.


---
### Fitted value function training


**Step 1: Initialize values**  
For each database motion state $s$, we store a value
$$
V(s) \in \mathbb{R}.
$$
Initially, all values are set to zero.


**Step 2: Action rollout via k-NN**  
Compute the reward for each of the k actions>


**Step 3: Bellman update**  
For each database state $m$, update its value using the Bellman equation:
$$
V(m) \;\leftarrow\; \max_{a \in A(s)} \Big[ R(s,a) \;+\; \gamma \sum_j \omega_j V(s_j') \Big],
$$
where
- $R(m,a)$ is the immediate reward for taking action $a$ at state $s$,
- $\gamma \in (0,1)$ is the discount factor.


**Step 4: Iterate to convergence**  
Repeat the update across all states until the value function stabilizes.  
The result is a smooth value landscape over the motion field, enabling policies that anticipate future rewards rather than acting greedily.


In [None]:
%%time 

progress_output = widgets.Output(layout={'border': '1px solid black'})
display(progress_output)

# precompute all the states field transitions
K_NEIGHBORS = 15
TUG_RATIO_LEARNING = .1

all_states_actions_states_x = np.zeros((states_count, K_NEIGHBORS, bone_count+2, 4), dtype=np.float32)
all_states_actions_states_v = np.zeros((states_count, K_NEIGHBORS, bone_count+2, 4), dtype=np.float32)

all_states_actions_value_function_indices = np.zeros((states_count, K_NEIGHBORS, K_NEIGHBORS), dtype=np.int16)
all_states_actions_value_function_weights = np.zeros((states_count, K_NEIGHBORS, K_NEIGHBORS), dtype=np.float32)

def _build_precomputed_tables():

    with progress_output:

        display(f"get all future indices and weights")
        batched_query = np.zeros([K_NEIGHBORS, FEATURE_SHAPE[0], 3], dtype=np.float32)
        future_indices, future_weights = get_batched_k_neighbors(metric_matrix, K_NEIGHBORS)
        
        for i in range(states_count):
    
            progress_output.clear_output()
            display(f"state {i+1} / {states_count}")
        
            indices, weights = future_indices[i], future_weights[i]

            for n_idx in range(K_NEIGHBORS):
                w = weights.copy()
                w[n_idx] = 1.0
                w /= np.sum(w)

                # compute the next state, and see how much it rotates
                all_states_actions_states_x[i, n_idx, ...], all_states_actions_states_v[i, n_idx, ...] = compute_new_state(states_x[i], indices, w, TUG_RATIO_LEARNING)
                batched_query[n_idx, ...] = build_distance_metric(all_states_actions_states_x[i, n_idx, ...], all_states_actions_states_v[i, n_idx, ...])
        
            # find the next next
            all_states_actions_value_function_indices[i, :, :], all_states_actions_value_function_weights[i, :, :] = get_batched_k_neighbors(batched_query, K_NEIGHBORS)


    with open('motion_fields_precomputed_all_states_tables.dat', 'wb') as f:
        pickle.dump((all_states_actions_states_x, all_states_actions_states_v, all_states_actions_value_function_indices, all_states_actions_value_function_weights), f)

#_build_precomputed_tables()
with open('motion_fields_precomputed_all_states_tables.dat', 'rb') as f:
    (all_states_actions_states_x, all_states_actions_states_v, all_states_actions_value_function_indices, all_states_actions_value_function_weights) = pickle.load(f)

In [None]:
# all the task goals

theta_count = 17
thetas = np.linspace(-np.pi, np.pi, theta_count+1)[:theta_count]
theta_spacing = 2*np.pi / theta_count

In [None]:
# train the value function using Pytorch

EPOCH = 300
gamma = .99
pi = torch.pi

def _train(is_walking, factor):
    # output
    scores = torch.zeros([EPOCH, 3])
    V = torch.zeros([states_count, theta_count], device='cuda')
    
    # transfer to GPU
    tensor_all_states_action_delta_root_quat = torch.tensor(all_states_actions_states_x[:, :, 2, :], device='cuda')
    tensor_all_states_actions_value_function_indices = torch.tensor(all_states_actions_value_function_indices, device='cuda')
    tensor_all_states_actions_value_function_weights = torch.tensor(all_states_actions_value_function_weights, device='cuda')
    theta_grid = torch.tensor(thetas, device='cuda')

    # compute delta orientation
    delta_orientation = torch.atan2(2. * tensor_all_states_action_delta_root_quat[..., 0] * tensor_all_states_action_delta_root_quat[..., 2], 1.0 - 2. * tensor_all_states_action_delta_root_quat[...,2]**2.)

    # unused later
    del tensor_all_states_action_delta_root_quat

    # flag locomotion types
    state_score = torch.zeros([states_count], device='cuda')
    if is_walking:
        state_score[:end_of_walk_ids] = 1
    else:
        state_score[end_of_walk_ids:] = 1
    
    
    # compute rewards
    expand_theta_grid = theta_grid.view(1, 1, theta_count) #per state, per action, each theta
    
    expand_delta_orientation = delta_orientation.view(states_count, K_NEIGHBORS, 1)
    next_orientation = expand_theta_grid + expand_delta_orientation  # braodcast S, A, T
    next_orientation = (next_orientation + pi) % (2.0*pi) - pi
    
    rewards = -torch.abs(next_orientation) + state_score.unsqueeze(-1).unsqueeze(-1) * factor

    # compute accessors for interpolations
    step = (2*pi)/theta_count
    coord = (next_orientation +pi) / step                          # (S,A,T)
    i0 = torch.floor(coord).to(torch.long) % theta_count           # lower index
    i1 = (i0 + 1) % theta_count                                    # upper index
    alpha  = coord - torch.floor(coord)                            # weight to upper (in [0,1))
    
    i0_exp = i0.unsqueeze(2).expand(-1, -1, K_NEIGHBORS, -1)
    i1_exp = i1.unsqueeze(2).expand(-1, -1, K_NEIGHBORS, -1)

    # train
    for epoch in range(EPOCH):
        V_neighbors = V[tensor_all_states_actions_value_function_indices.to(torch.long)]
        
        V0 = torch.gather(V_neighbors, 3, i0_exp)
        V1 = torch.gather(V_neighbors, 3, i1_exp)
        
        # Task interpolation then motion interpolation:
        V_task = (1.0 - alpha).unsqueeze(2) * V0 + alpha.unsqueeze(2) * V1  # (S,A,J,Tθ)
        
        exp_next_V = (V_task * tensor_all_states_actions_value_function_weights.unsqueeze(-1)).sum(dim=2)     # (S,A,Tθ)
        
        # Bellman backup, then max over actions:
        Q = rewards + gamma * exp_next_V                            # (S,A,Tθ)
        V_prime = Q.max(dim=1).values                                 # (S,Tθ)
        
        bellman_residual = torch.abs(V-V_prime)
        scores[epoch, 0] = bellman_residual.min()
        scores[epoch, 1] = bellman_residual.max()
        scores[epoch, 2] = bellman_residual.mean()
        V = V_prime
    return V.cpu().numpy(), scores
    
value_function_walk, scores = _train(True, 1)
value_function_jog, scores = _train(False, 1)

In [None]:
x = np.arange(EPOCH)

fig, axs = plt.subplots(1, 1, figsize=(10, 4), sharex=True)

# First subplot
axs.plot(x, scores[:, 2], label='Mean', color='blue')
axs.fill_between(x, scores[:, 0], scores[:, 1], color='lightblue', alpha=0.4, label='Min-Max Range')
axs.set_title('Scores over Epoch (Mean and Range)')
axs.set_ylabel('Value')
axs.legend()
axs.grid(True)

plt.tight_layout()
plt.show()

### Motion Fields
---

In [None]:
last_indice = 0


current_x = states_x[last_indice, ...].copy()
current_v = states_v[last_indice, ...].copy()
q, p = pose_to_qp(current_x)
q, p = lab.utils.quat_fk(q, p, parents)
toes_positions = p[[bones.index('LeftToe'), bones.index('LeftToe')], ...]
toes_ratio = [0, 0]


def compute_theta_blend_factors(theta_prime):
    # find the rotation indices
    lower_id = int((theta_prime + np.pi) / theta_spacing)  % theta_count
    upper_id = (lower_id + 1) % theta_count
    theta_lower = theta_spacing*lower_id - np.pi
    angle_offset = (theta_prime - theta_lower + 2 * np.pi) % (2 * np.pi)
    t = angle_offset / theta_spacing

    return lower_id, upper_id, t


def render(frame, on_spot=False, display_debug=False):

    controller_dir = np.array([1,0,0,0], dtype=np.float32)
    posx = gamepad.axes[0].value 
    posz = -gamepad.axes[1].value 
    if np.abs(posx) > 0.001 or np.abs(posz) > 0.001:
        angle = np.atan2(posz, posx)
        controller_dir[0] = np.cos(angle/2)
        controller_dir[2] = np.sin(angle/2)

    color = np.array([1,1,0], dtype=np.float32)
    value_function = value_function_walk
    if gamepad.buttons[1].value > 0.5:
        color = np.array([1,0,1], dtype=np.float32)
        value_function = value_function_jog

    desired_direction = lab.utils.quat_mul_vec(controller_dir, np.array([0,0,1], dtype=np.float32))

    if gamepad.buttons[0].value > 0.5 :
        rot = lab.utils.angle_axis_to_quat(.2, np.array([0,0,-1]))
        current_v[bones.index('Spine')+2] = lab.utils.quat_mul(rot, current_v[bones.index('Spine')+2])
        current_v[bones.index('Spine1')+2] = lab.utils.quat_mul(rot, current_v[bones.index('Spine1')+2])
        current_v[bones.index('Spine2')+2] = lab.utils.quat_mul(rot, current_v[bones.index('Spine2')+2])
        current_v[1, :3] += lab.utils.quat_mul_vec(current_v[2], np.array([0,0,-10]))

        current_x[bones.index('Spine')+2] = lab.utils.quat_mul(rot, current_x[bones.index('Spine')+2])
        current_x[bones.index('Spine1')+2] = lab.utils.quat_mul(rot, current_x[bones.index('Spine1')+2])
        current_x[bones.index('Spine2')+2] = lab.utils.quat_mul(rot, current_x[bones.index('Spine2')+2])
        current_x[1, :3] += lab.utils.quat_mul_vec(current_v[2], np.array([0,0,-10]))

    
    # compute the n closest
    indices, weights = get_k_neighbors(current_x, current_v, K_NEIGHBORS)

    # get all the possible actions to query
    batched_query = np.zeros([K_NEIGHBORS, FEATURE_SHAPE[0], 3], dtype=np.float32)
    batched_theta_prime = np.zeros([K_NEIGHBORS], dtype=np.float32)
    for n_idx in range(K_NEIGHBORS):
        w = weights.copy()
        w[n_idx] = 1.0
        w /= np.sum(w)
    
        # compute the next state, and see how much it rotates
        x, v = compute_new_state(current_x, indices, w, TUG_RATIO_LEARNING)
        n = pose_unpack(x)
        batched_query[n_idx, ...] = build_distance_metric(x, v)
    
        orientation = lab.utils.quat_mul_vec(n.quats[0], np.array([0,0,1], dtype=np.float32))
        batched_theta_prime[n_idx] = signed_angle(orientation, desired_direction)

    next_indices, next_weights = get_batched_k_neighbors(batched_query, K_NEIGHBORS)

    future_rewards = np.zeros(K_NEIGHBORS)
    for n_idx in range(K_NEIGHBORS):

        # find the indices for the future
        theta_prime = batched_theta_prime[n_idx]
        d_lower_id, d_upper_id, d_t = compute_theta_blend_factors(theta_prime)

        #compute future value
        dl = np.sum( value_function[next_indices[n_idx], d_lower_id] * next_weights[n_idx])
        du = np.sum( value_function[next_indices[n_idx], d_upper_id] * next_weights[n_idx])

        future_reward = (1.0-d_t) * dl + d_t * du

        # store the rewards
        future_rewards[n_idx] = future_reward

    # apply the best
    best_i = np.argmax(future_rewards)
    weights[best_i] = 1.0
    weights /= np.sum(weights)
    current_x[...], current_v[...] = compute_new_state(current_x, indices, weights, .1)
    
    lock = np.sum(states_c[indices] * weights[:, None], axis=0 )

    # if display_debug:
    #     display((indices, future_rewards, best_i))
    
    q, p = pose_to_qp(current_x)

    gq, gp = lab.utils.quat_fk(q, p, parents)
    left_foot_p_current = gp[bones.index('LeftToe')]
    right_foot_p_current = gp[bones.index('RightToe')]

    if lock[0] < 0.9 :
        toes_ratio[0] = min(1.0, toes_ratio[0] + 0.2)
        toes_positions[0, ...] = toes_positions[0, ...] * (1.0-toes_ratio[0]) + left_foot_p_current * toes_ratio[0]
    else :
        toes_ratio[0] = 0
    if lock[1] < 0.9 :
        toes_ratio[1] = min(1.0, toes_ratio[1] + 0.2)
        toes_positions[1, ...] = toes_positions[1, ...] * (1.0-toes_ratio[1]) + right_foot_p_current * toes_ratio[1]
    else:
        toes_ratio[1] = 0

    id_toes = [bones.index('LeftToe'), bones.index('RightToe')]
    id_feet = [bones.index('LeftFoot'), bones.index('RightFoot')]
    relative_foot = lab.utils.qp_mul(lab.utils.qp_inv((gq[id_toes], gp[id_toes])), ((gq[id_feet], gp[id_feet])))
    new_foot = lab.utils.qp_mul((gq[id_toes], toes_positions), relative_foot)

    q, p = lab.utils.limb_ik(q[None,...], p[None,...], parents, bones, new_foot[0][None,...], new_foot[1][None,...] )
    q = q[0]
    p = p[0]
    
    if on_spot:
        p[0] = [0, 0, 200]
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])

    character.materials()[0].set_albedo(np.array([.4, .4, .8], dtype=np.float32))
    displacement_asset.materials()[0].set_albedo(color)
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)

    d = lab.utils.quat_to_mat(controller_dir, p[0])
    viewer.draw(displacement_asset, d)

    if display_debug:
        for i in range(K_NEIGHBORS):
            q, p = pose_to_qp(states_x[indices[i]])
            a = lab.utils.quat_to_mat(q, p)
            character.materials()[0].set_albedo(np.ones(3, dtype=np.float32) * float(weights[i]))
            a[0, 0, 3] = i * 40
            viewer.draw(character, a)
        
    viewer.end_display()

    viewer.disable(depth_test=True)
    
    viewer.execute_commands()
    
interact(
    render, 
    frame=widgets.Play(interval=1000./5.), #lab.Timeline(max=100),
)
viewer