In [1]:
import os
import cv2
import imageio
import torch
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from cliport.utils import utils
from cliport import tasks
from cliport.dataset import RavensDataset
from cliport.environments.environment import Environment

pybullet build time: Nov 28 2023 23:45:17


export PATH=/path/to/specific/jupyter/bin:$PATH

In [2]:
n_demos = 4
mode = "val"
task = 'put-block-in-bowl-seen-colors'

In [6]:
root_dir = '/home/a/acw799/cliport'
data_dir = '/home/a/acw799'
assets_root = os.path.join(root_dir, 'cliport/environments/assets/')
config_file = 'eval.yaml' 
vcfg = utils.load_hydra_config(os.path.join(root_dir, f'cliport/cfg/{config_file}'))
vcfg['data_dir'] = os.path.join(data_dir, 'data')
vcfg['mode'] = mode
vcfg['task'] = task
vcfg['train_config'] = "cliport/cfg/inference.yaml"
tcfg = utils.load_hydra_config(vcfg['train_config'])

# Load dataset
ds = RavensDataset(os.path.join(vcfg['data_dir'], f'{vcfg["task"]}-{vcfg["mode"]}'), 
                   tcfg, 
                   n_demos=n_demos,
                   augment=False)

Key 'dataset.in_shape' not found, using default value.


In [7]:
vcfg['data_dir']

'/home/a/acw799/data'

In [8]:
# Initialize environment and task.
env = Environment(
    assets_root,
    disp=False,
    shared_memory=False,
    hz=480,
    record_cfg=False
)

In [9]:
# 创建文件夹路径
color_dir = "/home/a/acw799/cliport/data/unseen/color"
depth_dir = "/home/a/acw799/cliport/data/unseen/depth"

# 确保文件夹存在
os.makedirs(color_dir, exist_ok=True)
os.makedirs(depth_dir, exist_ok=True)

In [11]:
for i in tqdm(range(n_demos), desc="saving color and depth:", total=n_demos):
    episode, seed = ds.load(i)

     # Set task
    task_name = vcfg['task']
    task = tasks.names[task_name]()
    task.mode = mode

    # Set environment
    env.seed(seed)
    env.set_task(task)
    obs = env.reset()
    info = env.info
    reward = 0
         
    # Get batch
    batch = ds.process_goal((obs, None, reward, info), perturb_params=None)
    
    # Get color and depth inputs
    img = batch['img']
    img = torch.from_numpy(img)
    color = np.uint8(img.detach().cpu().numpy())[:,:,:3]
    color = color.transpose(1,0,2)

    depth = np.array(img.detach().cpu().numpy())[:,:,3]
    depth = depth.transpose(1,0)

    # save .png files
    color_filename = os.path.join(color_dir, f"{i:05}.png")
    depth_filename = os.path.join(depth_dir, f"{i:05}.png")
    # 保存彩色图像
    imageio.imwrite(color_filename, color)  # color(H, W, 3)
    
    # 保存深度图像（以灰度图保存）
    normalized_depth = (255 * (depth / depth.max())).astype(np.uint8)  # 归一化到 [0, 255]
    imageio.imwrite(depth_filename, normalized_depth)

    # fig, axs = plt.subplots(1, 2, figsize=(13, 7))
    # # Display input color
    # axs[0].imshow(color)
    # axs[0].axes.xaxis.set_visible(False)
    # axs[0].axes.yaxis.set_visible(False)
    # axs[0].set_title('Input RGB sample')
    
    # # Display input depth
    # axs[1].imshow(normalized_depth)
    # axs[1].axes.xaxis.set_visible(False)
    # axs[1].axes.yaxis.set_visible(False)        
    # axs[1].set_title('Input Depth sample')
    # break

saving color and depth:: 100%|██████████| 4/4 [00:24<00:00,  6.17s/it]
