In [1]:
import os
import pickle
import click
import mlflow

from mlflow.entities import ViewType
from mlflow.tracking import MlflowClient
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

In [2]:
HPO_EXPERIMENT_NAME = "random-forest-hyperopt"
EXPERIMENT_NAME = "random-forest-best-models"
RF_PARAMS = ['max_depth', 'n_estimators', 'min_samples_split', 'min_samples_leaf', 'random_state']

mlflow.set_tracking_uri("http://127.0.0.1:5002")
mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.sklearn.autolog()

In [3]:
def load_pickle(filename):
    with open(filename, "rb") as f_in:
        return pickle.load(f_in)

In [4]:
def train_and_log_model(data_path, params):
    X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
    X_val, y_val = load_pickle(os.path.join(data_path, "val.pkl"))
    X_test, y_test = load_pickle(os.path.join(data_path, "test.pkl"))

    with mlflow.start_run(run_name="RandomForest"):
        new_params = {}
        for param in RF_PARAMS:
            new_params[param] = int(params[param])

        rf = RandomForestRegressor(**new_params)
        rf.fit(X_train, y_train)

        # Evaluate model on the validation and test sets
        val_rmse = mean_squared_error(y_val, rf.predict(X_val), squared=False)
        mlflow.log_metric("val_rmse", val_rmse)
        test_rmse = mean_squared_error(y_test, rf.predict(X_test), squared=False)
        mlflow.log_metric("test_rmse", test_rmse)

In [8]:
def run_register_model(data_path: str = "./output",  top_n: int = 10):

    client = MlflowClient()

    # Retrieve the top_n model runs and log the models
    experiment = client.get_experiment_by_name(HPO_EXPERIMENT_NAME)
    runs = client.search_runs(
        experiment_ids=experiment.experiment_id,
        run_view_type=ViewType.ACTIVE_ONLY,
        max_results=top_n,
        order_by=["metrics.rmse ASC"]
    )
    for run in runs:
        train_and_log_model(data_path=data_path, params=run.data.params)

    # Select the model with the lowest test RMSE
    experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
    best_run = client.search_runs(
        experiment_ids=experiment.experiment_id,
        run_view_type=ViewType.ACTIVE_ONLY,
        max_results=top_n,
        order_by=["metrics.rmse ASC"])[0]

    # Register the best model
    run_id = best_run.info.run_id
    model_uri = f"runs:/{run_id}/model"
    mlflow.register_model(model_uri=model_uri, name="random-forest-best-model")


if __name__ == '__main__':
    run_register_model()



🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/bf248f1b030b493e88a3797bafd41757
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3




🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/8cc2b60c9ad74f84b009066a0f1bc156
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3




🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/ed8cac66aad240e5bdf6df6dd72546fe
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3




🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/d68d17d1a8154fb5a74b16b578fab9e3
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3




🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/ea94adedd07147feac24a03f055235a7
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3




🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/76e9f4e49aea45fbaf9f865c3900748e
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3




🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/d961bbd3f1a24d63a061a067dc0cf76b
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3




🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/493da101c2774a3a9fcb3d91aed029ff
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3




🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/92b5822d31b147aca43207573eb75236
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3


Registered model 'random-forest-best-model' already exists. Creating a new version of this model...
2025/05/10 00:28:16 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: random-forest-best-model, version 5


🏃 View run RandomForest at: http://127.0.0.1:5002/#/experiments/3/runs/cf0ae4831ae143368499ddf932bfcb20
🧪 View experiment at: http://127.0.0.1:5002/#/experiments/3


Created version '5' of model 'random-forest-best-model'.


In [None]:
! python register_model.py