# convert tf record to png

In [None]:
# use bucket as example
import os
import tensorflow as tf
import numpy as np
import pandas as pd
from tqdm import tqdm
import imageio.v2 as imageio  

from utils import DATA_DIR

tfrecord_dir = DATA_DIR / "tfds_datasets/bucket_dex_art_dataset/1.0.0"
tfrecord_dir = str(tfrecord_dir)

base_save_dir = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/dexart_all_bucket_png"
base_save_dir = str(base_save_dir)
camera_views = [
    "bucket_viz"
]

os.makedirs(base_save_dir, exist_ok=True)

tfrecord_files = sorted([
    os.path.join(tfrecord_dir, f)
    for f in os.listdir(tfrecord_dir)
    if f.startswith("dex_art_dataset-train.tfrecord")
])

for view in camera_views:
    print(f"\nstart export: {view}")
    save_dir = os.path.join(base_save_dir, view)
    os.makedirs(save_dir, exist_ok=True)

    frame_counts = {}

    def parse_example(example_proto):
        feature_description = {
            f"steps/observation/{view}": tf.io.VarLenFeature(tf.string),
            "episode_metadata/file_path": tf.io.FixedLenFeature([], tf.string),
        }
        return tf.io.parse_single_example(example_proto, feature_description)

    for tfrecord_path in tqdm(tfrecord_files, desc=f"处理 TFRecord - {view}"):
        raw_dataset = tf.data.TFRecordDataset(tfrecord_path)

        for raw_record in raw_dataset:
            example = parse_example(raw_record)
            episode_path = example["episode_metadata/file_path"].numpy().decode()
            episode_id = os.path.splitext(os.path.basename(episode_path))[0]

            try:
                frames = tf.sparse.to_dense(example[f"steps/observation/{view}"]).numpy()
            except:
                print(f"[skip] {episode_id} missing {view}")
                continue

            if len(frames) == 0:
                continue

            frame_counts[episode_id] = len(frames)
            print(f"[frame count] {episode_id} has {len(frames)} frames")

            episode_dir = os.path.join(save_dir, episode_id)
            os.makedirs(episode_dir, exist_ok=True)

            for i, img_bytes in enumerate(frames):
                try:
                    img = tf.io.decode_png(img_bytes).numpy()
                    if img.shape[-1] == 4:
                        img = img[:, :, :3]  # 去除 alpha 通道
                    frame_path = os.path.join(episode_dir, f"frame_{i:03d}.png")
                    imageio.imwrite(frame_path, img)
                except Exception as e:
                    print(f"[skip frame] {episode_id} frame {i} decode failed: {e}")

    frame_values = list(frame_counts.values())
    if frame_values:
        print(f"\nframe count statistics - {view}")
        print(f"total video count: {len(frame_values)}")
        print(f"average frame count: {np.mean(frame_values):.2f}")
        print(f"max frame count: {np.max(frame_values)}")
        print(f"min frame count: {np.min(frame_values)}")

        df = pd.DataFrame({
            "episode_id": list(frame_counts.keys()),
            "frame_count": frame_values
        })
        csv_path = os.path.join(save_dir, f"{view}_frame_counts.csv")
        df.to_csv(csv_path, index=False)
        print(f"frame count statistics saved to: {csv_path}")
    else:
        print(f"no valid frame data found - {view}")

print("\nfinished")


# gripper position

###  gpos for bucket

In [None]:

import os
import json
import numpy as np
import tensorflow as tf
import transforms3d
from tqdm import tqdm
import re
import shutil
import glob
import imageio.v2 as imageio
import cv2

from utils import DATA_DIR

tfrecord_dir = DATA_DIR / "tfds_datasets/bucket_dex_art_dataset/1.0.0"
tfrecord_dir = str(tfrecord_dir)

video_dir = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/dexart_all_bucket_png/bucket_viz"
video_dir = str(video_dir)
save_video_dir = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/bucket/vis_gpos"
save_video_dir = str(save_video_dir)
save_json_path = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/bucket/gripper_positions.json"
save_json_path = str(save_json_path)

TARGET_RES = 224
CAMERA_RES = 1000
SCALE = TARGET_RES / CAMERA_RES


os.makedirs(save_video_dir, exist_ok=True)
for filename in os.listdir(save_video_dir):
    file_path = os.path.join(save_video_dir, filename)
    try:
        if os.path.isfile(file_path) or os.path.islink(file_path):
            os.unlink(file_path)
        elif os.path.isdir(file_path):
            shutil.rmtree(file_path)
    except Exception as e:
        print(f'delete {file_path} failed: {e}')

def get_bucket_extrinsics():
    position = np.array([0, 1, 0.5])
    quat = transforms3d.euler.euler2quat(np.pi / 2, np.pi, 0)
    R = transforms3d.quaternions.quat2mat(quat)
    t = position
    return R, t

def get_bucket_intrinsics(fov_deg=69.4):
    fov_rad = np.deg2rad(fov_deg)
    fx = fy = CAMERA_RES / (2 * np.tan(fov_rad / 2))
    cx = cy = CAMERA_RES / 2
    return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])

def project_gripper_to_image(pos_world, K, R, t):
    x_cam = R @ (pos_world - t)
    x, y, z = x_cam
    if z <= 0:
        return [-1, -1]
    u = (K[0, 0] * x / z) + K[0, 2]
    v = (K[1, 1] * y / z) + K[1, 2]
    return [int(round(u)), int(round(v))]

def parse_single_example(example_proto):
    feature_description = {
        "episode_metadata/file_path": tf.io.FixedLenFeature([], tf.string),
        "steps/state": tf.io.VarLenFeature(tf.float32),
    }
    return tf.io.parse_single_example(example_proto, feature_description)

def run():
    K = get_bucket_intrinsics()
    R, t = get_bucket_extrinsics()
    uv_scaled_dict = {}

    tfrecord_files = sorted([
        os.path.join(tfrecord_dir, f)
        for f in os.listdir(tfrecord_dir)
        if f.endswith(".tfrecord") or ".tfrecord-" in f
    ])

    for tf_file in tqdm(tfrecord_files, desc="processing tfrecords"):
        dataset = tf.data.TFRecordDataset(tf_file)
        for raw_record in dataset:
            example = parse_single_example(raw_record)
            file_path = example["episode_metadata/file_path"].numpy().decode()
            state_seq = tf.sparse.to_dense(example["steps/state"]).numpy()

            if len(state_seq) == 0:
                continue

            try:
                states = state_seq.reshape(-1, 33)
            except:
                continue

            gripper_pos_seq = states[:, 28:31]
            uv_full = [project_gripper_to_image(p, K, R, t) for p in gripper_pos_seq]
            uv_scaled = [[int(u * SCALE), int(v * SCALE)] if u >= 0 and v >= 0 else [-1, -1] for u, v in uv_full]

            match = re.search(r"episode_(\d+)_combined\.npz$", file_path)
            if not match:
                continue
            episode_id = match.group(1)
            uv_scaled_dict[episode_id] = uv_scaled

            frame_dir = os.path.join(video_dir, f"episode_{episode_id}_combined")

            if not os.path.exists(frame_dir):
                print(f"[skip] missing frame folder: {frame_dir}")
                continue

            frame_paths = sorted(glob.glob(os.path.join(frame_dir, "frame_*.png")))
            if len(frame_paths) == 0:
                print(f"[skip] no image frames: {frame_dir}")
                continue

            fps = 10
            frame_count = len(frame_paths)
            print(f"[align check] {episode_id}: state frame count={states.shape[0]}, image frame count={frame_count}")

            out_path = os.path.join(save_video_dir, f"{episode_id}_vis.mp4")
            frame0 = imageio.imread(frame_paths[0])
            h, w = TARGET_RES, TARGET_RES
            out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))

            for frame_idx, frame_path in enumerate(frame_paths):
                if frame_idx >= len(uv_scaled):
                    break

                frame = imageio.imread(frame_path)
                frame = cv2.resize(frame, (TARGET_RES, TARGET_RES))

                u, v = uv_scaled[frame_idx]
                u = max(0, min(u, TARGET_RES - 1))
                v = max(0, min(v, TARGET_RES - 1))
                cv2.circle(frame, (u, v), radius=4, color=(0, 0, 255), thickness=-1)

                pos_xyz = gripper_pos_seq[frame_idx]
                text = f"x={pos_xyz[0]:.3f}, y={pos_xyz[1]:.3f}, z={pos_xyz[2]:.3f}"
                cv2.putText(frame, text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 255), 1)

                origin_uv = project_gripper_to_image(np.array([0, 0, 0]), K, R, t)
                if origin_uv[0] >= 0 and origin_uv[1] >= 0:
                    origin_uv_scaled = [int(origin_uv[0] * SCALE), int(origin_uv[1] * SCALE)]
                    cv2.circle(frame, tuple(origin_uv_scaled), radius=4, color=(0, 255, 0), thickness=-1)
                    cv2.putText(frame, "origin", (origin_uv_scaled[0] + 5, origin_uv_scaled[1] - 5),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)

                out.write(frame)

            out.release()
            print(f"[save video] {out_path}")

    with open(save_json_path, "w") as f:
        json.dump(uv_scaled_dict, f, indent=2)
    print(f"[save json] {save_json_path}")

run()


### (Optional) If you are just running the bucket example, please skip this cell. Run it only when you need the gripper_position for the other three DexArt objects, as their data dimensions differ from those of the bucket.

In [None]:

import os
import json
import numpy as np
import tensorflow as tf
import transforms3d
from tqdm import tqdm
import re
import shutil
import glob
import imageio.v2 as imageio
import cv2

from utils import DATA_DIR

tfrecord_dir = DATA_DIR / "tfds_datasets/bucket_dex_art_dataset/1.0.0"
tfrecord_dir = str(tfrecord_dir)

video_dir = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/dexart_all_bucket_png/bucket_viz"
video_dir = str(video_dir)
save_video_dir = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/results_gpos/bucket/vis_gpos"
save_video_dir = str(save_video_dir)
save_json_path = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/results_gpos/bucket/gripper_positions.json"
save_json_path = str(save_json_path)

TARGET_RES = 224
CAMERA_RES = 1000
SCALE = TARGET_RES / CAMERA_RES

os.makedirs(save_video_dir, exist_ok=True)
for filename in os.listdir(save_video_dir):
    file_path = os.path.join(save_video_dir, filename)
    try:
        if os.path.isfile(file_path) or os.path.islink(file_path):
            os.unlink(file_path)
        elif os.path.isdir(file_path):
            shutil.rmtree(file_path)
    except Exception as e:
        print(f'delete {file_path} failed: {e}')



def get_bucket_extrinsics():
    # === faucet_viz2  ===
    position = np.array([0, 0.8, 0.5])
    quat = transforms3d.euler.euler2quat(np.pi / 3, np.pi, 0)  

    R_cam_to_world = transforms3d.quaternions.quat2mat(quat)

    R_world_to_cam = R_cam_to_world.T
    t = position  
    return R_world_to_cam, t


def get_bucket_intrinsics(fov_deg=69.4):
    fov_rad = np.deg2rad(fov_deg)
    fx = fy = CAMERA_RES / (2 * np.tan(fov_rad / 2))
    cx = cy = CAMERA_RES / 2
    return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])

def project_gripper_to_image(pos_world, K, R, t):
    x_cam = R @ (pos_world - t)
    x, y, z = x_cam
    if z <= 0:
        return [-1, -1]
    u = (K[0, 0] * x / z) + K[0, 2]
    v = (K[1, 1] * y / z) + K[1, 2]
    return [int(round(u)), int(round(v))]

def parse_single_example(example_proto):
    feature_description = {
        "episode_metadata/file_path": tf.io.FixedLenFeature([], tf.string),
        "steps/state": tf.io.VarLenFeature(tf.float32),
    }
    return tf.io.parse_single_example(example_proto, feature_description)

def run():
    K = get_bucket_intrinsics()
    R, t = get_bucket_extrinsics()
    uv_scaled_dict = {}

    tfrecord_files = sorted([
        os.path.join(tfrecord_dir, f)
        for f in os.listdir(tfrecord_dir)
        if f.endswith(".tfrecord") or ".tfrecord-" in f
    ])

    for tf_file in tqdm(tfrecord_files, desc="处理TFRecords"):
        dataset = tf.data.TFRecordDataset(tf_file)
        for raw_record in dataset:
            example = parse_single_example(raw_record)
            file_path = example["episode_metadata/file_path"].numpy().decode()
            state_seq = tf.sparse.to_dense(example["steps/state"]).numpy()

            if len(state_seq) == 0:
                continue

            try:
                states = state_seq.reshape(-1, 32)
            except:
                continue

            gripper_pos_seq = states[:, -4:-1]
            uv_full = [project_gripper_to_image(p, K, R, t) for p in gripper_pos_seq]
            uv_scaled = [[int(u * SCALE), int(v * SCALE)] if u >= 0 and v >= 0 else [-1, -1] for u, v in uv_full]

            match = re.search(r"episode_(\d+)_combined\.npz$", file_path)
            if not match:
                continue
            episode_id = match.group(1)
            uv_scaled_dict[episode_id] = uv_scaled

            frame_dir = os.path.join(video_dir, f"episode_{episode_id}_combined")

            if not os.path.exists(frame_dir):
                print(f"[skip] missing frame folder: {frame_dir}")
                continue

            frame_paths = sorted(glob.glob(os.path.join(frame_dir, "frame_*.png")))
            if len(frame_paths) == 0:
                print(f"[skip] no image frames: {frame_dir}")
                continue

            fps = 10
            frame_count = len(frame_paths)
            print(f"[align check] {episode_id}: state frame count={states.shape[0]}, image frame count={frame_count}")

            out_path = os.path.join(save_video_dir, f"{episode_id}_vis.mp4")
            frame0 = imageio.imread(frame_paths[0])
            h, w = TARGET_RES, TARGET_RES
            out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))

            for frame_idx, frame_path in enumerate(frame_paths):
                if frame_idx >= len(uv_scaled):
                    break

                frame = imageio.imread(frame_path)
                frame = cv2.resize(frame, (TARGET_RES, TARGET_RES))

                u, v = uv_scaled[frame_idx]
                u = max(0, min(u, TARGET_RES - 1))
                v = max(0, min(v, TARGET_RES - 1))
                cv2.circle(frame, (u, v), radius=4, color=(0, 0, 255), thickness=-1)

                pos_xyz = gripper_pos_seq[frame_idx]
                text = f"x={pos_xyz[0]:.3f}, y={pos_xyz[1]:.3f}, z={pos_xyz[2]:.3f}"
                cv2.putText(frame, text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 255), 1)

                origin_uv = project_gripper_to_image(np.array([0, 0, 0]), K, R, t)
                if origin_uv[0] >= 0 and origin_uv[1] >= 0:
                    origin_uv_scaled = [int(origin_uv[0] * SCALE), int(origin_uv[1] * SCALE)]
                    cv2.circle(frame, tuple(origin_uv_scaled), radius=4, color=(0, 255, 0), thickness=-1)
                    cv2.putText(frame, "origin", (origin_uv_scaled[0] + 5, origin_uv_scaled[1] - 5),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)

                out.write(frame)

            out.release()
            print(f"[save video] {out_path}")

    with open(save_json_path, "w") as f:
        json.dump(uv_scaled_dict, f, indent=2)
    print(f"[save json] {save_json_path}")

run()


# move primitive

#### move primitive for bucket

In [None]:
import os
import tensorflow as tf
import numpy as np
import json
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# ==== 配置路径 ====
from utils import DATA_DIR

tfrecord_dir = DATA_DIR / "tfds_datasets/bucket_dex_art_dataset/1.0.0"
tfrecord_dir = str(tfrecord_dir)

save_dir = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/results_move_primitives/bucket"
save_dir = str(save_dir)
os.makedirs(save_dir, exist_ok=True)

tfrecord_files = sorted([
    os.path.join(tfrecord_dir, f)
    for f in os.listdir(tfrecord_dir)
    if f.startswith("dex_art_dataset-train.tfrecord")
])

def describe_move(move_vec):
    direction_names = [
        {-1: "backward", 0: None, 1: "forward"},   # x
        {-1: "right",    0: None, 1: "left"},      # y
        {-1: "down",     0: None, 1: "up"},        # z
    ]

    move_descriptions = [direction_names[i][move_vec[i]] for i in range(3)]
    move_descriptions = [desc for desc in move_descriptions if desc is not None]

    if len(move_descriptions) == 0:
        return "stop"
    else:
        return "move " + " ".join(move_descriptions)


def classify_movement(move, threshold=0.003):
    diff = move[-1] - move[0]

    if np.sum(np.abs(diff[:3])) > 3 * threshold:
        diff[:3] *= 3 * threshold / np.sum(np.abs(diff[:3]))

    move_vec = 1 * (diff[:3] > threshold) - 1 * (diff[:3] < -threshold)
    return describe_move(move_vec), move_vec


def get_move_primitives_from_states(states):
    move_trajs = [states[i : i + 4] for i in range(len(states) - 1)]
    primitives = [classify_movement(move) for move in move_trajs]
    primitives.append(primitives[-1])  # 补上最后一帧
    return primitives

def parse_single_example(example_proto):
    feature_description = {
        "episode_metadata/file_path": tf.io.FixedLenFeature([], tf.string),
        "steps/state": tf.io.VarLenFeature(tf.float32),  # 注意：必须使用 float32，读取后转为 float64
    }
    return tf.io.parse_single_example(example_proto, feature_description)

all_move_texts = {}
all_states_list = []

for tfrecord_path in tqdm(tfrecord_files, desc="reading tfrecords"):
    raw_dataset = tf.data.TFRecordDataset(tfrecord_path)

    for raw_record in raw_dataset:
        try:
            example = parse_single_example(raw_record)

            file_path = example["episode_metadata/file_path"].numpy().decode()
            save_key = file_path

            state_seq = tf.sparse.to_dense(example["steps/state"]).numpy().astype(np.float64)

            if len(state_seq) == 0:
                print(f"[skip] {save_key} no state")
                continue

            states = state_seq.reshape(-1, 32)  
            states_for_move = states[:, 28:31]     

            move_primitives = get_move_primitives_from_states(states_for_move)
            move_text_list = [move_text for move_text, move_vec in move_primitives]

            all_move_texts[save_key] = move_text_list
            all_states_list.append(states)

            print(f"[done] {save_key}: {states.shape[0]} frames")
        except Exception as e:
            print(f"[error] {tfrecord_path}: {e}")

# ==== 保存 ====
all_states_concat = np.concatenate(all_states_list, axis=0)
np.save(os.path.join(save_dir, "all_states.npy"), all_states_concat)

with open(os.path.join(save_dir, "all_moves_list.json"), "w") as f:
    json.dump(all_move_texts, f, indent=2)

print("finished（file_path -> move list）")


### move primitive for other three dexart objects

In [None]:
import os
import tensorflow as tf
import numpy as np
import json
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# ==== 配置路径 ====
from utils import DATA_DIR

tfrecord_dir = DATA_DIR / "tfds_datasets/bucket_dex_art_dataset/1.0.0"
tfrecord_dir = str(tfrecord_dir)
save_dir = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/results_move_primitives/bucket"
save_dir = str(save_dir)
os.makedirs(save_dir, exist_ok=True)

tfrecord_files = sorted([
    os.path.join(tfrecord_dir, f)
    for f in os.listdir(tfrecord_dir)
    if f.startswith("dex_art_dataset-train.tfrecord")
])

def describe_move(move_vec):
    direction_names = [
        {-1: "backward", 0: None, 1: "forward"},   # x
        {-1: "right",    0: None, 1: "left"},      # y
        {-1: "down",     0: None, 1: "up"},        # z
    ]

    move_descriptions = [direction_names[i][move_vec[i]] for i in range(3)]
    move_descriptions = [desc for desc in move_descriptions if desc is not None]

    if len(move_descriptions) == 0:
        return "stop"
    else:
        return "move " + " ".join(move_descriptions)


def classify_movement(move, threshold=0.003):
    diff = move[-1] - move[0]

    if np.sum(np.abs(diff[:3])) > 3 * threshold:
        diff[:3] *= 3 * threshold / np.sum(np.abs(diff[:3]))

    move_vec = 1 * (diff[:3] > threshold) - 1 * (diff[:3] < -threshold)
    return describe_move(move_vec), move_vec


def get_move_primitives_from_states(states):
    move_trajs = [states[i : i + 4] for i in range(len(states) - 1)]
    primitives = [classify_movement(move) for move in move_trajs]
    primitives.append(primitives[-1])  # 补上最后一帧
    return primitives

def parse_single_example(example_proto):
    feature_description = {
        "episode_metadata/file_path": tf.io.FixedLenFeature([], tf.string),
        "steps/state": tf.io.VarLenFeature(tf.float32),  # 注意：必须使用 float32，读取后转为 float64
    }
    return tf.io.parse_single_example(example_proto, feature_description)

all_move_texts = {}
all_states_list = []

for tfrecord_path in tqdm(tfrecord_files, desc="reading tfrecords"):
    raw_dataset = tf.data.TFRecordDataset(tfrecord_path)

    for raw_record in raw_dataset:
        try:
            example = parse_single_example(raw_record)

            file_path = example["episode_metadata/file_path"].numpy().decode()
            save_key = file_path

            state_seq = tf.sparse.to_dense(example["steps/state"]).numpy().astype(np.float64)

            if len(state_seq) == 0:
                print(f"[skip] {save_key} no state")
                continue

            states = state_seq.reshape(-1, 32)  
            states_for_move = states[:, -4:-1]     

            move_primitives = get_move_primitives_from_states(states_for_move)
            move_text_list = [move_text for move_text, move_vec in move_primitives]

            all_move_texts[save_key] = move_text_list
            all_states_list.append(states)

            print(f"[done] {save_key}: {states.shape[0]} frames")
        except Exception as e:
            print(f"[error] {tfrecord_path}: {e}")

all_states_concat = np.concatenate(all_states_list, axis=0)
np.save(os.path.join(save_dir, "all_states.npy"), all_states_concat)

with open(os.path.join(save_dir, "all_moves_list.json"), "w") as f:
    json.dump(all_move_texts, f, indent=2)

print("finished（file_path -> move list）")


### post process for move primitive

In [None]:
#Continue to process the extracted move primitive (construct the file name key and move to save to the data_middle folder, and then convert key to pure numbers)
import json
import os

from utils import DATA_DIR

# input psth
INPUT_PATH = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/results_move_primitives/bucket/all_moves_list.json"
INPUT_PATH = str(INPUT_PATH)
# MIDDLE data out put path
OUTPUT_PATH = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/results_move_primitives/bucket/raw_primitives.json"
OUTPUT_PATH = str(OUTPUT_PATH)
#FINAL data out put dir
output_dir = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/results_move_primitives/bucket"
output_dir = str(output_dir)

with open(INPUT_PATH, "r") as f:
    raw_moves = json.load(f)

formatted_moves = {}

for episode_id_str, move_list in raw_moves.items():
    #key：object_box_Task_{episode_id}_Demo_0
    formatted_key = f"{episode_id_str}"
    formatted_moves[formatted_key] = move_list

os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
with open(OUTPUT_PATH, "w") as f:
    json.dump(formatted_moves, f)

print(f"Saved rawprimitives.json to {OUTPUT_PATH}")

import json
import re

def convert_keys(input_path, output_path):
    with open(input_path, "r") as f:
        data = json.load(f)

    converted_data = {}
    for key, value in data.items():
        # match /data2/zyx/demo_dataset/fold/88 → extract "88"
        match = re.search(r"episode_(\d+)_combined\.npz$", key)


        if match:
            episode_id = match.group(1)
            converted_data[episode_id] = value
        else:
            print(f"[WARN] Key format not matched: {key}")

    with open(output_path, "w") as f:
        json.dump(converted_data, f, indent=2)
    print(f"Saved converted primitives file to {output_path}")




convert_keys(OUTPUT_PATH, f"{output_dir}/primitives.json")



#### visualize move primitive (optional)

In [None]:
import os
import json
from glob import glob
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

obj ='bucket'

from utils import DATA_DIR

json_path = DATA_DIR / f"planning_datasets/bucket_dex_art_dataset/results_move_primitives/{obj}/primitives.json"
json_path = str(json_path)
image_root = DATA_DIR / f"planning_datasets/bucket_dex_art_dataset/dexart_all_{obj}_png/{obj}_viz"
image_root = str(image_root)
save_root = DATA_DIR / f"planning_datasets/bucket_dex_art_dataset/results_move_primitives/a_check_move_gpos/{obj}"
save_root = str(save_root)
os.makedirs(save_root, exist_ok=True)

with open(json_path, "r") as f:
    primitives = json.load(f)

for episode_id, moves in tqdm(primitives.items()):
    input_folder = os.path.join(image_root, f"episode_{episode_id}_combined")
    output_folder = os.path.join(save_root, f"episode_{episode_id}_combined")
    os.makedirs(output_folder, exist_ok=True)

    png_files = sorted(glob(os.path.join(input_folder, "*.png")))

    if len(moves) != len(png_files):
        print(f"[Mismatch] Episode {episode_id}: {len(moves)} moves vs {len(png_files)} PNGs")
        continue

    for png_path, move in zip(png_files, moves):
        image = cv2.imread(png_path)

        if image is None:
            print(f"[Error] Cannot read image: {png_path}")
            continue

        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(image, move, (30, 50), font, 0.5, (0, 255, 0), 1, cv2.LINE_AA)

        filename = os.path.basename(png_path)
        save_path = os.path.join(output_folder, filename)
        cv2.imwrite(save_path, image)

print("finished")


# scense description

 First run `dexart/scense_description/scripts/generate_descriptions_dexart.py` to generate description for object.
 
 Then run following cell to get formalized data.

In [None]:
import json
import os
obj='bucket'

from utils import DATA_DIR

input_path = DATA_DIR / f"planning_datasets/bucket_dex_art_dataset/results_descriptions/{obj}/results_0.json"
input_path = str(input_path)
output_path = DATA_DIR / f"planning_datasets/bucket_dex_art_dataset/results_descriptions/{obj}/descriptions_{obj}_fixed.json"
output_path = str(output_path)

with open(input_path, "r") as f:
    data = json.load(f)

new_data = {}
for k, v in data.items():
    new_key = k.split("_")[-3] 
    new_data[new_key] = v

with open(output_path, "w") as f:
    json.dump(new_data, f, indent=2)

print(f"finished, saved to {output_path}")

# bounding box 

1. Install the environment from [Grounded-SAM-2](https://github.com/IDEA-Research/Grounded-SAM-2).
2. Run `third_party/Grounded-SAM-2/point_object_by_hand.py` and point the target object to get anchor. 
3. Modify `INPUT_DIR` to the folder containing png extracted from tf.
4. Run following scripts to get bounding box raw data. 
```bash
cd third_party/Grounded-SAM-2
conda activate grounded_sam2
python bbox_dexart.py
```
5. run following cell to extract raw bounding box data to json.

In [None]:
import os
import json
from tqdm import tqdm

from utils import DATA_DIR

INPUT_DIR = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/results_bbox/bucket"
INPUT_DIR = str(INPUT_DIR)

OUTPUT_PATH = f"{INPUT_DIR}/bboxes.json"

TARGET_NAME = INPUT_DIR.split('/')[-1]

bboxes_all = {}

for episode_folder in tqdm(sorted(os.listdir(INPUT_DIR)), desc="Processing episodes"):
    episode_path = os.path.join(INPUT_DIR, episode_folder)
    video_boxes_path = os.path.join(episode_path, "video_boxes.json")

    if not os.path.isfile(video_boxes_path):
        print(f"Missing video_boxes.json in {episode_folder}, skipped.")
        continue

    with open(video_boxes_path, "r") as f:
        frame_box_data = json.load(f)

    frame_keys = sorted(frame_box_data.keys(), key=lambda x: int(''.join(filter(str.isdigit, x))))
    episode_boxes = []

    for frame_key in frame_keys:
        objects = frame_box_data.get(frame_key, [])
        frame_boxes = []

        for obj in objects:
            if isinstance(obj, list) and len(obj) == 2:
                original_name, bbox = obj
                frame_boxes.append([TARGET_NAME, bbox]) 

        episode_boxes.append([frame_boxes[0]] if frame_boxes else [])  

    episode_id = episode_folder.replace("episode_", "").replace("_combined", "")
    bboxes_all[episode_id] = episode_boxes

os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)

with open(OUTPUT_PATH, "w") as f:
    json.dump(bboxes_all, f)

print(f"Saved bboxes.json to {OUTPUT_PATH}")


# CoT

1. Run `dexart/CoT/cot_code/batch_generate_plan_subtasks.sh`.
2. Run `dexart/CoT/cot_code/batch_filter_plan_subtasks.sh`.

### combine

In [None]:
# change the name of gpos.json to gripper_positions.json if encounter error
import argparse
import json
from tqdm import tqdm
import os

from utils import DATA_DIR

INPUT_DIR = DATA_DIR / "planning_datasets/bucket_dex_art_dataset/results_bbox/bucket"


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--libero_task_suite", type=str, default="bucket")
    parser.add_argument("--data_dir", type=str, default="/data/lyd/embodied-CoT/data_results/data_cot_results/data_middle/dexart/results_cot")
    args = parser.parse_args([])

    bboxes_file_path = os.path.join(args.data_dir, args.libero_task_suite, f"bboxes.json")
    with open(bboxes_file_path, "r") as f:
        bboxes = json.load(f)

    gripper_positions_file_path = os.path.join(args.data_dir, args.libero_task_suite, f"gripper_positions.json")
    with open(gripper_positions_file_path, "r") as f:
        gripper_positions = json.load(f)

    primitives_file_path = os.path.join(args.data_dir, args.libero_task_suite, f"primitives.json")
    with open(primitives_file_path, "r") as f:
        primitives = json.load(f)

    # reasonings_file_path = os.path.join(args.data_dir, args.libero_task_suite+"_w_mask", f"{args.libero_task_suite}_plan_subtasks.json")
    reasonings_file_path = os.path.join(args.data_dir, args.libero_task_suite, "filtered_reasoning_h10.json")
    with open(reasonings_file_path, "r") as f:
        reasonings = json.load(f)

    for file_path in tqdm(reasonings.keys(), desc="Merging"):
        if file_path not in bboxes:
            print(f"File path {file_path} not found in bboxes")
            continue
        if file_path not in gripper_positions:
            print(f"File path {file_path} not found in gripper_positions")
            continue
        if file_path not in primitives:
            print(f"File path {file_path} not found in primitives")
            continue
        bbox = bboxes[file_path]
        gripper_position = gripper_positions[file_path]
        primitive = primitives[file_path]

        try:
            assert len(bbox) == len(gripper_position) == len(primitive) == len(reasonings[file_path]["0"]["reasoning"]), f"Length mismatch for {file_path}: {len(bbox)}, {len(gripper_position)}, {len(primitive)}, {len(reasonings[file_path]['0']['reasoning'])}"
        except Exception as e:
            print(e)
            continue

        reasonings[file_path]["0"]["features"].update(
            {
                "bboxes": bbox,
                "gripper_position": gripper_position,
                "move_primitive": primitive
            }
        )

    target_dir = os.path.join(args.data_dir, args.libero_task_suite, "data_merged")
    os.makedirs(target_dir, exist_ok=True)
    print(f"Saving to {target_dir}")
    target_file_path = os.path.join(target_dir, f"reasoning_{args.libero_task_suite}.json")

    with open(target_file_path, "w") as f:
        json.dump(reasonings, f)
