In [1]:
import os

from mlflow.recipes.steps.ingest import CustomDataset

os.chdir("../")
%pwd

  from google.protobuf import service as _service


'C:\\Users\\Legion\\OneDrive\\Desktop\\Paris-Saclay\\Learning\\AI\\badminton-pose-coach'

In [2]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model_path: Path
    training_data: Path
    checkpoint_dir: Path
    params_epochs: int
    params_batch_size: int
    params_device: str
    params_lr: float
    params_step_size: int
    params_gamma: float
    params_use_amp: bool


In [3]:
from badmintonPoseCoach.constants import *
from badmintonPoseCoach.utils.common import read_yaml, create_directories

In [4]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_training_config(self) -> TrainingConfig:
        prepare_base_model_config = self.config.prepare_base_model
        training_config = self.config.training
        params = self.params.training

        create_directories([training_config.root_dir])

        training_config = TrainingConfig(
            root_dir=Path(training_config.root_dir),
            trained_model_path = Path(training_config.trained_model_path),
            updated_base_model_path=Path(prepare_base_model_config.updated_base_model_path),
            training_data = Path(training_config.training_data),
            checkpoint_dir = Path(training_config.checkpoint_dir),
            params_epochs = params.epochs,
            params_batch_size = params.batch_size,
            params_device = params.device,
            params_step_size = params.step_size,
            params_lr = params.lr,
            params_gamma = params.gamma,
            params_use_amp = params.use_amp,
        )
        return training_config

In [5]:
import os
from pathlib import Path
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import pandas as pd
import json

In [6]:
class BadmintonPoseDataset(Dataset):
    """
    Dataset class that load json file and output a dataframe of keypoints
    """
    def __init__(self,
                 config: TrainingConfig,
                 seed: int = 42,
                 split: str = 'train',
                 split_ratio: tuple[float, float, float] = (0.8, 0.1, 0.1),
                 frame_format: str = 'auto',
                 num_joints: int = 17,):
        self.training_data = Path(config.training_data)
        self.frame_format = frame_format
        self.num_joints = num_joints

        class_dirs = sorted([d for d in self.training_data.iterdir() if d.is_dir()])
        self.class_names = [d.name for d in class_dirs]

        # list all files in data folder
        self.file_list = []
        for ci, d in enumerate(class_dirs):
            for p in sorted(d.rglob("*.json")):
                self.file_list.append((p, ci))

        # Train/val/test split
        g = torch.Generator().manual_seed(seed)
        per_class_idx = [[] for _ in self.class_names]
        for idx, (_p, ci) in enumerate(self.file_list):
            per_class_idx[ci].append(idx)
        for lst in per_class_idx:
            perm = torch.randperm(len(lst), generator=g).tolist()
            lst = [lst[i] for i in perm]

        def take_splits(idxs: list[int]) -> tuple[list[int], list[int], list[int]]:
            n = len(idxs)
            n_train = int(n * split_ratio[0])
            n_val = int(n * split_ratio[1])
            return idxs[:n_train], idxs[n_train:n_train+n_val], idxs[n_train+n_val:]

        split_map = {"train": 0, "val": 1, "valid": 1, "validation": 1, "test": 2}
        which = split_map[split]

        selected: list[int] = []
        for lst in per_class_idx:
            tr, va, te = take_splits(lst)
            selected.extend([tr, va, te][which])
        selected = sorted(selected)

        self.files: list[Path] = [self.file_list[i][0] for i in selected]
        self.labels: list[int] = [self.file_list[i][1] for i in selected]



    def __len__(self):
        return len(self.files)

    def __getitem__(self, index: int) -> tuple[torch.FloatTensor, int]:
        path = self.files[index]
        label = self.labels[index]
        with open(path, "r", encoding="utf-8") as f:
            obj = json.load(f)

        seq = obj.get("seq")
        if seq is None:
            raise ValueError(f"Missing 'seq' in {path}")

        pose = self._to_tensor_TxKx3(seq)

        return self.normalize_pose(pose, 720, 1280), label

    def _to_tensor_TxKx3(self, seq: any) -> torch.tensor:
        if self.frame_format in ("auto", "Kx3") and isinstance(seq, list) and len(seq) > 0 and isinstance(seq[0], list):
            sample = seq[0]
            if len(sample) > 0 and isinstance(sample[0], list):
                return torch.tensor(seq, dtype=torch.float32)
            else:
                if self.frame_format == "flat" and self.num_joints is not None:
                    K = int(self.num_joints)
                else:
                    flen = len(sample)
                    if flen % 3 != 0:
                        raise ValueError("Cannot infer num_keypoints")
                    K = flen // 3
                frames_Kx3 = []
                for fr in seq:
                    triplets = [fr[i:i+3] for i in range(0, len(fr), 3)]
                    frames_Kx3.append(triplets)
                return torch.tensor(frames_Kx3, dtype=torch.float32)

        if self.frame_format in ("auto", "flat") and isinstance(seq, list) and seq and isinstance(seq[0], (int,float)):
            if self.num_joints is None:
                raise ValueError("Need num_keypoints for flat seq")
            K = int(self.num_joints)
            T = len(seq) // (K*3)
            return torch.tensor(seq, dtype=torch.float32).view(T, K, 3)

        raise ValueError("Unsupported 'seq' structure")
    @staticmethod
    def normalize_pose(pose, W, H, method="skeleton"):
        # pose: (T,K,3)
        if method == "image":
            pose[...,0] /= W
            pose[...,1] /= H
        elif method == "skeleton":
            # pelvis = joint 11,12 in average
            pelvis = pose[:,[11,12],:2].mean(1, keepdims=True)
            pose[...,:2] -= pelvis
            # scale with the shoulder
            shoulder = pose[:,[5,6],:2].mean(1, keepdims=True)
            scale = (pose[:,5,:2]-pose[:,6,:2]).norm(dim=-1, keepdim=True).clamp(min=1e-6)
            pose[...,:2] /= scale[:,None,:]
        return pose

In [7]:
from torch.nn.utils.rnn import pack_padded_sequence
def pack_collate(batch: list[tuple[torch.tensor, int]]) -> dict[str, any]:
    """
    Collate function sử dụng torch.nn.utils.rnn.pack_padded_sequence.
    - pose: (T,K,3) với T có thể khác nhau
    - Trả về PackedSequence để dùng cho RNN.

    Returns:
      packed: PackedSequence chứa (T, K*3)
      lengths: chiều dài thực tế từng sample
      labels: (B,)
    """
    poses, labels = zip(*batch)
    lengths = torch.tensor([p.shape[0] for p in poses], dtype=torch.long)
    K = poses[0].shape[1]

    # Flatten (T,K,3) -> (T, K*3)
    flat_poses = [p.reshape(p.shape[0], K*3) for p in poses]
    padded = torch.nn.utils.rnn.pad_sequence(flat_poses, batch_first=True)

    packed = pack_padded_sequence(padded, lengths, batch_first=True, enforce_sorted=False)
    labels_t = torch.tensor(labels, dtype=torch.long)
    return {"packed": packed, "lengths": lengths, "labels": labels_t}

In [8]:
import torch
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

class Trainer:
    def __init__(self, config: TrainingConfig,
                 train_loader: torch.utils.data.DataLoader,
                 val_loader: torch.utils.data.DataLoader,
                 test_loader: torch.utils.data.DataLoader,):
        self.config = config
        self.device = config.params_device

        self.model = torch.load(self.config.updated_base_model_path, weights_only=False).to(self.device)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=config.params_lr)
        self.scheduler = StepLR(self.optimizer, step_size=config.params_step_size, gamma=config.params_gamma)

        self.criterion = nn.CrossEntropyLoss()

        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

         # AMP
        self.scaler = torch.amp.GradScaler("cuda", enabled=config.params_use_amp and self.device == "cuda")
        # Checkpoint dir
        Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True)


    def save_model(self):
        torch.save(self.model, self.config.trained_model_path)

    def _step_batch(self, batch, train: bool = True) -> tuple[float, float]:
        packed = batch["packed"]
        labels = batch["labels"].to(self.device)
        packed = packed.to(self.device)

        with torch.amp.autocast("cuda", enabled=self.scaler.is_enabled()):

            logits = self.model(packed)

            loss = self.criterion(logits, labels)
        if train:
            self.optimizer.zero_grad(set_to_none=True)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
        with torch.no_grad():
            preds = logits.argmax(dim=-1)
            acc = (preds == labels).float().mean().item()
        return loss.item(), acc

    def train_one_epoch(self, epoch: int) -> tuple[float, float]:
        self.model.train()
        total_loss, total_acc, n = 0.0, 0.0, 0
        for batch in tqdm(self.train_loader):
            loss, acc = self._step_batch(batch, train=True)
            total_loss += loss
            total_acc += acc
            n += 1
        avg_loss = total_loss / max(n, 1)
        avg_acc = total_acc / max(n, 1)
        return avg_loss, avg_acc

    @torch.no_grad()
    def evaluate(self, split: str = "val") -> tuple[float, float]:
        self.model.eval()
        loader = {"val": self.val_loader, "test": self.test_loader}[split]
        total_loss, total_acc, n = 0.0, 0.0, 0
        for batch in tqdm(loader):
            loss, acc = self._step_batch(batch, train=False)
            total_loss += loss
            total_acc += acc
            n += 1
        return total_loss / max(n, 1), total_acc / max(n, 1)

    def fit(self):
        best_val_acc = 0.0
        for epoch in range(1, self.config.params_epochs + 1):
            train_loss, train_acc = self.train_one_epoch(epoch)
            val_loss, val_acc = self.evaluate("val")
            self.scheduler.step()

            print(f"Epoch {epoch:03d}: train loss {train_loss:.4f}, acc {train_acc:.4f} | "
                  f"val loss {val_loss:.4f}, acc {val_acc:.4f}")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                ckpt_path = Path(self.config.checkpoint_dir) / f"best.pkl"
                torch.save({
                    "model_state": self.model.state_dict(),
                    "cfg": self.config.__dict__,
                    "val_acc": val_acc,
                }, ckpt_path)
                print(f"Saved checkpoint to {ckpt_path}")

        print(f"Best val acc: {best_val_acc:.4f}")

    @torch.no_grad()
    def test(self):
        ckpt_path = Path(self.config.checkpoint_dir) / "best.pkl"
        if ckpt_path.exists():
            state = torch.load(ckpt_path, map_location=self.device, weights_only=False)
            self.model.load_state_dict(state["model_state"])
            print(f"Loaded checkpoint from {ckpt_path}")
        test_loss, test_acc = self.evaluate("test")
        print(f"Test: loss {test_loss:.4f}, acc {test_acc:.4f}")


In [9]:
config = ConfigurationManager()
train_config = config.get_training_config()
train_data = BadmintonPoseDataset(train_config, split='train')
val_data = BadmintonPoseDataset(train_config, split='val')
test_data = BadmintonPoseDataset(train_config, split='test')

train_loader = DataLoader(train_data, batch_size=train_config.params_batch_size, collate_fn=pack_collate, shuffle=True)
val_loader = DataLoader(val_data, batch_size=train_config.params_batch_size, collate_fn=pack_collate, shuffle=True)
test_loader = DataLoader(test_data, batch_size=train_config.params_batch_size, collate_fn=pack_collate)

trainer = Trainer(train_config, train_loader, val_loader, test_loader)
trainer.fit()
trainer.test()


[2025-09-27 21:54:28,703: INFO: common: yaml file: config\config.yaml loaded successfully]
[2025-09-27 21:54:28,706: INFO: common: yaml file: params.yaml loaded successfully]
[2025-09-27 21:54:28,707: INFO: common: created directory at: artifacts]
[2025-09-27 21:54:28,709: INFO: common: created directory at: artifacts/training]


100%|██████████| 389/389 [00:09<00:00, 38.99it/s]


Epoch 001: train loss 1.8337, acc 0.4521 | val loss 1.8453, acc 0.4732
Saved checkpoint to artifacts\training\checkpoints\best.pt


100%|██████████| 389/389 [00:12<00:00, 31.49it/s]


Epoch 002: train loss 1.3952, acc 0.5833 | val loss 1.7189, acc 0.5077
Saved checkpoint to artifacts\training\checkpoints\best.pt


100%|██████████| 389/389 [00:11<00:00, 33.03it/s]


Epoch 003: train loss 1.2543, acc 0.6210 | val loss 1.6444, acc 0.5204
Saved checkpoint to artifacts\training\checkpoints\best.pt


100%|██████████| 389/389 [00:09<00:00, 41.37it/s]


Epoch 004: train loss 1.1735, acc 0.6380 | val loss 1.5576, acc 0.5255
Saved checkpoint to artifacts\training\checkpoints\best.pt


100%|██████████| 389/389 [00:09<00:00, 41.03it/s]


Epoch 005: train loss 1.1077, acc 0.6603 | val loss 1.5101, acc 0.5727
Saved checkpoint to artifacts\training\checkpoints\best.pt


100%|██████████| 389/389 [00:09<00:00, 41.22it/s]


Epoch 006: train loss 1.0346, acc 0.6746 | val loss 1.5714, acc 0.5599


100%|██████████| 389/389 [00:09<00:00, 42.08it/s]


Epoch 007: train loss 0.9779, acc 0.6910 | val loss 1.5739, acc 0.5663


100%|██████████| 389/389 [00:09<00:00, 41.87it/s]


Epoch 008: train loss 0.9141, acc 0.7109 | val loss 1.4350, acc 0.5995
Saved checkpoint to artifacts\training\checkpoints\best.pt


100%|██████████| 389/389 [00:09<00:00, 41.00it/s]


Epoch 009: train loss 0.8802, acc 0.7167 | val loss 1.4250, acc 0.5778


100%|██████████| 389/389 [00:09<00:00, 41.49it/s]


Epoch 010: train loss 0.8062, acc 0.7416 | val loss 1.7362, acc 0.5383
Best val acc: 0.5995
Loaded checkpoint from artifacts\training\checkpoints\best.pt
Test: loss 1.3006, acc 0.6221


In [10]:
import numpy as np

# 2 frame, 4 keypoints, 3 giá trị (x,y,score)
pose = np.array([
    [[10, 20, 0.9], [30, 40, 0.8], [50, 60, 0.7], [70, 80, 0.6]],
    [[15, 25, 0.95], [35, 45, 0.85], [55, 65, 0.75], [75, 85, 0.65]]
])  # shape (2,4,3)

print(pose[...,0])  # lấy toàn bộ x -> shape (2,4)
# [[10 30 50 70]
#  [15 35 55 75]]

print(pose[...,1])  # lấy toàn bộ y -> shape (2,4)
# [[20 40 60 80]
#  [25 45 65 85]]

[[         10          30          50          70]
 [         15          35          55          75]]
[[         20          40          60          80]
 [         25          45          65          85]]


In [11]:
def normalize_pose(pose, W, H, method="skeleton"):
    # pose: (T,K,3)
    if method == "image":
        pose[...,0] /= W
        pose[...,1] /= H
    elif method == "skeleton":
        # pelvis = joint 11,12 in average
        pelvis = pose[:,[11,12],:2].mean(1, keepdims=True)
        pose[...,:2] -= pelvis
        # scale with the shoulder
        shoulder = pose[:,[5,6],:2].mean(1, keepdims=True)
        scale = (pose[:,5,:2]-pose[:,6,:2]).norm(dim=-1, keepdim=True).clamp(min=1e-6)
        pose[...,:2] /= scale[:,None,:]
    return pose

#train_data[0][0]
normalize_pose(train_data[0][0], 720,1280, "skeleton")

tensor([[[-3.9153e-01, -1.6061e+00,  1.2644e-01],
         [-3.5748e-01, -1.6992e+00,  1.4010e-01],
         [-3.1302e-01, -1.7031e+00,  3.0042e-02],
         [-2.8901e-01, -1.6639e+00,  8.8703e-01],
         [ 1.0857e-01, -1.6435e+00,  3.2072e-01],
         [-5.6635e-01, -1.2585e+00,  9.9525e-01],
         [ 4.3139e-01, -1.1913e+00,  9.8801e-01],
         [-1.0255e+00, -1.0691e+00,  9.7379e-01],
         [ 7.3381e-01, -9.3401e-01,  8.5509e-01],
         [-1.1087e+00, -1.1081e+00,  8.9493e-01],
         [ 7.4719e-01, -1.0355e+00,  6.6659e-01],
         [-3.2455e-01, -2.1698e-02,  9.9826e-01],
         [ 3.2455e-01,  2.1698e-02,  9.9710e-01],
         [-5.1888e-01,  7.7028e-01,  9.9477e-01],
         [ 2.7007e-01,  8.2048e-01,  9.8932e-01],
         [-4.7282e-01,  1.6781e+00,  9.8510e-01],
         [ 2.9500e-01,  1.7270e+00,  9.7517e-01]],

        [[-3.4170e-01, -1.6471e+00,  1.1605e-01],
         [-3.0739e-01, -1.7378e+00,  1.1529e-01],
         [-2.6835e-01, -1.7437e+00,  3.0186e-02]