# Near-optimal Character Animation with Continuous Control

by Adrien Treuille, Yongjoon Lee, Zoran Popovic,
2007

Notebook by Jerome Eippers, 2025

In [None]:
%matplotlib widget
import pickle
from dataclasses import dataclass, field
from random import randrange
import numpy as np
from ipywidgets import widgets, interact
from matplotlib import pyplot as plt
from matplotlib import colors as plt_color
import ipyanimlab as lab
from scipy.optimize import linprog

viewer = lab.Viewer(move_speed=5, width=1280, height=720)

## Load

In [None]:
# load the character
character = viewer.import_usd_asset('AnimLabSimpleMale.usd')
character.add_bone('LeftHeel', np.array([1,0,0,0]), np.array([9.2,0,-12]), 'LeftFoot')
character.add_bone('LeftBall', np.array([1,0,0,0]), np.array([14.5,0,8.22]), 'LeftFoot')
character.add_bone('RightHeel', np.array([1,0,0,0]), np.array([-9.2,0,-12]), 'RightFoot')
character.add_bone('RightBall', np.array([1,0,0,0]), np.array([-14.5,0,8.22]), 'RightFoot')

left_heel = character.bone_index('LeftHeel')
left_ball = character.bone_index('LeftBall')
right_heel = character.bone_index('RightHeel')
right_ball = character.bone_index('RightBall')
left_foot = character.bone_index('LeftFoot')
right_foot = character.bone_index('RightFoot')
left_toe = character.bone_index('LeftToe')
right_toe = character.bone_index('RightToe')
foottag_indices = np.asarray([left_heel, left_ball, right_heel, right_ball], dtype=np.int8)
print(foottag_indices)

In [None]:
direction = viewer.import_usd_asset('../../meshes/displacement.usd')

In [None]:
target = viewer.create_asset(
    vertices = np.asarray([[9.37, 0.58, -6.81, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [3.58, 0.58, -11.01, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [-3.58, 0.58, -11.01, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [-9.37, 0.58, -6.81, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [-11.58, 0.58, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [-9.37, 0.58, 6.81, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [-3.58, 0.58, 11.01, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [3.58, 0.58, 11.01, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [9.37, 0.58, 6.81, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [11.58, 0.58, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.58, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]], dtype=np.float32),
    indices = np.asarray([[1, 2, 10], [3, 4, 10], [5, 6, 10], [7, 8, 10], [9, 0, 10]], dtype=np.int16)
)

In [None]:
import pickle
with open('near_optimal_character_animations.dat', 'rb') as f:
    (bones, parents, all_clips) = pickle.load(f)
    
bone_count = character.bone_count()

In [None]:
def render(frame):
    
    q = (all_clips[0][0][frame,...])
    p = (all_clips[0][1][frame,...])
        
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    
    viewer.end_display()

    viewer.disable(depth_test=True)
   
    viewer.draw_axis(character.world_skeleton_xforms(a), 5)
    viewer.draw_lines(character.world_skeleton_lines(a))
    
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=all_clips[0][0].shape[0]-1)
)
viewer

## Motion Model

Precompute all the Clips, their timings, and the contact constraints

In [None]:
MAX_STEP_LEN = 70
LEFT_CLIP_COUNT = sum((clip[2].shape[0] for clip in all_clips))
RIGHT_CLIP_COUNT = sum((clip[3].shape[0] for clip in all_clips))
CLIP_COUNT = LEFT_CLIP_COUNT + RIGHT_CLIP_COUNT
CLIP_COUNT

In [None]:
clips_q = np.array([1,0,0,0], dtype=np.float32)[np.newaxis,...].repeat(CLIP_COUNT * MAX_STEP_LEN * bone_count, axis=0).reshape(CLIP_COUNT, MAX_STEP_LEN, bone_count, 4)
clips_p = np.array([0,0,0], dtype=np.float32).repeat(CLIP_COUNT* MAX_STEP_LEN * bone_count).reshape(CLIP_COUNT, MAX_STEP_LEN, bone_count, 3)
clips_timings = np.zeros([CLIP_COUNT, 5], dtype=np.uint32)

In [None]:
def compute_root(q, p):
    g_q, g_p = lab.utils.quat_fk(q, p, parents)
    v = lab.utils.quat_mul_vec(g_q[:, character.bone_index('Hips'), :], np.array([0,1,0]))
    #v = g_p[:, character.bone_index('RightHand'), :] - g_p[:, character.bone_index('Spine2'), :]
    angle = np.atan2(v[:, 0], v[:, 2])
    g_q[:, 0, 0] = np.cos(angle/2)
    g_q[:, 0, 2] = np.sin(angle/2)
    q, p = lab.utils.quat_ik(g_q, g_p, parents)
    q[:, 0, :], p[:, 0, :] = lab.utils.qp_mul(lab.utils.qp_inv((q[0:1, 0, :], p[0:1, 0, :])), (q[:, 0, :], p[:, 0, :]))
    return q, p

def compute_clip(quats, pos, ranges):
      
    aq = quats[ranges[0]:ranges[4],...].copy()
    ap = pos[ranges[0]:ranges[4],...].copy()

    q = np.array([1,0,0,0], dtype=np.float32)[np.newaxis,...].repeat(MAX_STEP_LEN * bone_count, axis=0).reshape(MAX_STEP_LEN, bone_count, 4)
    p = np.array([0,0,0], dtype=np.float32).repeat(MAX_STEP_LEN * bone_count).reshape(MAX_STEP_LEN, bone_count, 3)

    q[:aq.shape[0], ...] = aq
    p[:ap.shape[0], ...] = ap
    
    iq, ip = lab.utils.qp_inv((q[0,0], p[0,0]))
    
    q[:,0], p[:,0] = lab.utils.qp_mul(
        (iq[np.newaxis,...], ip[np.newaxis,...]),
        (q[:,0],p[:,0])
    )

    return compute_root(q, p)

In [None]:
left_clip_count = 0
right_clip_count = 0

for clip in all_clips:
    quats, pos, l_ranges, r_ranges = clip
    
    for ranges in l_ranges:
        left_anim = compute_clip(quats, pos, ranges)
        clips_q[left_clip_count, ...], clips_p[left_clip_count, ...] = left_anim
        clips_timings[left_clip_count, 4] = ranges[0]
        clips_timings[left_clip_count, :4] = ranges[1:] - ranges[0]
        left_clip_count += 1
        
        
    for ranges in r_ranges:
        right_anim = compute_clip(quats, pos, ranges)
        clips_q[right_clip_count + LEFT_CLIP_COUNT, ...], clips_p[right_clip_count + LEFT_CLIP_COUNT, ...] = right_anim
        clips_timings[right_clip_count + LEFT_CLIP_COUNT, 4] = ranges[0]
        clips_timings[right_clip_count + LEFT_CLIP_COUNT, :4] = ranges[1:] - ranges[0]
        right_clip_count += 1

In [None]:
def compute_constraint_qp(gpos, frame, foot_id, toe_id):
    vec = gpos[frame, toe_id, :] - gpos[frame, foot_id, :]
    angle = np.arctan2(vec[0], vec[2])/2
    q = np.zeros(4, dtype=np.float32)
    p = np.zeros(3, dtype=np.float32)
    q[0] = np.cos(angle)
    q[2] = np.sin(angle)
    p[[0,2]] = gpos[frame, foot_id, [0,2]]
    return q, p

clips_constraints_q = np.zeros([clips_q.shape[0], 2, 4], dtype=np.float32)
clips_constraints_p = np.zeros([clips_q.shape[0], 2, 3], dtype=np.float32)

for i in range(clips_q.shape[0]):
    _, gpos = lab.utils.quat_fk(clips_q[i, ...], clips_p[i, ...], parents)
    if i < LEFT_CLIP_COUNT:
        clips_constraints_q[i, 0, :], clips_constraints_p[i, 0, :] = compute_constraint_qp(
            gpos, 
            clips_timings[i, 0],
            foottag_indices[0],
            foottag_indices[1]
        )
        clips_constraints_q[i, 1, :], clips_constraints_p[i, 1, :] = compute_constraint_qp(
            gpos, 
            clips_timings[i, 2],
            foottag_indices[2],
            foottag_indices[3]
        )
    else:
        clips_constraints_q[i, 0, :], clips_constraints_p[i, 0, :] = compute_constraint_qp(
            gpos, 
            clips_timings[i, 0],
            foottag_indices[2],
            foottag_indices[3]
        )
        clips_constraints_q[i, 1, :], clips_constraints_p[i, 1, :] = compute_constraint_qp(
            gpos, 
            clips_timings[i, 2],
            foottag_indices[0],
            foottag_indices[1]
        )

In [None]:
def render(frame, clip_id=0):

    q = clips_q[clip_id, frame, :].copy()
    p = clips_p[clip_id, frame, :].copy()
    
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)

    contacts_matrices = np.eye(4, dtype=np.float32)[np.newaxis,...].repeat(2, axis=0)
    contacts_matrices = lab.utils.quat_to_mat( clips_constraints_q[clip_id, ...],  clips_constraints_p[clip_id, ...] )

    viewer.draw(target, contacts_matrices)
    
    viewer.end_display()

    viewer.disable(depth_test=True)

    viewer.draw_axis(contacts_matrices, 20)
    viewer.draw_axis(character.world_skeleton_xforms(a), 5)
    viewer.draw_lines(character.world_skeleton_lines(a))
    
    viewer.execute_commands()
    display(clips_timings[clip_id])
    
interact(
    render, 
    frame=lab.Timeline(max=MAX_STEP_LEN-1),
    clip_id=widgets.IntSlider(max=CLIP_COUNT-1)
)
viewer

## Control

In [None]:
class ClipPlayer:
    def __init__(self, clip_id):
        self.clip_id = clip_id
        self.frame = -1
        self.start_at_out_frame = 0
        self.start_clip_frame = 0
        self.blend_in_frame_count = 0
        self.quaternions = clips_q[self.clip_id, :, :].copy()
        self.positions = clips_p[self.clip_id, :, :].copy()      
        
    def align_to_out(self, out_clip):
        pre_contact_blend_time = min(clips_timings[out_clip.clip_id, 2] - clips_timings[out_clip.clip_id, 1], clips_timings[self.clip_id, 0])
        post_contact_blend_time = min(clips_timings[out_clip.clip_id, 3] - clips_timings[out_clip.clip_id, 2], clips_timings[self.clip_id, 1] - clips_timings[self.clip_id, 0])
        
        self.start_at_out_frame = clips_timings[out_clip.clip_id, 2] - pre_contact_blend_time
        self.start_clip_frame = clips_timings[self.clip_id, 0] - pre_contact_blend_time
        self.blend_in_frame_count = pre_contact_blend_time + post_contact_blend_time
        self.frame = self.start_clip_frame

        #align motion
        iq, ip = lab.utils.qp_inv((clips_constraints_q[self.clip_id, 0], clips_constraints_p[self.clip_id, 0]))
        self.quaternions[:,0], self.positions[:,0] = lab.utils.qp_mul(
            (iq[np.newaxis,...], ip[np.newaxis,...]),
            (self.quaternions[:,0],self.positions[:,0])
        )
        q, p = lab.utils.qp_mul(
            (out_clip.quaternions[0, 0], out_clip.positions[0, 0]),
            (clips_constraints_q[out_clip.clip_id, 1], clips_constraints_p[out_clip.clip_id, 1]),
        )
        self.quaternions[:,0], self.positions[:,0] = lab.utils.qp_mul(
            (q[np.newaxis,...], p[np.newaxis,...]),
            (self.quaternions[:,0],self.positions[:,0])
        )
        

    def tick(self, forced_frame=None):
        if forced_frame is not None:
            self.frame = forced_frame
        elif self.frame < clips_timings[self.clip_id, 3]-1:
            self.frame += 1

        
class Player:
    def __init__(self):
        self.current_clip = None
        self.next_clip = None
        self.quaternions = np.array([1,0,0,0], dtype=np.float32)[np.newaxis,...].repeat(bone_count, axis=0)
        self.positions = np.zeros([bone_count, 3], dtype=np.float32)
        self.last_clip_position = np.zeros([3], dtype=np.float32)

    def set_next_clip(self, clip_id):
        if self.current_clip is None:
            self.current_clip = ClipPlayer(clip_id)
        else:
            self.next_clip = ClipPlayer(clip_id)
            self.next_clip.align_to_out(self.current_clip)
        self.last_clip_position = self.positions[0, :]

    def tick(self):
        if self.current_clip is not None:
            self.current_clip.tick()

            self.quaternions = self.current_clip.quaternions[self.current_clip.frame]
            self.positions = self.current_clip.positions[self.current_clip.frame]

            if self.next_clip is not None:
                if self.current_clip.frame >= self.next_clip.start_at_out_frame:
                    tick_frame = self.next_clip.start_clip_frame + self.current_clip.frame - self.next_clip.start_at_out_frame
                    self.next_clip.tick(tick_frame)

                    t = float(self.current_clip.frame - self.next_clip.start_at_out_frame) / float(self.next_clip.blend_in_frame_count)
                    self.quaternions = lab.utils.quat_slerp(self.current_clip.quaternions[self.current_clip.frame], self.next_clip.quaternions[self.next_clip.frame], t)
                    self.positions = (1.0-t) * self.current_clip.positions[self.current_clip.frame] + (t) * self.next_clip.positions[self.next_clip.frame]
                    
                    if t >= .99 or self.next_clip.frame >= clips_timings[self.next_clip.clip_id, 1] or self.current_clip.frame >= clips_timings[self.current_clip.clip_id, 3] - 1 :
                        self.current_clip = self.next_clip
                        self.next_clip = None                 
                    

In [None]:
player = Player()
player.set_next_clip(0)

def render(frame):

    if player.next_clip is None:
        if player.current_clip.clip_id < LEFT_CLIP_COUNT:
            player.set_next_clip(randrange(LEFT_CLIP_COUNT, CLIP_COUNT))
        else:
            player.set_next_clip(randrange(0, LEFT_CLIP_COUNT))
    
    player.tick()
    
    q = player.quaternions
    p = player.positions
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    viewer.end_display()

    viewer.disable(depth_test=True)

    a = lab.utils.quat_to_mat(player.current_clip.quaternions[player.current_clip.frame], player.current_clip.positions[player.current_clip.frame])
    viewer.draw_lines(character.world_skeleton_lines(a), np.array([1,0,0], dtype=np.float32))
    viewer.draw_axis(character.world_skeleton_xforms(a), 5)
    if player.next_clip is not None:
        f = max(player.next_clip.frame, 0)
        a = lab.utils.quat_to_mat(player.next_clip.quaternions[f], player.next_clip.positions[f])
        viewer.draw_lines(character.world_skeleton_lines(a), np.array([0,1,0], dtype=np.float32))
        viewer.draw_axis(character.world_skeleton_xforms(a), 5)
    
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=MAX_STEP_LEN-1)
)
viewer

### Cost

In [None]:
%%time

progress_output = widgets.Output(layout={'border': '1px solid black'})
display(progress_output)

physics_costs = np.zeros([clips_q.shape[0], clips_q.shape[0]])
delta_theta = np.zeros([clips_q.shape[0], clips_q.shape[0]])
delta_x = np.zeros([clips_q.shape[0], clips_q.shape[0]])
delta_z = np.zeros([clips_q.shape[0], clips_q.shape[0]])

bone_filter = [bones.index("Hips"), bones.index("Spine2"), bones.index("LeftUpLeg"), bones.index("LeftLeg"), bones.index("LeftFoot"), bones.index("RightUpLeg"), bones.index("RightLeg"), bones.index("RightFoot")]

def pre_compute_transitions_costs():
    for i in range(clips_q.shape[0]):
        with progress_output:
            progress_output.clear_output()
            display(f"clip {i} / {clips_q.shape[0]}")

        a = ClipPlayer(i)
        a.quaternions[:, 0], a.positions[:, 0] = lab.utils.qp_mul(
            lab.utils.qp_inv((a.quaternions[clips_timings[i, 1], 0][np.newaxis,...], a.positions[clips_timings[i, 1][np.newaxis,...], 0])),
            (a.quaternions[:, 0], a.positions[:, 0])
        )
        aq, ap = lab.utils.quat_fk(a.quaternions, a.positions, parents)

        for j in range(clips_q.shape[0]):

            b = ClipPlayer(j)
            b.align_to_out(a)

            # compute cost
            bq, bp = lab.utils.quat_fk(b.quaternions, b.positions, parents)
            for k in range(b.blend_in_frame_count):
                v = ap[b.start_at_out_frame+k, bone_filter, ...] - bp[b.start_clip_frame+k, bone_filter, ...]
                physics_costs[i, j] += np.sum(np.sqrt(np.sum(v*v, axis=1)))
            physics_costs[i, j] /= b.blend_in_frame_count

            # motion directions
            # va = ap[clips_timings[i, 2], 0] - ap[clips_timings[i, 1], 0]
            # vb = bp[clips_timings[j, 0], 0] - bp[0, 0]
            # va_n = np.linalg.norm(va)
            # vb_n =  np.linalg.norm(vb)
            # physics_costs[i, j] += np.abs(va_n - vb_n) * 100
            
            # compute deltas
            delta_x[i, j] = bp[clips_timings[j, 1], 0, 0] * .01
            delta_z[i, j] = bp[clips_timings[j, 1], 0, 2] * .01
            delta_theta[i, j] = np.atan2(
                2 * bq[clips_timings[j, 1], 0, 0] * bq[clips_timings[j, 1], 0, 2], 
                1.0 - (2 * bq[clips_timings[j, 1], 0, 2] * bq[clips_timings[j, 1], 0, 2])
            )


    # normalize
    physics_costs[:,:] /= physics_costs.mean()

    # discard the impossible transitions between feet
    physics_costs[0:LEFT_CLIP_COUNT, 0:LEFT_CLIP_COUNT] = 10000
    physics_costs[LEFT_CLIP_COUNT:, LEFT_CLIP_COUNT:] = 10000
    
    with open('near_optimal_character_animation_transitions.dat', 'wb') as f:
        pickle.dump((physics_costs, delta_theta, delta_x, delta_z), f)

pre_compute_transitions_costs()

In [None]:
player = Player()
player.set_next_clip(0)

def render(frame):

    if player.next_clip is None:
        player.set_next_clip(int(np.argmin(physics_costs[player.current_clip.clip_id, :])))
    
    player.tick()
    
    q = player.quaternions
    p = player.positions
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    viewer.end_display()

    viewer.disable(depth_test=True)

    a = lab.utils.quat_to_mat(player.current_clip.quaternions[player.current_clip.frame], player.current_clip.positions[player.current_clip.frame])
    viewer.draw_lines(character.world_skeleton_lines(a), np.array([1,0,0], dtype=np.float32))
    viewer.draw_axis(character.world_skeleton_xforms(a))
    if player.next_clip is not None:
        f = max(player.next_clip.frame, 0)
        a = lab.utils.quat_to_mat(player.next_clip.quaternions[f], player.next_clip.positions[f])
        viewer.draw_lines(character.world_skeleton_lines(a), np.array([0,1,0], dtype=np.float32))

        display((player.current_clip.clip_id, player.next_clip.clip_id, physics_costs[player.current_clip.clip_id, player.next_clip.clip_id]))
    else:
        display((player.current_clip.clip_id))
    
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=MAX_STEP_LEN-1)
)
viewer

### Transition Function

given by:

$$
f(X, C') = X' = \begin{bmatrix}
C' \\
x + \cos(\theta)\Delta x + \sin(\theta)\Delta z \\
z - \sin(\theta)\Delta x + \cos(\theta)\Delta z \\
\theta + \Delta \theta \\
\end{bmatrix}
$$


In [None]:
def transition (clip_id, x, z, theta, next_clip_id):
    x_prime = x + np.sin(theta) * delta_z[clip_id, next_clip_id] + np.cos(theta) * delta_x[clip_id, next_clip_id]
    z_prime = z + np.cos(theta) * delta_z[clip_id, next_clip_id] - np.sin(theta) * delta_x[clip_id, next_clip_id]
    theta_prime = theta + delta_theta[clip_id, next_clip_id]
    theta_prime[theta_prime < -np.pi] += np.pi*2
    theta_prime[theta_prime > np.pi] -= np.pi*2
    return next_clip_id, x_prime, z_prime, theta_prime

## Greedy policy

In [None]:
def direction_cost (x, z, x_primes, z_primes):
    cost = np.abs(np.atan2(x_primes-x, z_primes-z))
    cost[z_primes<.1] *= 3
    return cost

In [None]:
def greedy_policy(start_clip, x, theta, deviation_factor, physic_factor, direction_factor, debug=None):

    _, x_prime, z_prime, theta_prime = transition(start_clip, x, 0, theta, ...)

    deviation_c = np.abs(x_prime) * deviation_factor
    physic_c = physics_costs[start_clip, :] * physic_factor
    direction_c = direction_cost(x, 0, x_prime, z_prime) * direction_factor

    cost = deviation_c + physic_c + direction_c
    
    # take the minimum
    picked_clip = int(np.argmin(cost))

    if debug is not None:
        debug["clip"] = start_clip
        debug["x"] = x
        debug["theta"] = theta
        debug["picked_clip"] = picked_clip
        debug["x_prime"] = x_prime[picked_clip]
        debug["z_prime"] = z_prime[picked_clip]
        debug["theta_prime"] = theta_prime[picked_clip]
        debug["cost"] = cost[picked_clip]
        debug["deviation_c"] = deviation_c[picked_clip]
        debug["physic_c"] = physic_c[picked_clip]
        debug["direction_c"] = direction_c[picked_clip]

    return picked_clip, (x_prime[picked_clip], z_prime[picked_clip], theta_prime[picked_clip], cost[picked_clip])

In [None]:
gamepad = widgets.Controller(index=0)
gamepad

In [None]:
player = Player()
player.set_next_clip(0)

debug = {}

def render(frame, deviation_factor=1.0, physic_factor=.8, direction_factor=1.0):

    global debug
    
    controller_orient = np.array([1,0,0,0], dtype=np.float32)
    posx = gamepad.axes[0].value 
    posz = -gamepad.axes[1].value 
    if np.abs(posx) > 0.001 or np.abs(posz) > 0.001:
        angle = np.atan2(posz, posx)
        controller_orient[0] = np.cos(angle/2)
        controller_orient[2] = np.sin(angle/2)

    if player.next_clip is None and clips_timings[player.current_clip.clip_id, 1] <= player.current_clip.frame:
        f = clips_timings[player.current_clip.clip_id, 1]

        x = player.current_clip.positions[f, 0, 0] * 0.01

        q = lab.utils.quat_mul(lab.utils.quat_inv(controller_orient), player.current_clip.quaternions[f, 0, :])
        theta =  np.atan2(
            2 * q[0] * q[2], 
            1.0 - (2 * q[2] * q[2])
        )

        x = 0

        next_clip_id, _ = greedy_policy(
            player.current_clip.clip_id,
            x,
            theta,
            deviation_factor, 
            physic_factor, 
            direction_factor,
            debug
        )
        player.set_next_clip(next_clip_id)
    
    player.tick()
    
    q = player.quaternions
    p = player.positions
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    
    d = lab.utils.quat_to_mat(controller_orient, p[0])
    viewer.draw(direction, d)
    
    viewer.end_display()

    viewer.disable(depth_test=True)

    a = lab.utils.quat_to_mat(player.current_clip.quaternions[player.current_clip.frame], player.current_clip.positions[player.current_clip.frame])
    viewer.draw_lines(character.world_skeleton_lines(a), np.array([1,0,0], dtype=np.float32))
    viewer.draw_axis(character.world_skeleton_xforms(a))
    if player.next_clip is not None:
        f = max(player.next_clip.frame, 0)
        a = lab.utils.quat_to_mat(player.next_clip.quaternions[f], player.next_clip.positions[f])
        viewer.draw_lines(character.world_skeleton_lines(a), np.array([0,1,0], dtype=np.float32))

    
    viewer.execute_commands()

    #display(debug)
    
interact(
    render, 
    frame=lab.Timeline(max=MAX_STEP_LEN-1),
    deviation_factor=widgets.FloatSlider(1, min=0, max=1), 
    physics_factor=widgets.FloatSlider(1, min=0, max=1), 
    direction_factor=widgets.FloatSlider(1, min=0, max=1),
    torso_factor=widgets.FloatSlider(1, min=0, max=1)
)
viewer

## Optimal policy
#### Recursive Formulation

The recursive formulation of the cost is:

$$
V(X) = c_s(X) + c_t(X, X') + \alpha V(X')
$$

#### Value Function and Optimal Policy


Using the recursive relationship, we derive the Bellman equation for the value function:

$$
V(X) = c_s(X) + \min_{C' \in \mathcal{C}} \left[ c_t(X, X') + \alpha V(X') \right]
$$

The optimal policy $ \Pi^* $ then selects the action $ C' $ that minimizes the right-hand side:

$$
\Pi^*(X) = \arg\min_{C' \in \mathcal{C}} \left[ c_t(X, X') + \alpha V(X') \right]
$$


In [None]:
def compute_basis(x, theta):
    x_2 = x**2
    t_2 = theta**2
    return np.stack([np.ones_like(x), x, x_2, theta, theta*x, theta*x_2, t_2, t_2*x, t_2*x_2]).T

Basis_count = compute_basis(0, 0).shape[0]
Basis_count

In [None]:
def optimal_policy(coefficients, alpha, start_clip, x, theta, physic_factor=.9, direction_factor=.99, debug=None):

    _, x_prime, z_prime, theta_prime = transition(start_clip, x, 0, theta, ...)

    physic_c = physics_costs[start_clip, :] * physic_factor
    direction_c = direction_cost(x, 0, x_prime, z_prime) * direction_factor

    cost = physic_c + direction_c
    
    # future cost
    future_cost = alpha * np.einsum('ij,ij->i', coefficients,  compute_basis(x_prime, theta_prime))
    total_cost = cost + future_cost
    
    # take the minimum
    picked_clip = int(np.argmin(total_cost))

    if debug is not None:
        debug["clip"] = start_clip
        debug["x"] = x
        debug["theta"] = theta
        debug["picked_clip"] = picked_clip
        debug["x_prime"] = x_prime[picked_clip]
        debug["z_prime"] = z_prime[picked_clip]
        debug["theta_prime"] = theta_prime[picked_clip]
        debug["cost"] = cost[picked_clip]
        debug["total_cost"] = total_cost[picked_clip]
        debug["physic_c"] = physic_c[picked_clip]
        debug["direction_c"] = direction_c[picked_clip]
        debug["future_cost"] = future_cost[picked_clip]
        
    return picked_clip, (x_prime[picked_clip], z_prime[picked_clip], theta_prime[picked_clip], cost[picked_clip])

### Learning the Value Function
#### Approximating the Value Function with Linear Programming

Maximize the value function $ V(X) $ across all sampled states $ X \in \bar{S} $:

$$
\max_{V} \sum_{X \in \bar{S}} V(X)  
$$
$$
V(X) \leq c_s(X) + c_t(X, X') + \alpha V(X') \quad \forall (X, X') \in L
$$

where $ \bar{S} $ is the set of sampled states.


The optimization respects the Bellman inequality for all transitions $ (X, X') $ in the set of sampled transitions $ L $:


#### Basis Function Approximation

For continuous state spaces, computing the exact value function  $ V(X) $ is impractical. Instead, we approximate the value function using a set of basis functions.

We define a set of basis functions $ \Phi = [\phi_1, \phi_2, \dots, \phi_n] $, where each $ \phi_i : S \to \mathbb{R} $ is a function that can be evaluated in closed form. The value function $ V(X) $ is approximated as a linear combination of these basis functions:

$$
V(X) \approx r_1 \phi_1(X) + r_2 \phi_2(X) + \dots + r_n \phi_n(X) = \Phi(X) \cdot r
$$

where:
- $ r = [r_1, r_2, \dots, r_n] $ is the vector of coefficients to be determined.


Using the approximation in the original Bellman inequality, we get:

$$
\Phi(X) \cdot r \leq c_s(X) + c_t(X, X') + \alpha \Phi(X') \cdot r
$$


$$
\Phi(X) \cdot r - \alpha \Phi(X') \cdot r \leq c_s(X) + c_t(X, X')
$$


$$
\left( \Phi(X) - \alpha \Phi(X') \right) \cdot r \leq c_s(X) + c_t(X, X')
$$


#### Tabulation of coefficients for each clips

$$
\left( \Phi_c(X) - \alpha \Phi_{c'}(X') \right) \cdot r \leq c_s(X) + c_t(X, X')
$$

In [None]:
sample_x = np.linspace(-2, 2, 5)
sample_theta = np.linspace(-np.pi, np.pi, 5)
samples = []
for i in range(clips_q.shape[0]):
        for k in sample_x:
            for l in sample_theta:
                x = k
                theta = l
                samples.append((i, x, theta))
len(samples)

In [None]:
%%time
display(progress_output)

def learn_locomotion_value_function(max_iter, alpha=0.95, stop_constraint_criteria = 2, deviation_factor=1, physic_factor=.7, direction_factor=1.0):

    transitions = {}
    coefficients = np.zeros([clips_q.shape[0], Basis_count])
    c = np.zeros_like(coefficients)
    for i in range(0, len(samples)):
        clip_id, x, theta = samples[i] 
        c[clip_id, :] += compute_basis(x, theta)
    c = -c.flatten()

    with progress_output:

        progress_output.clear_output()
        
        for epoch in range(max_iter):
            progress_output.clear_output()
            display(f"epoch : [{epoch}], {len(transitions)} constraints")

            display(f"evaluating constraints")
            added_constraints = 0
        
            for i in range(0, len(samples)):
                clip_id, x, theta = samples[i]
                target_clip, new_state = optimal_policy(coefficients, alpha, clip_id, x, theta, physic_factor, direction_factor)
                if (i, target_clip) not in transitions:
                    transitions[(i, target_clip)] = new_state
                    added_constraints += 1


            if added_constraints < stop_constraint_criteria :
                display(f"only {added_constraints} new constrains, we stop")
                break

            display(f"{added_constraints} new constrains")
            
            display(f"build constraints")
            A = []
            b = []
        
            for states, new_state in transitions.items():
                sample_id, next_clip_id = states
                clip_id, x, theta = samples[sample_id]
                new_x, _, new_theta, cost = new_state
                
                lhs = np.zeros_like(coefficients)
                lhs[clip_id, :] += compute_basis(x, theta) 
    
                future_basis = compute_basis(new_x, new_theta)
                lhs[next_clip_id, :] -= alpha * (future_basis)
                rhs = cost + np.abs(x) * deviation_factor

                A.append(lhs.flatten())
                b.append(rhs)
                      
            # # make sure the cost function can only be positive
            for i in range(0, len(samples)):
                clip_id, x, theta = samples[i] 
                lhs = np.zeros_like(coefficients)
                lhs[clip_id, :] = -compute_basis(x, theta)
                A.append(lhs.flatten())
                b.append(0)
                            
            ######
            display(f"solving")
            
            solution = linprog(c, A, b, bounds=(None, None))

            if solution.success:
                coefficients = solution.x.reshape(clips_q.shape[0], Basis_count)

                display("solved")

            else :
                display("failed solve")
                display(solution)
                break
            
    return coefficients

coefficients_forward_0 = learn_locomotion_value_function(15)

In [None]:
clip_id = 3

vx, vy = np.meshgrid(np.linspace(-2, 2, 40), np.linspace(-3, 3, 40))
z = np.zeros([40, 40])
for i in range(40):
    for j in range(40):
        z[i, j] = np.dot(compute_basis(vx[i,j], vy[i,j]), coefficients_forward_0[clip_id])

import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(vx.flatten(), vy.flatten(), z.flatten(), label="left")
plt.show()

In [None]:
player = Player()
player.set_next_clip(0)
debug_struct = {}

In [None]:
controller_position = np.array([0,0,0], dtype=np.float32)

def render(frame, draw_debug=False):

    global debug_struct, controller_position

    
    controller_orient = np.array([1,0,0,0], dtype=np.float32)
    posx = gamepad.axes[0].value 
    posz = -gamepad.axes[1].value 
    if np.abs(posx) > 0.001 or np.abs(posz) > 0.001:
        angle = np.atan2(posz, posx)
        controller_orient[0] = np.cos(angle/2)
        controller_orient[2] = np.sin(angle/2)

    #play = gamepad.buttons[7].value > 0.1
    play = True

    if player.next_clip is None and clips_timings[player.current_clip.clip_id, 1] <= player.current_clip.frame:
        f = clips_timings[player.current_clip.clip_id, 1]

        q = lab.utils.quat_mul(lab.utils.quat_inv(controller_orient), player.current_clip.quaternions[f, 0, :])
        theta =  np.atan2(
            2 * q[0] * q[2], 
            1.0 - (2 * q[2] * q[2])
        )

        v = lab.utils.quat_mul_vec(controller_orient, np.array([1,0,0], dtype=np.float32))
        x = np.dot(v, player.current_clip.positions[f, 0, :] - controller_position) * 0.01
        
        next_clip_id, _ = optimal_policy(
            coefficients_forward_0,
            0.95,
            player.current_clip.clip_id,
            x,
            theta, 
            physic_factor=.99, direction_factor=1.,
            #debug = debug_struct
        )
        player.set_next_clip(next_clip_id)

    if play:
        player.tick()
        # we move the position along the direction of the wanted displacement
        # this allows us to push a x value of drift to the controller
        # we then slowly converge the wanted position to the actual position to avoid massive drift
        v = lab.utils.quat_mul_vec(controller_orient, np.array([0,0,1], dtype=np.float32))
        dist = np.dot(v, player.positions[0] - controller_position)
        controller_position += v * dist
        controller_position = (controller_position*.90 + player.positions[0]*.10)

    # if draw_debug :
    #     display(debug_struct)
        
    q = player.quaternions
    p = player.positions
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    d = lab.utils.quat_to_mat(controller_orient, p[0])
    viewer.draw(direction, d)
    
    viewer.end_display()

    viewer.disable(depth_test=True)

    if draw_debug:
        a = lab.utils.quat_to_mat(player.current_clip.quaternions[player.current_clip.frame], player.current_clip.positions[player.current_clip.frame])
        viewer.draw_lines(character.world_skeleton_lines(a), np.array([1,0,0], dtype=np.float32))
        viewer.draw_axis(character.world_skeleton_xforms(a))
        if player.next_clip is not None:
            f = max(player.next_clip.frame, 0)
            a = lab.utils.quat_to_mat(player.next_clip.quaternions[f], player.next_clip.positions[f])
            viewer.draw_lines(character.world_skeleton_lines(a), np.array([0,1,0], dtype=np.float32))

    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=MAX_STEP_LEN-1)
)
viewer