In [2]:
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn import datasets


* 'schema_extra' has been renamed to 'json_schema_extra'


In [7]:
mlflow.get_tracking_uri()

'http://localhost:5000'

In [4]:
mlflow.set_tracking_uri("http://localhost:5000")

In [5]:
mlflow.get_tracking_uri()

'http://localhost:5000'

In [8]:
mlflow.set_experiment("iris") # creates an experiment if it doesn't exist

2024/01/17 09:29:54 INFO mlflow.tracking.fluent: Experiment with name 'iris' does not exist. Creating a new experiment.


<Experiment: artifact_location='mlflow-artifacts:/251967034232254676', creation_time=1705480194253, experiment_id='251967034232254676', last_update_time=1705480194253, lifecycle_stage='active', name='iris', tags={}>

In [9]:
iris = datasets.load_iris()
x = iris.data[:, 2:]
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=7)

In [11]:
import joblib

with mlflow.start_run(run_name="Iris RF Experiment") as run:
    
    # Add parameters for tuning
    num_estimators = 100
    mlflow.log_param("num_estimators", num_estimators)

    # Train the model
    rf = RandomForestRegressor(n_estimators=num_estimators)
    rf.fit(X_train, y_train)
    predictions = rf.predict(X_test)

    # Save the model artifact using joblib.dump
    model_path = "random-forest-model.joblib"
    joblib.dump(rf, model_path)
    mlflow.log_artifact(model_path, "model")

    # Log model performance 
    mse = mean_squared_error(y_test, predictions)
    mlflow.log_metric("mse", mse)
    print("  MSE: %f" % mse)

    run_id = run.info.run_uuid
    experiment_id = run.info.experiment_id
    
    # End the MLflow run
    mlflow.end_run()

    # Print artifact URI and run ID
    print(mlflow.get_artifact_uri())
    print("Run ID: %s" % run_id)

  MSE: 0.084680
mlflow-artifacts:/251967034232254676/aa150082d14e47c1a48516df6bda01b5/artifacts
Run ID: a350fca35ac54d63a3472a375584d6b6


In [12]:
model_path = "random-forest-model.joblib"

# Chargemenet du modèle
loaded_model = joblib.load(model_path)

predictions = loaded_model.predict(X_test)
