In [1]:
import os
import pickle
import click

from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import root_mean_squared_error

# Script has been modified to autologging with MLflow.
import mlflow

In [2]:
def load_pickle(filename: str):
    with open(filename, "rb") as f_in:
        return pickle.load(f_in)

In [3]:
def run_train(data_path: str):

    X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
    X_val, y_val = load_pickle(os.path.join(data_path, "val.pkl"))

    rf = RandomForestRegressor(max_depth=10, random_state=0)
    rf.fit(X_train, y_train)
    y_pred = rf.predict(X_val)

    # rmse = mean_squared_error(y_val, y_pred, squared=False) <-- squared is deprecated.
    rmse = root_mean_squared_error(y_val, y_pred)
    return rmse

In [4]:
# Initiate autolog and start run with mlflow
# ---- IMPORTANT: use `mlflow run --no-conda` to use without conda.------
mlflow.set_tracking_uri('sqlite:///mlflow.db')
mlflow.set_experiment('nyc-green-taxi-exp')

# Below sets log_datasets = False to not track training data as data is NumPy array but
# mlflow expects pandas DataFrames.
mlflow.autolog(log_datasets = False)
with mlflow.start_run():
    mlflow.set_tag("developer", "wylie")
    rmse = run_train("./output")
    mlflow.log_metric("rmse", rmse)

2024/05/26 06:05:42 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2024/05/26 06:05:43 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Running upgrade  -> 451aebb31d03, add metric step
INFO  [alembic.runtime.migration] Running upgrade 451aebb31d03 -> 90e64c465722, migrate user column to tags
INFO  [alembic.runtime.migration] Running upgrade 90e64c465722 -> 181f10493468, allow nulls for metric values
INFO  [alembic.runtime.migration] Running upgrade 181f10493468 -> df50e92ffc5e, Add Experiment Tags Table
INFO  [alembic.runtime.migration] Running upgrade df50e92ffc5e -> 7ac759974ad8, Update run tags with larger limit
INFO  [alembic.runtime.migration] Running upgrade 7ac759974ad8 -> 89d4b8295536, create latest metrics table
INFO  [89d4b8295536_create_latest_metrics_table_py] Migration complete!
INFO  