In [None]:

# =============================================
# 1. 依赖库导入及bag2lerobot数据转换相关代码定义
# =============================================
import dataclasses
from pathlib import Path
import shutil
from typing import Literal
import sys
from unittest.mock import patch
from lerobot.scripts.train import train,init_logging
import os
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
import torch
import tqdm
import json
from kuavo_1convert.common.kuavo_dataset import (
    KuavoRosbagReader,
    DEFAULT_JOINT_NAMES_LIST,
    DEFAULT_LEG_JOINT_NAMES,
    DEFAULT_ARM_JOINT_NAMES,
    DEFAULT_HEAD_JOINT_NAMES,
    DEFAULT_CAMERA_NAMES,
    DEFAULT_JOINT_NAMES,
    FPS,
    CONTROL_HAND_SIDE,
    ROBOT_SLICE,
    DEX_SLICE,
)
config = None
@dataclasses.dataclass(frozen=True)
class DatasetConfig:
    use_videos: bool = True
    tolerance_s: float = 0.0001
    image_writer_processes: int = 10
    image_writer_threads: int = 5
    video_backend: str | None = None

DEFAULT_DATASET_CONFIG = DatasetConfig()

#提取bag数据中摄像头名称
def get_cameras(bag_data: dict) -> list[str]:
    """
    /camera/color/camera_info           : sensor_msgs/CameraInfo
    /camera/color/image_raw             : sensor_msgs/Image     
    /camera/depth/camera_info           : sensor_msgs/CameraInfo
    /camera/depth/image_rect_raw        : sensor_msgs/Image     
    """
    cameras = []

    for k in DEFAULT_CAMERA_NAMES:
        cameras.append(k)
    return cameras

###创建结构化的空数据集
def create_empty_dataset(
    repo_id: str,
    robot_type: str,
    mode: Literal["video", "image"] = "video",
    *,
    has_velocity: bool = False,
    has_effort: bool = False,
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
    root: str,
) -> LeRobotDataset:
    
    # 根据config的参数决定是否为半身和末端的关节类型
    motors = DEFAULT_JOINT_NAMES_LIST
    # TODO: auto detect cameras
    cameras = DEFAULT_CAMERA_NAMES

    features = {
        "observation.state": {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": {
                "motors": motors
            }
        },
        "action": {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": {
                "motors": motors
            }
        },
    }

    if has_velocity:
        features["observation.velocity"] = {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        }

    if has_effort:
        features["observation.effort"] = {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        }

    for cam in cameras:
        if 'depth' in cam:
            features[f"observation.images.{cam}"] = {
                "dtype": mode,
                "shape": (480, 640),
                "names": [
                    "height",
                    "width",
                ],
            }
        else:
            features[f"observation.images.{cam}"] = {
                "dtype": mode,
                "shape": (3, 480, 640),
                "names": [
                    "channels",
                    "height",
                    "width",
                ],
            }

    if Path(LEROBOT_HOME / repo_id).exists():
        shutil.rmtree(LEROBOT_HOME / repo_id)

    return LeRobotDataset.create(
        repo_id=repo_id,
        fps=FPS,
        robot_type=robot_type,
        features=features,
        use_videos=dataset_config.use_videos,
        tolerance_s=dataset_config.tolerance_s,
        image_writer_processes=dataset_config.image_writer_processes,
        image_writer_threads=dataset_config.image_writer_threads,
        video_backend=dataset_config.video_backend,
        root=root,
    )
###转换图像数据
def load_raw_images_per_camera(bag_data: dict) -> dict[str, np.ndarray]:
    imgs_per_cam = {}
    for camera in get_cameras(bag_data):
        imgs_per_cam[camera] = np.array([msg['data'] for msg in bag_data[camera]])
    
    return imgs_per_cam
###转换关节状态数据
def load_raw_episode_data(
    ep_path: Path,
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]:
    
    bag_reader = KuavoRosbagReader()
    bag_data = bag_reader.process_rosbag(ep_path)
    
    state = np.array([msg['data'] for msg in bag_data['observation.state']], dtype=np.float32)
    action = np.array([msg['data'] for msg in bag_data['action']], dtype=np.float32)
    claw_state = np.array([msg['data'] for msg in bag_data['observation.claw']], dtype=np.float64)
    claw_action= np.array([msg['data'] for msg in bag_data['action.claw']], dtype=np.float64)
    qiangnao_state = np.array([msg['data'] for msg in bag_data['observation.qiangnao']], dtype=np.float64)
    qiangnao_action= np.array([msg['data'] for msg in bag_data['action.qiangnao']], dtype=np.float64)

    velocity = None
    effort = None
    
    imgs_per_cam = load_raw_images_per_camera(bag_data)
    
    return imgs_per_cam, state, action, velocity, effort ,claw_state ,claw_action,qiangnao_state,qiangnao_action
#显示数据维度
def diagnose_frame_data(data):
    for k, v in data.items():
        print(f"Field: {k}")
        print(f"  Shape    : {v.shape}")
        print(f"  Dtype    : {v.dtype}")
        print(f"  Type     : {type(v).__name__}")
        print("-" * 40)

#构建frame数据
def populate_dataset(
    dataset: LeRobotDataset,
    bag_files: list[Path],
    task: str,
    episodes: list[int] | None = None,
) -> LeRobotDataset:
    if episodes is None:
        episodes = range(len(bag_files))
    
    for ep_idx in tqdm.tqdm(episodes):
        ep_path = bag_files[ep_idx]
        # 默认读取所有的数据如果话题不存在相应的数值应该是一个空的数据
        imgs_per_cam, state, action, velocity, effort ,claw_state, claw_action,qiangnao_state,qiangnao_action= load_raw_episode_data(ep_path)
        # 进行二值化处理
        binary_figure = True
        if binary_figure:
            qiangnao_state = np.where(qiangnao_state > 50, 1, 0)
            qiangnao_action = np.where(qiangnao_action > 50, 1, 0)
        num_frames = state.shape[0]
        
        for i in range(num_frames):
            output_action = np.array([], dtype=np.float32) 
            # TODO: 根据配置文件读取相应的数据决定构建什么样的frame数据(全身还是半身)(lujuclaw还是qiangnao)
            if config['only_half_up_body']:
                if config['use_lejuclaw']:
                    output_state = state[i, 12:19]
                    output_state = np.insert(output_state, 7, claw_state[i, 0].astype(np.float32))
                    output_state = np.concatenate((output_state, state[i, 19:26]), axis=0)
                    output_state = np.insert(output_state, 15, claw_state[i, 1].astype(np.float32))
                    output_action = action[i, 12:19]
                    output_action = np.insert(output_action, 7, claw_action[i, 0].astype(np.float32))
                    output_action = np.concatenate((output_action, action[i, 19:26]), axis=0)
                    output_action = np.insert(output_action, 15, claw_action[i, 1].astype(np.float32))

                elif config['use_qiangnao']:
                    if CONTROL_HAND_SIDE == "left" or CONTROL_HAND_SIDE == "both":
                        output_state = state[i, ROBOT_SLICE[0][0]:ROBOT_SLICE[0][-1]]
                        output_state = np.concatenate((output_state, qiangnao_state[i, DEX_SLICE[0][0]:DEX_SLICE[0][-1]].astype(np.float32)), axis=0)
                        
                        output_action = action[i, ROBOT_SLICE[0][0]:ROBOT_SLICE[0][-1]]
                        output_action = np.concatenate((output_action, qiangnao_action[i, DEX_SLICE[0][0]:DEX_SLICE[0][-1]].astype(np.float32)), axis=0)
                    if CONTROL_HAND_SIDE == "right" or CONTROL_HAND_SIDE == "both":
                        if CONTROL_HAND_SIDE == "both":
                            output_state = np.concatenate((output_state, state[i, ROBOT_SLICE[1][0]:ROBOT_SLICE[1][-1]]), axis=0)
                            output_state = np.concatenate((output_state, qiangnao_state[i, DEX_SLICE[1][0]:DEX_SLICE[1][-1]].astype(np.float32)), axis=0)
                            output_action = np.concatenate((output_action, action[i, ROBOT_SLICE[1][0]:ROBOT_SLICE[1][-1]]), axis=0)
                            
                            output_action = np.concatenate((output_action, qiangnao_action[i, DEX_SLICE[1][0]:DEX_SLICE[1][-1]].astype(np.float32)), axis=0)
                        else:
                            output_state = state[i, ROBOT_SLICE[1][0]:ROBOT_SLICE[1][-1]]
                            output_state = np.concatenate((output_state, qiangnao_state[i, DEX_SLICE[1][0]:DEX_SLICE[1][-1]].astype(np.float32)), axis=0)
                            output_action = action[i, ROBOT_SLICE[1][0]:ROBOT_SLICE[1][-1]]
                            output_action = np.concatenate((output_action, qiangnao_action[i, DEX_SLICE[1][0]:DEX_SLICE[1][-1]].astype(np.float32)), axis=0)

            else:   
                if config['use_lejuclaw']:
                    output_state = state[i, 0:19]
                    output_state = np.insert(output_state, 19, claw_state[i, 0].astype(np.float32))
                    output_state = np.concatenate((output_state, state[i, 19:26]), axis=0)
                    output_state = np.insert(output_state, 19, claw_state[i, 1].astype(np.float32))
                    output_state = np.concatenate((output_state, state[i, 26:28]), axis=0)
                    output_action = action[i, 0:19]
                    output_action = np.insert(output_action, 19, claw_action[i, 0].astype(np.float32))
                    output_action = np.concatenate((output_action, action[i, 19:26]), axis=0)
                    output_action = np.insert(output_action, 19, claw_action[i, 1].astype(np.float32))
                    output_action = np.concatenate((output_action, action[i, 26:28]), axis=0)
                elif config['use_qiangnao']:
                    output_state = state[i, 0:19]
                    output_state = np.concatenate((output_state, qiangnao_state[i, 0:6].astype(np.float32)), axis=0)
                    output_state = np.concatenate((output_state, state[i, 19:26]), axis=0)
                    output_state = np.concatenate((output_state, qiangnao_state[i, 6:12].astype(np.float32)), axis=0)
                    output_state = np.concatenate((output_state, state[i, 26:28]), axis=0)

                    output_action = action[i, 0:19]
                    output_action = np.concatenate((output_action, qiangnao_action[i, 0:6].astype(np.float32)),axis=0)
                    output_action = np.concatenate((output_action, action[i, 19:26]), axis=0)
                    output_action = np.concatenate((output_action, qiangnao_action[i, 6:12].astype(np.float32)), axis=0)
                    output_action = np.concatenate((output_action, action[i, 26:28]), axis=0)
                
            frame = {
                "observation.state": torch.from_numpy(output_state).type(torch.float32),
                "action": torch.from_numpy(output_action).type(torch.float32),
            }
            
            for camera, img_array in imgs_per_cam.items():
                frame[f"observation.images.{camera}"] = img_array[i]
            
            if velocity is not None:
                frame["observation.velocity"] = velocity[i]
            if effort is not None:
                frame["observation.effort"] = effort[i]   
            
            # diagnose_frame_data(frame)
            dataset.add_frame(frame)
        dataset.save_episode(task=task)
    return dataset

# 数据转换
def port_kuavo_rosbag(
    raw_dir: Path,
    repo_id: str,
    raw_repo_id: str | None = None,
    task: str = "DEBUG",
    *,
    episodes: list[int] | None = None,
    push_to_hub: bool = False,
    is_mobile: bool = False,
    mode: Literal["video", "image"] = "video",
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
    root: str,
    n: int | None = None,
):
    # Download raw data if not exists
    if (LEROBOT_HOME / repo_id).exists():
        shutil.rmtree(LEROBOT_HOME / repo_id)
        
    bag_reader = KuavoRosbagReader() 
    bag_files = bag_reader.list_bag_files(raw_dir)
    
    if isinstance(n, int) and n > 0:
        # random sample num_of_bag files
        select_idx = np.random.choice(len(bag_files), n, replace=False)
        bag_files = [bag_files[i] for i in select_idx]
    
    dataset = create_empty_dataset( 
        repo_id,
        robot_type="kuavo4pro",
        mode=mode,
        has_effort=False,
        has_velocity=False,
        dataset_config=dataset_config,
        root = root,
    )
    dataset = populate_dataset(
        dataset,
        bag_files,
        task=task,
        episodes=episodes,
    )
    dataset.consolidate()
    


In [None]:
# =============================================
# 2. 进行bag2lerobot数据转换
# =============================================
##########  原始数据集路径raw_dir  ##########
raw_dir = "/home/lejurobot/tjy/Task20_doorknob/testbag"
##########  生成lerobot文件名称    ##########
version = "v0"
##########  可选 从raw_dir中抽多少数据进行转换，None即全部转换 ##########
n = None
task_name = os.path.basename(raw_dir)
repo_id = f'lerobot/{task_name}'
lerobot_dir = os.path.join(raw_dir, "../", version, "lerobot")

if os.path.exists(lerobot_dir):
    shutil.rmtree(lerobot_dir)

current_dir = os.path.dirname(os.path.abspath("__file__"))
config_file = os.path.join(current_dir, "kuavo_1convert/config/claw.json")
# 
with open(config_file, 'r') as f:
    config = json.load(f)
    if config['only_half_up_body']:
        if config['use_lejuclaw']:
            DEFAULT_ARM_JOINT_NAMES =[
            "zarm_l1_link", "zarm_l2_link", "zarm_l3_link", "zarm_l4_link", "zarm_l5_link", "zarm_l6_link", "zarm_l7_link","left_claw",
            "zarm_r1_link", "zarm_r2_link", "zarm_r3_link", "zarm_r4_link", "zarm_r5_link", "zarm_r6_link", "zarm_r7_link","right_claw",
            ]
        elif config['use_qiangnao']:
            DEFAULT_ARM_JOINT_NAMES =[
            "zarm_l1_link", "zarm_l2_link", "zarm_l3_link", "zarm_l4_link", "zarm_l5_link", "zarm_l6_link", "zarm_l7_link","left_qiangnao_1",
            "left_qiangnao_2",
            "left_qiangnao_3","left_qiangnao_4","left_qiangnao_5","left_qiangnao_6",
            "zarm_r1_link", "zarm_r2_link", "zarm_r3_link", "zarm_r4_link", "zarm_r5_link", "zarm_r6_link", "zarm_r7_link","right_qiangnao_1",
            "right_qiangnao_2",
            "right_qiangnao_3","right_qiangnao_4","right_qiangnao_5","right_qiangnao_6",
            ]                
            arm_slice = [
                (ROBOT_SLICE[0][0] - 12, ROBOT_SLICE[0][-1] - 12),(DEX_SLICE[0][0] + 7, DEX_SLICE[0][-1] + 7), 
                (ROBOT_SLICE[1][0] - 12 + 6, ROBOT_SLICE[1][-1] - 12 + 6), (DEX_SLICE[1][0] + 14, DEX_SLICE[1][-1] + 14)
                ]
            DEFAULT_ARM_JOINT_NAMES = [DEFAULT_ARM_JOINT_NAMES[k] for l, r in arm_slice for k in range(l, r)]  

        DEFAULT_JOINT_NAMES_LIST = DEFAULT_ARM_JOINT_NAMES
    else:
        if config['use_lejuclaw']:
            DEFAULT_ARM_JOINT_NAMES=[
            "zarm_l1_link", "zarm_l2_link", "zarm_l3_link", "zarm_l4_link", "zarm_l5_link", "zarm_l6_link", "zarm_l7_link","left_claw",
            "zarm_r1_link", "zarm_r2_link", "zarm_r3_link", "zarm_r4_link", "zarm_r5_link", "zarm_r6_link", "zarm_r7_link","right_claw",
            ]
        elif config['use_qiangnao']:
            DEFAULT_ARM_JOINT_NAMES=[
            "zarm_l1_link", "zarm_l2_link", "zarm_l3_link", "zarm_l4_link", "zarm_l5_link", "zarm_l6_link", "zarm_l7_link","left_qiangnao_1","left_qiangnao_2","left_qiangnao_3","left_qiangnao_4","left_qiangnao_5","left_qiangnao_6",
            "zarm_r1_link", "zarm_r2_link", "zarm_r3_link", "zarm_r4_link", "zarm_r5_link", "zarm_r6_link", "zarm_r7_link","right_qiangnao_1","right_qiangnao_2","right_qiangnao_3","right_qiangnao_4","right_qiangnao_5","right_qiangnao_6",
            ]           
        DEFAULT_JOINT_NAMES_LIST = DEFAULT_LEG_JOINT_NAMES + DEFAULT_ARM_JOINT_NAMES + DEFAULT_HEAD_JOINT_NAMES
    # ----------------------------------------------------------------
    # config["record_topics_qiangnao"][7]="/control_robot_hand_position_action"
    # ----------------------------------------------------------------
# 避免后续使用数据与导入发生冲突
DEFAULT_JOINT_NAMES = {
    "full_joint_names": DEFAULT_LEG_JOINT_NAMES + DEFAULT_ARM_JOINT_NAMES + DEFAULT_HEAD_JOINT_NAMES,
    "leg_joint_names": DEFAULT_LEG_JOINT_NAMES,
    "arm_joint_names": DEFAULT_ARM_JOINT_NAMES,
    "head_joint_names": DEFAULT_HEAD_JOINT_NAMES,
}
###开始进行数据转换
port_kuavo_rosbag(raw_dir, repo_id, root=lerobot_dir, n=n)

In [None]:
# =============================================
# 3. 训练模型
# ============================================= 
original_argv = sys.argv
args = [
    "--dataset.repo_id", repo_id ,
    "--policy.type", "act",
    "--device", "cuda",
    "--dataset.local_files_only", "true",
    "--dataset.root", lerobot_dir
]
with patch.object(sys, 'argv', ["train.py"]+args):
    init_logging()
    train()
    
#最终输出模型在kuavo_il路径下 repo_idoutput