In [11]:
import os
import cv2
import numpy as np
import mediapipe as mp
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import math
import torch

In [12]:
# Define paths for input and output
VIDEO_DIR = "SportsData/"  # Directory containing MP4 videos
OUTPUT_DIR = "GraphData/"   # Directory to save graph data

FRAME_SKIP = 3

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [13]:
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(
    static_image_mode=False,
    model_complexity=2,  # Use the most accurate model
    smooth_landmarks=True,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
)
mp_drawing = mp.solutions.drawing_utils

In [14]:
# Define the connections between body landmarks for graph creation
POSE_CONNECTIONS = [
    # Torso
    (mp_pose.PoseLandmark.LEFT_SHOULDER, mp_pose.PoseLandmark.RIGHT_SHOULDER),
    (mp_pose.PoseLandmark.LEFT_SHOULDER, mp_pose.PoseLandmark.LEFT_HIP),
    (mp_pose.PoseLandmark.RIGHT_SHOULDER, mp_pose.PoseLandmark.RIGHT_HIP),
    (mp_pose.PoseLandmark.LEFT_HIP, mp_pose.PoseLandmark.RIGHT_HIP),
    # Arms
    (mp_pose.PoseLandmark.LEFT_SHOULDER, mp_pose.PoseLandmark.LEFT_ELBOW),
    (mp_pose.PoseLandmark.RIGHT_SHOULDER, mp_pose.PoseLandmark.RIGHT_ELBOW),
    (mp_pose.PoseLandmark.LEFT_ELBOW, mp_pose.PoseLandmark.LEFT_WRIST),
    (mp_pose.PoseLandmark.RIGHT_ELBOW, mp_pose.PoseLandmark.RIGHT_WRIST),
    # Legs
    (mp_pose.PoseLandmark.LEFT_HIP, mp_pose.PoseLandmark.LEFT_KNEE),
    (mp_pose.PoseLandmark.RIGHT_HIP, mp_pose.PoseLandmark.RIGHT_KNEE),
    (mp_pose.PoseLandmark.LEFT_KNEE, mp_pose.PoseLandmark.LEFT_ANKLE),
    (mp_pose.PoseLandmark.RIGHT_KNEE, mp_pose.PoseLandmark.RIGHT_ANKLE),
]

In [15]:
# Function to calculate angle between three points
def calculate_angle(a, b, c):
    """
    Calculate the angle between three points.
    Args:
        a: first point [x, y, z]
        b: middle point [x, y, z] (vertex of the angle)
        c: end point [x, y, z]
    Returns:
        angle in degrees
    """
    # Convert to numpy arrays
    a = np.array(a)
    b = np.array(b)
    c = np.array(c)
    
    # Calculate vectors
    ba = a - b
    bc = c - b
    
    # Calculate dot product
    cosine_angle = np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))
    cosine_angle = np.clip(cosine_angle, -1.0, 1.0)  # Ensure value is within domain of arccos
    
    # Calculate angle in degrees
    angle = np.arccos(cosine_angle)
    angle = np.degrees(angle)
    
    return angle

In [16]:
# Define the key angles to calculate
def calculate_key_angles(landmarks):
    """
    Calculate key angles from landmarks.
    Args:
        landmarks: MediaPipe pose landmarks
    Returns:
        Dictionary of key angles
    """
    # Get positions as (x, y, z) tuples
    positions = {}
    for landmark in mp_pose.PoseLandmark:
        idx = landmark.value
        if idx < len(landmarks):
            positions[landmark] = [landmarks[idx].x, landmarks[idx].y, landmarks[idx].z]
    
    angles = {}
    
    # If any key landmarks are missing, return empty dict
    required_landmarks = [
        mp_pose.PoseLandmark.LEFT_SHOULDER, mp_pose.PoseLandmark.LEFT_ELBOW, mp_pose.PoseLandmark.LEFT_WRIST,
        mp_pose.PoseLandmark.RIGHT_SHOULDER, mp_pose.PoseLandmark.RIGHT_ELBOW, mp_pose.PoseLandmark.RIGHT_WRIST,
        mp_pose.PoseLandmark.LEFT_HIP, mp_pose.PoseLandmark.LEFT_KNEE, mp_pose.PoseLandmark.LEFT_ANKLE,
        mp_pose.PoseLandmark.RIGHT_HIP, mp_pose.PoseLandmark.RIGHT_KNEE, mp_pose.PoseLandmark.RIGHT_ANKLE
    ]
    
    if not all(landmark in positions for landmark in required_landmarks):
        return {}
    
    # Calculate elbow angles
    angles['left_elbow'] = calculate_angle(
        positions[mp_pose.PoseLandmark.LEFT_SHOULDER],
        positions[mp_pose.PoseLandmark.LEFT_ELBOW],
        positions[mp_pose.PoseLandmark.LEFT_WRIST]
    )
    
    angles['right_elbow'] = calculate_angle(
        positions[mp_pose.PoseLandmark.RIGHT_SHOULDER],
        positions[mp_pose.PoseLandmark.RIGHT_ELBOW],
        positions[mp_pose.PoseLandmark.RIGHT_WRIST]
    )
    
    # Calculate shoulder angles
    angles['left_shoulder'] = calculate_angle(
        positions[mp_pose.PoseLandmark.LEFT_HIP],
        positions[mp_pose.PoseLandmark.LEFT_SHOULDER],
        positions[mp_pose.PoseLandmark.LEFT_ELBOW]
    )
    
    angles['right_shoulder'] = calculate_angle(
        positions[mp_pose.PoseLandmark.RIGHT_HIP],
        positions[mp_pose.PoseLandmark.RIGHT_SHOULDER],
        positions[mp_pose.PoseLandmark.RIGHT_ELBOW]
    )
    
    # Calculate knee angles
    angles['left_knee'] = calculate_angle(
        positions[mp_pose.PoseLandmark.LEFT_HIP],
        positions[mp_pose.PoseLandmark.LEFT_KNEE],
        positions[mp_pose.PoseLandmark.LEFT_ANKLE]
    )
    
    angles['right_knee'] = calculate_angle(
        positions[mp_pose.PoseLandmark.RIGHT_HIP],
        positions[mp_pose.PoseLandmark.RIGHT_KNEE],
        positions[mp_pose.PoseLandmark.RIGHT_ANKLE]
    )
    
    # Calculate hip angles
    angles['left_hip'] = calculate_angle(
        positions[mp_pose.PoseLandmark.LEFT_SHOULDER],
        positions[mp_pose.PoseLandmark.LEFT_HIP],
        positions[mp_pose.PoseLandmark.LEFT_KNEE]
    )
    
    angles['right_hip'] = calculate_angle(
        positions[mp_pose.PoseLandmark.RIGHT_SHOULDER],
        positions[mp_pose.PoseLandmark.RIGHT_HIP],
        positions[mp_pose.PoseLandmark.RIGHT_KNEE]
    )
    
    return angles

In [17]:
# Function to create a graph from landmarks
def create_graph_from_landmarks(landmarks, angles, time_point, source_video, label):
    """
    Create a graph from MediaPipe pose landmarks.
    Args:
        landmarks: MediaPipe pose landmarks
        angles: Dictionary of calculated angles
        time_point: Time in the video
        source_video: Name of the source video
        label: 1 for hit, 0 for miss
    Returns:
        Graph data structure and node features
    """
    G = nx.Graph()
    
    # Add nodes with positional features
    for idx, landmark in enumerate(landmarks):
        G.add_node(idx, 
                  x=landmark.x, 
                  y=landmark.y, 
                  z=landmark.z, 
                  visibility=landmark.visibility)
    
    # Add edges according to pose connections
    for connection in POSE_CONNECTIONS:
        start_idx = connection[0].value
        end_idx = connection[1].value
        if start_idx < len(landmarks) and end_idx < len(landmarks):
            # Calculate Euclidean distance between nodes
            start_point = np.array([landmarks[start_idx].x, landmarks[start_idx].y, landmarks[start_idx].z])
            end_point = np.array([landmarks[end_idx].x, landmarks[end_idx].y, landmarks[end_idx].z])
            distance = np.linalg.norm(end_point - start_point)
            
            G.add_edge(start_idx, end_idx, weight=distance)
    
    # Prepare node features for GNN
    num_nodes = len(G.nodes())
    node_features = torch.zeros(num_nodes, 4)  # [x, y, z, visibility]
    
    for node in G.nodes():
        node_features[node, 0] = G.nodes[node]['x']
        node_features[node, 1] = G.nodes[node]['y']
        node_features[node, 2] = G.nodes[node]['z']
        node_features[node, 3] = G.nodes[node]['visibility']
    
    # Create edge index and edge attributes for PyTorch Geometric
    edge_index = []
    edge_attr = []
    
    for u, v, data in G.edges(data=True):
        edge_index.append([u, v])
        edge_index.append([v, u])  # Add in both directions for undirected graph
        
        edge_attr.append([data['weight']])
        edge_attr.append([data['weight']])
    
    if edge_index:  # Check if there are any edges
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    else:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, 1), dtype=torch.float)
    
    # Convert angles to tensor
    if angles:
        angle_keys = sorted(angles.keys())
        angle_tensor = torch.tensor([angles[k] for k in angle_keys], dtype=torch.float)
    else:
        angle_tensor = torch.zeros(8, dtype=torch.float)  # 8 angles we're calculating
    
    # Package everything together
    data = {
        'graph': G,
        'node_features': node_features,
        'edge_index': edge_index,
        'edge_attr': edge_attr,
        'angles': angle_tensor,
        'time': time_point,
        'source_video': source_video,
        'label': label
    }
    
    return data

In [18]:
# Modified process_video function (small change: return video name separately)
def process_video(video_path):
    """
    Process a video file to extract pose landmarks and create graph data.
    """
    video_name = os.path.basename(video_path)
    print(f"Processing {video_name} with frame skip = {FRAME_SKIP}...")

    # Determine label from video name
    if 'hit' in video_name.lower():
        label = 1  # Hit
    elif 'miss' in video_name.lower():
        label = 0  # Miss
    else:
        label = -1  # Unknown
        print(f"Warning: Could not determine label for {video_name}, defaulting to -1")

    # Open video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return [], video_name

    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    all_graph_data = []
    frame_count = 0
    processed_count = 0

    with tqdm(total=total_frames // FRAME_SKIP + 1) as pbar:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            if frame_count % FRAME_SKIP == 0:
                rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                results = pose.process(rgb_frame)
                time_point = frame_count / fps

                if results.pose_landmarks:
                    landmarks = results.pose_landmarks.landmark
                    angles = calculate_key_angles(landmarks)
                    graph_data = create_graph_from_landmarks(
                        landmarks, angles, time_point, video_name, label
                    )
                    all_graph_data.append(graph_data)
                    processed_count += 1

                pbar.update(1)

            frame_count += 1

    cap.release()
    print(f"Processed {processed_count} frames out of {frame_count} total frames from {video_name}")

    return all_graph_data, video_name


In [19]:
# Modified main loop
def process_all_videos():
    """
    Process all videos in the input directory and save each video's graph data separately.
    """
    video_files = [f for f in os.listdir(VIDEO_DIR) if f.endswith('.mp4')]
    
    if not video_files:
        print(f"No mp4 files found in {VIDEO_DIR}")
        return
    
    print(f"Found {len(video_files)} videos to process")
    
    for video_file in video_files:
        video_path = os.path.join(VIDEO_DIR, video_file)
        
        # Process the video
        video_data, video_name = process_video(video_path)
        
        if not video_data:
            print(f"Warning: No graph data extracted from {video_name}")
            continue
        
        # Clean the video name to create a safe filename
        base_name = os.path.splitext(video_name)[0]
        output_file = os.path.join(OUTPUT_DIR, f"{base_name}_graph_data.pkl")
        
        # Save the graph data
        with open(output_file, 'wb') as f:
            pickle.dump(video_data, f)
        
        print(f"Saved graph data for {video_name} to {output_file}")

    print("Finished processing all videos.")


In [20]:
process_all_videos()

Found 36 videos to process
Processing Clip10Hit.mp4 with frame skip = 3...


  0%|          | 0/24 [00:00<?, ?it/s]

100%|██████████| 24/24 [00:01<00:00, 19.05it/s]


Processed 0 frames out of 70 total frames from Clip10Hit.mp4
Processing Clip11Hit.mp4 with frame skip = 3...


100%|██████████| 25/25 [00:03<00:00,  6.25it/s]


Processed 25 frames out of 73 total frames from Clip11Hit.mp4
Saved graph data for Clip11Hit.mp4 to GraphData/Clip11Hit_graph_data.pkl
Processing Clip12Hit.mp4 with frame skip = 3...


100%|██████████| 23/23 [00:03<00:00,  6.89it/s]


Processed 23 frames out of 68 total frames from Clip12Hit.mp4
Saved graph data for Clip12Hit.mp4 to GraphData/Clip12Hit_graph_data.pkl
Processing Clip13Hit.mp4 with frame skip = 3...


100%|██████████| 19/19 [00:02<00:00,  7.18it/s]


Processed 18 frames out of 56 total frames from Clip13Hit.mp4
Saved graph data for Clip13Hit.mp4 to GraphData/Clip13Hit_graph_data.pkl
Processing Clip14Miss.mp4 with frame skip = 3...


100%|██████████| 23/23 [00:01<00:00, 13.34it/s]


Processed 1 frames out of 68 total frames from Clip14Miss.mp4
Saved graph data for Clip14Miss.mp4 to GraphData/Clip14Miss_graph_data.pkl
Processing Clip15Hit.mp4 with frame skip = 3...


100%|██████████| 18/18 [00:02<00:00,  7.42it/s]


Processed 18 frames out of 53 total frames from Clip15Hit.mp4
Saved graph data for Clip15Hit.mp4 to GraphData/Clip15Hit_graph_data.pkl
Processing Clip16Hit.mp4 with frame skip = 3...


100%|██████████| 24/24 [00:03<00:00,  7.39it/s]


Processed 24 frames out of 70 total frames from Clip16Hit.mp4
Saved graph data for Clip16Hit.mp4 to GraphData/Clip16Hit_graph_data.pkl
Processing Clip17Hit.mp4 with frame skip = 3...


100%|██████████| 16/16 [00:02<00:00,  7.50it/s]


Processed 16 frames out of 46 total frames from Clip17Hit.mp4
Saved graph data for Clip17Hit.mp4 to GraphData/Clip17Hit_graph_data.pkl
Processing Clip18Hit.mp4 with frame skip = 3...


 95%|█████████▌| 21/22 [00:02<00:00,  7.30it/s]


Processed 21 frames out of 63 total frames from Clip18Hit.mp4
Saved graph data for Clip18Hit.mp4 to GraphData/Clip18Hit_graph_data.pkl
Processing Clip19Hit.mp4 with frame skip = 3...


 95%|█████████▌| 21/22 [00:02<00:00,  7.69it/s]


Processed 21 frames out of 63 total frames from Clip19Hit.mp4
Saved graph data for Clip19Hit.mp4 to GraphData/Clip19Hit_graph_data.pkl
Processing Clip1Hit.mp4 with frame skip = 3...


 94%|█████████▍| 17/18 [00:02<00:00,  7.25it/s]


Processed 15 frames out of 51 total frames from Clip1Hit.mp4
Saved graph data for Clip1Hit.mp4 to GraphData/Clip1Hit_graph_data.pkl
Processing Clip20Miss.mp4 with frame skip = 3...


 96%|█████████▋| 26/27 [00:03<00:00,  7.77it/s]


Processed 26 frames out of 78 total frames from Clip20Miss.mp4
Saved graph data for Clip20Miss.mp4 to GraphData/Clip20Miss_graph_data.pkl
Processing Clip21Hit.mp4 with frame skip = 3...


100%|██████████| 22/22 [00:02<00:00,  7.72it/s]


Processed 22 frames out of 65 total frames from Clip21Hit.mp4
Saved graph data for Clip21Hit.mp4 to GraphData/Clip21Hit_graph_data.pkl
Processing Clip22Hit.mp4 with frame skip = 3...


 97%|█████████▋| 30/31 [00:04<00:00,  6.45it/s]


Processed 30 frames out of 90 total frames from Clip22Hit.mp4
Saved graph data for Clip22Hit.mp4 to GraphData/Clip22Hit_graph_data.pkl
Processing Clip23Hit.mp4 with frame skip = 3...


100%|██████████| 22/22 [00:03<00:00,  7.16it/s]


Processed 22 frames out of 65 total frames from Clip23Hit.mp4
Saved graph data for Clip23Hit.mp4 to GraphData/Clip23Hit_graph_data.pkl
Processing Clip24Hit.mp4 with frame skip = 3...


 94%|█████████▍| 17/18 [00:02<00:00,  7.46it/s]


Processed 17 frames out of 51 total frames from Clip24Hit.mp4
Saved graph data for Clip24Hit.mp4 to GraphData/Clip24Hit_graph_data.pkl
Processing Clip25Miss.mp4 with frame skip = 3...


100%|██████████| 17/17 [00:02<00:00,  7.45it/s]


Processed 15 frames out of 50 total frames from Clip25Miss.mp4
Saved graph data for Clip25Miss.mp4 to GraphData/Clip25Miss_graph_data.pkl
Processing Clip26Hit.mp4 with frame skip = 3...


100%|██████████| 17/17 [00:02<00:00,  7.43it/s]


Processed 17 frames out of 49 total frames from Clip26Hit.mp4
Saved graph data for Clip26Hit.mp4 to GraphData/Clip26Hit_graph_data.pkl
Processing Clip27Hit.mp4 with frame skip = 3...


100%|██████████| 21/21 [00:02<00:00,  7.83it/s]


Processed 21 frames out of 61 total frames from Clip27Hit.mp4
Saved graph data for Clip27Hit.mp4 to GraphData/Clip27Hit_graph_data.pkl
Processing Clip28Miss.mp4 with frame skip = 3...


100%|██████████| 20/20 [00:02<00:00,  7.72it/s]


Processed 20 frames out of 58 total frames from Clip28Miss.mp4
Saved graph data for Clip28Miss.mp4 to GraphData/Clip28Miss_graph_data.pkl
Processing Clip29Hit.mp4 with frame skip = 3...


 95%|█████████▍| 18/19 [00:02<00:00,  7.61it/s]


Processed 18 frames out of 54 total frames from Clip29Hit.mp4
Saved graph data for Clip29Hit.mp4 to GraphData/Clip29Hit_graph_data.pkl
Processing Clip2Hit.mp4 with frame skip = 3...


 96%|█████████▌| 25/26 [00:03<00:00,  7.20it/s]


Processed 25 frames out of 75 total frames from Clip2Hit.mp4
Saved graph data for Clip2Hit.mp4 to GraphData/Clip2Hit_graph_data.pkl
Processing Clip30Miss.mp4 with frame skip = 3...


100%|██████████| 16/16 [00:02<00:00,  7.16it/s]


Processed 16 frames out of 47 total frames from Clip30Miss.mp4
Saved graph data for Clip30Miss.mp4 to GraphData/Clip30Miss_graph_data.pkl
Processing Clip31Hit.mp4 with frame skip = 3...


 96%|█████████▌| 22/23 [00:03<00:00,  6.95it/s]


Processed 22 frames out of 66 total frames from Clip31Hit.mp4
Saved graph data for Clip31Hit.mp4 to GraphData/Clip31Hit_graph_data.pkl
Processing Clip32Hit.mp4 with frame skip = 3...


100%|██████████| 29/29 [00:03<00:00,  7.33it/s]


Processed 29 frames out of 86 total frames from Clip32Hit.mp4
Saved graph data for Clip32Hit.mp4 to GraphData/Clip32Hit_graph_data.pkl
Processing Clip33Hit.mp4 with frame skip = 3...


100%|██████████| 12/12 [00:01<00:00,  6.04it/s]


Processed 12 frames out of 34 total frames from Clip33Hit.mp4
Saved graph data for Clip33Hit.mp4 to GraphData/Clip33Hit_graph_data.pkl
Processing Clip34Hit.mp4 with frame skip = 3...


100%|██████████| 25/25 [00:03<00:00,  7.88it/s]


Processed 21 frames out of 73 total frames from Clip34Hit.mp4
Saved graph data for Clip34Hit.mp4 to GraphData/Clip34Hit_graph_data.pkl
Processing Clip35Miss.mp4 with frame skip = 3...


100%|██████████| 23/23 [00:03<00:00,  7.14it/s]


Processed 20 frames out of 68 total frames from Clip35Miss.mp4
Saved graph data for Clip35Miss.mp4 to GraphData/Clip35Miss_graph_data.pkl
Processing Clip36Hit.mp4 with frame skip = 3...


 97%|█████████▋| 29/30 [00:03<00:00,  9.62it/s]


Processed 8 frames out of 87 total frames from Clip36Hit.mp4
Saved graph data for Clip36Hit.mp4 to GraphData/Clip36Hit_graph_data.pkl
Processing Clip3Hit.mp4 with frame skip = 3...


100%|██████████| 26/26 [00:03<00:00,  6.95it/s]


Processed 23 frames out of 76 total frames from Clip3Hit.mp4
Saved graph data for Clip3Hit.mp4 to GraphData/Clip3Hit_graph_data.pkl
Processing Clip4Hit.mp4 with frame skip = 3...


100%|██████████| 27/27 [00:03<00:00,  7.97it/s]


Processed 18 frames out of 79 total frames from Clip4Hit.mp4
Saved graph data for Clip4Hit.mp4 to GraphData/Clip4Hit_graph_data.pkl
Processing Clip5Hit.mp4 with frame skip = 3...


100%|██████████| 23/23 [00:03<00:00,  5.87it/s]


Processed 9 frames out of 67 total frames from Clip5Hit.mp4
Saved graph data for Clip5Hit.mp4 to GraphData/Clip5Hit_graph_data.pkl
Processing Clip6Hit.mp4 with frame skip = 3...


100%|██████████| 28/28 [00:02<00:00,  9.65it/s]


Processed 20 frames out of 83 total frames from Clip6Hit.mp4
Saved graph data for Clip6Hit.mp4 to GraphData/Clip6Hit_graph_data.pkl
Processing Clip7Miss.mp4 with frame skip = 3...


100%|██████████| 20/20 [00:02<00:00,  7.62it/s]


Processed 18 frames out of 59 total frames from Clip7Miss.mp4
Saved graph data for Clip7Miss.mp4 to GraphData/Clip7Miss_graph_data.pkl
Processing Clip8Hit.mp4 with frame skip = 3...


 96%|█████████▌| 25/26 [00:03<00:00,  7.81it/s]


Processed 23 frames out of 75 total frames from Clip8Hit.mp4
Saved graph data for Clip8Hit.mp4 to GraphData/Clip8Hit_graph_data.pkl
Processing Clip9Miss.mp4 with frame skip = 3...


100%|██████████| 14/14 [00:01<00:00,  7.49it/s]

Processed 14 frames out of 41 total frames from Clip9Miss.mp4
Saved graph data for Clip9Miss.mp4 to GraphData/Clip9Miss_graph_data.pkl
Finished processing all videos.



