In [1]:
import agents_env
from agents_env.agent_replay_motion import HumanoidReplay
from utils.SimpleConverter import SimpleConverter
from utils.util_data import *
from copy import deepcopy

In [2]:

from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap
import jax.random
from jax import lax

Here I will have my agent

In [3]:
from agents_env.agent_env_template import HumanoidDiff
from some_math.rotation6D import quaternion_to_rotation_6d
from some_math.math_utils_jax import *
from agents_env.losses import *

In [61]:
class HumanoidEnvTrain(HumanoidDiff):
    def __init__(self, reference_trajectory_qpos, 
                 reference_trajectory_qvel, 
                 duration_trajectory,
                 dict_duration, 
                 model_path,
                 kp_gains,
                 kd_gains,
                 reference_x_pos,
                 reference_x_rot,
                 **kwargs):
        super().__init__(reference_trajectory_qpos, 
                         reference_trajectory_qvel, 
                         duration_trajectory, 
                         dict_duration, 
                         model_path, **kwargs)
        self.kp__gains = kp_gains
        self.kd__gains = kd_gains
        #the row lenght is the number of frames of the motion
        #thus it is the lenght, it will be the same with qvel
        self.rollout_lenght = reference_trajectory_qpos.shape[0]
        #I want to save another instance of the model for the refernece 
        self.sys_reference = deepcopy(self.sys)
        self.reference_x_pos = reference_x_pos
        self.reference_x_rot = reference_x_rot

        #for now this will be hardcaode
        self.rot_weight =  0.5
        self.vel_weight =  0.01
        self.ang_weight =  0.01
        self.reward_scaling= 0.02
        #self.err_threshold = 0.4
        self.err_threshold = 2.0
        #for now it will be the same size
        self.cycle_len = reference_trajectory_qpos.shape[0]
  
    
    #set pd callback
    def set_pd_callback(self,pd_control):
        self.pd_function = pd_control
 
    
    
    def _demo_replay(self, state,ref_data)-> State:
        global_pos_state = state.x.pos
        global_pos_ref = ref_data.x.pos
        #jax.debug.print("pos state: {}",global_pos_state)
        #jax.debug.print("pos ref: {}",global_pos_ref)
        error = loss_l2_pos(global_pos_state, global_pos_ref)
        #jax.debug.print("error: {}",error)
        
        replay = jp.where(error > self.err_threshold, jp.float32(1), jp.float32(0))
        #jax.debug.print("replay: {}",replay)
        #replace the ref_state data
          # Define the true and false branch functions for lax.cond
        def true_fun(_):
            # Update state to reference state and maintain step index
            return ref_data
            

        def false_fun(_):
            # Return the original state with updated metrics
            return state
        # Use lax.cond to conditionally switch between states
        new_data = lax.cond(replay == 1, true_fun, false_fun, None)
        
        return new_data,replay
        

    
    def reset(self, rng: jp.ndarray) -> State:
        
        #set this as zero
        reward, done, zero = jp.zeros(3)
        #random state initialization (RSI)
        new_step_idx = jax.random.randint(rng, shape=(), minval=0, maxval=self.rollout_lenght)
        data = self.get_reference_state(new_step_idx)
        metrics = {'step_index': new_step_idx, 'pose_error': zero, 'fall': zero}
        obs = self._get_obs(data, new_step_idx)
        
        state = State(data, obs, reward, done, metrics)
        
        #update the replay with 0 index
        state.metrics.update(replay=jp.zeros(1)[0])
        
        return state
    
    
    
    def _get_obs(self, data: mjx.Data, step_idx: jp.ndarray)-> jp.ndarray:
          
        current_step_inx =  jp.asarray(step_idx, dtype=jp.int32)
        #we take out the first index that is the world pos
        current_xpos = data.xpos[1:]
        current_xrot = data.xquat[1:]
        #get rid pf the first index that is the root, we just want
        #pos relative to the root, thus the root will become zero
        relative_pos = (current_xpos - current_xpos[0])[1:].ravel()
        
        #this is already in quat form
        current_qpos_root = data.qpos[3:7]
        #qpos of the joins this are scale values, since they are hinge joints
        current_qpos_joints = data.qpos[7:] 
        #now I will convert them into quaterions
        #first the joints that are onedofs, thus axis angle to quaterion
        hinge_quat = self.hinge_to_quat(current_qpos_joints)
        #now get a 13x4 quaterion, were we combine links of 3DOFS
        
        #this is still not working I will get back to it later
        #local_quat = self.local_quat(current_qpos_root,current_xrot,hinge_quat,self.one_dofs_joints_idx,self.link_types_array_without_root)
        
        #jax.debug.print("local quat: {}",local_quat)
        
        #now we convert it to a 6D matrix representation
        local_rot_6D= quaternion_to_rotation_6d(hinge_quat).ravel()
        #remeber for now we have the linear vel of the root
        linear_vel = data.qvel[0:3]
        angular_vel = data.qvel[3:]
        
        #get the phi value I will do that later
        phi = ( current_step_inx% self.cycle_len) / self.cycle_len
        phi = jp.asarray(phi)
        #in theory it is mutiable if we do concatenate [] instead of
        #() since, one is a list and the other a tuple
        return jp.concatenate([relative_pos,local_rot_6D,linear_vel,angular_vel,phi[None]])
   
    
    
    #this will grab a reference state from the reference trajectory
    #dont confuse with the state of the main agent
    def set_new_ref_state(self,step_index):
        #grab the current qpos and qvel of the reference
        
        ref_qp = self.reference_trajectory_qpos[step_index]
        ref_qv = self.reference_trajectory_qvel[step_index]
        
        data = mjx.make_data(self.sys_reference)
        data = data.replace(qpos=ref_qp, qvel=ref_qv)
        
        self.new_ref_data = mjx.forward(self.sys_reference, data)
    
    def set_ref_state_pipeline(self,step_index):
        ref_qp = self.reference_trajectory_qpos[step_index]
        ref_qv = self.reference_trajectory_qvel[step_index]
        #now I will return a state depending on the index and the reference trajectory
        return self._pipeline.init(self.sys_reference, ref_qp, ref_qv, self._debug)
        
    
    
    def get_com_reference(self):
        #here I want to calculate the center of mass
        #we only want the root
        return self.new_ref_data.subtree_com[1]
        
    
    def get_reference_state(self,step_index):
        ref_qp = self.reference_trajectory_qpos[step_index]
        ref_qv = self.reference_trajectory_qvel[step_index]
        #now I will return a state depending on the index and the reference trajectory
        return self.pipeline_init(ref_qp,ref_qv)
        
    #just with a custom target but not selected joints
    def step(self, state: State, action: jp.ndarray,
                                           custom_target,time) -> State:
        
        initial_idx = state.metrics['step_index']
        current_step_inx =  jp.asarray(initial_idx, dtype=jp.int32)
        
        #advance the simulation in the dummy data,
        #not sure if I should keep this maybe I dont need to
        # but this will match with the current step index       
        #self.set_new_ref_state(current_step_inx)
        current_state_ref = self.set_ref_state_pipeline(current_step_inx)
        
        
       
        #jax.debug.print("idx: {}",current_step_inx)
        
        #current qpos and qvel for the torque    
        qpos = state.pipeline_state.q
        qvel = state.pipeline_state.qd
        #this will be modified by a one without custom target   
        torque = self.pd_function(custom_target,self.sys,state,qpos,qvel,
                                 self.kp__gains,self.kd__gains,time,self.sys.dt) 
        
        data = self.pipeline_step(state.pipeline_state,torque)
        
        
    
        #first get the values, first value
        global_pos_state = data.x.pos
        global_rot_state = quaternion_to_rotation_6d(data.x.rot)
        global_vel_state = data.xd.vel
        global_ang_state = data.xd.ang
        
        # jax.debug.print("pos state: {}",global_pos_state)
        # jax.debug.print("state rot: {}",data.x.rot)
        # jax.debug.print("state vel: {}",global_vel_state)
        # jax.debug.print("state ang: {}",global_ang_state)
        
        
        
        #now for the reference trajectory
        global_pos_ref = self.reference_x_pos[current_step_inx]
        global_rot_ref = quaternion_to_rotation_6d(self.reference_x_rot[current_step_inx])
        global_vel_ref = current_state_ref.xd.vel
        global_ang_ref = current_state_ref.xd.ang
        
        
        # jax.debug.print("pos ref: {}",global_pos_ref)
        # jax.debug.print("ref rot: {}",self.reference_x_rot[current_step_inx])
        # jax.debug.print("ref vel: {}",global_vel_ref)
        # jax.debug.print("ref ang: {}",global_ang_ref)
        
        
        
        reward = -1 * (mse_pos(global_pos_state, global_pos_ref) +
               self.rot_weight * mse_rot(global_rot_state, global_rot_ref) +
               self.vel_weight * mse_vel(global_vel_state, global_vel_ref) +
               self.ang_weight * mse_ang(global_ang_state, global_ang_ref)
               ) * self.reward_scaling
        
        #jax.debug.print("rewards: {}",reward)
        
        #here I will do the fall
        #on the z axis
        fall = jp.where(data.qpos[2] < 0.2, jp.float32(1), jp.float32(0))
        fall = jp.where(data.qpos[2] > 1.7, jp.float32(1), fall)
        
        jax.debug.print("fall: {}",fall)
        jax.debug.print("qpos: {}",data.qpos[0:3])
        
        #here the demoreplay
        new_data,replay=self._demo_replay(data,current_state_ref)

        #jax.debug.print("new data: {}",new_data.x.pos)
        
        #get the observations
        obs = self._get_obs(new_data, current_step_inx)
        
        #done =1.0
        #increment the step index to know in which episode and wrap
        next_step_index = (state.metrics['step_index'] + 1) % self.rollout_lenght
        # jax.debug.print("idx: {}",current_step_inx)
        # jax.debug.print("next idx: {}",next_step_index )
        
        state.metrics.update(
            step_index=next_step_index,
            pose_error=loss_l2_relpos(global_pos_state, global_pos_ref),
            fall=fall,
            replay=replay
        )
        
        
        return state.replace(
            pipeline_state=new_data, obs=obs, reward=reward, done=state.metrics['fall']
        )
        
        
        
    

In [62]:
#since we are going to use custom trajectory, we will set up the initial
#and the end time
t_init = 1
t_end = 3

In [63]:
#this is just dummy data to initialize the agent
trajectory = SimpleConverter('motions/humanoid3d_punch.txt')
trajectory.load_mocap()
model_path = 'models/final_humanoid_no_gravity.xml'

data_mocap_matrix = jp.asarray(trajectory.data)
data_pos_mocap = jp.asarray(trajectory.data_pos)
data_vel_mocap = jp.asarray(trajectory.data_vel)

data_dict_mocap = trajectory.duration_dict

data_xpos_mocap=jp.asarray(trajectory.data_xpos)
data_xrot_mocap=jp.asarray(trajectory.data_xrot)


In [64]:
#get th kp and kd for the agent
kp,kd = generate_kp_kd_gains()
print(kp)
print(kd)

[1000 1000 1000  100  100  100  400  400  400  300  400  400  400  300
  500  500  500  500  400  400  400  500  500  500  500  400  400  400]
[100 100 100  10  10  10  40  40  40  30  40  40  40  30  50  50  50  50
  40  40  40  50  50  50  50  40  40  40]


In [65]:
envs.register_environment('humanoidEnvMimic', HumanoidEnvTrain)
env_name = 'humanoidEnvMimic'
env = envs.get_environment(env_name=env_name,
                           reference_trajectory_qpos=data_pos_mocap,
                           reference_trajectory_qvel = data_vel_mocap,
                            duration_trajectory=trajectory.total_time,
                            dict_duration= data_dict_mocap,
                           model_path=model_path,
                           kp_gains = kp,
                           kd_gains = kd,
                           reference_x_pos=data_xpos_mocap,
                           reference_x_rot=data_xrot_mocap)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)


#select the pd_control
from agents_env.pds_controllers_agents import stable_pd_controller_custom_trajectory
env.set_pd_callback(stable_pd_controller_custom_trajectory)

In [51]:
from some_math.math_utils import generate_trajectory,compute_cubic_trajectory,start_trajectories

In [66]:
trajec_dict = dict()


a_jnt_right =get_actuator_indx(env.sys.mj_model,'right_shoulder','Y')
a_jnt_left =get_actuator_indx(env.sys.mj_model,'left_shoulder','X')
a_jnt_left_elbow = get_actuator_indx(env.sys.mj_model,'left_elbow','X')
a_jnt_right_elbow = get_actuator_indx(env.sys.mj_model,'right_elbow','X')

right_knee = get_actuator_indx(env.sys.mj_model,'right_knee','X')
left_knee = get_actuator_indx(env.sys.mj_model,'left_knee','X')

trajec_dict[a_jnt_right] = generate_trajectory(t_init,t_end, 0, -1.5)
#trajectory left
trajec_dict[a_jnt_left] = generate_trajectory(t_init, t_end, 0, 1.5)
# #left elbow   
trajec_dict[a_jnt_left_elbow]= generate_trajectory(t_init, t_end, 0, 1.5)
# #right elbow
trajec_dict[a_jnt_right_elbow] = generate_trajectory(t_init, t_end, 0, 1.5)

trajec_dict[right_knee] = generate_trajectory(t_init, t_end, 0, 0)

trajec_dict[left_knee] = generate_trajectory(t_init, t_end, 0, 0)


start_trajec = start_trajectories(trajec_dict)

In [67]:
print(a_jnt_right)
print(a_jnt_left)
print(a_jnt_left_elbow)
print(a_jnt_right_elbow)
print(right_knee)
print(left_knee)


7
10
13
9
17
24


In [68]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]
# grab a 500 steps
n_steps = 500

In [56]:
phi = ( 1% env.rollout_lenght) / env.rollout_lenght
phi = jp.asarray(phi)
print(env.rollout_lenght)
print(phi)
print(phi[None].shape)

65
0.015384615
(1,)


In [69]:
print(env.dt)
print(env.sys.dt)

0.016
0.002


In [70]:

for i in range(n_steps):
    
    
    ctrl = -0.1 * jp.ones(env.sys.nu)
    #time
    time = state.pipeline_state.time
    
    time = jp.clip(time, t_init, t_end)
    
    print('time: ',time)
         
    state = jit_step(state, ctrl,start_trajec,time)
    
    rollout.append(state.pipeline_state)

time:  1.0
fall: 0.0
qpos: [1.0407541  0.92413276 0.7685817 ]
time:  1.0
fall: 0.0
qpos: [1.0356866  0.9319589  0.77845985]
time:  1.0
fall: 0.0
qpos: [1.0294065 0.941135  0.7918113]
time:  1.0
fall: 0.0
qpos: [1.0230837  0.94932425 0.8085291 ]
time:  1.0
fall: 0.0
qpos: [1.0169721  0.956836   0.82723665]
time:  1.0
fall: 0.0
qpos: [1.0111412  0.9635479  0.84665287]
time:  1.0
fall: 0.0
qpos: [1.005376   0.96633524 0.8662445 ]
time:  1.0
fall: 0.0
qpos: [1.0002903 0.9693727 0.8854053]
time:  1.0
fall: 0.0
qpos: [0.9949942  0.97356105 0.9040123 ]
time:  1.0
fall: 0.0
qpos: [0.9917163 0.9761592 0.9221831]
time:  1.0
fall: 0.0
qpos: [0.9890413  0.97828275 0.9399611 ]
time:  1.0
fall: 0.0
qpos: [0.9868898  0.98127824 0.9574367 ]
time:  1.0
fall: 0.0
qpos: [0.9846418 0.9847676 0.9745837]
time:  1.0
fall: 0.0
qpos: [0.98147845 0.9875302  0.9914959 ]
time:  1.0
fall: 0.0
qpos: [0.97794145 0.99059635 1.0079379 ]
time:  1.0
fall: 0.0
qpos: [0.9750581  0.99493784 1.0238324 ]
time:  1.0
fall: 0.0

In [71]:
media.show_video(env.render(rollout, camera='back'), fps=1.0/env.dt)

0
This browser does not support the video tag.


In [100]:
state.pipeline_state.xpos

Array([[0.        , 0.        , 0.        ],
       [1.0449276 , 0.9201753 , 0.762677  ],
       [1.0468391 , 0.9241122 , 0.99878746],
       [1.0308832 , 1.004719  , 1.2070577 ],
       [1.2122096 , 1.0383474 , 1.229011  ],
       [1.4756391 , 1.1030678 , 1.272881  ],
       [0.8589965 , 0.9421486 , 1.2391825 ],
       [0.73543   , 0.7370127 , 1.1044271 ],
       [1.1253372 , 0.9473576 , 0.7615728 ],
       [1.2434226 , 1.2149004 , 0.45796478],
       [1.2533716 , 1.0272739 , 0.09369731],
       [0.96451795, 0.89299303, 0.76378125],
       [0.75786364, 0.86453366, 0.39746842],
       [0.75479364, 0.6369034 , 0.05663341]], dtype=float32)

In [138]:
state.pipeline_state.x.rot

Array([[-0.9988232 , -0.0112808 , -0.04714939, -0.00136157],
       [-0.9760239 ,  0.00921053, -0.21670713,  0.01817523],
       [-0.954311  ,  0.05125172, -0.01494027, -0.29400754],
       [-0.7900792 ,  0.11564598, -0.52131987, -0.30104205],
       [-0.90522087, -0.15832108,  0.27672157, -0.2809533 ],
       [-0.9292731 , -0.30408332, -0.00979375,  0.20949635],
       [-0.8230528 , -0.16857538,  0.43154195,  0.32853878],
       [-0.9441449 ,  0.07370024,  0.21859391,  0.23531947],
       [-0.89839023, -0.07485342, -0.36343905,  0.23495519],
       [-0.90135527,  0.1235581 , -0.23688422,  0.34084862],
       [-0.90527534, -0.10094515,  0.3854131 , -0.14745612],
       [-0.97898144, -0.01743659, -0.09829593, -0.177846  ],
       [-0.97810173, -0.02084978, -0.10072432, -0.18093319]],      dtype=float32)

Testing how to get the global rotations

In [176]:


# Remove the first character ('f') and convert the rest to a list of integers
link_types_numbers = [int(char) for char in env.sys.link_types[1:]]

# Convert the list of integers to a JAX array
link_types_array_without_root = jp.array(link_types_numbers)


In [107]:
env.sys.jnt_axis

Array([[ 0.,  0.,  1.],
       [ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 0., -1.,  0.],
       [ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 0., -1.,  0.],
       [ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 0., -1.,  0.],
       [ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 0., -1.,  0.],
       [ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.]], dtype=float32)

In [31]:
#save the root qpos
current_pos_rot = state.pipeline_state.qpos[0:3]
current_root_rot = state.pipeline_state.qpos[3:7]
current_joints = state.pipeline_state.qpos[7:]

print(state.pipeline_state.qpos)



[ 6.97175786e-03  5.95309353e-03  7.52526999e-01 -9.98823225e-01
 -1.12807984e-02 -4.71493937e-02 -1.36156718e-03 -3.17845419e-02
  3.42894375e-01 -3.98943871e-02 -8.93776119e-02 -4.40016717e-01
  5.91198862e-01 -6.19465292e-01  5.02453387e-01  7.96520531e-01
  1.75985932e+00  7.12900817e-01 -4.34377730e-01 -1.01669244e-01
  9.86902952e-01 -3.00477624e-01 -4.77567464e-01 -5.72591901e-01
 -1.22387028e+00 -2.82287776e-01 -2.97072321e-01 -3.64673197e-01
  4.22074705e-01 -8.18987012e-01  5.31627297e-01 -1.00513864e+00
  6.95715472e-03  3.84540367e-03  6.93053333e-03]


In [32]:
#now get the indices for the one joint index


joints_one_dofs_indices=jp.array([env.right_elbow_joint,env.left_elbow_joint,env.right_knee_joint,env.left_knee_joint])
print(joints_one_dofs_indices)


[ 9 13 17 24]


In [35]:
def axis_angle_to_quat(axis: jax.Array, angle: jax.Array) -> jax.Array:
  """Provides a quaternion that describes rotating around axis by angle.

  Args:
    axis: (3,) axis (x,y,z)
    angle: () float angle to rotate by

  Returns:
    A quaternion that rotates around axis by angle
  """
  s, c = jp.sin(angle * 0.5), jp.cos(angle * 0.5)
  return jp.insert(axis * s, 0, c)


In [36]:
def quat_mul(u: jax.Array, v: jax.Array) -> jax.Array:
  """Multiplies two quaternions.

  Args:
    u: (4,) quaternion (w,x,y,z)
    v: (4,) quaternion (w,x,y,z)

  Returns:
    A quaternion u * v.
  """
  return jp.array([
      u[0] * v[0] - u[1] * v[1] - u[2] * v[2] - u[3] * v[3],
      u[0] * v[1] + u[1] * v[0] + u[2] * v[3] - u[3] * v[2],
      u[0] * v[2] - u[1] * v[3] + u[2] * v[0] + u[3] * v[1],
      u[0] * v[3] + u[1] * v[2] - u[2] * v[1] + u[3] * v[0],
  ])


In [40]:
#this works properly
right_elbow = current_joints[9]
right_elbow_axis = env.sys.jnt_axis[10]

right_elbow_qloc = axis_angle_to_quat(right_elbow_axis,right_elbow )

quat = quat_mul(state.pipeline_state.x.rot[3], right_elbow_qloc)


quat




Array([-0.90522087, -0.15832108,  0.27672157, -0.2809533 ], dtype=float32)

In [113]:
def compute_quaternion_for_joint(joint_axis,joint_angle):
    
    joint_quat = axis_angle_to_quat(joint_axis,joint_angle)
    return joint_quat

In [150]:

def combine_quaterions_joint_3DOFS(quats):
    # Combine rotations in XYZ order, ensure the multiplication reflects this
    #first element is w the real number
    combined_quat = quat_mul(quat_mul(quats[0], quats[1]),quats[2])
    return combined_quat


In [235]:
#now do the same but with vmap
#jnt_axis without free joint
axis_hinge = env.sys.jnt_axis[1:]

vmap_compute_quaternion = jax.vmap(compute_quaternion_for_joint, in_axes=(0, 0))

hinge_quat = vmap_compute_quaternion(axis_hinge,current_joints) 

print(hinge_quat)
print(hinge_quat.shape)

[[ 0.9998737  -0.0158916  -0.         -0.        ]
 [ 0.98533887  0.          0.17060849  0.        ]
 [ 0.99980104 -0.         -0.         -0.01994587]
 [ 0.9990016  -0.04467393 -0.         -0.        ]
 [ 0.97589564 -0.         -0.21823777 -0.        ]
 [ 0.95662767  0.          0.          0.29131332]
 [ 0.9524151  -0.30480403 -0.         -0.        ]
 [ 0.9686082   0.          0.24859233  0.        ]
 [ 0.9217371   0.          0.          0.38781536]
 [ 0.63720536  0.         -0.7706941   0.        ]
 [ 0.93714136  0.34894997  0.          0.        ]
 [ 0.97650707 -0.         -0.21548538 -0.        ]
 [ 0.9987082  -0.         -0.         -0.05081273]
 [ 0.8807033   0.         -0.47366843  0.        ]
 [ 0.9887354  -0.14967425 -0.         -0.        ]
 [ 0.97162634 -0.         -0.23652105 -0.        ]
 [ 0.95929646 -0.         -0.         -0.28240088]
 [ 0.8185379  -0.          0.5744525  -0.        ]
 [ 0.99005574 -0.14067572 -0.         -0.        ]
 [ 0.98898876 -0.         -0.14

In [234]:
# Mask to filter out the quaternions that are not to be combined in triples
mask = jp.ones(hinge_quat.shape[0], dtype=bool)
mask = mask.at[joints_one_dofs_indices].set(False)



#index 0 since is a tuple and we want the indices that are index 0
combinable_indices = jp.where(mask)[0]
non_combinable_indices = jp.where(~mask)[0]

no_grouped_quaterions = hinge_quat[non_combinable_indices]


#select and store these indices
# Assuming the remaining quaternions are multiple of three
# Reshape the array to (-1, 3, 4) where 3 is the number of quaternions to be combined and 4 is the quaternion dimension
grouped_quaternions = hinge_quat[combinable_indices].reshape(-1, 3, 4)
#this will be applied on the first axs
vmap_combined_quat = jax.vmap(combine_quaterions_joint_3DOFS)

quat_combined = vmap_combined_quat(grouped_quaternions)

#there are 13 links -12, since we will merge the root at the end the shape 1 is 4 for the quat
quat_loc_all_joints = jp.zeros((state.pipeline_state.x.rot.shape[0]-1,state.pipeline_state.x.rot.shape[1]))

# Create a mask where each position is True if the corresponding link type is 3
link_types_mask = link_types_array_without_root == 3

filter_out_jnt_type_idx = jp.where(link_types_mask)

filter_out_jnt_type_one_dofs = jp.where(~link_types_mask)

quat_loc_all_joints = quat_loc_all_joints.at[filter_out_jnt_type_idx].set(quat_combined)

quat_loc_all_joints = quat_loc_all_joints.at[filter_out_jnt_type_one_dofs].set(no_grouped_quaterions)

quat_loc_all_joints = jp.concatenate([current_root_rot.reshape(1,-1),quat_loc_all_joints],axis=0)


print(quat_loc_all_joints)

# #testing it it works the global chest

# global_quat_chest = quat_mul(state.pipeline_state.x.rot[0],quat_loc_all_joints[1])

# print(global_quat_chest)
#now to 6D rotatio
sixD_matrix = quaternion_to_rotation_6d(quat_loc_all_joints)

print(sixD_matrix)
print(sixD_matrix.shape)

[[-0.9988232  -0.0112808  -0.04714939 -0.00136157]
 [ 0.9849643  -0.019058    0.17024067 -0.02236166]
 [ 0.9297966  -0.10521828 -0.19586344  0.29333425]
 [ 0.8797037  -0.18030933  0.33273026  0.28792447]
 [ 0.63720536  0.         -0.7706941   0.        ]
 [ 0.91012216  0.3505731  -0.18436486 -0.12159649]
 [ 0.8807033   0.         -0.47366843  0.        ]
 [ 0.9315755  -0.0734667  -0.26540676 -0.23733708]
 [ 0.8185379  -0.          0.5744525  -0.        ]
 [ 0.9666973  -0.11025239 -0.16931759 -0.15707439]
 [ 0.88737756  0.0831369  -0.4261168   0.15515728]
 [ 0.8763479  -0.          0.48167875 -0.        ]
 [ 0.99998605  0.00348521  0.00191062  0.00347192]]
[[ 0.99555016 -0.00165616  0.09421854  0.0037837   0.9997418  -0.02240665]
 [ 0.9410361   0.03756199  0.3362144  -0.05053978  0.9982735   0.02992918]
 [ 0.75118506 -0.5042655  -0.42595455  0.5866992   0.80576825  0.08075629]
 [ 0.6127801  -0.6265652   0.48157722  0.38658774  0.76917607  0.50883996]
 [-0.18793869 -0.         -0.9821808

In [148]:
def update_rotation(local_angle, local_axis):
    # Convert the joint angle and axis to a quaternion
    local_rotation = axis_angle_to_quat(local_axis, local_angle)
    # Apply the local rotation to the parent rotation
    return local_rotation  

# Compute local rotations correctly applying local transformation first
chest_x_qloc = update_rotation(current_joints[0], env.sys.jnt_axis[1])
chest_y_qloc = update_rotation(current_joints[1], env.sys.jnt_axis[2])
chest_z_qloc = update_rotation(current_joints[2], env.sys.jnt_axis[3])

# Combine rotations in XYZ order, ensure the multiplication reflects this
combined_quat = quat_mul(quat_mul(chest_x_qloc, chest_y_qloc), chest_z_qloc)


print(combined_quat)


global_quat_chest = quat_mul(state.pipeline_state.x.rot[0],combined_quat)
print(global_quat_chest)

[ 0.9849643  -0.019058    0.17024067 -0.02236166]
[-0.976024    0.00921053 -0.21670713  0.01817523]


0.5384615384615384
