In [2]:
import cv2
import mediapipe as mp
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import joblib
from sklearn.metrics import accuracy_score, classification_report

# Set EVALUATION_MODE to True to compare against a ground-truth CSV
EVALUATION_MODE = False
VIDEO_PATH = r"C:\Users\Nour_Azab\Downloads\squat_31.mp4"
GROUND_TRUTH_CSV_PATH = r"C:\Users\Nour_Azab\Downloads\pose_landmarks31.csv"

MODEL_PATH = "transformer_multi_task.pth"
SCALER_PATH = "pose_scaler.pkl"
ENCODER_PATH = "phase_encoder.pkl"

WINDOW_SIZE = 30
NUM_FEATURES = 99
RECONSTRUCTION_ERROR_THRESHOLD = 0.1 # Adjust this so that if exceeded notify user


class TransformerAutoencoderClassifier(nn.Module):
    def __init__(self, num_features, seq_len, d_model=128, nhead=8, num_layers=3, num_classes=4):
        super().__init__()
        self.seq_len = seq_len
        self.input_proj = nn.Linear(num_features, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_proj = nn.Linear(d_model, num_features)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(d_model * seq_len, d_model),
            nn.ReLU(),
            nn.Linear(d_model, num_classes)
        )

    
    def forward(self, x):
        z = self.input_proj(x)
        memory = self.encoder(z)
        reconstructed = self.decoder(z, memory)
        recon_out = self.output_proj(reconstructed)
        class_out = self.classifier(memory)
        return recon_out, class_out

# Load Model, Preprocessors, and Ground Truth Data if present
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

scaler = joblib.load(SCALER_PATH)
label_encoder = joblib.load(ENCODER_PATH)
num_classes = len(label_encoder.classes_)

model = TransformerAutoencoderClassifier(num_features=NUM_FEATURES, seq_len=WINDOW_SIZE, num_classes=num_classes).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

print("Model and preprocessors loaded successfully")

# Load ground truth data if in evaluation mode
df_true = None
if EVALUATION_MODE:
    try:
        df_true = pd.read_csv(GROUND_TRUTH_CSV_PATH)
        # Create a quick lookup dictionary: frame_number -> phase_label
        true_labels_map = dict(zip(df_true['frame'], df_true['phase']))
        print("Loaded ground truth data")
    except FileNotFoundError:
        print(f"Ground truth CSV not found at {GROUND_TRUTH_CSV_PATH}. Disabling evaluation mode.")
        EVALUATION_MODE = False

# Real-time Inference & Evaluation Pipeline 

mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)

cap = cv2.VideoCapture(0)# Change to 0 for camera 
if not cap.isOpened():
    raise IOError(f"Cannot open video: {VIDEO_PATH}")

# Buffers and trackers
pose_data_buffer = []
frame_count = 0
current_phase = "Waiting..."
reconstruction_error = 0.0
form_feedback = ""
y_true, y_pred = [], [] # For final accuracy calculation

while True:
    success, frame = cap.read()
    if not success:
        break

    frame_count += 1
    image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = pose.process(image_rgb)

   # Feature Extraction
    if results.pose_landmarks:
        # Get the frame's dimensions for pixel coordinate conversion
        h, w, _ = frame.shape

        # Draw skeleton on the frame
        mp_drawing.draw_landmarks(
            frame, 
            results.pose_landmarks, 
            mp_pose.POSE_CONNECTIONS,
            mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2),
            mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=2, circle_radius=2)
        )

        # Extract landmarks
        landmarks = []
        for lm in results.pose_landmarks.landmark:
            # Convert to pixel coordinates to handle any window size
            px = int(lm.x * w)
            py = int(lm.y * h)
            vis = lm.visibility
            landmarks.extend([px, py, vis])
        
        # Error handling if there is data mismatch
        if len(landmarks) == NUM_FEATURES:
            pose_data_buffer.append(landmarks)
        else:
            pose_data_buffer.append([0.0] * NUM_FEATURES)

    else:
        # If no landmarks are detected, append zeros
        pose_data_buffer.append([0.0] * NUM_FEATURES)

    # Keeps adding frames to buffer till it reaches the size of the given window
    if len(pose_data_buffer) >= WINDOW_SIZE:
        window = np.array(pose_data_buffer[-WINDOW_SIZE:])# get the last number of frames with the size of window
        scaled_window = scaler.transform(window)
        window_tensor = torch.tensor(scaled_window, dtype=torch.float32).unsqueeze(0).to(device)

        with torch.no_grad():
            recon_out, class_out = model(window_tensor)# model prediction and recontruction output
            pred_index = torch.argmax(class_out, dim=1).item()# get the class label index
            current_phase = label_encoder.inverse_transform([pred_index])[0]

            # Calculate reconstruction error
            reconstruction_error = F.mse_loss(recon_out, window_tensor).item()
            form_feedback = "Good Form" if reconstruction_error < RECONSTRUCTION_ERROR_THRESHOLD else "Check Form!"

        # If in evaluation mode, store true and predicted labels
        if EVALUATION_MODE:
            # The label corresponds to the center frame of the window
            center_frame_idx = frame_count - (WINDOW_SIZE // 2)
            true_phase = true_labels_map.get(center_frame_idx)
            
            if true_phase:
                true_label_encoded = label_encoder.transform([true_phase])[0]
                y_true.append(true_label_encoded)
                y_pred.append(pred_index)

    # Feedback
    cv2.rectangle(frame, (5, 5), (300, 110), (0, 0, 0), -1)

    # Display Phase
    cv2.putText(frame, f"Phase: {current_phase}", (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)

    # Display Reconstruction Error and Form Feedback
    feedback_color = (0, 255, 0) if form_feedback == "Good Form" else (0, 0, 255)
    cv2.putText(frame, f"Error: {reconstruction_error:.4f}", (10, 65),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
    cv2.putText(frame, f"Feedback: {form_feedback}", (10, 95),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, feedback_color, 2)
                
    cv2.imshow("Squat Analysis", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()
pose.close()

if EVALUATION_MODE and y_true:
    print("\n Test Evaluation ")
    accuracy = accuracy_score(y_true, y_pred)
    print(f"Overall Accuracy: {accuracy:.4f}")
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=label_encoder.classes_))
else:
    print("\nInference complete. No evaluation was performed.")

Model and preprocessors loaded successfully





Inference complete. No evaluation was performed.
