In [11]:
# 0. Imports
from __future__ import annotations
import os, random, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm

import mediapipe as mp

# Replace this failing line:
# from mediapipe.solutions.pose import PoseLandmark

# With:
PoseLandmark = mp.solutions.pose.PoseLandmark

# Then:
JOINT_NAMES = [lm.name for lm in PoseLandmark]
N_JOINTS    = len(JOINT_NAMES)  # should be 33

print(f"JOINT_NAMES: {JOINT_NAMES}")
print(f"N_JOINTS: {N_JOINTS}")

JOINT_NAMES: ['NOSE', 'LEFT_EYE_INNER', 'LEFT_EYE', 'LEFT_EYE_OUTER', 'RIGHT_EYE_INNER', 'RIGHT_EYE', 'RIGHT_EYE_OUTER', 'LEFT_EAR', 'RIGHT_EAR', 'MOUTH_LEFT', 'MOUTH_RIGHT', 'LEFT_SHOULDER', 'RIGHT_SHOULDER', 'LEFT_ELBOW', 'RIGHT_ELBOW', 'LEFT_WRIST', 'RIGHT_WRIST', 'LEFT_PINKY', 'RIGHT_PINKY', 'LEFT_INDEX', 'RIGHT_INDEX', 'LEFT_THUMB', 'RIGHT_THUMB', 'LEFT_HIP', 'RIGHT_HIP', 'LEFT_KNEE', 'RIGHT_KNEE', 'LEFT_ANKLE', 'RIGHT_ANKLE', 'LEFT_HEEL', 'RIGHT_HEEL', 'LEFT_FOOT_INDEX', 'RIGHT_FOOT_INDEX']
N_JOINTS: 33


In [None]:
# 1. Paths & device
SCRIPT_DIR     = Path(os.getcwd()).resolve()
META_FILE      = SCRIPT_DIR / "Data-REHAB24-6" / "Segmentation.xlsx"
KEYPOINT_ROOT  = SCRIPT_DIR / "Data-REHAB24-6" / "mp_keypoints"

DEVICE = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print("► Using device:", DEVICE)

► Using device: mps


In [13]:
# 2. Joint names & count
JOINT_NAMES = [jl.name for jl in PoseLandmark]
N_JOINTS    = len(JOINT_NAMES)   # 33
# save them for reference
names_file = KEYPOINT_ROOT / "mediapipe_joint_names.txt"
KEYPOINT_ROOT.mkdir(exist_ok=True)
with open(names_file, "w") as f:
    for i, name in enumerate(JOINT_NAMES):
        f.write(f"{i}: {name}\n")
print("✔ Saved MediaPipe joint names to", names_file)

✔ Saved MediaPipe joint names to /Users/jithinkrishnan/Documents/Study/IS06 /MVP/RehabApp/model-training-scripts/Data-REHAB24-6/mp_keypoints/mediapipe_joint_names.txt


In [14]:
# 3. Feedback helper
import math

# friendly names for a few key joints (you can extend this)
FRIENDLY = {
    "LEFT_ELBOW":     "left elbow",
    "RIGHT_ELBOW":    "right elbow",
    "LEFT_SHOULDER":  "left shoulder",
    "RIGHT_SHOULDER": "right shoulder",
    "LEFT_KNEE":      "left knee",
    "RIGHT_KNEE":     "right knee",
    "SPINE":          "spine",
    "NOSE":           "head",
    # etc…
}

# for each body part keyword, tuple = (verb if err<0, verb if err>0)
ACTION_VERBS = {
    "elbow":    ("flex",        "extend"),
    "shoulder": ("rotate back", "rotate forward"),
    "knee":     ("bend",        "straighten"),
    "hip":      ("lower",       "lift"),
    "spine":    ("arch",        "tuck"),
    # fallback will be ("move", "move")
}

def magnitude_adverb(deg: float) -> str:
    a = abs(deg)
    if a < 5:   return "slightly"
    if a < 15:  return "noticeably"
    return        "significantly"

def get_action(part: str, deg: float) -> str:
    part = part.lower()
    for key, (neg, pos) in ACTION_VERBS.items():
        if key in part:
            return neg if deg < 0 else pos
    return "move"  # generic fallback

def english_feedback(err_vec: list[float], tol: float = 5.0) -> list[str]:
    # pick joints with error > tol
    joints = [(j, deg) for j, deg in enumerate(err_vec) if abs(deg) > tol]
    if not joints:
        return ["✅ Perfect form! Hold that posture."]

    # sort by largest error first, keep top 3
    joints = sorted(joints, key=lambda x: abs(x[1]), reverse=True)[:3]

    tips = []
    for j, deg in joints:
        name = JOINT_NAMES[j]
        part = FRIENDLY.get(name, name.lower().replace("_", " "))
        adv  = magnitude_adverb(deg).capitalize()
        verb = get_action(part, deg)
        tips.append(f"{adv} {verb} your {part} by {abs(deg):.1f}°")
    return tips

# — example usage —
# err_vec = [0, -12.3, 3.1, 20.5, …]  # degrees off for each JOINT_NAMES entry
# print(english_feedback(err_vec))


In [15]:
# 4. Dataset
class KeypointDataset(Dataset):
    def __init__(self,
                 keypt_root: Path,
                 meta:       Path,
                 frames:     int = 16,
                 split:      str = "train"):
        # load metadata (xlsx or csv)
        ext = meta.suffix.lower()
        if ext in (".xlsx", ".xls"):
            df = pd.read_excel(meta, engine="openpyxl")
        else:
            df = pd.read_csv(meta, skipinitialspace=True)
        df.columns = df.columns.str.strip()

        # ensure error‑columns exist
        err_cols = [f"err_{i}" for i in range(N_JOINTS)]
        if not all(c in df.columns for c in err_cols):
            warnings.warn("err_0…err_NJOINTS not found → filling zeros")
            for c in err_cols:
                df[c] = 0.0

        # split video_ids
        vids = sorted(df["video_id"].unique())
        random.shuffle(vids)
        a, b = int(0.7*len(vids)), int(0.85*len(vids))
        keep = (
            vids[:a]   if split=="train" else
            vids[a:b]  if split=="val"   else
            vids[b:]
        )
        self.meta       = df[df.video_id.isin(keep)].reset_index(drop=True)
        self.err_cols   = err_cols
        self.frames     = frames
        self.keypt_root = keypt_root

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx: int):
        row    = self.meta.iloc[idx]
        ex_dir = self.keypt_root / f"Ex{row.exercise_id}"

        # pick Camera17 file by globbing
        pattern = f"{row.video_id}-Camera17*-mp.npy"
        files = list(ex_dir.glob(pattern))
        if not files:
            raise FileNotFoundError(f"No Camera17 MP file matching {pattern} in {ex_dir}")
        if len(files) > 1:
            # unlikely, but guard against duplicates
            files = sorted(files)
        mp_file = files[0]

        # load & reshape
        arr = np.load(mp_file)        # (F, N_JOINTS, 3)
        F, L, C = arr.shape
        assert C == 3 and L == N_JOINTS, f"Expected (F,{N_JOINTS},3), got {arr.shape}"
        arr = arr.reshape(F, L*C)     # → (F, N_JOINTS*3)

        # sample T frames
        tot  = arr.shape[0]
        idxs = torch.linspace(0, tot-1, self.frames).long()
        seq  = torch.from_numpy(arr[idxs]).float()  # (T, N_JOINTS*3)

        # labels & errors
        label    = torch.tensor(row.correctness, dtype=torch.long)
        err_vals = row[self.err_cols].to_numpy(dtype=np.float32)
        err      = torch.tensor(err_vals, dtype=torch.float32)

        return seq, label, err


In [16]:
# 5. Model definitions
class KeypointEncoder(nn.Module):
    def __init__(self, in_dim: int, embed: int = 512):
        super().__init__()
        self.conv1 = nn.Conv1d(in_dim, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(128, embed, kernel_size=3, padding=1)
        self.pool  = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):
        #x = x.transpose(1,2)               # (B, D, T)
        # x is (B, D); treat it as (B, D, 1) so Conv1d can run over that “length=1”
        x = x.unsqueeze(2)                # → (B, D, 1)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        return self.pool(x).squeeze(-1)    # (B, embed)

class PoseQualityNetKP(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 256):
        super().__init__()
        self.encoder  = KeypointEncoder(in_dim)
        self.lstm     = nn.LSTM(
            input_size=512,
            hidden_size=hidden,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        dim = hidden*2
        self.cls_head = nn.Linear(dim, 2)
        self.err_head = nn.Linear(dim, N_JOINTS)

    def forward(self, seq):
        # seq: (B, T, D)
        feats = torch.stack([
            self.encoder(seq[:, t]) for t in range(seq.size(1))
        ], dim=1)                         # (B, T, 512)
        out, _ = self.lstm(feats)         # (B, T, 2*hidden)
        g = out.mean(1)                   # (B, 2*hidden)
        return self.cls_head(g), self.err_head(g)

In [17]:
# 6. Training loop
def train_epochs(
    epochs:    int = 30,
    batch:     int = 16,
    lr:        float = 1e-4,
    ckpt_file: str = "kp_pose_quality_mp.pt"
):
    train_ds = KeypointDataset(KEYPOINT_ROOT, META_FILE, split="train")
    val_ds   = KeypointDataset(KEYPOINT_ROOT, META_FILE, split="val")
    train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)
    val_dl   = DataLoader(val_ds,   batch_size=batch, shuffle=False)

    # infer input dim
    sample_seq, _, _ = train_ds[0]
    in_dim = sample_seq.shape[-1]  # 33*3 = 99

    model    = PoseQualityNetKP(in_dim).to(DEVICE)
    loss_cls = nn.CrossEntropyLoss()
    loss_err = nn.SmoothL1Loss()
    opt      = Adam(model.parameters(), lr)
    best_f1  = 0.0

    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        for seq, y, err in tqdm(train_dl, desc=f"Epoch {epoch:02d}"):
            seq, y, err = seq.to(DEVICE), y.to(DEVICE), err.to(DEVICE)
            opt.zero_grad()
            logits, err_hat = model(seq)
            loss = loss_cls(logits, y) + 0.1*loss_err(err_hat, err)
            loss.backward()
            opt.step()
            total_loss += loss.item() * y.size(0)
        print(f"  ↳ train loss: {total_loss/len(train_ds):.4f}")

        # validation
        model.eval()
        y_true, y_pred, errs = [], [], []
        with torch.no_grad():
            for seq, y, err in val_dl:
                logits, err_hat = model(seq.to(DEVICE))
                y_true += y.tolist()
                y_pred += logits.argmax(1).cpu().tolist()
                errs    += [torch.abs(err_hat.cpu() - err).mean(1)]
        acc = accuracy_score(y_true, y_pred)
        f1  = f1_score(y_true, y_pred)
        mae = torch.cat(errs).mean().item()
        print(f"  ↳ val acc {acc:.3f}, F1 {f1:.3f}, MAE° {mae:.2f}")

        if f1 > best_f1:
            best_f1 = f1

            # package checkpoint with everything needed to rebuild & infer
            checkpoint = {
                "state_dict": model.state_dict(),
                "input_dim":  in_dim,
                "hidden":     model.lstm.hidden_size,
                "frames":     train_ds.frames,
                "epoch":      epoch,
                "best_f1":    best_f1
            }

            torch.save(checkpoint, ckpt_file)
            print(f"  ✓ saved new best model to {ckpt_file}  (epoch {epoch}, F1 {f1:.3f})")

In [18]:
# 8. Run training
train_epochs(epochs=50, batch=16, lr=1e-4)


Epoch 01: 100%|██████████| 46/46 [00:01<00:00, 36.69it/s]


  ↳ train loss: 0.6925
  ↳ val acc 0.486, F1 0.654, MAE° 0.00
  ✓ saved new best model to kp_pose_quality_mp.pt  (epoch 1, F1 0.654)


Epoch 02: 100%|██████████| 46/46 [00:00<00:00, 50.56it/s]


  ↳ train loss: 0.6935
  ↳ val acc 0.486, F1 0.654, MAE° 0.00


Epoch 03: 100%|██████████| 46/46 [00:00<00:00, 46.31it/s]


  ↳ train loss: 0.6919
  ↳ val acc 0.486, F1 0.654, MAE° 0.00


Epoch 04: 100%|██████████| 46/46 [00:01<00:00, 45.28it/s]


  ↳ train loss: 0.6913
  ↳ val acc 0.486, F1 0.654, MAE° 0.00


Epoch 05: 100%|██████████| 46/46 [00:01<00:00, 45.46it/s]


  ↳ train loss: 0.6895
  ↳ val acc 0.486, F1 0.654, MAE° 0.00


Epoch 06: 100%|██████████| 46/46 [00:01<00:00, 45.99it/s]


  ↳ train loss: 0.6836
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 07: 100%|██████████| 46/46 [00:00<00:00, 48.87it/s]


  ↳ train loss: 0.6748
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 08: 100%|██████████| 46/46 [00:00<00:00, 49.64it/s]


  ↳ train loss: 0.6776
  ↳ val acc 0.549, F1 0.621, MAE° 0.00


Epoch 09: 100%|██████████| 46/46 [00:00<00:00, 49.84it/s]


  ↳ train loss: 0.6708
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 10: 100%|██████████| 46/46 [00:00<00:00, 50.14it/s]


  ↳ train loss: 0.6716
  ↳ val acc 0.549, F1 0.621, MAE° 0.00


Epoch 11: 100%|██████████| 46/46 [00:00<00:00, 50.42it/s]


  ↳ train loss: 0.6671
  ↳ val acc 0.549, F1 0.602, MAE° 0.00


Epoch 12: 100%|██████████| 46/46 [00:00<00:00, 49.84it/s]


  ↳ train loss: 0.6693
  ↳ val acc 0.549, F1 0.602, MAE° 0.00


Epoch 13: 100%|██████████| 46/46 [00:00<00:00, 49.74it/s]


  ↳ train loss: 0.6668
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 14: 100%|██████████| 46/46 [00:00<00:00, 49.29it/s]


  ↳ train loss: 0.6697
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 15: 100%|██████████| 46/46 [00:00<00:00, 49.87it/s]


  ↳ train loss: 0.6671
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 16: 100%|██████████| 46/46 [00:00<00:00, 50.11it/s]


  ↳ train loss: 0.6692
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 17: 100%|██████████| 46/46 [00:00<00:00, 49.77it/s]


  ↳ train loss: 0.6661
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 18: 100%|██████████| 46/46 [00:00<00:00, 49.97it/s]


  ↳ train loss: 0.6669
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 19: 100%|██████████| 46/46 [00:00<00:00, 50.62it/s]


  ↳ train loss: 0.6670
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 20: 100%|██████████| 46/46 [00:00<00:00, 50.41it/s]


  ↳ train loss: 0.6659
  ↳ val acc 0.549, F1 0.581, MAE° 0.00


Epoch 21: 100%|██████████| 46/46 [00:00<00:00, 50.90it/s]


  ↳ train loss: 0.6646
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 22: 100%|██████████| 46/46 [00:00<00:00, 50.17it/s]


  ↳ train loss: 0.6643
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 23: 100%|██████████| 46/46 [00:00<00:00, 49.94it/s]


  ↳ train loss: 0.6668
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 24: 100%|██████████| 46/46 [00:00<00:00, 49.61it/s]


  ↳ train loss: 0.6637
  ↳ val acc 0.671, F1 0.607, MAE° 0.00


Epoch 25: 100%|██████████| 46/46 [00:00<00:00, 50.07it/s]


  ↳ train loss: 0.6673
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 26: 100%|██████████| 46/46 [00:00<00:00, 50.56it/s]


  ↳ train loss: 0.6647
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 27: 100%|██████████| 46/46 [00:00<00:00, 49.88it/s]


  ↳ train loss: 0.6635
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 28: 100%|██████████| 46/46 [00:00<00:00, 50.32it/s]


  ↳ train loss: 0.6647
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 29: 100%|██████████| 46/46 [00:00<00:00, 50.62it/s]


  ↳ train loss: 0.6633
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 30: 100%|██████████| 46/46 [00:00<00:00, 50.65it/s]


  ↳ train loss: 0.6635
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 31: 100%|██████████| 46/46 [00:00<00:00, 50.30it/s]


  ↳ train loss: 0.6627
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 32: 100%|██████████| 46/46 [00:00<00:00, 50.95it/s]


  ↳ train loss: 0.6611
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 33: 100%|██████████| 46/46 [00:00<00:00, 49.88it/s]


  ↳ train loss: 0.6613
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 34: 100%|██████████| 46/46 [00:00<00:00, 47.47it/s]


  ↳ train loss: 0.6644
  ↳ val acc 0.543, F1 0.652, MAE° 0.00


Epoch 35: 100%|██████████| 46/46 [00:00<00:00, 49.19it/s]


  ↳ train loss: 0.6613
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 36: 100%|██████████| 46/46 [00:00<00:00, 48.77it/s]


  ↳ train loss: 0.6622
  ↳ val acc 0.549, F1 0.581, MAE° 0.00


Epoch 37: 100%|██████████| 46/46 [00:00<00:00, 49.42it/s]


  ↳ train loss: 0.6615
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 38: 100%|██████████| 46/46 [00:00<00:00, 48.94it/s]


  ↳ train loss: 0.6611
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 39: 100%|██████████| 46/46 [00:00<00:00, 48.89it/s]


  ↳ train loss: 0.6597
  ↳ val acc 0.543, F1 0.680, MAE° 0.00
  ✓ saved new best model to kp_pose_quality_mp.pt  (epoch 39, F1 0.680)


Epoch 40: 100%|██████████| 46/46 [00:00<00:00, 48.90it/s]


  ↳ train loss: 0.6649
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 41: 100%|██████████| 46/46 [00:00<00:00, 49.94it/s]


  ↳ train loss: 0.6594
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 42: 100%|██████████| 46/46 [00:00<00:00, 49.87it/s]


  ↳ train loss: 0.6603
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 43: 100%|██████████| 46/46 [00:00<00:00, 49.46it/s]


  ↳ train loss: 0.6619
  ↳ val acc 0.549, F1 0.557, MAE° 0.00


Epoch 44: 100%|██████████| 46/46 [00:00<00:00, 49.65it/s]


  ↳ train loss: 0.6600
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 45: 100%|██████████| 46/46 [00:00<00:00, 49.64it/s]


  ↳ train loss: 0.6599
  ↳ val acc 0.543, F1 0.680, MAE° 0.00


Epoch 46: 100%|██████████| 46/46 [00:00<00:00, 48.24it/s]


  ↳ train loss: 0.6600
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 47: 100%|██████████| 46/46 [00:00<00:00, 48.87it/s]


  ↳ train loss: 0.6580
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 48: 100%|██████████| 46/46 [00:00<00:00, 49.58it/s]


  ↳ train loss: 0.6623
  ↳ val acc 0.543, F1 0.652, MAE° 0.00


Epoch 49: 100%|██████████| 46/46 [00:00<00:00, 49.93it/s]


  ↳ train loss: 0.6586
  ↳ val acc 0.549, F1 0.530, MAE° 0.00


Epoch 50: 100%|██████████| 46/46 [00:00<00:00, 49.96it/s]


  ↳ train loss: 0.6618
  ↳ val acc 0.543, F1 0.578, MAE° 0.00


In [19]:
import cv2
import numpy as np
import torch
from collections import deque
import mediapipe as mp

# MediaPipe setup
mp_drawing = mp.solutions.drawing_utils
mp_styles  = mp.solutions.drawing_styles
mp_pose    = mp.solutions.pose.Pose(
    static_image_mode=False,
    model_complexity=2,
    enable_segmentation=False,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
)

# Load your feedback helper
# from your_feedback_cell import english_feedback

# Rebuild model (use weights_only=True to suppress the pickle warning)
ckpt = torch.load('kp_pose_quality_mp.pt', map_location=DEVICE, weights_only=True)
in_dim, hidden, frames = ckpt["input_dim"], ckpt["hidden"], ckpt["frames"]
infer_model = PoseQualityNetKP(in_dim, hidden).to(DEVICE)
infer_model.load_state_dict(ckpt["state_dict"])
infer_model.eval()
print("✅ Model loaded. Starting live inference…")

cap = cv2.VideoCapture(0)
buffer = deque(maxlen=frames)
frame_count = 0

print("▶ Press ESC to exit.")
while True:
    ret, frame = cap.read()
    frame_count += 1
    if not ret:
        print("⚠️  Failed to grab frame, exiting.")
        break

    # Prepare image for MediaPipe
    h, w = frame.shape[:2]
    rgb = cv2.cvtColor(cv2.flip(frame, 1), cv2.COLOR_BGR2RGB)
    rgb.flags.writeable = False
    results = mp_pose.process(rgb)  # IMAGE_DIMENSIONS warning can be ignored
    rgb.flags.writeable = True

    img = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)

    if results.pose_landmarks:
        # Draw skeleton
        mp_drawing.draw_landmarks(
            img, results.pose_landmarks, mp.solutions.pose.POSE_CONNECTIONS,
            landmark_drawing_spec=mp_styles.get_default_pose_landmarks_style()
        )
        # Extract and buffer keypoints
        arr = np.array([[lm.x, lm.y, lm.z]
                        for lm in results.pose_landmarks.landmark],
                       dtype=np.float32).reshape(-1)
        buffer.append(arr)
        print(f"[Frame {frame_count}] Buffered {len(buffer)}/{frames}")
    else:
        print(f"[Frame {frame_count}] No landmarks")

    # When buffer is full, run inference
    if len(buffer) == frames:
        print("▶ Buffer full — running inference")
        seq = torch.tensor(np.stack(buffer), dtype=torch.float32).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            logits, err_hat = infer_model(seq)
        # Debug print raw outputs
        print("   raw logits:", logits.cpu().numpy())
        print("   raw err (first 5):", err_hat.squeeze(0).cpu().numpy()[:5])

        pred = logits.argmax(1).item()
        tips = english_feedback(err_hat.squeeze(0).cpu().numpy())
        print(f"   🎯 pred={pred}, tips={tips}")

        # Overlay tips on image
        color = (0,255,0) if pred == 1 else (0,0,255)
        for i, tip in enumerate(tips):
            cv2.putText(img, tip, (10, 30 + 30*i),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA)

        # Clear buffer so next window is fresh
        buffer.clear()

    else:
        # Show buffering progress
        cv2.putText(img,
                    f"Buffering keypoints: {len(buffer)}/{frames}",
                    (10,30),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8,
                    (200,200,0), 2, cv2.LINE_AA)

    cv2.imshow("Live Pose Feedback", img)
    if cv2.waitKey(1) & 0xFF == 27:
        print("❎ ESC pressed, exiting.")
        break

cap.release()
cv2.destroyAllWindows()


I0000 00:00:1744977798.436367  716137 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.3), renderer: Apple M4 Max
W0000 00:00:1744977798.487191  734105 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1744977798.509922  734111 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


✅ Model loaded. Starting live inference…
▶ Press ESC to exit.
[Frame 1] No landmarks
[Frame 2] Buffered 1/16
[Frame 3] Buffered 2/16
[Frame 4] Buffered 3/16
[Frame 5] Buffered 4/16
[Frame 6] Buffered 5/16
[Frame 7] Buffered 6/16
[Frame 8] Buffered 7/16
[Frame 9] Buffered 8/16
[Frame 10] Buffered 9/16
[Frame 11] Buffered 10/16
[Frame 12] Buffered 11/16
[Frame 13] Buffered 12/16
[Frame 14] Buffered 13/16
[Frame 15] Buffered 14/16
[Frame 16] Buffered 15/16
[Frame 17] Buffered 16/16
▶ Buffer full — running inference
   raw logits: [[-1.0874856  1.0430459]]
   raw err (first 5): [-0.00043773  0.00259034 -0.00691423  0.01132228  0.00314515]
   🎯 pred=1, tips=['✅ Perfect form! Hold that posture.']
[Frame 18] Buffered 1/16
[Frame 19] Buffered 2/16
[Frame 20] Buffered 3/16
[Frame 21] Buffered 4/16
[Frame 22] Buffered 5/16
[Frame 23] Buffered 6/16
[Frame 24] Buffered 7/16
❎ ESC pressed, exiting.
