In [1]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from PIL import Image
from tqdm import tqdm
from matplotlib import font_manager as fm

font_path = "/home2/qrchen/GillSans.ttc"
font_prop = fm.FontProperties(fname=font_path)

2025-07-23 14:43:42.197805: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753253022.211983 1657444 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753253022.216230 1657444 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753253022.228332 1657444 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753253022.228344 1657444 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753253022.228346 1657444 computation_placer.cc:177] computation placer alr

## 倒入pakage 并且define 函数

In [3]:
import numpy as np
import tensorflow_datasets as tfds
from PIL import Image
from IPython import display
import os
import matplotlib.pyplot as plt
import pandas as pd

df = pd.read_csv('/home2/qrchen/embodied-datasets/metadata.csv')

TRAIN_SPLIT = 'train[:10]'  
FRAME_SKIP_DEFAULT = 5
FRAME_SKIP_LARGE = 10  # 当帧数>100时使用
LARGE_FRAME_THRESHOLD = 100
THIRD_PERSON_FRAME_THRESHOLD = 1000
RESIZE_FACTOR = 1
MAX_EPISODES = 10

DATASET_CONFIG = {
    "dataset_name_in_csv": "RoboSet",
    "language_field": "language_instruction",
    "observation_field": "observation",
    "image_fields": ["image_left","image_right","image_top","image_wrist"],
    "depth_fields": []
}

# 从df 中找到 DATASET_CONFIG 中 dataset_name_in_csv 对应的 行，得到其中 nickname 的值
dataset_name_in_csv = DATASET_CONFIG["dataset_name_in_csv"]
dataset = df[df['Datasets'] == dataset_name_in_csv]
dataset = dataset['NickName'].item()
OUTPUT_DIR = f"{dataset}/sample"


def dataset2path(dataset_name):
  if dataset_name == 'robo_net':
    version = '1.0.0'
  elif dataset_name == 'language_table' or dataset_name == 'robo_set':
    version = '0.0.1'
  else:
    version = '0.1.0'
  return f'gs://gresearch/robotics/{dataset_name}/{version}'

# ============ 工具函数 ============
def depth_to_color_img(depth):
    """将depth转为彩色图像"""
    d = depth.copy()
    d = (d - np.nanmin(d)) / (np.nanmax(d) - np.nanmin(d) + 1e-8)
    cm = plt.get_cmap('jet')
    colored = cm(d)[:, :, :3]  # 只要RGB，不要alpha
    colored = (colored * 255).astype(np.uint8)
    return colored

def as_gif(images, path="temp.gif", resize_factor=0.5):
    """生成GIF文件"""
    if resize_factor != 1.0:
        resized_images = []
        for img in images:
            width, height = img.size
            new_size = (int(width * resize_factor), int(height * resize_factor))
            resized_images.append(img.resize(new_size, Image.Resampling.LANCZOS))
        images = resized_images
    
    images[0].save(path, save_all=True, append_images=images[1:], duration=int(1000/15), loop=0)
    gif_bytes = open(path,"rb").read()
    return gif_bytes

def get_language_instruction(episode, config):
    """提取语言指令"""
    for step in episode["steps"]:
        lang_inst = step[config["language_field"]].numpy()
        if isinstance(lang_inst, bytes):
            lang_inst = lang_inst.decode('utf-8')
        return lang_inst
    return ""

def process_single_step(obs, config):
    """处理单帧数据，返回拼接的图像和深度图像"""
    # 提取RGB图像
    rgb_images = []
    for field in config["image_fields"]:
        # 处理嵌套图像路径
        if isinstance(field, tuple):
            field_name, sub_field_name = field
            rgb_images.append(obs[field_name][sub_field_name].numpy())
        else:
            rgb_images.append(obs[field].numpy())
    concat_rgb = np.concatenate(rgb_images, axis=1)
    
    # 提取并处理深度图像
    depth_images = []
    if len(config["depth_fields"]) > 0:
      for field in config["depth_fields"]:
        depth_img = obs[field].numpy()
        color_depth = depth_to_color_img(depth_img)
        depth_images.append(color_depth)
    else:
      depth_images.append(np.zeros_like(concat_rgb))
    concat_depth = np.concatenate(depth_images, axis=1) 
    
    return Image.fromarray(concat_rgb), Image.fromarray(concat_depth)

def process_episode(episode, episode_idx, config):
    """处理单个episode"""
    # 获取语言指令
    lang_inst = get_language_instruction(episode, config)
    if lang_inst == "":
        return None
    
    print(f"Language Instruction: {lang_inst}")
    
    # 动态调整帧抽取
    total_frames = len(list(episode['steps']))
    frame_skip = FRAME_SKIP_LARGE if total_frames > LARGE_FRAME_THRESHOLD else FRAME_SKIP_DEFAULT
    print(f"当前episode总帧数: {total_frames}, 抽取帧数: {frame_skip}")
    
    # 收集图像
    rgb_images = []
    depth_images = []
    
    for step_idx, step in enumerate(episode["steps"]):
        if step_idx % frame_skip == 0:
            obs = step[config["observation_field"]]
            rgb_img, depth_img = process_single_step(obs, config)
            rgb_images.append(rgb_img)
            # depth_images.append(depth_img)
    
    # 生成文件名
    safe_filename = lang_inst.replace(" ", "_").replace(".", "")
    rgb_path = f"{OUTPUT_DIR}/{safe_filename}_{episode_idx}_image.gif"
    depth_path = f"{OUTPUT_DIR}/{safe_filename}_{episode_idx}_depth.gif"
    
    # 保存GIF
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    display.Image(as_gif(rgb_images, rgb_path, RESIZE_FACTOR))
    # display.Image(as_gif(depth_images, depth_path, RESIZE_FACTOR))
    print(f"已生成GIF文件: {rgb_path}, {depth_path}")
    
    return {
        "total_frames": total_frames,
        "processed_frames": len(rgb_images),
        "rgb_path": rgb_path,
        "depth_path": depth_path
    }

In [4]:
import spacy
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from collections import defaultdict, Counter
import pprint

nlp = spacy.load("en_core_web_sm")

## Read and build data from dataset 

In [6]:

b = tfds.builder_from_directory(builder_dir=dataset2path(dataset))
ds = b.as_dataset(split=TRAIN_SPLIT)  # 用100的缓冲区用于打乱，取其中的10条

2025-07-23 14:53:07.552253: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753253587.562552 1664277 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753253587.566926 1664277 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753253587.579807 1664277 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753253587.579820 1664277 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753253587.579822 1664277 computation_placer.cc:177] computation placer alr

## 生产GIF 图片，旧版本，先不要使用

In [5]:
# count = 0
# print("开始读取数据...")
# for episode_idx, episode in enumerate(ds):  # type: ignore
#     print(f"\n=== Episode {episode_idx + 1} ===")
    
#     result = process_episode(episode, episode_idx, DATASET_CONFIG)
#     if result is None:
#         continue 
#     count += 1
#     if count >= 2:
#         break

## 统计数据

In [None]:

import spacy
from collections import Counter
from tqdm import tqdm
import numpy as np
import tensorflow_datasets as tfds

# 1. 统计动词和动词短语
frame_num = []
task_list = []
def extract_verbs_and_phrases_from_tasks(task_list):
    """
    从英语任务列表中提取动词和动词短语，并统计它们的出现次数。
    """
    nlp = spacy.load("en_core_web_sm")
    verbs_and_phrases = []
    for task in task_list:
        doc = nlp(task)
        for token in doc:
            if token.pos_ == "VERB":
                verbs_and_phrases.append(token.text.lower())
                phrase = [token.text.lower()]
                for child in token.children:
                    if child.pos_ in ["PART", "ADV"]:
                        phrase.append(child.text.lower())
                if len(phrase) > 1:
                    verbs_and_phrases.append(" ".join(phrase))
    counts = Counter(verbs_and_phrases)
    return dict(counts)

# 2. 统计帧数和收集instruction

def get_instructions_and_frame_stats(ds):
    task_list = []
    frame_num = []
    for episode_idx, episode in tqdm(enumerate(ds)):
        step_0 = list(episode['steps'])[0]
        if 'natural_language_instruction' in step_0['observation']:
            task = step_0['observation']['natural_language_instruction'].numpy().decode('utf-8')
        elif 'language_instruction' in step_0:
            task = step_0['language_instruction'].numpy().decode('utf-8')
        else:
            instruction_bytes = step_0["observation"]["instruction"]
            instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
            task = tf.strings.split(instruction_encoded, "\x00")[0].numpy().decode('utf-8')
        task_list.append(task)
        frame_num.append(len(list(episode['steps'])))
    return task_list, frame_num

task_list, frame_num = get_instructions_and_frame_stats(ds)

# 3. 统计帧数均值和标准差
# 每一个episode统计帧数均值和标准差
mean_frames = np.mean(frame_num) if frame_num else 0
std_frames = np.std(frame_num) if frame_num else 0
print("--------------------------------")
print(f"Mean frames per episode: {mean_frames:.2f}")
print(f"Standard deviation of frames: {std_frames:.2f}")
print(f"总episode数: {len(frame_num)}，总instruction数: {len(task_list)}，去重后instruction数: {len(set(task_list))}")
# 统计动词和动词短语
# verb_stats = extract_verbs_and_phrases_from_tasks(task_list)
# print("动词和动词短语统计:")
# for k, v in sorted(verb_stats.items(), key=lambda x: -x[1]):
#     print(f"{k}: {v}")
# print(f"总动词和动词短语数: {len(verb_stats)}")
print("--------------------------------")
print(f"所有instruction的数量:{len(task_list)}")
print(f"所有的instruction: {task_list}")
print(f"所有的instruction去重后数: {len(set(task_list))}")
print(f"所有的instruction去重后:")
for i in set(task_list):
    print(i)
print(Counter(task_list))

0it [00:00, ?it/s]2025-07-23 14:53:16.587894: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:387] The default buffer size is 262144, which is overridden by the user specified `buffer_size` of 8388608
2025-07-23 14:53:25.717272: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-07-23 14:53:25.764380: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
1it [00:09,  9.19s/it]2025-07-23 14:53:25.834676: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
3it [00:09,  2.44s/it]2025-07-23 14:53:25.979203: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
6it [00:13,  2.01s/it]2025-07-23 14:53:30.402633: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with

--------------------------------
Mean frames per episode: 94.00
Standard deviation of frames: 18.00
总episode数: 10，总instruction数: 10，去重后instruction数: 7
--------------------------------
所有instruction的数量:10
所有的instruction: ['Drag mug right to left.', 'Pick up the butter.', 'Drag mug backwards.', 'Drag mug forwards.', 'Flap open oven.', 'Drag strainer right to left.', 'Drag strainer forwards.', 'Drag mug backwards.', 'Drag mug backwards.', 'Drag strainer forwards.']
所有的instruction去重后数: 7
所有的instruction去重后:
Drag strainer forwards.
Drag mug forwards.
Flap open oven.
Drag strainer right to left.
Pick up the butter.
Drag mug backwards.
Drag mug right to left.
Counter({'Drag mug backwards.': 3, 'Drag strainer forwards.': 2, 'Drag mug right to left.': 1, 'Pick up the butter.': 1, 'Drag mug forwards.': 1, 'Flap open oven.': 1, 'Drag strainer right to left.': 1})





## 按照MAX_TASKS 个task 生成 gif 图片，并且保存在对应的samples/路径下

In [None]:
from collections import defaultdict
from PIL import Image
from PIL.ImageQt import rgb 
save_root = f"/home2/qrchen/embodied-datasets/Trajectories/{dataset}"
os.makedirs(os.path.join(save_root, "samples"), exist_ok=True)


target_fps = 1
fps=10
frame_skip = 10

# 每个task 最多的 trajectory 数
MAX_TRAJECTORIES_PER_TASK = 2
# 最多生成多少个task，如果一共有100 个task，那么我们只会输出 MAX_TASKS 个task
MAX_TASKS = 5


# 1. 统计每个任务的轨迹数
file_name_list = [] # 每个task 的 trajectory 数
task_counters = defaultdict(list) # 每个task 的 trajectory 数

# 2. 遍历每个轨迹，统计每个任务的轨迹数
for episode_idx, episode in tqdm(enumerate(ds)):

    # 2.1 如果任务数超过最大任务数，则跳过
    if len(task_counters.keys()) > MAX_TASKS:
        break

    task = ''
    step_0 = list(episode['steps'])[0]# 获取轨迹的第一个step
    # 2.2 获取任务 遍历 嵌套字典
    if 'natural_language_instruction' in step_0['observation']: 
        task = step_0['observation']['natural_language_instruction'].numpy().decode('utf-8')
    elif 'language_instruction' in step_0:
        task = step_0['language_instruction'].numpy().decode('utf-8')
    else:
        instruction_bytes = step_0["observation"]["instruction"]
        instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
        task = tf.strings.split(instruction_encoded, "\x00")[0].numpy().decode('utf-8')

    # 2.3 如果任务为空，则跳过  
    if not len(task):
        continue


    # 2.4 如果任务数超过最大任务数，则跳过
    if task in task_counters and len(task_counters[task]) >= MAX_TRAJECTORIES_PER_TASK:
        continue
    

    rgb_images = []
    for step_index ,step in enumerate(episode['steps']):
        if step_index % frame_skip == 0:
            obs = step[DATASET_CONFIG["observation_field"]]
            rgb_img, depth_img = process_single_step(obs, DATASET_CONFIG) #这里掉用了 process_single_step 函数，这个函数在 debug-xintong.ipynb 中
            rgb_images.append(rgb_img)

    
    if task not in task_counters or len(task_counters[task]) < MAX_TRAJECTORIES_PER_TASK:
        if rgb_images:
            task_filename = task.replace(" ", "_").replace(".", "")
            
            current_count = len(task_counters[task])
            gif_path= os.path.join(save_root, "samples", task_filename, f"{current_count}.gif")
            os.makedirs(os.path.dirname(gif_path), exist_ok=True)
            task_counters[task].append(f"{current_count}.gif")
            display.Image(as_gif(rgb_images, gif_path, RESIZE_FACTOR))


task_counters

5it [00:13,  1.72s/it]2025-07-23 14:37:49.658274: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
8it [00:16,  2.03s/it]


defaultdict(list,
            {'Drag strainer forwards.': ['0.gif', '1.gif'],
             'Drag strainer right to left.': ['0.gif'],
             'Drag mug forwards.': ['0.gif'],
             'Flap open oven.': ['0.gif'],
             'Drag mug backwards.': ['0.gif', '1.gif'],
             'Drag mug right to left.': ['0.gif']})

## Modify javascript

In [None]:
from jinja2 import Template


with open('/home2/qrchen/embodied-datasets/Templates/script.js', 'r', encoding='utf-8') as f:
    js_template = Template(f.read())

filled_js = js_template.render(
    gif_paths=pprint.pformat(dict(task_counters))
)

with open(os.path.join(save_root, 'script.js'), 'w', encoding='utf-8') as f:
    f.write(filled_js)

## Modify html

In [None]:
import pandas as pd
df = pd.read_csv('/home2/qrchen/embodied-datasets/metadata.csv')

with open("/home2/qrchen/embodied-datasets/Templates/index.html", "r", encoding="utf-8") as f:
    html_content = f.read()

dataset_name_in_csv = DATASET_CONFIG["dataset_name_in_csv"]
metadata = df[df['Datasets'] == dataset_name_in_csv] #得到这一行的数据



episodes = metadata['#Trajectories'].item().replace('\n', '<br>')
contents = f"""
<h3>
Basic Information
</h3>
<p>
    <span class="highlight-label">#Tasks:</span> {metadata['#Tasks'].item()}
    <br>
    <span class="highlight-label">#Scenes:</span> {metadata['#Scenes'].item()}
    <br>
    <span class="highlight-label">#Episodes:</span> {episodes}
    <br>
    <span class="highlight-label">Avg Frames per episode:</span> {metadata['Avg. frames/ trajectory'].item()}
    <br>
    <span class="highlight-label">Instruction:</span> {metadata['Language instructions'].item()}
</p>
<br>
<br>
<h3>
How to modify?
</h3>
<p>
{metadata['How to modify it to align with our overall objectives?'].item()}
</p>
"""

html_content = html_content.replace("==TITLE==", metadata["Datasets"].item())
html_content = html_content.replace("==Contents==", contents)

# 保存到新路径
with open(os.path.join(save_root, 'index.html'), 'w', encoding='utf-8') as f:
    f.write(html_content)