# Tessera: Generate Embeddings from LeRobot Datasets

This notebook generates CLIP embeddings from LeRobot v3.0 datasets for visualization in [Tessera](https://github.com/arpitg1304/tessera).

**Features:**
- Generate CLIP embeddings from any LeRobot dataset on HuggingFace
- Optional thumbnail and GIF previews for hover visualization
- Configurable embedding modes (single frame or start+end)
- Export to HDF5 format compatible with Tessera

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pVsTizT8Ec1iST0tyNDzhgh4h6YUtZUh)

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q torch torchvision
!pip install -q git+https://github.com/openai/CLIP.git
!pip install -q h5py pillow pandas pyarrow av huggingface_hub tqdm

In [None]:
# Verify GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

Configure your embedding generation settings below. You can use the interactive widgets or modify the values directly.

In [None]:
# =============================================================================
# CONFIGURATION - Modify these values or use the interactive widgets below
# =============================================================================

# Dataset settings
DATASET_NAME = "lerobot/pusht"  # @param {type:"string"}
MAX_EPISODES = 0  # @param {type:"integer"} 0 = all episodes

# Device settings
DEVICE = "auto"  # @param ["auto", "cuda", "cpu"]

# Embedding mode
EMBEDDING_MODE = "start_end"  # @param ["single", "start_end"]

# Thumbnail settings
GENERATE_THUMBNAILS = True  # @param {type:"boolean"}
THUMBNAIL_QUALITY = "medium"  # @param ["low", "medium", "high"]

# GIF settings
GENERATE_GIFS = True  # @param {type:"boolean"}
GIF_QUALITY = "medium"  # @param ["low", "medium", "high"]

# Output settings
OUTPUT_FILENAME = ""  # @param {type:"string"} Leave empty for auto-generated name

In [None]:
# Quality presets
THUMBNAIL_PRESETS = {
    "low": {"size": (64, 64), "quality": 60},
    "medium": {"size": (128, 128), "quality": 80},
    "high": {"size": (192, 192), "quality": 90}
}

GIF_PRESETS = {
    "low": {"size": (64, 64), "fps": 6, "max_frames": 8},
    "medium": {"size": (128, 128), "fps": 8, "max_frames": 16},
    "high": {"size": (192, 192), "fps": 10, "max_frames": 24}
}

# Apply presets
thumb_config = THUMBNAIL_PRESETS[THUMBNAIL_QUALITY]
gif_config = GIF_PRESETS[GIF_QUALITY]

# Auto-detect device
if DEVICE == "auto":
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Generate output filename if not specified
if not OUTPUT_FILENAME:
    dataset_short_name = DATASET_NAME.split("/")[-1]
    suffix = "_gifs" if GENERATE_GIFS else ("_thumbs" if GENERATE_THUMBNAILS else "")
    OUTPUT_FILENAME = f"{dataset_short_name}_embeddings{suffix}.h5"

print("Configuration:")
print(f"  Dataset: {DATASET_NAME}")
print(f"  Device: {DEVICE}")
print(f"  Embedding mode: {EMBEDDING_MODE}")
print(f"  Thumbnails: {GENERATE_THUMBNAILS} ({THUMBNAIL_QUALITY})")
print(f"  GIFs: {GENERATE_GIFS} ({GIF_QUALITY})")
print(f"  Output: {OUTPUT_FILENAME}")

## 3. Download Dataset

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path
import json

# Download dataset
print(f"Downloading {DATASET_NAME}...")
dataset_path = Path(snapshot_download(
    repo_id=DATASET_NAME,
    repo_type="dataset",
    local_dir=f"./datasets/{DATASET_NAME.replace('/', '_')}"
))
print(f"Downloaded to: {dataset_path}")

# Load dataset info
info_path = dataset_path / "meta" / "info.json"
with open(info_path) as f:
    dataset_info = json.load(f)

print(f"\nDataset info:")
print(f"  Total episodes: {dataset_info.get('total_episodes', 'unknown')}")
print(f"  Total frames: {dataset_info.get('total_frames', 'unknown')}")
print(f"  FPS: {dataset_info.get('fps', 'unknown')}")
print(f"  Video keys: {list(dataset_info.get('video_keys', []))}")

## 4. Load CLIP Model

In [None]:
import clip

print(f"Loading CLIP model on {DEVICE}...")
clip_model, clip_preprocess = clip.load("ViT-B/32", device=DEVICE)
clip_model.eval()
print("CLIP model loaded!")

## 5. Helper Functions

In [None]:
import io
import av
import numpy as np
from PIL import Image
from tqdm.auto import tqdm

def extract_frame_from_video(video_path: str, frame_idx: int) -> np.ndarray:
    """Extract a single frame from a video file."""
    container = av.open(str(video_path))
    stream = container.streams.video[0]
    
    # Seek to approximate position
    target_pts = int(frame_idx * stream.duration / stream.frames) if stream.frames else 0
    container.seek(target_pts, stream=stream)
    
    current_frame = 0
    for frame in container.decode(video=0):
        if current_frame >= frame_idx:
            container.close()
            return frame.to_ndarray(format='rgb24')
        current_frame += 1
    
    container.close()
    return None


def extract_frames_for_gif(video_path: str, start_frame: int, end_frame: int, 
                           max_frames: int = 16) -> list:
    """Extract evenly spaced frames for GIF creation."""
    total_frames = end_frame - start_frame
    if total_frames <= max_frames:
        frame_indices = list(range(start_frame, end_frame))
    else:
        frame_indices = np.linspace(start_frame, end_frame - 1, max_frames, dtype=int).tolist()
    
    frames = []
    container = av.open(str(video_path))
    
    current_frame = 0
    for frame in container.decode(video=0):
        if current_frame in frame_indices:
            frames.append(frame.to_ndarray(format='rgb24'))
        if current_frame > max(frame_indices):
            break
        current_frame += 1
    
    container.close()
    return frames


def create_thumbnail(image: np.ndarray, size: tuple, quality: int) -> bytes:
    """Create a JPEG thumbnail from an image."""
    pil_image = Image.fromarray(image)
    pil_image = pil_image.resize(size, Image.LANCZOS)
    buffer = io.BytesIO()
    pil_image.save(buffer, format='JPEG', quality=quality)
    return buffer.getvalue()


def create_gif(frames: list, size: tuple, fps: int) -> bytes:
    """Create an animated GIF from frames."""
    pil_frames = [Image.fromarray(f).resize(size, Image.LANCZOS) for f in frames]
    
    buffer = io.BytesIO()
    pil_frames[0].save(
        buffer,
        format='GIF',
        save_all=True,
        append_images=pil_frames[1:],
        duration=int(1000 / fps),
        loop=0
    )
    return buffer.getvalue()


def get_clip_embedding(image: np.ndarray) -> np.ndarray:
    """Generate CLIP embedding for an image."""
    pil_image = Image.fromarray(image)
    image_input = clip_preprocess(pil_image).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        features = clip_model.encode_image(image_input)
        features = features / features.norm(dim=-1, keepdim=True)
    
    return features.cpu().numpy().flatten()

print("Helper functions loaded!")

## 6. Parse Dataset Structure

In [None]:
import pandas as pd
from pathlib import Path

# Find video key to use - handle different LeRobot info.json formats
video_keys = dataset_info.get('video_keys', [])

# If video_keys not present, extract from features (newer format)
if not video_keys and 'features' in dataset_info:
    features = dataset_info['features']
    video_keys = [k for k, v in features.items() 
                  if isinstance(v, dict) and v.get('dtype') == 'video']
    print(f"Extracted video keys from features: {video_keys}")

# Preferred video keys to try
preferred_keys = [
    'observation.images.top', 
    'observation.image', 
    'observation.images.wrist_image_left',
    'observation.images.exterior_image_1_left',
    'observation.images.exterior_image_2_left'
]

video_key = None
for key in preferred_keys:
    if key in video_keys:
        video_key = key
        break
if video_key is None and video_keys:
    video_key = video_keys[0]

print(f"Available video keys: {video_keys}")
print(f"Using video key: {video_key}")

# Load episode info
episodes_parquet = list((dataset_path / "meta" / "episodes").glob("**/*.parquet"))
episode_df = pd.concat([pd.read_parquet(p) for p in episodes_parquet])
episode_df = episode_df.sort_values('episode_index').reset_index(drop=True)

# Load frame data to get episode boundaries
data_parquets = list((dataset_path / "data").glob("**/*.parquet"))
frame_df = pd.concat([pd.read_parquet(p) for p in data_parquets])
frame_df = frame_df.sort_values(['episode_index', 'frame_index']).reset_index(drop=True)

# Get episode boundaries
episode_info = []
for ep_idx in episode_df['episode_index'].unique():
    ep_frames = frame_df[frame_df['episode_index'] == ep_idx]
    episode_info.append({
        'episode_index': ep_idx,
        'start_frame': ep_frames['frame_index'].min(),
        'end_frame': ep_frames['frame_index'].max() + 1,
        'length': len(ep_frames)
    })

episode_info = pd.DataFrame(episode_info)
total_episodes = len(episode_info)

# Apply max episodes limit
if MAX_EPISODES > 0 and MAX_EPISODES < total_episodes:
    episode_info = episode_info.head(MAX_EPISODES)
    print(f"Limited to {MAX_EPISODES} episodes (out of {total_episodes})")
else:
    print(f"Processing all {total_episodes} episodes")

print(f"\nEpisode stats:")
print(f"  Min length: {episode_info['length'].min()} frames")
print(f"  Max length: {episode_info['length'].max()} frames")
print(f"  Mean length: {episode_info['length'].mean():.1f} frames")

## 7. Find Video Files

In [None]:
# Find video files for the selected video key
if video_key is None:
    raise ValueError(
        f"No video keys found in dataset! "
        f"Available keys in info.json: {list(dataset_info.keys())}\n"
        f"This dataset may not contain video data."
    )

# Try different path formats (LeRobot v3 uses nested structure)
video_dir = dataset_path / "videos" / video_key.replace('.', '/')
if not video_dir.exists():
    # Try with dots preserved
    video_dir = dataset_path / "videos" / video_key
if not video_dir.exists():
    # Try finding any videos directory
    videos_base = dataset_path / "videos"
    if videos_base.exists():
        # List available video directories
        available = list(videos_base.glob("*"))
        print(f"Video key '{video_key}' not found. Available directories:")
        for d in available[:5]:
            print(f"  {d.name}")
        raise ValueError(f"Could not find video directory for key: {video_key}")
    else:
        raise ValueError(f"No 'videos' directory found in dataset at {dataset_path}")

video_files = sorted(video_dir.glob("**/*.mp4"))

if len(video_files) == 0:
    raise ValueError(f"No .mp4 files found in {video_dir}")

print(f"Found {len(video_files)} video file(s) in {video_dir.name}")
for vf in video_files[:3]:
    print(f"  {vf.name}")
if len(video_files) > 3:
    print(f"  ... and {len(video_files) - 3} more")

## 8. Generate Embeddings

In [None]:
# Storage for results
embeddings = []
episode_ids = []
thumbnails = [] if GENERATE_THUMBNAILS else None
gifs = [] if GENERATE_GIFS else None
metadata = {
    'episode_length': []
}

# For LeRobot v3, videos are chunked - we need to map episodes to video files
# Typically: chunk-000/file-000.mp4 contains frames for multiple episodes

# Build a mapping of global frame index to video file
# For simplicity, if there's only one video file, use it for all
if len(video_files) == 1:
    video_path = video_files[0]
    print(f"Using single video file: {video_path.name}")
else:
    # Multiple video files - need to handle chunking
    video_path = video_files[0]  # Simplified - may need adjustment for your dataset
    print(f"Warning: Multiple video files found, using first: {video_path.name}")

print(f"\nGenerating embeddings...")
print(f"  Mode: {EMBEDDING_MODE}")
print(f"  Thumbnails: {GENERATE_THUMBNAILS}")
print(f"  GIFs: {GENERATE_GIFS}")
print()

In [None]:
# Process each episode
for idx, row in tqdm(episode_info.iterrows(), total=len(episode_info), desc="Processing episodes"):
    ep_idx = row['episode_index']
    start_frame = row['start_frame']
    end_frame = row['end_frame']
    ep_length = row['length']
    
    # Calculate frame indices
    middle_frame = start_frame + ep_length // 2
    last_frame = end_frame - 1
    
    try:
        # Extract frames and generate embedding based on mode
        if EMBEDDING_MODE == "single":
            frame = extract_frame_from_video(video_path, middle_frame)
            if frame is None:
                print(f"Warning: Could not extract frame for episode {ep_idx}")
                continue
            embedding = get_clip_embedding(frame)
        
        elif EMBEDDING_MODE == "start_end":
            start_frame_img = extract_frame_from_video(video_path, start_frame)
            end_frame_img = extract_frame_from_video(video_path, last_frame)
            
            if start_frame_img is None or end_frame_img is None:
                print(f"Warning: Could not extract frames for episode {ep_idx}")
                continue
            
            start_emb = get_clip_embedding(start_frame_img)
            end_emb = get_clip_embedding(end_frame_img)
            embedding = np.concatenate([start_emb, end_emb])
            frame = start_frame_img  # Use start frame for thumbnail
        
        embeddings.append(embedding)
        episode_ids.append(f"episode_{ep_idx:05d}")
        metadata['episode_length'].append(ep_length)
        
        # Generate thumbnail
        if GENERATE_THUMBNAILS:
            thumb_bytes = create_thumbnail(
                frame, 
                thumb_config['size'], 
                thumb_config['quality']
            )
            thumbnails.append(np.frombuffer(thumb_bytes, dtype=np.uint8))
        
        # Generate GIF
        if GENERATE_GIFS:
            gif_frames = extract_frames_for_gif(
                video_path, 
                start_frame, 
                end_frame,
                gif_config['max_frames']
            )
            if gif_frames:
                gif_bytes = create_gif(
                    gif_frames,
                    gif_config['size'],
                    gif_config['fps']
                )
                gifs.append(np.frombuffer(gif_bytes, dtype=np.uint8))
            else:
                # Fallback: empty GIF placeholder
                gifs.append(np.array([], dtype=np.uint8))
    
    except Exception as e:
        print(f"Error processing episode {ep_idx}: {e}")
        continue

print(f"\nProcessed {len(embeddings)} episodes successfully!")

## 9. Save to HDF5

In [None]:
import h5py

# Convert to arrays
embeddings_array = np.array(embeddings, dtype=np.float32)

print(f"Saving to {OUTPUT_FILENAME}...")
print(f"  Episodes: {len(embeddings)}")
print(f"  Embedding dimension: {embeddings_array.shape[1]}")

with h5py.File(OUTPUT_FILENAME, 'w') as f:
    # Required: embeddings and episode_ids
    f.create_dataset('embeddings', data=embeddings_array, compression='gzip')
    f.create_dataset('episode_ids', data=episode_ids)
    
    # Metadata
    meta_group = f.create_group('metadata')
    meta_group.create_dataset('episode_length', data=metadata['episode_length'])
    
    # Add dataset name as metadata
    dataset_labels = [DATASET_NAME.split('/')[-1]] * len(embeddings)
    meta_group.create_dataset('dataset', data=dataset_labels)
    
    # Optional: thumbnails
    if GENERATE_THUMBNAILS and thumbnails:
        vlen_dtype = h5py.vlen_dtype(np.uint8)
        thumb_ds = f.create_dataset('thumbnails', (len(thumbnails),), dtype=vlen_dtype)
        for i, thumb in enumerate(thumbnails):
            thumb_ds[i] = thumb
        thumb_size_mb = sum(len(t) for t in thumbnails) / 1024 / 1024
        print(f"  Thumbnails: {len(thumbnails)} ({thumb_size_mb:.2f} MB)")
    
    # Optional: GIFs
    if GENERATE_GIFS and gifs:
        vlen_dtype = h5py.vlen_dtype(np.uint8)
        gif_ds = f.create_dataset('gifs', (len(gifs),), dtype=vlen_dtype)
        for i, gif in enumerate(gifs):
            gif_ds[i] = gif
        gif_size_mb = sum(len(g) for g in gifs) / 1024 / 1024
        print(f"  GIFs: {len(gifs)} ({gif_size_mb:.2f} MB)")

# Report file size
file_size_mb = Path(OUTPUT_FILENAME).stat().st_size / 1024 / 1024
print(f"\nOutput file: {OUTPUT_FILENAME}")
print(f"File size: {file_size_mb:.2f} MB")

## 10. Verify Output

In [None]:
# Verify the generated file
print("Verifying output file...\n")

with h5py.File(OUTPUT_FILENAME, 'r') as f:
    print("File structure:")
    def print_structure(name, obj):
        if isinstance(obj, h5py.Dataset):
            print(f"  {name}: {obj.shape} {obj.dtype}")
        else:
            print(f"  {name}/")
    f.visititems(print_structure)
    
    print(f"\nEmbeddings shape: {f['embeddings'].shape}")
    print(f"Episode IDs: {len(f['episode_ids'])}")
    
    if 'thumbnails' in f:
        print(f"Thumbnails: {len(f['thumbnails'])}")
    if 'gifs' in f:
        print(f"GIFs: {len(f['gifs'])}")

print("\n" + "="*50)
print("SUCCESS! Your embedding file is ready.")
print("="*50)

## 11. Download File

Run the cell below to download your embedding file.

In [None]:
# Download the file (works in Google Colab)
try:
    from google.colab import files
    files.download(OUTPUT_FILENAME)
    print(f"Downloading {OUTPUT_FILENAME}...")
except ImportError:
    print(f"Not running in Colab. File saved to: {OUTPUT_FILENAME}")
    print(f"\nYou can download it manually or upload directly to Tessera.")

## 12. Upload to Tessera (Optional)

You can upload directly to Tessera from this notebook.

In [None]:
# Optional: Upload directly to Tessera
TESSERA_HOST = "https://tessera.vlastudio.cloud"  # @param {type:"string"}
UPLOAD_TO_TESSERA = False  # @param {type:"boolean"}

if UPLOAD_TO_TESSERA:
    import requests
    
    print(f"Uploading to {TESSERA_HOST}...")
    
    with open(OUTPUT_FILENAME, 'rb') as f:
        response = requests.post(
            f"{TESSERA_HOST}/api/upload",
            files={'file': (OUTPUT_FILENAME, f, 'application/x-hdf5')}
        )
    
    if response.status_code == 200:
        result = response.json()
        project_id = result.get('project_id')
        print(f"\nUpload successful!")
        print(f"View your embeddings at: {TESSERA_HOST}/project/{project_id}")
    else:
        print(f"Upload failed: {response.status_code}")
        print(response.text)
else:
    print("Upload skipped. Set UPLOAD_TO_TESSERA = True to upload.")
    print(f"\nOr upload manually at: {TESSERA_HOST}")

---

## Next Steps

1. **Upload to Tessera**: Drag and drop your `.h5` file at [tessera.vlastudio.cloud](https://tessera.vlastudio.cloud)
2. **Explore**: Use the interactive scatter plot to explore your embeddings
3. **Sample**: Select diverse episodes using K-means or stratified sampling
4. **Export**: Download episode IDs for your training pipeline

For more information, visit the [Tessera GitHub repository](https://github.com/arpitg1304/tessera).