In [7]:
 import os
 import pickle
 import click
 import mlflow
import numpy as np

In [3]:
 from mlflow.entities import ViewType
 from mlflow.tracking import MlflowClient
 from sklearn.ensemble import RandomForestRegressor
 from sklearn.metrics import root_mean_squared_error
from sklearn.metrics import mean_squared_error

In [4]:
HPO_EXPERIMENT_NAME = "random-forest-hyperopt"
EXPERIMENT_NAME    = "random-forest-best-models"
RF_PARAMS          = ['max_depth', 'n_estimators', 'min_samples_split', 'min_samples_leaf', 'random_state']


In [5]:
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.sklearn.autolog()


2025/06/22 23:04:49 INFO mlflow.tracking.fluent: Experiment with name 'random-forest-best-models' does not exist. Creating a new experiment.


In [6]:
def load_pickle(filename):
    with open(filename, "rb") as f_in:
        return pickle.load(f_in)


In [12]:
def train_and_log_model(data_path, params):
    X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
    X_val,   y_val   = load_pickle(os.path.join(data_path, "val.pkl"))
    X_test,  y_test  = load_pickle(os.path.join(data_path, "test.pkl"))

    with mlflow.start_run():
        new_params = {}
        for param in RF_PARAMS:
            new_params[param] = int(params[param])

        rf = RandomForestRegressor(**new_params)
        rf.fit(X_train, y_train)
        # Evaluate model on the validation and test sets
        mse_val = mean_squared_error(y_val, rf.predict(X_val))
        val_rmse = float(np.sqrt(mse_val))
        mlflow.log_metric("val_rmse", val_rmse)
        mse_test = mean_squared_error(y_test, rf.predict(X_test))
        test_rmse = float(np.sqrt(mse_val))
        mlflow.log_metric("test_rmse", test_rmse)


In [13]:
def run_register_model(data_path: str = "./output", top_n: int = 5):

    client = MlflowClient()

    # Retrieve the top_n model runs and log the models
    experiment = client.get_experiment_by_name(HPO_EXPERIMENT_NAME)
    runs = client.search_runs(
        experiment_ids=experiment.experiment_id,
        run_view_type=ViewType.ACTIVE_ONLY,
        max_results=top_n,
        order_by=["metrics.rmse ASC"]
    )
    for run in runs:
        train_and_log_model(data_path=data_path, params=run.data.params)
        
    # Select the model with the lowest test RMSE
    exp_best = client.get_experiment_by_name(EXPERIMENT_NAME)
    best_runs = client.search_runs(
        experiment_ids=[exp_best.experiment_id],
        run_view_type=ViewType.ACTIVE_ONLY,
        max_results=1,
        order_by=["metrics.test_rmse ASC"]
    )
    best_run = best_runs[0]
    

    # Register the best model
    # mlflow.register_model( ... )
    model_uri = f"runs:/{best_run.info.run_id}/model"
    model_name = "RandomForestTaxi"
    mlflow.register_model(model_uri=model_uri, name=model_name)

    print(
        f"✅ Registered model '{model_name}' from run {best_run.info.run_id} "
        f"with test_rmse={best_run.data.metrics['test_rmse']}"
    )

if __name__ == '__main__':
    run_register_model()

🏃 View run stately-squid-534 at: http://127.0.0.1:5000/#/experiments/802108489478494946/runs/fcdf1543ff034fd1924697a3681005f8
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/802108489478494946
🏃 View run exultant-jay-24 at: http://127.0.0.1:5000/#/experiments/802108489478494946/runs/d5b1a5c3108946a3bf5f55484d6ec695
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/802108489478494946
🏃 View run brawny-croc-423 at: http://127.0.0.1:5000/#/experiments/802108489478494946/runs/64e169d7996c4635a3e5114c778b32e8
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/802108489478494946
🏃 View run flawless-gull-817 at: http://127.0.0.1:5000/#/experiments/802108489478494946/runs/8cfa324bbddf4a62bae848e99a764a00
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/802108489478494946
🏃 View run blushing-boar-361 at: http://127.0.0.1:5000/#/experiments/802108489478494946/runs/a3cd7f13b13b481a82c1282499e22f28
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/80210848

Successfully registered model 'RandomForestTaxi'.
2025/06/22 23:29:30 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: RandomForestTaxi, version 1


✅ Registered model 'RandomForestTaxi' from run fcdf1543ff034fd1924697a3681005f8 with test_rmse=5.335419588556921


Created version '1' of model 'RandomForestTaxi'.
