# Precomputing Avatar Behavior From Human Motion Data

by Jehee Lee and Kang Hoon Lee
Siggraph 2004

Notebook by Jerome Eippers, 2024

In [None]:
%matplotlib widget
from dataclasses import dataclass, field
import numpy as np
from ipywidgets import widgets, interact
from random import randrange
import pickle

import ipyanimlab as lab

viewer = lab.Viewer(move_speed=5, width=1280, height=720)

## Motion Graph

### Load Motion Graph Data

We will load the motion graph data from the Motion Graph notebook.  
In this we will have the animation, what frames are part of the animation and the local minima used used to create the transitions.

We will also load the character and add the extra foot contact bones, so we can quickly compute a simple foot locking mechanism.

In [None]:
# load the character

character = viewer.import_usd_asset('AnimLabSimpleMale.usd')
character.add_bone('LeftHeel', np.array([1,0,0,0]), np.array([9.2,0,-12]), 'LeftFoot')
character.add_bone('LeftBall', np.array([1,0,0,0]), np.array([14.5,0,8.22]), 'LeftFoot')
character.add_bone('RightHeel', np.array([1,0,0,0]), np.array([-9.2,0,-12]), 'RightFoot')
character.add_bone('RightBall', np.array([1,0,0,0]), np.array([-14.5,0,8.22]), 'RightFoot')

left_heel = character.bone_index('LeftHeel')
left_ball = character.bone_index('LeftBall')
right_heel = character.bone_index('RightHeel')
right_ball = character.bone_index('RightBall')
left_foot = character.bone_index('LeftFoot')
right_foot = character.bone_index('RightFoot')
foottag_indices = np.asarray([left_heel, left_ball, right_heel, right_ball], dtype=np.int8)
print(foottag_indices)

In [None]:
# create a sphere for debugging purpose

sphere = viewer.create_sphere(radius=10)

In [None]:
# load the pre computed motiongraph data

with open('motion_graph_walking_rawdata.dat', 'rb') as f:
    (animation, window_size, animation_frame_validity, animation_local_minima) = pickle.load(f)
    
animmap = lab.AnimMapper(character)
animation = animmap(animation)

frame_count = animation.quats.shape[0]
bone_count = character.bone_count()

In [None]:
# compute the foot contacts

gquats, gpos = lab.utils.quat_fk(animation.quats, animation.pos, animation.parents)

foot_tags = np.zeros([frame_count, 4], dtype=np.bool_)
foot_tags[:, :2], foot_tags[:, 2:] = lab.utils.extract_feet_contacts(gpos, [left_heel, left_ball], [right_heel, right_ball],  0.04)

# smooth out a little the signal ( we keep the signal on if it switch off for one or 2 frames )
for frame in range(2, frame_count-2):
    for c in range(4):
        foot_tags[frame, c] = foot_tags[frame, c] or (foot_tags[frame-2, c] and foot_tags[frame+2, c])
    
for frame in range(1, frame_count-1):
    for c in range(4):
        foot_tags[frame, c] = foot_tags[frame, c] or (foot_tags[frame-1, c] and foot_tags[frame+1, c])


In [None]:
def render(frame):
    
    anim = animation
    p = (anim.pos[frame,...])
    q = (anim.quats[frame,...])
        
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    viewer.end_display()

    viewer.disable(depth_test=True)
   
    viewer.draw_axis(character.world_skeleton_xforms(a), 5)
    viewer.draw_lines(character.world_skeleton_lines(a))
    
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=frame_count-1)
)
viewer

### Create the Motion Graph States

This time instead of list of states (with several frames) and edges (connecting frames),  
we will have :

* a list of states : a single frame where a choice can be made.
* for each state a list of actions : where each action is list of frames to play to reach the next state. 

In [None]:
@dataclass
class State:
    """A state in the graph."""
    frame: int = 0
    actions: list[int] = field(default_factory=list)
    reversed_actions: list[int] = field(default_factory=list)
    
@dataclass
class Action:
    """A directed action."""
    start: int = 0
    end: int = 0
    blend: bool = False
    is_valid_action: bool = False
    next_state_id: int = 0
    
# prepare the states data
states = []
for i in range(frame_count):
    states.append(State(frame=i))
    if animation_frame_validity[i]:
        states[i].actions.append(Action(start=i, end=i+1))
        
# add the transitions
horizontal, vertical = np.where(animation_local_minima == 1)
for i, j in zip(horizontal, vertical):
    if animation_frame_validity[i] and animation_frame_validity[j]:
        states[i].actions.append(Action(start=j-window_size, end=j, blend=True))

# starting from an action we check if the targetted state has only one action
# in that case we extend the action to target directly the next next state, ... and so on
def collapse_action(action):
    end = action.end
    while len(states[end].actions) == 1:
        end = states[end].actions[0].end
    if len(states[end].actions) > 1:
        action.end = end
        action.is_valid_action = True
        return True
    return False

# collapse actions
stack = next((list(state.actions) for state in states if len(state.actions) > 0 ))
while stack:
    action = stack.pop()
    if (action.is_valid_action == False and collapse_action(action)):
        stack += states[action.end].actions

# compute reverse actions
for state in states:
    for action in state.actions:
        states[action.end].reversed_actions.append(action)
        
# delete actions that are not valid, and all the states that are unreachable
for state in reversed(states):
    for action in reversed(state.actions):
        if action.is_valid_action == False:
            state.actions.remove(action)
    if not state.actions:
        states.remove(state)

# collapse all states that have only one action left that is not a transition
for state in reversed(states):
    if len(state.actions) == 1 and state.actions[0].start == state.frame:
        for action in state.reversed_actions:
            action.end = state.actions[0].end
        states.remove(state)
        
# clear the reverse
for state in states:
    state.reversed_actions.clear()
    
def find_state_index(frame):
    for state in states:
        if state.frame == frame:
            return states.index(state)
    return None

for state in states:
    for action in state.actions:
        action.next_state_id = find_state_index(action.end)
        
state_count = len(states)
        
# sort the states by frame
# (only for debug)
states = sorted(states, key=lambda n : n.frame)

display(state_count)

### Precompute root motion

For each action we can pre compute he root motion relative to the first frame.  
This will be usefull when playing the animation and also for the training process.

In [None]:
# compute trajectories of displacement for each actions
gquats, gpos = lab.utils.quat_fk(animation.quats, animation.pos, animation.parents)

for state in states:
    for action in state.actions:
        action.start = int(action.start)
        action.end = int(action.end)
        blend_frame_count = action.blend * window_size
        start_frame, end_frame = action.start+1, action.end+1
        
        if end_frame-start_frame <= 0:
            raise Exception(f"state {states.index(state)} action {action}")
                
        # bring anim relative to first frame
        boq, bop = animation.quats[start_frame-1,0,:], animation.pos[start_frame-1,0,:]
        boqi, bopi = lab.utils.qp_inv((boq, bop))
        bq, bp = animation.quats[start_frame:end_frame,0,:].copy(), animation.pos[start_frame:end_frame,0,:].copy()
        bq, bp = lab.utils.qp_mul((boqi[np.newaxis,...], bopi[np.newaxis,...]),(bq, bp))
    
        action.trajectory_quats = bq
        action.trajectory_pos = bp
        
        action.local_trajectory_pos = action.trajectory_pos.copy()
        action.local_trajectory_quats = action.trajectory_quats.copy()
        
        action.local_trajectory_quats[1:], action.local_trajectory_pos[1:] = lab.utils.qp_mul(
            lab.utils.qp_inv((action.local_trajectory_quats[:-1], action.local_trajectory_pos[:-1])),
            (action.local_trajectory_quats[1:], action.local_trajectory_pos[1:])
        )

### Animation Player

A simple player that can take an animation and plays frame after frame everytime we tick it.

In [None]:
class FootLock:
    def __init__(self, toe_id, ankle_boneid):
        self.toe_id = toe_id
        self.ankle_boneid = ankle_boneid
        self.ankle_pos = np.zeros([3], dtype=np.float32)
        self.ankle_quat = np.array([1,0,0,0], dtype=np.float32)
        self.toe_lock = False
        self.toe_position = np.zeros([3], dtype=np.float32)
        self.toe_error = np.zeros([3], dtype=np.float32)
        self.blend = 0
        
    def compute(self, positions, quaternions, frame):
        gq, gp = lab.utils.quat_fk(quaternions, positions, animation.parents)
        
        if foot_tags[frame, self.toe_id] :
            if self.toe_lock == False:
                self.toe_lock = True
                self.toe_position = gp[foottag_indices[self.toe_id]]
                self.toe_position[1] = 0
                self.blend = 0
            
            self.toe_error = (self.toe_position - gp[foottag_indices[self.toe_id]])
            t = 1.0
            if self.blend < 3:
                t = float(self.blend) / 4.0
                self.ankle_pos += self.toe_error * t
                self.blend += 1
            self.ankle_pos = gp[self.ankle_boneid] + self.toe_error * t
            self.ankle_quat = gq[self.ankle_boneid]
            
        else:
            if self.toe_lock:
                self.blend = 5
            self.toe_lock = False
            self.ankle_pos = gp[self.ankle_boneid]
            self.ankle_quat = gq[self.ankle_boneid]
            
            if self.blend > 0:
                t = float(self.blend) / 6.0
                self.ankle_pos += self.toe_error * t
                self.blend -= 1
            

class AnimPlayer:
    def __init__(self, start_state):
        self.current_action = None
        self.current_action_framei = 1
        self.positions = np.zeros([bone_count, 3], dtype=np.float32)
        self.quaternions = np.array([1,0,0,0], dtype=np.float32)[np.newaxis,...].repeat(bone_count, axis=0)
        self.last_played_frame = states[start_state].frame
        self.transition = None
        self.next_state_id = start_state
        self.left_foot = FootLock(1, left_foot)
        self.right_foot = FootLock(3, right_foot)
        
    def tick(self):
        if self.current_action is not None:
            last_root_q = self.quaternions[0].copy()
            last_root_p = self.positions[0].copy()
            
            self.last_played_frame = self.current_action.start + self.current_action_framei
            self.positions = animation.pos[self.last_played_frame, :, :].copy()
            self.quaternions = animation.quats[self.last_played_frame, :, :].copy()
            
            
            if self.transition is not None:
                from_frame, from_frame_i = self.transition
                from_frame_i += 1
                from_frame += 1
                if from_frame_i <= window_size:
                    t = float(from_frame_i)/float(window_size+1)
                    t = -2.0*t**3 + 3*t**2
                    
                    self.positions = (t) * self.positions + (1.0-t) * animation.pos[from_frame, :, :]
                    self.quaternions = lab.utils.quat_slerp(animation.quats[from_frame, :, :], self.quaternions, t)
                    
                    self.transition = (from_frame, from_frame_i)
                else:
                    self.transition = None
                    
            
            self.quaternions[0], self.positions[0] = lab.utils.qp_mul(
                (last_root_q, last_root_p),
                (self.current_action.local_trajectory_quats[self.current_action_framei-1], self.current_action.local_trajectory_pos[self.current_action_framei-1])
            )
            
            self.current_action_framei += 1
            if self.current_action.start + self.current_action_framei > self.current_action.end :
                self.current_action = None
                self.current_action_framei = 0
                
            self.left_foot.compute(self.positions, self.quaternions, self.last_played_frame)
            self.right_foot.compute(self.positions, self.quaternions, self.last_played_frame)

            self.quaternions, self.positions = lab.utils.limb_ik(
                self.quaternions[np.newaxis, ...], self.positions[np.newaxis, ...], animation.parents, animation.bones,
                np.array([self.left_foot.ankle_quat, self.right_foot.ankle_quat])[np.newaxis, ...],
                np.array([self.left_foot.ankle_pos, self.right_foot.ankle_pos])[np.newaxis, ...],
            )
            self.quaternions, self.positions = self.quaternions[0], self.positions[0]
                
                
        return self.current_action != None
    
    def set_action(self, action):
        self.current_action = action
        self.next_state_id = action.next_state_id
        self.current_action_framei = 1
        self.transition = None
        if action is not None and action.blend:
            self.transition = (self.last_played_frame, 0)
    

In [None]:
player = AnimPlayer(0)

def render(frame):

    if player.current_action is None:
        state = states[player.next_state_id]
        e = randrange(0, len(state.actions))
        action = state.actions[e]
        player.set_action(action)
                
    player.tick()
    display((player.last_played_frame, player.transition))
    
    q = player.quaternions
    p = player.positions
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    viewer.end_display()

    viewer.disable(depth_test=True)
    
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=100)
)
viewer

## Control Policy

### The reward function $R(s, e, a)$ 

is given by:

$$
R(s, e, a) = \max_t \left( \gamma^t \exp\left(-\frac{\|p(t) - p_d\|}{10}\right) \right),
$$

Where:

* $s$: The current state of the avatar.
* $e$: The target's state.
* $a$: The action taken by the avatar.
* $p(t)$: The avatar's position trajectory during the action $a$.
* $p_d$: The target's position in the polar grid.
* $\gamma \in (0, 1)$: The discount factor, which penalizes delayed rewards.
* $|\cdot|$: The Euclidean distance between $p(t)$ and $p_d$.


In [None]:
max_action_count = 0
max_action_length = 0
for state in states:
    max_action_count = max(max_action_count, len(state.actions))
    for action in state.actions:
        max_action_length = max(max_action_length, action.end -  action.start)
display(max_action_count)
display(max_action_length)

In [None]:
target_positions = [[[0,0,0]]]
for i in range(6):
    target_positions.append (np.array([
        np.cos(np.linspace(0, np.pi*2, 9+i)[:-1]),
        [0]*(8+i),
        np.sin(np.linspace(0, np.pi*2, 9+i)[:-1])
    ]).T * (36*(i)+20) )
    
target_positions = np.concatenate(target_positions)
target_positions_count = target_positions.shape[0]
display(target_positions_count)

In [None]:
time_discount = .98
time_discounts = np.array([time_discount**t for t in range(max_action_length)])

immediate_rewards = np.ones([state_count, target_positions_count, max_action_count]) * -1
next_states = np.ones([state_count, target_positions_count, max_action_count], dtype=np.int32) * -1

for i in range(state_count):
    for j in range(target_positions.shape[0]):
        for a, action in enumerate(states[i].actions):
            length = action.end - action.start
            v = action.trajectory_pos - target_positions[j]
            immediate_rewards[i,j,a] = np.max(np.exp( -np.sqrt(np.sum(v*v, axis=1)) / 10.0 ) * time_discounts[:length])
            
            # find the closest target_positions after we moved
            spositions = lab.utils.quat_mul_vec(action.trajectory_quats[-1][np.newaxis,:], target_positions) + action.trajectory_pos[-1] - target_positions[j]
            target_state_id = np.sum(spositions*spositions, axis=1).argmin()
                        
            next_states[i,j,a] = action.next_state_id*target_positions_count + target_state_id

immediate_rewards = immediate_rewards.reshape(-1, max_action_count)
next_states = next_states.reshape(-1, max_action_count)

In [None]:
gamepad = widgets.Controller(index=0)
gamepad

In [None]:
player = AnimPlayer(0)

player.selected_position = 0
player.selected_qp = ()

def render(frame):              

    posx = -gamepad.axes[1].value * 500
    posz = gamepad.axes[0].value * 500
    
    if player.current_action is None:
        state = states[player.next_state_id]
        
        spositions = lab.utils.quat_mul_vec(player.quaternions[0][np.newaxis,:], target_positions) + player.positions[0] - np.array([posx,0,posz])
        target_state_id = np.sum(spositions*spositions, axis=1).argmin()
        
        action_id = immediate_rewards[player.next_state_id * target_positions_count + target_state_id, :].argmax()

        player.selected_position = target_state_id
        player.selected_qp = (player.quaternions[0], player.positions[0])
        
        action = state.actions[action_id]
        player.set_action(action)
                
    player.tick()
    display((player.last_played_frame, player.transition))
    
    q = player.quaternions
    p = player.positions
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    sphere.materials()[0].set_albedo(np.array([1,0,0], dtype=np.float32))
    target_matrix = np.eye(4, dtype=np.float32)
    target_matrix[0, 3] = posx
    target_matrix[2, 3] = posz
    viewer.draw(sphere, target_matrix)
    viewer.end_display()

    viewer.disable(depth_test=True)

    states_matrices = np.eye(4, dtype=np.float32)[np.newaxis,...].repeat(target_positions.shape[0], axis=0)
    states_matrices[:, :3, 3] = lab.utils.quat_mul_vec(player.selected_qp[0][np.newaxis,:], target_positions) + player.selected_qp[1]
    viewer.draw_axis(states_matrices, 2)  
    
    viewer.draw_axis(states_matrices[player.selected_position][np.newaxis, ...], 10)
    if player.current_action:
        lines = lab.utils.quat_mul_vec(player.selected_qp[0][np.newaxis,:], player.current_action.trajectory_pos) + player.selected_qp[1]
        viewer.draw_lines(lines.repeat(2, axis=0)[1:-1].astype(np.float32))
        
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=100)
)
viewer

### Markov Decision Process (MDP)

The Markov property ensures the next state depends only on the current state and action, not on the history.


An MDP is defined by the tuple $(S, A, R, \gamma, T)$, where:
- $S$: The set of states $s,e$.
- $A$: The set of actions.
- $R(s, e, a)$: The reward function, which gives the immediate reward for taking action $a$ in state $s,e$.
- $\gamma \in [0, 1]$: The discount factor, which balances the importance of future rewards.
- $T(s, e, a) \to s', e'$: The transition function that determines the next state $s', e'$ after taking action $a$ in state $s, e$.

The goal in an MDP is to find a **policy** (a mapping from states to actions) that **maximizes the expected cumulative reward** starting from an initial state $s_0$. The cumulative reward is:

$$
G = \sum_{t=0}^\infty \gamma^t R(s_t, e_t, a_t),
$$

where:
- $s_t, e_t$: The state at time $t$.
- $a_t$: The action taken at time $t$.


### Value Function

The **Value Function** quantifies the cumulative reward of starting in state $(s, e)$ and following the optimal policy thereafter. It is recursively defined as:

$$
V(s, e) := \max_a \left[ R(s, e, a) + \gamma^t V(s', e') \right],
$$

where:
- $R(s, e, a)$: The immediate reward for taking action $a$ in state $(s, e)$.
- $s', e'$: The resulting avatar and target states after taking action $a$.
- $t$: The duration of the action.
- $\gamma^t$: Discounts future rewards, encouraging quicker actions.



### Bellman Update Rule

The **Bellman Update Rule** refines the value function iteratively, ensuring optimality for every state-action pair:

In my case I will actually use a value function **$V(s, e, a)$**

Steps:
1. **Sample** a state-action pair $(s, e, a)$.
2. **Compute the reward $R(s, e, a)$** for the current action.
3. **Estimate the value** of the next state $\max_{a'} V(s', e', a')$:
4. **Update $V(s, e, a)$** based on the maximum cumulative reward.



In [None]:
out_value_function = widgets.Output(layout={'border': '1px solid black'})
display(out_value_function)

future_reward_discount = .97
value_function = immediate_rewards.copy()

last_total_reward = 0
for epoch in range(1000):
    for _ in range(30000):
        state_id = randrange(0, state_count)
        state_space_id = state_id * target_positions_count + randrange(0, target_positions_count)
        action_id = randrange(0, len(states[state_id].actions))
        action = states[state_id].actions[action_id]

        length = action.end - action.start

        # reward
        reward = immediate_rewards[state_space_id, action_id]

        # future
        next_state = next_states[state_space_id, action_id]
        next_max = np.max(value_function[next_state,:])

        # Update
        value_function[state_space_id, action_id] = reward + next_max * future_reward_discount**length


    total_reward = np.sum(value_function)
    with out_value_function:
        out_value_function.clear_output()
        display(f"epoch {epoch} :: total {total_reward}, difference with previous epoch {total_reward - last_total_reward}")
    if np.abs(total_reward - last_total_reward) < 1:
        last_total_reward = total_reward
        break
    last_total_reward = total_reward

In [None]:
player = AnimPlayer(0)

player.selected_position = 0
player.selected_qp = ()

def render(frame):              

    posx = -gamepad.axes[1].value * 500
    posz = gamepad.axes[0].value * 500
    is_immediate = gamepad.buttons[0].value > 0.2
    
    if player.current_action is None:
        state = states[player.next_state_id]
        
        spositions = lab.utils.quat_mul_vec(player.quaternions[0][np.newaxis,:], target_positions) + player.positions[0] - np.array([posx,0,posz])
        target_state_id = np.sum(spositions*spositions, axis=1).argmin()
        
        action_id = value_function[player.next_state_id * target_positions_count + target_state_id, :].argmax()
        if is_immediate:
            action_id = immediate_rewards[player.next_state_id * target_positions_count + target_state_id, :].argmax()

        player.selected_position = target_state_id
        player.selected_qp = (player.quaternions[0], player.positions[0])
        
        action = state.actions[action_id]
        player.set_action(action)
                
    player.tick()
    display((player.last_played_frame, player.transition, is_immediate))
    
    q = player.quaternions
    p = player.positions
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    if is_immediate:
        sphere.materials()[0].set_albedo(np.array([1,0,0], dtype=np.float32))
    else:
        sphere.materials()[0].set_albedo(np.array([0,1,1], dtype=np.float32))
    target_matrix = np.eye(4, dtype=np.float32)
    target_matrix[0, 3] = posx
    target_matrix[2, 3] = posz
    viewer.draw(sphere, target_matrix)
    viewer.end_display()

    viewer.disable(depth_test=True)

    states_matrices = np.eye(4, dtype=np.float32)[np.newaxis,...].repeat(target_positions.shape[0], axis=0)
    states_matrices[:, :3, 3] = lab.utils.quat_mul_vec(player.selected_qp[0][np.newaxis,:], target_positions) + player.selected_qp[1]
    viewer.draw_axis(states_matrices, 2)  
    
    viewer.draw_axis(states_matrices[player.selected_position][np.newaxis, ...], 10)
    if player.current_action:
        lines = lab.utils.quat_mul_vec(player.selected_qp[0][np.newaxis,:], player.current_action.trajectory_pos) + player.selected_qp[1]
        viewer.draw_lines(lines.repeat(2, axis=0)[1:-1].astype(np.float32))
        
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=100)
)
viewer

In [None]:
display(out_value_function)
out_value_function.clear_output()

future_reward_discount = .97
value_function_static = np.zeros([state_count, max_action_count])

last_total_reward = 0
for epoch in range(1000):
    for _ in range(30000):
        state_id = randrange(0, state_count)
        action_id = randrange(0, len(states[state_id].actions))
        action = states[state_id].actions[action_id]

        length = action.end - action.start

        # reward
        reward = 0
        if action.start > 950 and action.start < 1200 and action.end > 950 and action.end < 1200 :
            reward = 1

        # future
        next_state = action.next_state_id
        next_max = np.max(value_function_static[next_state,:])

        # Update
        value_function_static[state_id, action_id] = reward + next_max * future_reward_discount**length


    total_reward = np.sum(value_function_static)
    with out_value_function:
        out_value_function.clear_output()
        display(f"epoch {epoch} :: total {total_reward}, difference with previous epoch {total_reward - last_total_reward}")
    if np.abs(total_reward - last_total_reward) < .1:
        last_total_reward = total_reward
        break
    last_total_reward = total_reward

In [None]:
player = AnimPlayer(0)

player.selected_position = 0
player.selected_qp = ()
player.selected_mode = 0

def render(frame):              

    posx = -gamepad.axes[1].value * 500
    posz = gamepad.axes[0].value * 500
    
    if player.current_action is None:
        state = states[player.next_state_id]

        spos = player.positions[0] - np.array([posx,0,posz])
        if player.selected_mode == 0 and np.sum(spos*spos) < 2500:
            player.selected_mode = 1
        if player.selected_mode == 1 and np.sum(spos*spos) > 10000:
            player.selected_mode = 0

        action_id = 0
        target_state_id = 0
        if player.selected_mode == 0:
            spositions = lab.utils.quat_mul_vec(player.quaternions[0][np.newaxis,:], target_positions) + player.positions[0] - np.array([posx,0,posz])
            target_state_id = np.sum(spositions*spositions, axis=1).argmin()
            
            action_id = value_function[player.next_state_id * target_positions_count + target_state_id, :].argmax()
        else:
            action_id = value_function_static[player.next_state_id, :].argmax()

        player.selected_position = target_state_id
        player.selected_qp = (player.quaternions[0], player.positions[0])
        
        action = state.actions[action_id]
        player.set_action(action)
                
    player.tick()
    display((player.last_played_frame, player.transition, player.selected_mode))
    
    q = player.quaternions
    p = player.positions
 
    a = lab.utils.quat_to_mat(q, p)
    viewer.set_shadow_poi(p[0])
    
    viewer.begin_shadow()
    viewer.draw(character, a)
    viewer.end_shadow()
    
    viewer.begin_display()
    viewer.draw_ground()
    viewer.draw(character, a)
    if player.selected_mode:
        sphere.materials()[0].set_albedo(np.array([0,1,0], dtype=np.float32))
    else:
        sphere.materials()[0].set_albedo(np.array([0,1,1], dtype=np.float32))
    target_matrix = np.eye(4, dtype=np.float32)
    target_matrix[0, 3] = posx
    target_matrix[2, 3] = posz
    viewer.draw(sphere, target_matrix)
    viewer.end_display()

    viewer.disable(depth_test=True)

    if player.selected_mode == False:
        states_matrices = np.eye(4, dtype=np.float32)[np.newaxis,...].repeat(target_positions.shape[0], axis=0)
        states_matrices[:, :3, 3] = lab.utils.quat_mul_vec(player.selected_qp[0][np.newaxis,:], target_positions) + player.selected_qp[1]
        viewer.draw_axis(states_matrices, 2)  
        
        viewer.draw_axis(states_matrices[player.selected_position][np.newaxis, ...], 10)
        if player.current_action:
            lines = lab.utils.quat_mul_vec(player.selected_qp[0][np.newaxis,:], player.current_action.trajectory_pos) + player.selected_qp[1]
            viewer.draw_lines(lines.repeat(2, axis=0)[1:-1].astype(np.float32))
        
    viewer.execute_commands()
    
interact(
    render, 
    frame=lab.Timeline(max=100)
)
viewer