In [None]:
import itertools  # 반복문과 조합 관련 유틸리티
import os  # 운영체제 파일 경로 등 처리
import os.path as osp  # 경로(join, split 등) 처리 단축
import time  # 시간 측정 및 대기
from collections import OrderedDict, defaultdict  # 순서 있는 dict, 기본값 딕셔너리
from datetime import datetime  # 날짜/시간 객체
from tqdm import tqdm  # 진행 바 표시

import h5py  # HDF5 파일 입출력

import torch  # PyTorch 텐서 연산
import collections  # 컬렉션 관련 타입
import tensorflow_datasets as tfds  # TFDS 데이터셋 로드
import numpy as np  # 수치 연산
import tkinter as tk  # 간단한 GUI (언어 입력 등)
from tkinter import simpledialog  # 간단한 입력 다이얼로그
from PIL import Image, ImageTk  # 이미지 처리 및 표시
import argparse  # 명령행 인자 파싱

from IPython.display import display, Image as IPyImage

# === 전역 디렉토리 설정 ===
BASE_ROOT_DIR = "/home/parkjeongsu/TinyVLA"
GIF_SAVE_DIR = os.path.join(BASE_ROOT_DIR, "droid_image")
SMOOTH_ACTION_FIG_DIR = os.path.join(BASE_ROOT_DIR, "droid_traj", "smooth_action_results")
DATASET_DIR = os.path.join(BASE_ROOT_DIR, "/home/parkjeongsu/TinyVLA/Droid")
OUTPUT_H5_DIR = os.path.join(DATASET_DIR, "droid_with_lang")

def get_image_list_np(img_rgb_dir_path, remove_index_list):
    """
    지정한 디렉토리에서 RGB 이미지 파일 목록을 불러와 numpy 배열로 반환합니다.
    remove_index_list에 포함된 인덱스는 건너뜁니다.
    """
    cur_camera_rgb_list = []  # 개별 프레임 저장할 리스트
    img_name_list = os.listdir(img_rgb_dir_path)  # 디렉토리 내 파일 이름 목록
    img_name_list = sorted(img_name_list)  # 정렬하여 순서 보장

    for idx, img_name in enumerate(img_name_list):  # 각 이미지 파일에 대해
        if idx in remove_index_list:  # 제거할 인덱스면 건너뜀
            continue

        img_path = os.path.join(img_rgb_dir_path, img_name)  # 전체 경로 생성

        # (w 640, h 480)
        img_frame = Image.open(img_path).convert('RGB')  # PIL로 열어 RGB로 변환
        img_np = np.array(img_frame)  # numpy 배열로 변환 (H,W,3)
        cur_camera_rgb_list.append(img_np)  # 리스트에 추가

    cur_camera_rgb_np = np.array(cur_camera_rgb_list)  # 전체 시퀀스 배열화
    print('+++++++++++++++')
    print(f"img_rgb_dir_path: {img_rgb_dir_path}")  # 경로 정보 출력
    print(f'cur_camera_rgb_np size: {cur_camera_rgb_np.shape}')  # 배열 크기 출력

    return cur_camera_rgb_np  # 반환


def plot_smooth_action(traj_act_xyz_np, fig_name):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(12, 4))
    figure_name = ["x", "y", "z"]
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.plot(range(traj_act_xyz_np.shape[0]), traj_act_xyz_np[:, i], label='cur_action')
        plt.title(figure_name[i])
        plt.legend()
    plt.suptitle(f"Differences between predicted and target actions_traj")
    plt.tight_layout()

    os.makedirs(SMOOTH_ACTION_FIG_DIR, exist_ok=True)
    figure_path = os.path.join(SMOOTH_ACTION_FIG_DIR, f"{fig_name}.png")
    plt.savefig(figure_path)
    plt.clf()


def print_h5_structure(group, indent=0):
    """
    HDF5 파일 그룹 구조를 재귀적으로 출력합니다.
    """
    for name in group:
        item = group[name]
        print(" " * indent + f"name: {name}")
        if isinstance(item, h5py.Group):  # 그룹일 경우
            print(" " * indent + f"Group: {name}")
            print_h5_structure(item, indent + 2)  # 재귀 호출
        elif isinstance(item, h5py.Dataset):  # 데이터셋일 경우
            print(" " * indent + f"Dataset: {name} (Shape: {item.shape}, Dtype: {item.dtype})")
        else:
            print(" " * indent + f"Unknown item: {name}")


def print_dict_structure(cur_dict, indent=0):
    """
    Python dict 구조를 재귀적으로 출력합니다.
    """
    for name in cur_dict.keys():
        item = cur_dict[name]
        print(" " * indent + f"name: {name}")
        if isinstance(item, dict):
            print(" " * indent + f"Dict: {name}")
            print_dict_structure(item, indent + 2)
        elif isinstance(item, np.ndarray):
            print(" " * indent + f"Array: {name} (Shape: {item.shape}, Dtype: {item.dtype})")
        else:
            print(" " * indent + f"Unknown item: {name}")


def to_numpy(x):
    """
    중첩된 torch.Tensor를 재귀적으로 numpy 배열로 변환합니다.
    """
    def f(tensor):
        if tensor.is_cuda:
            return tensor.detach().cpu().numpy()
        else:
            return tensor.detach().numpy()

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: f,
            np.ndarray: lambda x: x,
            type(None): lambda x: x,
        }
    )


def recursive_dict_list_tuple_apply(x, type_func_dict):
    """
    중첩된 dict/list/tuple 내부 요소에 주어진 함수를 적용합니다.
    """
    assert (list not in type_func_dict)
    assert (tuple not in type_func_dict)
    assert (dict not in type_func_dict)

    if isinstance(x, (dict, collections.OrderedDict)):
        new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else dict()
        for k, v in x.items():
            new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
        return new_x
    elif isinstance(x, (list, tuple)):
        ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
        return tuple(ret) if isinstance(x, tuple) else ret
    else:
        for t, f in type_func_dict.items():
            if isinstance(x, t):
                return f(x)
        return x  # 그 외 타입은 그대로 반환


def matrix_to_rotation_6d(matrix):
    batch_dim = matrix.size()[:-2]
    return matrix[..., :2, :].clone().reshape(batch_dim + (6,))

def _axis_angle_rotation(axis, angle):
    cos = torch.cos(angle)
    sin = torch.sin(angle)
    one = torch.ones_like(angle)
    zero = torch.zeros_like(angle)
    if axis == "X":
        R = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
    elif axis == "Y":
        R = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
    elif axis == "Z":
        R = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
    else:
        raise ValueError("Axis must be X, Y, or Z")
    return torch.stack(R, -1).reshape(angle.shape + (3, 3))

def euler_angles_to_matrix(euler_angles, convention):
    matrices = [_axis_angle_rotation(c, e) for c, e in zip(convention, torch.unbind(euler_angles, -1))]
    return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])

def euler_angles_to_rot_6d(euler_angles, convention="XYZ"):
    return matrix_to_rotation_6d(euler_angles_to_matrix(euler_angles, convention))

def show_gif(images, save_dir=GIF_SAVE_DIR):
    os.makedirs(save_dir, exist_ok=True)  # ← 디렉토리 생성 추가
    path = os.path.join(save_dir, 'result.gif')
    images[0].save(path, save_all=True, append_images=images[1:], duration=int(1000/15), loop=0)
    display(IPyImage(filename=path))




def convert_h5py2np_dict(group, state_np_dict, indent=0):
    """
    HDF5 그룹을 순회하며 numpy dict로 변환하고 구조 출력합니다.
    """
    for name in group:
        item = group[name]
        print(" " * indent + f"name: {name}")
        if isinstance(item, h5py.Group):
            state_np_dict[name] = {}
            convert_h5py2np_dict(item, state_np_dict[name], indent + 2)
        elif isinstance(item, h5py.Dataset):
            state_np_dict[name] = item[...]
            print(" " * indent + f"Dataset: {name} (Shape: {item.shape}, Dtype: {item.dtype})")
        else:
            state_np_dict[name] = item


def print_name(name):
    """단순히 이름을 출력"""
    print(name)

def generate_h5(obs_replay, action_replay, cfg, total_traj_cnt, act_root_dir_path, edit_flag):
    data_dict = {
        '/observations/qpos': obs_replay['qpos'],
        '/observations/qvel': obs_replay['qvel'],
        '/action': action_replay,
        'is_edited': np.array(edit_flag)
    }
    for cam_name in cfg['camera_names']:
        data_dict[f'/observations/images/{cam_name}'] = obs_replay['images'][cam_name]
    data_dict['/observations/images/wrist'] = obs_replay['images']['wrist']

    max_timesteps = len(data_dict['/observations/qpos'])
    dataset_path = os.path.join(act_root_dir_path, f'episode_{total_traj_cnt}')
    with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024**2*2) as root:
        root.attrs['sim'] = True
        obs = root.create_group('observations')
        image = obs.create_group('images')
        for cam_name in cfg['camera_names'] + ['wrist']:
            image.create_dataset(cam_name,
                (max_timesteps, cfg['cam_height'], cfg['cam_width'], 3),
                dtype='uint8', chunks=(1, cfg['cam_height'], cfg['cam_width'], 3))
        obs.create_dataset('qpos', (max_timesteps, cfg['qpos_dim']))
        obs.create_dataset('qvel', (max_timesteps, cfg['qpos_dim']))
        root.create_dataset('action', (max_timesteps, cfg['action_dim']))
        root.create_dataset('is_edited', (1,))
        root.create_dataset("language_raw", data=[np.string_(cfg['lang_intrs'])])
        for name, array in data_dict.items():
            root[name][...] = array

# === 설정 ===
args_src_root = DATASET_DIR
dataset_name = "droid_100"
cfg = {
    "task_name": "droid_1dot7t_lang",
    "camera_names": ["left", "right"],
    "cam_height": 180,
    "cam_width": 320,
    "state_dim": 7,
    "qpos_dim": 10,
    "action_dim": 10,
    "lang_intrs": "close the lid of the box"
}
act_target_root = OUTPUT_H5_DIR
os.makedirs(act_target_root, exist_ok=True)
act_root_dir_name = f'{cfg["task_name"]}_succ_t0001_s-0-0'
act_root_dir_path = os.path.join(act_target_root, act_root_dir_name)
os.makedirs(act_root_dir_path, exist_ok=True)

# === 데이터셋 로드 ===
ds = tfds.load(dataset_name, data_dir=args_src_root, split="train")
total_traj_cnt = 0

# === 변환 루프 ===
for episode in tqdm(ds):
    save_path = os.path.join(act_root_dir_path, f'episode_{total_traj_cnt}.hdf5')
    if os.path.exists(save_path):
        total_traj_cnt += 1
        continue

    cur_actions_dict = {}
    cur_obs_image = {'1': [], '2': []}
    cur_obs_wrist_image = []
    cur_obs_gripper_pos = []
    cur_obs_joint_state = []
    cur_obs_ee_pos = []
    cur_actions = []
    edit_flag = 0

    for idx, step in enumerate(episode['steps']):
        if idx == 0:
            cur_actions_dict = {k: [] for k in step['action_dict'].keys()}

        l1 = step['language_instruction'].numpy().decode('utf-8')
        l2 = step['language_instruction_2'].numpy().decode('utf-8')
        l3 = step['language_instruction_3'].numpy().decode('utf-8')
        raw_lang = l1 if len(l1) >= 4 else (l2 if len(l2) >= 4 else l3)
        if len(raw_lang) < 4:
            edit_flag = 1

        cur_actions.append(step['action'].numpy()[:-1])
        cur_obs_image['1'].append(step['observation']['exterior_image_1_left'].numpy())
        cur_obs_image['2'].append(step['observation']['exterior_image_2_left'].numpy())
        cur_obs_wrist_image.append(step['observation']['wrist_image_left'].numpy())
        cur_obs_gripper_pos.append(step['observation']['gripper_position'].numpy())
        cur_obs_joint_state.append(step['observation']['joint_position'].numpy())
        cur_obs_ee_pos.append(step['observation']['cartesian_position'].numpy()[:2])

        for k in cur_actions_dict:
            cur_actions_dict[k].append(step['action_dict'][k].numpy())

    if edit_flag:
        left_imgs = np.array(cur_obs_image['1'])
        right_imgs = np.array(cur_obs_image['2'])
        wrist_imgs = np.array(cur_obs_wrist_image)

        # 최소 길이만큼 잘라서 일치시킴
        min_len = min(len(left_imgs), len(right_imgs), len(wrist_imgs))
        left_imgs = left_imgs[:min_len]
        right_imgs = right_imgs[:min_len]
        wrist_imgs = wrist_imgs[:min_len]

        # 이미지 연결 전에 dtype, shape 확인
        print(f"[DEBUG] left: {left_imgs.shape}, right: {right_imgs.shape}, wrist: {wrist_imgs.shape}")
        print(f"[DEBUG] dtype check: {left_imgs.dtype}, {right_imgs.dtype}, {wrist_imgs.dtype}")

        # 수직 연결 또는 수평 연결 중 선택
        try:
            all_images_np = np.concatenate((left_imgs, right_imgs, wrist_imgs), axis=2)  # (T, H, W*3, 3)
            all_images = [Image.fromarray(each.astype(np.uint8)) for each in all_images_np]
            show_gif(all_images)
        except Exception as e:
            print("❌ GIF 생성 실패:", e)

        # 언어 입력 받기
        raw_lang = input("please write a language instruction:")



    traj_len = min(len(cur_obs_image['1']), len(cur_obs_image['2']), len(cur_obs_wrist_image))

    in_action = np.array(cur_actions_dict['cartesian_position'])
    in_pos = in_action[:, :3]
    in_rot = in_action[:, 3:6]
    rot_6d = euler_angles_to_rot_6d(torch.from_numpy(in_rot)).numpy()
    gripper = np.array(cur_actions_dict['gripper_position'])
    traj_actions = np.concatenate((in_pos, rot_6d, gripper), axis=-1)[:traj_len]

    traj_qpos = np.concatenate((np.array(cur_obs_joint_state),
                                np.array(cur_obs_gripper_pos),
                                np.array(cur_obs_ee_pos)), axis=-1)[:traj_len]
    traj_qvel = np.zeros_like(traj_qpos)

    obs_replay = {
        'qpos': traj_qpos,
        'qvel': traj_qvel,
        'images': {
            'left': np.array(cur_obs_image['1'])[:traj_len],
            'right': np.array(cur_obs_image['2'])[:traj_len],
            'wrist': np.array(cur_obs_wrist_image)[:traj_len]
        }
    }
    cfg['lang_intrs'] = raw_lang
    generate_h5(obs_replay, traj_actions, cfg, total_traj_cnt, act_root_dir_path, edit_flag)
    total_traj_cnt += 1

2025-08-08 16:10:22.334079: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-08 16:10:22.358281: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-08-08 16:10:22.358308: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-08-08 16:10:22.359169: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-08-08 16:10:22.363602: I tensorflow/core/platform/cpu_feature_guar