In [17]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression, Lasso, Ridge
from sklearn.metrics import root_mean_squared_error
import pickle
from pathlib import Path

DATA_FOLDER = Path("../../../data")
MODEL_FOLDER = Path("../models")
MLFLOW_DB_PATH = Path("..")

In [18]:
import mlflow

mlflow.set_tracking_uri(f"sqlite:///{MLFLOW_DB_PATH}/mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location=('/Users/alessandro.arlandini/Documents/General DS '
 'learning/DataTalksClub/mlops-zoomcamp/ale-mlops/notebooks/03-training/experiment_tracking/mlruns/1'), creation_time=1713518478408, experiment_id='1', last_update_time=1713518478408, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [19]:
df = pd.read_parquet(DATA_FOLDER / "green_tripdata_2021-01.parquet")
df.dtypes

VendorID                          int64
lpep_pickup_datetime     datetime64[us]
lpep_dropoff_datetime    datetime64[us]
store_and_fwd_flag               object
RatecodeID                      float64
PULocationID                      int64
DOLocationID                      int64
passenger_count                 float64
trip_distance                   float64
fare_amount                     float64
extra                           float64
mta_tax                         float64
tip_amount                      float64
tolls_amount                    float64
ehail_fee                        object
improvement_surcharge           float64
total_amount                    float64
payment_type                    float64
trip_type                       float64
congestion_surcharge            float64
dtype: object

In [5]:
df["duration"] = (
    df.lpep_dropoff_datetime - df.lpep_pickup_datetime
).dt.total_seconds() / 60

In [6]:
df.duration.describe(percentiles=[0.95, 0.98, 0.99])

count    76518.000000
mean        19.927896
std         59.338594
min          0.000000
50%         13.883333
95%         44.000000
98%         56.000000
99%         67.158167
max       1439.600000
Name: duration, dtype: float64

In [7]:
((df.duration >= 1) & (df.duration <= 60)).mean()

0.9658903787344154

In [8]:
df = df[(df.duration >= 1) & (df.duration <= 60)]
df.shape

(73908, 21)

In [9]:
categorical = ["PULocationID", "DOLocationID"]
numerical = ["trip_distance"]
df[categorical] = df[categorical].astype(str)

In [10]:
dv = DictVectorizer()

train_dicts = df[categorical + numerical].to_dict(orient="records")
X_train = dv.fit_transform(train_dicts)

In [11]:
dv.feature_names_

['DOLocationID=1',
 'DOLocationID=10',
 'DOLocationID=100',
 'DOLocationID=101',
 'DOLocationID=102',
 'DOLocationID=106',
 'DOLocationID=107',
 'DOLocationID=108',
 'DOLocationID=109',
 'DOLocationID=11',
 'DOLocationID=111',
 'DOLocationID=112',
 'DOLocationID=113',
 'DOLocationID=114',
 'DOLocationID=115',
 'DOLocationID=116',
 'DOLocationID=117',
 'DOLocationID=118',
 'DOLocationID=119',
 'DOLocationID=12',
 'DOLocationID=120',
 'DOLocationID=121',
 'DOLocationID=122',
 'DOLocationID=123',
 'DOLocationID=124',
 'DOLocationID=125',
 'DOLocationID=126',
 'DOLocationID=127',
 'DOLocationID=128',
 'DOLocationID=129',
 'DOLocationID=13',
 'DOLocationID=130',
 'DOLocationID=131',
 'DOLocationID=132',
 'DOLocationID=133',
 'DOLocationID=134',
 'DOLocationID=135',
 'DOLocationID=136',
 'DOLocationID=137',
 'DOLocationID=138',
 'DOLocationID=139',
 'DOLocationID=14',
 'DOLocationID=140',
 'DOLocationID=141',
 'DOLocationID=142',
 'DOLocationID=143',
 'DOLocationID=144',
 'DOLocationID=145',

DictVectorizer transforms data stored in a dictionary into a vector, in this case, a sparse matrix. It recognizes string fields and one-hot encodes them. It also recognizes numerical fields, and leaves them untouched. If we had not converted PULocation and DOLocation to strings, it would have treated them as numerical fields, which would be wrong for our application (as location numbers are only IDs, there is no notion of operation between them).

By checking the values stored in the resulting matrix one can see that except for the duration column, all the other values are 0s and 1s, consistent with the one-hot encoding. The feature names are also consistent with one-hot encoding for the `DOLocationID` and `PULocationID` fields, as every possible value originates a different column.

In [12]:
np.unique(X_train.toarray()[:, :-1].flatten())

array([0., 1.])

In [13]:
target = "duration"
y_train = df[target].values

In [14]:
lm = LinearRegression()
lm.fit(X_train, y_train)

In [15]:
y_pred = lm.predict(X_train)

In [16]:
root_mean_squared_error(y_train, y_pred)

9.838799799829625

The model predictions have an RMSE of 9.5 minutes. We can probably do better, but this is still the training data. We want to compute scoring metrics on validation data too. Since we will have to repeat pretty much the same steps, we put everything together in functions for convenience.

We are building _preprocessing and training pipelines._

In [21]:
def read_dataframe(filename):
    categorical = ["PULocationID", "DOLocationID"]

    df = pd.read_parquet(filename)

    df["duration"] = (
        df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    ).dt.total_seconds() / 60
    df = df[(df.duration >= 1) & (df.duration <= 60)]
    df[categorical] = df[categorical].astype(str)

    return df

In [22]:
# Read training and validation dataframes
df_train = read_dataframe(DATA_FOLDER / "green_tripdata_2021-01.parquet")
df_val = read_dataframe(DATA_FOLDER / "green_tripdata_2021-02.parquet")
df_train.shape, df_val.shape

((73908, 21), (61921, 21))

In [19]:
# Engineer an interaction feature
# Interaction features combine two (or more) features into one. The idea is that the pair of values matters more than the single values together.
# In our case we will create an interaction feature from PU and DO. The interpretation is that instead of capturing what is the marginal variation due to "pick up here" and "drop off there" separately, we'll try to capture the information carried by "pick up here and drop off there".
# Note that this also reduces in half the number of features, reducing the variance of the models.
interaction = True

# In this case it suffices to combine the string to create a new categorical feature
if interaction:
    df_train["PU_DO"] = df_train["PULocationID"] + "_" + df_train["DOLocationID"]
    df_val["PU_DO"] = df_val["PULocationID"] + "_" + df_val["DOLocationID"]

In [20]:
# Use DictVectorizer to encode categorical features and obtain training and validation matrices
if interaction:
    # When we use the interaction features, we drop the individual features.
    categorical = ["PU_DO"]
else:
    categorical = ["PULocationID", "DOLocationID"]
numerical = ["trip_distance"]

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient="records")
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient="records")
X_val = dv.transform(val_dicts)

In [21]:
# Obtain training and validation targets
target = "duration"
y_train = df_train[target]
y_val = df_val[target]

In [22]:
# Train and validate a linear regression model
lm = LinearRegression()
lm.fit(X_train, y_train)

y_pred = lm.predict(X_val)
root_mean_squared_error(y_val, y_pred)

7.758715204520257

In [None]:
# Train and validate a Lasso model with mlflow logging
with mlflow.start_run():
    mlflow.set_tag("developer", "alessandro")
    mlflow.log_param("train-data-path", DATA_FOLDER / "green_tripdata_2021-01.parquet")
    mlflow.log_param("valid-data-path", DATA_FOLDER / "green_tripdata_2021-02.parquet")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)

    ls = Lasso(alpha)
    ls.fit(X_train, y_train)

    y_pred = ls.predict(X_val)
    rmse = root_mean_squared_error(y_val, y_pred)
    mlflow.log_metric("rmse", rmse)

It is common to change the code and try different configurations and parameters while experimenting. When doing that, it's important to record history to allow for proper comparisons, reproducibility and observability.

One simple method of recording what we are doing when training models is to manually log the desired quantities using mlflow. In this case we use the context environment `mlflow.start_run` to record the actions performed by the code in a single experiment run. Then we track:
* The name of the developer in a tag
* The hyperparameter `alpha`: this is a parameter that we want to tune, so it's important to keep track of the values we tried.
* The path to the training and validation data: this is a very simplified method of data versioning. We can do much better than this, but it's a start in pointing to the data used for the run.
* The resulting RMSE. This is the error metric that we obtain from the model, and it's essential to track it in order to compare runs and find the best one (that with the lowest error).

In [30]:
import xgboost as xgb

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [32]:
# we use xgb's internal DMatrix to store the data for trainig and cross-validation
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [32]:
def objective(params):
    """Take a set of parameters, train xgboost using those parameters and return
    the RMSE loss obtained with that model."""

    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=100,
            evals=[(valid, "validation")],
            early_stopping_rounds=15,
        )
        y_pred = booster.predict(valid)
        rmse = root_mean_squared_error(y_val, y_pred)
        mlflow.log_metric("rmse", rmse)

    return {"loss": rmse, "status": STATUS_OK}

In [33]:
# This is the search space that we want to explore with our (hyper)parameters search
# Hyperopt will loop over all the combinations of parameters that can be generated from this
# parameter space.
search_space = {
    "max_depth": scope.int(hp.quniform("max_depth", 4, 20, 1)),
    # loguniform generates values whose logarithms are uniformly spaced.
    # In this case the logarithms will be -3, -2, -1, 0
    # so the values will be exp(-3), exp(-2), exp(-1), 1
    "learning_rate": hp.loguniform("learning_rate", -3, 0),
    "reg_alpha": hp.loguniform("reg_alpha", -5, -1),
    "reg_lambda": hp.loguniform("reg_lambda", -6, -1),
    "min_child_weight": hp.loguniform("min_child_weight", -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

# fmin minimizes a function over a domain, or more precisely finds the minimum value of a function
# over a search space (trying only the inputs specified in the search space).
# By viewing our model as a function (params) -> (loss) we can use fmin to find the set of hyperparameters
# that minimizes the loss, thus finding the best model.
best_result = fmin(
    # we wrote the "objective" function to take a set of hyperparams and returning the loss
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=40,
    # storage for iterations and results
    trials=Trials()
)

[0]	validation-rmse:11.75673                          
[1]	validation-rmse:11.33466                          
[2]	validation-rmse:10.94414                          
  0%|          | 0/40 [00:00<?, ?trial/s, best loss=?]




[3]	validation-rmse:10.58347                          
[4]	validation-rmse:10.25118                          
[5]	validation-rmse:9.94511                           
[6]	validation-rmse:9.66379                           
[7]	validation-rmse:9.40561                           
[8]	validation-rmse:9.16844                           
[9]	validation-rmse:8.95131                           
[10]	validation-rmse:8.75266                          
[11]	validation-rmse:8.57070                          
[12]	validation-rmse:8.40437                          
[13]	validation-rmse:8.25221                          
[14]	validation-rmse:8.11396                          
[15]	validation-rmse:7.98707                          
[16]	validation-rmse:7.87161                          
[17]	validation-rmse:7.76634                          
[18]	validation-rmse:7.67016                          
[19]	validation-rmse:7.58359                          
[20]	validation-rmse:7.50406                          
[21]	valid




[2]	validation-rmse:9.09541                                                    
[3]	validation-rmse:8.49709                                                    
[4]	validation-rmse:8.04644                                                    
[5]	validation-rmse:7.71006                                                    
[6]	validation-rmse:7.45886                                                    
[7]	validation-rmse:7.27343                                                    
[8]	validation-rmse:7.13699                                                    
[9]	validation-rmse:7.03357                                                    
[10]	validation-rmse:6.95608                                                   
[11]	validation-rmse:6.89827                                                   
[12]	validation-rmse:6.85468                                                   
[13]	validation-rmse:6.82133                                                   
[14]	validation-rmse:6.79526            




[1]	validation-rmse:8.91050                                                    
[2]	validation-rmse:8.07339                                                    
[3]	validation-rmse:7.55109                                                    
[4]	validation-rmse:7.22741                                                    
[5]	validation-rmse:7.02734                                                    
[6]	validation-rmse:6.89884                                                    
[7]	validation-rmse:6.82018                                                    
[8]	validation-rmse:6.76804                                                    
[9]	validation-rmse:6.73511                                                    
[10]	validation-rmse:6.71189                                                   
[11]	validation-rmse:6.69468                                                   
[12]	validation-rmse:6.68351                                                   
[13]	validation-rmse:6.67432            




[3]	validation-rmse:9.59319                                                     
[4]	validation-rmse:9.14726                                                     
[5]	validation-rmse:8.76644                                                     
[6]	validation-rmse:8.44241                                                     
[7]	validation-rmse:8.16699                                                     
[8]	validation-rmse:7.93409                                                     
[9]	validation-rmse:7.73767                                                     
[10]	validation-rmse:7.57233                                                    
[11]	validation-rmse:7.43303                                                    
[12]	validation-rmse:7.31556                                                    
[13]	validation-rmse:7.21633                                                    
[14]	validation-rmse:7.13312                                                    
[15]	validation-rmse:7.06284




[2]	validation-rmse:10.74558                                                    
[3]	validation-rmse:10.34103                                                    
[4]	validation-rmse:9.97349                                                     
[5]	validation-rmse:9.64196                                                     
[6]	validation-rmse:9.34099                                                     
[7]	validation-rmse:9.06954                                                     
[8]	validation-rmse:8.82469                                                     
[9]	validation-rmse:8.60591                                                     
[10]	validation-rmse:8.40782                                                    
[11]	validation-rmse:8.22895                                                    
[12]	validation-rmse:8.06904                                                    
[13]	validation-rmse:7.92709                                                    
[14]	validation-rmse:7.79841




[6]	validation-rmse:7.58369                                                     
[7]	validation-rmse:7.38116                                                     
[8]	validation-rmse:7.22660                                                     
[9]	validation-rmse:7.10921                                                     
[10]	validation-rmse:7.01993                                                    
[11]	validation-rmse:6.95261                                                    
[12]	validation-rmse:6.89904                                                    
[13]	validation-rmse:6.85780                                                    
[14]	validation-rmse:6.82672                                                    
[15]	validation-rmse:6.80134                                                    
[16]	validation-rmse:6.78129                                                    
[17]	validation-rmse:6.76507                                                    
[18]	validation-rmse:6.75042




[4]	validation-rmse:7.05072                                                     
[5]	validation-rmse:6.91609                                                     
[6]	validation-rmse:6.83806                                                     
[7]	validation-rmse:6.79234                                                     
[8]	validation-rmse:6.76358                                                     
[9]	validation-rmse:6.74339                                                     
[10]	validation-rmse:6.72970                                                    
[11]	validation-rmse:6.71983                                                    
[12]	validation-rmse:6.71365                                                    
[13]	validation-rmse:6.70803                                                    
[14]	validation-rmse:6.70591                                                    
[15]	validation-rmse:6.70390                                                    
[16]	validation-rmse:6.70209




[3]	validation-rmse:7.81056                                                     
[4]	validation-rmse:7.43700                                                     
[5]	validation-rmse:7.19230                                                     
[6]	validation-rmse:7.03150                                                     
[7]	validation-rmse:6.92593                                                     
[8]	validation-rmse:6.85332                                                     
[9]	validation-rmse:6.80338                                                     
[10]	validation-rmse:6.76767                                                    
[11]	validation-rmse:6.74375                                                    
[12]	validation-rmse:6.72451                                                    
[13]	validation-rmse:6.70992                                                    
[14]	validation-rmse:6.69823                                                    
[15]	validation-rmse:6.69030




[4]	validation-rmse:9.64996                                                     
[5]	validation-rmse:9.29277                                                     
[6]	validation-rmse:8.97688                                                     
[7]	validation-rmse:8.69811                                                     
[8]	validation-rmse:8.45205                                                     
[9]	validation-rmse:8.23520                                                     
[10]	validation-rmse:8.04476                                                    
[11]	validation-rmse:7.87748                                                    
[12]	validation-rmse:7.73086                                                    
[13]	validation-rmse:7.60303                                                    
[14]	validation-rmse:7.49047                                                    
[15]	validation-rmse:7.39173                                                    
[16]	validation-rmse:7.30594




[6]	validation-rmse:6.65412                                                     
[7]	validation-rmse:6.64761                                                     
[8]	validation-rmse:6.64394                                                     
[9]	validation-rmse:6.63807                                                     
[10]	validation-rmse:6.63260                                                    
[11]	validation-rmse:6.62789                                                    
[12]	validation-rmse:6.62281                                                    
[13]	validation-rmse:6.62101                                                    
[14]	validation-rmse:6.61699                                                    
[15]	validation-rmse:6.61510                                                    
[16]	validation-rmse:6.61339                                                    
[17]	validation-rmse:6.60982                                                    
[18]	validation-rmse:6.60834




[9]	validation-rmse:6.70536                                                     
[10]	validation-rmse:6.70397                                                    
[11]	validation-rmse:6.70332                                                    
[12]	validation-rmse:6.69837                                                    
[13]	validation-rmse:6.69594                                                    
[14]	validation-rmse:6.69116                                                    
[15]	validation-rmse:6.68833                                                    
[16]	validation-rmse:6.68529                                                    
[17]	validation-rmse:6.68280                                                    
[18]	validation-rmse:6.67982                                                    
[19]	validation-rmse:6.67475                                                    
[20]	validation-rmse:6.66715                                                    
[21]	validation-rmse:6.66544




[7]	validation-rmse:8.96586                                                     
[8]	validation-rmse:8.72144                                                     
[9]	validation-rmse:8.50301                                                     
[10]	validation-rmse:8.30773                                                    
[11]	validation-rmse:8.13397                                                    
[12]	validation-rmse:7.98004                                                    
[13]	validation-rmse:7.84304                                                    
[14]	validation-rmse:7.72127                                                    
[15]	validation-rmse:7.61319                                                    
[16]	validation-rmse:7.51709                                                    
[17]	validation-rmse:7.43182                                                    
[18]	validation-rmse:7.35638                                                    
[19]	validation-rmse:7.28964




[9]	validation-rmse:8.90803                                                     
[10]	validation-rmse:8.71167                                                    
[11]	validation-rmse:8.53350                                                    
[12]	validation-rmse:8.37179                                                    
[13]	validation-rmse:8.22448                                                    
[14]	validation-rmse:8.09060                                                    
[15]	validation-rmse:7.96938                                                    
[16]	validation-rmse:7.85966                                                    
[17]	validation-rmse:7.75980                                                    
[18]	validation-rmse:7.67024                                                    
[19]	validation-rmse:7.58840                                                    
[20]	validation-rmse:7.51562                                                    
[21]	validation-rmse:7.44865




[2]	validation-rmse:8.60644                                                     
[3]	validation-rmse:8.01837                                                     
[4]	validation-rmse:7.60741                                                     
[5]	validation-rmse:7.32331                                                     
[6]	validation-rmse:7.12575                                                     
[7]	validation-rmse:6.99068                                                     
[8]	validation-rmse:6.89636                                                     
[9]	validation-rmse:6.82604                                                     
[10]	validation-rmse:6.77669                                                    
[11]	validation-rmse:6.73826                                                    
[12]	validation-rmse:6.71410                                                    
[13]	validation-rmse:6.69446                                                    
[14]	validation-rmse:6.68061




[1]	validation-rmse:6.99527                                                     
[2]	validation-rmse:6.72939                                                     
[3]	validation-rmse:6.64648                                                     
[4]	validation-rmse:6.61474                                                     
[5]	validation-rmse:6.60094                                                     
[6]	validation-rmse:6.59290                                                     
[7]	validation-rmse:6.58843                                                     
[8]	validation-rmse:6.58489                                                     
[9]	validation-rmse:6.58089                                                     
[10]	validation-rmse:6.57624                                                    
[11]	validation-rmse:6.56856                                                    
[12]	validation-rmse:6.56203                                                    
[13]	validation-rmse:6.55749




[10]	validation-rmse:6.77106                                                    
[11]	validation-rmse:6.76558                                                    
[12]	validation-rmse:6.76218                                                    
[13]	validation-rmse:6.76072                                                    
[14]	validation-rmse:6.75435                                                    
[15]	validation-rmse:6.74977                                                    
[16]	validation-rmse:6.74614                                                    
[17]	validation-rmse:6.74353                                                    
[18]	validation-rmse:6.73854                                                    
[19]	validation-rmse:6.73754                                                    
[20]	validation-rmse:6.73575                                                    
[21]	validation-rmse:6.73516                                                    
[22]	validation-rmse:6.73351




[9]	validation-rmse:7.41937                                                     
[10]	validation-rmse:7.29456                                                    
[11]	validation-rmse:7.19450                                                    
[12]	validation-rmse:7.11402                                                    
[13]	validation-rmse:7.04893                                                    
[14]	validation-rmse:6.99637                                                    
[15]	validation-rmse:6.95441                                                    
[16]	validation-rmse:6.92004                                                    
[17]	validation-rmse:6.89225                                                    
[18]	validation-rmse:6.86921                                                    
[19]	validation-rmse:6.85104                                                    
[20]	validation-rmse:6.83475                                                    
[21]	validation-rmse:6.82129




[2]	validation-rmse:9.44229                                                     
[3]	validation-rmse:8.85586                                                     
[4]	validation-rmse:8.38933                                                     
[5]	validation-rmse:8.02214                                                     
[6]	validation-rmse:7.73297                                                     
[7]	validation-rmse:7.50504                                                     
[8]	validation-rmse:7.32841                                                     
[9]	validation-rmse:7.19224                                                     
[10]	validation-rmse:7.08428                                                    
[11]	validation-rmse:6.99901                                                    
[12]	validation-rmse:6.93237                                                    
[13]	validation-rmse:6.87854                                                    
[14]	validation-rmse:6.83651




[10]	validation-rmse:6.81573                                                    
[11]	validation-rmse:6.81361                                                    
[12]	validation-rmse:6.80507                                                    
[13]	validation-rmse:6.80290                                                    
[14]	validation-rmse:6.80014                                                    
[15]	validation-rmse:6.79775                                                    
[16]	validation-rmse:6.79563                                                    
[17]	validation-rmse:6.79264                                                    
[18]	validation-rmse:6.78879                                                    
[19]	validation-rmse:6.78760                                                    
[20]	validation-rmse:6.78558                                                    
[21]	validation-rmse:6.78430                                                    
[22]	validation-rmse:6.77739




[3]	validation-rmse:6.70747                                                     
[4]	validation-rmse:6.70225                                                     
[5]	validation-rmse:6.69685                                                     
[6]	validation-rmse:6.69026                                                     
[7]	validation-rmse:6.68136                                                     
[8]	validation-rmse:6.67833                                                     
[9]	validation-rmse:6.66979                                                     
[10]	validation-rmse:6.66212                                                    
[11]	validation-rmse:6.65432                                                    
[12]	validation-rmse:6.64646                                                    
[13]	validation-rmse:6.64796                                                    
[14]	validation-rmse:6.64596                                                    
[15]	validation-rmse:6.64319




[1]	validation-rmse:7.26550                                                     
[2]	validation-rmse:6.84725                                                     
[3]	validation-rmse:6.69771                                                     
[4]	validation-rmse:6.63830                                                     
[5]	validation-rmse:6.60829                                                     
[6]	validation-rmse:6.59484                                                     
[7]	validation-rmse:6.58523                                                     
[8]	validation-rmse:6.58163                                                     
[9]	validation-rmse:6.57663                                                     
[10]	validation-rmse:6.57032                                                    
[11]	validation-rmse:6.56602                                                    
[12]	validation-rmse:6.56095                                                    
[13]	validation-rmse:6.55505




[1]	validation-rmse:7.30380                                                   
[2]	validation-rmse:6.86613                                                   
[3]	validation-rmse:6.70861                                                   
[4]	validation-rmse:6.64395                                                   
[5]	validation-rmse:6.61216                                                   
[6]	validation-rmse:6.59827                                                   
[7]	validation-rmse:6.58701                                                   
[8]	validation-rmse:6.58217                                                   
[9]	validation-rmse:6.57713                                                   
[10]	validation-rmse:6.57305                                                  
[11]	validation-rmse:6.56675                                                  
[12]	validation-rmse:6.56334                                                  
[13]	validation-rmse:6.55621                        




[1]	validation-rmse:7.42636                                                   
[2]	validation-rmse:6.92851                                                   
[3]	validation-rmse:6.74078                                                   
[4]	validation-rmse:6.65747                                                   
[5]	validation-rmse:6.61644                                                   
[6]	validation-rmse:6.59792                                                   
[7]	validation-rmse:6.58650                                                   
[8]	validation-rmse:6.58271                                                   
[9]	validation-rmse:6.57362                                                   
[10]	validation-rmse:6.56907                                                  
[11]	validation-rmse:6.56200                                                  
[12]	validation-rmse:6.55933                                                  
[13]	validation-rmse:6.55365                        




[1]	validation-rmse:7.44935                                                   
[2]	validation-rmse:6.94958                                                   
[3]	validation-rmse:6.75618                                                   
[4]	validation-rmse:6.67863                                                   
[5]	validation-rmse:6.64240                                                   
[6]	validation-rmse:6.62089                                                   
[7]	validation-rmse:6.60606                                                   
[8]	validation-rmse:6.60491                                                   
[9]	validation-rmse:6.60155                                                   
[10]	validation-rmse:6.59312                                                  
[11]	validation-rmse:6.58982                                                  
[12]	validation-rmse:6.58623                                                  
[13]	validation-rmse:6.58104                        




[1]	validation-rmse:7.66609                                                   
[2]	validation-rmse:7.10222                                                   
[3]	validation-rmse:6.87206                                                   
[4]	validation-rmse:6.76708                                                   
[5]	validation-rmse:6.71980                                                   
[6]	validation-rmse:6.69627                                                   
[7]	validation-rmse:6.68094                                                   
[8]	validation-rmse:6.67024                                                   
[9]	validation-rmse:6.66438                                                   
[10]	validation-rmse:6.66048                                                  
[11]	validation-rmse:6.65719                                                  
[12]	validation-rmse:6.65511                                                  
[13]	validation-rmse:6.65269                        




[1]	validation-rmse:6.87496                                                   
[2]	validation-rmse:6.67857                                                   
[3]	validation-rmse:6.62456                                                   
[4]	validation-rmse:6.60289                                                   
[5]	validation-rmse:6.59332                                                   
[6]	validation-rmse:6.57983                                                   
[7]	validation-rmse:6.57435                                                   
[8]	validation-rmse:6.56953                                                   
[9]	validation-rmse:6.56221                                                   
[10]	validation-rmse:6.55691                                                  
[11]	validation-rmse:6.55156                                                  
[12]	validation-rmse:6.54857                                                  
[13]	validation-rmse:6.54374                        




[2]	validation-rmse:6.71457                                                     
[3]	validation-rmse:6.65902                                                     
[4]	validation-rmse:6.63712                                                     
[5]	validation-rmse:6.63090                                                     
[6]	validation-rmse:6.61934                                                     
[7]	validation-rmse:6.61324                                                     
[8]	validation-rmse:6.60734                                                     
[9]	validation-rmse:6.60271                                                     
[10]	validation-rmse:6.59891                                                    
[11]	validation-rmse:6.59411                                                    
[12]	validation-rmse:6.59023                                                    
[13]	validation-rmse:6.58352                                                    
[14]	validation-rmse:6.58009




[1]	validation-rmse:8.15931                                                     
[2]	validation-rmse:7.41883                                                     
[3]	validation-rmse:7.04421                                                     
[4]	validation-rmse:6.85125                                                     
[5]	validation-rmse:6.75018                                                     
[6]	validation-rmse:6.69393                                                     
[7]	validation-rmse:6.66103                                                     
[8]	validation-rmse:6.63875                                                     
[9]	validation-rmse:6.62209                                                     
[10]	validation-rmse:6.61075                                                    
[11]	validation-rmse:6.60227                                                    
[12]	validation-rmse:6.59757                                                    
[13]	validation-rmse:6.59389




[1]	validation-rmse:6.87058                                                     
[2]	validation-rmse:6.68854                                                     
[3]	validation-rmse:6.63853                                                     
[4]	validation-rmse:6.62128                                                     
[5]	validation-rmse:6.61154                                                     
[6]	validation-rmse:6.60062                                                     
[7]	validation-rmse:6.59626                                                     
[8]	validation-rmse:6.59319                                                     
[9]	validation-rmse:6.58831                                                     
[10]	validation-rmse:6.58335                                                    
[11]	validation-rmse:6.57819                                                    
[12]	validation-rmse:6.57421                                                    
[13]	validation-rmse:6.56960




[5]	validation-rmse:6.71133                                                     
[6]	validation-rmse:6.70410                                                     
[7]	validation-rmse:6.70376                                                     
[8]	validation-rmse:6.69987                                                     
[9]	validation-rmse:6.69737                                                     
[10]	validation-rmse:6.69111                                                    
[11]	validation-rmse:6.68574                                                    
[12]	validation-rmse:6.68043                                                    
[13]	validation-rmse:6.67510                                                    
[14]	validation-rmse:6.67220                                                    
[15]	validation-rmse:6.66636                                                    
[16]	validation-rmse:6.66341                                                    
[17]	validation-rmse:6.65700




[1]	validation-rmse:7.99805                                                     
[2]	validation-rmse:7.32016                                                     
[3]	validation-rmse:6.99108                                                     
[4]	validation-rmse:6.82923                                                     
[5]	validation-rmse:6.75144                                                     
[6]	validation-rmse:6.70832                                                     
[7]	validation-rmse:6.68318                                                     
[8]	validation-rmse:6.66736                                                     
[9]	validation-rmse:6.65837                                                     
[10]	validation-rmse:6.64773                                                    
[11]	validation-rmse:6.64506                                                    
[12]	validation-rmse:6.64218                                                    
[13]	validation-rmse:6.63654




[1]	validation-rmse:8.54934                                                     
[2]	validation-rmse:7.73035                                                     
[3]	validation-rmse:7.26054                                                     
[4]	validation-rmse:6.99435                                                     
[5]	validation-rmse:6.84083                                                     
[6]	validation-rmse:6.75058                                                     
[7]	validation-rmse:6.69162                                                     
[8]	validation-rmse:6.65362                                                     
[9]	validation-rmse:6.63026                                                     
[10]	validation-rmse:6.61189                                                    
[11]	validation-rmse:6.59686                                                    
[12]	validation-rmse:6.58923                                                    
[13]	validation-rmse:6.58286




[1]	validation-rmse:6.73795                                                     
[2]	validation-rmse:6.65901                                                     
[3]	validation-rmse:6.63583                                                     
[4]	validation-rmse:6.62809                                                     
[5]	validation-rmse:6.62139                                                     
[6]	validation-rmse:6.61507                                                     
[7]	validation-rmse:6.60311                                                     
[8]	validation-rmse:6.59505                                                     
[9]	validation-rmse:6.59221                                                     
[10]	validation-rmse:6.58702                                                    
[11]	validation-rmse:6.58283                                                    
[12]	validation-rmse:6.57815                                                    
[13]	validation-rmse:6.57192




[2]	validation-rmse:6.70379                                                     
[3]	validation-rmse:6.67906                                                     
[4]	validation-rmse:6.67543                                                     
[5]	validation-rmse:6.66840                                                     
[6]	validation-rmse:6.66298                                                     
[7]	validation-rmse:6.65132                                                     
[8]	validation-rmse:6.64918                                                     
[9]	validation-rmse:6.64167                                                     
[10]	validation-rmse:6.63428                                                    
[11]	validation-rmse:6.62756                                                    
[12]	validation-rmse:6.62061                                                    
[13]	validation-rmse:6.61673                                                    
[14]	validation-rmse:6.61438




[2]	validation-rmse:8.54068                                                     
[3]	validation-rmse:7.95462                                                     
[4]	validation-rmse:7.55100                                                     
[5]	validation-rmse:7.27483                                                     
[6]	validation-rmse:7.08630                                                     
[7]	validation-rmse:6.95856                                                     
[8]	validation-rmse:6.86738                                                     
[9]	validation-rmse:6.80209                                                     
[10]	validation-rmse:6.75533                                                    
[11]	validation-rmse:6.72253                                                    
[12]	validation-rmse:6.69947                                                    
[13]	validation-rmse:6.68126                                                    
[14]	validation-rmse:6.66662




[1]	validation-rmse:6.75025                                                     
[2]	validation-rmse:6.72276                                                     
[3]	validation-rmse:6.71482                                                     
[4]	validation-rmse:6.71140                                                     
[5]	validation-rmse:6.70145                                                     
[6]	validation-rmse:6.70093                                                     
[7]	validation-rmse:6.69821                                                     
[8]	validation-rmse:6.69480                                                     
[9]	validation-rmse:6.68780                                                     
[10]	validation-rmse:6.67989                                                    
[11]	validation-rmse:6.67540                                                    
[12]	validation-rmse:6.67242                                                    
[13]	validation-rmse:6.66972




[3]	validation-rmse:6.72269                                                     
[4]	validation-rmse:6.68022                                                     
[5]	validation-rmse:6.65597                                                     
[6]	validation-rmse:6.64924                                                     
[7]	validation-rmse:6.64129                                                     
[8]	validation-rmse:6.63524                                                     
[9]	validation-rmse:6.62865                                                     
[10]	validation-rmse:6.62555                                                    
[11]	validation-rmse:6.62127                                                    
[12]	validation-rmse:6.61806                                                    
[13]	validation-rmse:6.61426                                                    
[14]	validation-rmse:6.61081                                                    
[15]	validation-rmse:6.60853




[1]	validation-rmse:10.63236                                                    
[2]	validation-rmse:10.00499                                                    
[3]	validation-rmse:9.46983                                                     
[4]	validation-rmse:9.01274                                                     
[5]	validation-rmse:8.62726                                                     
[6]	validation-rmse:8.30292                                                     
[7]	validation-rmse:8.03034                                                     
[8]	validation-rmse:7.80268                                                     
[9]	validation-rmse:7.61021                                                     
[10]	validation-rmse:7.45183                                                    
[11]	validation-rmse:7.31781                                                    
[12]	validation-rmse:7.20633                                                    
[13]	validation-rmse:7.11434




[3]	validation-rmse:10.76076                                                    
[4]	validation-rmse:10.45624                                                    
[5]	validation-rmse:10.17231                                                    
[6]	validation-rmse:9.90808                                                     
[7]	validation-rmse:9.66230                                                     
[8]	validation-rmse:9.43364                                                     
[9]	validation-rmse:9.22154                                                     
[10]	validation-rmse:9.02530                                                    
[11]	validation-rmse:8.84326                                                    
[12]	validation-rmse:8.67464                                                    
[13]	validation-rmse:8.51902                                                    
[14]	validation-rmse:8.37511                                                    
[15]	validation-rmse:8.24226




[3]	validation-rmse:6.91707                                                     
[4]	validation-rmse:6.78676                                                     
[5]	validation-rmse:6.72106                                                     
[6]	validation-rmse:6.68491                                                     
[7]	validation-rmse:6.66770                                                     
[8]	validation-rmse:6.65334                                                     
[9]	validation-rmse:6.64257                                                     
[10]	validation-rmse:6.63878                                                    
[11]	validation-rmse:6.63440                                                    
[12]	validation-rmse:6.63002                                                    
[13]	validation-rmse:6.62577                                                    
[14]	validation-rmse:6.62303                                                    
[15]	validation-rmse:6.62118

In [35]:
# Training of the best model as inferred from the mlflow ui, by choosing the model that minimizes the RMSE loss
params = {
    "learning_rate": 0.6424019772458974,
    "max_depth": 20,
    "min_child_weight": 2.2694144028711833,
    "objective": "reg:linear",
    "reg_alpha": 0.025551415216516424,
    "reg_lambda": 0.009147735459264332,
    "seed": 42,
}

# Log using autologging, which automatically logs plenty info about the run
mlflow.xgboost.autolog()

booster = xgb.train(
    params=params,
    dtrain=train,
    num_boost_round=100,
    evals=[(valid, "validation")],
    early_stopping_rounds=15,
)

2024/04/19 16:27:51 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'b718c906f65f418fb57e0bc5c2b58959', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current xgboost workflow


[0]	validation-rmse:7.79040




[1]	validation-rmse:6.87496
[2]	validation-rmse:6.67857
[3]	validation-rmse:6.62456
[4]	validation-rmse:6.60289
[5]	validation-rmse:6.59332
[6]	validation-rmse:6.57983
[7]	validation-rmse:6.57435
[8]	validation-rmse:6.56953
[9]	validation-rmse:6.56221
[10]	validation-rmse:6.55691
[11]	validation-rmse:6.55156
[12]	validation-rmse:6.54857
[13]	validation-rmse:6.54374
[14]	validation-rmse:6.53648
[15]	validation-rmse:6.53111
[16]	validation-rmse:6.52934
[17]	validation-rmse:6.52656
[18]	validation-rmse:6.52193
[19]	validation-rmse:6.51718
[20]	validation-rmse:6.51076
[21]	validation-rmse:6.50723
[22]	validation-rmse:6.50659
[23]	validation-rmse:6.50324
[24]	validation-rmse:6.49680
[25]	validation-rmse:6.49462
[26]	validation-rmse:6.49068
[27]	validation-rmse:6.48778
[28]	validation-rmse:6.48296
[29]	validation-rmse:6.48014
[30]	validation-rmse:6.47863
[31]	validation-rmse:6.47687
[32]	validation-rmse:6.47484
[33]	validation-rmse:6.47093
[34]	validation-rmse:6.46951
[35]	validation-rmse:6.



### Model management

This tracks what is happening during model training, but it does not track the model itself. This can also be done in mlflow, which has the following benefits compared to folder-based logging:
* automated instead of manual (less error-prone)
* reliable versioning
* model lineage (users can retrieve how the model was trained: parameters, dataset, etc.)

In [None]:
# First approach at logging the model itself in mlflow: logging as an artifact
mlflow.log_artifact(MODEL_FOLDER / 'lin_reg.bin', artifact_path='models_pickle')

In [35]:
# Second approach: log the model inside of a run, using the appropriate mlflow API

with mlflow.start_run():
    # Training of the best model as inferred from the mlflow ui, by choosing the model that minimizes the RMSE loss
    params = {
        "learning_rate": 0.6424019772458974,
        "max_depth": 20,
        "min_child_weight": 2.2694144028711833,
        "objective": "reg:linear",
        "reg_alpha": 0.025551415216516424,
        "reg_lambda": 0.009147735459264332,
        "seed": 42,
    }
    
    mlflow.log_params(params)


    booster = xgb.train(
        params=params,
        dtrain=train,
        num_boost_round=100,
        evals=[(valid, "validation")],
        early_stopping_rounds=15,
    )
    
    # Logging it as a model using the mlflow API for the model framework (xgboost)
    mlflow.xgboost.log_model(booster, artifact_path='models_mlflow')
    
    # Logging the preprocessor as an artifact
    # First we save it as pickle, and then we log it by referencing the path
    with open(MODEL_FOLDER / 'preprocessor.b', 'wb') as f_out:
        pickle.dump(dv, f_out)

    mlflow.log_artifact(MODEL_FOLDER / "preprocessor.b", artifact_path="preprocessor")

[0]	validation-rmse:7.79040




[1]	validation-rmse:6.87496
[2]	validation-rmse:6.67857
[3]	validation-rmse:6.62456
[4]	validation-rmse:6.60289
[5]	validation-rmse:6.59332
[6]	validation-rmse:6.57983
[7]	validation-rmse:6.57435
[8]	validation-rmse:6.56953
[9]	validation-rmse:6.56221
[10]	validation-rmse:6.55691
[11]	validation-rmse:6.55156
[12]	validation-rmse:6.54857
[13]	validation-rmse:6.54374
[14]	validation-rmse:6.53648
[15]	validation-rmse:6.53111
[16]	validation-rmse:6.52934
[17]	validation-rmse:6.52656
[18]	validation-rmse:6.52193
[19]	validation-rmse:6.51718
[20]	validation-rmse:6.51076
[21]	validation-rmse:6.50723
[22]	validation-rmse:6.50659
[23]	validation-rmse:6.50324
[24]	validation-rmse:6.49680
[25]	validation-rmse:6.49462
[26]	validation-rmse:6.49068
[27]	validation-rmse:6.48778
[28]	validation-rmse:6.48296
[29]	validation-rmse:6.48014
[30]	validation-rmse:6.47863
[31]	validation-rmse:6.47687
[32]	validation-rmse:6.47484
[33]	validation-rmse:6.47093
[34]	validation-rmse:6.46951
[35]	validation-rmse:6.



In the above code we also logged the `DictVectorizer`: this is a preprocessor, not part of the model. It is used to generate new features from the raw data, that are then used to train the model.

In the future, when we want to make predictions with the model and we are given raw data as input, we want to preprocess that data in the same way that we did with the training data. It is then essential that we log the preprocessor for future use. We log it as an artifact.

In order to make predictions, we load the model from mlflow. In our case there are two different "flavours" that we can use: `pyfunc` which means to load it as a generic python function, and `xgboost` which loads the model as an xgboost object (giving back the original object).

In [36]:
# Load model as a PyFuncModel.
logged_model = 'runs:/0ebd7485b47c43ab9b53ab72548bccde/models_mlflow'

loaded_model = mlflow.pyfunc.load_model(logged_model)

loaded_model



mlflow.pyfunc.loaded_model:
  artifact_path: models_mlflow
  flavor: mlflow.xgboost
  run_id: 0ebd7485b47c43ab9b53ab72548bccde

In [38]:
# Load model as a xgboost model
xgboost_model = mlflow.xgboost.load_model(logged_model)

xgboost_model



<xgboost.core.Booster at 0x3515d8950>

In [40]:
y_pred = xgboost_model.predict(valid)
y_pred[:10]

array([15.175811 ,  7.352664 , 17.564875 , 24.57404  ,  9.3976755,
       17.19716  , 10.984803 ,  8.377045 ,  8.945719 , 19.10443  ],
      dtype=float32)