In [1]:
import zarr 
# !pip install zarr[jupyter]
# !pip install plotly

In [5]:
# Zarr Basic 

import numpy as np
# create
store = zarr.MemoryStore()
root = zarr.group(store=store)
data = root.require_group("data", overwrite=False)
cam_data = data.require_dataset(name="cam", shape=(50, 224, 224, 3), chunks=(1, 224, 224, 3), compressor=None, dtype=np.uint8) 
eff_pos = data.require_dataset(name="eff_pos", shape=(50, 3), chunks=(25, 3), compressor=None, dtype=np.float32)

# modify
cam_data[0, 0, 0, 0] = 5
eff_pos[0, 0] = 1.0
root.tree()

# save 
output = "tmp.zip"

store = zarr.ZipStore(output, mode="w")
for key, value in root["data"].items():
    this_path = '/data/' + key
    print(key, type(value), this_path)
    zarr.copy_store(source=root.store, dest=store, source_path=this_path, dest_path=this_path, if_exists="replace")

root_x = zarr.group(store)
root_x.tree()
store.close()

# load
with zarr.ZipStore(output, mode="r") as store:
    root_x = zarr.group(store)

Tree(nodes=(Node(disabled=True, name='/', nodes=(Node(disabled=True, name='data', nodes=(Node(disabled=True, i…

In [2]:
import os 
import pathlib
import sys
import os
import pathlib
import zarr


home_dir = str(pathlib.Path.home())
repo_dir = str(pathlib.Path(os.getcwd()).parent.parent.parent)
sys.path.append(repo_dir)
os.chdir(repo_dir)
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.codecs.imagecodecs_numcodecs import register_codecs, JpegXl

register_codecs()
img_compressor = JpegXl(level=99, numthreads=1)

dataset_path = os.path.join(home_dir, "example_demo_session/dataset.zarr.zip")
print(dataset_path)
# load into memory store
with zarr.ZipStore(dataset_path, mode='r') as zip_store:
    replay_buffer = ReplayBuffer.copy_from_store(
        src_store=zip_store, 
        store=zarr.MemoryStore()
    )        

/home/ubuntu/example_demo_session/dataset.zarr.zip


In [3]:
replay_buffer

/
 ├── data
 │   ├── camera0_rgb (2315, 224, 224, 3) uint8
 │   ├── robot0_demo_end_pose (2315, 6) float64
 │   ├── robot0_demo_start_pose (2315, 6) float64
 │   ├── robot0_eef_pos (2315, 3) float32
 │   ├── robot0_eef_rot_axis_angle (2315, 3) float32
 │   └── robot0_gripper_width (2315, 1) float32
 └── meta
     └── episode_ends (5,) int64

In [5]:
cam_idx = 0
gripper_idx = 0

rgb_key = f"camera{cam_idx}_rgb"
gripper_pos_key = f"robot{gripper_idx}_eef_pos"
gripper_rot_axis_angle_key = f"robot{gripper_idx}_eef_rot_axis_angle"
gripper_width_key = f"robot{gripper_idx}_gripper_width"


[0.08381866] [0.05877358]


In [8]:
def rot_axis_angle_to_rot_mat(rot_axis_angle):
    theta = np.linalg.norm(rot_axis_angle)
    if theta == 0:
        return np.eye(3)
    axis = rot_axis_angle / theta
    a = np.cos(theta / 2)
    b, c, d = -axis * np.sin(theta / 2)
    return np.array([
        [a * a + b * b - c * c - d * d, 2 * (b * c - a * d), 2 * (b * d + a * c)],
        [2 * (b * c + a * d), a * a - b * b + c * c - d * d, 2 * (c * d - a * b)],
        [2 * (b * d - a * c), 2 * (c * d + a * b), a * a - b * b - c * c + d * d]
    ])

In [9]:
# https://plotly.com/python/animations/
# go.Figure(data, layout, frames)
# 1. data must be a list of dicts referred to as "traces".
# 2. layout must be a dict, containing attributes that control positioning and configuration of non-data-related parts of the figure such as:
# 3. frames must be a list of dicts that define sequential frames in an animated plot. Each frame contains its own data attribute as well as other parameters. Animations are usually triggered and controlled via controls defined in layout.sliders and/or layout.updatemenus 
# data = [go.Scatter3d(x=xyz[:, 0], y=xyz[:, 1], z=xyz[:, 2], mode="lines+markers"), ...]
# go.Layout(axis, yaxis, updatemenus)
# frames = [go.Frame(), go.Frame(), ....]
# Along with data and layout, frames can be added as a key in a figure object. The frames key points to a list of figures, each of which will be cycled through when animation is triggered. Each frame can contain a data key, which is a list of traces that will be drawn on the plot. The data key can also

In [12]:
import numpy as np
import plotly.graph_objects as go
import plotly

def create_tcp_trajectory_animation(pos, gripper_width, output_file):
    # Create the figure and 3D scene
    fig = go.Figure()

    # Add the TCP trajectory trace
    tcp_trace = go.Scatter3d(
            x=pos[:, 0],
            y=pos[:, 1],
            z=pos[:, 2],
            mode='lines',
            line=dict(color='black', width=2),
            name='TCP Trajectory',
            showlegend=False
        )
    fig.add_trace(tcp_trace)
    # Add extra one, since the first one will be replaced by the data in frame during animation
    fig.add_trace(tcp_trace)

    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='manual',
            aspectratio=dict(x=1, y=1, z=1),
            camera=dict(
                    up=dict(x=0, y=0, z=1),
                    center=dict(x=0, y=0.5, z=0.15),
                    eye=dict(x=0, y=-1.5, z=0.15)
                )
        ),
        title='TCP Trajectory',
    )

    # Create the frames for the video
    red_color = np.array([255, 0, 0])
    green_color = np.array([0, 0, 255])
    gripper_width_max, gripper_width_min = max(gripper_width), min(gripper_width)
    gripper_width_nm = (gripper_width - gripper_width_min) / (gripper_width_max - gripper_width_min)

    frames = []
    for i in range(len(pos)):
        gw = gripper_width_nm[i]
        color = red_color * gw + green_color * (1 - gw) 
        frame = go.Frame(
            data = go.Scatter3d(
                x=[pos[i, 0]],
                y=[pos[i, 1]],
                z=[pos[i, 2]],
                mode='markers',
                marker=dict(color=color, size=10),
                name='TCP [color is gripper width]',
                showlegend=True
            )
        )
        frames.append(frame)
    fig.frames = frames
    fps = 100
    frame_control = {
        "frame": {
            "duration": 1/fps, 
            "redraw": True,
        }
    }
    fig.update_layout(updatemenus=[dict(type='buttons', showactive=True, buttons=[dict(label='Play', method='animate', args=[None, frame_control])])])

    print(f"Saving to {output_file}")
    # Save the video
    fig.write_html(output_file)

    return fig

def create_tcp_trajectory(pos, rot_axis_angle, gripper_width, output_file):
    # Create the figure and 3D scene
    fig = go.Figure()

    # needs to scale the axis_angle 
    vec_scale = 0.02
    red_color = np.array([255, 0, 0])
    green_color = np.array([0, 0, 255])

    gripper_width_max, gripper_width_min = max(gripper_width), min(gripper_width)
    gripper_width_nm = (gripper_width - gripper_width_min) / (gripper_width_max - gripper_width_min)

    # Add the rotation axis-angle traces
    axis_colors = ["red", "green", "blue"]
    for i in range(len(pos)):
        # Add position; 
        color = red_color * gripper_width_nm[i] + green_color * (1 - gripper_width_nm[i])

        fig.add_trace(
            go.Scatter3d(
                x=[pos[i, 0]],
                y=[pos[i, 1]],
                z=[pos[i, 2]],
                mode='markers',
                marker=dict(color=color, size=2),
                name='TCP Trajectory',
                showlegend=False
            )
        )
        
        # Add the rotation axis-angle traces 
        rotmat = rot_axis_angle_to_rot_mat(rot_axis_angle[i])
        start = pos[i]
        for j in range(2):
            end = start + rotmat[:, j] * vec_scale 
            fig.add_trace(go.Scatter3d(
                x=[start[0], end[0]],
                y=[start[1], end[1]],
                z=[start[2], end[2]],
                mode='lines',
                line=dict(color=axis_colors[j], width=1),
                showlegend=False,
                # name=f'Rotation Axis-Angle {i+1}'
            )
            )

    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='manual',
            aspectratio=dict(x=1, y=1, z=1),
            camera=dict(
                    up=dict(x=0, y=0, z=1),
                    center=dict(x=0, y=0.5, z=0.15),
                    eye=dict(x=0, y=-1.5, z=0.15)
                )
        ),
        title='TCP Trajectory',
    )

    print(f"Saving to {output_file}")
    fig.write_html(output_file)

    return fig
    
ii = 0
start_idx = 0
end_idx = replay_buffer.meta.episode_ends[ii]

rgb = replay_buffer.data[rgb_key][start_idx:end_idx]
pos = replay_buffer.data[gripper_pos_key][start_idx:end_idx]
rot_axis_angle = replay_buffer.data[gripper_rot_axis_angle_key][start_idx:end_idx]
gripper_width = replay_buffer.data[gripper_width_key][start_idx:end_idx]
print(max(gripper_width), min(gripper_width))
fig_static = create_tcp_trajectory(pos, rot_axis_angle, gripper_width, 'tcp_trajectory_single_demo_static.html')
fig_animation = create_tcp_trajectory_animation(pos, gripper_width, 'tcp_trajectory_single_demo_animation.html')

Saving to tcp_trajectory_single_demo_static.html
Saving to tcp_trajectory_single_demo_animation.html


In [None]:
# !cp tcp_trajectory_single_demo_animation.html /nfs/jchen/test/umi/visualizations/
# !cp tcp_trajectory_single_demo_static.html /nfs/jchen/test/umi/visualizations/

In [17]:
# for each episode, select a color, and plot its trajectory.
def  create_tcp_trajectory_all_episodes(replay_buffer, output_file):

    colors = plotly.colors.qualitative.Dark24
    fig = go.Figure()
    start_idx = None
    n_episodes = replay_buffer.meta.episode_ends.shape[0]
    colors = plotly.colors.qualitative.Dark24
    for episode_idx in range(n_episodes):
        if start_idx is None:
            start_idx = 0
        else:
            start_idx = end_idx
        end_idx = replay_buffer.meta.episode_ends[episode_idx]

        pos = replay_buffer.data[gripper_pos_key][start_idx:end_idx]
        fig.add_trace(
            go.Scatter3d(
                x=pos[:, 0],
                y=pos[:, 1],
                z=pos[:, 2],
                mode='lines',
                line=dict(color=colors[episode_idx], width=2),
                name=f'Episode {episode_idx}',
                showlegend=True
            )
        )
    
    print(f"Saving to {output_file}")
    fig.write_html(output_file)
    return fig

create_tcp_trajectory_all_episodes(replay_buffer, 'tcp_trajectory_all_episodes.html')

Saving to tcp_trajectory_all_episodes.html


In [18]:
!cp tcp_trajectory_all_episodes.html /nfs/jchen/test/umi/visualizations/

In [None]:
# you can plot multiple traces
# 1. plot all demos in the same plot? Are they in the same coordinate system? not sure... Question

# Question, is different camera data in the same coordinate system?. First, you can plot all the camera data in the same plot.
# Question, what is the frequency of the camera? what is the frequency of the motion? How about action?

In [None]:
#TODO 
# 1. upload the html to s3, with corresponding video? can gripper size?
# 2. visualize the normalized gripper?