In [1]:
import os

In [2]:
os.chdir("../")

In [3]:
from dataclasses import dataclass
from pathlib import Path

In [4]:
from Condition2Cure.utils.helpers import *
from Condition2Cure.constants import *
from Condition2Cure.utils.execptions import *

In [5]:
@dataclass(frozen=True)
class ModelRegistryConfig:
    model_name: str
    metric_path: Path
    metric_key: str

In [None]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH,
        schema_filepath = SCHEMA_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        self.schema = read_yaml(schema_filepath)

        create_directories([self.config.artifacts_root])

    def get_model_registry_config(self) -> ModelRegistryConfig:
        config = self.config.model_registry


        model_registry_config =  ModelRegistryConfig(
            model_name=config.model_name,
            metric_path=config.metric_path,
            metric_key=config.metric_key
        )

        return model_registry_config

In [None]:
import json
from pathlib import Path
import mlflow
from mlflow.tracking import MlflowClient
from Condition2Cure import logger

class ModelRegistry:
    def __init__(self, config: ModelRegistryConfig):
        self.config = config
        self.client = MlflowClient()

    def load_metric(self) -> float:
        with open(self.config.metric_path, "r") as f:
            metrics = json.load(f)
        return float(metrics.get(self.config.metric_key))

    def get_latest_model_by_stage(self, stage: str):
        versions = self.client.get_latest_versions(name=self.config.model_name, stages=[stage])
        return versions[0] if versions else None

    def promote_model(self, version):
        self.client.transition_model_version_stage(
            name=self.config.model_name,
            version=version,
            stage="Production",
            archive_existing_versions=True
        )
        logger.info(f"Promoted version {version} to Production.")

    def registry(self):
        logger.info("Running model registry promotion check...")
        new_score = self.load_metric()
        staging_model = self.get_latest_model_by_stage("Staging")
        if not staging_model:
            logger.warning("No staging model found.")
            return

        prod_model = self.get_latest_model_by_stage("Production")
        prod_score = None
        if prod_model:
            run_id = prod_model.run_id
            prod_metrics = self.client.get_run(run_id).data.metrics
            prod_score = float(prod_metrics.get(self.config.metric_key, 0))

        logger.info(f"Staging {self.config.metric_key}: {new_score}")
        logger.info(f"Production {self.config.metric_key}: {prod_score}")

        if prod_score is None or new_score > prod_score:
            self.promote_model(staging_model.version)
        else:
            logger.info("No promotion. Staging model is not better than Production.")

In [None]:
try:
    config = ConfigurationManager()
    model_registry_config = config.get_model_registry_config()
    registry = ModelRegistry(config=model_registry_config)
    registry.registry()
except Exception as e:
    raise CustomException(e, sys) from e

[2025-06-22 15:51:09,670: INFO: helpers: yaml file: config\config.yaml loaded successfully]
[2025-06-22 15:51:09,676: INFO: helpers: yaml file: config\params.yaml loaded successfully]
[2025-06-22 15:51:09,676: INFO: helpers: yaml file: config\schema.yaml loaded successfully]
[2025-06-22 15:51:09,676: INFO: helpers: created directory at: artifacts]
[2025-06-22 15:51:09,712: INFO: 1259520673: 🚀 Running model registry promotion check...]


  versions = self.client.get_latest_versions(name=self.config.model_name, stages=[stage])


