# Extract data corresponding to text caption

## 1. Text captions are stored in the 'texts' directory unziped from the 'texts.zip' file, which should look like:

texts\
├── 000000.txt\
├── 000001.txt\
├── ...\
├── 014615.txt\
├── M000000.txt\
├── M000001.txt\
├── ...\
└── M014615.txt\

where for example 000000.txt should contain:

`
a man kicks something or someone with his left leg.#a/DET man/NOUN kick/VERB something/PRON or/CCONJ someone/PRON with/ADP his/DET left/ADJ leg/NOUN#0.0#0.0
the standing person kicks with their left foot before going back to their original stance.#the/DET stand/VERB person/NOUN kick/VERB with/ADP their/DET left/ADJ foot/NOUN before/ADP go/VERB back/ADV to/ADP their/DET original/ADJ stance/NOUN#0.0#0.0
a man kicks with something or someone with his left leg.#a/DET man/NOUN kick/VERB with/ADP something/PRON or/CCONJ someone/PRON with/ADP his/DET left/ADJ leg/NOUN#0.0#0.0
he is flying kick with his left leg#he/PRON is/AUX fly/VERB kick/NOUN with/ADP his/DET left/ADJ leg/NOUN#0.0#0.0
`

and M000000.txt should contain:

`
a man kicks something or someone with his right leg.#a/DET man/NOUN kick/VERB something/PRON or/CCONJ someone/PRON with/ADP his/DET right/ADJ leg/NOUN#0.0#0.0
the standing person kicks with their right foot before going back to their original stance.#the/DET stand/VERB person/NOUN kick/VERB with/ADP their/DET right/ADJ foot/NOUN before/ADP go/VERB back/ADV to/ADP their/DET original/ADJ stance/NOUN#0.0#0.0
a man kicks with something or someone with his right leg.#a/DET man/NOUN kick/VERB with/ADP something/PRON or/CCONJ someone/PRON with/ADP his/DET right/ADJ leg/NOUN#0.0#0.0
he is flying kick with his right leg#he/PRON is/AUX fly/VERB kick/NOUN with/ADP his/DET right/ADJ leg/NOUN#0.0#0.0
`

## 2. By default, we assume AMASS dataset in under the directory 'data/AMASS/AMASS_Complete', which should look like as following under 'data':


AMASS\
└─&nbsp;AMASS_Complete\
&nbsp;&nbsp;&nbsp;&nbsp;├─&nbsp;ACCAD\
&nbsp;&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;├─&nbsp;Female1General_c3d\
&nbsp;&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;├─&nbsp;A1 - Stand_poses.npz\
&nbsp;&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;├─&nbsp;...\
&nbsp;&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;└─&nbsp;A15 - skip to stand_poses.npz\
&nbsp;&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;├─&nbsp;...\
&nbsp;&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;└─&nbsp;s011\
&nbsp;&nbsp;&nbsp;&nbsp;├─&nbsp;BioMotionLab_NTroje\
&nbsp;&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;└─&nbsp;...\
&nbsp;&nbsp;&nbsp;&nbsp;├─&nbsp;...\
&nbsp;&nbsp;&nbsp;&nbsp;├─&nbsp;Transitions_mocap\
&nbsp;&nbsp;&nbsp;&nbsp;│&nbsp;&nbsp;&nbsp;└─&nbsp;...\
&nbsp;&nbsp;&nbsp;&nbsp;└─&nbsp;LICENSE.txt

In [None]:
import pandas as pd
from tqdm import tqdm
import numpy as np
# Load csv file containing motion data information corresponding to the text
index_path = '../../resources/index.csv'
index_file = pd.read_csv(index_path)
# Read data in the order of the index file
humanml3d_data = {}
# for i in tqdm(range(index_file.shape[0])):
for i in tqdm(range(0, 100)):
    source_path = index_file.loc[i]['source_path']
    
    # TODO: set the dataset_name_index in the path according to the actual path in the 'index.csv' file
    dataset_name_index = 4
    names_in_path = source_path.split('/')
    dataset_name = names_in_path[dataset_name_index]
    subject_name = names_in_path[dataset_name_index + 1]
    action_name = names_in_path[dataset_name_index + 2]
    save_name = f"{dataset_name}_{subject_name}_{action_name[:-4]}" # This name should correspond to that in "amass_copycat_occlusion_v3.pkl" except from the '0-' in the beginning   
        
    # Skip the data of the humanact12 dataset, since it is not contained in standard AMASS dataset
    if 'humanact12' in source_path:
        continue
    
    source_path = f"../../{source_path}"
        
    new_name = index_file.loc[i]['new_name']
    data = np.load(source_path)
    start_frame = index_file.loc[i]['start_frame']
    end_frame = index_file.loc[i]['end_frame']
    
    data = dict(np.load(source_path))
    data['start_frame'] = start_frame
    data['end_frame'] = end_frame
    data['new_name'] = new_name
    
    
    humanml3d_data[save_name] = data

# Convert HumanML3D dataset into training dataset

## 1. 'data/occlusion/amass_copycat_occlusion_v3.pkl' is available [here](https://drive.google.com/uc?id=1uzFkT2s_zVdnAohPWHOLFcyRDq372Fmc).

In [2]:
import joblib
import torch
import torch.nn.functional as F
from torch import Tensor
from phc.smpllib.smpl_parser import SMPL_Parser
from phc.utils.transform_utils import convert_aa_to_orth6d
from phc.utils.flags import flags


# read occlusion data
occlusion_path = '../../data/occlusion/amass_copycat_occlusion_v3.pkl'
occlusion_data = joblib.load(occlusion_path)

# read smpl parsers
smpl_parser_n = SMPL_Parser(model_path="../../data/smpl", gender="neutral", use_pca=False, create_transl=False)
smpl_parser_m = SMPL_Parser(model_path="../../data/smpl", gender="male", use_pca=False, create_transl=False)
smpl_parser_f = SMPL_Parser(model_path="../../data/smpl", gender="female", use_pca=False, create_transl=False)

# parameters
target_frequency = 30
flags.debug = False

# HumanML3D parameters
humanml3d_frequency = 20
skip_frames = {'Eyes_Japan_Dataset' : 3 * target_frequency,     # 3 seconds
               'MPI_HDM05' : 3 * target_frequency,              # 3 seconds
               'TotalCapture' : 1 * target_frequency,           # 1 second
               'MPI_Limits' : 1 * target_frequency,             # 1 second
               'Transitions_mocap' : int(0.5 * target_frequency)# 0.5 second
               }

In [3]:

# functions to process data
def fix_height_smpl_vanilla(pose_aa, th_trans, th_betas, gender):
    # no filtering, just fix height. Make sure that the lowest point is greater or equal to 0, implying that the person is above the ground
    gender = gender.item() if isinstance(gender, np.ndarray) else gender
    if isinstance(gender, bytes):
        gender = gender.decode("utf-8")

    if gender == "neutral":
        smpl_parser = smpl_parser_n
    elif gender == "male":
        smpl_parser = smpl_parser_m
    elif gender == "female":
        smpl_parser = smpl_parser_f
    else:
        print(gender)
        raise Exception("Gender Not Supported!!")

    verts, jts = smpl_parser.get_joints_verts(pose_aa[0:1], th_betas.repeat((1, 1)), th_trans=th_trans[0:1])

    ground_position = torch.min(verts[:, :, 2])

    # if ground_position < 0:
    th_trans[:, 2] -= ground_position

    return th_trans

def process_data_dict(data_dict):
    amass_res = {}
    removed_k = []
    pbar = data_dict
    for k, v in tqdm(pbar.items()):
        k = "0-" + k
        seq_name = k
        new_name = v["new_name"]
        betas = v["betas"]
        gender = v["gender"]
        # downsample the sequence to the target frequency
        amass_frequency = v["mocap_framerate"]
        skip = int(amass_frequency / target_frequency)
        
        amass_pose = v["poses"][::skip]
        amass_trans = v["trans"][::skip]
        # segment the sequence according to the start and end frame
        start_frame_std_freq = v["start_frame"]
        end_frame_std_freq = v["end_frame"]
        start_frame = int(start_frame_std_freq * target_frequency / humanml3d_frequency)
        end_frame = int(end_frame_std_freq * target_frequency / humanml3d_frequency)
        for key, value in skip_frames.items():  # skip the first few frames if needed
            if key in k:
                start_frame = start_frame + value
                end_frame = end_frame + value
                break
        amass_pose = amass_pose[start_frame:end_frame]
        amass_trans = amass_trans[start_frame:end_frame]
        
        # check occlusion and skip those occluded sequences
        if k in occlusion_data:
            continue
        
        # if the sequence is too short, we skip the sequence
        seq_length = amass_pose.shape[0]
        if seq_length < 10:
            continue
        with torch.no_grad():
            batch_size = amass_pose.shape[0]
            amass_pose = np.concatenate([amass_pose[:, :66], np.zeros((batch_size, 6))], axis=1) # We use SMPL and not SMPLH, meaning that we don't use the hand joints data
            
            pose_aa = torch.tensor(amass_pose)
            amass_trans = torch.tensor(amass_trans)
            betas = torch.from_numpy(betas)

            amass_trans = fix_height_smpl_vanilla(
                pose_aa=pose_aa,
                th_betas=betas,
                th_trans=amass_trans,
                gender=gender,
            )

            pose_seq_6d = convert_aa_to_orth6d(torch.tensor(pose_aa)).reshape(batch_size, -1, 6)

            amass_res[new_name] = {
                "pose_aa": pose_aa.numpy(),
                "pose_6d": pose_seq_6d.numpy(),
                "trans": amass_trans.numpy(),
                "beta": betas.numpy(),
                "seq_name": seq_name,
                "gender": gender,
            }

        if flags.debug and len(amass_res) > 10:
            break
    print(removed_k)
    return amass_res


In [None]:
# process data
train_data = process_data_dict(humanml3d_data)

# Retarget training data to match robot structure

## 1. 'data/g1/optimized_shape_scale_g1.pkl' is generated by fit_robot_shape_g1.py

In [None]:

from copy import deepcopy
import os
import sys
from turtle import left

from sympy import jacobi
sys.path.append(os.getcwd())
from networkx import dorogovtsev_goltsev_mendes_graph
from phc.utils.torch_g1_humanoid_batch import Humanoid_Batch, G1_ROTATION_AXIS
from scipy.spatial.transform import Rotation as sRot
from torch.autograd import Variable
from phc.smpllib.smpl_parser import (
    SMPL_Parser,
    SMPL_BONE_ORDER_NAMES, 
)

# load smpl model optimized for robot
device = (
        torch.device("cuda", index=0)
        if torch.cuda.is_available()
        else torch.device("cpu")
    )
optimized_shape_scale = joblib.load("../../data/g1/optimized_shape_scale_g1.pkl")
shape_new = optimized_shape_scale["shape"].to(device)
optimized_scale = optimized_shape_scale["scale"].to(device)
# load smpl parser
smpl_parser_n = SMPL_Parser(model_path="../../data/smpl", gender="neutral")
smpl_parser_n.to(device)

# load robot forward kinematics model
robot_fk = Humanoid_Batch(device = device)

# robot specific parameters
robot_rotation_axis = G1_ROTATION_AXIS.to(device)


robot_link_names = [
    'pelvis',
    'left_hip_pitch_link',
    'left_hip_roll_link',
    'left_hip_yaw_link',
    'left_knee_link',
    'left_ankle_pitch_link',
    'left_ankle_roll_link',
    'right_hip_pitch_link',
    'right_hip_roll_link',
    'right_hip_yaw_link',
    'right_knee_link',
    'right_ankle_pitch_link',
    'right_ankle_roll_link',
    'waist_yaw_link',
    'waist_roll_link',
    'torso_link',
    'left_shoulder_pitch_link',
    'left_shoulder_roll_link',
    'left_shoulder_yaw_link',
    'left_elbow_link',
    'left_wrist_roll_link',
    'left_wrist_pitch_link',
    'left_wrist_yaw_link',
    'right_shoulder_pitch_link',
    'right_shoulder_roll_link',
    'right_shoulder_yaw_link',
    'right_elbow_link',
    'right_wrist_roll_link',
    'right_wrist_pitch_link',
    'right_wrist_yaw_link',
]

robot_joint_names = [
    "left_hip_pitch_joint",
    "left_hip_roll_joint",
    "left_hip_yaw_joint",
    "left_knee_joint",
    "left_ankle_pitch_joint",
    "left_ankle_roll_joint",
    "right_hip_pitch_joint",
    "right_hip_roll_joint",
    "right_hip_yaw_joint",
    "right_knee_joint",
    "right_ankle_pitch_joint",
    "right_ankle_roll_joint",
    "waist_yaw_joint",
    "waist_roll_joint",
    "waist_pitch_joint",
    "left_shoulder_pitch_joint",
    "left_shoulder_roll_joint",
    "left_shoulder_yaw_joint",
    "left_elbow_joint",
    "left_wrist_roll_joint",
    "left_wrist_pitch_joint",
    "left_wrist_yaw_joint",
    "right_shoulder_pitch_joint",
    "right_shoulder_roll_joint",
    "right_shoulder_yaw_joint",
    "right_elbow_joint",
    "right_wrist_roll_joint",
    "right_wrist_pitch_joint",
    "right_wrist_yaw_joint",
]

robot_link_pick = ['pelvis',
    'left_hip_pitch_link',
    'left_knee_link',
    'left_ankle_roll_link',
    'right_hip_pitch_link',
    'right_knee_link',
    'right_ankle_roll_link',
    'left_shoulder_roll_link',
    'left_elbow_link',
    'left_wrist_yaw_link',
    'right_shoulder_roll_link',
    'right_elbow_link',
    'right_wrist_yaw_link',
]
smpl_link_pick = [
    "Pelvis",
    "L_Hip",
    "L_Knee",
    "L_Ankle",
    "R_Hip",
    "R_Knee",
    "R_Ankle",
    "L_Shoulder",
    "L_Elbow",
    "L_Wrist",
    "R_Shoulder",
    "R_Elbow",
    "R_Wrist",
]

locked_joints = [
    "waist_roll_joint",
    "waist_pitch_joint",
    # "left_wrist_pitch_joint",
    # "right_wrist_pitch_joint",
]

hands_link = ["left_wrist_yaw_joint", "right_wrist_yaw_joint"]

robot_link_pick_idx = [ robot_link_names.index(j) for j in robot_link_pick]
smpl_link_pick_idx = [SMPL_BONE_ORDER_NAMES.index(j) for j in smpl_link_pick]
locked_joints_idx = [robot_joint_names.index(j) for j in locked_joints]
hands_link_idx = [robot_joint_names.index(j) for j in hands_link]

dict_joint_name_index = {}
for index, name in enumerate(robot_joint_names):
    dict_joint_name_index[name] = index

dict_link_name_index = {}
for index, name in enumerate(robot_link_names):
    dict_link_name_index[name] = index
    
dict_smpl_link_name_index = {}
for index, name in enumerate(SMPL_BONE_ORDER_NAMES):
    dict_smpl_link_name_index[name] = index

# list_selected_links = ["pelvis", "LL_faa", "LR_faa", "AL_sfe", "AL_efe", "AL_waa", "AR_sfe", "AR_efe", "AR_waa"]
list_selected_links = ["pelvis", "left_shoulder_roll_link", "left_elbow_link", "left_wrist_yaw_link", "left_ankle_roll_link", "right_shoulder_roll_link", "right_elbow_link", "right_wrist_yaw_link", "right_ankle_roll_link"]
list_selected_joints = ['waist_yaw_joint', 
                        'left_shoulder_pitch_joint', 'left_shoulder_roll_joint', 'left_shoulder_yaw_joint', 
                        'left_elbow_joint', 
                        'left_wrist_roll_joint', 'left_wrist_pitch_joint', 'left_wrist_yaw_joint', 
                        'right_shoulder_pitch_joint', 'right_shoulder_roll_joint', 'right_shoulder_yaw_joint',
                        'right_elbow_joint',
                        'right_wrist_roll_joint', 'right_wrist_pitch_joint', 'right_wrist_yaw_joint']

initial_joint_position_dict = {
    # "left_elbow_joint": 1.57,
    # "right_elbow_joint": 1.57,
}

# helper functions

# get hand orientation
def compute_hand_global_orientations(smpl_parser: SMPL_Parser, pose: torch.Tensor) -> tuple:
    """
    Calculate the global rotation matrices of the left and right hands.

    param:
        smpl_parser (SMPL_Parser)
        pose (torch.Tensor): shape (batch_size, J*3)

    返回:
        left_hand_rotmat (torch.Tensor): global rotation matrix for left hand, shape: (batch_size, 3, 3)
        right_hand_rotmat (torch.Tensor): global rotation matrix for right hand, shape: (batch_size, 3, 3)
    """
    global_rotmats = smpl_parser.get_global_orientations(pose)  # (B, J, 3, 3)

    # Get the indices of the left and right hands
    left_hand_idx = SMPL_BONE_ORDER_NAMES.index("L_Hand")
    right_hand_idx = SMPL_BONE_ORDER_NAMES.index("R_Hand")

    left_hand_rotmat = global_rotmats[:, left_hand_idx]  # (B, 3, 3)
    right_hand_rotmat = global_rotmats[:, right_hand_idx]  # (B, 3, 3)
    
    #  multiple the result bt an additional constant rotation matrix to align the hand with the robot hand
    additional_rotmat_L = torch.tensor([[0, 1, 0], [1, 0, 0], [0, 0, -1]], device=device).float() # y axis flip
    additional_rotmat_R = torch.tensor([[0, -1, 0], [-1, 0, 0], [0, 0, -1]], device=device).float() # another y axis flip
    # additional_rotmat = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], device=device).float() # x axis flip
    # additional_rotmat = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=device).float() # another x axis flip
    # additional_rotmat = torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=device).float() # z axis flip
    # additional_rotmat = torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=device).float() # another z axis flip
    
    left_hand_rotmat = left_hand_rotmat @ additional_rotmat_L
    right_hand_rotmat = right_hand_rotmat @ additional_rotmat_R
    
    # transpose the rotation matrix to align with the robot hand
    # left_hand_rotmat = left_hand_rotmat.transpose(1, 2)
    # right_hand_rotmat = right_hand_rotmat.transpose(1, 2)

    return left_hand_rotmat, right_hand_rotmat

def vee(skew_tensor):
    """
    Convert a skew-symmetric matrix to a vector.

    Args:
        skew_tensor (torch.Tensor): [batch_size, num_links, 3, 3]

    Returns:
        torch.Tensor: [batch_size, num_links, 3]
    """
    return torch.stack([skew_tensor[:, :, 2, 1], skew_tensor[:, :, 0, 2], skew_tensor[:, :, 1, 0]], dim=2)

# retarget data
def compute_batch_jacobian(var_dof_pose, const_root_trans, const_root_pose, epsilon=1e-4):
    """
    Compute the batch Jacobian matrix for both position and orientation errors.

    Args:
        var_dof_pose (torch.tensor): [1, num_frames, len(robot_joint_names), 1]
        const_root_trans (torch.tensor): [num_frames, 3]
        const_root_pose (torch.tensor): [num_frames, 3]
        epsilon (float, optional): Perturbation size. Defaults to 1e-4.

    Returns:
        tuple: 
            - jacobian_positions (torch.tensor): [num_frames, 3 * len(robot_link_pick_idx), num_joints]
            - jacobian_orientations (torch.tensor): [num_frames, 6, num_joints]
    """
    _, num_frames, num_joints, _ = var_dof_pose.shape
    
    # Base pose
    pose_aa_base = torch.cat([
        const_root_pose[None, :, None], 
        robot_rotation_axis * var_dof_pose
    ], dim=2)  # [1, num_frames, ...,]
    fk_base = robot_fk.fk_batch(pose_aa_base, const_root_trans[None, :])
    base_positions = fk_base['global_translation'][:, :, robot_link_pick_idx]  # [1, N, L, 3]
    base_rotations = fk_base['global_rotation_mat'][:, :, hands_link_idx]  # [1, N, 2, 3, 3]
    
    # Initialize Jacobians
    jacobian_positions = torch.zeros(num_frames, 3 * len(robot_link_pick_idx), num_joints, device=device)
    jacobian_orientations = torch.zeros(num_frames, 6, num_joints, device=device)
    
    for j in range(num_joints):
        # Perturb joint j by epsilon
        var_dof_pose_perturbed = var_dof_pose.clone()
        var_dof_pose_perturbed[0, :, j, 0] += epsilon
        
        pose_aa_perturbed = torch.cat([
            const_root_pose[None, :, None], 
            robot_rotation_axis * var_dof_pose_perturbed
        ], dim=2)  # [1, N, ...,]
        fk_perturbed = robot_fk.fk_batch(pose_aa_perturbed, const_root_trans[None, :])
        
        # Position differences
        perturbed_positions = fk_perturbed['global_translation'][:, :, robot_link_pick_idx]  # [1, N, L, 3]
        delta_positions = (perturbed_positions - base_positions) / epsilon  # [1, N, L, 3]
        delta_positions = delta_positions[0].reshape(num_frames, 3 * len(robot_link_pick_idx))  # [N, 3L]
        jacobian_positions[:, :, j] = delta_positions
        
        # Orientation differences
        perturbed_rotations = fk_perturbed['global_rotation_mat'][:, :, hands_link_idx]  # [1, N, 2, 3, 3]
        delta_rot = 0.5 * (perturbed_rotations.transpose(3,4) @ base_rotations - 
                           base_rotations.transpose(3,4) @ perturbed_rotations) / epsilon  # [1, N, 2, 3, 3]
        delta_rot = vee(delta_rot.squeeze(0))  # [1, N, 6]
        jacobian_orientations[:, :, j] = delta_rot.view(-1, 6)
    
    return jacobian_positions, jacobian_orientations
# -----------------------------------
def compute_batch_diff(const_dof_pose, const_root_trans, const_root_pose, const_smpl_positions, 
                      left_hand_rot_desired, right_hand_rot_desired):
    """
    Compute the difference in key points positions and hand orientations.

    Args:
        const_dof_pose (torch.tensor): [batch_size, num_frames, len(robot_joint_names), 1]
        const_root_trans (torch.tensor): [num_frames, 3]
        const_root_pose (torch.tensor): [num_frames, 3]
        const_smpl_positions (torch.tensor): [num_frames, len(smpl_link_pick_idx), 3]
        left_hand_rot_desired (torch.Tensor): [num_frames, 3, 3]
        right_hand_rot_desired (torch.Tensor): [num_frames, 3, 3]

    Returns:
        tuple: 
            - positions_diff (torch.tensor): [num_frames, len(robot_link_pick_idx), 3]
            - orientations_diff (torch.tensor): [num_frames, 6] (3 for each hand)
    """
    batch_size = const_dof_pose.shape[0]
    num_frames = const_dof_pose.shape[1]
    
    # Compute forward kinematics
    pose_aa = torch.cat([const_root_pose[None, :, None], robot_rotation_axis * const_dof_pose], axis = 2).to(device)
    fk_return = robot_fk.fk_batch(pose_aa, const_root_trans[None, ])
    
    # Position differences
    robot_link_positions = fk_return['global_translation'][:, :, robot_link_pick_idx]  # [B, N, len(robot_link_pick_idx), 3]
    positions_diff = robot_link_positions - const_smpl_positions  # [B, N, len(robot_link_pick_idx), 3]
    positions_diff = positions_diff[0]  # [N, len(robot_link_pick_idx), 3]
    
    # Orientation differences
    robot_hand_rot = fk_return['global_rotation_mat'][:, :, hands_link_idx]  # [B, N, 2, 3, 3]
    left_hand_rot = robot_hand_rot[:, :, 0]  # [B, N, 3, 3]
    right_hand_rot = robot_hand_rot[:, :, 1]  # [B, N, 3, 3]
    
    # Desired orientations
    left_diff_rot = 0.5 * (left_hand_rot_desired.unsqueeze(0).transpose(2, 3) @ left_hand_rot - 
                           left_hand_rot.transpose(2, 3) @ left_hand_rot_desired.unsqueeze(0))
    right_diff_rot = 0.5 * (right_hand_rot_desired.unsqueeze(0).transpose(2, 3) @ right_hand_rot - 
                            right_hand_rot.transpose(2, 3) @ right_hand_rot_desired.unsqueeze(0))
    
    # Apply vee operator
    left_diff = vee(left_diff_rot)  # [1, N, 3]
    right_diff = vee(right_diff_rot)  # [1, N, 3]
    
    # Concatenate orientation differences
    orientations_diff = torch.cat([left_diff, right_diff], dim=2).squeeze(0)  # [N, 6]
    
    return positions_diff, orientations_diff
# -----------------------------------
# Precompute the identity matrix once and move to the appropriate device
robot_joint_names_length = len(robot_joint_names)
identity_matrix = torch.eye(robot_joint_names_length, device=device)  # [num_joints, num_joints]

def compute_batch_LM_step(
    jacobian_mat, 
    diff, 
    const_last_dof_pos, 
    lambda_val=1e-3, 
    smooth_weight=1e-2
):
    """
    Compute the batch Levenberg-Marquardt step for updating DOF positions with smoothness constraints.

    Args:
        jacobian_mat (torch.Tensor): [num_frames, 3 * len(robot_link_pick_idx), num_joints]
        diff (torch.Tensor): [num_frames, 3 * len(robot_link_pick_idx)]
        const_last_dof_pos (torch.Tensor): [batch_size, num_frames, num_joints, 1]
        lambda_val (float, optional): Damping factor for LM inverse kinematic. Defaults to 1e-3.
        smooth_weight (float, optional): Weight for penalizing large changes in joint angles. Defaults to 1e-2.

    Returns:
        torch.Tensor: [num_frames, num_joints, 1]
    """
    num_frames, dim_diff, num_joints = jacobian_mat.shape

    # Compute J^T J for all frames: [num_frames, num_joints, num_joints]
    JTJ = torch.bmm(jacobian_mat.transpose(1, 2), jacobian_mat)  # [num_frames, num_joints, num_joints]

    # Add damping factor and smoothness weight
    damping = lambda_val * identity_matrix  # [num_joints, num_joints]
    smooth_mat = smooth_weight * identity_matrix  # [num_joints, num_joints]

    # Prepare smoothness weights for each frame
    smooth_weights = smooth_mat.unsqueeze(0).repeat(num_frames, 1, 1)  # [num_frames, num_joints, num_joints]
    smooth_weights[1:-1] += smooth_mat  # Middle frames have two neighbors

    # Add damping and smoothness to JTJ
    JTJ_damped = JTJ + smooth_weights + damping.unsqueeze(0)  # [num_frames, num_joints, num_joints]

    # Construct block diagonal Hessian
    H = torch.block_diag(*JTJ_damped)  # [num_frames*num_joints, num_frames*num_joints]
    # Add smoothness weights to block sub-diagonal and super-diagonal
    for i in range(1, num_frames):
        # Add smoothness cost weight
        H[i * num_joints : (i + 1) * num_joints, (i - 1) * num_joints : i * num_joints] = -smooth_mat
        H[(i - 1) * num_joints : i * num_joints, i * num_joints : (i + 1) * num_joints] = -smooth_mat

    # Compute J^T * diff: [num_frames, num_joints, 1]
    JT_diff = torch.bmm(jacobian_mat.transpose(1, 2), diff.unsqueeze(-1)).squeeze(-1)  # [num_frames, num_joints]

    # Add smoothness gradients
    p = const_last_dof_pos[0].squeeze(-1)  # [num_frames, num_joints]
    smooth_grad = torch.zeros_like(JT_diff)
    smooth_grad[0] += smooth_weight * (p[0] - p[1])
    smooth_grad[-1] += smooth_weight * (p[-1] - p[-2])
    smooth_grad[1:-1] += smooth_weight * (2 * p[1:-1] - p[:-2] - p[2:])

    # Combine gradients
    grad = JT_diff + smooth_grad  # [num_frames, num_joints]

    # Flatten gradient
    grad_flat = grad.reshape(-1, 1)  # [num_frames*num_joints, 1]

    # Solve H * step = grad_flat
    step_flat = torch.linalg.solve(H, grad_flat)  # [num_frames*num_joints, 1]

    # Reshape step to [num_frames, num_joints, 1]
    step = step_flat.reshape(num_frames, num_joints, 1)

    return step
# -----------------------------------
# parameters
initial_lambda = 1e-1
lambda_increase_factor = 10
lambda_decrease_factor = 0.5
max_lambda = 1e3
min_lambda = 1e-2
# Smoothness weight, for penalizing large changes in joint angles
smooth_weight = 2e-2
# retarget data, storing the results
retarget_data = {}
pbar = tqdm(train_data.keys())
for data_key in pbar:
    # translation
    trans = torch.from_numpy(train_data[data_key]['trans']).float().to(device)
    N = trans.shape[0]
    pose_aa_walk = torch.from_numpy(np.concatenate((train_data[data_key]['pose_aa'][:, :66], np.zeros((N, 6))), axis = -1)).float().to(device)
    # get joints, verts, and offset
    verts, joints = smpl_parser_n.get_joints_verts(pose_aa_walk, torch.zeros((1, 10)).to(device), trans)
    offset = joints[:, 0] - trans
    root_trans_offset = trans * optimized_scale + offset
    # get root rotation
    gt_root_rot = torch.from_numpy((sRot.from_rotvec(pose_aa_walk.cpu().numpy()[:, :3]) * sRot.from_quat([0.5, 0.5, 0.5, 0.5]).inv()).as_rotvec()).float().to(device)
    # prepare for iteration
    dof_pos = torch.zeros((1, N, 29, 1)).to(device)
    # set initial joint position
    for joint_name, joint_pos in initial_joint_position_dict.items():
        dof_pos[:, :, dict_joint_name_index[joint_name], 0] = joint_pos
    # iterational ik optimization
    last_loss = 1e10 # set last loss to be large enough for LM update
    lambda_val = initial_lambda
    # get target keypoint positions and hand orientations
    # positions
    verts, joints = smpl_parser_n.get_joints_verts(pose_aa_walk, shape_new, trans)
    joints = (joints - offset.unsqueeze(1)) * optimized_scale + offset.unsqueeze(1)
    # hand orientations
    left_hand_rotmat, right_hand_rotmat = compute_hand_global_orientations(smpl_parser_n, pose_aa_walk) # shape: [N, 3, 3]
    for iteration in range(5):
        # Compute Jacobians
        jacobian_pos, jacobian_ori = compute_batch_jacobian(dof_pos, root_trans_offset, gt_root_rot)
        # Compute differences
        pos_diff, ori_diff = compute_batch_diff(dof_pos, root_trans_offset, gt_root_rot, joints[:, smpl_link_pick_idx], 
                                               left_hand_rotmat, right_hand_rotmat)
        
        # Concatenate diffs and Jacobians
        diff = torch.cat([pos_diff.reshape(pos_diff.shape[0], -1), ori_diff], dim=1)  # [N, 3L + 6]
        jacobian = torch.cat([jacobian_pos, jacobian_ori], dim=1)  # [N, 3L + 6, num_joints]
        if iteration < 5:
            diff = pos_diff.reshape(pos_diff.shape[0], -1)  # [N, 3L]
            jacobian = jacobian_pos
        # diff = ori_diff  # [N, 6]
        # jacobian = jacobian_ori
        
        # Compute loss
        loss = diff.norm(dim=1).mean()  # You might want to weight position and orientation errors differently
        
        # Compute LM step
        step = compute_batch_LM_step(jacobian, diff, dof_pos, lambda_val=1e-3, smooth_weight=smooth_weight)
        
        # Update DOF positions
        propose_dof_pos = (dof_pos - step).clone().clamp_(robot_fk.joints_range[:, 0, None], robot_fk.joints_range[:, 1, None])
        pos_diff_updated, ori_diff_updated = compute_batch_diff(propose_dof_pos, root_trans_offset, gt_root_rot, joints[:, smpl_link_pick_idx], 
                                                                left_hand_rotmat, right_hand_rotmat)
        loss_updated = torch.cat([pos_diff_updated.reshape(pos_diff_updated.shape[0], -1), ori_diff_updated], dim=1).norm(dim=1).mean()
        
        if loss_updated > last_loss:
            lambda_val *= lambda_increase_factor
            lambda_val = min(lambda_val, max_lambda)
        else:
            lambda_val *= lambda_decrease_factor
            lambda_val = max(lambda_val, min_lambda)
            last_loss = loss_updated
            dof_pos = propose_dof_pos.clone()
            # dof_pos.data.clamp_(robot_fk.joints_range[:, 0, None], robot_fk.joints_range[:, 1, None])
            dof_pos[:, :, locked_joints_idx, 0] = 0.0
    # clamp after optimization
    dof_pos.data.clamp_(robot_fk.joints_range[:, 0, None], robot_fk.joints_range[:, 1, None])
    # get new pose and calculate fk
    pose_aa_new = torch.cat([gt_root_rot[None, :, None], robot_rotation_axis * dof_pos], axis = 2)
    fk_return = robot_fk.fk_batch(pose_aa_new, root_trans_offset[None, ])
    # save retargeted data
    root_trans_offset_dump = root_trans_offset.clone()
    global_translation = fk_return.global_translation.clone().squeeze(0)
    # decrese the height by height_correction
    height_correction = global_translation[..., 2].min().item() - 0.015    # 0.015 is the height of the robot foot
    root_trans_offset_dump[..., 2] -= height_correction
    global_translation[..., 2] -= height_correction
    mocap_global_translation = joints.clone()
    mocap_global_translation[..., 2] -= height_correction
    # read the captions
    def read_captions_from_file(file_path):
        with open(file_path, 'r') as f:
            lines = f.readlines()
        
        # Each line contains a caption with its annotations, split by newline and remove empty lines
        captions = [line.strip() for line in lines if line.strip()]
        
        # Extract just the first part of each line (the raw caption) before the first #
        raw_captions = [caption.split('#')[0] for caption in captions]
        
        return raw_captions
    caption_path = f"../../data/texts/{data_key.replace('.npz', '.txt')}"
    captions = read_captions_from_file(caption_path)
    
    retarget_data[data_key] = {
            "root_trans_offset": root_trans_offset_dump.squeeze().cpu().detach().numpy(),
            "dof": dof_pos.squeeze().cpu().detach().numpy(),
            "root_rot": sRot.from_rotvec(gt_root_rot.cpu().numpy()).as_quat(),
            "global_translation": global_translation.cpu().detach(),
            "mocap_global_translation": mocap_global_translation.cpu().detach(),
            "captions": captions,
            "left_hand_rotmat": left_hand_rotmat.cpu().detach().numpy(),
            "right_hand_rotmat": right_hand_rotmat.cpu().detach().numpy(),
    }
    
    # mirror the data
    mirror_caption_path = f"../../data/texts/M{data_key.replace('.npz', '.txt')}"
    mirror_caption = read_captions_from_file(mirror_caption_path)
    
    mirror_root_trans = root_trans_offset_dump.clone()
    # mirror_pose_aa = pose_aa_new.clone()
    mirror_dof = dof_pos.clone()
    mirror_gt_root_rot = gt_root_rot.clone()
    mirror_global_translation = global_translation.clone()
    mirror_mocap_global_translation = mocap_global_translation.clone()
    mirror_left_hand_rotmat = left_hand_rotmat.clone()
    mirror_right_hand_rotmat = right_hand_rotmat.clone()
    
    # inverse the y axis in translation, yaw angle in rotation and dof position of symmetric joints
    # Invert the Y-axis in translations
    mirror_root_trans[:, 1] = -mirror_root_trans[:, 1]
    mirror_global_translation[:, :, 1] = -mirror_global_translation[:, :, 1]
    mirror_mocap_global_translation[:, :, 1] = -mirror_mocap_global_translation[:, :, 1]

    # Adjust root rotation
    mirror_gt_root_rot = sRot.from_rotvec(mirror_gt_root_rot.cpu().numpy()).as_matrix()
    M = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
    mirror_gt_root_rot = M @ mirror_gt_root_rot @ M.T
    mirror_gt_root_rot = sRot.from_matrix(mirror_gt_root_rot).as_quat()
    
    mirror_left_hand_rotmat = M @ mirror_left_hand_rotmat.cpu().numpy() @ M.T
    mirror_right_hand_rotmat = M @ mirror_right_hand_rotmat.cpu().numpy() @ M.T
    
    
    # Define symmetric link pairs for swapping
    symmetric_links = [
        ('left_hip_pitch_link', 'right_hip_pitch_link'),
        ('left_hip_roll_link', 'right_hip_roll_link'),
        ('left_hip_yaw_link', 'right_hip_yaw_link'),
        ('left_knee_link', 'right_knee_link'),
        ('left_ankle_pitch_link', 'right_ankle_pitch_link'),
        ('left_ankle_roll_link', 'right_ankle_roll_link'),
        ('left_shoulder_pitch_link', 'right_shoulder_pitch_link'),
        ('left_shoulder_roll_link', 'right_shoulder_roll_link'),
        ('left_shoulder_yaw_link', 'right_shoulder_yaw_link'),
        ('left_elbow_link', 'right_elbow_link'),
        ('left_wrist_roll_link', 'right_wrist_roll_link'),
        ('left_wrist_pitch_link', 'right_wrist_pitch_link'),
        ('left_wrist_yaw_link', 'right_wrist_yaw_link'),
    ]

    # Swap the symmetric links' global translations
    for left_link, right_link in symmetric_links:
        left_idx = dict_link_name_index[left_link]
        right_idx = dict_link_name_index[right_link]

        # Swap global translations
        temp_trans = mirror_global_translation[:, left_idx, :].clone()
        mirror_global_translation[:, left_idx, :] = mirror_global_translation[:, right_idx, :]
        mirror_global_translation[:, right_idx, :] = temp_trans
        
    symmetric_smpl_links = [
        ('L_Hip', 'R_Hip'),
        ('L_Knee', 'R_Knee'),
        ('L_Ankle', 'R_Ankle'),
        ('L_Shoulder', 'R_Shoulder'),
        ('L_Elbow', 'R_Elbow'),
        ('L_Wrist', 'R_Wrist'),
    ]
    
    # Swap the symmetric links' global translations
    for left_link, right_link in symmetric_smpl_links:
        left_idx = dict_smpl_link_name_index[left_link]
        right_idx = dict_smpl_link_name_index[right_link]

        # Swap mocap global translations
        temp_mocap_trans = mirror_mocap_global_translation[:, left_idx, :].clone()
        mirror_mocap_global_translation[:, left_idx, :] = mirror_mocap_global_translation[:, right_idx, :]
        mirror_mocap_global_translation[:, right_idx, :] = temp_mocap_trans

    # Define symmetric joint pairs for swapping
    symmetric_joints = [
        ('left_shoulder_pitch_joint', 'right_shoulder_pitch_joint'),
        ('left_shoulder_roll_joint', 'right_shoulder_roll_joint'),
        ('left_shoulder_yaw_joint', 'right_shoulder_yaw_joint'),
        ('left_elbow_joint', 'right_elbow_joint'),
        ('left_wrist_roll_joint', 'right_wrist_roll_joint'),
        ('left_wrist_pitch_joint', 'right_wrist_pitch_joint'),
        ('left_wrist_yaw_joint', 'right_wrist_yaw_joint'),
        {'left_hip_pitch_joint', 'right_hip_pitch_joint'},
        {'left_hip_roll_joint', 'right_hip_roll_joint'},
        {'left_hip_yaw_joint', 'right_hip_yaw_joint'},
        {'left_knee_joint', 'right_knee_joint'},
        {'left_ankle_pitch_joint', 'right_ankle_pitch_joint'},
        {'left_ankle_roll_joint', 'right_ankle_roll_joint'},
    ]

    # Swap the DOF positions and negate yaw angles for symmetric joints
    for left_joint, right_joint in symmetric_joints:
        left_idx = dict_joint_name_index[left_joint]
        right_idx = dict_joint_name_index[right_joint]

        # Store the left joint DOF temporarily
        temp = mirror_dof[:, :, left_idx, 0].clone()

        # Swap DOF positions between left and right joints
        mirror_dof[:, :, left_idx, 0] = mirror_dof[:, :, right_idx, 0]
        mirror_dof[:, :, right_idx, 0] = temp

        # Negate the yaw angles for both joints if they have yaw
        if 'yaw_joint' in left_joint or 'roll_joint' in left_joint:
            mirror_dof[:, :, left_idx, 0] = -mirror_dof[:, :, left_idx, 0]
            mirror_dof[:, :, right_idx, 0] = -mirror_dof[:, :, right_idx, 0]

    # Swap the left and right hand rotation matrices
    mirror_left_hand_rotmat, mirror_right_hand_rotmat = deepcopy(mirror_right_hand_rotmat), deepcopy(mirror_left_hand_rotmat)
    #
    # Save the mirrored data with a prefixed key 'M'
    retarget_data['M' + data_key] = {
        "root_trans_offset": mirror_root_trans.squeeze().cpu().detach().numpy(),
        "dof": mirror_dof.squeeze().cpu().detach().numpy(),
        "root_rot": mirror_gt_root_rot,
        "global_translation": mirror_global_translation.cpu().detach(),
        "mocap_global_translation": mirror_mocap_global_translation.cpu().detach(),
        "captions": mirror_caption,
        "left_hand_rotmat": mirror_left_hand_rotmat,
        "right_hand_rotmat": mirror_right_hand_rotmat,
    }
    
    
# configuration information
retarget_data['config'] = {
    "frequency": target_frequency,
    "dict_joint_name_index" : dict_joint_name_index,
    "dict_link_name_index" : dict_link_name_index,
    "list_selected_links" : list_selected_links,
    "list_selected_joints" : list_selected_joints,
    "left_foot_name" : "left_ankle_roll_link",
    "right_foot_name" : "right_ankle_roll_link",
} 
# save retargeted data
joblib.dump(retarget_data, f'../../data/g1/humanml3d_train_retargeted_wholebody_{len(train_data.keys())}.pkl')
        
        