In [None]:
# 0. Imports
from __future__ import annotations
import os, random, warnings, json
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm

In [10]:

# 1. Paths & device
SCRIPT_DIR = Path(os.getcwd()).resolve()
ROOT       = SCRIPT_DIR / "Data-REHAB24-6"
META_FILE  = ROOT / "Segmentation.xlsx"        # original Excel
JOINT_FILE = ROOT / "joints_names.txt"

DEVICE = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print("► Using device:", DEVICE)

# reproducibility
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

# joint names
JOINT_NAMES = [
    ln.split(":", 1)[1].strip()
    for ln in JOINT_FILE.read_text().splitlines()
]
N_JOINTS = len(JOINT_NAMES)

► Using device: mps


In [11]:
# 2. Natural‑language feedback
def english_feedback(err_vec: torch.Tensor | np.ndarray, tol: float = 5.0):
    tips: list[str] = []
    for j, deg in enumerate(err_vec):
        deg = float(deg)
        if abs(deg) <= tol:
            continue
        act = "straighten" if deg > 0 else "bend"
        tips.append(f"{act} your {JOINT_NAMES[j]} by ≈{abs(deg):.0f}°")
    return tips or ["Great form ✅"]

In [12]:
# 3. Dataset
class RehabDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root: str | Path,
        meta: str | Path,
        frames: int = 16,
        camera: str = "Camera17",
        transform: transforms.Compose | None = None,
        split: str = "train"
    ):
        self.root   = Path(root)
        self.frames = frames
        self.camera = camera

        # load metadata
        ext = Path(meta).suffix.lower()
        if ext in [".xlsx", ".xls"]:
            df = pd.read_excel(meta, engine="openpyxl")
        else:
            df = pd.read_csv(meta, skipinitialspace=True, encoding="utf-8-sig")
        df.columns = df.columns.str.strip()

        # ensure err_0…err_25 exist (zero‑fill if missing)
        err_cols = [f"err_{i}" for i in range(N_JOINTS)]
        if not all(c in df.columns for c in err_cols):
            warnings.warn("err_0…err_25 not found in metadata → auto‑filling zeros")
            for c in err_cols:
                df[c] = 0.0
        self.err_cols = err_cols

        # subject‑wise split
        vids = sorted(df["video_id"].unique())
        random.shuffle(vids)
        a, b = int(0.7 * len(vids)), int(0.85 * len(vids))
        keep = vids[:a] if split=="train" else vids[a:b] if split=="val" else vids[b:]
        self.meta = df[df.video_id.isin(keep)].reset_index(drop=True)

        # transforms
        self.transform = transform or transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize([.485,.456,.406],[.229,.224,.225])
        ])

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx: int):
        row = self.meta.iloc[idx]
        mp4 = (
            self.root / f"videos/Ex{row.exercise_id}"
            / f"{row.video_id}-Camera17-30fps.mp4"
        )
        cap = cv2.VideoCapture(str(mp4))
        if not cap.isOpened():
            warnings.warn(f"Cannot open {mp4}; skipping to next")
            return self.__getitem__((idx+1) % len(self))

        tot = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frames = torch.linspace(row.first_frame, min(row.last_frame, tot-1), self.frames).long()
        imgs = []
        for f in frames:
            cap.set(cv2.CAP_PROP_POS_FRAMES, int(f))
            ok, img = cap.read()
            if not ok:
                warnings.warn(f"Bad frame {f} in {mp4}; skipping")
                return self.__getitem__((idx+1) % len(self))
            imgs.append(self.transform(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)))
        cap.release()

        clip  = torch.stack(imgs)  # (T,C,H,W)
        label = torch.tensor(row.correctness, dtype=torch.long)

        # ❗️ Force float32 array to avoid object dtype
        err_vals = row[self.err_cols].to_numpy(dtype=np.float32)
        err      = torch.tensor(err_vals, dtype=torch.float32)

        return clip, label, err

In [13]:

# 4. Model
class CNNEncoder(nn.Module):
    def __init__(self, embed: int = 512):
        super().__init__()
        base = resnet18(weights="IMAGENET1K_V1")
        self.backbone = nn.Sequential(*list(base.children())[:-1])
        self.proj     = nn.Linear(512, embed)

    def forward(self, x):
        B, T = x.shape[:2]
        feats = self.backbone(x.view(B*T, *x.shape[2:])).flatten(1)
        return self.proj(feats).view(B, T, -1)

class PoseQualityNet(nn.Module):
    def __init__(self, embed: int = 512, hidden: int = 256):
        super().__init__()
        self.cnn   = CNNEncoder(embed)
        self.lstm  = nn.LSTM(embed, hidden, 2, batch_first=True, bidirectional=True)
        dim = hidden * 2
        self.cls_head = nn.Linear(dim, 2)
        self.err_head = nn.Linear(dim, N_JOINTS)

    def forward(self, clip):
        feats, _ = self.lstm(self.cnn(clip))
        g = feats.mean(1)
        return self.cls_head(g), self.err_head(g)

In [14]:
# 5. Training routine
def train_epochs(
    epochs: int = 30,
    batch: int = 4,
    lr: float = 1e-4,
    ckpt_file: str | Path = "pose_quality_best.pt"
):
    train_ds = RehabDataset(ROOT, META_FILE, split="train")
    val_ds   = RehabDataset(ROOT, META_FILE, split="val")
    train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)
    val_dl   = DataLoader(val_ds,   batch_size=batch, shuffle=False)

    model    = PoseQualityNet().to(DEVICE)
    loss_cls = nn.CrossEntropyLoss()
    loss_err = nn.SmoothL1Loss()
    opt      = Adam(model.parameters(), lr)
    best_f1  = 0.0

    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        for clip, y, err in tqdm(train_dl, desc=f"Epoch {epoch:02d}"):
            clip, y, err = clip.to(DEVICE), y.to(DEVICE), err.to(DEVICE)
            opt.zero_grad()
            logits, err_hat = model(clip)
            loss = loss_cls(logits, y) + 0.1 * loss_err(err_hat, err)
            loss.backward()
            opt.step()
            total_loss += loss.item() * y.size(0)
        print(f"  ↳ train loss: {total_loss/len(train_ds):.4f}")

        # validation
        model.eval()
        y_true, y_pred, errs = [], [], []
        with torch.no_grad():
            for clip, y, err in val_dl:
                logits, err_hat = model(clip.to(DEVICE))
                y_true += y.tolist()
                y_pred += logits.argmax(1).cpu().tolist()
                errs    += [torch.abs(err_hat.cpu() - err).mean(1)]
        acc = accuracy_score(y_true, y_pred)
        f1  = f1_score(y_true, y_pred)
        mae = torch.cat(errs).mean().item()
        print(f"  ↳ val acc {acc:.3f}, F1 {f1:.3f}, MAE° {mae:.2f}")

        if f1 > best_f1:
            best_f1 = f1
            torch.save({"state": model.state_dict()}, ckpt_file)
            print("  ✓ saved new best model")



In [15]:
# 6. run training
train_epochs(epochs=2, batch=16, lr=1e-4, ckpt_file="pose_quality_best.pt")

Epoch 01: 100%|██████████| 46/46 [14:03<00:00, 18.35s/it]


  ↳ train loss: 0.6563
  ↳ val acc 0.651, F1 0.538, MAE° 0.02
  ✓ saved new best model


Epoch 02: 100%|██████████| 46/46 [14:46<00:00, 19.27s/it]


  ↳ train loss: 0.4827
  ↳ val acc 0.709, F1 0.769, MAE° 0.01
  ✓ saved new best model


In [16]:
# 7. Inference & Feedback 

from pathlib import Path
import torch
import json
import cv2
from torchvision import transforms

# make sure PoseQualityNet & english_feedback are still in scope from Cell 1

DEVICE = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)

@torch.inference_mode()
def predict_clip(
    clip_path: str | Path,
    model_pt: str | Path = "pose_quality_best.pt",
    frames: int = 16
):
    ck = torch.load(model_pt, map_location=DEVICE)
    net = PoseQualityNet().to(DEVICE).eval()
    net.load_state_dict(ck["state"])

    cap = cv2.VideoCapture(str(clip_path))
    if not cap.isOpened():
        raise IOError(f"Cannot open {clip_path}")
    tot = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    idxs = torch.linspace(0, max(tot-1,0), frames).long()

    tfm = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([.485,.456,.406],[.229,.224,.225])
    ])
    frames_tensor = []
    for f in idxs:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(f))
        ok, img = cap.read()
        if not ok:
            raise RuntimeError(f"Bad frame {f}")
        frames_tensor.append(tfm(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)))
    cap.release()

    clip = torch.stack(frames_tensor).unsqueeze(0).to(DEVICE)
    logits, err_hat = net(clip)
    prob = torch.softmax(logits,1)[0,1].item()
    tips = english_feedback(err_hat.squeeze().cpu())
    return {"correctness_prob": prob, "feedback": tips}

# run inference
clip_path = Path("Data-REHAB24-6/videos/Ex1/sample.mp4")
out = predict_clip(clip_path, model_pt="pose_quality_best.pt", frames=16)

print(json.dumps(out, indent=2))
print(f"\n✅ Correctness: {out['correctness_prob']:.2%}")
print("📝 Feedback:")
for t in out["feedback"]:
    print("  -", t)


  ck = torch.load(model_pt, map_location=DEVICE)
OpenCV: Couldn't read video stream from file "Data-REHAB24-6/videos/Ex1/sample.mp4"


OSError: Cannot open Data-REHAB24-6/videos/Ex1/sample.mp4