In [32]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
# pip install plotly numpy pandas nbformat
# ---------------------------------------------------------
# 1. DEFINE SKELETON CONNECTIVITY
# ---------------------------------------------------------

SKELETON_CONNECTIONS = [
    # TORSO
    ('neck', 'nose'),
    ('neck', 'mid_hip'),
    ('mid_hip', 'left_hip'),
    ('mid_hip', 'right_hip'),
    ('left_shoulder', 'right_shoulder'),
    ('neck', 'left_shoulder'),
    ('neck', 'right_shoulder'),

    # HEAD
    ('nose', 'left_eye'),
    ('nose', 'right_eye'),
    ('left_eye', 'left_ear'),
    ('right_eye', 'right_ear'),

    # ARMS
    ('left_shoulder', 'left_elbow'),
    ('left_elbow', 'left_wrist'),
    ('right_shoulder', 'right_elbow'),
    ('right_elbow', 'right_wrist'),

    # LEGS
    ('left_hip', 'left_knee'),
    ('left_knee', 'left_ankle'),
    ('left_ankle', 'left_heel'),
    ('left_ankle', 'left_big_toe'),
    ('left_ankle', 'left_small_toe'),
    ('right_hip', 'right_knee'),
    ('right_knee', 'right_ankle'),
    ('right_ankle', 'right_heel'),
    ('right_ankle', 'right_big_toe'),
    ('right_ankle', 'right_small_toe'),

    # LEFT HAND (Detailed)
    # Thumb (1st)
    ('left_wrist', 'left_first_finger_cmc'),
    ('left_first_finger_cmc', 'left_first_finger_mcp'),
    ('left_first_finger_mcp', 'left_first_finger_ip'),
    ('left_first_finger_ip', 'left_first_finger_distal'),
    # Index (2nd)
    ('left_wrist', 'left_second_finger_mcp'),
    ('left_second_finger_mcp', 'left_second_finger_pip'),
    ('left_second_finger_pip', 'left_second_finger_dip'),
    ('left_second_finger_dip', 'left_second_finger_distal'),
    # Middle (3rd)
    ('left_wrist', 'left_third_finger_mcp'),
    ('left_third_finger_mcp', 'left_third_finger_pip'),
    ('left_third_finger_pip', 'left_third_finger_dip'),
    ('left_third_finger_dip', 'left_third_finger_distal'),
    # Ring (4th)
    ('left_wrist', 'left_fourth_finger_mcp'),
    ('left_fourth_finger_mcp', 'left_fourth_finger_pip'),
    ('left_fourth_finger_pip', 'left_fourth_finger_dip'),
    ('left_fourth_finger_dip', 'left_fourth_finger_distal'),
    # Pinky (5th)
    ('left_wrist', 'left_fifth_finger_mcp'),
    ('left_fifth_finger_mcp', 'left_fifth_finger_pip'),
    ('left_fifth_finger_pip', 'left_fifth_finger_dip'),
    ('left_fifth_finger_dip', 'left_fifth_finger_distal'),

    # RIGHT HAND (Detailed) - Mirror of Left

    # Thumb
    ('right_wrist', 'right_first_finger_cmc'),
    ('right_first_finger_cmc', 'right_first_finger_mcp'),
    ('right_first_finger_mcp', 'right_first_finger_ip'),
    ('right_first_finger_ip', 'right_first_finger_distal'),
    # Pointer
    ('right_wrist', 'right_second_finger_mcp'),
    ('right_second_finger_mcp', 'right_second_finger_pip'),
    ('right_second_finger_pip', 'right_second_finger_dip'),
    ('right_second_finger_dip', 'right_second_finger_distal'),
    # Middle
    ('right_wrist', 'right_third_finger_mcp'),
    ('right_third_finger_mcp', 'right_third_finger_pip'),
    ('right_third_finger_pip', 'right_third_finger_dip'),
    ('right_third_finger_dip', 'right_third_finger_distal'),
    # Ring
    ('right_wrist', 'right_fourth_finger_mcp'),
    ('right_fourth_finger_mcp', 'right_fourth_finger_pip'),
    ('right_fourth_finger_pip', 'right_fourth_finger_dip'),
    ('right_fourth_finger_dip', 'right_fourth_finger_distal'),
    # Pinky
    ('right_wrist', 'right_fifth_finger_mcp'),
    ('right_fifth_finger_mcp', 'right_fifth_finger_pip'),
    ('right_fifth_finger_pip', 'right_fifth_finger_dip'),
    ('right_fifth_finger_dip', 'right_fifth_finger_distal'),
]

def parse_coordinate(val):
    """
    Handles lists with commas, newlines, or just spaces when reading from the keypoint positional arrays
    """
    if isinstance(val, str):
        clean_val = val.replace('[', '').replace(']', '')
        clean_val = clean_val.replace(',', ' ').replace('\n', ' ')
        clean_val = clean_val.strip()
        return np.fromstring(clean_val, sep=' ')
    return np.array(val)

def visualize_shot(row_data, fps=60):
    """
    Visualizes a single shot (row) from the dataset.
    
    Args:
        row_data: A pandas Series or dictionary containing the columns.
    """
    
    # 1. Identify all unique joints from column names
    # Filter keys that end in _x, _y, _z to find unique joint names
    cols = row_data.keys()
    joints = set()
    for k in cols:
        if k.endswith('_x'):
            joints.add(k[:-2])
            
    joints = sorted(list(joints))
    
    # 2. Extract data into a structured format: (n_frames, n_joints, 3)
    # We assume all lists are length 240.
    
    # Check length of first available column to determine n_frames
    first_col = row_data[f"{joints[0]}_x"]
    first_col_parsed = parse_coordinate(first_col)
    n_frames = len(first_col_parsed)
    print(f'Number of frames:{n_frames}')
    
    # Pre-allocate array
    # Shape: (Frames, Joints, XYZ)
    motion_data = np.zeros((n_frames, len(joints), 3))
    
    # Map joint name to index for easy lookup later
    joint_to_idx = {name: i for i, name in enumerate(joints)}
    
    print(f"Parsing data for {len(joints)} joints over {n_frames} frames...")
    
    for j_name in joints:
        idx = joint_to_idx[j_name]
        try:
            motion_data[:, idx, 0] = parse_coordinate(row_data[f"{j_name}_x"])
            motion_data[:, idx, 1] = parse_coordinate(row_data[f"{j_name}_y"])
            motion_data[:, idx, 2] = parse_coordinate(row_data[f"{j_name}_z"])
        except KeyError:
            print(f"Warning: Missing data for {j_name}")
            
    # 3. Build Connectivity Indices
    # We need pairs of indices (start, end) for the lines
    connections_indices = []
    for start_name, end_name in SKELETON_CONNECTIONS:
        if start_name in joint_to_idx and end_name in joint_to_idx:
            connections_indices.append((joint_to_idx[start_name], joint_to_idx[end_name]))

    # 4. Setup Plotly Animation
    
    # ---------------------------------------------------------
    # FIXED AXIS CALCULATION (Prevents Bouncing)
    # ---------------------------------------------------------
    
    # 1. Define vertical limits (Z-axis) based on your request
    # Try to find specific joint indices; fallback to global min/max if missing
    try:
        # Get Z-values for all frames for specific body parts
        lw_z = motion_data[:, joint_to_idx['left_wrist'], 2]
        rw_z = motion_data[:, joint_to_idx['right_wrist'], 2]
        lh_z = motion_data[:, joint_to_idx['left_heel'], 2]
        rh_z = motion_data[:, joint_to_idx['right_heel'], 2]
        
        # Calculate specific boundaries
        # Floor: Lowest point of the heels
        min_z = min(np.min(lh_z), np.min(rh_z))
        
        # Ceiling: Highest point of the wrists + 2 units of padding
        max_z = max(np.max(lw_z), np.max(rw_z)) + 2.0
        
        z_range = [min_z, max_z]
    except KeyError:
        # Fallback if specific joints are missing in a specific shot
        print("Specific joints not found for scaling, using global bounds.")
        z_range = [np.min(motion_data[:,:,2]), np.max(motion_data[:,:,2]) + 2.0]

    # 2. Define Horizontal limits (X and Y)
    # We find the global center of the motion and expand outwards to fit the widest movement
    all_x = motion_data[:, :, 0].flatten()
    all_y = motion_data[:, :, 1].flatten()
    
    mid_x = (np.max(all_x) + np.min(all_x)) / 2
    mid_y = (np.max(all_y) + np.min(all_y)) / 2
    
    # Find the max "radius" of movement from the center
    # This ensures X and Y are square and don't stretch
    max_range = max(
        np.max(all_x) - np.min(all_x), 
        np.max(all_y) - np.min(all_y)
    ) / 2.0
    
    padding = 1.0 # Add a little extra side padding
    
    x_range = [mid_x - max_range - padding, mid_x + max_range + padding]
    y_range = [mid_y - max_range - padding, mid_y + max_range + padding]
    frames = []
    for t in range(0, n_frames, 1): # Skip frames if too slow, e.g. range(0, n_frames, 2)
        
        # Coordinates for dots
        x_pts = motion_data[t, :, 0]
        y_pts = motion_data[t, :, 1]
        z_pts = motion_data[t, :, 2]
        
        # Coordinates for lines (Skeleton)
        # We construct a list of lines separated by None to draw multiple segments in one trace
        x_lines, y_lines, z_lines = [], [], []
        for (start_idx, end_idx) in connections_indices:
            x_lines.extend([x_pts[start_idx], x_pts[end_idx], None])
            y_lines.extend([y_pts[start_idx], y_pts[end_idx], None])
            z_lines.extend([z_pts[start_idx], z_pts[end_idx], None])
            
        frames.append(go.Frame(
            data=[
                # Trace 0: Joints
                go.Scatter3d(x=x_pts, y=y_pts, z=z_pts, mode='markers',
                             text = joints,
                             hoverinfo = 'text',
                             marker=dict(size=4, color='blue')),
                # Trace 1: Skeleton
                go.Scatter3d(x=x_lines, y=y_lines, z=z_lines, mode='lines',
                             line=dict(color='red', width=2))
            ],
            name=str(t)
        ))

    # Initial Data (Frame 0)
    x_start = motion_data[0, :, 0]
    y_start = motion_data[0, :, 1]
    z_start = motion_data[0, :, 2]
    
    # Skeleton Frame 0
    x_lines_0, y_lines_0, z_lines_0 = [], [], []
    for (start_idx, end_idx) in connections_indices:
        x_lines_0.extend([x_start[start_idx], x_start[end_idx], None])
        y_lines_0.extend([y_start[start_idx], y_start[end_idx], None])
        z_lines_0.extend([z_start[start_idx], z_start[end_idx], None])

    fig = go.Figure(
        data=[
            go.Scatter3d(x=x_start, y=y_start, z=z_start, mode='markers',text=joints, hoverinfo='text', marker=dict(size=4, color='blue'), name='Joints'),
            go.Scatter3d(x=x_lines_0, y=y_lines_0, z=z_lines_0, mode='lines', line=dict(color='red', width=2), name='Skeleton')
        ],
        layout=go.Layout(
            title="Skeletal Motion",
            scene=dict(
                xaxis=dict(range=x_range, title='X', autorange=False),
                yaxis=dict(range=y_range, title='Y', autorange=False),
                zaxis=dict(range=z_range, title='Z', autorange=False),
                aspectmode='manual', # Forces the box to stay the specific shape we defined
                aspectratio=dict(x=1, y=1, z=1) # Keeps the box looking like a cube (or adjust z if needed)
            ),
            updatemenus=[{
                'type': 'buttons',
                'buttons': [{
                    'label': 'Play',
                    'method': 'animate',
                    'args': [None, {
                        'frame': {'duration': 1000/fps, 'redraw': True},
                        'fromcurrent': True,
                        'transition': {'duration': 0}
                    }]
                }, {
                    'label': 'Pause',
                    'method': 'animate',
                    'args': [[], {
                        'frame': {'duration': 0, 'redraw': False},
                        'mode': 'immediate',
                        'transition': {'duration': 0}
                    }]
                }]
            }],
            sliders=[{
                'currentvalue': {'prefix': 'Frame: '},
                'steps': [
                    {
                        'method': 'animate',
                        'label': str(k),
                        'args': [[str(k)], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate'}]
                    } for k in range(n_frames)
                ]
            }]
        ),
        frames=frames
    )
    
    fig.show()


In [22]:

# ---------------------------------------------------------
# RUNNING ON 'train.csv'
# ---------------------------------------------------------

# 1. Load the dataset
# We use nrows=10 initially to make sure it loads quickly for testing. 
# Remove 'nrows=10' later to load the whole file.
print("Loading train.csv...")
df = pd.read_csv('data/train.csv')

Loading train.csv...


In [33]:


# 2. Select a specific shot (row)
# Change the index (0) to view different shots
shot_index = 1
row_data = df.iloc[shot_index]

print(f"Visualizing shot at index {shot_index}...")

# 3. Call the visualizer
visualize_shot(row_data)

Visualizing shot at index 1...
Number of frames:240
Parsing data for 69 joints over 240 frames...
