In [2]:
dataFilesAreIn="data/raw"

import os
import sys


In [3]:
import os
import torch
import numpy as np
import pickle
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mmrnet.dataset.mmrnet_data import MMRKeypointData
# Add this to the third code cell

import ipywidgets as widgets
from IPython.display import display, clear_output

dataset_config = {
    'seed': 20,
    'train_split': 0.8,
    'val_split': 0.1,
    'test_split': 0.1,
    'stacks': 3,
    'zero_padding': 'per_data_point'
}

# Load dataset only once
def load_dataset():
    print("Creating dataset...")
    
    # Create necessary directories
    os.makedirs(os.path.dirname('data/processed/mmr_kp/data.pkl'), exist_ok=True)
    
    # Use forced_rewrite if needed (set to False after first successful run)
    config = dataset_config.copy()
    config['forced_rewrite'] = True  # Set to False after first successful run
    
    try:
        dataset = MMRKeypointData(
            root='data/mmr_kp', 
            partition='train', 
            mmr_dataset_config=config
        )
        print(f"Dataset loaded successfully with {len(dataset)} samples")
        return dataset
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        
        # Check if data directories exist
        raw_path = 'data/raw'
        if os.path.exists(raw_path):
            print(f"\nRaw data directory exists. Contents of {raw_path}:")
            files = os.listdir(raw_path)
            for f in files[:10]:
                print(f"  - {f}")
            if len(files) > 10:
                print(f"  - ... and {len(files)-10} more files")
        else:
            print(f"\nRaw data directory {raw_path} does not exist")
        
        return None

# Visualize a sample from the dataset
def visualize_sample(dataset, idx=None):
    if dataset is None:
        print("No dataset available to visualize")
        return
    
    # Get a random sample if index is not provided
    if idx is None:
        idx = random.randint(0, len(dataset) - 1)
    
    # Get data for the specified index
    point_cloud, keypoints = dataset[idx]
    
    # Convert tensors to numpy for plotting
    if isinstance(point_cloud, torch.Tensor):
        point_cloud = point_cloud.numpy()
    if isinstance(keypoints, torch.Tensor):
        keypoints = keypoints.numpy()
    
    print(f"Sample index: {idx}")
    print(f"Point cloud shape: {point_cloud.shape}")
    print(f"Keypoints shape: {keypoints.shape}")
    
    # Create 3D visualization
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot point cloud
    ax.scatter(
        point_cloud[:, 0], 
        point_cloud[:, 1], 
        point_cloud[:, 2],
        c='blue', 
        s=5, 
        alpha=0.5, 
        label='Point Cloud'
    )
    
    # Plot keypoints
    ax.scatter(
        keypoints[:, 0], 
        keypoints[:, 1], 
        keypoints[:, 2],
        c='red', 
        s=100, 
        marker='*', 
        label='Keypoints'
    )
    
    # Draw lines connecting keypoints to visualize skeleton
    kp_connections = [
        # Right arm
        (0, 1),  # RIGHT_SHOULDER to RIGHT_ELBOW
        # Left arm
        (2, 3),  # LEFT_SHOULDER to LEFT_ELBOW
        # Right leg
        (4, 5),  # RIGHT_HIP to RIGHT_KNEE
        # Left leg
        (6, 7),  # LEFT_HIP to LEFT_KNEE
        # Torso
        (0, 4),  # RIGHT_SHOULDER to RIGHT_HIP
        (2, 6),  # LEFT_SHOULDER to LEFT_HIP
        (0, 2),  # RIGHT_SHOULDER to LEFT_SHOULDER
        (4, 6),  # RIGHT_HIP to LEFT_HIP
        # Head connections
        (8, 0),  # HEAD to RIGHT_SHOULDER
        (8, 2),  # HEAD to LEFT_SHOULDER
    ]
    
    for start, end in kp_connections:
        ax.plot(
            [keypoints[start, 0], keypoints[end, 0]],
            [keypoints[start, 1], keypoints[end, 1]],
            [keypoints[start, 2], keypoints[end, 2]],
            color='green',
            linestyle='-',
            linewidth=2,
            alpha=0.7
        )
    
    # Annotate keypoints
    kp_names = MMRKeypointData.kp9_names
    for i, name in enumerate(kp_names):
        ax.text(
            keypoints[i, 0], 
            keypoints[i, 1], 
            keypoints[i, 2], 
            name.replace('_', ' '),
            fontsize=8, 
            color='black'
        )
    
    # Set axes labels and title
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(f'Point Cloud and Keypoints (Sample {idx})')
    ax.legend()
    
    # Set equal aspect ratio
    max_range = np.array([
        point_cloud[:, 0].max() - point_cloud[:, 0].min(),
        point_cloud[:, 1].max() - point_cloud[:, 1].min(),
        point_cloud[:, 2].max() - point_cloud[:, 2].min()
    ]).max() / 2.0
    
    mid_x = (point_cloud[:, 0].max() + point_cloud[:, 0].min()) * 0.5
    mid_y = (point_cloud[:, 1].max() + point_cloud[:, 1].min()) * 0.5
    mid_z = (point_cloud[:, 2].max() + point_cloud[:, 2].min()) * 0.5
    
    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)
    
    plt.tight_layout()
    plt.show()
    
    return point_cloud, keypoints

def create_browser(dataset):
    if dataset is None:
        print("No dataset available to browse")
        return
    
    # Create a function to update the plot when the slider is moved
    def update_plot(change):
        idx = change.new
        clear_output(wait=True)
        visualize_sample(dataset, idx)
        display(slider)
    
    # Create a slider widget
    slider = widgets.IntSlider(
        min=0,
        max=len(dataset)-1,
        step=1,
        value=0,
        description='Sample:',
        continuous_update=False,
        layout=widgets.Layout(width='600px')
    )
    
    # Connect the slider to the update function
    slider.observe(update_plot, names='value')
    
    # Initial display
    display(slider)
    visualize_sample(dataset, 0)





In [4]:
dataset = load_dataset()


Creating dataset...
Transforming keypoints ...


100%|██████████| 545059/545059 [00:12<00:00, 44344.04it/s]


Transforming keypoints done
Stacking and padding frames...


100%|██████████| 545059/545059 [02:33<00:00, 3555.64it/s]


Stacking and padding frames done
Dataset loaded successfully with 436047 samples


In [5]:
# Add this function to the second cell or create a new cell

def create_animation(dataset, start_idx=0, num_frames=10, interval=200, save_path=None):
    """
    Create an animation of a sequence of samples on the same axis.
    
    Args:
        dataset: The dataset to visualize
        start_idx: The starting index
        num_frames: Number of consecutive frames to animate
        interval: Time between frames in milliseconds
        save_path: Path to save the animation (optional)
    """
    import matplotlib.animation as animation
    
    if dataset is None:
        print("No dataset available to visualize")
        return
    
    total_samples = len(dataset)
    if start_idx < 0 or start_idx >= total_samples:
        print(f"Start index out of range. Valid range: 0-{total_samples-1}")
        return
    
    end_idx = min(start_idx + num_frames, total_samples)
    frame_indices = list(range(start_idx, end_idx))
    
    print(f"Creating animation for samples {start_idx} to {end_idx-1}")
    
    # Pre-load all data
    data = []
    for idx in frame_indices:
        point_cloud, keypoints = dataset[idx]
        
        # Convert tensors to numpy if needed
        if isinstance(point_cloud, torch.Tensor):
            point_cloud = point_cloud.numpy()
        if isinstance(keypoints, torch.Tensor):
            keypoints = keypoints.numpy()
            
        data.append((point_cloud, keypoints))
    
    # Create figure and initial plot
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    # Define the keypoint connections
    kp_connections = [
        # Right arm
        (0, 1),  # RIGHT_SHOULDER to RIGHT_ELBOW
        # Left arm
        (2, 3),  # LEFT_SHOULDER to LEFT_ELBOW
        # Right leg
        (4, 5),  # RIGHT_HIP to RIGHT_KNEE
        # Left leg
        (6, 7),  # LEFT_HIP to LEFT_KNEE
        # Torso
        (0, 4),  # RIGHT_SHOULDER to RIGHT_HIP
        (2, 6),  # LEFT_SHOULDER to LEFT_HIP
        (0, 2),  # RIGHT_SHOULDER to LEFT_SHOULDER
        (4, 6),  # RIGHT_HIP to LEFT_HIP
        # Head connections
        (8, 0),  # HEAD to RIGHT_SHOULDER
        (8, 2),  # HEAD to LEFT_SHOULDER
    ]
    
    # Calculate the global bounds to keep the view consistent
    all_points = np.concatenate([pc for pc, _ in data], axis=0)
    x_min, y_min, z_min = all_points.min(axis=0)
    x_max, y_max, z_max = all_points.max(axis=0)
    
    max_range = max(x_max - x_min, y_max - y_min, z_max - z_min) / 2.0
    mid_x = (x_max + x_min) / 2.0
    mid_y = (y_max + y_min) / 2.0
    mid_z = (z_max + z_min) / 2.0
    
    # Initialize plot objects that will be updated
    point_cloud_scatter = ax.scatter([], [], [], c='blue', s=5, alpha=0.3, label='Point Cloud')
    keypoint_scatter = ax.scatter([], [], [], c='red', s=100, marker='*', label='Keypoints')
    
    # Initialize lines for skeleton connections
    lines = [ax.plot([], [], [], color='green', linestyle='-', linewidth=2, alpha=0.7)[0] 
             for _ in range(len(kp_connections))]
    
    # Text annotation for frame number
    frame_text = ax.text2D(0.02, 0.98, "", transform=ax.transAxes, fontsize=14)
    
    # Set axes limits
    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)
    
    # Labels and legend
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('Point Cloud and Keypoints Animation')
    ax.legend()
    
    def init():
        """Initialize the animation"""
        point_cloud_scatter._offsets3d = ([], [], [])
        keypoint_scatter._offsets3d = ([], [], [])
        for line in lines:
            line.set_data([], [])
            line.set_3d_properties([])
        frame_text.set_text("")
        return [point_cloud_scatter, keypoint_scatter, frame_text] + lines
    
    def update(frame_num):
        """Update the animation for each frame"""
        # Clear previous frame
        ax.clear()
        
        # Get data for current frame
        point_cloud, keypoints = data[frame_num]
        
        # Update point cloud and keypoints
        ax.scatter(point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2], 
                   c='blue', s=5, alpha=0.3, label='Point Cloud')
        ax.scatter(keypoints[:, 0], keypoints[:, 1], keypoints[:, 2], 
                   c='red', s=100, marker='*', label='Keypoints')
        
        # Update skeleton lines
        for i, (start, end) in enumerate(kp_connections):
            ax.plot([keypoints[start, 0], keypoints[end, 0]],
                    [keypoints[start, 1], keypoints[end, 1]],
                    [keypoints[start, 2], keypoints[end, 2]],
                    color='green', linestyle='-', linewidth=2, alpha=0.7)
        
        # Update frame number text
        ax.text2D(0.02, 0.98, f"Frame: {frame_indices[frame_num]}", 
                  transform=ax.transAxes, fontsize=14)
        
        # Keep consistent view
        ax.set_xlim(mid_x - max_range, mid_x + max_range)
        ax.set_ylim(mid_y - max_range, mid_y + max_range)
        ax.set_zlim(mid_z - max_range, mid_z + max_range)
        
        # Labels
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('Point Cloud and Keypoints Animation')
        ax.legend()
        
        # Required for 3D animations
        return []
    
    # Create animation
    ani = animation.FuncAnimation(
        fig, update, frames=len(frame_indices),
        init_func=init, interval=interval, blit=False)
    
    # Save animation if requested
    if save_path:
        ani.save(save_path, writer='pillow', fps=1000//interval)
        print(f"Animation saved to {save_path}")
    
    # Display the animation inline
    plt.tight_layout()
    plt.close()  # Prevents duplicate display
    
    print(f"Animation created with {len(frame_indices)} frames")
    print(f"To view additional sequences: create_animation(dataset, start_idx={end_idx}, num_frames={num_frames})")
    
    return ani

In [9]:
len(dataset)

436047

In [None]:
# Add this to a new cell


# Optionally save the animation
ramdonnumber=random.randint(0, len(dataset)-1)
animation = create_animation(dataset, start_idx=ramdonnumber, num_frames=30, interval=200,
                            save_path="point_cloud_animation.gif")


Creating animation for samples 291141 to 291150
Animation saved to point_cloud_animation.gif
Animation created with 10 frames
To view additional sequences: create_animation(dataset, start_idx=291151, num_frames=10)
