# Trajectory Data Inspector

This notebook allows you to inspect and visualize trajectory data collected by the `TrajectoryDataCollector`.

Features:
- Load and explore episode data
- Visualize joint trajectories and cartesian paths
- View RGB and depth images from the trajectory
- Analyze synchronization statistics
- Export trajectory summaries

In [None]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
from typing import Dict, List, Any, Optional
import ipywidgets as widgets
from IPython.display import display, HTML
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

## Configuration and Data Loading

In [None]:
class TrajectoryInspector:
    """Helper class for loading and analyzing trajectory data."""
    
    def __init__(self, data_directory: str = "rl_training_data"):
        self.data_directory = Path(data_directory)
        self.current_episode = None
        self.episode_data = None
        
    def list_episodes(self) -> List[str]:
        """List all available episodes in the data directory."""
        if not self.data_directory.exists():
            print(f"Data directory {self.data_directory} does not exist.")
            return []
        
        episodes = []
        for episode_dir in self.data_directory.iterdir():
            if episode_dir.is_dir() and (episode_dir / "episode_data.json").exists():
                episodes.append(episode_dir.name)
        
        return sorted(episodes)
    
    def load_episode(self, episode_id: str) -> bool:
        """Load episode data from disk."""
        episode_path = self.data_directory / episode_id / "episode_data.json"
        
        if not episode_path.exists():
            print(f"Episode data file not found: {episode_path}")
            return False
        
        try:
            with open(episode_path, 'r') as f:
                self.episode_data = json.load(f)
            self.current_episode = episode_id
            print(f"Loaded episode: {episode_id}")
            print(f"Input message: {self.episode_data.get('input_message', 'N/A')}")
            print(f"Duration: {self.episode_data.get('duration', 0):.2f} seconds")
            print(f"Trajectory points: {len(self.episode_data.get('trajectory_data', []))}")
            return True
        except Exception as e:
            print(f"Error loading episode data: {e}")
            return False
    
    def get_trajectory_summary(self) -> Dict[str, Any]:
        """Get summary statistics for the current episode."""
        if not self.episode_data:
            return {}
        
        trajectory_data = self.episode_data.get('trajectory_data', [])
        if not trajectory_data:
            return {}
        
        # Extract basic statistics
        total_points = len(trajectory_data)
        duration = self.episode_data.get('duration', 0)
        
        # Extract prompts
        prompts = [point.get('prompt') for point in trajectory_data if point.get('prompt')]
        unique_prompts = list(set(prompts))
        
        # Calculate joint movement ranges
        joint_positions = [point['observations']['joint_state'] for point in trajectory_data]
        joint_ranges = []
        if joint_positions:
            for joint_idx in range(len(joint_positions[0])):
                joint_values = [pos[joint_idx] for pos in joint_positions]
                joint_ranges.append(max(joint_values) - min(joint_values))
        
        # Calculate cartesian distance
        positions = [point['observations']['cartesian_position']['position'] for point in trajectory_data]
        total_distance = 0.0
        if len(positions) > 1:
            for i in range(1, len(positions)):
                dx = positions[i]['x'] - positions[i-1]['x']
                dy = positions[i]['y'] - positions[i-1]['y']
                dz = positions[i]['z'] - positions[i-1]['z']
                total_distance += (dx**2 + dy**2 + dz**2)**0.5
        
        return {
            'total_points': total_points,
            'duration': duration,
            'frequency': total_points / duration if duration > 0 else 0,
            'unique_prompts': len(unique_prompts),
            'prompts': unique_prompts,
            'joint_ranges': joint_ranges,
            'total_distance': total_distance
        }
    
    def decode_image(self, hex_string: str, image_type: str = 'rgb') -> np.ndarray:
        """Decode hex-encoded image data back to numpy array."""
        try:
            # Convert hex string back to bytes
            image_bytes = bytes.fromhex(hex_string)
            
            # Decode based on image type
            if image_type == 'rgb':
                # Decode JPEG
                nparr = np.frombuffer(image_bytes, np.uint8)
                image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
                # Convert BGR to RGB for matplotlib
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            else:  # depth
                # Decode PNG
                nparr = np.frombuffer(image_bytes, np.uint8)
                image = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
            
            return image
        except Exception as e:
            print(f"Error decoding {image_type} image: {e}")
            return None

# Initialize inspector
inspector = TrajectoryInspector("/Users/wiktorjurasz/Projects/aera/rl_training_data")
episodes = inspector.list_episodes()
print(f"Found {len(episodes)} episodes: {episodes}")

## Episode Selection and Loading

In [None]:
# Interactive episode selector
if episodes:
    episode_dropdown = widgets.Dropdown(
        options=episodes,
        value=episodes[0],
        description='Episode:',
        disabled=False,
    )
    
    load_button = widgets.Button(
        description='Load Episode',
        disabled=False,
        button_style='success',
        tooltip='Load the selected episode',
    )
    
    output = widgets.Output()
    
    def on_load_clicked(b):
        with output:
            output.clear_output()
            selected_episode = episode_dropdown.value
            success = inspector.load_episode(selected_episode)
            if success:
                summary = inspector.get_trajectory_summary()
                print("\n=== Episode Summary ===")
                for key, value in summary.items():
                    if key == 'joint_ranges':
                        print(f"{key}: {[f'{r:.3f}' for r in value]}")
                    elif isinstance(value, float):
                        print(f"{key}: {value:.3f}")
                    else:
                        print(f"{key}: {value}")
    
    load_button.on_click(on_load_clicked)
    
    display(widgets.HBox([episode_dropdown, load_button]))
    display(output)
else:
    print("No episodes found. Make sure the data directory exists and contains episode data.")

## Trajectory Visualization

In [None]:
def plot_joint_trajectories():
    """Plot joint position trajectories over time."""
    if not inspector.episode_data:
        print("No episode loaded. Please load an episode first.")
        return
    
    trajectory_data = inspector.episode_data.get('trajectory_data', [])
    if not trajectory_data:
        print("No trajectory data available.")
        return
    
    # Extract joint positions and timestamps
    timestamps = [point['observations']['timestamp'] for point in trajectory_data]
    joint_positions = [point['observations']['joint_state'] for point in trajectory_data]
    
    # Convert to relative time (seconds from start)
    start_time = timestamps[0]
    relative_times = [(t - start_time) for t in timestamps]
    
    # Convert to numpy array for easier plotting
    joint_array = np.array(joint_positions)
    num_joints = joint_array.shape[1]
    
    # Create subplots
    fig, axes = plt.subplots(num_joints, 1, figsize=(12, 2*num_joints), sharex=True)
    if num_joints == 1:
        axes = [axes]
    
    for i in range(num_joints):
        axes[i].plot(relative_times, joint_array[:, i], linewidth=2, label=f'Joint {i+1}')
        axes[i].set_ylabel(f'Joint {i+1}\n(radians)')
        axes[i].grid(True, alpha=0.3)
        axes[i].legend()
    
    axes[-1].set_xlabel('Time (seconds)')
    plt.suptitle(f'Joint Trajectories - Episode: {inspector.current_episode}', fontsize=14)
    plt.tight_layout()
    plt.show()

def plot_cartesian_trajectory():
    """Plot 3D cartesian trajectory of end effector."""
    if not inspector.episode_data:
        print("No episode loaded. Please load an episode first.")
        return
    
    trajectory_data = inspector.episode_data.get('trajectory_data', [])
    if not trajectory_data:
        print("No trajectory data available.")
        return
    
    # Extract cartesian positions
    positions = [point['observations']['cartesian_position']['position'] for point in trajectory_data]
    
    x_coords = [pos['x'] for pos in positions]
    y_coords = [pos['y'] for pos in positions]
    z_coords = [pos['z'] for pos in positions]
    
    # Create 3D plot
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot trajectory with color gradient
    colors = plt.cm.viridis(np.linspace(0, 1, len(x_coords)))
    for i in range(len(x_coords)-1):
        ax.plot([x_coords[i], x_coords[i+1]], 
                [y_coords[i], y_coords[i+1]], 
                [z_coords[i], z_coords[i+1]], 
                color=colors[i], linewidth=2)
    
    # Mark start and end points
    ax.scatter(x_coords[0], y_coords[0], z_coords[0], 
               color='green', s=100, label='Start', marker='o')
    ax.scatter(x_coords[-1], y_coords[-1], z_coords[-1], 
               color='red', s=100, label='End', marker='s')
    
    ax.set_xlabel('X (meters)')
    ax.set_ylabel('Y (meters)')
    ax.set_zlabel('Z (meters)')
    ax.set_title(f'3D Cartesian Trajectory - Episode: {inspector.current_episode}')
    ax.legend()
    
    # Make axes equal
    max_range = np.array([x_coords, y_coords, z_coords]).max()
    min_range = np.array([x_coords, y_coords, z_coords]).min()
    ax.set_xlim([min_range, max_range])
    ax.set_ylim([min_range, max_range])
    ax.set_zlim([min_range, max_range])
    
    plt.show()

# Create buttons for plotting
joint_plot_button = widgets.Button(description='Plot Joint Trajectories', button_style='info')
cartesian_plot_button = widgets.Button(description='Plot 3D Cartesian Path', button_style='info')

joint_plot_button.on_click(lambda b: plot_joint_trajectories())
cartesian_plot_button.on_click(lambda b: plot_cartesian_trajectory())

display(widgets.HBox([joint_plot_button, cartesian_plot_button]))

## Image Viewer

In [None]:
def create_image_viewer():
    """Create interactive image viewer for RGB and depth images."""
    if not inspector.episode_data:
        print("No episode loaded. Please load an episode first.")
        return
    
    trajectory_data = inspector.episode_data.get('trajectory_data', [])
    if not trajectory_data:
        print("No trajectory data available.")
        return
    
    # Create slider for frame selection
    frame_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(trajectory_data)-1,
        step=1,
        description='Frame:',
        continuous_update=False
    )
    
    # Create image type selector
    image_type = widgets.RadioButtons(
        options=['RGB', 'Depth', 'Both'],
        value='Both',
        description='Image Type:',
        disabled=False
    )
    
    # Output widget for images
    image_output = widgets.Output()
    
    def update_images(change=None):
        with image_output:
            image_output.clear_output(wait=True)
            
            frame_idx = frame_slider.value
            data_point = trajectory_data[frame_idx]
            
            # Get image data
            rgb_hex = data_point['observations'].get('rgb_image')
            depth_hex = data_point['observations'].get('depth_image')
            timestamp = data_point['observations']['timestamp']
            prompt = data_point.get('prompt', 'N/A')
            
            print(f"Frame {frame_idx}/{len(trajectory_data)-1}")
            print(f"Timestamp: {timestamp:.3f}")
            print(f"Prompt: {prompt}")
            print(f"Is First: {data_point.get('is_first', False)}")
            print(f"Is Last: {data_point.get('is_last', False)}")
            
            # Decode and display images
            if image_type.value in ['RGB', 'Both'] and rgb_hex:
                rgb_image = inspector.decode_image(rgb_hex, 'rgb')
                if rgb_image is not None:
                    if image_type.value == 'Both':
                        plt.figure(figsize=(15, 6))
                        plt.subplot(1, 2, 1)
                    else:
                        plt.figure(figsize=(8, 6))
                    
                    plt.imshow(rgb_image)
                    plt.title(f'RGB Image - Frame {frame_idx}')
                    plt.axis('off')
            
            if image_type.value in ['Depth', 'Both'] and depth_hex:
                depth_image = inspector.decode_image(depth_hex, 'depth')
                if depth_image is not None:
                    if image_type.value == 'Both':
                        plt.subplot(1, 2, 2)
                    else:
                        plt.figure(figsize=(8, 6))
                    
                    plt.imshow(depth_image, cmap='viridis')
                    plt.title(f'Depth Image - Frame {frame_idx}')
                    plt.axis('off')
                    plt.colorbar(label='Depth')
            
            plt.tight_layout()
            plt.show()
    
    # Connect widgets to update function
    frame_slider.observe(update_images, names='value')
    image_type.observe(update_images, names='value')
    
    # Display widgets
    controls = widgets.VBox([frame_slider, image_type])
    display(controls)
    display(image_output)
    
    # Show initial images
    update_images()

# Create button to launch image viewer
image_viewer_button = widgets.Button(
    description='Launch Image Viewer',
    button_style='warning',
    tooltip='View RGB and depth images from the trajectory'
)

image_viewer_button.on_click(lambda b: create_image_viewer())
display(image_viewer_button)

## Synchronization Analysis

In [None]:
def analyze_synchronization():
    """Analyze and visualize synchronization statistics."""
    if not inspector.episode_data:
        print("No episode loaded. Please load an episode first.")
        return
    
    sync_stats = inspector.episode_data.get('synchronization_stats', {})
    if not sync_stats:
        print("No synchronization statistics available.")
        return
    
    print("=== Synchronization Statistics ===")
    print(f"Sync tolerance used: {sync_stats.get('sync_tolerance_used', 'N/A')} seconds")
    print(f"Total failed syncs: {sync_stats.get('total_failed_syncs', 0)}")
    
    # Failed syncs by type
    failed_syncs = sync_stats.get('failed_syncs_by_type', {})
    if failed_syncs:
        print("\nFailed syncs by type:")
        for data_type, count in failed_syncs.items():
            print(f"  {data_type}: {count}")
    
    # Plot synchronization discrepancies
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # RGB sync discrepancies
    rgb_sync = sync_stats.get('rgb_sync', {})
    if rgb_sync.get('count', 0) > 0:
        axes[0].bar(['Mean', 'Max', 'Min'], 
                   [rgb_sync.get('mean_discrepancy', 0),
                    rgb_sync.get('max_discrepancy', 0),
                    rgb_sync.get('min_discrepancy', 0)])
        axes[0].set_title(f'RGB Sync Discrepancies\n({rgb_sync.get("count", 0)} syncs)')
        axes[0].set_ylabel('Discrepancy (seconds)')
    else:
        axes[0].text(0.5, 0.5, 'No RGB sync data', ha='center', va='center', transform=axes[0].transAxes)
        axes[0].set_title('RGB Sync Discrepancies')
    
    # Depth sync discrepancies
    depth_sync = sync_stats.get('depth_sync', {})
    if depth_sync.get('count', 0) > 0:
        axes[1].bar(['Mean', 'Max', 'Min'], 
                   [depth_sync.get('mean_discrepancy', 0),
                    depth_sync.get('max_discrepancy', 0),
                    depth_sync.get('min_discrepancy', 0)])
        axes[1].set_title(f'Depth Sync Discrepancies\n({depth_sync.get("count", 0)} syncs)')
        axes[1].set_ylabel('Discrepancy (seconds)')
    else:
        axes[1].text(0.5, 0.5, 'No depth sync data', ha='center', va='center', transform=axes[1].transAxes)
        axes[1].set_title('Depth Sync Discrepancies')
    
    # Pose sync discrepancies
    pose_sync = sync_stats.get('pose_sync', {})
    if pose_sync.get('count', 0) > 0:
        axes[2].bar(['Mean', 'Max', 'Min'], 
                   [pose_sync.get('mean_discrepancy', 0),
                    pose_sync.get('max_discrepancy', 0),
                    pose_sync.get('min_discrepancy', 0)])
        axes[2].set_title(f'Pose Sync Discrepancies\n({pose_sync.get("count", 0)} syncs)')
        axes[2].set_ylabel('Discrepancy (seconds)')
    else:
        axes[2].text(0.5, 0.5, 'No pose sync data', ha='center', va='center', transform=axes[2].transAxes)
        axes[2].set_title('Pose Sync Discrepancies')
    
    plt.tight_layout()
    plt.show()

# Create button for synchronization analysis
sync_button = widgets.Button(
    description='Analyze Synchronization',
    button_style='success',
    tooltip='Analyze data synchronization statistics'
)

sync_button.on_click(lambda b: analyze_synchronization())
display(sync_button)

## Data Export and Summary

In [None]:
def export_trajectory_summary():
    """Export trajectory summary to CSV for further analysis."""
    if not inspector.episode_data:
        print("No episode loaded. Please load an episode first.")
        return
    
    trajectory_data = inspector.episode_data.get('trajectory_data', [])
    if not trajectory_data:
        print("No trajectory data available.")
        return
    
    # Create DataFrame with key trajectory information
    data_rows = []
    
    for i, point in enumerate(trajectory_data):
        obs = point['observations']
        action = point['action']
        
        row = {
            'frame': i,
            'timestamp': obs['timestamp'],
            'prompt': point.get('prompt', ''),
            'is_first': point.get('is_first', False),
            'is_last': point.get('is_last', False),
            'is_terminal': point.get('is_terminal', False),
            'reward': point.get('default_reward', 0.0),
        }
        
        # Add joint positions
        for j, pos in enumerate(obs['joint_state']):
            row[f'joint_{j+1}_pos'] = pos
        
        # Add gripper state
        for j, pos in enumerate(obs['gripper_state']):
            row[f'gripper_{j+1}_pos'] = pos
        
        # Add cartesian position
        cart_pos = obs['cartesian_position']['position']
        row['cart_x'] = cart_pos['x']
        row['cart_y'] = cart_pos['y']
        row['cart_z'] = cart_pos['z']
        
        # Add cartesian orientation
        cart_ori = obs['cartesian_position']['orientation']
        row['cart_qx'] = cart_ori['x']
        row['cart_qy'] = cart_ori['y']
        row['cart_qz'] = cart_ori['z']
        row['cart_qw'] = cart_ori['w']
        
        data_rows.append(row)
    
    # Create DataFrame
    df = pd.DataFrame(data_rows)
    
    # Save to CSV
    output_file = f"trajectory_summary_{inspector.current_episode}.csv"
    df.to_csv(output_file, index=False)
    
    print(f"Exported trajectory summary to: {output_file}")
    print(f"DataFrame shape: {df.shape}")
    print("\nFirst few rows:")
    display(df.head())
    
    return df

def show_detailed_summary():
    """Show detailed summary of the current episode."""
    if not inspector.episode_data:
        print("No episode loaded. Please load an episode first.")
        return
    
    summary = inspector.get_trajectory_summary()
    
    print(f"=== Detailed Summary for Episode: {inspector.current_episode} ===")
    print(f"Input Message: {inspector.episode_data.get('input_message', 'N/A')}")
    print(f"Start Time: {inspector.episode_data.get('start_time', 'N/A')}")
    print(f"End Time: {inspector.episode_data.get('end_time', 'N/A')}")
    print(f"Duration: {summary.get('duration', 0):.3f} seconds")
    print(f"Total Data Points: {summary.get('total_points', 0)}")
    print(f"Average Frequency: {summary.get('frequency', 0):.2f} Hz")
    print(f"Unique Prompts: {summary.get('unique_prompts', 0)}")
    print(f"Prompts: {summary.get('prompts', [])}")
    print(f"Total Cartesian Distance: {summary.get('total_distance', 0):.4f} meters")
    
    joint_ranges = summary.get('joint_ranges', [])
    if joint_ranges:
        print("\nJoint Movement Ranges (radians):")
        for i, range_val in enumerate(joint_ranges):
            print(f"  Joint {i+1}: {range_val:.4f}")
    
    # Show metadata
    metadata = inspector.episode_data.get('metadata', {})
    if metadata:
        print("\nImage Metadata:")
        print(f"  Image Size: {metadata.get('image_width', 'N/A')} x {metadata.get('image_height', 'N/A')}")
        print(f"  RGB Encoding: {metadata.get('rgb_encoding', 'N/A')}")
        print(f"  Depth Encoding: {metadata.get('depth_encoding', 'N/A')}")

# Create buttons for export and summary
export_button = widgets.Button(
    description='Export to CSV',
    button_style='warning',
    tooltip='Export trajectory data to CSV file'
)

summary_button = widgets.Button(
    description='Show Detailed Summary',
    button_style='info',
    tooltip='Show detailed episode summary'
)

export_button.on_click(lambda b: export_trajectory_summary())
summary_button.on_click(lambda b: show_detailed_summary())

display(widgets.HBox([summary_button, export_button]))

## Quick Analysis Cell

Use this cell for quick custom analysis and exploration:

In [None]:
# Quick analysis cell - customize as needed
if inspector.episode_data:
    trajectory_data = inspector.episode_data.get('trajectory_data', [])
    
    # Example: Plot gripper state over time
    if trajectory_data:
        timestamps = [point['observations']['timestamp'] for point in trajectory_data]
        gripper_states = [point['observations']['gripper_state'] for point in trajectory_data]
        
        # Convert to relative time
        start_time = timestamps[0]
        relative_times = [(t - start_time) for t in timestamps]
        
        # Plot gripper positions
        if gripper_states and len(gripper_states[0]) > 0:
            gripper_array = np.array(gripper_states)
            
            plt.figure(figsize=(12, 4))
            for i in range(gripper_array.shape[1]):
                plt.plot(relative_times, gripper_array[:, i], label=f'Gripper Joint {i+1}', linewidth=2)
            
            plt.xlabel('Time (seconds)')
            plt.ylabel('Gripper Position')
            plt.title('Gripper State Over Time')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.show()
        
        # Example: Show prompt transitions
        prompts = [point.get('prompt', 'None') for point in trajectory_data]
        unique_prompts = []
        prompt_changes = []
        
        current_prompt = None
        for i, prompt in enumerate(prompts):
            if prompt != current_prompt:
                unique_prompts.append(prompt)
                prompt_changes.append(relative_times[i])
                current_prompt = prompt
        
        if len(unique_prompts) > 1:
            print("\nPrompt Transitions:")
            for i, (prompt, time) in enumerate(zip(unique_prompts, prompt_changes)):
                print(f"  {time:.2f}s: {prompt}")
else:
    print("Load an episode first to run custom analysis.")