# GymPal Exercise Form Detection

This notebook contains the consolidated code for the GymPal project, which uses computer vision and pose estimation to detect and analyze exercise form.

## Installation

In [None]:
%pip install -r requirements.txt

## 1. Process Videos
Process raw exercise videos to standardize their length and frame rate.

In [None]:
import os
from moviepy.editor import VideoFileClip
from tqdm import tqdm


def process_video(input_path, output_path, target_duration=5, target_fps=24):
    """
    Process a video:
    1. Crop to target duration (first 5 seconds)
    2. Standardize to target FPS
    """
    try:
        # Load the video
        clip = VideoFileClip(input_path)

        # Crop to first 5 seconds
        if clip.duration > target_duration:
            clip = clip.subclip(0, target_duration)

        # Set the FPS
        clip = clip.set_fps(target_fps)

        # Write the processed video
        clip.write_videofile(output_path, codec="libx264", fps=target_fps)

        # Close the clip to release resources
        clip.close()

        return True
    except Exception as e:
        print(f"Error processing {input_path}: {str(e)}")
        return False


def process_all_videos(input_dir, output_dir):
    """Process all videos in the input directory"""
    # Get all video files
    video_files = [
        f for f in os.listdir(input_dir) if f.endswith((".mp4", ".avi", ".mov"))
    ]

    if not video_files:
        print(f"No video files found in {input_dir}")
        return

    print(f"Found {len(video_files)} videos to process")

    # Process each video
    for video_file in tqdm(video_files, desc="Processing videos"):
        input_path = os.path.join(input_dir, video_file)
        output_path = os.path.join(output_dir, video_file)
        process_video(input_path, output_path)

    print("Video processing complete")

In [None]:
# Run the processing
input_dir = "exercises/deadlift"
output_dir = "deadlift_processed"

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
process_all_videos(input_dir, output_dir)

## 2. Create Snapshots
Interactive tool to extract and annotate key frames from videos.

In [None]:
import os
import cv2
import numpy as np

DISPLAY_WIDTH = 1280
DISPLAY_HEIGHT = 720
WINDOW_NAME = "Video Annotation"

SOURCE_DIR = "deadlift_processed"
SNAPSHOT_DIR = "deadlift_snapshots"
os.makedirs(SNAPSHOT_DIR, exist_ok=True)


def show_start_screen():
    """Displays a splash screen before annotation starts."""
    start_screen = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)

    cv2.putText(
        start_screen,
        "Press any key to start...",
        (int(DISPLAY_WIDTH * 0.3), int(DISPLAY_HEIGHT * 0.5)),
        cv2.FONT_HERSHEY_SIMPLEX,
        1,
        (255, 255, 255),
        2,
    )

    cv2.imshow(WINDOW_NAME, start_screen)
    cv2.waitKey(0)  # Wait indefinitely until a key is pressed


def save_snapshot(frame, video_name, frame_pos, phase=None):
    """Save a snapshot of the current frame with frame number and phase."""
    # Create video-specific subfolder
    video_snapshot_dir = os.path.join(SNAPSHOT_DIR, os.path.splitext(video_name)[0])
    os.makedirs(video_snapshot_dir, exist_ok=True)

    # Create filename with frame number and phase
    phase_str = f"_{phase}" if phase else ""
    filename = f"frame_{frame_pos}{phase_str}.jpg"
    filepath = os.path.join(video_snapshot_dir, filename)

    # Save the frame
    cv2.imwrite(filepath, frame)


def annotate_video(video_path):
    """Plays a video for annotation with rewind, fast-forward, and snapshot."""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error opening video file: {video_path}")
        return None

    video_name = os.path.basename(video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))  # Frames per second
    annotations = []
    last_key = None
    paused = False

    while True:
        if not paused:
            ret, frame = cap.read()
            if not ret:
                break  # Stop if the video ends

        frame_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES))  # Get current frame

        # Resize and display the frame
        frame = cv2.resize(frame, (DISPLAY_WIDTH, DISPLAY_HEIGHT))
        cv2.imshow(WINDOW_NAME, frame)

        key = cv2.waitKey(10) & 0xFF  # Adjusted delay for smooth playback

        if key == ord(" "):  # Pause/Play
            paused = not paused

        elif key == ord("a"):  # Rewind 1 second
            frame_pos = max(0, frame_pos - fps)
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)

        elif key == ord("d"):  # Fast-forward 1 second
            frame_pos += fps
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)

        elif key == ord("t") and last_key != "t":  # Top annotation
            annotations.append((frame_pos, "top"))
            save_snapshot(frame, video_name, frame_pos, "top")
            last_key = "t"

        elif key == ord("b") and last_key != "b":  # Bottom annotation
            annotations.append((frame_pos, "bottom"))
            save_snapshot(frame, video_name, frame_pos, "bottom")
            last_key = "b"

        elif key == ord("q"):  # Quit
            break

    cap.release()
    return annotations

In [None]:
# Get all video files
video_files = [
    os.path.join(SOURCE_DIR, f) for f in os.listdir(SOURCE_DIR) if f.endswith(".mp4")
]

# Store annotations
all_annotations = {}

def run_annotation():
    if video_files:
        cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(WINDOW_NAME, DISPLAY_WIDTH, DISPLAY_HEIGHT)

        # Show splash screen
        show_start_screen()

    current_video_index = 0
    while current_video_index < len(video_files):
        video_file = video_files[current_video_index]

        annotations = annotate_video(video_file)
        if annotations:
            all_annotations[os.path.basename(video_file)] = annotations

        # After annotation, ask for next action
        action_screen = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
        cv2.putText(
            action_screen,
            f"Video {current_video_index + 1}/{len(video_files)} completed",
            (int(DISPLAY_WIDTH * 0.3), int(DISPLAY_HEIGHT * 0.4)),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (255, 255, 255),
            2,
        )
        cv2.putText(
            action_screen,
            "Press 'd' for next video, 'a' for previous video, 'q' to quit",
            (int(DISPLAY_WIDTH * 0.2), int(DISPLAY_HEIGHT * 0.5)),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.8,
            (255, 255, 255),
            2,
        )

        cv2.imshow(WINDOW_NAME, action_screen)
        key = cv2.waitKey(0) & 0xFF

        if key == ord("d"):  # Next video
            current_video_index += 1
        elif key == ord("a") and current_video_index > 0:  # Previous video
            current_video_index -= 1
        elif key == ord("q"):  # Quit
            break      

    cv2.destroyAllWindows()
    print("\nAnnotation complete!")

# Run the annotation code when needed
# run_annotation()

## 3. Extract Keypoints
Extract pose keypoints from the annotated frames.

In [None]:
import os
import cv2
import csv
import mediapipe as mp
from tqdm import tqdm

SOURCE_DIR = "deadlift_snapshots"
OUTPUT_DIR = "deadlift_keypoints"
os.makedirs(OUTPUT_DIR, exist_ok=True)

mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils

pose = mp_pose.Pose(
    static_image_mode=True,
    min_detection_confidence=0.6,
)

KEYPOINTS = {
    "shoulder_left": 11,
    "shoulder_right": 12,
    "elbow_left": 13,
    "elbow_right": 14,
    "hip_left": 23,
    "hip_right": 24,
    "knee_left": 25,
    "knee_right": 26,
    "ankle_left": 27,
    "ankle_right": 28,
}

In [None]:
def extract_keypoints():
    csv_filename = "deadlift_keypoints.csv"
    csv_file = open(csv_filename, "w", newline="")
    csv_writer = csv.writer(csv_file)

    header = ["video_file", "frame_no"]
    for keypoint in KEYPOINTS.keys():
        header.extend([f"{keypoint}_x", f"{keypoint}_y", f"{keypoint}_z"])
    header.append("label")  # Add label column
    csv_writer.writerow(header)

    for video_name in tqdm(os.listdir(SOURCE_DIR), desc="Detecting keypoints"):
        video_folder = os.path.join(SOURCE_DIR, video_name)

        for image_file in os.listdir(video_folder):
            parts = image_file.split("_")
            frame_no = int(parts[1])
            label = parts[2].split('.')[0]
            row = [video_name, frame_no]

            image_path = os.path.join(video_folder, image_file)
            image = cv2.imread(image_path)
            frame_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            results = pose.process(frame_rgb)
            if results.pose_landmarks:
                for idx in KEYPOINTS.values():
                    landmark = results.pose_landmarks.landmark[idx]
                    row.extend([landmark.x, landmark.y, landmark.z])
            else:
                print(f"No pose landmarks detected in: {image_path}")
                continue

            row.append(label)
            csv_writer.writerow(row)
    
    csv_file.close()
    print(f"Keypoints saved to {csv_filename}")

# Run keypoint extraction when needed
# extract_keypoints()

## 4. Augment Training Data
Extract additional frames around snapshots to increase training data volume.

In [None]:
import os
import cv2
from tqdm import tqdm

def extract_frames_around_snapshots(source_dir, snapshots_dir, frame_range=2):
    """
    Extract frames around the snapshot frames from processed videos and save them
    to the same snapshot directory structure.
    
    Args:
        source_dir: Directory containing processed videos
        snapshots_dir: Directory containing snapshots organized by video name
        frame_range: Number of frames to extract before and after the snapshot frame
    """
    # Get all video files
    video_files = [f for f in os.listdir(source_dir) if f.endswith((".mp4", ".avi", ".mov"))]
    
    if not video_files:
        print(f"No video files found in {source_dir}")
        return
    
    print(f"Found {len(video_files)} videos to process")
    
    for video_file in tqdm(video_files, desc="Processing videos"):
        video_name = os.path.splitext(video_file)[0]
        video_path = os.path.join(source_dir, video_file)
        
        # Check if this video has snapshots
        video_snapshot_dir = os.path.join(snapshots_dir, video_name)
        if not os.path.exists(video_snapshot_dir):
            print(f"No snapshots found for {video_name}, skipping...")
            continue
        
        # Get all snapshot files for this video
        snapshot_files = [f for f in os.listdir(video_snapshot_dir) if f.endswith(".jpg")]
        
        if not snapshot_files:
            print(f"No snapshot files found for {video_name}, skipping...")
            continue
        
        # Extract frame numbers and phases from snapshot filenames
        frame_info = []
        for snapshot_file in snapshot_files:
            # Example: frame_54_top.jpg
            try:
                parts = snapshot_file.split('_')
                if len(parts) == 3 and parts[0] == "frame":
                    frame_no = int(parts[1])
                    phase = parts[2].split('.')[0]  # Get 'top' or 'bottom' without extension
                    frame_info.append((frame_no, phase))
            except (ValueError, IndexError):
                print(f"Could not parse frame info from {snapshot_file}, skipping...")
                continue
        
        # Open the video
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Could not open {video_path}, skipping...")
            continue
        
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        for frame_no, phase in frame_info:
            frames_to_extract = []
            for offset in range(-frame_range, frame_range + 1):
                # Skip the target frame (offset 0)
                if offset == 0:
                    continue
                
                target_frame = frame_no + offset
                if 0 <= target_frame < total_frames:
                    frames_to_extract.append(target_frame)
            
            # Extract only the specific frames
            for f in frames_to_extract:
                # Set the frame position
                cap.set(cv2.CAP_PROP_POS_FRAMES, f)
                ret, frame = cap.read()
                
                if not ret:
                    print(f"Could not read frame {f} from {video_path}, skipping...")
                    continue
                
                # Save the frame to the same snapshot directory with phase information
                frame_filename = f"frame_{f}_{phase}.jpg"
                output_path = os.path.join(video_snapshot_dir, frame_filename)
                cv2.imwrite(output_path, frame)
        
        # Release the video
        cap.release()
    
    print("Frame extraction complete!")

In [None]:
# Run the augmentation
source_dir = "deadlift_processed"
snapshots_dir = "deadlift_snapshots"

# extract_frames_around_snapshots(source_dir, snapshots_dir)

## 5. Train Model
Train a machine learning model to classify poses using the extracted keypoints.

In [None]:
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report

def train_model():
    # Load dataset
    data = pd.read_csv("deadlift_keypoints.csv")

    # Drop non-numeric and unnecessary columns
    X = data.iloc[:, 2:-1]  # Exclude video_file, frame_no, and label
    y = data["label"].map({"top": 1, "bottom": 0})  # Encode labels

    # Split dataset
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    # Fit StandardScaler only on the training set
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Define models for grid search
    models = {
        "RandomForest": RandomForestClassifier(),
        "SVM": SVC(),
        "KNN": KNeighborsClassifier(),
    }

    # Define parameter grids
    param_grids = {
        "RandomForest": {"n_estimators": [50, 100, 200], "max_depth": [None, 10, 20]},
        "SVM": {"C": [0.1, 1, 10], "kernel": ["linear", "rbf"]},
        "KNN": {"n_neighbors": [3, 5, 7]},
    }

    # Perform grid search with cross-validation
    best_models = {}
    for name, model in models.items():
        grid_search = GridSearchCV(
            model, param_grids[name], cv=5, scoring="accuracy", n_jobs=-1
        )
        grid_search.fit(X_train, y_train)
        best_models[name] = grid_search.best_estimator_
        print(f"Best {name} model: {grid_search.best_params_}")

    # Find the best performing model by comparing test accuracy scores
    best_model_name = max(best_models, key=lambda k: best_models[k].score(X_test, y_test))
    best_model = best_models[best_model_name]
    print(f"Selected best model: {best_model_name}")

    y_pred = best_model.predict(X_test)
    print(classification_report(y_test, y_pred))

    # Save the model to disk
    model_filename = 'deadlift_classifier_model.pkl'
    with open(model_filename, 'wb') as file:
        pickle.dump(best_model, file)

    scaler_filename = "deadlift_classifier_scaler.pkl"
    with open(scaler_filename, "wb") as file:
        pickle.dump(scaler, file)
    print(f"Model saved to {model_filename}")
    
    return best_model, scaler

# Train the model when needed
# best_model, scaler = train_model()

## 6. Inference
Run real-time inference using the trained model.

In [None]:
import mediapipe as mp
import numpy as np
import cv2
import pickle

def run_inference():
    class_map = {0: 'bottom', 1: 'top'}
    
    with open("deadlift_classifier_scaler.pkl", "rb") as f:
        scaler = pickle.load(f)

    with open("deadlift_classifier_model.pkl", "rb") as f:
        model = pickle.load(f)

    mp_utils = mp.solutions.drawing_utils
    mp_pose = mp.solutions.pose

    KEYPOINTS = {
        "shoulder_left": 11,
        "shoulder_right": 12,
        "elbow_left": 13,
        "elbow_right": 14,
        "hip_left": 23,
        "hip_right": 24,
        "knee_left": 25,
        "knee_right": 26,
        "ankle_left": 27,
        "ankle_right": 28,
    }
    pose = mp_pose.Pose(min_detection_confidence=0.7, min_tracking_confidence=0.6)

    reps = 0
    prev_phase = None
    cap = cv2.VideoCapture(0)  # Use 0 for webcam, or provide video path
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)

    while True:
        ret, img = cap.read()
        if not ret:
            break

        frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        results = pose.process(frame)

        phase_text = "-"
        if results.pose_landmarks:
            input = []
            for idx in KEYPOINTS.values():
                landmark = results.pose_landmarks.landmark[idx]
                input.extend([landmark.x, landmark.y, landmark.z])

            input = scaler.transform([input])
            output = model.predict_proba(input)[0]

            idx = np.argmax(output)
            prob = output[idx]

            if prob > 0.8: # threshold to configure
                curr_phase = class_map[idx]
                phase_text = curr_phase
                # Count a rep when moving from bottom to top
                if prev_phase == "bottom" and curr_phase == "top":
                    reps += 1
                prev_phase = curr_phase

            mp_utils.draw_landmarks(img, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)

        cv2.putText(img, f"Phase: {phase_text}", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 3, cv2.LINE_AA)
        cv2.putText(img, f"Reps: {reps}", (1000, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 3, cv2.LINE_AA)

        cv2.imshow('Deadlift Counter', img)
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break

    cap.release()
    cv2.destroyAllWindows()

# Run inference when needed
# run_inference()