In [2]:
import sys

sys.path.append("/scr/aliang80/robot_pref/")
from utils.data import load_tensordict
from pathlib import Path
import pickle
import numpy as np
import imageio
import matplotlib.pyplot as plt

data_path = "/scr2/shared/pref/datasets/robomimic/lift/lift_mg_image_dense.pt"
segment_length = 64

data = load_tensordict(data_path)
dtw_matrix_file = Path(data_path).parent / f"dtw_matrix_{segment_length}.pkl"

dtw_matrix, segment_start_end = pickle.load(open(dtw_matrix_file, "rb"))
print(dtw_matrix.shape)

Fields: ['action', 'episode', 'image', 'obs', 'reward']
(4500, 4500)


In [7]:
segment_idx = 2100

# find top-k similar segments based on DTW distance, excluding the segment itself
top_k = 10
similar_segments = np.argsort(dtw_matrix[segment_idx])[: top_k + 1]
similar_segments = [s for s in similar_segments if s != segment_idx]

# visualize the similar segments
print(similar_segments)

# Get all segments
segments = [
    data["image"][segment_start_end[i][0] : segment_start_end[i][1]]
    for i in similar_segments
]
segments = [segment.cpu().numpy() for segment in segments]
segments = np.stack(segments, axis=0)  # Shape: (N, T, H, W, C)

# Create a grid for each timestep
T = segments.shape[1]  # number of timesteps
H, W, C = segments.shape[2:]  # height, width, channels
N = segments.shape[0]  # number of segments

# Create a grid of frames
grid_frames = []
for t in range(T):
    # Get all frames at this timestep
    frames = segments[:, t]  # Shape: (N, H, W, C)

    # Create a horizontal grid
    grid_frame = np.concatenate(frames, axis=1)  # Shape: (H, N*W, C)

    grid_frames.append(grid_frame)

# Stack all grid frames
grid_frames = np.stack(grid_frames, axis=0)  # Shape: (T, H, N*W, C)
print(f"Grid frames shape: {grid_frames.shape}")

# Save as video
imageio.mimsave("similar_segments_grid.mp4", grid_frames, fps=5)

# Display the video
from IPython.display import HTML
from base64 import b64encode


def show_video(video_path):
    mp4 = open(video_path, "rb").read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return HTML(f"""
    <video width=1000 controls>
        <source src="{data_url}" type="video/mp4">
    </video>
    """)


show_video("similar_segments_grid.mp4")



[2592, 2451, 1932, 2634, 2253, 1803, 2337, 2496, 1890, 2202]
Grid frames shape: (64, 84, 840, 3)
