# 1. Imports

In [None]:
import os
import cv2
import torch
import numpy as np
import torch.nn.functional as F
import mediapipe as mp
from collections import deque
from pathlib import Path
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from tqdm import tqdm
import math
from torch.utils.data import Dataset 

# 2. Keypoint Extraction from Training Videos with MediaPipe  
Before training a model for tasks like pose estimation, exercise recognition, or movement analysis, the first step is to extract relevant features from the data. In this case, the data consists of training videos, and the relevant features are pose keypoints—3D coordinates representing specific body parts (like shoulders, elbows, knees, etc.).  

This code uses MediaPipe Pose to extract keypoints from training videos. These keypoints are critical for model training as they represent the underlying body movements that the model will need to learn. Below is a breakdown of the process:

## 2.1 Keypoint extraction and saving as numpy files

In [2]:
# # Keypoint extraction from Training Videos With Mediapipe

# # 0. Quiet TensorFlow/absl
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# # 1. MediaPipe Pose setup
# pose = mp.solutions.pose.Pose(
#     static_image_mode=False,
#     model_complexity=2,
#     enable_segmentation=False,
#     min_detection_confidence=0.5,
#     min_tracking_confidence=0.5
# )

# # 2. Paths
# VIDEO_ROOT = Path("Data-REHAB24-6/videos")
# OUT_ROOT   = Path("Data-REHAB24-6/mp_keypoints")
# OUT_ROOT.mkdir(exist_ok=True)

# # 3. Worker
# def process_video(vid_path: Path):
#     rel     = vid_path.parent.name            # e.g. "Ex1"
#     out_dir = OUT_ROOT / rel
#     out_dir.mkdir(parents=True, exist_ok=True)
#     out_file = out_dir / f"{vid_path.stem}-mp.npy"

#     print(f"\n→ Processing: {vid_path.name}")
#     print(f"   From:      {vid_path.parent}")
#     print(f"   To folder: {out_dir}")

#     cap    = cv2.VideoCapture(str(vid_path))
#     frames = []
#     count  = 0

#     while True:
#         ret, frame = cap.read()
#         if not ret:
#             break
#         count += 1

#         img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#         res = pose.process(img)
#         lm  = res.pose_world_landmarks.landmark if res.pose_world_landmarks else []

#         if lm:
#             pts = [(p.x, p.y, p.z) for p in lm]
#         else:
#             pts = [(0.0, 0.0, 0.0)] * 33

#         frames.append(pts)

#     cap.release()

#     arr = np.array(frames, dtype=np.float32)
#     np.save(out_file, arr)

#     print(f"✔ Saved: {out_file}  (frames={count}, shape={arr.shape})")

# # 4. Run — only Ex1 through Ex5
# for i in range(1, 6):
#     ex_dir = VIDEO_ROOT / f"Ex{i}"
#     if not ex_dir.is_dir():
#         print(f"⚠️  Skipping missing folder: {ex_dir}")
#         continue

#     for vid in sorted(ex_dir.glob("*.mp4")):
#         try:
#             process_video(vid)
#         except Exception as e:
#             print(f"✘ Failed processing {vid.name}: {e}")

# print("\nAll requested videos processed.")


## 2.2 Inspecting at a numpy file containing 3D keypoints

In [3]:
# ─── edit this to your target file ────────────────────────────────────────────
file_path = Path("Data-REHAB24-6/mp_keypoints/Ex6/PM_008-Camera17-30fps-mp.npy")
# ────────────────────────────────────────────────────────────────────────────────

# load
arr = np.load(file_path)

# basic info
print(f"Loaded: {file_path}")
print(f" dtype: {arr.dtype}")
print(f" shape: {arr.shape}  (frames × landmarks × coords)\n")

# show first few frames
n_show = min(3, arr.shape[0])
for i in range(n_show):
    print(f"Frame #{i:03d} (33×3):")
    print(arr[i])
    print(f"  → first landmark: {tuple(arr[i,0])}\n")

# overall statistics
print("Overall coordinate stats:")
for idx, name in enumerate(("x", "y", "z")):
    data = arr[..., idx]
    print(f"  {name}: min={data.min():.3f}, max={data.max():.3f}, mean={data.mean():.3f}")


Loaded: Data-REHAB24-6/mp_keypoints/Ex6/PM_008-Camera17-30fps-mp.npy
 dtype: float32
 shape: (5191, 33, 3)  (frames × landmarks × coords)

Frame #000 (33×3):
[[-0.03379065 -0.6112996  -0.22619084]
 [-0.02455677 -0.628223   -0.2275483 ]
 [-0.0257254  -0.6304051  -0.21732828]
 [-0.02547118 -0.6302204  -0.21820049]
 [-0.02778288 -0.6363151  -0.2358914 ]
 [-0.02677054 -0.63463634 -0.24706481]
 [-0.02263773 -0.61979294 -0.2281486 ]
 [ 0.02648694 -0.6184615  -0.16859515]
 [-0.03559308 -0.56201506 -0.14808732]
 [ 0.0030987  -0.59678274 -0.19447449]
 [-0.01712019 -0.55977213 -0.21580447]
 [ 0.1250833  -0.49333623 -0.08702794]
 [-0.05722423 -0.53108674 -0.02576461]
 [ 0.15178505 -0.51440114 -0.09251688]
 [-0.17247145 -0.4994753  -0.05377672]
 [ 0.14430666 -0.54461473 -0.03090633]
 [-0.32217076 -0.5845331  -0.08351779]
 [ 0.1352469  -0.5789426  -0.02217976]
 [-0.33411983 -0.62724733 -0.13406767]
 [ 0.11857966 -0.58959824 -0.04421883]
 [-0.30619952 -0.65674794 -0.14033641]
 [ 0.13962431 -0.533594

# 3. Data Preparation for Pose Error Analysis  (Sliding Window-Based Pose Error Calculation for Video segments)

This script processes keypoint data from videos to compute errors in body joint angles relative to "ideal" angles, and then generates sliding windows of these errors. This is done to prepare features that will later be used for model training in tasks like exercise recognition or pose correction.

What is done: For each exercise in the dataset, the script calculates the "ideal" joint angles by selecting the middle frame of each correct repetition (based on the correctness label). It calculates the angles between the joints defined in JOINT_TRIPLETS for each of these frames.  

The median angle is then computed for each joint across the correct repetitions of the exercise. These median values represent the "ideal" angles the model should aim for in perfect executions of the exercise.  

In the next step, the script generates sliding windows of angular errors, calculated as the difference between the observed angles in the video and the ideal angles. These windows contain temporal sequences of error data that are used as features for model training. The windowed data is then saved in a CSV file, which includes additional metadata such as exercise ID, repetition number, and frame indices.  

Why it’s done: Calculating the ideal angles for each exercise provides a reference for identifying errors during subsequent video frames. These ideal angles will serve as the baseline for assessing whether a movement is performed correctly or incorrectly. By generating and saving the sliding window data, the script prepares the dataset for supervised learning, allowing the model to analyze temporal error patterns over a series of frames. This windowed approach is crucial for the model to learn dynamic movements and classify whether exercises are performed correctly based on the computed joint angles. The segmentation data in the CSV file offers a structured and labeled dataset, which helps in efficient training and evaluation of the model.  


In [4]:
# # 1. helpers --------------------------------------------------
# def angle_between(a,b,c):
#     BA = a-b; BC = c-b
#     cosθ = np.dot(BA,BC)/(np.linalg.norm(BA)*np.linalg.norm(BC))
#     return math.degrees(math.acos(np.clip(cosθ,-1,1)))

# PoseLandmark = mp.solutions.pose.PoseLandmark
# JOINT_TRIPLETS = {
#     "LEFT_ELBOW":   (PoseLandmark.LEFT_SHOULDER.value,
#                      PoseLandmark.LEFT_ELBOW.value,
#                      PoseLandmark.LEFT_WRIST.value),
#     "RIGHT_ELBOW":  (PoseLandmark.RIGHT_SHOULDER.value,
#                      PoseLandmark.RIGHT_ELBOW.value,
#                      PoseLandmark.RIGHT_WRIST.value),
#     "LEFT_SHOULDER":  (PoseLandmark.LEFT_ELBOW.value,
#                        PoseLandmark.LEFT_SHOULDER.value,
#                        PoseLandmark.LEFT_HIP.value),
#     "RIGHT_SHOULDER": (PoseLandmark.RIGHT_ELBOW.value,
#                        PoseLandmark.RIGHT_SHOULDER.value,
#                        PoseLandmark.RIGHT_HIP.value),
#     "LEFT_HIP":   (PoseLandmark.LEFT_SHOULDER.value,
#                    PoseLandmark.LEFT_HIP.value,
#                    PoseLandmark.LEFT_KNEE.value),
#     "RIGHT_HIP":  (PoseLandmark.RIGHT_SHOULDER.value,
#                    PoseLandmark.RIGHT_HIP.value,
#                    PoseLandmark.RIGHT_KNEE.value),
#     "LEFT_KNEE":  (PoseLandmark.LEFT_HIP.value,
#                   PoseLandmark.LEFT_KNEE.value,
#                   PoseLandmark.LEFT_ANKLE.value),
#     "RIGHT_KNEE": (PoseLandmark.RIGHT_HIP.value,
#                   PoseLandmark.RIGHT_KNEE.value,
#                   PoseLandmark.RIGHT_ANKLE.value),
#     "SPINE": (
#        PoseLandmark.LEFT_HIP.value,       
#        PoseLandmark.LEFT_SHOULDER.value,   
#        PoseLandmark.RIGHT_SHOULDER.value   
#     ),
#     "HEAD": (
#        PoseLandmark.LEFT_SHOULDER.value,
#        PoseLandmark.NOSE.value,
#        PoseLandmark.RIGHT_SHOULDER.value
#     ),
# }
# ERR_JOINTS = list(JOINT_TRIPLETS.keys())
# N_ERR = len(ERR_JOINTS)  # 10

# # 2. load original metadata & keypoints -----------------------
# DATA_ROOT    = Path("Data-REHAB24-6")
# KEYPT_ROOT   = DATA_ROOT/"mp_keypoints"
# META_ORIG    = DATA_ROOT/"Segmentation.xlsx"
# df           = pd.read_excel(META_ORIG, engine="openpyxl")
# df.columns   = df.columns.str.strip()

# # 3. compute ideal_angles on correct reps ----------------------
# ideal_angles = {}
# correct = df[df.correctness==1]
# for ex in correct.exercise_id.unique():
#     all_ang = {jn:[] for jn in ERR_JOINTS}
#     for _,r in correct[correct.exercise_id==ex].iterrows():
#         vid, f0, f1 = r.video_id, int(r.first_frame), int(r.last_frame)
#         files = list((KEYPT_ROOT/f"Ex{ex}").glob(f"{vid}-Camera17*-mp.npy"))
#         if not files: continue
#         arr = np.load(files[0])
#         seg = arr[f0:f1] if f1>f0 else arr[f0:]
#         if len(seg)==0: continue
#         mid = len(seg)//2
#         frm = seg[mid]
#         for jn in ERR_JOINTS:
#             ia,ib,ic = JOINT_TRIPLETS[jn]
#             ang = angle_between(frm[ia,:2],frm[ib,:2],frm[ic,:2])
#             all_ang[jn].append(ang)
#     # median
#     ideal_angles[ex] = {jn:float(np.median(all_ang[jn])) for jn in all_ang if all_ang[jn]}

# # 4. slide windows & write rows --------------------------------
# WINDOW, STRIDE = 16, 8
# rows = []
# for _,r in df.iterrows():
#     vid, ex, f0, f1 = r.video_id, int(r.exercise_id), int(r.first_frame), int(r.last_frame)
#     files = list((KEYPT_ROOT/f"Ex{ex}").glob(f"{vid}-Camera17*-mp.npy"))
#     if not files: continue
#     arr = np.load(files[0])                # (F,33,3)
#     seg = arr[f0:f1] if f1>f0 else arr[f0:]
#     if len(seg)<WINDOW: continue

#     # per-frame errors
#     pf_err = {jn:[] for jn in ERR_JOINTS}
#     for frm in seg:
#         for jn in ERR_JOINTS:
#             ia,ib,ic = JOINT_TRIPLETS[jn]
#             ang = angle_between(frm[ia,:2],frm[ib,:2],frm[ic,:2])
#             pf_err[jn].append(ang - ideal_angles[ex].get(jn,ang))

#     # slide
#     for start in range(0, len(seg)-WINDOW+1, STRIDE):
#         w = np.array([ pf_err[jn][start:start+WINDOW] for jn in ERR_JOINTS ])  # (10,WINDOW)
#         mean_err = w.mean(axis=1)
#         row = {
#             "video_id":vid,
#             "exercise_id":ex,
#             "repetition_number":r.repetition_number,
#             "window_start": f0+start,
#             "window_end":   f0+start+WINDOW,
#             "correctness":  r.correctness
#         }
#         for i,jn in enumerate(ERR_JOINTS):
#             row[f"err_{i}"] = float(mean_err[i])
#         rows.append(row)

# win_df = pd.DataFrame(rows)
# win_df.to_csv(DATA_ROOT/"Segmentation_windows.csv", index=False)
# print("Wrote", len(win_df), "windows to Segmentation_windows.csv")


# 4. Traning Setup

## 4.1 Paths & device

In [5]:

SCRIPT_DIR    = Path().resolve()
DATA_ROOT     = SCRIPT_DIR/"Data-REHAB24-6"
WIN_CSV       = DATA_ROOT/"Segmentation_windows.csv"
KEYPT_ROOT    = DATA_ROOT/"mp_keypoints"

DEVICE = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print("► Using device:", DEVICE)

► Using device: mps


## 4.2 Joint names & count

In [6]:
PoseLandmark = mp.solutions.pose.PoseLandmark

# Then:
JOINT_NAMES = [lm.name for lm in PoseLandmark]
N_JOINTS    = len(JOINT_NAMES)  # should be 33

print(f"JOINT_NAMES: {JOINT_NAMES}")
print(f"N_JOINTS: {N_JOINTS}")

#  Exerciseses (Ex1…Ex6)
NUM_EXERCISES = 6
CKPT_FILE     = "kp_pose_quality_windows_ex.pt"  

ERR_JOINTS   = [
  "LEFT_ELBOW","RIGHT_ELBOW",
  "LEFT_SHOULDER","RIGHT_SHOULDER",
  "LEFT_HIP","RIGHT_HIP",
  "LEFT_KNEE","RIGHT_KNEE",
  "SPINE","HEAD",
]
N_ERR = len(ERR_JOINTS)   # 10
ERR_COLS = [f"err_{i}" for i in range(N_ERR)]


JOINT_NAMES: ['NOSE', 'LEFT_EYE_INNER', 'LEFT_EYE', 'LEFT_EYE_OUTER', 'RIGHT_EYE_INNER', 'RIGHT_EYE', 'RIGHT_EYE_OUTER', 'LEFT_EAR', 'RIGHT_EAR', 'MOUTH_LEFT', 'MOUTH_RIGHT', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW', 'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'LEFT_PINKY', 'RIGHT_PINKY', 'LEFT_INDEX', 'RIGHT_INDEX', 'LEFT_THUMB', 'RIGHT_THUMB', 'LEFT_HIP', 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE', 'RIGHT_ANKLE', 'LEFT_HEEL', 'RIGHT_HEEL', 'LEFT_FOOT_INDEX', 'RIGHT_FOOT_INDEX']
N_JOINTS: 33


## 4.3 Dataset class definition

The KeypointWindowDataset class loads and processes pose keypoint data from videos for model training. It reads a CSV file containing metadata, including video IDs, exercise IDs, frame indices, and pre-calculated joint angle errors (ranging from 0 to 9). The data is sorted based on video ID, repetition number, and window start. For each sample, it loads the corresponding keypoint data (in .npy format), extracts a segment of frames based on the start and end indices, reshapes the keypoints into a 2D array, and converts them into a PyTorch tensor. It also retrieves the correctness label and the pre-calculated error values, which are stored in tensors. This class efficiently loads and processes the data in batches for training tasks like exercise recognition, where both pose keypoints and error features are used for supervised learning.

In [7]:

class KeypointWindowDataset(Dataset):
    def __init__(self, csv_file: Path, keypt_root: Path):
        df = pd.read_csv(csv_file)
        df = df.sort_values(["video_id","repetition_number","window_start"])
        self.rows = df.to_dict("records")
        self.keypt_root = keypt_root

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i: int):
        r   = self.rows[i]
        ex  = int(r["exercise_id"]) - 1      # zero‐based [0..NUM_EXERCISES-1]
        vid = r["video_id"]
        f0, f1 = int(r["window_start"]), int(r["window_end"])

        # load keypoints
        arr = np.load(
            next((self.keypt_root/f"Ex{ex+1}").glob(f"{vid}-Camera17*-mp.npy"))
        )  # shape (F,33,3)

        seg = arr[f0:f1]            # (T, 33, 3)
        seg = seg.reshape(len(seg), -1)  # (T, 99)
        seq = torch.from_numpy(seg).float()

        label = torch.tensor(r["correctness"], dtype=torch.long)
        err   = torch.tensor([r[f"err_{j}"] for j in range(N_ERR)],
                             dtype=torch.float32)

        return seq, label, err, ex

## 4.4 Model definitions   

1. KeypointEncoder Class: Feature Extraction  
The KeypointEncoder class is responsible for extracting feature representations from the input keypoint data. It uses two 1D convolutional layers (conv1 and conv2) to process the input sequence of keypoints. The input tensor, which represents keypoints for each frame in a video, is passed through these convolutional layers after being reshaped to fit the 1D convolution. Each convolution layer is followed by a ReLU activation function to introduce non-linearity. The final step of the encoder involves an adaptive average pooling (pool), which reduces the feature map to a single value per feature channel. This results in a compact representation of the keypoint sequence, which is then passed forward for further processing.  

2. PoseQualityNetKP Class: Overview  
The PoseQualityNetKP class is the main model used for pose quality assessment. It integrates the KeypointEncoder to process the raw keypoint data and extracts meaningful features. The model then uses an LSTM (Long Short-Term Memory) network to learn the temporal dependencies between the keypoint sequences. The LSTM consists of two bidirectional layers, allowing the model to capture information from both past and future frames in the sequence. The LSTM outputs a sequence of hidden states, which are averaged across the time dimension to produce a fixed-size feature vector representing the entire sequence of frames. This vector, along with the exercise embedding, is used to make predictions.  

3. Exercise Embedding and Final Layers  
In addition to the keypoint features, the model incorporates an exercise-specific embedding to capture the variations between different exercises. The ex_emb layer processes the one-hot encoded exercise ID into a dense representation. This embedding is passed through a small multi-layer perceptron (MLP) that reduces the embedding size, enabling the model to focus on the most important characteristics of each exercise. The concatenation of the temporal features from the LSTM and the exercise embedding forms the final input to the classification and error prediction heads. These final heads, cls_head and err_head, are fully connected layers that output the classification of the exercise and the pose errors, respectively.  

In [8]:
# 5. Model definitions
class KeypointEncoder(nn.Module):
    def __init__(self, in_dim:int, embed:int=512):
        super().__init__()
        self.conv1 = nn.Conv1d(in_dim, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(128, embed, kernel_size=3, padding=1)
        self.pool  = nn.AdaptiveAvgPool1d(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, D); treat as (B, D, 1) for Conv1d
        x = x.unsqueeze(2)                 # → (B, D, 1)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        return self.pool(x).squeeze(-1)    # → (B, embed)

class PoseQualityNetKP(nn.Module):
    def __init__(self,
                 in_dim: int,
                 num_ex: int,
                 hidden: int = 256,
                 ex_emb: int = 64):
        super().__init__()
        # keypoint feature extractor
        self.encoder = KeypointEncoder(in_dim)

        # sequence model
        self.lstm = nn.LSTM(
            input_size=512,
            hidden_size=hidden,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        feat_dim = hidden * 2

        # exercise embedding MLP
        self.ex_emb = nn.Sequential(
            nn.Linear(num_ex, ex_emb),
            nn.ReLU(),
            nn.Linear(ex_emb, ex_emb)
        )

        # final heads
        self.cls_head = nn.Linear(feat_dim + ex_emb, 2)
        self.err_head = nn.Linear(feat_dim + ex_emb, N_ERR)

    def forward(self,
                seq:     torch.Tensor,  # (B, T, D)
                ex_1hot: torch.Tensor   # (B, num_ex)
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # 1) keypoint → sequence feats
        # encode each frame
        B,T,_ = seq.shape
        feats = torch.stack([
            self.encoder(seq[:,t]) for t in range(T)
        ], dim=1)                                # (B, T, 512)
        out, _ = self.lstm(feats)                # (B, T, 2*hidden)
        g = out.mean(1)                          # (B, 2*hidden)

        # 2) exercise embed
        ex_e = self.ex_emb(ex_1hot)              # (B, ex_emb)

        # 3) concat and heads
        h = torch.cat([g, ex_e], dim=1)          # (B, feat_dim+ex_emb)
        return self.cls_head(h), self.err_head(h)


# 5. Model Training  

The PoseQualityNetKP model is designed for pose quality assessment, utilizing keypoint data to evaluate exercise movements. The model consists of a Keypoint Encoder to extract meaningful features from raw keypoint sequences, followed by an LSTM layer for learning temporal dependencies across frames. Additionally, the model incorporates an exercise-specific embedding layer to capture the variations between different exercises, and two final output heads for classification and error prediction.  

During the training process, the model receives keypoint data sequences and exercise one-hot encodings as inputs. The keypoint sequences are first processed by the KeypointEncoder, which uses 1D convolutions to extract feature representations from each frame in the sequence. These frame-level features are then passed through a bidirectional LSTM to capture both past and future context in the sequence. The resulting features are aggregated by averaging the LSTM output over time, producing a fixed-size feature vector for each input sequence.  

The exercise information is encoded through the exercise embedding MLP, which maps the one-hot encoded exercise IDs to a dense embedding representation. The temporal features from the LSTM and the exercise embeddings are concatenated together and passed through two separate heads: the classification head (cls_head) to predict the correctness of the exercise repetition (correct/incorrect), and the error head (err_head) to predict the pose errors for each joint, based on the ideal joint angles.  

During training, the model optimizes two loss functions:

- Cross-entropy loss (loss_cls): This loss is used to train the model to classify the exercise as correct or incorrect.  

- Smooth L1 loss (loss_err): This loss is used to predict the joint angle errors, aiming to minimize the difference between predicted and actual errors.  

The model is trained using the Adam optimizer with a learning rate of 1e-4. The training process iterates over batches of data, updating the model parameters to minimize the combined loss. During each epoch, the model's performance is validated using a validation set, where metrics such as accuracy, F1 score, precision, recall, and mean absolute error (MAE) are calculated. The model with the best F1 score is saved as the final trained model.  

This model training framework prepares the system for accurate pose quality assessment and error prediction, allowing for fine-tuned classification of exercise correctness and detailed joint error analysis.  

In [None]:
def train_epochs(
    csv_file:  str   = str(WIN_CSV),
    keypt_root:str   = str(KEYPT_ROOT),
    num_ex:    int   = NUM_EXERCISES,
    epochs:    int   = 30,
    batch:     int   = 16,
    lr:        float = 1e-4,
    ckpt_file: str   = CKPT_FILE
):
    # Build dataset and split
    ds  = KeypointWindowDataset(Path(csv_file), Path(keypt_root))
    N   = len(ds)
    idx = np.arange(N); np.random.shuffle(idx)
    c1, c2 = int(0.7*N), int(0.85*N)
    train_idx, val_idx = idx[:c1], idx[c1:c2]

    train_dl = DataLoader(Subset(ds, train_idx), batch_size=batch, shuffle=True)
    val_dl   = DataLoader(Subset(ds, val_idx),   batch_size=batch, shuffle=False)

    # Infer input dimension
    sample_seq, _, _, _ = ds[0]
    in_dim = sample_seq.shape[-1]

    # Build model
    model    = PoseQualityNetKP(in_dim, num_ex).to(DEVICE)
    loss_cls = nn.CrossEntropyLoss()
    loss_err = nn.SmoothL1Loss()
    opt      = Adam(model.parameters(), lr)

    best_f1 = 0.0
    for epoch in range(1, epochs+1):
        # -- train --
        model.train()
        tot_loss = 0.0
        for seq, y, err, ex in tqdm(train_dl, desc=f"Epoch {epoch:02d}"):
            seq, y, err, ex = [x.to(DEVICE) for x in (seq, y, err, ex)]
            # Build one-hot encoding for exercise
            ex_1hot = F.one_hot(ex, num_ex).float()

            opt.zero_grad()
            logits, err_hat = model(seq, ex_1hot)
            loss = loss_cls(logits, y) + 0.1 * loss_err(err_hat, err)
            loss.backward()
            opt.step()

            tot_loss += loss.item() * y.size(0)
        print(f"  ↳ train loss: {tot_loss/len(train_idx):.4f}")

        # -- validation --
        model.eval()
        y_true, y_pred, errs, precision, recall = [], [], [], [], []
        with torch.no_grad():
            for seq, y, err, ex in val_dl:
                seq, y, err, ex = [x.to(DEVICE) for x in (seq, y, err, ex)]
                ex_1hot = F.one_hot(ex, num_ex).float()
                logits, err_hat = model(seq, ex_1hot)

                y_true += y.cpu().tolist()
                y_pred += logits.argmax(1).cpu().tolist()
                errs    += [(err_hat - err.to(DEVICE)).abs().mean(1)]
                
                # Precision and Recall
                precision += [precision_score(y.cpu(), logits.argmax(1).cpu(), average='weighted')]
                recall += [recall_score(y.cpu(), logits.argmax(1).cpu(), average='weighted')]

        # Calculate metrics
        acc = accuracy_score(y_true, y_pred)
        f1  = f1_score(y_true, y_pred, average='weighted')
        mae = torch.cat(errs).mean().item()
        precision_mean = np.mean(precision)
        recall_mean = np.mean(recall)

        print(f"  ↳ val acc {acc:.3f}, Precision {precision_mean:.3f}, Recall {recall_mean:.3f}, F1 {f1:.3f}, MAE° {mae:.2f}")

         # Save if improved
        if f1 > best_f1:
            best_f1 = f1
            # Save full model
            torch.save(model, ckpt_file)
            # Save state_dict separately
            state_dict_path = Path(ckpt_file).with_suffix('.pth')
            torch.save(model.state_dict(), state_dict_path)
            print(f"  ✓ Saved best model to {ckpt_file} and state_dict to {state_dict_path} (F1 {f1:.3f})")

train_epochs(epochs=80, batch=16, lr=1e-4)


Epoch 01: 100%|██████████| 629/629 [00:13<00:00, 47.89it/s]


  ↳ train loss: 2.7361


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  ↳ val acc 0.585, Precision 0.617, Recall 0.585, F1 0.559, MAE° 19.60
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.559)


Epoch 02: 100%|██████████| 629/629 [00:12<00:00, 49.66it/s]


  ↳ train loss: 2.4140
  ↳ val acc 0.633, Precision 0.674, Recall 0.633, F1 0.633, MAE° 17.05
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.633)


Epoch 03: 100%|██████████| 629/629 [00:12<00:00, 49.03it/s]


  ↳ train loss: 2.1987
  ↳ val acc 0.637, Precision 0.670, Recall 0.637, F1 0.635, MAE° 15.78
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.635)


Epoch 04: 100%|██████████| 629/629 [00:12<00:00, 48.86it/s]


  ↳ train loss: 2.0483
  ↳ val acc 0.695, Precision 0.725, Recall 0.695, F1 0.693, MAE° 14.72
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.693)


Epoch 05: 100%|██████████| 629/629 [00:12<00:00, 48.44it/s]


  ↳ train loss: 1.9148
  ↳ val acc 0.725, Precision 0.765, Recall 0.725, F1 0.717, MAE° 13.80
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.717)


Epoch 06: 100%|██████████| 629/629 [00:13<00:00, 48.34it/s]


  ↳ train loss: 1.7716
  ↳ val acc 0.733, Precision 0.772, Recall 0.733, F1 0.725, MAE° 12.92
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.725)


Epoch 07: 100%|██████████| 629/629 [00:12<00:00, 48.39it/s]


  ↳ train loss: 1.6460
  ↳ val acc 0.748, Precision 0.782, Recall 0.748, F1 0.743, MAE° 11.81
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.743)


Epoch 08: 100%|██████████| 629/629 [00:12<00:00, 48.40it/s]


  ↳ train loss: 1.5368
  ↳ val acc 0.776, Precision 0.801, Recall 0.776, F1 0.777, MAE° 10.89
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.777)


Epoch 09: 100%|██████████| 629/629 [00:13<00:00, 48.15it/s]


  ↳ train loss: 1.4314
  ↳ val acc 0.786, Precision 0.830, Recall 0.786, F1 0.779, MAE° 10.18
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.779)


Epoch 10: 100%|██████████| 629/629 [00:13<00:00, 47.94it/s]


  ↳ train loss: 1.3509
  ↳ val acc 0.802, Precision 0.823, Recall 0.802, F1 0.801, MAE° 9.57
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.801)


Epoch 11: 100%|██████████| 629/629 [07:16<00:00,  1.44it/s]  


  ↳ train loss: 1.2768


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  ↳ val acc 0.740, Precision 0.806, Recall 0.740, F1 0.723, MAE° 9.12


Epoch 12: 100%|██████████| 629/629 [00:12<00:00, 50.25it/s]


  ↳ train loss: 1.2156
  ↳ val acc 0.809, Precision 0.840, Recall 0.809, F1 0.807, MAE° 8.57
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.807)


Epoch 13: 100%|██████████| 629/629 [00:12<00:00, 49.53it/s]


  ↳ train loss: 1.1549
  ↳ val acc 0.800, Precision 0.821, Recall 0.800, F1 0.799, MAE° 8.19


Epoch 14: 100%|██████████| 629/629 [00:12<00:00, 49.51it/s]


  ↳ train loss: 1.1053
  ↳ val acc 0.830, Precision 0.851, Recall 0.830, F1 0.829, MAE° 7.74
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.829)


Epoch 15: 100%|██████████| 629/629 [00:12<00:00, 49.73it/s]


  ↳ train loss: 1.0451
  ↳ val acc 0.843, Precision 0.862, Recall 0.843, F1 0.843, MAE° 7.32
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.843)


Epoch 16: 100%|██████████| 629/629 [00:12<00:00, 49.95it/s]


  ↳ train loss: 1.0001
  ↳ val acc 0.831, Precision 0.849, Recall 0.831, F1 0.831, MAE° 7.08


Epoch 17: 100%|██████████| 629/629 [00:12<00:00, 49.09it/s]


  ↳ train loss: 0.9678
  ↳ val acc 0.833, Precision 0.854, Recall 0.832, F1 0.831, MAE° 6.75


Epoch 18: 100%|██████████| 629/629 [00:12<00:00, 49.35it/s]


  ↳ train loss: 0.9310
  ↳ val acc 0.830, Precision 0.846, Recall 0.829, F1 0.830, MAE° 6.57


Epoch 19: 100%|██████████| 629/629 [00:12<00:00, 49.10it/s]


  ↳ train loss: 0.8960
  ↳ val acc 0.793, Precision 0.851, Recall 0.792, F1 0.783, MAE° 6.31


Epoch 20: 100%|██████████| 629/629 [00:12<00:00, 48.72it/s]


  ↳ train loss: 0.8619
  ↳ val acc 0.828, Precision 0.854, Recall 0.828, F1 0.826, MAE° 6.16


Epoch 21: 100%|██████████| 629/629 [00:12<00:00, 48.46it/s]


  ↳ train loss: 0.8339
  ↳ val acc 0.873, Precision 0.887, Recall 0.873, F1 0.873, MAE° 5.86
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.873)


Epoch 22: 100%|██████████| 629/629 [00:13<00:00, 48.34it/s]


  ↳ train loss: 0.8148
  ↳ val acc 0.856, Precision 0.871, Recall 0.855, F1 0.855, MAE° 5.72


Epoch 23: 100%|██████████| 629/629 [00:13<00:00, 48.22it/s]


  ↳ train loss: 0.7936
  ↳ val acc 0.873, Precision 0.887, Recall 0.873, F1 0.873, MAE° 5.67
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.873)


Epoch 24: 100%|██████████| 629/629 [00:13<00:00, 48.20it/s]


  ↳ train loss: 0.7775
  ↳ val acc 0.864, Precision 0.878, Recall 0.864, F1 0.864, MAE° 5.46


Epoch 25: 100%|██████████| 629/629 [00:13<00:00, 48.27it/s]


  ↳ train loss: 0.7534
  ↳ val acc 0.875, Precision 0.890, Recall 0.874, F1 0.874, MAE° 5.34
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.874)


Epoch 26: 100%|██████████| 629/629 [00:13<00:00, 47.89it/s]


  ↳ train loss: 0.7243
  ↳ val acc 0.861, Precision 0.876, Recall 0.861, F1 0.861, MAE° 5.26


Epoch 27: 100%|██████████| 629/629 [00:12<00:00, 48.50it/s]


  ↳ train loss: 0.7258
  ↳ val acc 0.852, Precision 0.871, Recall 0.851, F1 0.850, MAE° 5.14


Epoch 28: 100%|██████████| 629/629 [00:13<00:00, 48.03it/s]


  ↳ train loss: 0.7086
  ↳ val acc 0.872, Precision 0.894, Recall 0.872, F1 0.871, MAE° 5.04


Epoch 29: 100%|██████████| 629/629 [00:13<00:00, 47.98it/s]


  ↳ train loss: 0.6898
  ↳ val acc 0.830, Precision 0.871, Recall 0.829, F1 0.824, MAE° 5.11


Epoch 30: 100%|██████████| 629/629 [00:13<00:00, 48.07it/s]


  ↳ train loss: 0.6800
  ↳ val acc 0.854, Precision 0.882, Recall 0.854, F1 0.852, MAE° 4.91


Epoch 31: 100%|██████████| 629/629 [00:13<00:00, 47.92it/s]


  ↳ train loss: 0.6585
  ↳ val acc 0.874, Precision 0.885, Recall 0.874, F1 0.874, MAE° 4.83


Epoch 32: 100%|██████████| 629/629 [00:13<00:00, 48.20it/s]


  ↳ train loss: 0.6433
  ↳ val acc 0.895, Precision 0.904, Recall 0.895, F1 0.895, MAE° 4.74
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.895)


Epoch 33: 100%|██████████| 629/629 [00:13<00:00, 48.10it/s]


  ↳ train loss: 0.6344
  ↳ val acc 0.881, Precision 0.894, Recall 0.880, F1 0.880, MAE° 4.72


Epoch 34: 100%|██████████| 629/629 [00:12<00:00, 48.50it/s]


  ↳ train loss: 0.6187
  ↳ val acc 0.888, Precision 0.902, Recall 0.887, F1 0.887, MAE° 4.71


Epoch 35: 100%|██████████| 629/629 [00:13<00:00, 48.36it/s]


  ↳ train loss: 0.6199
  ↳ val acc 0.872, Precision 0.888, Recall 0.872, F1 0.873, MAE° 4.52


Epoch 36: 100%|██████████| 629/629 [00:13<00:00, 48.34it/s]


  ↳ train loss: 0.5904
  ↳ val acc 0.880, Precision 0.899, Recall 0.879, F1 0.879, MAE° 4.43


Epoch 37: 100%|██████████| 629/629 [00:13<00:00, 48.37it/s]


  ↳ train loss: 0.5880
  ↳ val acc 0.882, Precision 0.896, Recall 0.882, F1 0.882, MAE° 4.31


Epoch 38: 100%|██████████| 629/629 [00:12<00:00, 48.48it/s]


  ↳ train loss: 0.5849
  ↳ val acc 0.900, Precision 0.910, Recall 0.899, F1 0.900, MAE° 4.44
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.900)


Epoch 39: 100%|██████████| 629/629 [00:13<00:00, 48.38it/s]


  ↳ train loss: 0.5647
  ↳ val acc 0.871, Precision 0.895, Recall 0.871, F1 0.869, MAE° 4.34


Epoch 40: 100%|██████████| 629/629 [00:12<00:00, 48.44it/s]


  ↳ train loss: 0.5531
  ↳ val acc 0.874, Precision 0.888, Recall 0.873, F1 0.874, MAE° 4.37


Epoch 41: 100%|██████████| 629/629 [00:13<00:00, 48.34it/s]


  ↳ train loss: 0.5498
  ↳ val acc 0.892, Precision 0.903, Recall 0.892, F1 0.892, MAE° 4.04


Epoch 42: 100%|██████████| 629/629 [00:13<00:00, 48.38it/s]


  ↳ train loss: 0.5352
  ↳ val acc 0.894, Precision 0.906, Recall 0.894, F1 0.894, MAE° 4.12


Epoch 43: 100%|██████████| 629/629 [23:23<00:00,  2.23s/it]  


  ↳ train loss: 0.5291
  ↳ val acc 0.906, Precision 0.915, Recall 0.905, F1 0.906, MAE° 3.91
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.906)


Epoch 44: 100%|██████████| 629/629 [00:12<00:00, 50.78it/s]


  ↳ train loss: 0.5164
  ↳ val acc 0.906, Precision 0.916, Recall 0.905, F1 0.906, MAE° 4.05
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.906)


Epoch 45: 100%|██████████| 629/629 [00:12<00:00, 50.47it/s]


  ↳ train loss: 0.5228
  ↳ val acc 0.891, Precision 0.903, Recall 0.890, F1 0.891, MAE° 4.04


Epoch 46: 100%|██████████| 629/629 [00:12<00:00, 49.90it/s]


  ↳ train loss: 0.5060
  ↳ val acc 0.894, Precision 0.910, Recall 0.894, F1 0.893, MAE° 3.85


Epoch 47: 100%|██████████| 629/629 [00:12<00:00, 49.65it/s]


  ↳ train loss: 0.5018
  ↳ val acc 0.911, Precision 0.923, Recall 0.911, F1 0.911, MAE° 3.74
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.911)


Epoch 48: 100%|██████████| 629/629 [00:12<00:00, 49.77it/s]


  ↳ train loss: 0.4978
  ↳ val acc 0.898, Precision 0.911, Recall 0.898, F1 0.897, MAE° 3.89


Epoch 49: 100%|██████████| 629/629 [00:12<00:00, 49.34it/s]


  ↳ train loss: 0.4728
  ↳ val acc 0.914, Precision 0.924, Recall 0.914, F1 0.914, MAE° 3.66
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.914)


Epoch 50: 100%|██████████| 629/629 [00:12<00:00, 49.13it/s]


  ↳ train loss: 0.4796
  ↳ val acc 0.903, Precision 0.916, Recall 0.903, F1 0.903, MAE° 3.68


Epoch 51: 100%|██████████| 629/629 [00:12<00:00, 49.33it/s]


  ↳ train loss: 0.4860
  ↳ val acc 0.904, Precision 0.917, Recall 0.904, F1 0.904, MAE° 3.71


Epoch 52: 100%|██████████| 629/629 [00:12<00:00, 49.54it/s]


  ↳ train loss: 0.4629
  ↳ val acc 0.900, Precision 0.913, Recall 0.900, F1 0.899, MAE° 3.60


Epoch 53: 100%|██████████| 629/629 [00:12<00:00, 48.74it/s]


  ↳ train loss: 0.4489
  ↳ val acc 0.914, Precision 0.924, Recall 0.913, F1 0.914, MAE° 3.60


Epoch 54: 100%|██████████| 629/629 [00:12<00:00, 48.93it/s]


  ↳ train loss: 0.4368
  ↳ val acc 0.896, Precision 0.914, Recall 0.896, F1 0.895, MAE° 3.59


Epoch 55: 100%|██████████| 629/629 [00:12<00:00, 49.05it/s]


  ↳ train loss: 0.4656
  ↳ val acc 0.900, Precision 0.910, Recall 0.900, F1 0.900, MAE° 3.53


Epoch 56: 100%|██████████| 629/629 [00:12<00:00, 49.20it/s]


  ↳ train loss: 0.4455


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  ↳ val acc 0.891, Precision 0.909, Recall 0.891, F1 0.890, MAE° 3.78


Epoch 57: 100%|██████████| 629/629 [00:12<00:00, 48.67it/s]


  ↳ train loss: 0.4448
  ↳ val acc 0.907, Precision 0.919, Recall 0.906, F1 0.906, MAE° 3.55


Epoch 58: 100%|██████████| 629/629 [00:12<00:00, 48.99it/s]


  ↳ train loss: 0.4340
  ↳ val acc 0.915, Precision 0.925, Recall 0.915, F1 0.915, MAE° 3.54
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.915)


Epoch 59: 100%|██████████| 629/629 [00:12<00:00, 48.67it/s]


  ↳ train loss: 0.4173
  ↳ val acc 0.905, Precision 0.922, Recall 0.905, F1 0.905, MAE° 3.39


Epoch 60: 100%|██████████| 629/629 [00:12<00:00, 48.47it/s]


  ↳ train loss: 0.4160
  ↳ val acc 0.894, Precision 0.910, Recall 0.894, F1 0.893, MAE° 3.34


Epoch 61: 100%|██████████| 629/629 [00:13<00:00, 47.63it/s]


  ↳ train loss: 0.4161
  ↳ val acc 0.909, Precision 0.919, Recall 0.909, F1 0.909, MAE° 3.30


Epoch 62: 100%|██████████| 629/629 [00:13<00:00, 47.71it/s]


  ↳ train loss: 0.4166
  ↳ val acc 0.909, Precision 0.922, Recall 0.908, F1 0.908, MAE° 3.37


Epoch 63: 100%|██████████| 629/629 [00:14<00:00, 44.21it/s]


  ↳ train loss: 0.3986
  ↳ val acc 0.873, Precision 0.893, Recall 0.873, F1 0.872, MAE° 3.46


Epoch 64: 100%|██████████| 629/629 [00:13<00:00, 47.20it/s]


  ↳ train loss: 0.4125
  ↳ val acc 0.919, Precision 0.928, Recall 0.919, F1 0.919, MAE° 3.32
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.919)


Epoch 65: 100%|██████████| 629/629 [00:13<00:00, 45.80it/s]


  ↳ train loss: 0.3953
  ↳ val acc 0.898, Precision 0.913, Recall 0.898, F1 0.899, MAE° 3.32


Epoch 66: 100%|██████████| 629/629 [00:13<00:00, 45.86it/s]


  ↳ train loss: 0.3854
  ↳ val acc 0.910, Precision 0.922, Recall 0.910, F1 0.910, MAE° 3.23


Epoch 67: 100%|██████████| 629/629 [00:13<00:00, 47.97it/s]


  ↳ train loss: 0.3750
  ↳ val acc 0.918, Precision 0.926, Recall 0.918, F1 0.918, MAE° 3.15


Epoch 68: 100%|██████████| 629/629 [00:13<00:00, 48.06it/s]


  ↳ train loss: 0.3737
  ↳ val acc 0.906, Precision 0.918, Recall 0.906, F1 0.906, MAE° 3.28


Epoch 69: 100%|██████████| 629/629 [00:13<00:00, 47.64it/s]


  ↳ train loss: 0.3784
  ↳ val acc 0.910, Precision 0.920, Recall 0.910, F1 0.910, MAE° 3.18


Epoch 70: 100%|██████████| 629/629 [00:13<00:00, 48.15it/s]


  ↳ train loss: 0.3718
  ↳ val acc 0.895, Precision 0.909, Recall 0.895, F1 0.895, MAE° 3.20


Epoch 71: 100%|██████████| 629/629 [00:12<00:00, 48.52it/s]


  ↳ train loss: 0.3616
  ↳ val acc 0.913, Precision 0.926, Recall 0.913, F1 0.913, MAE° 3.12


Epoch 72: 100%|██████████| 629/629 [00:13<00:00, 46.72it/s]


  ↳ train loss: 0.3624
  ↳ val acc 0.922, Precision 0.927, Recall 0.921, F1 0.922, MAE° 2.99
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.922)


Epoch 73: 100%|██████████| 629/629 [00:13<00:00, 47.78it/s]


  ↳ train loss: 0.3632
  ↳ val acc 0.918, Precision 0.926, Recall 0.918, F1 0.918, MAE° 3.31


Epoch 74: 100%|██████████| 629/629 [00:13<00:00, 47.85it/s]


  ↳ train loss: 0.3699
  ↳ val acc 0.919, Precision 0.926, Recall 0.919, F1 0.919, MAE° 3.06


Epoch 75: 100%|██████████| 629/629 [00:13<00:00, 47.07it/s]


  ↳ train loss: 0.3454
  ↳ val acc 0.924, Precision 0.934, Recall 0.924, F1 0.924, MAE° 3.06
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.924)


Epoch 76: 100%|██████████| 629/629 [00:13<00:00, 46.43it/s]


  ↳ train loss: 0.3536
  ↳ val acc 0.920, Precision 0.932, Recall 0.920, F1 0.919, MAE° 3.09


Epoch 77: 100%|██████████| 629/629 [00:13<00:00, 47.53it/s]


  ↳ train loss: 0.3404
  ↳ val acc 0.918, Precision 0.927, Recall 0.918, F1 0.918, MAE° 3.13


Epoch 78: 100%|██████████| 629/629 [00:13<00:00, 47.90it/s]


  ↳ train loss: 0.3417
  ↳ val acc 0.928, Precision 0.936, Recall 0.928, F1 0.928, MAE° 2.97
  ✓ saved new best model to kp_pose_quality_windows_ex.pt  (F1 0.928)


Epoch 79: 100%|██████████| 629/629 [00:13<00:00, 47.66it/s]


  ↳ train loss: 0.3464
  ↳ val acc 0.924, Precision 0.932, Recall 0.924, F1 0.924, MAE° 3.01


Epoch 80: 100%|██████████| 629/629 [00:13<00:00, 47.91it/s]


  ↳ train loss: 0.3417
  ↳ val acc 0.917, Precision 0.926, Recall 0.917, F1 0.916, MAE° 2.94


# 6. Inference Tesing - Correctness and feedback on Live Videos / Recorded Videos 


This script provides a real-time system for exercise pose analysis and feedback, utilizing keypoint data from recorded videos or live camera streams. The system is built around a pre-trained model (PoseQualityNetKP) that classifies the correctness of an exercise and provides detailed feedback on joint angle errors for each frame of the video.

- Model Setup and Loading: The model is first loaded from a checkpoint file (kp_pose_quality_windows_ex1.pt). The model consists of a keypoint encoder, a sequence model (LSTM), and two output heads for classification (correct/incorrect) and error prediction (joint angle errors).

- Exercise Selection: The user is prompted to select an exercise from a predefined list, and the corresponding exercise ID is used to guide the feedback generation.

- MediaPipe Pose Estimation: The script uses MediaPipe to extract world keypoints (3D pose landmarks) from each frame of the video or camera feed. These keypoints are passed to the model for pose evaluation.

- Pose Inference and Feedback:
    - The keypoints from each frame are processed and buffered to maintain a sequence of frames, which are then fed into the model.
    - The model predicts the correctness of the exercise and calculates the joint angle errors for each frame.
    - If the exercise is classified as incorrect, the model suggests corrections based on the joint with the largest error. The suggestion is displayed on the video feed for 3 seconds.
    - Wrong exercise detection: If the model detects significant errors across multiple joints, it flags a wrong exercise and displays a warning message.

- Visual Feedback:
    - The system draws pose landmarks on the video feed using MediaPipe's drawing utilities for visual feedback.
    - Feedback about the exercise's correctness and suggestions for improvement are overlaid on the video.
    - If joint errors exceed a certain threshold, the model displays detailed feedback on specific joints, highlighting areas where the user should focus on improving.

- User Interaction: The video feed (from a file or camera) is processed frame by frame, and real-time feedback is provided to the user, allowing them to see their performance and improve their exercise form.

This system serves as a real-time feedback tool for users performing exercises, enabling them to receive immediate corrective suggestions based on pose analysis. It is useful for applications like personal training, physical therapy, and fitness monitoring.

In [None]:
# Attention: Re-run necessary imports and definitions above if running standalone

# 1. Paths & device (adjust if needed)
SCRIPT_DIR    = Path().resolve()
DATA_ROOT     = SCRIPT_DIR/"Data-REHAB24-6" # Make sure this path is correct

DEVICE = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print("► Using device:", DEVICE)

# 2. Joint names & count
PoseLandmark = mp.solutions.pose.PoseLandmark
JOINT_NAMES = [lm.name for lm in PoseLandmark]
N_JOINTS    = len(JOINT_NAMES)  # should be 33

#  Exerciseses (Ex1…Ex6)
NUM_EXERCISES = 6
CKPT_FILE     = "kp_pose_quality_windows_ex1.pt" # Check if this file exists

ERR_JOINTS   = [
  "LEFT_ELBOW","RIGHT_ELBOW",
  "LEFT_SHOULDER","RIGHT_SHOULDER",
  "LEFT_HIP","RIGHT_HIP",
  "LEFT_KNEE","RIGHT_KNEE",
  "SPINE","HEAD", # Make sure these match the training order
]
N_ERR = len(ERR_JOINTS)   # 10
ERR_COLS = [f"err_{i}" for i in range(N_ERR)]

# 5. Model definitions (Using the ORIGINAL definition from training)
class KeypointEncoder(nn.Module):
    # --- Restored ORIGINAL definition ---
    def __init__(self, in_dim:int, embed:int=512):
        super().__init__()
        self.conv1 = nn.Conv1d(in_dim, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(128, embed, kernel_size=3, padding=1)
        self.pool  = nn.AdaptiveAvgPool1d(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, D); treat as (B, D, 1) for Conv1d
        # This encoder is designed to process features of a SINGLE frame (B, D)
        if x.dim() == 2:
            x = x.unsqueeze(2)                 # → (B, D, 1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return self.pool(x).squeeze(-1)    # → (B, embed)


class PoseQualityNetKP(nn.Module):
    # --- Keep the PoseQualityNetKP class definition as in the original code ---
    def __init__(self,
                 in_dim: int, # Should be 99 (33*3)
                 num_ex: int,
                 hidden: int = 256,
                 ex_emb: int = 64,
                 embed: int = 512): # Added embed dim to match encoder
        super().__init__()
        # keypoint feature extractor (Uses the restored original encoder)
        self.encoder = KeypointEncoder(in_dim, embed=embed)

        # sequence model
        self.lstm = nn.LSTM(
            input_size=embed, # Use embed dim here
            hidden_size=hidden,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        feat_dim = hidden * 2

        # exercise embedding MLP
        self.ex_emb = nn.Sequential(
            nn.Linear(num_ex, ex_emb),
            nn.ReLU(),
            nn.Linear(ex_emb, ex_emb)
        )

        # final heads
        self.cls_head = nn.Linear(feat_dim + ex_emb, 2) # 2 classes: incorrect, correct
        self.err_head = nn.Linear(feat_dim + ex_emb, N_ERR)

    def forward(self,
                seq:     torch.Tensor,  # (B, T, D) where D=99
                ex_1hot: torch.Tensor   # (B, num_ex)
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # 1) keypoint → sequence feats
        # encode each frame
        B, T, D = seq.shape # Now this should work
        # Process sequence frame by frame using the encoder
        frame_embeddings = []
        for t in range(T):
            frame_data = seq[:, t, :] # Get data for frame t: (B, D)
            frame_embedding = self.encoder(frame_data) # Output: (B, embed)
            frame_embeddings.append(frame_embedding)

        feats = torch.stack(frame_embeddings, dim=1) # (B, T, embed)

        # 2) sequence model (LSTM)
        out, _ = self.lstm(feats)                # (B, T, 2*hidden)
        # Aggregate LSTM outputs (e.g., mean pooling over time)
        g = out.mean(dim=1)                      # (B, 2*hidden)

        # 3) exercise embed
        ex_e = self.ex_emb(ex_1hot)              # (B, ex_emb)

        # 4) concat and heads
        h = torch.cat([g, ex_e], dim=1)          # (B, feat_dim + ex_emb)
        logits = self.cls_head(h)                # (B, 2)
        err_hat = self.err_head(h)               # (B, N_ERR)

        return logits, err_hat
# --- End of re-included definitions ---


# Load model
if not Path(CKPT_FILE).exists():
    print(f"Error: Checkpoint file not found at {CKPT_FILE}")
    exit()

print(f"Loading model from {CKPT_FILE}...")
# Load the state dict if you have the class defined, or the whole model if saved that way
# Assuming the whole model was saved with torch.save(model, ckpt_file)
infer_model = torch.load(CKPT_FILE, map_location=DEVICE)
infer_model.eval()
print("✅ Model loaded.")

# Exercise map
EXERCISE_MAP = {
    1: "Arm abduction",
    2: "Arm VW",
    3: "Push-ups",
    4: "Leg abduction",
    5: "Leg lunge",
    6: "Squats"
}
NUM_EXERCISES = len(EXERCISE_MAP) # Ensure consistency

# Ask user for exercise ID
while True:
    try:
        exercise_id_str = input(f"Enter the exercise ID you're performing (1-{len(EXERCISE_MAP)}): ")
        exercise_id = int(exercise_id_str)
        if 1 <= exercise_id <= len(EXERCISE_MAP):
            exercise_name = EXERCISE_MAP[exercise_id]
            break
        else:
            print(f"Invalid ID. Please enter a number between 1 and {len(EXERCISE_MAP)}.")
    except ValueError:
        print("Invalid input. Please enter a number.")

# MediaPipe Pose Setup --- MODIFIED ---
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(
    static_image_mode=False,
    model_complexity=2, # <<< MATCHED training complexity
    enable_segmentation=False, # Keep false if not used
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
)
mp_drawing = mp.solutions.drawing_utils

# Keypoints extraction function --- MODIFIED ---
def extract_keypoints(frame):
    img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img_rgb.flags.writeable = False
    result = pose.process(img_rgb)
    img_rgb.flags.writeable = True

    keypoints = None
    landmarks_for_drawing = None # Still useful to draw image landmarks

    # <<< USE pose_world_landmarks >>>
    if result.pose_world_landmarks:
        world_landmarks = result.pose_world_landmarks.landmark
        # Store x, y, z world coordinates
        keypoints = np.array([(lm.x, lm.y, lm.z) for lm in world_landmarks], dtype=np.float32)

    # Get image landmarks separately for drawing (optional but helpful)
    if result.pose_landmarks:
         landmarks_for_drawing = result.pose_landmarks

    # Return world keypoints (for model) and image landmarks (for drawing)
    return keypoints, landmarks_for_drawing

# Inference parameters
SEQUENCE_LENGTH = 16 # Match the buffer size used in inference
IN_DIM = N_JOINTS * 3 # 33 * 3 = 99

# Inference for correctness (model prediction and feedback)
def infer_and_feedback(model, video_path, selected_ex_id, selected_ex_name):
    # Check if a video path is provided or use the camera
    if video_path is None:
        cap = cv2.VideoCapture(0)  # 0 refers to the default camera (webcam)
        if not cap.isOpened():
            print("Error: Could not access the camera.")
            return
        print("Using camera feed...")
    else:
        cap = cv2.VideoCapture(video_path)  # Use the provided video path
        if not cap.isOpened():
            print(f"Error: Could not open video file {video_path}")
            return
        print(f"Using video: {video_path}")

    keypoints_buffer = deque(maxlen=SEQUENCE_LENGTH)
    feedback = "Initializing..."
    err_values = np.zeros(N_ERR)  # Store last error values
    predicted_class = 0  # Default to incorrect initially
    suggestion = ""  # Variable to store the suggestion
    suggestion_time = 0  # Timer to track suggestion display duration (in frames)
    SUGGESTION_DURATION = 3 * 30  # 3 seconds (assuming 30 fps, so 3 * 30 frames)
    
    wrong_exercise_detected = False  # Flag to track if a wrong exercise is detected
    wrong_exercise_time = 0  # Timer to track the duration of wrong exercise detection

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            print("End of video or cannot read frame.")
            break

        # --- Get WORLD keypoints for model, IMAGE landmarks for drawing ---
        world_keypoints, image_landmarks_for_drawing = extract_keypoints(frame)

        # --- Draw IMAGE landmarks if detected ---
        if image_landmarks_for_drawing:
            mp_drawing.draw_landmarks(
                frame,
                image_landmarks_for_drawing,  # Use image landmarks here
                mp_pose.POSE_CONNECTIONS,
                landmark_drawing_spec=mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
                connection_drawing_spec=mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
            )

        # --- Use WORLD keypoints for the model ---
        if world_keypoints is not None:
            keypoints_buffer.append(world_keypoints)  # Add WORLD keypoints to buffer

            # Check if buffer is full
            if len(keypoints_buffer) == SEQUENCE_LENGTH:
                # Prepare sequence for model using world keypoints
                keypoints_array = np.array(keypoints_buffer, dtype=np.float32)  # (16, 33, 3) world coords
                keypoints_flat = keypoints_array.reshape(SEQUENCE_LENGTH, -1)  # (16, 99) world coords
                seq = torch.tensor(keypoints_flat, dtype=torch.float32).unsqueeze(0).to(DEVICE)  # (1, 16, 99)

                ex_tensor = torch.tensor([selected_ex_id - 1], device=DEVICE)
                ex_1hot = F.one_hot(ex_tensor, num_classes=NUM_EXERCISES).float()

                model.eval()
                with torch.no_grad():
                    logits, err_hat = model(seq, ex_1hot)
                    predicted_class = logits.argmax(1).item()  # 0: incorrect, 1: correct
                    err_values = err_hat.squeeze().cpu().numpy()

                feedback = "Correct" if predicted_class == 1 else "Incorrect"

                # If the posture is incorrect, find the joint with the largest error and suggest correction
                if predicted_class == 0:
                    # Find the joint with the maximum error
                    max_error_idx = np.argmax(np.abs(err_values))  # Get index of the joint with the largest error
                    joint_with_error = ERR_JOINTS[max_error_idx]
                    max_error = err_values[max_error_idx]
                    suggestion = f"Correct your {joint_with_error.replace('_', ' ')} (Error: {max_error:+.2f}°)"
                    suggestion_time = SUGGESTION_DURATION  # Reset suggestion timer for 3 seconds

                # Detect wrong exercise based on high error across multiple joints for a sustained time
                error_threshold = 10  # Adjust as needed (threshold for detecting wrong exercise)
                wrong_joints = np.sum(np.abs(err_values) > error_threshold)  # Count joints with high error

                if wrong_joints >= 4:  # For example, detect wrong exercise if 4 or more joints are off
                    wrong_exercise_detected = True
                    wrong_exercise_time += 1
                    if wrong_exercise_time > 90:  # 3 seconds of high error across multiple joints (3*30 fps)
                        # Show a large warning on the center of the screen
                        frame_height, frame_width = frame.shape[:2]  # Get the frame dimensions
                        text = "WRONG EXERCISE DETECTED!"
                        text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0]
                        text_x = (frame_width - text_size[0]) // 2  # Calculate X position to center text
                        text_y = (frame_height + text_size[1]) // 2  # Calculate Y position to center text
                        cv2.putText(frame, text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 3, cv2.LINE_AA)
                else:
                    wrong_exercise_detected = False
                    wrong_exercise_time = 0  # Reset if exercise is correct

            else:
                feedback = f"Collecting frames... {len(keypoints_buffer)}/{SEQUENCE_LENGTH}"

        else:
            # Handle case where no world keypoints are detected
            feedback = "No pose detected"
            keypoints_buffer.clear()  # Clear buffer if detection lost
            predicted_class = 0  # Reset prediction if pose lost
            suggestion = ""  # Reset suggestion when no pose detected
            suggestion_time = 0  # Reset suggestion timer

        # Display feedback on frame
        cv2.putText(frame, f"Exercise: {selected_ex_name}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA)
        feedback_color = (0, 255, 0) if predicted_class == 1 and len(keypoints_buffer) == SEQUENCE_LENGTH else (0, 0, 255)
        cv2.putText(frame, f"Feedback: {feedback}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, feedback_color, 2, cv2.LINE_AA)

        # Show the suggestion for 3 seconds (timer logic)
        if suggestion_time > 0:
            cv2.putText(frame, suggestion, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2, cv2.LINE_AA)
            suggestion_time -= 1  # Decrease timer on each frame

        # Display specific joint feedback based on error values
        feedback_y_start = 130
        error_threshold = 0.15  # Adjust as needed

        if len(keypoints_buffer) == SEQUENCE_LENGTH:  # Only show errors if inference ran
            for i, joint_name in enumerate(ERR_JOINTS):
                err_val = err_values[i]
                if abs(err_val) > error_threshold:  # Use abs() as error can be +/-
                    color = (0, 165, 255)  # Orange/Yellow for warning
                    text = f"{joint_name}: Check ({err_val:+.2f})"
                    cv2.putText(frame, text, (10, feedback_y_start + i * 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)

        cv2.imshow('Pose Estimation Feedback', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
    pose.close()


  infer_model = torch.load(CKPT_FILE, map_location=DEVICE)


► Using device: mps
Loading model from kp_pose_quality_windows_ex1.pt...
✅ Model loaded.


I0000 00:00:1745045005.890300 1145256 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.3), renderer: Apple M4 Max


INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


Attention: Set Vedio path =None to do live inferencewith laptop camera

In [None]:
# --- Main execution part ---
# VIDEO_PATH = "Data-REHAB24-6/Videos/Ex1/PM_001-Camera17-30fps.mp4"
#VIDEO_PATH = "Data-REHAB24-6/Videos/Ex3/PM_010-Camera17-30fps.mp4"
# --- Main execution part ---
VIDEO_PATH = None  # This is where you set the video path (None for camera feed)

# Check if VIDEO_PATH is None or invalid
if VIDEO_PATH is None:
    print("Using camera feed...")
    cap = cv2.VideoCapture(0)  # 0 refers to the default camera (webcam)
    if not cap.isOpened():
        print("Error: Could not access the camera.")
        exit()
else:
    cap = cv2.VideoCapture(VIDEO_PATH)  # Use the provided video path
    if not cap.isOpened():
        print(f"Error: Could not open video file {VIDEO_PATH}")
        exit()

# Run the inference and feedback function
infer_and_feedback(infer_model, VIDEO_PATH, exercise_id, exercise_name)

W0000 00:00:1745045005.949570 1211346 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1745045005.990816 1211346 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


Using camera feed...
Using camera feed...


W0000 00:00:1745045006.586156 1211345 landmark_projection_calculator.cc:186] Using NORM_RECT without IMAGE_DIMENSIONS is only supported for the square ROI. Provide IMAGE_DIMENSIONS or use PROJECTION_MATRIX.


: 