# 1. Imports

In [4]:
import os
import cv2
import torch
import numpy as np
import torch.nn.functional as F
import mediapipe as mp
from collections import deque
from pathlib import Path
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torch.optim import Adam
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from tqdm import tqdm
import math
from torch.utils.data import Dataset 

# 2. Keypoint Extraction from Training Videos with MediaPipe  
Before training a model for tasks like pose estimation, exercise recognition, or movement analysis, the first step is to extract relevant features from the data. In this case, the data consists of training videos, and the relevant features are pose keypoints—3D coordinates representing specific body parts (like shoulders, elbows, knees, etc.).  

This code uses MediaPipe Pose to extract keypoints from training videos. These keypoints are critical for model training as they represent the underlying body movements that the model will need to learn. Below is a breakdown of the process:

## 2.1 Keypoint extraction and saving as numpy files

In [None]:
# # Keypoint extraction from Training Videos With Mediapipe

# # 0. Quiet TensorFlow/absl
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# # 1. MediaPipe Pose setup
# pose = mp.solutions.pose.Pose(
#     static_image_mode=False,
#     model_complexity=2,
#     enable_segmentation=False,
#     min_detection_confidence=0.5,
#     min_tracking_confidence=0.5
# )

# # 2. Paths
# VIDEO_ROOT = Path("Data-REHAB24-6/videos")
# OUT_ROOT   = Path("Data-REHAB24-6/mp_keypoints")
# OUT_ROOT.mkdir(exist_ok=True)

# # 3. Worker
# def process_video(vid_path: Path):
#     rel     = vid_path.parent.name            # e.g. "Ex1"
#     out_dir = OUT_ROOT / rel
#     out_dir.mkdir(parents=True, exist_ok=True)
#     out_file = out_dir / f"{vid_path.stem}-mp.npy"

#     print(f"\n→ Processing: {vid_path.name}")
#     print(f"   From:      {vid_path.parent}")
#     print(f"   To folder: {out_dir}")

#     cap    = cv2.VideoCapture(str(vid_path))
#     frames = []
#     count  = 0

#     while True:
#         ret, frame = cap.read()
#         if not ret:
#             break
#         count += 1

#         img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#         res = pose.process(img)
#         lm  = res.pose_world_landmarks.landmark if res.pose_world_landmarks else []

#         if lm:
#             pts = [(p.x, p.y, p.z) for p in lm]
#         else:
#             pts = [(0.0, 0.0, 0.0)] * 33

#         frames.append(pts)

#     cap.release()

#     arr = np.array(frames, dtype=np.float32)
#     np.save(out_file, arr)

#     print(f"✔ Saved: {out_file}  (frames={count}, shape={arr.shape})")

# # 4. Run — only Ex1 through Ex5
# for i in range(1, 6):
#     ex_dir = VIDEO_ROOT / f"Ex{i}"
#     if not ex_dir.is_dir():
#         print(f"⚠️  Skipping missing folder: {ex_dir}")
#         continue

#     for vid in sorted(ex_dir.glob("*.mp4")):
#         try:
#             process_video(vid)
#         except Exception as e:
#             print(f"✘ Failed processing {vid.name}: {e}")

# print("\nAll requested videos processed.")


## 2.2 Inspecting at a numpy file containing 3D keypoints

In [6]:
# ─── edit this to your target file ────────────────────────────────────────────
file_path = Path("Data-REHAB24-6/mp_keypoints/Ex6/PM_008-Camera17-30fps-mp.npy")
# ────────────────────────────────────────────────────────────────────────────────

# load
arr = np.load(file_path)

# basic info
print(f"Loaded: {file_path}")
print(f" dtype: {arr.dtype}")
print(f" shape: {arr.shape}  (frames × landmarks × coords)\n")

# show first few frames
n_show = min(3, arr.shape[0])
for i in range(n_show):
    print(f"Frame #{i:03d} (33×3):")
    print(arr[i])
    print(f"  → first landmark: {tuple(arr[i,0])}\n")

# overall statistics
print("Overall coordinate stats:")
for idx, name in enumerate(("x", "y", "z")):
    data = arr[..., idx]
    print(f"  {name}: min={data.min():.3f}, max={data.max():.3f}, mean={data.mean():.3f}")


Loaded: Data-REHAB24-6/mp_keypoints/Ex6/PM_008-Camera17-30fps-mp.npy
 dtype: float32
 shape: (5191, 33, 3)  (frames × landmarks × coords)

Frame #000 (33×3):
[[-0.03379065 -0.6112996  -0.22619084]
 [-0.02455677 -0.628223   -0.2275483 ]
 [-0.0257254  -0.6304051  -0.21732828]
 [-0.02547118 -0.6302204  -0.21820049]
 [-0.02778288 -0.6363151  -0.2358914 ]
 [-0.02677054 -0.63463634 -0.24706481]
 [-0.02263773 -0.61979294 -0.2281486 ]
 [ 0.02648694 -0.6184615  -0.16859515]
 [-0.03559308 -0.56201506 -0.14808732]
 [ 0.0030987  -0.59678274 -0.19447449]
 [-0.01712019 -0.55977213 -0.21580447]
 [ 0.1250833  -0.49333623 -0.08702794]
 [-0.05722423 -0.53108674 -0.02576461]
 [ 0.15178505 -0.51440114 -0.09251688]
 [-0.17247145 -0.4994753  -0.05377672]
 [ 0.14430666 -0.54461473 -0.03090633]
 [-0.32217076 -0.5845331  -0.08351779]
 [ 0.1352469  -0.5789426  -0.02217976]
 [-0.33411983 -0.62724733 -0.13406767]
 [ 0.11857966 -0.58959824 -0.04421883]
 [-0.30619952 -0.65674794 -0.14033641]
 [ 0.13962431 -0.533594

# 3. Data Preparation for Pose Error Analysis  

## 3.1 Change the correctness of half profile vidoes to incorrect 

This change is being made because the half-profile camera angle is considered insufficient or unreliable for accurately determining if the exercise repetition was performed correctly according to the established standards. Furthermore, for model inference, the intention is to use only the front-facing view. Therefore, this action serves as a data cleaning and preparation step to improve the quality and consistency of the data that will be used for subsequent analysis or modeling, ensuring that only reliably assessed repetitions are marked as 'correct'.

In [7]:
# # --- Configuration ---
# DATA_ROOT = Path("Data-REHAB24-6") # Adjust if your data folder has a different relative path
# ORIGINAL_FILENAME = "Segmentation_original.xlsx"
# # --- CHANGE HERE: Update the output filename extension ---
# NEW_FILENAME = "Segmentation.xlsx" # Saving as .xlsx

# FILE_ORIG = DATA_ROOT / ORIGINAL_FILENAME
# FILE_DEST = DATA_ROOT / NEW_FILENAME

# # --- Processing ---
# if not DATA_ROOT.is_dir():
#     print(f"Error: Data directory not found at '{DATA_ROOT.resolve()}'")
# elif not FILE_ORIG.is_file():
#     print(f"Error: Original file not found at '{FILE_ORIG.resolve()}'")
# else:
#     print(f"Loading original file: {FILE_ORIG}")
#     try:
#         # Load the Excel file
#         df = pd.read_excel(FILE_ORIG)
#         print("Original file loaded successfully.")
#         print(f"Original shape: {df.shape}")

#         # Identify rows where 'cam17_orientation' is 'half-profile'
#         condition = df['cam17_orientation'] == 'half-profile'
#         num_rows_to_change = condition.sum()
#         print(f"Found {num_rows_to_change} rows where 'cam17_orientation' is 'half-profile'.")

#         if num_rows_to_change > 0:
#             # Change 'correctness' to 0 for matched rows
#             print("Updating 'correctness' column to 0 for matched rows...")
#             df.loc[condition, 'correctness'] = 0
#             print("Update complete.")
#         else:
#             print("No rows matched the condition. 'correctness' column remains unchanged.")

#         # Save the modified DataFrame to a new Excel file (.xlsx)
#         print(f"Saving modified data to: {FILE_DEST}")
#         try:
#             # Saving to .xlsx uses openpyxl engine by default (pip install openpyxl if needed)
#             df.to_excel(FILE_DEST, index=False)
#             print(f"Successfully saved modified data to '{FILE_DEST.resolve()}'")
#         except Exception as save_error:
#             # --- CHANGE HERE: Updated error message for .xlsx ---
#             print(f"Error saving file to {FILE_DEST}: {save_error}")
#             print("Saving to .xlsx format typically requires the 'openpyxl' package. Try: pip install openpyxl")

#     except FileNotFoundError:
#          print(f"Error: Make sure the file exists at {FILE_ORIG.resolve()}")
#     except KeyError as e:
#          print(f"Error: Column not found - {e}. Please check column names in '{ORIGINAL_FILENAME}'.")
#          print(f"Available columns are: {df.columns.tolist()}")
#     except Exception as e:
#         print(f"An unexpected error occurred during processing: {e}")

# # --- Verification ---
# if FILE_DEST.is_file():
#     print("\nVerifying the saved file...")
#     try:
#         # --- CHANGE HERE: Reading the .xlsx file for verification ---
#         df_new = pd.read_excel(FILE_DEST)
#         print(f"Loaded new file shape: {df_new.shape}")
#         check_condition = df_new['cam17_orientation'] == 'half-profile'
#         incorrect_rows = df_new.loc[check_condition & (df_new['correctness'] != 0)]
#         if incorrect_rows.empty:
#             print("Verification successful: All 'half-profile' rows have 'correctness' set to 0.")
#         else:
#             print("Verification FAILED: Some 'half-profile' rows still have 'correctness' != 0.")
#             print(incorrect_rows)
#     except Exception as e:
#         print(f"Error during verification: {e}")
# else:
#      print(f"\nCould not verify as the destination file '{FILE_DEST}' was not found or not saved.")

## 3.2 Sliding Window-Based Pose Error Calculation for Video segments  
This script processes keypoint data from videos to compute errors in body joint angles relative to "ideal" angles, and then generates sliding windows of these errors. This is done to prepare features that will later be used for model training in tasks like exercise recognition or pose correction.

What is done: For each exercise in the dataset, the script calculates the "ideal" joint angles by selecting the middle frame of each correct repetition (based on the correctness label). It calculates the angles between the joints defined in JOINT_TRIPLETS for each of these frames.  We use 17 out of the 33 MediaPipe landmarks for the ideal joint angle calculation. The used mediapipe landmarks are: NOSE, LEFT_SHOULDER, RIGHT_SHOULDER, LEFT_ELBOW, RIGHT_ELBOW, LEFT_WRIST, RIGHT_WRIST, LEFT_INDEX, RIGHT_INDEX, LEFT_HIP, RIGHT_HIP, LEFT_KNEE, RIGHT_KNEE, LEFT_ANKLE, RIGHT_ANKLE, LEFT_FOOT_INDEX, and RIGHT_FOOT_INDEX. These cover the major joints and end‑effectors (shoulders through wrists and hips through ankles, plus the spine/head via the nose) needed to compute all our angle‑based error metrics for the six rehab exercises. The 16 unused landmarks are all the fine‑grain facial points (inner/outer eyes, ears, mouth corners), the pinky and thumb tips, and the heel points. Since our focus is on gross limb alignment (arm and leg joint planes) rather than facial expression, finger articulation, or detailed foot posture, those landmarks don’t contribute to correcting the targeted movements and so are omitted.

The median angle is then computed for each joint across the correct repetitions of the exercise. These median values represent the "ideal" angles the model should aim for in perfect executions of the exercise.  

In the next step, the script generates sliding windows of angular errors, calculated as the difference between the observed angles in the video and the ideal angles. These windows contain temporal sequences of error data that are used as features for model training. The windowed data is then saved in a CSV file, which includes additional metadata such as exercise ID, repetition number, and frame indices.  

Why it’s done: Calculating the ideal angles for each exercise provides a reference for identifying errors during subsequent video frames. These ideal angles will serve as the baseline for assessing whether a movement is performed correctly or incorrectly. By generating and saving the sliding window data, the script prepares the dataset for supervised learning, allowing the model to analyze temporal error patterns over a series of frames. This windowed approach is crucial for the model to learn dynamic movements and classify whether exercises are performed correctly based on the computed joint angles. The segmentation data in the CSV file offers a structured and labeled dataset, which helps in efficient training and evaluation of the model.  


In [None]:
# # 1. helpers --------------------------------------------------
# def angle_between(a,b,c):
#     BA = a-b; BC = c-b
#     cosθ = np.dot(BA,BC)/(np.linalg.norm(BA)*np.linalg.norm(BC))
#     return math.degrees(math.acos(np.clip(cosθ,-1,1)))

# PoseLandmark = mp.solutions.pose.PoseLandmark
# JOINT_TRIPLETS = {
#     "LEFT_ELBOW":   (PoseLandmark.LEFT_SHOULDER.value,
#                      PoseLandmark.LEFT_ELBOW.value,
#                      PoseLandmark.LEFT_WRIST.value),
#     "RIGHT_ELBOW":  (PoseLandmark.RIGHT_SHOULDER.value,
#                      PoseLandmark.RIGHT_ELBOW.value,
#                      PoseLandmark.RIGHT_WRIST.value),
#     "LEFT_SHOULDER":  (PoseLandmark.LEFT_ELBOW.value,
#                        PoseLandmark.LEFT_SHOULDER.value,
#                        PoseLandmark.LEFT_HIP.value),
#     "RIGHT_SHOULDER": (PoseLandmark.RIGHT_ELBOW.value,
#                        PoseLandmark.RIGHT_SHOULDER.value,
#                        PoseLandmark.RIGHT_HIP.value),
#     "LEFT_HIP":   (PoseLandmark.LEFT_SHOULDER.value,
#                    PoseLandmark.LEFT_HIP.value,
#                    PoseLandmark.LEFT_KNEE.value),
#     "RIGHT_HIP":  (PoseLandmark.RIGHT_SHOULDER.value,
#                    PoseLandmark.RIGHT_HIP.value,
#                    PoseLandmark.RIGHT_KNEE.value),
#     "LEFT_KNEE":  (PoseLandmark.LEFT_HIP.value,
#                   PoseLandmark.LEFT_KNEE.value,
#                   PoseLandmark.LEFT_ANKLE.value),
#     "RIGHT_KNEE": (PoseLandmark.RIGHT_HIP.value,
#                   PoseLandmark.RIGHT_KNEE.value,
#                   PoseLandmark.RIGHT_ANKLE.value),
#     "SPINE": (
#        PoseLandmark.LEFT_HIP.value,       
#        PoseLandmark.LEFT_SHOULDER.value,   
#        PoseLandmark.RIGHT_SHOULDER.value   
#     ),
#     "HEAD": (
#        PoseLandmark.LEFT_SHOULDER.value,
#        PoseLandmark.NOSE.value,
#        PoseLandmark.RIGHT_SHOULDER.value
#     ),
#      "LEFT_WRIST":  (
#         PoseLandmark.LEFT_ELBOW.value,
#         PoseLandmark.LEFT_WRIST.value,
#         PoseLandmark.LEFT_INDEX.value
#     ),
#     "RIGHT_WRIST": (
#         PoseLandmark.RIGHT_ELBOW.value,
#         PoseLandmark.RIGHT_WRIST.value,
#         PoseLandmark.RIGHT_INDEX.value
#     ),
#     "LEFT_ANKLE":  (
#         PoseLandmark.LEFT_KNEE.value,
#         PoseLandmark.LEFT_ANKLE.value,
#         PoseLandmark.LEFT_FOOT_INDEX.value
#     ),
#     "RIGHT_ANKLE": (
#         PoseLandmark.RIGHT_KNEE.value,
#         PoseLandmark.RIGHT_ANKLE.value,
#         PoseLandmark.RIGHT_FOOT_INDEX.value
#     ),
    
# }
# ERR_JOINTS = list(JOINT_TRIPLETS.keys())
# N_ERR = len(ERR_JOINTS)  # 14

# # 2. load original metadata & keypoints -----------------------
# DATA_ROOT    = Path("Data-REHAB24-6")
# KEYPT_ROOT   = DATA_ROOT/"mp_keypoints"
# META_ORIG    = DATA_ROOT/"Segmentation.xlsx"
# df           = pd.read_excel(META_ORIG, engine="openpyxl")
# df.columns   = df.columns.str.strip()

# # 3. compute ideal_angles on correct reps ----------------------
# ideal_angles = {}
# correct = df[df.correctness==1]
# for ex in correct.exercise_id.unique():
#     all_ang = {jn:[] for jn in ERR_JOINTS}
#     for _,r in correct[correct.exercise_id==ex].iterrows():
#         vid, f0, f1 = r.video_id, int(r.first_frame), int(r.last_frame)
#         files = list((KEYPT_ROOT/f"Ex{ex}").glob(f"{vid}-Camera17*-mp.npy"))
#         if not files: continue
#         arr = np.load(files[0])
#         seg = arr[f0:f1] if f1>f0 else arr[f0:]
#         if len(seg)==0: continue
#         mid = len(seg)//2
#         frm = seg[mid]
#         for jn in ERR_JOINTS:
#             ia,ib,ic = JOINT_TRIPLETS[jn]
#             ang = angle_between(frm[ia,:2],frm[ib,:2],frm[ic,:2])
#             all_ang[jn].append(ang)
#     # median
#     ideal_angles[ex] = {jn:float(np.median(all_ang[jn])) for jn in all_ang if all_ang[jn]}

# # 4. slide windows & write rows --------------------------------
# WINDOW, STRIDE = 16, 8
# rows = []
# for _,r in df.iterrows():
#     vid, ex, f0, f1 = r.video_id, int(r.exercise_id), int(r.first_frame), int(r.last_frame)
#     files = list((KEYPT_ROOT/f"Ex{ex}").glob(f"{vid}-Camera17*-mp.npy"))
#     if not files: continue
#     arr = np.load(files[0])                # (F,33,3)
#     seg = arr[f0:f1] if f1>f0 else arr[f0:]
#     if len(seg)<WINDOW: continue

#     # per-frame errors
#     pf_err = {jn:[] for jn in ERR_JOINTS}
#     for frm in seg:
#         for jn in ERR_JOINTS:
#             ia,ib,ic = JOINT_TRIPLETS[jn]
#             ang = angle_between(frm[ia,:2],frm[ib,:2],frm[ic,:2])
#             pf_err[jn].append(ang - ideal_angles[ex].get(jn,ang))

#     # slide
#     for start in range(0, len(seg)-WINDOW+1, STRIDE):
#         w = np.array([ pf_err[jn][start:start+WINDOW] for jn in ERR_JOINTS ])  # (10,WINDOW)
#         mean_err = w.mean(axis=1)
#         row = {
#             "video_id":vid,
#             "exercise_id":ex,
#             "repetition_number":r.repetition_number,
#             "window_start": f0+start,
#             "window_end":   f0+start+WINDOW,
#             "correctness":  r.correctness
#         }
#         for i,jn in enumerate(ERR_JOINTS):
#             row[f"err_{i}"] = float(mean_err[i])
#         rows.append(row)

# win_df = pd.DataFrame(rows)
# win_df.to_csv(DATA_ROOT/"Segmentation_windows.csv", index=False)
# print("Wrote", len(win_df), "windows to Segmentation_windows.csv")


# 4. Traning Setup

## 4.1 Paths & device

In [9]:

SCRIPT_DIR    = Path().resolve()
DATA_ROOT     = SCRIPT_DIR/"Data-REHAB24-6"
WIN_CSV       = DATA_ROOT/"Segmentation_windows.csv"
KEYPT_ROOT    = DATA_ROOT/"mp_keypoints"

DEVICE = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print("► Using device:", DEVICE)

► Using device: mps


## 4.2 Analysing the dataset

In [10]:
# --- Specify the file to analyze ---
# Use the file that was generated in the previous step (after modification)
DATA_FILENAME = "Segmentation.xlsx"
FILE_TO_ANALYZE = DATA_ROOT / DATA_FILENAME

print(f"Attempting to load data from: {FILE_TO_ANALYZE}")

# --- Check if file exists and process ---
if not FILE_TO_ANALYZE.is_file():
    print(f"Error: The file '{FILE_TO_ANALYZE.name}' was not found in '{DATA_ROOT}'.")
    print("Please ensure the previous script ran successfully and saved the file.")
else:
    try:
        # Load the dataframe
        df = pd.read_excel(FILE_TO_ANALYZE)
        print(f"Successfully loaded '{FILE_TO_ANALYZE.name}'. Shape: {df.shape}")

        # Check if the 'correctness' column exists
        if 'correctness' not in df.columns:
            print(f"Error: 'correctness' column not found in the dataframe.")
            print(f"Available columns are: {df.columns.tolist()}")
        else:
            print("\nAnalyzing 'correctness' column distribution...")

            # Get the counts for each value in the 'correctness' column
            class_counts = df['correctness'].value_counts()

            # --- Report the counts ---
            print("\nClass Counts:")
            print(class_counts)

            # Provide a more descriptive output
            correct_count = class_counts.get(1, 0) # Get count for value 1, default to 0 if not present
            wrong_count = class_counts.get(0, 0)   # Get count for value 0, default to 0 if not present
            total_count = correct_count + wrong_count # Or use len(df) if there are only 0s and 1s

            print(f"\nNumber of 'Correct' (1) instances: {correct_count}")
            print(f"Number of 'Wrong' (0) instances:   {wrong_count}")
            print(f"Total analyzed instances:        {total_count}") # Good sanity check

            # --- Assess Balance ---
            if total_count > 0:
                correct_percentage = (correct_count / total_count) * 100
                wrong_percentage = (wrong_count / total_count) * 100
                print(f"\nPercentage 'Correct': {correct_percentage:.2f}%")
                print(f"Percentage 'Wrong':   {wrong_percentage:.2f}%")

                # Basic balance assessment (you can adjust the threshold)
                if abs(correct_percentage - wrong_percentage) < 20: # e.g., less than 20% difference (60/40 split)
                    print("\nThe dataset appears relatively balanced based on this threshold.")
                elif abs(correct_percentage - wrong_percentage) < 40: # e.g., less than 40% difference (70/30 split)
                     print("\nThe dataset shows moderate imbalance.")
                else:
                    print("\nWARNING: The dataset appears significantly imbalanced.")
            else:
                print("\nCannot assess balance: No instances found in the 'correctness' column.")

    except Exception as e:
        print(f"\nAn error occurred during file loading or analysis: {e}")

Attempting to load data from: /Users/jithinkrishnan/Documents/Study/IS06 /MVP/RehabApp/model-training-scripts/Data-REHAB24-6/Segmentation.xlsx
Successfully loaded 'Segmentation.xlsx'. Shape: (1072, 13)

Analyzing 'correctness' column distribution...

Class Counts:
correctness
0    751
1    321
Name: count, dtype: int64

Number of 'Correct' (1) instances: 321
Number of 'Wrong' (0) instances:   751
Total analyzed instances:        1072

Percentage 'Correct': 29.94%
Percentage 'Wrong':   70.06%



### Note: Balacing of the dataset will be done duting model training

## 4.3 Joint names setup

In [11]:
PoseLandmark = mp.solutions.pose.PoseLandmark

# Then:
JOINT_NAMES = [lm.name for lm in PoseLandmark]
N_JOINTS    = len(JOINT_NAMES)  # should be 33

print(f"JOINT_NAMES: {JOINT_NAMES}")
print(f"N_JOINTS: {N_JOINTS}")

#  Exerciseses (Ex1…Ex6)
NUM_EXERCISES = 6
CKPT_FILE     = "kp_pose_quality_windows_ex.pt"  

ERR_JOINTS   = [
  "LEFT_ELBOW","RIGHT_ELBOW",
  "LEFT_SHOULDER","RIGHT_SHOULDER",
  "LEFT_HIP","RIGHT_HIP",
  "LEFT_KNEE","RIGHT_KNEE",
  "SPINE","HEAD",
  "LEFT_WRIST", "RIGHT_WRIST",
  "LEFT_ANKLE", "RIGHT_ANKLE"
]
N_ERR = len(ERR_JOINTS)   # 14
ERR_COLS = [f"err_{i}" for i in range(N_ERR)]


JOINT_NAMES: ['NOSE', 'LEFT_EYE_INNER', 'LEFT_EYE', 'LEFT_EYE_OUTER', 'RIGHT_EYE_INNER', 'RIGHT_EYE', 'RIGHT_EYE_OUTER', 'LEFT_EAR', 'RIGHT_EAR', 'MOUTH_LEFT', 'MOUTH_RIGHT', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW', 'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'LEFT_PINKY', 'RIGHT_PINKY', 'LEFT_INDEX', 'RIGHT_INDEX', 'LEFT_THUMB', 'RIGHT_THUMB', 'LEFT_HIP', 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE', 'RIGHT_ANKLE', 'LEFT_HEEL', 'RIGHT_HEEL', 'LEFT_FOOT_INDEX', 'RIGHT_FOOT_INDEX']
N_JOINTS: 33


## 4.4 Dataset class definition

The KeypointWindowDataset class loads and processes pose keypoint data from videos for model training. It reads a CSV file containing metadata, including video IDs, exercise IDs, frame indices, and pre-calculated joint angle errors (ranging from 0 to 9). The data is sorted based on video ID, repetition number, and window start. For each sample, it loads the corresponding keypoint data (in .npy format), extracts a segment of frames based on the start and end indices, reshapes the keypoints into a 2D array, and converts them into a PyTorch tensor. It also retrieves the correctness label and the pre-calculated error values, which are stored in tensors. This class efficiently loads and processes the data in batches for training tasks like exercise recognition, where both pose keypoints and error features are used for supervised learning.

In [12]:

class KeypointWindowDataset(Dataset):
    def __init__(self, csv_file: Path, keypt_root: Path):
        df = pd.read_csv(csv_file)
        df = df.sort_values(["video_id","repetition_number","window_start"])
        self.rows = df.to_dict("records")
        self.keypt_root = keypt_root

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i: int):
        r   = self.rows[i]
        ex  = int(r["exercise_id"]) - 1      # zero‐based [0..NUM_EXERCISES-1]
        vid = r["video_id"]
        f0, f1 = int(r["window_start"]), int(r["window_end"])

        # load keypoints
        arr = np.load(
            next((self.keypt_root/f"Ex{ex+1}").glob(f"{vid}-Camera17*-mp.npy"))
        )  # shape (F,33,3)

        seg = arr[f0:f1]            # (T, 33, 3)
        seg = seg.reshape(len(seg), -1)  # (T, 99)
        seq = torch.from_numpy(seg).float()

        label = torch.tensor(r["correctness"], dtype=torch.long)
        err   = torch.tensor([r[f"err_{j}"] for j in range(N_ERR)],
                             dtype=torch.float32)

        return seq, label, err, ex

## 4.4 Model definitions   

1. KeypointEncoder Class: Feature Extraction  
The KeypointEncoder class is responsible for extracting feature representations from the input keypoint data. It uses two 1D convolutional layers (conv1 and conv2) to process the input sequence of keypoints. The input tensor, which represents keypoints for each frame in a video, is passed through these convolutional layers after being reshaped to fit the 1D convolution. Each convolution layer is followed by a ReLU activation function to introduce non-linearity. The final step of the encoder involves an adaptive average pooling (pool), which reduces the feature map to a single value per feature channel. This results in a compact representation of the keypoint sequence, which is then passed forward for further processing.  

2. PoseQualityNetKP Class: Overview  
The PoseQualityNetKP class is the main model used for pose quality assessment. It integrates the KeypointEncoder to process the raw keypoint data and extracts meaningful features. The model then uses an LSTM (Long Short-Term Memory) network to learn the temporal dependencies between the keypoint sequences. The LSTM consists of two bidirectional layers, allowing the model to capture information from both past and future frames in the sequence. The LSTM outputs a sequence of hidden states, which are averaged across the time dimension to produce a fixed-size feature vector representing the entire sequence of frames. This vector, along with the exercise embedding, is used to make predictions.  

3. Exercise Embedding and Final Layers  
In addition to the keypoint features, the model incorporates an exercise-specific embedding to capture the variations between different exercises. The ex_emb layer processes the one-hot encoded exercise ID into a dense representation. This embedding is passed through a small multi-layer perceptron (MLP) that reduces the embedding size, enabling the model to focus on the most important characteristics of each exercise. The concatenation of the temporal features from the LSTM and the exercise embedding forms the final input to the classification and error prediction heads. These final heads, cls_head and err_head, are fully connected layers that output the classification of the exercise and the pose errors, respectively.  

In [13]:
# 5. Model definitions
class KeypointEncoder(nn.Module):
    def __init__(self, in_dim:int, embed:int=512):
        super().__init__()
        self.conv1 = nn.Conv1d(in_dim, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(128, embed, kernel_size=3, padding=1)
        self.pool  = nn.AdaptiveAvgPool1d(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, D); treat as (B, D, 1) for Conv1d
        x = x.unsqueeze(2)                 # → (B, D, 1)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        return self.pool(x).squeeze(-1)    # → (B, embed)

class PoseQualityNetKP(nn.Module):
    def __init__(self,
                 in_dim: int,
                 num_ex: int,
                 hidden: int = 256,
                 ex_emb: int = 64):
        super().__init__()
        # keypoint feature extractor
        self.encoder = KeypointEncoder(in_dim)

        # sequence model
        self.lstm = nn.LSTM(
            input_size=512,
            hidden_size=hidden,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        feat_dim = hidden * 2

        # exercise embedding MLP
        self.ex_emb = nn.Sequential(
            nn.Linear(num_ex, ex_emb),
            nn.ReLU(),
            nn.Linear(ex_emb, ex_emb)
        )

        # final heads
        self.cls_head = nn.Linear(feat_dim + ex_emb, 2)
        self.err_head = nn.Linear(feat_dim + ex_emb, N_ERR)

    def forward(self,
                seq:     torch.Tensor,  # (B, T, D)
                ex_1hot: torch.Tensor   # (B, num_ex)
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # 1) keypoint → sequence feats
        # encode each frame
        B,T,_ = seq.shape
        feats = torch.stack([
            self.encoder(seq[:,t]) for t in range(T)
        ], dim=1)                                # (B, T, 512)
        out, _ = self.lstm(feats)                # (B, T, 2*hidden)
        g = out.mean(1)                          # (B, 2*hidden)

        # 2) exercise embed
        ex_e = self.ex_emb(ex_1hot)              # (B, ex_emb)

        # 3) concat and heads
        h = torch.cat([g, ex_e], dim=1)          # (B, feat_dim+ex_emb)
        return self.cls_head(h), self.err_head(h)


# 5. Model Training  

The PoseQualityNetKP model is engineered for pose quality assessment, leveraging keypoint data to evaluate the correctness of exercise movements. It comprises a KeypointEncoder to extract robust features from keypoint sequences, a bidirectional LSTM to model temporal relationships across frames, an exercise-specific embedding layer to account for exercise variations, and two output heads for classification and joint error prediction.

### Model Architecture and Processing

The model processes input keypoint sequences and one-hot encoded exercise IDs. The KeypointEncoder applies 1D convolutional layers to each frame of the keypoint sequence, generating per-frame feature representations. These features are fed into a bidirectional LSTM, which captures both forward and backward temporal dependencies, producing a sequence of contextualized features. The LSTM outputs are averaged over time to yield a fixed-size feature vector for the entire sequence.

Simultaneously, the exercise-specific embedding MLP transforms the one-hot encoded exercise IDs into a dense embedding, capturing exercise-specific characteristics. The temporal features from the LSTM and the exercise embeddings are concatenated and passed to two heads:  
- Classification Head (cls_head): Predicts whether the exercise repetition is "Correct" (1) or "Wrong" (0).
- Error Head (err_head): Estimates pose errors for key joints by predicting deviations from ideal joint angles.

### Training Process

The model is trained to optimize two loss functions:
- Cross-Entropy Loss (loss_cls): Used for binary classification of exercise correctness. To address the significant class imbalance (70.06% "Wrong" vs. 29.94% "Correct"), the loss is weighted inversely proportional to class frequencies, assigning higher penalties to misclassifications of the minority "Correct" class.
- Smooth L1 Loss (loss_err): Minimizes the difference between predicted and ground-truth joint angle errors, weighted by a factor of 0.1 to balance its contribution relative to the classification loss.

The training data is processed using the KeypointWindowDataset, which loads keypoint sequences and labels from a CSV file and keypoint directory. To further mitigate class imbalance, the training DataLoader employs a WeightedRandomSampler to oversample the minority "Correct" class, ensuring balanced exposure to both classes during training. The validation set remains unsampled to provide unbiased performance metrics.

The Adam optimizer, with a learning rate of 1e-4, updates model parameters to minimize the combined loss. Training proceeds over multiple epochs, with each epoch processing batches of data. During validation, the model is evaluated on a separate validation set using metrics such as accuracy, weighted F1-score, per-class precision and recall (for both "Wrong" and "Correct" classes), and mean absolute error (MAE) for joint error predictions. Per-class metrics help monitor performance on the minority "Correct" class, which is critical due to the imbalance.

### Model Selection and Saving

The model with the highest weighted F1-score on the validation set is saved, along with its state dictionary, ensuring the best-performing model is retained for deployment. This training framework, enhanced with class-weighted loss and oversampling, enables robust pose quality assessment, improving classification accuracy for both classes and providing precise joint error analysis for exercise feedback.  

In [14]:

def train_epochs(
    csv_file: str = str(WIN_CSV),
    keypt_root: str = str(KEYPT_ROOT),
    num_ex: int = NUM_EXERCISES,
    epochs: int = 30,
    batch: int = 16,
    lr: float = 1e-4,
    ckpt_file: str = CKPT_FILE
):
    # Build dataset and split
    ds = KeypointWindowDataset(Path(csv_file), Path(keypt_root))
    N = len(ds)
    idx = np.arange(N)
    np.random.shuffle(idx)
    c1, c2 = int(0.7 * N), int(0.85 * N)
    train_idx, val_idx = idx[:c1], idx[c1:c2]

    # --- Oversampling the Minority Class ---
    # Get labels for training indices
    train_labels = [ds.rows[i]["correctness"] for i in train_idx]
    # Compute class weights for oversampling
    class_counts = np.bincount(train_labels)  # [count_0, count_1]
    num_samples = len(train_labels)
    weights = np.zeros(num_samples)
    for i, label in enumerate(train_labels):
        weights[i] = num_samples / (len(class_counts) * class_counts[label])  # Inverse frequency
    sampler = WeightedRandomSampler(weights=weights, num_samples=num_samples, replacement=True)

    # DataLoaders with oversampling for training
    train_dl = DataLoader(
        Subset(ds, train_idx),
        batch_size=batch,
        sampler=sampler  # Use sampler instead of shuffle
    )
    val_dl = DataLoader(
        Subset(ds, val_idx),
        batch_size=batch,
        shuffle=False
    )

    # Infer input dimension
    sample_seq, _, _, _ = ds[0]
    in_dim = sample_seq.shape[-1]

    # Build model
    model = PoseQualityNetKP(in_dim, num_ex).to(DEVICE)
    
    # --- Class-Weighted Loss ---
    # Assign higher weight to minority class (1: Correct)
    class_weights = torch.tensor([1.0, class_counts[0] / class_counts[1]], dtype=torch.float32).to(DEVICE)  # Weight for [0, 1]
    loss_cls = nn.CrossEntropyLoss(weight=class_weights)
    loss_err = nn.SmoothL1Loss()
    opt = Adam(model.parameters(), lr)

    best_f1 = 0.0
    for epoch in range(1, epochs + 1):
        # -- Train --
        model.train()
        tot_loss = 0.0
        for seq, y, err, ex in tqdm(train_dl, desc=f"Epoch {epoch:02d}"):
            seq, y, err, ex = [x.to(DEVICE) for x in (seq, y, err, ex)]
            ex_1hot = F.one_hot(ex, num_ex).float()

            opt.zero_grad()
            logits, err_hat = model(seq, ex_1hot)
            loss = loss_cls(logits, y) + 0.1 * loss_err(err_hat, err)
            loss.backward()
            opt.step()

            tot_loss += loss.item() * y.size(0)
        print(f"  ↳ train loss: {tot_loss/len(train_idx):.4f}")

        # -- Validation --
        model.eval()
        y_true, y_pred, errs = [], [], []
        precision, recall = [], []
        with torch.no_grad():
            for seq, y, err, ex in val_dl:
                seq, y, err, ex = [x.to(DEVICE) for x in (seq, y, err, ex)]
                ex_1hot = F.one_hot(ex, num_ex).float()
                logits, err_hat = model(seq, ex_1hot)

                y_true += y.cpu().tolist()
                y_pred += logits.argmax(1).cpu().tolist()
                errs += [(err_hat - err.to(DEVICE)).abs().mean(1)]
                
                # Per-class precision and recall
                precision += [precision_score(y.cpu(), logits.argmax(1).cpu(), average=None, labels=[0, 1])]
                recall += [recall_score(y.cpu(), logits.argmax(1).cpu(), average=None, labels=[0, 1])]

        # Calculate metrics
        acc = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average='weighted')
        mae = torch.cat(errs).mean().item()
        precision_mean = np.mean(precision, axis=0)  # [prec_0, prec_1]
        recall_mean = np.mean(recall, axis=0)       # [recall_0, recall_1]

        print(f"  ↳ val acc {acc:.3f}, Precision (0,1) {precision_mean[0]:.3f},{precision_mean[1]:.3f}, "
              f"Recall (0,1) {recall_mean[0]:.3f},{recall_mean[1]:.3f}, F1 {f1:.3f}, MAE° {mae:.2f}")

        # Save if improved
        if f1 > best_f1:
            best_f1 = f1
            torch.save(model, ckpt_file)
            state_dict_path = Path(ckpt_file).with_suffix('.pth')
            torch.save(model.state_dict(), state_dict_path)
            print(f"  ✓ Saved best model to {ckpt_file} and state_dict to {state_dict_path} (F1 {f1:.3f})")

train_epochs(epochs=40, batch=16, lr=1e-4)

Epoch 01: 100%|██████████| 629/629 [00:13<00:00, 47.21it/s]


  ↳ train loss: 2.2208
  ↳ val acc 0.740, Precision (0,1) 1.000,0.543, Recall (0,1) 0.621,1.000, F1 0.748, MAE° 17.67
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.748)


Epoch 02: 100%|██████████| 629/629 [00:12<00:00, 49.20it/s]


  ↳ train loss: 1.9653
  ↳ val acc 0.745, Precision (0,1) 1.000,0.547, Recall (0,1) 0.628,1.000, F1 0.753, MAE° 17.13
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.753)


Epoch 03: 100%|██████████| 629/629 [00:12<00:00, 48.71it/s]


  ↳ train loss: 1.9391
  ↳ val acc 0.769, Precision (0,1) 0.994,0.572, Recall (0,1) 0.668,0.989, F1 0.776, MAE° 16.87
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.776)


Epoch 04: 100%|██████████| 629/629 [00:13<00:00, 48.37it/s]


  ↳ train loss: 1.8953
  ↳ val acc 0.781, Precision (0,1) 0.979,0.585, Recall (0,1) 0.696,0.961, F1 0.789, MAE° 16.48
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.789)


Epoch 05: 100%|██████████| 629/629 [00:13<00:00, 47.10it/s]


  ↳ train loss: 1.8416
  ↳ val acc 0.790, Precision (0,1) 0.998,0.595, Recall (0,1) 0.697,0.997, F1 0.798, MAE° 15.92
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.798)


Epoch 06: 100%|██████████| 629/629 [00:13<00:00, 47.53it/s]


  ↳ train loss: 1.7538
  ↳ val acc 0.794, Precision (0,1) 1.000,0.599, Recall (0,1) 0.699,1.000, F1 0.801, MAE° 15.09
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.801)


Epoch 07: 100%|██████████| 629/629 [00:12<00:00, 49.23it/s]


  ↳ train loss: 1.6621
  ↳ val acc 0.783, Precision (0,1) 1.000,0.587, Recall (0,1) 0.683,1.000, F1 0.790, MAE° 14.20


Epoch 08: 100%|██████████| 629/629 [00:12<00:00, 49.45it/s]


  ↳ train loss: 1.5865
  ↳ val acc 0.798, Precision (0,1) 1.000,0.608, Recall (0,1) 0.706,1.000, F1 0.805, MAE° 13.24
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.805)


Epoch 09: 100%|██████████| 629/629 [00:12<00:00, 48.76it/s]


  ↳ train loss: 1.4740
  ↳ val acc 0.815, Precision (0,1) 0.996,0.628, Recall (0,1) 0.732,0.994, F1 0.821, MAE° 12.34
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.821)


Epoch 10: 100%|██████████| 629/629 [00:12<00:00, 49.03it/s]


  ↳ train loss: 1.4107
  ↳ val acc 0.826, Precision (0,1) 0.998,0.641, Recall (0,1) 0.748,0.996, F1 0.832, MAE° 11.88
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.832)


Epoch 11: 100%|██████████| 629/629 [00:12<00:00, 48.91it/s]


  ↳ train loss: 1.3231
  ↳ val acc 0.848, Precision (0,1) 0.997,0.671, Recall (0,1) 0.781,0.996, F1 0.854, MAE° 11.10
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.854)


Epoch 12: 100%|██████████| 629/629 [00:12<00:00, 48.74it/s]


  ↳ train loss: 1.2539
  ↳ val acc 0.855, Precision (0,1) 0.996,0.683, Recall (0,1) 0.790,0.989, F1 0.860, MAE° 10.55
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.860)


Epoch 13: 100%|██████████| 629/629 [00:12<00:00, 48.68it/s]


  ↳ train loss: 1.1801
  ↳ val acc 0.854, Precision (0,1) 0.995,0.679, Recall (0,1) 0.791,0.990, F1 0.859, MAE° 10.07


Epoch 14: 100%|██████████| 629/629 [00:12<00:00, 48.93it/s]


  ↳ train loss: 1.1489
  ↳ val acc 0.846, Precision (0,1) 0.995,0.669, Recall (0,1) 0.777,0.990, F1 0.851, MAE° 9.63


Epoch 15: 100%|██████████| 629/629 [00:12<00:00, 48.72it/s]


  ↳ train loss: 1.0958
  ↳ val acc 0.873, Precision (0,1) 0.995,0.712, Recall (0,1) 0.817,0.989, F1 0.877, MAE° 9.32
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.877)


Epoch 16: 100%|██████████| 629/629 [04:08<00:00,  2.53it/s]  


  ↳ train loss: 1.0485
  ↳ val acc 0.869, Precision (0,1) 0.983,0.712, Recall (0,1) 0.822,0.968, F1 0.873, MAE° 9.12


Epoch 17: 100%|██████████| 629/629 [00:12<00:00, 49.36it/s]


  ↳ train loss: 1.0329
  ↳ val acc 0.882, Precision (0,1) 0.997,0.726, Recall (0,1) 0.828,0.991, F1 0.885, MAE° 8.83
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.885)


Epoch 18: 100%|██████████| 629/629 [00:12<00:00, 49.52it/s]


  ↳ train loss: 0.9677
  ↳ val acc 0.859, Precision (0,1) 0.989,0.691, Recall (0,1) 0.802,0.977, F1 0.864, MAE° 8.56


Epoch 19: 100%|██████████| 629/629 [00:12<00:00, 49.44it/s]


  ↳ train loss: 0.9704
  ↳ val acc 0.886, Precision (0,1) 0.997,0.727, Recall (0,1) 0.835,0.997, F1 0.889, MAE° 8.42
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.889)


Epoch 20: 100%|██████████| 629/629 [00:12<00:00, 49.51it/s]


  ↳ train loss: 0.9361
  ↳ val acc 0.879, Precision (0,1) 0.997,0.715, Recall (0,1) 0.826,0.992, F1 0.883, MAE° 8.12


Epoch 21: 100%|██████████| 629/629 [00:12<00:00, 49.13it/s]


  ↳ train loss: 0.9014
  ↳ val acc 0.892, Precision (0,1) 0.996,0.742, Recall (0,1) 0.846,0.992, F1 0.895, MAE° 7.99
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.895)


Epoch 22: 100%|██████████| 629/629 [00:13<00:00, 47.00it/s]


  ↳ train loss: 0.8963
  ↳ val acc 0.864, Precision (0,1) 0.997,0.690, Recall (0,1) 0.802,0.997, F1 0.868, MAE° 7.75


Epoch 23: 100%|██████████| 629/629 [00:13<00:00, 47.69it/s]


  ↳ train loss: 0.8753
  ↳ val acc 0.903, Precision (0,1) 0.993,0.763, Recall (0,1) 0.861,0.990, F1 0.905, MAE° 7.67
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.905)


Epoch 24: 100%|██████████| 629/629 [00:13<00:00, 48.38it/s]


  ↳ train loss: 0.8426
  ↳ val acc 0.898, Precision (0,1) 0.986,0.761, Recall (0,1) 0.864,0.970, F1 0.901, MAE° 7.46


Epoch 25: 100%|██████████| 629/629 [00:12<00:00, 48.83it/s]


  ↳ train loss: 0.8273
  ↳ val acc 0.904, Precision (0,1) 0.984,0.772, Recall (0,1) 0.873,0.973, F1 0.907, MAE° 7.46
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.907)


Epoch 26: 100%|██████████| 629/629 [00:13<00:00, 48.37it/s]


  ↳ train loss: 0.8173
  ↳ val acc 0.896, Precision (0,1) 0.998,0.745, Recall (0,1) 0.847,0.997, F1 0.899, MAE° 7.42


Epoch 27: 100%|██████████| 629/629 [00:12<00:00, 48.65it/s]


  ↳ train loss: 0.7942
  ↳ val acc 0.888, Precision (0,1) 0.954,0.771, Recall (0,1) 0.877,0.909, F1 0.890, MAE° 7.04


Epoch 28: 100%|██████████| 629/629 [00:12<00:00, 48.51it/s]


  ↳ train loss: 0.7759
  ↳ val acc 0.908, Precision (0,1) 0.997,0.772, Recall (0,1) 0.870,0.989, F1 0.911, MAE° 6.81
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.911)


Epoch 29: 100%|██████████| 629/629 [00:13<00:00, 48.25it/s]


  ↳ train loss: 0.7535
  ↳ val acc 0.907, Precision (0,1) 0.992,0.771, Recall (0,1) 0.869,0.983, F1 0.909, MAE° 6.79


Epoch 30: 100%|██████████| 629/629 [00:12<00:00, 48.71it/s]


  ↳ train loss: 0.7475
  ↳ val acc 0.920, Precision (0,1) 0.995,0.796, Recall (0,1) 0.886,0.990, F1 0.922, MAE° 6.70
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.922)


Epoch 31: 100%|██████████| 629/629 [00:12<00:00, 48.63it/s]


  ↳ train loss: 0.7286
  ↳ val acc 0.911, Precision (0,1) 0.989,0.781, Recall (0,1) 0.879,0.977, F1 0.913, MAE° 6.46


Epoch 32: 100%|██████████| 629/629 [00:13<00:00, 48.08it/s]


  ↳ train loss: 0.7259
  ↳ val acc 0.924, Precision (0,1) 0.985,0.815, Recall (0,1) 0.900,0.978, F1 0.925, MAE° 6.41
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.925)


Epoch 33: 100%|██████████| 629/629 [00:12<00:00, 48.42it/s]


  ↳ train loss: 0.7035
  ↳ val acc 0.920, Precision (0,1) 0.991,0.803, Recall (0,1) 0.892,0.980, F1 0.922, MAE° 6.20


Epoch 34: 100%|██████████| 629/629 [00:13<00:00, 47.95it/s]


  ↳ train loss: 0.6906
  ↳ val acc 0.895, Precision (0,1) 1.000,0.741, Recall (0,1) 0.845,1.000, F1 0.898, MAE° 6.12


Epoch 35: 100%|██████████| 629/629 [00:13<00:00, 47.90it/s]


  ↳ train loss: 0.6813
  ↳ val acc 0.875, Precision (0,1) 0.999,0.709, Recall (0,1) 0.817,0.998, F1 0.879, MAE° 6.09


Epoch 36: 100%|██████████| 629/629 [00:13<00:00, 47.84it/s]


  ↳ train loss: 0.6688
  ↳ val acc 0.915, Precision (0,1) 0.991,0.783, Recall (0,1) 0.884,0.983, F1 0.917, MAE° 5.96


Epoch 37: 100%|██████████| 629/629 [00:12<00:00, 48.44it/s]


  ↳ train loss: 0.6476
  ↳ val acc 0.918, Precision (0,1) 0.997,0.791, Recall (0,1) 0.883,0.994, F1 0.920, MAE° 5.82


Epoch 38: 100%|██████████| 629/629 [00:13<00:00, 47.64it/s]


  ↳ train loss: 0.6446
  ↳ val acc 0.924, Precision (0,1) 0.990,0.808, Recall (0,1) 0.898,0.977, F1 0.926, MAE° 5.72
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.926)


Epoch 39: 100%|██████████| 629/629 [00:13<00:00, 47.51it/s]


  ↳ train loss: 0.6409
  ↳ val acc 0.929, Precision (0,1) 0.992,0.815, Recall (0,1) 0.903,0.981, F1 0.930, MAE° 5.73
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.930)


Epoch 40: 100%|██████████| 629/629 [00:13<00:00, 47.77it/s]


  ↳ train loss: 0.6150
  ↳ val acc 0.931, Precision (0,1) 0.991,0.826, Recall (0,1) 0.908,0.983, F1 0.933, MAE° 5.63
  ✓ Saved best model to kp_pose_quality_windows_ex.pt and state_dict to kp_pose_quality_windows_ex.pth (F1 0.933)


# 6. Inference Tesing - Correctness and feedback on Live Videos / Recorded Videos 


### How It Works

This script performs real-time pose quality assessment for physical exercises by integrating MediaPipe's Pose estimation with a custom deep learning model, `PoseQualityNetKP`, implemented in PyTorch. The process begins with capturing video input (either from a file or live webcam) using OpenCV (`cv2`). MediaPipe Pose processes each frame to detect 33 human pose landmarks, extracting their 3D coordinates (x, y, z) if visibility exceeds a threshold (0.7). These keypoints are collected into a buffer of 16 frames (`SEQUENCE_LENGTH`) to form a sequence of shape `(1, 16, 99)` (batch, time, 33 joints × 3 coordinates). This sequence is fed into `PoseQualityNetKP`, which consists of a `KeypointEncoder` (using 1D convolutions to encode per-frame keypoints into 512-dimensional embeddings) followed by a bidirectional LSTM (with 2 layers, 256 hidden units per direction) to capture temporal dynamics, producing a 512-dimensional sequence feature (`2 × hidden`). The model also embeds the exercise type (via one-hot encoding and an MLP) and concatenates it with the sequence feature, yielding a final representation. This is passed through two heads: a classification head (`cls_head`) to predict "Correct" or "Incorrect" pose (binary classification) and an error head (`err_head`) to estimate angular deviations for 14 predefined joints (e.g., `LEFT_ELBOW`, `RIGHT_KNEE`). The script ensures robust inference by checking the visibility of required landmarks for each exercise (e.g., special handling for push-ups with `PUSHUP_REQUIRED_LANDMARKS`), halting analysis if key joints are not visible (visibility < 0.7), and displaying messages like "Adjust posture" or "No pose detected" to guide the user.  

The feedback logic is designed to provide actionable insights by leveraging the model's outputs in a structured manner. Once the model classifies a pose as "Incorrect" (predicted class 0), the script examines the angular deviation predictions from the error head (`err_head`), which outputs a vector of 14 values representing the mean angular deviation for each joint in `ERR_JOINTS`. It identifies the joint with the largest absolute deviation using `np.argmax`, but only flags it for correction if the deviation exceeds a threshold degree (`ERROR_WARNING_THRESHOLD`) (Currently set as 5). The suggestion is then formatted as a string (e.g., "Check LEFT_ELBOW (Dev: +3.2°)") and displayed on the video feed for a duration of 90 frames (`SUGGESTION_DURATION_FRAMES = 3 * 30`), ensuring the user has sufficient time to notice and act on the feedback. If the pose is "Correct" (predicted class 1), no joint-specific feedback is shown, and the buffer is maintained to continue monitoring. Additional feedback states, such as "Analysing 5/16" during buffer filling or "World landmarks missing" if 3D coordinates are unavailable, are also displayed to keep the user informed, with all text overlaid on the video using OpenCV in distinct colors (e.g., green for "Correct", red for "Incorrect", orange for warnings) to enhance clarity.

### How to Run (Live or from File)

1. Set Up Environment: Ensure Python is installed, then install required packages by running pip install torch torchvision torchaudio mediapipe opencv-python numpy.  
2. Prepare Files: Place the script in a directory with the Data-REHAB24-6 folder containing video files (e.g., Videos/Ex1/PM_000-Camera17-30fps.mp4) and the checkpoint file kp_pose_quality_windows_ex.pt in the same directory.  
3. Run the Script: Execute the script in a terminal or Jupyter notebook using python pose_inference.py.  
4. Select Exercise: When prompted, enter an exercise ID (1-6) to choose from the available exercises (e.g., 1 for "Arm abduction").
5. Choose Video Source: Enter 0 for file-based inference or 1 for live webcam. If choosing a file, enter the video ID (e.g., 000 for VIDEO_000).
6. View Feedback: The script will process the video, display pose landmarks, and provide real-time feedback ("Correct", "Incorrect", or joint-specific suggestions). Press q to exit.
7. Review Output: Check the terminal for logs and ensure resources are released upon completion.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import mediapipe as mp
import cv2
import numpy as np
from pathlib import Path
from collections import deque

# --- Configuration ---
SCRIPT_DIR = Path().resolve()
DATA_ROOT = SCRIPT_DIR / "Data-REHAB24-6"
CKPT_FILE = SCRIPT_DIR / "kp_pose_quality_windows_ex.pt"

DEVICE = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print(f"► Using device: {DEVICE}")

PoseLandmark = mp.solutions.pose.PoseLandmark
JOINT_NAMES = [lm.name for lm in PoseLandmark]
N_JOINTS = len(JOINT_NAMES)

EXERCISE_MAP = {
    1: "Arm abduction", 2: "Arm VW", 3: "Push-ups",
    4: "Leg abduction", 5: "Leg lunge", 6: "Squats"
}
NUM_EXERCISES = len(EXERCISE_MAP)

JOINT_TRIPLETS = {
    "LEFT_ELBOW":   (PoseLandmark.LEFT_SHOULDER, PoseLandmark.LEFT_ELBOW, PoseLandmark.LEFT_WRIST),
    "RIGHT_ELBOW":  (PoseLandmark.RIGHT_SHOULDER, PoseLandmark.RIGHT_ELBOW, PoseLandmark.RIGHT_WRIST),
    "LEFT_SHOULDER":  (PoseLandmark.LEFT_ELBOW, PoseLandmark.LEFT_SHOULDER, PoseLandmark.LEFT_HIP),
    "RIGHT_SHOULDER": (PoseLandmark.RIGHT_ELBOW, PoseLandmark.RIGHT_SHOULDER, PoseLandmark.RIGHT_HIP),
    "LEFT_HIP":   (PoseLandmark.LEFT_SHOULDER, PoseLandmark.LEFT_HIP, PoseLandmark.LEFT_KNEE),
    "RIGHT_HIP":  (PoseLandmark.RIGHT_SHOULDER, PoseLandmark.RIGHT_HIP, PoseLandmark.RIGHT_KNEE),
    "LEFT_KNEE":  (PoseLandmark.LEFT_HIP, PoseLandmark.LEFT_KNEE, PoseLandmark.LEFT_ANKLE),
    "RIGHT_KNEE": (PoseLandmark.RIGHT_HIP, PoseLandmark.RIGHT_KNEE, PoseLandmark.RIGHT_ANKLE),
    "SPINE":      (PoseLandmark.LEFT_HIP, PoseLandmark.LEFT_SHOULDER, PoseLandmark.RIGHT_SHOULDER),
    "HEAD":       (PoseLandmark.LEFT_SHOULDER, PoseLandmark.NOSE, PoseLandmark.RIGHT_SHOULDER),
    "LEFT_WRIST": (PoseLandmark.LEFT_ELBOW, PoseLandmark.LEFT_WRIST, PoseLandmark.LEFT_INDEX),
    "RIGHT_WRIST":(PoseLandmark.RIGHT_ELBOW, PoseLandmark.RIGHT_WRIST, PoseLandmark.RIGHT_INDEX),
    "LEFT_ANKLE": (PoseLandmark.LEFT_KNEE, PoseLandmark.LEFT_ANKLE, PoseLandmark.LEFT_FOOT_INDEX),
    "RIGHT_ANKLE":(PoseLandmark.RIGHT_KNEE, PoseLandmark.RIGHT_ANKLE, PoseLandmark.RIGHT_FOOT_INDEX),
}
ERR_JOINTS = list(JOINT_TRIPLETS.keys())
N_ERR = len(ERR_JOINTS)

REQUIRED_LANDMARK_INDICES_FOR_ERRORS = set()
for joint_name, landmarks in JOINT_TRIPLETS.items():
    for landmark in landmarks:
        REQUIRED_LANDMARK_INDICES_FOR_ERRORS.add(landmark.value)
print(f"ℹ️ Indices required for error angle visibility check: {sorted(list(REQUIRED_LANDMARK_INDICES_FOR_ERRORS))}")
N_REQUIRED_LANDMARKS = len(REQUIRED_LANDMARK_INDICES_FOR_ERRORS)

PUSHUP_JOINT_TRIPLETS = { 
    "RIGHT_ELBOW": (PoseLandmark.RIGHT_SHOULDER, PoseLandmark.RIGHT_ELBOW, PoseLandmark.RIGHT_WRIST),
    "RIGHT_SHOULDER": (PoseLandmark.RIGHT_ELBOW, PoseLandmark.RIGHT_SHOULDER, PoseLandmark.RIGHT_HIP),
    "SPINE": (PoseLandmark.LEFT_HIP, PoseLandmark.LEFT_SHOULDER, PoseLandmark.RIGHT_SHOULDER),
}
PUSHUP_REQUIRED_LANDMARKS = set()
for landmarks in PUSHUP_JOINT_TRIPLETS.values():
    for landmark in landmarks:
        PUSHUP_REQUIRED_LANDMARKS.add(landmark.value)

VISIBILITY_THRESHOLD = 0.8

# --- Model Definitions ---
# 5. Model definitions
class KeypointEncoder(nn.Module):
    def __init__(self, in_dim:int, embed:int=512):
        super().__init__()
        self.conv1 = nn.Conv1d(in_dim, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(128, embed, kernel_size=3, padding=1)
        self.pool  = nn.AdaptiveAvgPool1d(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, D); treat as (B, D, 1) for Conv1d
        x = x.unsqueeze(2)                 # → (B, D, 1)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        return self.pool(x).squeeze(-1)    # → (B, embed)

class PoseQualityNetKP(nn.Module):
    def __init__(self,
                 in_dim: int,
                 num_ex: int,
                 hidden: int = 256,
                 ex_emb: int = 64):
        super().__init__()
        # keypoint feature extractor
        self.encoder = KeypointEncoder(in_dim)

        # sequence model
        self.lstm = nn.LSTM(
            input_size=512,
            hidden_size=hidden,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        feat_dim = hidden * 2

        # exercise embedding MLP
        self.ex_emb = nn.Sequential(
            nn.Linear(num_ex, ex_emb),
            nn.ReLU(),
            nn.Linear(ex_emb, ex_emb)
        )

        # final heads
        self.cls_head = nn.Linear(feat_dim + ex_emb, 2)
        self.err_head = nn.Linear(feat_dim + ex_emb, N_ERR)

    def forward(self,
                seq:     torch.Tensor,  # (B, T, D)
                ex_1hot: torch.Tensor   # (B, num_ex)
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # 1) keypoint → sequence feats
        # encode each frame
        B,T,_ = seq.shape
        feats = torch.stack([
            self.encoder(seq[:,t]) for t in range(T)
        ], dim=1)                                # (B, T, 512)
        out, _ = self.lstm(feats)                # (B, T, 2*hidden)
        g = out.mean(1)                          # (B, 2*hidden)

        # 2) exercise embed
        ex_e = self.ex_emb(ex_1hot)              # (B, ex_emb)

        # 3) concat and heads
        h = torch.cat([g, ex_e], dim=1)          # (B, feat_dim+ex_emb)
        return self.cls_head(h), self.err_head(h)


# --- User Input for Exercise ---
while True:
    print("\nAvailable Exercises:")
    for id, name in EXERCISE_MAP.items(): print(f"  {id}: {name}")
    try:
        exercise_id_str = input(f"► Enter the exercise ID (1-{NUM_EXERCISES}): ")
        exercise_id = int(exercise_id_str)
        if 1 <= exercise_id <= NUM_EXERCISES:
            exercise_name = EXERCISE_MAP[exercise_id]
            print(f"Selected: {exercise_id} - {exercise_name}")
            break
        else:
            print(f"Invalid ID.")
    except ValueError:
        print("Invalid input.")
    except EOFError:
        print("\nCancelled.")
        exit()

# --- MediaPipe Pose Setup ---
print("⏳ Initializing MediaPipe Pose...")
try:
    mp_pose = mp.solutions.pose
    pose = mp_pose.Pose(
        static_image_mode=False, model_complexity=2, enable_segmentation=False,
        min_detection_confidence=0.8, min_tracking_confidence=0.8
    )
    mp_drawing = mp.solutions.drawing_utils
    print("✅ MediaPipe Pose initialized.")
except Exception as e:
    print(f"❌ Error initializing MediaPipe Pose: {e}")
    exit()

# --- Helper Function: Extract Keypoints ---
def extract_keypoints(frame, exercise_id):
    img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img_rgb.flags.writeable = False
    result = pose.process(img_rgb)
    img_rgb.flags.writeable = True

    world_keypoints = None
    image_landmarks_for_drawing = None
    all_required_visible = False

    required_landmarks = (
        PUSHUP_REQUIRED_LANDMARKS if exercise_id == 3 else REQUIRED_LANDMARK_INDICES_FOR_ERRORS
    )
    n_required = len(required_landmarks)

    if result.pose_landmarks:
        image_landmarks_for_drawing = result.pose_landmarks
        num_visible_required = sum(
            1 for index in required_landmarks
            if index < len(image_landmarks_for_drawing.landmark)
            and image_landmarks_for_drawing.landmark[index].visibility >= VISIBILITY_THRESHOLD
        )
        all_required_visible = num_visible_required == n_required
        if all_required_visible and result.pose_world_landmarks:
            world_landmarks = result.pose_world_landmarks.landmark
            world_keypoints = np.array([(lm.x, lm.y, lm.z) for lm in world_landmarks], dtype=np.float32)

    return world_keypoints, image_landmarks_for_drawing, all_required_visible

# --- Inference Parameters ---
SEQUENCE_LENGTH = 16
IN_DIM = N_JOINTS * 3

# --- Core Inference and Feedback Function ---
def infer_and_feedback(model, video_source, selected_ex_id, selected_ex_name):
    if isinstance(video_source, str) and not Path(video_source).exists():
        print(f"❌ Error: Video file not found: {video_source}")
        return
    try:
        cap = cv2.VideoCapture(video_source)
        if not cap.isOpened():
            print(f"❌ Error: Could not open video source '{video_source}'")
            return
        # --- New: Set and verify webcam frame rate ---
        if isinstance(video_source, int):  # Webcam
            cap.set(cv2.CAP_PROP_FPS, 30)  # Attempt to set to 30 FPS
            webcam_fps = cap.get(cv2.CAP_PROP_FPS) or 30  # Fallback to 30 if not available
            print(f"Webcam FPS: {webcam_fps}")
            target_fps = 30
            frame_interval = max(1, round(webcam_fps / target_fps))  # Frames to skip
            print(f"Frame interval: {frame_interval} (processing every {frame_interval}th frame)")
        else:
            frame_interval = 1  # No skipping for video files (assumed 30 FPS)
        frame_count = 0  # Track frame count for skipping
    except Exception as e:
        print(f"❌ Error opening video source: {e}")
        return

    source_name = "camera" if isinstance(video_source, int) else Path(video_source).name
    print(f"🚀 Starting feedback loop using: {source_name} (Press 'q' to quit)")

    keypoints_buffer = deque(maxlen=SEQUENCE_LENGTH)
    feedback = "Initializing..."
    err_values = np.zeros(N_ERR)
    predicted_class = 0
    suggestion = ""
    suggestion_time = 0
    SUGGESTION_DURATION_FRAMES = 3 * 30

    ERROR_WARNING_THRESHOLD = 10 #Minimum error angle for joint warning display

    FONT_FACE = cv2.FONT_HERSHEY_SIMPLEX
    FONT_SCALE_INFO = 1
    FONT_SCALE_FEEDBACK = 1
    FONT_SCALE_SUGGESTION = 1
    FONT_SCALE_WARNING = 1
    FONT_THICKNESS = 2
    JOINT_WARNING_SPACING = 30

    window_name = 'Pose Estimation Feedback'
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            print("\nℹ️ End of video stream.")
            break

        frame_count += 1
        if frame_count % frame_interval != 0:
            continue  # Skip this frame
        
        world_keypoints, image_landmarks_for_drawing, all_required_visible = extract_keypoints(frame, selected_ex_id)

        if image_landmarks_for_drawing:
            mp_drawing.draw_landmarks(
                frame, image_landmarks_for_drawing, mp_pose.POSE_CONNECTIONS,
                landmark_drawing_spec=mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
                connection_drawing_spec=mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
            )

        current_feedback_state = "OK"
        if image_landmarks_for_drawing is None:
            current_feedback_state = "NO_POSE"
        elif not all_required_visible:
            current_feedback_state = "ADJUST_POSTURE"
        elif world_keypoints is None and all_required_visible:
            current_feedback_state = "NO_WORLD_LANDMARKS"

        if current_feedback_state == "NO_POSE":
            feedback = "No pose detected"
            keypoints_buffer.clear()
            predicted_class = 0
            suggestion = ""
            suggestion_time = 0
        elif current_feedback_state == "ADJUST_POSTURE":
            feedback = "Adjust posture"
            keypoints_buffer.clear()
            predicted_class = 0
            suggestion = ""
            suggestion_time = 0
        elif current_feedback_state == "NO_WORLD_LANDMARKS":
            feedback = "World landmarks missing"
            keypoints_buffer.clear()
            predicted_class = 0
            suggestion = ""
            suggestion_time = 0
        else:
            keypoints_flat_frame = world_keypoints.flatten()
            keypoints_buffer.append(keypoints_flat_frame)

            if len(keypoints_buffer) == SEQUENCE_LENGTH:
                seq_np = np.array(keypoints_buffer, dtype=np.float32)
                seq = torch.tensor(seq_np, dtype=torch.float32).unsqueeze(0).to(DEVICE)
                ex_tensor = torch.tensor([selected_ex_id - 1], device=DEVICE)
                ex_1hot = F.one_hot(ex_tensor, num_classes=NUM_EXERCISES).float()

                model.eval()
                with torch.no_grad():
                    logits, err_hat = model(seq, ex_1hot)
                    predicted_class = logits.argmax(1).item()
                    err_values = err_hat.squeeze().cpu().numpy()

                feedback = "Correct" if predicted_class == 1 else "Incorrect"

                if predicted_class == 0:
                    if np.any(np.abs(err_values) > 1e-3):
                        max_error_idx = np.argmax(np.abs(err_values))
                        joint_with_error = ERR_JOINTS[max_error_idx]
                        max_error = err_values[max_error_idx]
                        suggestion = f"Check {joint_with_error.replace('_', ' ')} (Dev: {max_error:+.1f}°)"
                        suggestion_time = SUGGESTION_DURATION_FRAMES
                    else:
                        suggestion = "Check Form"
                        suggestion_time = SUGGESTION_DURATION_FRAMES
                else:
                    suggestion = ""
                    suggestion_time = 0
            else:
                feedback = f"Analysing {len(keypoints_buffer)}/{SEQUENCE_LENGTH}"
                suggestion = ""
                suggestion_time = 0
                predicted_class = 0

        info_y = 40
        cv2.putText(frame, f"{selected_ex_name}", (15, info_y), FONT_FACE,
                    FONT_SCALE_INFO, (255, 255, 255), FONT_THICKNESS, cv2.LINE_AA)

        feedback_y = info_y + 50
        if current_feedback_state == "OK" and len(keypoints_buffer) == SEQUENCE_LENGTH:
            feedback_color = (0, 255, 0) if predicted_class == 1 else (0, 0, 255)
        elif current_feedback_state == "ADJUST_POSTURE" or current_feedback_state == "NO_WORLD_LANDMARKS":
            feedback_color = (0, 165, 255)
        else:
            feedback_color = (0, 0, 255)

        cv2.putText(frame, f"{feedback}", (15, feedback_y), FONT_FACE,
                    FONT_SCALE_FEEDBACK, feedback_color, FONT_THICKNESS + 1, cv2.LINE_AA)

        suggestion_y = feedback_y + 50
        if suggestion_time > 0:
            cv2.putText(frame, suggestion, (15, suggestion_y), FONT_FACE,
                        FONT_SCALE_SUGGESTION, (0, 255, 255), FONT_THICKNESS, cv2.LINE_AA)
            suggestion_time -= 1
            feedback_y_start = suggestion_y + 50
        else:
            feedback_y_start = feedback_y + 50

        joint_warning_count = 0
        if current_feedback_state == "OK" and len(keypoints_buffer) == SEQUENCE_LENGTH:
            sorted_error_indices = np.argsort(np.abs(err_values))[::-1]
            for idx in sorted_error_indices:
                warning_y = feedback_y_start + joint_warning_count * JOINT_WARNING_SPACING
                if joint_warning_count >= 4:
                    cv2.putText(frame, "...", (15, warning_y), FONT_FACE, FONT_SCALE_WARNING, (0, 165, 255), FONT_THICKNESS, cv2.LINE_AA)
                    break
                joint_name = ERR_JOINTS[idx]
                err_val = err_values[idx]
                if abs(err_val) > ERROR_WARNING_THRESHOLD:
                    color = (0, 165, 255)
                    text = f"{joint_name}: {err_val:+.0f}°"
                    cv2.putText(frame, text, (15, warning_y), FONT_FACE,
                                FONT_SCALE_WARNING, color, FONT_THICKNESS, cv2.LINE_AA)
                    joint_warning_count += 1

        cv2.imshow(window_name, frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            print("\n'q' pressed. Exiting.")
            break

    cap.release()
    cv2.destroyAllWindows()
    pose.close()
    print("✅ Resources released.")

# --- Main Execution ---
if __name__ == "__main__":
    # Ex1 videos
    VIDEO_000 = DATA_ROOT / "Videos" / "Ex1" / "PM_000-Camera17-30fps.mp4"
    VIDEO_001 = DATA_ROOT / "Videos" / "Ex1" / "PM_001-Camera17-30fps.mp4"
    VIDEO_002 = DATA_ROOT / "Videos" / "Ex1" / "PM_002-Camera17-30fps.mp4"
    VIDEO_012 = DATA_ROOT / "Videos" / "Ex1" / "PM_012-Camera17-30fps.mp4"
    VIDEO_016 = DATA_ROOT / "Videos" / "Ex1" / "PM_016-Camera17-30fps.mp4"
    VIDEO_023 = DATA_ROOT / "Videos" / "Ex1" / "PM_023-Camera17-30fps.mp4"
    VIDEO_024 = DATA_ROOT / "Videos" / "Ex1" / "PM_024-Camera17-30fps.mp4"
    VIDEO_032 = DATA_ROOT / "Videos" / "Ex1" / "PM_032-Camera17-30fps.mp4"
    VIDEO_039 = DATA_ROOT / "Videos" / "Ex1" / "PM_039-Camera17-30fps.mp4"
    VIDEO_100 = DATA_ROOT / "Videos" / "Ex1" / "PM_100-Camera17-30fps.mp4"
    VIDEO_109 = DATA_ROOT / "Videos" / "Ex1" / "PM_109-Camera17-30fps.mp4"
    VIDEO_114 = DATA_ROOT / "Videos" / "Ex1" / "PM_114-Camera17-30fps.mp4"
    VIDEO_122 = DATA_ROOT / "Videos" / "Ex1" / "PM_122-Camera17-30fps.mp4"

    # Ex2 videos
    VIDEO_003 = DATA_ROOT / "Videos" / "Ex2" / "PM_003-Camera17-30fps.mp4"
    VIDEO_004 = DATA_ROOT / "Videos" / "Ex2" / "PM_004-Camera17-30fps.mp4"
    VIDEO_013 = DATA_ROOT / "Videos" / "Ex2" / "PM_013-Camera17-30fps.mp4"
    VIDEO_014 = DATA_ROOT / "Videos" / "Ex2" / "PM_014-Camera17-30fps.mp4"
    VIDEO_025 = DATA_ROOT / "Videos" / "Ex2" / "PM_025-Camera17-30fps.mp4"
    VIDEO_026 = DATA_ROOT / "Videos" / "Ex2" / "PM_026-Camera17-30fps.mp4"
    VIDEO_033 = DATA_ROOT / "Videos" / "Ex2" / "PM_033-Camera17-30fps.mp4"
    VIDEO_040 = DATA_ROOT / "Videos" / "Ex2" / "PM_040-Camera17-30fps.mp4"
    VIDEO_102 = DATA_ROOT / "Videos" / "Ex2" / "PM_102-Camera17-30fps.mp4"
    VIDEO_110 = DATA_ROOT / "Videos" / "Ex2" / "PM_110-Camera17-30fps.mp4"
    VIDEO_115 = DATA_ROOT / "Videos" / "Ex2" / "PM_115-Camera17-30fps.mp4"
    VIDEO_123 = DATA_ROOT / "Videos" / "Ex2" / "PM_123-Camera17-30fps.mp4"

    # Ex3 videos
    VIDEO_010 = DATA_ROOT / "Videos" / "Ex3" / "PM_010-Camera17-30fps.mp4"
    VIDEO_011 = DATA_ROOT / "Videos" / "Ex3" / "PM_011-Camera17-30fps.mp4"
    VIDEO_030 = DATA_ROOT / "Videos" / "Ex3" / "PM_030-Camera17-30fps.mp4"
    VIDEO_031 = DATA_ROOT / "Videos" / "Ex3" / "PM_031-Camera17-30fps.mp4"
    VIDEO_044 = DATA_ROOT / "Videos" / "Ex3" / "PM_044-Camera17-30fps.mp4"
    VIDEO_045 = DATA_ROOT / "Videos" / "Ex3" / "PM_045-Camera17-30fps.mp4"
    VIDEO_107 = DATA_ROOT / "Videos" / "Ex3" / "PM_107-Camera17-30fps.mp4"
    VIDEO_108 = DATA_ROOT / "Videos" / "Ex3" / "PM_108-Camera17-30fps.mp4"
    VIDEO_119 = DATA_ROOT / "Videos" / "Ex3" / "PM_119-Camera17-30fps.mp4"
    VIDEO_121 = DATA_ROOT / "Videos" / "Ex3" / "PM_121-Camera17-30fps.mp4"

    # Ex4 videos
    VIDEO_005 = DATA_ROOT / "Videos" / "Ex4" / "PM_005-Camera17-30fps.mp4"
    VIDEO_006 = DATA_ROOT / "Videos" / "Ex4" / "PM_006-Camera17-30fps.mp4"
    VIDEO_018 = DATA_ROOT / "Videos" / "Ex4" / "PM_018-Camera17-30fps.mp4"
    VIDEO_020 = DATA_ROOT / "Videos" / "Ex4" / "PM_020-Camera17-30fps.mp4"
    VIDEO_027 = DATA_ROOT / "Videos" / "Ex4" / "PM_027-Camera17-30fps.mp4"
    VIDEO_034 = DATA_ROOT / "Videos" / "Ex4" / "PM_034-Camera17-30fps.mp4"
    VIDEO_035 = DATA_ROOT / "Videos" / "Ex4" / "PM_035-Camera17-30fps.mp4"
    VIDEO_041 = DATA_ROOT / "Videos" / "Ex4" / "PM_041-Camera17-30fps.mp4"
    VIDEO_103 = DATA_ROOT / "Videos" / "Ex4" / "PM_103-Camera17-30fps.mp4"
    VIDEO_111 = DATA_ROOT / "Videos" / "Ex4" / "PM_111-Camera17-30fps.mp4"
    VIDEO_116 = DATA_ROOT / "Videos" / "Ex4" / "PM_116-Camera17-30fps.mp4"
    VIDEO_124 = DATA_ROOT / "Videos" / "Ex4" / "PM_124-Camera17-30fps.mp4"

    # Ex5 videos
    VIDEO_021 = DATA_ROOT / "Videos" / "Ex5" / "PM_021-Camera17-30fps.mp4"
    VIDEO_028 = DATA_ROOT / "Videos" / "Ex5" / "PM_028-Camera17-30fps.mp4"
    VIDEO_037 = DATA_ROOT / "Videos" / "Ex5" / "PM_037-Camera17-30fps.mp4"
    VIDEO_042 = DATA_ROOT / "Videos" / "Ex5" / "PM_042-Camera17-30fps.mp4"
    VIDEO_104 = DATA_ROOT / "Videos" / "Ex5" / "PM_104-Camera17-30fps.mp4"
    VIDEO_112 = DATA_ROOT / "Videos" / "Ex5" / "PM_112-Camera17-30fps.mp4"
    VIDEO_117a = DATA_ROOT / "Videos" / "Ex5" / "PM_117a-Camera17-30fps.mp4"
    VIDEO_117b = DATA_ROOT / "Videos" / "Ex5" / "PM_117b-Camera17-30fps.mp4"
    VIDEO_125 = DATA_ROOT / "Videos" / "Ex5" / "PM_125-Camera17-30fps.mp4"

    # Ex6 videos
    VIDEO_008 = DATA_ROOT / "Videos" / "Ex6" / "PM_008-Camera17-30fps.mp4"
    VIDEO_022 = DATA_ROOT / "Videos" / "Ex6" / "PM_022-Camera17-30fps.mp4"
    VIDEO_029 = DATA_ROOT / "Videos" / "Ex6" / "PM_029-Camera17-30fps.mp4"
    VIDEO_038 = DATA_ROOT / "Videos" / "Ex6" / "PM_038-Camera17-30fps.mp4"
    VIDEO_043 = DATA_ROOT / "Videos" / "Ex6" / "PM_043-Camera17-30fps.mp4"
    VIDEO_105 = DATA_ROOT / "Videos" / "Ex6" / "PM_105-Camera17-30fps.mp4"
    VIDEO_113 = DATA_ROOT / "Videos" / "Ex6" / "PM_113-Camera17-30fps.mp4"
    VIDEO_118 = DATA_ROOT / "Videos" / "Ex6" / "PM_118-Camera17-30fps.mp4"
    VIDEO_126 = DATA_ROOT / "Videos" / "Ex6" / "PM_126-Camera17-30fps.mp4"

    # Prompt for video source selection
    while True:
        print("\nPerform inference from:")
        print("  0: File")
        print("  1: Laptop camera (live)")
        try:
            source_choice = input("► Enter choice (0 for file, 1 for laptop camera): ")
            source_choice = int(source_choice)
            if source_choice in [0, 1]:
                break
            else:
                print("Invalid choice. Please enter 0 or 1.")
        except ValueError:
            print("Invalid input. Please enter 0 or 1.")
        except EOFError:
            print("\nCancelled.")
            exit()

    # Handle video source based on choice
    if source_choice == 0:  # File-based inference
        while True:
            try:
                video_id = input("► Enter video ID (e.g., '000'): ")
                video_var_name = f"VIDEO_{video_id}"
                if video_var_name in globals():
                    VIDEO_PATH = globals()[video_var_name]
                    if VIDEO_PATH.exists():
                        VIDEO_SOURCE = str(VIDEO_PATH)
                        print(f"Selected video: {VIDEO_PATH.name}")
                        break
                    else:
                        print(f"Video file not found: {VIDEO_PATH}")
                else:
                    print(f"Video ID '{video_id}' not found. Please enter a valid ID.")
            except EOFError:
                print("\nCancelled.")
                exit()
    else:  # Live camera inference
        VIDEO_SOURCE = 0  # Webcam
        print("Using laptop camera for live inference.")

    # --- Model Loading ---
    if not CKPT_FILE.exists():
        print(f"❌ Error: Checkpoint file not found at {CKPT_FILE}")
        exit()
    print(f"⏳ Loading model from {CKPT_FILE}...")
    try:
        infer_model = torch.load(CKPT_FILE, map_location=DEVICE)
        if isinstance(infer_model, dict):
            print("Loaded state_dict, instantiating model...")
            model_state_dict = infer_model
            IN_DIM = N_JOINTS * 3
            infer_model = PoseQualityNetKP(in_dim=IN_DIM, num_ex=NUM_EXERCISES).to(DEVICE)
            infer_model.load_state_dict(model_state_dict)
            print("Model instantiated and state_dict loaded.")
        elif not isinstance(infer_model, nn.Module):
            raise TypeError(f"Loaded object is not a nn.Module or state_dict: {type(infer_model)}")
        infer_model.eval()
        print("✅ Model loaded successfully and set to eval mode.")
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        exit()

    try:
        infer_and_feedback(infer_model, VIDEO_SOURCE, exercise_id, exercise_name)
    except NameError as e:
        print(f"❌ NameError: {e}")
    except Exception as e:
        print(f"❌ An unexpected error occurred: {e}")
    print("\nScript finished.")

► Using device: mps
ℹ️ Indices required for error angle visibility check: [0, 11, 12, 13, 14, 15, 16, 19, 20, 23, 24, 25, 26, 27, 28, 31, 32]

Available Exercises:
  1: Arm abduction
  2: Arm VW
  3: Push-ups
  4: Leg abduction
  5: Leg lunge
  6: Squats


I0000 00:00:1745192644.032706  612322 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.3), renderer: Apple M4 Max


Selected: 1 - Arm abduction
⏳ Initializing MediaPipe Pose...
✅ MediaPipe Pose initialized.

Perform inference from:
  0: File
  1: Laptop camera (live)


W0000 00:00:1745192644.094704  617701 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1745192644.115260  617701 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


Using laptop camera for live inference.
⏳ Loading model from /Users/jithinkrishnan/Documents/Study/IS06 /MVP/RehabApp/model-training-scripts/kp_pose_quality_windows_ex.pt...
✅ Model loaded successfully and set to eval mode.


  infer_model = torch.load(CKPT_FILE, map_location=DEVICE)


Webcam FPS: 30.0
Frame interval: 1 (processing every 1th frame)
🚀 Starting feedback loop using: camera (Press 'q' to quit)

'q' pressed. Exiting.
✅ Resources released.

Script finished.
