# Project Setup for Colab and Kaggle

This notebook was automatically bundled for cloud execution. Run the cell below to reconstruct the project structure and install dependencies.

In [None]:
# =========================================================
# CLOUD ENVIRONMENT SETUP (AUTO-GENERATED)
# =========================================================
import os
import sys
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules
IN_KAGGLE = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ

if IN_COLAB or IN_KAGGLE:
    print("Running in Cloud Environment")
    
    # Write supporting files
    FILES = {
        'config.py': "from pathlib import Path\nfrom dataclasses import dataclass\nfrom typing import Optional, List, Tuple\n\n@dataclass\nclass Config:\n    device: str = 'cuda'\n    seed: int = 42\n    dataset_path: Path = Path('./dataset/humanml3d-subset')\n    output_path: Path = Path('./generation')\n    checkpoint_dir: Path = Path('./checkpoints')\n    motion_dim: int = 263\n    num_joints: int = 22\n    joint_dim: int = 3\n    max_motion_length: int = 200\n    fps: int = 20\n    feature_dims: tuple[slice, ...] = (slice(0, 4), slice(4, 67), slice(193, 259), slice(259, 263))\n    dataset_name: str = 't2m'\n    unit_length: int = 4\n    hidden_dim: int = 512\n    num_encoder_layers: int = 3\n    dropout: float = 0.1\n    bidirectional_gru: bool = False\n    num_flow_layers: int = 12\n    flow_hidden_dim: int = 512\n    num_timesteps: int = 1000\n    batch_size: int = 64\n    learning_rate: float = 0.0001\n    num_epochs: int = 100\n    weight_decay: float = 1e-05\n    gradient_clip: float = 1.0\n    warmup_steps: int = 1000\n    lr_decay: float = 0.95\n    lr_decay_epoch: int = 10\n    flow_loss_weight: float = 1.0\n    context_loss_weight: float = 0.1\n    num_inference_steps: int = 50\n    guidance_scale: float = 1.0\n    num_workers: int = 4\n    pin_memory: bool = True\n    log_interval: int = 100\n    save_interval: int = 5\n    eval_interval: int = 1\n    num_eval_samples: int = 100\n    eval_batch_size: int = 32\n\n    def __post_init__(self):\n        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)\n        self.output_path.mkdir(parents=True, exist_ok=True)\n        self.dataset_path.mkdir(parents=True, exist_ok=True)\n\n    @property\n    def context_encoder_output_dim(self) -> int:\n        return self.hidden_dim\n\n    def to_dict(self) -> dict:\n        return {k: str(v) if isinstance(v, Path) else v for k, v in self.__dict__.items()}",
        'models.py': 'import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Optional, List\n\nclass AutoregressiveContextEncoder(nn.Module):\n\n    def __init__(self, input_dim: int=263, hidden_dim: int=512, num_layers: int=3, dropout: float=0.1, max_seq_length: int=196, bidirectional: bool=False):\n        super().__init__()\n        self.input_dim = input_dim\n        self.hidden_dim = hidden_dim\n        self.num_layers = num_layers\n        self.max_seq_length = max_seq_length\n        self.bidirectional = bidirectional\n        self.input_projection = nn.Linear(input_dim, hidden_dim)\n        self.gru = nn.GRU(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.0, bidirectional=bidirectional)\n        self.text_encoder = None\n        gru_output_dim = hidden_dim * 2 if bidirectional else hidden_dim\n        self.output_projection = nn.Linear(gru_output_dim, hidden_dim)\n\n    @property\n    def output_dim(self) -> int:\n        return self.hidden_dim\n\n    def forward(self, motion: torch.Tensor, text: Optional[List[str]]=None, mask: Optional[torch.Tensor]=None) -> torch.Tensor:\n        batch_size, seq_len, _ = motion.shape\n        x = self.input_projection(motion)\n        if text is not None:\n            pass\n        if mask is not None:\n            lengths = mask.sum(dim=1).cpu()\n            x_packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)\n            output_packed, hidden = self.gru(x_packed)\n            output, _ = nn.utils.rnn.pad_packed_sequence(output_packed, batch_first=True, total_length=seq_len)\n        else:\n            output, hidden = self.gru(x)\n        context = self.output_projection(output)\n        return context\n\nclass FlowMatchingNetwork(nn.Module):\n\n    def __init__(self, context_dim: int=512, motion_dim: int=263, hidden_dim: int=512, num_layers: int=12, dropout: float=0.1, num_timesteps: int=1000):\n        super().__init__()\n        self.context_dim = context_dim\n        self.motion_dim = motion_dim\n        self.hidden_dim = hidden_dim\n        self.num_timesteps = num_timesteps\n        self.time_embedding = nn.Sequential(nn.Linear(1, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim))\n        self.context_projection = nn.Linear(context_dim, hidden_dim)\n        layers = []\n        for i in range(num_layers):\n            layers.append(FlowMatchingLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, motion_dim=motion_dim if i == num_layers - 1 else hidden_dim, dropout=dropout))\n        self.flow_layers = nn.ModuleList(layers)\n        self.output_projection = nn.Linear(hidden_dim, motion_dim)\n\n    def forward(self, context: torch.Tensor, motion: Optional[torch.Tensor]=None, timestep: Optional[torch.Tensor]=None) -> torch.Tensor:\n        batch_size, seq_len, _ = context.shape\n        x = self.context_projection(context)\n        if timestep is None:\n            timestep = torch.rand(batch_size, device=context.device)\n        t_emb = self.time_embedding(timestep.unsqueeze(-1))\n        t_emb = t_emb.unsqueeze(1).expand(-1, seq_len, -1)\n        x = x + t_emb\n        for layer in self.flow_layers:\n            x = layer(x, motion if motion is not None else None)\n        output = self.output_projection(x)\n        return output\n\nclass FlowMatchingLayer(nn.Module):\n\n    def __init__(self, input_dim: int, hidden_dim: int, motion_dim: int, dropout: float=0.1):\n        super().__init__()\n        self.layer = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, motion_dim))\n\n    def forward(self, x: torch.Tensor, motion: Optional[torch.Tensor]=None) -> torch.Tensor:\n        output = self.layer(x)\n        if motion is not None and motion.shape[-1] == output.shape[-1]:\n            output = output + motion\n        return output',
        'requirements.txt': "# Core ML dependencies\ntorch>=1.9.0\ntorchvision>=0.10.0\nnumpy>=1.21.0\nscipy>=1.7.0\n\n# Data processing\npandas>=1.3.0\n\n# Visualization\nmatplotlib>=3.4.0\nseaborn>=0.11.0\n\n# # Jupyter\n# jupyter>=1.0.0\n# ipykernel>=6.0.0\n# notebook>=6.4.0\n\n# Utilities\ntqdm>=4.62.0\ngdown>=4.4.0\npathlib2>=2.3.6; python_version < '3.4'\n\n# Optional: For HumanML3D dataset compatibility\n# (Add specific HumanML3D dependencies if available)\n# humanml3d>=0.1.0\n\n# Optional: For text encoding (if using CLIP)\n# git+https://github.com/openai/CLIP.git\n\n# Optional: For SMPL models (if needed for visualization)\n# smplx>=0.1.28\n\n# Optional: For advanced visualization\n# plotly>=5.0.0\n# opencv-python>=4.5.0\n",
        'utils/utils.py': "from .dataset import Text2MotionDataset, create_dataloader, load_sample\nfrom .motion_utils import DATASET_CONFIGS, get_dataset_config, feature_to_joints, joints_to_feature, extract_features, recover_from_ric\nfrom .visualization import plot_3d_motion, visualize_motion, compare_motions\nfrom .bvh_utils import joints_to_bvh, save_bvh, save_joints, validate_bvh\nfrom .metrics import compute_metrics\n__all__ = ['Text2MotionDataset', 'create_dataloader', 'load_sample', 'DATASET_CONFIGS', 'get_dataset_config', 'feature_to_joints', 'joints_to_feature', 'extract_features', 'recover_from_ric', 'plot_3d_motion', 'visualize_motion', 'compare_motions', 'joints_to_bvh', 'save_bvh', 'save_joints', 'validate_bvh', 'compute_metrics']",
        'utils/dataset.py': "import numpy as np\nfrom os.path import join as pjoin\nimport random\nfrom tqdm import tqdm\nfrom torch.utils.data import Dataset, DataLoader\nfrom pathlib import Path\nfrom typing import List, Dict, Any, Optional\nfrom config import Config\nfrom .motion_utils import extract_feature_subset\n\nclass Text2MotionDataset(Dataset):\n\n    def __init__(self, config: Config, mean: np.ndarray, std: np.ndarray, split: str='train', feature_dims: tuple[slice, ...] | None=None):\n        self.config = config\n        self.feature_dims = feature_dims if feature_dims is not None else config.feature_dims\n        self.max_length = 20\n        self.pointer = 0\n        self.max_motion_length = config.max_motion_length\n        min_motion_len = 40 if config.dataset_name == 't2m' else 24\n        motion_dir = config.dataset_path / 'new_joint_vecs'\n        joints_dir = config.dataset_path / 'new_joints'\n        text_dir = config.dataset_path / 'texts'\n        split_file = config.dataset_path / f'{split}.txt'\n        data_dict = {}\n        id_list = []\n        with open(str(split_file), 'r', encoding='utf-8') as f:\n            for line in f.readlines():\n                id_list.append(line.strip())\n        new_name_list = []\n        length_list = []\n        for name in tqdm(id_list):\n            try:\n                motion = np.load(pjoin(str(motion_dir), name + '.npy'))\n                joints = np.load(pjoin(str(joints_dir), name + '.npy'))\n                if len(motion) < min_motion_len or len(motion) >= 200:\n                    continue\n                text_data = []\n                flag = False\n                with open(pjoin(str(text_dir), name + '.txt'), 'r', encoding='utf-8') as f:\n                    for line in f.readlines():\n                        text_dict: Dict[str, Optional[Any]] = {}\n                        line_split = line.strip().split('#')\n                        caption = line_split[0]\n                        tokens = line_split[1].split(' ')\n                        f_tag = float(line_split[2])\n                        to_tag = float(line_split[3])\n                        f_tag = 0.0 if np.isnan(f_tag) else f_tag\n                        to_tag = 0.0 if np.isnan(to_tag) else to_tag\n                        text_dict['caption'] = caption\n                        text_dict['tokens'] = tokens\n                        if f_tag == 0.0 and to_tag == 0.0:\n                            flag = True\n                            text_data.append(text_dict)\n                        else:\n                            try:\n                                n_motion = motion[int(f_tag * 20):int(to_tag * 20)]\n                                if len(n_motion) < min_motion_len or len(n_motion) >= 200:\n                                    continue\n                                new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name\n                                while new_name in data_dict:\n                                    new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name\n                                n_joints = joints[int(f_tag * 20):int(to_tag * 20)]\n                                data_dict[new_name] = {'motion': n_motion, 'joints': n_joints, 'length': len(n_motion), 'text': [text_dict]}\n                                new_name_list.append(new_name)\n                                length_list.append(len(n_motion))\n                            except:\n                                print(line_split)\n                                print(line_split[2], line_split[3], f_tag, to_tag, name)\n                if flag:\n                    data_dict[name] = {'motion': motion, 'joints': joints, 'length': len(motion), 'text': text_data}\n                    new_name_list.append(name)\n                    length_list.append(len(motion))\n            except Exception as e:\n                pass\n        name_list, length_list = (new_name_list, length_list)\n        self.mean = mean\n        self.std = std\n        self.length_arr = np.array(length_list)\n        self.data_dict = data_dict\n        self.name_list = name_list\n\n    def inv_transform(self, data):\n        return data * self.std + self.mean\n\n    def __len__(self):\n        return len(self.data_dict) - self.pointer\n\n    def __getitem__(self, item):\n        idx = self.pointer + item\n        data = self.data_dict[self.name_list[idx]]\n        motion, joints, m_length, text_list = (data['motion'], data['joints'], data['length'], data['text'])\n        text_data = random.choice(text_list)\n        caption, tokens = (text_data['caption'], text_data['tokens'])\n        if self.config.unit_length < 10:\n            coin2 = np.random.choice(['single', 'single', 'double'])\n        else:\n            coin2 = 'single'\n        if coin2 == 'double':\n            m_length = (m_length // self.config.unit_length - 1) * self.config.unit_length\n        elif coin2 == 'single':\n            m_length = m_length // self.config.unit_length * self.config.unit_length\n        idx = random.randint(0, len(motion) - m_length)\n        motion = motion[idx:idx + m_length]\n        joints = joints[idx:idx + m_length]\n        motion = (motion - self.mean) / self.std\n        if m_length < self.max_motion_length:\n            motion = np.concatenate([motion, np.zeros((self.max_motion_length - m_length, motion.shape[1]))], axis=0)\n            joints = np.concatenate([joints, np.zeros((self.max_motion_length - m_length, joints.shape[1], joints.shape[2]))], axis=0)\n        input_features = extract_feature_subset(motion, self.feature_dims)\n        return (caption, input_features, motion, joints, m_length)\n\n    def reset_min_len(self, length):\n        assert length <= self.max_motion_length\n        self.pointer = np.searchsorted(self.length_arr, length)\n        print('Pointer Pointing at %d' % self.pointer)\n\ndef create_dataloader(config: Config, split: str='train', shuffle: bool=True) -> DataLoader:\n    mean_path = config.dataset_path / 'Mean.npy'\n    std_path = config.dataset_path / 'Std.npy'\n    if not mean_path.exists() or not std_path.exists():\n        raise FileNotFoundError(f'Mean.npy and/or Std.npy not found in {config.dataset_path}. Please ensure Mean.npy and Std.npy exist in the dataset directory.')\n    mean = np.load(mean_path)\n    std = np.load(std_path)\n    dataset_obj = Text2MotionDataset(config, mean, std, split, feature_dims=config.feature_dims)\n    return DataLoader(dataset_obj, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_workers, pin_memory=config.pin_memory)\n\ndef load_sample(dataset_path: Path, file_id: str) -> Dict[str, Optional[Any]]:\n    features_path = dataset_path / 'new_joint_vecs' / f'{file_id}.npy'\n    joints_path = dataset_path / 'new_joints' / f'{file_id}.npy'\n    text_path = dataset_path / 'texts' / f'{file_id}.txt'\n    data: Dict[str, Optional[Any]] = {'file_id': file_id}\n    if features_path.exists():\n        data['features'] = np.load(features_path)\n    else:\n        print(f'Warning: Features not found for {file_id}')\n        data['features'] = None\n    if joints_path.exists():\n        data['joints'] = np.load(joints_path)\n    else:\n        print(f'Warning: Joints not found for {file_id}')\n        data['joints'] = None\n    if text_path.exists():\n        with open(text_path, 'r') as f:\n            descriptions = [line.strip().split('#')[0] for line in f.readlines()]\n            data['text'] = descriptions[0] if descriptions else ''\n    else:\n        data['text'] = ''\n    return data",
        'utils/motion_utils.py': "import numpy as np\nimport torch\nfrom typing import List, Tuple, Dict, Any\nfrom .skeleton import Skeleton\nfrom .quaternion import qrot_np, qfix, quaternion_to_cont6d_np, qmul_np, qinv_np, qrot, qinv\nkit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]\nkit_raw_offsets = np.array([[0, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [0, -1, 0], [0, -1, 0], [-1, 0, 0], [0, -1, 0], [0, -1, 0], [1, 0, 0], [0, -1, 0], [0, -1, 0], [0, 0, 1], [0, 0, 1], [-1, 0, 0], [0, -1, 0], [0, -1, 0], [0, 0, 1], [0, 0, 1]])\nt2m_raw_offsets = np.array([[0, 0, 0], [1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, -1, 0], [0, 1, 0], [0, -1, 0], [0, -1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0], [1, 0, 0], [-1, 0, 0], [0, 0, 1], [0, -1, 0], [0, -1, 0], [0, -1, 0], [0, -1, 0], [0, -1, 0], [0, -1, 0]])\nt2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]\nt2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]\nt2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]\nkit_tgt_skel_id = '03950'\nt2m_tgt_skel_id = '000021'\n\ndef extract_features(positions: np.ndarray, feet_thre: float, n_raw_offsets: np.ndarray, kinematic_chain: list, face_joint_indx: list, fid_r: list, fid_l: list) -> np.ndarray:\n    global_positions = positions.copy()\n\n    def foot_detect(positions, thres):\n        velfactor = np.array([thres, thres])\n        feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2\n        feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2\n        feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2\n        feet_l = (feet_l_x + feet_l_y + feet_l_z < velfactor).astype(np.float64)\n        feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2\n        feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2\n        feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2\n        feet_r = (feet_r_x + feet_r_y + feet_r_z < velfactor).astype(np.float64)\n        return (feet_l, feet_r)\n    feet_l, feet_r = foot_detect(positions, feet_thre)\n\n    def get_cont6d_params(positions):\n        skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')\n        quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)\n        cont_6d_params = quaternion_to_cont6d_np(quat_params)\n        r_rot = quat_params[:, 0].copy()\n        velocity = (positions[1:, 0] - positions[:-1, 0]).copy()\n        velocity = qrot_np(r_rot[1:], velocity)\n        r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))\n        return (cont_6d_params, r_velocity, velocity, r_rot)\n\n    def get_rifke(positions, r_rot):\n        positions = positions.copy()\n        positions[..., 0] -= positions[:, 0:1, 0]\n        positions[..., 2] -= positions[:, 0:1, 2]\n        positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)\n        return positions\n    cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)\n    positions = get_rifke(positions, r_rot)\n    root_y = positions[:, 0, 1:2]\n    r_velocity = np.arcsin(r_velocity[:, 2:3])\n    l_velocity = velocity[:, [0, 2]]\n    root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)\n    rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)\n    ric_data = positions[:, 1:].reshape(len(positions), -1)\n    local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1), global_positions[1:] - global_positions[:-1])\n    local_vel = local_vel.reshape(len(local_vel), -1)\n    data = root_data\n    data = np.concatenate([data, ric_data[:-1]], axis=-1)\n    data = np.concatenate([data, rot_data[:-1]], axis=-1)\n    data = np.concatenate([data, local_vel], axis=-1)\n    data = np.concatenate([data, feet_l, feet_r], axis=-1)\n    return data\n\ndef recover_root_rot_pos(data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n    rot_vel = data[..., 0]\n    r_rot_ang = torch.zeros_like(rot_vel).to(data.device)\n    r_rot_ang[..., 1:] = rot_vel[..., :-1]\n    r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)\n    r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)\n    r_rot_quat[..., 0] = torch.cos(r_rot_ang)\n    r_rot_quat[..., 2] = torch.sin(r_rot_ang)\n    r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)\n    r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]\n    r_pos = qrot(qinv(r_rot_quat), r_pos)\n    r_pos = torch.cumsum(r_pos, dim=-2)\n    r_pos[..., 1] = data[..., 3]\n    return (r_rot_quat, r_pos)\n\ndef recover_from_ric(data: torch.Tensor, joints_num: int) -> torch.Tensor:\n    r_rot_quat, r_pos = recover_root_rot_pos(data)\n    positions = data[..., 4:(joints_num - 1) * 3 + 4]\n    positions = positions.view(positions.shape[:-1] + (-1, 3))\n    positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)\n    positions[..., 0] += r_pos[..., 0:1]\n    positions[..., 2] += r_pos[..., 2:3]\n    positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)\n    return positions\nDATASET_CONFIGS = {'t2m': {'name': 'HumanML3D', 'num_joints': 22, 'raw_offsets': t2m_raw_offsets, 'kinematic_chain': t2m_kinematic_chain, 'face_joint_indx': [2, 1, 17, 16], 'fid_r': [8, 11], 'fid_l': [7, 10]}, 'kit': {'name': 'KIT', 'num_joints': 21, 'raw_offsets': kit_raw_offsets, 'kinematic_chain': kit_kinematic_chain, 'face_joint_indx': [11, 16, 5, 8], 'fid_r': [14, 15], 'fid_l': [19, 20]}}\n\ndef get_dataset_config(dataset_type: str='t2m') -> Dict[str, Any]:\n    if dataset_type not in DATASET_CONFIGS:\n        raise ValueError(f'Unknown dataset_type: {dataset_type}. Available: {list(DATASET_CONFIGS.keys())}')\n    return DATASET_CONFIGS[dataset_type]\n\ndef feature_to_joints(motion_features: torch.Tensor, dataset_type: str='t2m') -> torch.Tensor:\n    if not isinstance(motion_features, torch.Tensor):\n        motion_features = torch.FloatTensor(motion_features)\n    config = get_dataset_config(dataset_type)\n    joints = recover_from_ric(motion_features, config['num_joints'])\n    return joints\n\ndef joints_to_feature(joint_positions: torch.Tensor, dataset_type: str='t2m', feet_thre: float=0.002) -> np.ndarray:\n    if isinstance(joint_positions, torch.Tensor):\n        joint_positions_np = joint_positions.cpu().numpy()\n    else:\n        joint_positions_np = joint_positions\n    config = get_dataset_config(dataset_type)\n    features = extract_features(joint_positions_np, feet_thre=feet_thre, n_raw_offsets=config['raw_offsets'], kinematic_chain=config['kinematic_chain'], face_joint_indx=config['face_joint_indx'], fid_r=config['fid_r'], fid_l=config['fid_l'])\n    return features\n\ndef extract_feature_subset(features: np.ndarray, dimensions: tuple[slice, ...]) -> np.ndarray:\n    subset_list = []\n    for dim in dimensions:\n        subset_list.append(features[:, dim])\n    return np.concatenate(subset_list, axis=-1)",
        'utils/visualization.py': "import numpy as np\nimport matplotlib.pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom pathlib import Path\nfrom typing import Optional, Any\nfrom .motion_utils import t2m_kinematic_chain\n\ndef plot_3d_motion(motion: np.ndarray, fps: float=20, radius: float=1.0, title: str='Motion Visualization', follow_root: bool=False) -> FuncAnimation:\n    fig = plt.figure(figsize=(8, 8))\n    ax = fig.add_subplot(111, projection='3d')\n    ax.view_init(elev=15, azim=-70)\n    colors = ['#2980b9', '#c0392b', '#27ae60', '#f39c12', '#8e44ad']\n    lines = [ax.plot([], [], [], color=colors[i % len(colors)], marker='o', ms=2, lw=2)[0] for i in range(len(t2m_kinematic_chain))]\n    ax.set_xlabel('X (Side)')\n    ax.set_ylabel('Z (Forward)')\n    ax.set_zlabel('Y (Height)')\n    ax.set_title(title)\n    pos_min = motion.min(axis=(0, 1))\n    pos_max = motion.max(axis=(0, 1))\n\n    def update(frame):\n        root = motion[frame, 0, :]\n        if follow_root:\n            ax.set_xlim3d([root[0] - radius, root[0] + radius])\n            ax.set_ylim3d([root[2] - radius, root[2] + radius])\n            ax.set_zlim3d([pos_min[1], pos_max[1] + radius * 0.5])\n        else:\n            ax.set_xlim3d([pos_min[0] - radius, pos_max[0] + radius])\n            ax.set_ylim3d([pos_min[2] - radius, pos_max[2] + radius])\n            ax.set_zlim3d([pos_min[1], pos_max[1] + radius * 0.5])\n        for i, c_indices in enumerate(t2m_kinematic_chain):\n            joints = motion[frame, c_indices, :]\n            lines[i].set_data(joints[:, 0], joints[:, 2])\n            lines[i].set_3d_properties(joints[:, 1])\n        return lines\n    ani = FuncAnimation(fig, update, frames=len(motion), interval=1000 / fps, blit=False)\n    plt.close()\n    return ani\n\ndef visualize_motion(joint_positions: np.ndarray, ground_truth: Optional[np.ndarray]=None, title: str='Motion Visualization', save_path: Optional[Path]=None, fps: float=20, skip_frames: int=1, notebook: bool=True) -> Any:\n    fps = fps / skip_frames\n    ani = plot_3d_motion(joint_positions[::skip_frames], fps=fps, title=title)\n    if save_path:\n        save_path.parent.mkdir(parents=True, exist_ok=True)\n        ani.save(str(save_path), writer='ffmpeg', fps=fps)\n        print(f'Saved animation to {save_path}')\n    if notebook:\n        from IPython.display import HTML\n        return HTML(ani.to_html5_video())\n    return ani\n\ndef compare_motions(generated_joints: np.ndarray, ground_truth_joints: np.ndarray, save_path: Optional[Path]=None) -> None:\n    visualize_motion(generated_joints, ground_truth=ground_truth_joints, title='Generated vs Ground Truth', save_path=save_path)",
        'utils/bvh_utils.py': 'import numpy as np\nfrom pathlib import Path\nfrom typing import Dict, Any, Optional\n\ndef _get_default_skeleton_hierarchy() -> Dict:\n    parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]\n    joint_names = [\'pelvis\', \'left_hip\', \'right_hip\', \'spine1\', \'left_knee\', \'right_knee\', \'spine2\', \'left_ankle\', \'right_ankle\', \'spine3\', \'left_foot\', \'right_foot\', \'neck\', \'left_collar\', \'right_collar\', \'head\', \'left_shoulder\', \'right_shoulder\', \'left_elbow\', \'right_elbow\', \'left_wrist\', \'right_wrist\']\n    hierarchy = {}\n    for i, name in enumerate(joint_names):\n        p_idx = parents[i]\n        p_name = joint_names[p_idx] if p_idx != -1 else \'root\'\n        if p_name not in hierarchy:\n            hierarchy[p_name] = {\'children\': []}\n        hierarchy[p_name][\'children\'].append(name)\n        if name not in hierarchy:\n            hierarchy[name] = {\'children\': []}\n    return hierarchy\n\ndef joints_to_bvh(joint_positions: np.ndarray, fps: int=20, skeleton_template: Optional[Dict]=None) -> Dict[str, Any]:\n    nframe, num_joints, _ = joint_positions.shape\n    bvh_data = {\'hierarchy\': skeleton_template or _get_default_skeleton_hierarchy(), \'motion\': {\'frames\': nframe, \'fps\': fps, \'data\': joint_positions.tolist()}}\n    print(f\'TODO: Implement proper joints_to_bvh conversion\')\n    print(f\'Input: {joint_positions.shape} -> BVH format\')\n    return bvh_data\n\ndef save_bvh(bvh_data: Dict[str, Any], output_path: Path) -> None:\n    output_path.parent.mkdir(parents=True, exist_ok=True)\n    with open(output_path, \'w\') as f:\n        f.write(\'HIERARCHY\\n\')\n        f.write(\'ROOT root\\n\')\n        f.write(\'{\\n\')\n        f.write(\'  OFFSET 0.0 0.0 0.0\\n\')\n        f.write(\'  CHANNELS 6 Xposition Yposition Zposition Zrotation Xrotation Yrotation\\n\')\n        f.write(\'}\\n\')\n        f.write(\'MOTION\\n\')\n        f.write(f"Frames: {bvh_data[\'motion\'][\'frames\']}\\n")\n        f.write(f"Frame Time: {1.0 / bvh_data[\'motion\'][\'fps\']:.6f}\\n")\n    print(f\'TODO: Implement complete BVH file writing\')\n    print(f\'Saved BVH to {output_path}\')\n\ndef save_joints(joint_positions: np.ndarray, output_path: Path) -> None:\n    output_path.parent.mkdir(parents=True, exist_ok=True)\n    np.save(output_path, joint_positions)\n    print(f\'Saved joints to {output_path}\')\n\ndef validate_bvh(bvh_path: Path) -> bool:\n    if not bvh_path.exists():\n        return False\n    try:\n        with open(bvh_path, \'r\') as f:\n            content = f.read()\n            if \'HIERARCHY\' in content and \'MOTION\' in content:\n                return True\n    except Exception:\n        return False\n    return False',
        'utils/metrics.py': "import numpy as np\nfrom typing import List, Dict\n\ndef compute_metrics(generated_joints: List[np.ndarray], ground_truth_joints: List[np.ndarray], generated_texts: List[str], gt_texts: List[str]) -> Dict[str, float]:\n    metrics = {'fid': 0.0, 'diversity': 0.0, 'r_precision': 0.0, 'mm_dist': 0.0}\n    print('TODO: Implement evaluation metrics computation')\n    return metrics",
        'utils/quaternion.py': "import torch\nimport numpy as np\n_EPS4 = np.finfo(float).eps * 4.0\n_FLOAT_EPS = np.finfo(np.float64).eps\n\ndef qinv(q):\n    assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'\n    mask = torch.ones_like(q)\n    mask[..., 1:] = -mask[..., 1:]\n    return q * mask\n\ndef qinv_np(q):\n    assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'\n    return qinv(torch.from_numpy(q).float()).numpy()\n\ndef qnormalize(q):\n    assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'\n    return q / torch.norm(q, dim=-1, keepdim=True)\n\ndef qmul(q, r):\n    assert q.shape[-1] == 4\n    assert r.shape[-1] == 4\n    original_shape = q.shape\n    terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))\n    w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]\n    x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]\n    y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]\n    z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]\n    return torch.stack((w, x, y, z), dim=1).view(original_shape)\n\ndef qrot(q, v):\n    assert q.shape[-1] == 4\n    assert v.shape[-1] == 3\n    assert q.shape[:-1] == v.shape[:-1]\n    original_shape = list(v.shape)\n    q = q.contiguous().view(-1, 4)\n    v = v.contiguous().view(-1, 3)\n    qvec = q[:, 1:]\n    uv = torch.cross(qvec, v, dim=1)\n    uuv = torch.cross(qvec, uv, dim=1)\n    return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)\n\ndef qeuler(q, order, epsilon=0, deg=True):\n    assert q.shape[-1] == 4\n    original_shape = list(q.shape)\n    original_shape[-1] = 3\n    q = q.view(-1, 4)\n    q0 = q[:, 0]\n    q1 = q[:, 1]\n    q2 = q[:, 2]\n    q3 = q[:, 3]\n    if order == 'xyz':\n        x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))\n        y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))\n        z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))\n    elif order == 'yzx':\n        x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))\n        y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))\n        z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))\n    elif order == 'zxy':\n        x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))\n        y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))\n        z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))\n    elif order == 'xzy':\n        x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))\n        y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))\n        z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))\n    elif order == 'yxz':\n        x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))\n        y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))\n        z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))\n    elif order == 'zyx':\n        x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))\n        y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))\n        z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))\n    else:\n        raise\n    if deg:\n        return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi\n    else:\n        return torch.stack((x, y, z), dim=1).view(original_shape)\n\ndef qmul_np(q, r):\n    q = torch.from_numpy(q).contiguous().float()\n    r = torch.from_numpy(r).contiguous().float()\n    return qmul(q, r).numpy()\n\ndef qrot_np(q, v):\n    q = torch.from_numpy(q).contiguous().float()\n    v = torch.from_numpy(v).contiguous().float()\n    return qrot(q, v).numpy()\n\ndef qeuler_np(q, order, epsilon=0, use_gpu=False):\n    if use_gpu:\n        q = torch.from_numpy(q).cuda().float()\n        return qeuler(q, order, epsilon).cpu().numpy()\n    else:\n        q = torch.from_numpy(q).contiguous().float()\n        return qeuler(q, order, epsilon).numpy()\n\ndef qfix(q):\n    assert len(q.shape) == 3\n    assert q.shape[-1] == 4\n    result = q.copy()\n    dot_products = np.sum(q[1:] * q[:-1], axis=2)\n    mask = dot_products < 0\n    mask = (np.cumsum(mask, axis=0) % 2).astype(bool)\n    result[1:][mask] *= -1\n    return result\n\ndef euler2quat(e, order, deg=True):\n    assert e.shape[-1] == 3\n    original_shape = list(e.shape)\n    original_shape[-1] = 4\n    e = e.view(-1, 3)\n    if deg:\n        e = e * np.pi / 180.0\n    x = e[:, 0]\n    y = e[:, 1]\n    z = e[:, 2]\n    rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)\n    ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)\n    rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)\n    result = None\n    for coord in order:\n        if coord == 'x':\n            r = rx\n        elif coord == 'y':\n            r = ry\n        elif coord == 'z':\n            r = rz\n        else:\n            raise\n        if result is None:\n            result = r\n        else:\n            result = qmul(result, r)\n    if order in ['xyz', 'yzx', 'zxy']:\n        result *= -1\n    return result.view(original_shape)\n\ndef expmap_to_quaternion(e):\n    assert e.shape[-1] == 3\n    original_shape = list(e.shape)\n    original_shape[-1] = 4\n    e = e.reshape(-1, 3)\n    theta = np.linalg.norm(e, axis=1).reshape(-1, 1)\n    w = np.cos(0.5 * theta).reshape(-1, 1)\n    xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e\n    return np.concatenate((w, xyz), axis=1).reshape(original_shape)\n\ndef euler_to_quaternion(e, order):\n    assert e.shape[-1] == 3\n    original_shape = list(e.shape)\n    original_shape[-1] = 4\n    e = e.reshape(-1, 3)\n    x = e[:, 0]\n    y = e[:, 1]\n    z = e[:, 2]\n    rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)\n    ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)\n    rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)\n    result = None\n    for coord in order:\n        if coord == 'x':\n            r = rx\n        elif coord == 'y':\n            r = ry\n        elif coord == 'z':\n            r = rz\n        else:\n            raise\n        if result is None:\n            result = r\n        else:\n            result = qmul_np(result, r)\n    if order in ['xyz', 'yzx', 'zxy']:\n        result *= -1\n    return result.reshape(original_shape)\n\ndef quaternion_to_matrix(quaternions):\n    r, i, j, k = torch.unbind(quaternions, -1)\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n    o = torch.stack((1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j)), -1)\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n\ndef quaternion_to_matrix_np(quaternions):\n    q = torch.from_numpy(quaternions).contiguous().float()\n    return quaternion_to_matrix(q).numpy()\n\ndef quaternion_to_cont6d_np(quaternions):\n    rotation_mat = quaternion_to_matrix_np(quaternions)\n    cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)\n    return cont_6d\n\ndef quaternion_to_cont6d(quaternions):\n    rotation_mat = quaternion_to_matrix(quaternions)\n    cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)\n    return cont_6d\n\ndef cont6d_to_matrix(cont6d):\n    assert cont6d.shape[-1] == 6, 'The last dimension must be 6'\n    x_raw = cont6d[..., 0:3]\n    y_raw = cont6d[..., 3:6]\n    x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)\n    z = torch.cross(x, y_raw, dim=-1)\n    z = z / torch.norm(z, dim=-1, keepdim=True)\n    y = torch.cross(z, x, dim=-1)\n    x = x[..., None]\n    y = y[..., None]\n    z = z[..., None]\n    mat = torch.cat([x, y, z], dim=-1)\n    return mat\n\ndef cont6d_to_matrix_np(cont6d):\n    q = torch.from_numpy(cont6d).contiguous().float()\n    return cont6d_to_matrix(q).numpy()\n\ndef qpow(q0, t, dtype=torch.float):\n    q0 = qnormalize(q0)\n    theta0 = torch.acos(q0[..., 0])\n    mask = (theta0 <= 1e-09) * (theta0 >= -1e-09)\n    theta0 = (1 - mask) * theta0 + mask * 1e-09\n    v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)\n    if isinstance(t, torch.Tensor):\n        q = torch.zeros(t.shape + q0.shape)\n        theta = t.view(-1, 1) * theta0.view(1, -1)\n    else:\n        q = torch.zeros(q0.shape)\n        theta = t * theta0\n    q[..., 0] = torch.cos(theta)\n    q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)\n    return q.to(dtype)\n\ndef qslerp(q0, q1, t):\n    q0 = qnormalize(q0)\n    q1 = qnormalize(q1)\n    q_ = qpow(qmul(q1, qinv(q0)), t)\n    return qmul(q_, q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())\n\ndef qbetween(v0, v1):\n    assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'\n    assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'\n    v = torch.cross(v0, v1)\n    w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, keepdim=True)\n    return qnormalize(torch.cat([w, v], dim=-1))\n\ndef qbetween_np(v0, v1):\n    assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'\n    assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'\n    v0 = torch.from_numpy(v0).float()\n    v1 = torch.from_numpy(v1).float()\n    return qbetween(v0, v1).numpy()\n\ndef lerp(p0, p1, t):\n    if not isinstance(t, torch.Tensor):\n        t = torch.Tensor([t])\n    new_shape = t.shape + p0.shape\n    new_view_t = t.shape + torch.Size([1] * len(p0.shape))\n    new_view_p = torch.Size([1] * len(t.shape)) + p0.shape\n    p0 = p0.view(new_view_p).expand(new_shape)\n    p1 = p1.view(new_view_p).expand(new_shape)\n    t = t.view(new_view_t).expand(new_shape)\n    return p0 + t * (p1 - p0)",
        'utils/skeleton.py': "from .quaternion import *\nimport scipy.ndimage.filters as filters\n\nclass Skeleton(object):\n\n    def __init__(self, offset, kinematic_tree, device):\n        self.device = device\n        self._raw_offset_np = offset.numpy()\n        self._raw_offset = offset.clone().detach().to(device).float()\n        self._kinematic_tree = kinematic_tree\n        self._offset = None\n        self._parents = [0] * len(self._raw_offset)\n        self._parents[0] = -1\n        for chain in self._kinematic_tree:\n            for j in range(1, len(chain)):\n                self._parents[chain[j]] = chain[j - 1]\n\n    def njoints(self):\n        return len(self._raw_offset)\n\n    def offset(self):\n        return self._offset\n\n    def set_offset(self, offsets):\n        self._offset = offsets.clone().detach().to(self.device).float()\n\n    def kinematic_tree(self):\n        return self._kinematic_tree\n\n    def parents(self):\n        return self._parents\n\n    def get_offsets_joints_batch(self, joints):\n        assert len(joints.shape) == 3\n        _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()\n        for i in range(1, self._raw_offset.shape[0]):\n            _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]\n        self._offset = _offsets.detach()\n        return _offsets\n\n    def get_offsets_joints(self, joints):\n        assert len(joints.shape) == 2\n        _offsets = self._raw_offset.clone()\n        for i in range(1, self._raw_offset.shape[0]):\n            _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]\n        self._offset = _offsets.detach()\n        return _offsets\n\n    def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):\n        assert len(face_joint_idx) == 4\n        'Get Forward Direction'\n        l_hip, r_hip, sdr_r, sdr_l = face_joint_idx\n        across1 = joints[:, r_hip] - joints[:, l_hip]\n        across2 = joints[:, sdr_r] - joints[:, sdr_l]\n        across = across1 + across2\n        across = across / np.sqrt((across ** 2).sum(axis=-1))[:, np.newaxis]\n        forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)\n        if smooth_forward:\n            forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')\n        forward = forward / np.sqrt((forward ** 2).sum(axis=-1))[..., np.newaxis]\n        'Get Root Rotation'\n        target = np.array([[0, 0, 1]]).repeat(len(forward), axis=0)\n        root_quat = qbetween_np(forward, target)\n        'Inverse Kinematics'\n        quat_params = np.zeros(joints.shape[:-1] + (4,))\n        root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])\n        quat_params[:, 0] = root_quat\n        for chain in self._kinematic_tree:\n            R = root_quat\n            for j in range(len(chain) - 1):\n                u = self._raw_offset_np[chain[j + 1]][np.newaxis, ...].repeat(len(joints), axis=0)\n                v = joints[:, chain[j + 1]] - joints[:, chain[j]]\n                v = v / np.sqrt((v ** 2).sum(axis=-1))[:, np.newaxis]\n                rot_u_v = qbetween_np(u, v)\n                R_loc = qmul_np(qinv_np(R), rot_u_v)\n                quat_params[:, chain[j + 1], :] = R_loc\n                R = qmul_np(R, R_loc)\n        return quat_params\n\n    def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):\n        if skel_joints is not None:\n            offsets = self.get_offsets_joints_batch(skel_joints)\n        if len(self._offset.shape) == 2:\n            offsets = self._offset.expand(quat_params.shape[0], -1, -1)\n        joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)\n        joints[:, 0] = root_pos\n        for chain in self._kinematic_tree:\n            if do_root_R:\n                R = quat_params[:, 0]\n            else:\n                R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)\n            for i in range(1, len(chain)):\n                R = qmul(R, quat_params[:, chain[i]])\n                offset_vec = offsets[:, chain[i]]\n                joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i - 1]]\n        return joints\n\n    def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):\n        if skel_joints is not None:\n            skel_joints = torch.from_numpy(skel_joints)\n            offsets = self.get_offsets_joints_batch(skel_joints)\n        if len(self._offset.shape) == 2:\n            offsets = self._offset.expand(quat_params.shape[0], -1, -1)\n        offsets = offsets.numpy()\n        joints = np.zeros(quat_params.shape[:-1] + (3,))\n        joints[:, 0] = root_pos\n        for chain in self._kinematic_tree:\n            if do_root_R:\n                R = quat_params[:, 0]\n            else:\n                R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)\n            for i in range(1, len(chain)):\n                R = qmul_np(R, quat_params[:, chain[i]])\n                offset_vec = offsets[:, chain[i]]\n                joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]\n        return joints\n\n    def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):\n        if skel_joints is not None:\n            skel_joints = torch.from_numpy(skel_joints)\n            offsets = self.get_offsets_joints_batch(skel_joints)\n        if len(self._offset.shape) == 2:\n            offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)\n        offsets = offsets.numpy()\n        joints = np.zeros(cont6d_params.shape[:-1] + (3,))\n        joints[:, 0] = root_pos\n        for chain in self._kinematic_tree:\n            if do_root_R:\n                matR = cont6d_to_matrix_np(cont6d_params[:, 0])\n            else:\n                matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)\n            for i in range(1, len(chain)):\n                matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))\n                offset_vec = offsets[:, chain[i]][..., np.newaxis]\n                joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i - 1]]\n        return joints\n\n    def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):\n        if skel_joints is not None:\n            offsets = self.get_offsets_joints_batch(skel_joints)\n        if len(self._offset.shape) == 2:\n            offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)\n        joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)\n        joints[..., 0, :] = root_pos\n        for chain in self._kinematic_tree:\n            if do_root_R:\n                matR = cont6d_to_matrix(cont6d_params[:, 0])\n            else:\n                matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)\n            for i in range(1, len(chain)):\n                matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))\n                offset_vec = offsets[:, chain[i]].unsqueeze(-1)\n                joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i - 1]]\n        return joints",
    }
    
    for filepath, content in FILES.items():
        path = Path(filepath)
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, 'w', encoding='utf-8') as f:
            f.write(content)
        print(f'Created {filepath}')
    
    # Install dependencies
    print("Installing dependencies (this may take a minute)...")
    %pip install -r requirements.txt
    
    # Copy dataset
    print("Copying dataset...")
    !apt -qq install rclone && rclone copy /kaggle/input/ /kaggle/working/dataset/ --transfers 16 --checkers 16 --progress --ignore-existing -q
    
    print("Setup Complete!")
else:
    print("Running locally. No setup needed.")


# Human Motion Animation Generation Pipeline

This notebook implements a pipeline for generating human motion animations using:

- **Autoregressive Context Encoder**: Encodes motion context sequentially
- **Flow Matching Network**: Generates motion sequences using flow matching

**Compatible with MoMask input/output format:**

- Input: HumanML3D dim-263 feature vectors
- Output: Joint positions (nframe, 22, 3) → BVH files


In [None]:
# Imports
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys
from config import Config

# Add project root to path
sys.path.append(str(Path.cwd()))

from config import Config
from models import AutoregressiveContextEncoder, FlowMatchingNetwork
from utils.utils import (
    load_sample,
    create_dataloader,
    feature_to_joints,
    joints_to_bvh,
    save_bvh,
    save_joints,
    compute_metrics,
    visualize_motion,
    extract_features,
    recover_from_ric,
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load configuration
config = Config()
config.dataset_path = Path("./dataset/humanml3d-subset")

In [None]:
with open(config.dataset_path / "all.txt", "r") as f:
    all_ids = f.readlines()

file_id = all_ids[0].strip()  # Remove newline character
data_path = config.dataset_path / "new_joints" / f"{file_id}.npy"
motion_data = np.load(data_path)
text_path = config.dataset_path / "texts" / f"{file_id}.txt"
with open(text_path, "r") as f:
    text = f.read()

# 2. Visualize
ani = visualize_motion(motion_data, title=f"{file_id}.npy", fps=20, skip_frames=2)
ani

## Step 1: Data Preparation (HumanML3D)

Load and preprocess the HumanML3D dataset with dim-263 feature vectors.


In [None]:
# Create dataloaders using the modular utilities
print("Creating dataloaders...")

train_loader = create_dataloader(config, split="train", shuffle=True)
val_loader = create_dataloader(config, split="val", shuffle=False)

print(f"Train dataloader created with batch_size={config.batch_size}")
print(f"Val dataloader created with batch_size={config.batch_size}")

# Iterate through a batch to verify
print("\nSample batch from train_loader:")
for batch_idx, (
    captions,
    input_features,
    motions,
    joints,
    lengths,
) in enumerate(train_loader):
    print(f"Batch {batch_idx}:")
    print(f"  Captions (list): {len(captions)} samples")
    print(
        f"  Input features shape: {input_features.shape}"
    )  # (batch_size, max_length, feature_dim)
    print(f"  Motions shape: {motions.shape}")  # (batch_size, max_length, 263)
    print(f"  Joints shape: {joints.shape}")  # (batch_size, max_length, 22, 3)
    print(f"  Lengths shape: {lengths.shape}")  # (batch_size,)

    # Sample caption and motion
    print(f"  Sample caption: '{captions[0]}'")
    print(f"  Sample input feature shape: {input_features[0].shape}")
    print(f"  Sample motion shape: {motions[0].shape}")
    print(f"  Sample joints shape: {joints[0].shape}")
    print(f"  Sample length: {lengths[0].item()}")

    # Show one batch and break
    if batch_idx == 0:
        break

## Step 2: Autoregressive Context Encoder

Initialize and test the autoregressive context encoder model.


In [None]:
# TODO: Initialize Autoregressive Context Encoder
context_encoder = AutoregressiveContextEncoder(
    input_dim=config.motion_dim,  # 263
    hidden_dim=config.hidden_dim,
    num_layers=config.num_encoder_layers,
    max_seq_length=config.max_motion_length,
    bidirectional=config.bidirectional_gru,
).to(device)

print(
    f"Context Encoder parameters: {sum(p.numel() for p in context_encoder.parameters()):,}"
)

# TODO: Test forward pass
sample_batch_motion = torch.randn(
    config.batch_size, config.max_motion_length, config.motion_dim
).to(device)
sample_batch_text = ["A person is walking"] * config.batch_size

with torch.no_grad():
    context_output = context_encoder(sample_batch_motion, sample_batch_text)
    print(f"Context encoder output shape: {context_output.shape}")

## Step 3: Flow Matching Network

Initialize and test the flow matching network model.


In [None]:
# TODO: Initialize Flow Matching Network
flow_matching_net = FlowMatchingNetwork(
    context_dim=context_encoder.output_dim,
    motion_dim=config.motion_dim,  # 263
    hidden_dim=config.hidden_dim,
    num_layers=config.num_flow_layers,
).to(device)

print(
    f"Flow Matching Network parameters: {sum(p.numel() for p in flow_matching_net.parameters()):,}"
)

# TODO: Test forward pass
with torch.no_grad():
    # Flow matching forward pass
    flow_output = flow_matching_net(context_output, sample_batch_motion)
    print(f"Flow matching output shape: {flow_output.shape}")

## Step 4: Training Loop

Set up training configuration, loss functions, and training loop.


In [None]:
# TODO: Set up optimizers
optimizer_context = torch.optim.Adam(
    context_encoder.parameters(), lr=config.learning_rate
)

optimizer_flow = torch.optim.Adam(
    flow_matching_net.parameters(), lr=config.learning_rate
)


# TODO: Define loss functions
def compute_loss(predicted_motion, target_motion, context_output):
    """
    Compute training loss for flow matching.

    Args:
        predicted_motion: Generated motion from flow matching (batch, seq_len, 263)
        target_motion: Ground truth motion (batch, seq_len, 263)
        context_output: Context from autoregressive encoder

    Returns:
        loss: Scalar loss value
    """
    # TODO: Implement flow matching loss
    loss = nn.MSELoss()(predicted_motion, target_motion)
    return loss


# TODO: Training loop
def train_epoch(
    model_context, model_flow, train_loader, optimizer_context, optimizer_flow, device
):
    """
    Train for one epoch.
    """
    model_context.train()
    model_flow.train()

    total_loss = 0.0
    num_batches = 0

    for batch_idx, (motion, text) in enumerate(train_loader):
        # TODO: Move to device
        motion = motion.to(device)

        # TODO: Forward pass
        # 1. Encode context
        context = model_context(motion, text)

        # 2. Flow matching
        predicted_motion = model_flow(context, motion)

        # 3. Compute loss
        loss = compute_loss(predicted_motion, motion, context)

        # TODO: Backward pass
        optimizer_context.zero_grad()
        optimizer_flow.zero_grad()
        loss.backward()
        optimizer_context.step()
        optimizer_flow.step()

        total_loss += loss.item()
        num_batches += 1

        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

    return total_loss / num_batches


# TODO: Validation loop
def validate(model_context, model_flow, val_loader, device):
    """
    Validate model performance.
    """
    model_context.eval()
    model_flow.eval()

    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for motion, text in val_loader:
            motion = motion.to(device)

            context = model_context(motion, text)
            predicted_motion = model_flow(context, motion)
            loss = compute_loss(predicted_motion, motion, context)

            total_loss += loss.item()
            num_batches += 1

    return total_loss / num_batches


# TODO: Training loop with checkpointing
num_epochs = config.num_epochs
best_val_loss = float("inf")

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    # Train
    train_loss = train_epoch(
        context_encoder,
        flow_matching_net,
        train_loader,
        optimizer_context,
        optimizer_flow,
        device,
    )
    print(f"Train Loss: {train_loss:.4f}")

    # Validate
    val_loss = validate(context_encoder, flow_matching_net, val_loader, device)
    print(f"Val Loss: {val_loss:.4f}")

    # TODO: Save checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(
            {
                "context_encoder": context_encoder.state_dict(),
                "flow_matching_net": flow_matching_net.state_dict(),
                "epoch": epoch,
                "val_loss": val_loss,
            },
            config.checkpoint_dir / f"best_model_epoch_{epoch+1}.pt",
        )
        print(f"Saved best model (val_loss: {val_loss:.4f})")

## Step 5: Inference / Generation

Load trained models and generate motion sequences.


In [None]:
# TODO: Load trained models
checkpoint_path = (
    config.checkpoint_dir / "best_model_epoch_X.pt"
)  # Update with actual path
checkpoint = torch.load(checkpoint_path, map_location=device)

context_encoder.load_state_dict(checkpoint["context_encoder"])
flow_matching_net.load_state_dict(checkpoint["flow_matching_net"])

context_encoder.eval()
flow_matching_net.eval()

print("Models loaded successfully")


# TODO: Generate motion sequences
def generate_motion(
    model_context, model_flow, text_prompt, motion_length=None, device="cuda"
):
    """
    Generate motion from text prompt.

    Args:
        model_context: Trained context encoder
        model_flow: Trained flow matching network
        text_prompt: Text description of desired motion
        motion_length: Desired motion length in frames (optional)
        device: Device to run on

    Returns:
        generated_motion: Generated motion as dim-263 features (seq_len, 263)
    """
    model_context.eval()
    model_flow.eval()

    with torch.no_grad():
        # TODO: Generate initial context from text
        # For now, use random initialization - will be replaced with text encoding
        if motion_length is None:
            motion_length = config.max_motion_length

        # TODO: Autoregressive generation with flow matching
        # 1. Initialize with context
        # 2. Iteratively generate using flow matching
        # 3. Return generated motion sequence

        # Placeholder: random generation for skeleton
        generated_motion = torch.randn(motion_length, config.motion_dim).to(device)

    return generated_motion.cpu().numpy()


# TODO: Generate from text prompts
text_prompts = [
    "A person is walking forward",
    "A person is running on a treadmill",
    "A person is dancing",
]

generated_motions = []
for text in text_prompts:
    motion = generate_motion(context_encoder, flow_matching_net, text, device=device)
    generated_motions.append(motion)
    print(f"Generated motion for: '{text}' - Shape: {motion.shape}")

# TODO: Convert to joint positions
generated_joints = []
for motion in generated_motions:
    joints = feature_to_joints(motion)  # (nframe, 22, 3)
    generated_joints.append(joints)
    print(f"Converted to joints - Shape: {joints.shape}")

## Step 6: Post-processing

Convert generated motions to BVH format and save files.


In [None]:
# TODO: Create output directories
output_dir = Path(config.output_path) / "experiment_1"
joints_dir = output_dir / "joints"
animation_dir = output_dir / "animation"

joints_dir.mkdir(parents=True, exist_ok=True)
animation_dir.mkdir(parents=True, exist_ok=True)

# TODO: Convert joint positions to BVH and save
for idx, (joints, text) in enumerate(zip(generated_joints, text_prompts)):
    # Save joint positions as numpy file
    joints_file = joints_dir / f"motion_{idx:04d}.npy"
    np.save(joints_file, joints)
    print(f"Saved joints to {joints_file}")

    # Convert to BVH format
    bvh_data = joints_to_bvh(joints)

    # Save BVH file
    bvh_file = animation_dir / f"motion_{idx:04d}.bvh"
    save_bvh(bvh_data, bvh_file)
    print(f"Saved BVH to {bvh_file}")

    # TODO: Validate BVH structure
    is_valid = validate_bvh(bvh_file)
    print(f"BVH validation: {'Valid' if is_valid else 'Invalid'}")

print(f"\nAll outputs saved to {output_dir}")

## Step 7: Evaluation

Compute evaluation metrics and visualize generated motions.


In [None]:
# TODO: Load ground truth motions for comparison
# For evaluation, compare generated motions with ground truth from validation set
val_motions = []
val_texts = []

for i in range(min(10, len(val_data))):  # Sample 10 validation motions
    motion, text = val_data[i]
    val_motions.append(motion)
    val_texts.append(text)

# Convert validation motions to joints
val_joints = [feature_to_joints(motion) for motion in val_motions]


# TODO: Compute evaluation metrics
def evaluate_generated_motions(
    generated_joints, ground_truth_joints, generated_texts, gt_texts
):
    """
    Compute evaluation metrics for generated motions.

    Metrics:
    - FID (Fréchet Inception Distance) - motion quality
    - Diversity - motion variety
    - R-Precision - text-motion alignment
    """
    # TODO: Implement metrics computation
    metrics = {
        "fid": 0.0,  # Placeholder
        "diversity": 0.0,  # Placeholder
        "r_precision": 0.0,  # Placeholder
    }
    return metrics


metrics = compute_metrics(generated_joints, val_joints, text_prompts, val_texts)
print("\nEvaluation Metrics:")
for metric_name, value in metrics.items():
    print(f"  {metric_name}: {value:.4f}")

# TODO: Visualize generated motions
for idx, (joints, text) in enumerate(zip(generated_joints, text_prompts)):
    print(f"\nVisualizing motion {idx+1}: '{text}'")
    visualize_motion(
        joints, title=text, save_path=animation_dir / f"vis_motion_{idx:04d}.png"
    )

# TODO: Compare with ground truth
print("\nComparing generated vs ground truth:")
for idx in range(min(3, len(generated_joints))):
    print(f"\nSample {idx+1}:")
    print(f"  Generated: '{text_prompts[idx]}'")
    print(f"  Ground Truth: '{val_texts[idx]}'")

    # Visualize comparison
    visualize_motion(
        generated_joints[idx],
        ground_truth=val_joints[idx],
        title=f"Generated vs GT - {idx+1}",
        save_path=animation_dir / f"comparison_{idx:04d}.png",
    )

print("\nEvaluation complete!")