In [1]:
import os
os.chdir("../")
%pwd

'C:\\Users\\Legion\\OneDrive\\Desktop\\Paris-Saclay\\Learning\\AI\\badminton-pose-coach'

In [21]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)  # you cant add another element here because frozen = True
class PrepareBaseModelConfig:
    root_dir: Path
    updated_base_model_path: Path
    params_model_name: float

@dataclass(frozen=True)
class ModelConfig:
    params_model_name: float
    params_num_classes: int
    params_hidden: int
    params_layers: int
    params_dropout: int
    params_num_joints: int
    params_channel: int
    params_bidirectional: bool

In [22]:
from badmintonPoseCoach.constants import *
from badmintonPoseCoach.utils.common import read_yaml, create_directories

In [66]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        params = self.params.prepare_base_model

        create_directories([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_model_name=params.model_name,
        )
        return prepare_base_model_config

    def get_model_config(self) -> ModelConfig:
        params = None
        if self.params.prepare_base_model.model_name == "gru":
            params = self.params.prepare_base_model.gru
        model_config = ModelConfig(
            params_hidden=params.hidden,
            params_model_name=params.model_name,
            params_num_classes=params.num_classes,
            params_layers=params.layers,
            params_dropout=params.dropout,
            params_num_joints=params.num_joints,
            params_channel=params.channel,
            params_bidirectional=params.bidirectional,
        )
        return model_config


In [67]:
import os
from pathlib import Path
import torch.nn as nn


In [71]:
class GRUModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        self.gru = nn.GRU(
                    input_size=self.config.params_hidden * self.config.params_layers,
                    hidden_size=self.config.params_hidden,
                    num_layers=self.config.params_layers,
                    dropout=self.config.params_dropout,
                    batch_first=True,
                    bidirectional=self.config.params_bidirectional,
                ),
        self.fc = nn.Linear(self.config.params_hidden * (2 if self.config.params_bidirectional else 1), self.config.params_num_classes)
    def forward(self, x):
        # x: [B,T,C,J] -> [B,T,C*J]
        B,T,C,J = x.shape
        x = x.reshape(B,T,C*J)
        out,_ = self.gru(x)
        last = out[:,-1,:]
        return self.fc(last)

In [72]:
import torch
class PrepareBaseModel:
    def __init__(self, prepare_base_model_config: PrepareBaseModelConfig, model_config: ModelConfig):
        self.model = None
        self.prepare_base_model_config = prepare_base_model_config
        self.model_config = model_config

    def get_base_model(self):
        if self.prepare_base_model_config.params_model_name == "gru":
            self.model = GRUModel(config=self.model_config)
            self.save_model(self.prepare_base_model_config.updated_base_model_path, self.model)
        return self.model

    @staticmethod
    def save_model(path, model):
        torch.save(model.state_dict(), path)

In [74]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    model_config = config.get_model_config()
    prepare_base_model = PrepareBaseModel(prepare_base_model_config, model_config)
    prepare_base_model.get_base_model()
except Exception as e:
    raise e

[2025-09-18 18:46:26,794: INFO: common: yaml file: config\config.yaml loaded successfully]
[2025-09-18 18:46:26,799: INFO: common: yaml file: params.yaml loaded successfully]
[2025-09-18 18:46:26,801: INFO: common: created directory at: artifacts]
[2025-09-18 18:46:26,802: INFO: common: created directory at: artifacts/prepare_base_model]
