In [17]:
import pandas as pd

from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.metrics import root_mean_squared_error

import mlflow

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

import xgboost as xgb

import pickle

In [10]:
pd.options.mode.chained_assignment = None  # default='warn'

In [11]:
# mlflow ui --backend-store-uri sqlite:///mlflow.db
mlflow.set_tracking_uri(uri="sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/Users/bastienwinant/Desktop/Projects/mlops-zoomcamp/02-experiment-tracking/mlruns/1', creation_time=1721809120839, experiment_id='1', last_update_time=1721809120839, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

## Data Processing

In [12]:
def import_data(url):
  return pd.read_parquet(url)

In [13]:
def process_data(df):
  df['duration'] = (df.lpep_dropoff_datetime - df.lpep_pickup_datetime).apply(lambda x: x.total_seconds() / 60)
  df = df.loc[(df.duration >= 1) & (df.duration <= 60), :]

  categorical = ['PULocationID', 'DOLocationID']
  df[categorical] = df[categorical].astype(str)
  df['PU_DO'] = df.PULocationID + '_' + df.DOLocationID
  categorical.append('PU_DO')
  
  numerical = ['trip_distance', 'duration']

  return df[categorical + numerical]

In [14]:
def transform_data(df, dv=None):
  # predictors = ['PULocationID', 'DOLocationID']
  predictors = ['PU_DO', 'trip_distance']
  target = 'duration'

  df_dicts = df[predictors].to_dict(orient='records')

  if dv:
    X = dv.transform(df_dicts)
  else:
    dv = DictVectorizer()
    X = dv.fit_transform(df_dicts)
  
  y = df[target].values

  return X, y, dv

In [15]:
def compute_error(X, y, model):
  preds = model.predict(X)
  error = root_mean_squared_error(preds, y)

  return error

In [16]:
train_url = "https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-01.parquet"
val_url = "https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-02.parquet"

train_df = import_data(train_url)
train_df = process_data(train_df)
X_train, y_train, dv = transform_data(train_df)

val_df = import_data(val_url)
val_df = process_data(val_df)
X_val, y_val, _ = transform_data(val_df, dv)

## Manual Logging

In [9]:
alpha = .01
model = Lasso(alpha=alpha)


model.fit(X_train, y_train)
rmse = compute_error(X_val, y_val, model)

with mlflow.start_run():
  mlflow.set_tag("developer", "Bastien Winant")

  mlflow.log_params({
    "train_data": train_url,
    "val_data": val_url,
    "alpha": alpha
  })


  mlflow.log_metric("rmse", rmse)

## Hyperparameter Tuning

In [19]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [11]:
def objective(params):
  booster = xgb.train(
    params=params,
    dtrain=train,
    num_boost_round=1000,
    evals=[(valid, "validation")],
    early_stopping_rounds=50
  )

  preds = booster.predict(valid)
  rmse = root_mean_squared_error(preds, y_val)
  
  with mlflow.start_run():
    mlflow.set_tag("model", "xgboost")
    mlflow.log_params(params)
    mlflow.log_metric("rmse", rmse)
  
  return {"loss": rmse, "status": STATUS_OK}

In [12]:
search_space = {
  'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
  'learning_rate': hp.loguniform('learning_rate', -3, 0),
  'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
  'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
  'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
  'objective': 'reg:linear',
  'seed': 42
}

best_result = fmin(
  fn=objective,
  space=search_space,
  algo=tpe.suggest,
  max_evals=50,
  trials=Trials()
)

  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:10.95707                          
[1]	validation-rmse:9.96230                           
[2]	validation-rmse:9.18036                           
[3]	validation-rmse:8.58192                           
[4]	validation-rmse:8.09857                           
[5]	validation-rmse:7.74662                           
[6]	validation-rmse:7.46771                           
[7]	validation-rmse:7.25924                           
[8]	validation-rmse:7.10225                           
[9]	validation-rmse:6.98494                           
[10]	validation-rmse:6.89293                          
[11]	validation-rmse:6.82250                          
[12]	validation-rmse:6.76023                          
[13]	validation-rmse:6.71796                          
[14]	validation-rmse:6.67723                          
[15]	validation-rmse:6.65204                          
[16]	validation-rmse:6.63042                          
[17]	validation-rmse:6.60816                          
[18]	valid




[0]	validation-rmse:6.81113                                                    
[1]	validation-rmse:6.71755                                                    
[2]	validation-rmse:6.69675                                                    
[3]	validation-rmse:6.68679                                                    
[4]	validation-rmse:6.67982                                                    
[5]	validation-rmse:6.66945                                                    
[6]	validation-rmse:6.66564                                                    
[7]	validation-rmse:6.66124                                                    
[8]	validation-rmse:6.65773                                                    
[9]	validation-rmse:6.65556                                                    
[10]	validation-rmse:6.65565                                                   
[11]	validation-rmse:6.65404                                                   
[12]	validation-rmse:6.65177            




[0]	validation-rmse:7.33400                                                    
[1]	validation-rmse:6.79245                                                    
[2]	validation-rmse:6.70559                                                    
[3]	validation-rmse:6.68207                                                    
[4]	validation-rmse:6.66576                                                    
[5]	validation-rmse:6.65872                                                    
[6]	validation-rmse:6.65516                                                    
[7]	validation-rmse:6.65113                                                    
[8]	validation-rmse:6.64680                                                    
[9]	validation-rmse:6.64428                                                    
[10]	validation-rmse:6.64231                                                   
[11]	validation-rmse:6.64034                                                   
[12]	validation-rmse:6.63634            




[0]	validation-rmse:6.80168                                                    
[1]	validation-rmse:6.59345                                                    
[2]	validation-rmse:6.55442                                                    
[3]	validation-rmse:6.53529                                                    
[4]	validation-rmse:6.52296                                                    
[5]	validation-rmse:6.50893                                                    
[6]	validation-rmse:6.49989                                                    
[7]	validation-rmse:6.49396                                                    
[8]	validation-rmse:6.48841                                                    
[9]	validation-rmse:6.48589                                                    
[10]	validation-rmse:6.47980                                                   
[11]	validation-rmse:6.47596                                                   
[12]	validation-rmse:6.47398            




[0]	validation-rmse:11.56814                                                   
[1]	validation-rmse:10.99070                                                   
[2]	validation-rmse:10.47490                                                   
[3]	validation-rmse:10.01667                                                   
[4]	validation-rmse:9.60921                                                    
[5]	validation-rmse:9.24880                                                    
[6]	validation-rmse:8.92919                                                    
[7]	validation-rmse:8.64817                                                    
[8]	validation-rmse:8.40012                                                    
[9]	validation-rmse:8.18317                                                    
[10]	validation-rmse:7.99232                                                   
[11]	validation-rmse:7.82612                                                   
[12]	validation-rmse:7.67997            




[0]	validation-rmse:10.87644                                                   
[1]	validation-rmse:9.83786                                                    
[2]	validation-rmse:9.03933                                                    
[3]	validation-rmse:8.43355                                                    
[4]	validation-rmse:7.98002                                                    
[5]	validation-rmse:7.64145                                                    
[6]	validation-rmse:7.38930                                                    
[7]	validation-rmse:7.20213                                                    
[8]	validation-rmse:7.06482                                                    
[9]	validation-rmse:6.96354                                                    
[10]	validation-rmse:6.88283                                                   
[11]	validation-rmse:6.82292                                                   
[12]	validation-rmse:6.78022            




[0]	validation-rmse:10.18991                                                   
[1]	validation-rmse:8.86785                                                    
[2]	validation-rmse:8.03064                                                    
[3]	validation-rmse:7.51558                                                    
[4]	validation-rmse:7.19990                                                    
[5]	validation-rmse:7.00716                                                    
[6]	validation-rmse:6.88769                                                    
[7]	validation-rmse:6.80465                                                    
[8]	validation-rmse:6.75278                                                    
[9]	validation-rmse:6.71757                                                    
[10]	validation-rmse:6.69207                                                   
[11]	validation-rmse:6.67228                                                   
[12]	validation-rmse:6.65597            




[0]	validation-rmse:11.09489                                                   
[1]	validation-rmse:10.18079                                                   
[2]	validation-rmse:9.44151                                                    
[3]	validation-rmse:8.85434                                                    
[4]	validation-rmse:8.37786                                                    
[5]	validation-rmse:8.00544                                                    
[6]	validation-rmse:7.70856                                                    
[7]	validation-rmse:7.48015                                                    
[8]	validation-rmse:7.29993                                                    
[9]	validation-rmse:7.15947                                                    
[10]	validation-rmse:7.04529                                                   
[11]	validation-rmse:6.95917                                                   
[12]	validation-rmse:6.88952            




[0]	validation-rmse:6.95015                                                    
[1]	validation-rmse:6.64166                                                    
[2]	validation-rmse:6.58851                                                    
[3]	validation-rmse:6.57316                                                    
[4]	validation-rmse:6.55397                                                    
[5]	validation-rmse:6.54437                                                    
[6]	validation-rmse:6.53775                                                    
[7]	validation-rmse:6.52358                                                    
[8]	validation-rmse:6.51955                                                    
[9]	validation-rmse:6.51728                                                    
[10]	validation-rmse:6.50999                                                   
[11]	validation-rmse:6.50566                                                   
[12]	validation-rmse:6.50221            




[0]	validation-rmse:11.76945                                                   
[1]	validation-rmse:11.35674                                                   
[2]	validation-rmse:10.97303                                                   
[3]	validation-rmse:10.61703                                                   
[4]	validation-rmse:10.28675                                                   
[5]	validation-rmse:9.98057                                                    
[6]	validation-rmse:9.69764                                                    
[7]	validation-rmse:9.43593                                                    
[8]	validation-rmse:9.19407                                                    
[9]	validation-rmse:8.97161                                                    
[10]	validation-rmse:8.76668                                                   
[11]	validation-rmse:8.57804                                                   
[12]	validation-rmse:8.40448            




[0]	validation-rmse:10.08356                                                    
[1]	validation-rmse:8.72239                                                     
[2]	validation-rmse:7.86665                                                     
[3]	validation-rmse:7.35192                                                     
[4]	validation-rmse:7.04658                                                     
[5]	validation-rmse:6.87956                                                     
[6]	validation-rmse:6.75991                                                     
[7]	validation-rmse:6.69050                                                     
[8]	validation-rmse:6.64336                                                     
[9]	validation-rmse:6.61024                                                     
[10]	validation-rmse:6.58906                                                    
[11]	validation-rmse:6.57528                                                    
[12]	validation-rmse:6.56311




[2]	validation-rmse:8.13680                                                     
[3]	validation-rmse:7.61290                                                     
[4]	validation-rmse:7.28509                                                     
[5]	validation-rmse:7.08385                                                     
[6]	validation-rmse:6.95556                                                     
[7]	validation-rmse:6.87373                                                     
[8]	validation-rmse:6.81993                                                     
[9]	validation-rmse:6.78496                                                     
[10]	validation-rmse:6.76054                                                    
[11]	validation-rmse:6.74223                                                    
[12]	validation-rmse:6.73099                                                    
[13]	validation-rmse:6.72483                                                    
[14]	validation-rmse:6.71718




[0]	validation-rmse:10.24661                                                    
[1]	validation-rmse:8.93848                                                     
[2]	validation-rmse:8.09335                                                     
[3]	validation-rmse:7.54947                                                     
[4]	validation-rmse:7.20930                                                     
[5]	validation-rmse:6.98844                                                     
[6]	validation-rmse:6.84181                                                     
[7]	validation-rmse:6.75444                                                     
[8]	validation-rmse:6.69291                                                     
[9]	validation-rmse:6.64571                                                     
[10]	validation-rmse:6.61640                                                    
[11]	validation-rmse:6.60005                                                    
[12]	validation-rmse:6.58249




[0]	validation-rmse:10.40461                                                    
[1]	validation-rmse:9.14906                                                     
[2]	validation-rmse:8.29696                                                     
[3]	validation-rmse:7.73262                                                     
[4]	validation-rmse:7.36233                                                     
[5]	validation-rmse:7.11952                                                     
[6]	validation-rmse:6.95713                                                     
[7]	validation-rmse:6.84725                                                     
[8]	validation-rmse:6.77178                                                     
[9]	validation-rmse:6.72012                                                     
[10]	validation-rmse:6.68482                                                    
[11]	validation-rmse:6.65843                                                    
[12]	validation-rmse:6.63848




[0]	validation-rmse:10.67526                                                    
[1]	validation-rmse:9.53298                                                     
[2]	validation-rmse:8.69926                                                     
[3]	validation-rmse:8.10009                                                     
[4]	validation-rmse:7.67473                                                     
[5]	validation-rmse:7.37522                                                     
[6]	validation-rmse:7.16378                                                     
[7]	validation-rmse:7.01212                                                     
[8]	validation-rmse:6.90350                                                     
[9]	validation-rmse:6.82631                                                     
[10]	validation-rmse:6.76810                                                    
[11]	validation-rmse:6.72735                                                    
[12]	validation-rmse:6.69512




[0]	validation-rmse:8.19533                                                     
[1]	validation-rmse:7.06807                                                     
[2]	validation-rmse:6.76123                                                     
[3]	validation-rmse:6.65277                                                     
[4]	validation-rmse:6.60831                                                     
[5]	validation-rmse:6.58590                                                     
[6]	validation-rmse:6.57212                                                     
[7]	validation-rmse:6.56800                                                     
[8]	validation-rmse:6.56086                                                     
[9]	validation-rmse:6.55703                                                     
[10]	validation-rmse:6.55217                                                    
[11]	validation-rmse:6.55205                                                    
[12]	validation-rmse:6.54767




[0]	validation-rmse:9.07158                                                     
[1]	validation-rmse:7.67394                                                     
[2]	validation-rmse:7.09604                                                     
[3]	validation-rmse:6.85229                                                     
[4]	validation-rmse:6.73868                                                     
[5]	validation-rmse:6.68676                                                     
[6]	validation-rmse:6.65803                                                     
[7]	validation-rmse:6.64070                                                     
[8]	validation-rmse:6.62848                                                     
[9]	validation-rmse:6.61816                                                     
[10]	validation-rmse:6.61129                                                    
[11]	validation-rmse:6.60840                                                    
[12]	validation-rmse:6.60579




[0]	validation-rmse:11.61771                                                    
[1]	validation-rmse:11.07948                                                    
[2]	validation-rmse:10.59363                                                    
[3]	validation-rmse:10.15611                                                    
[4]	validation-rmse:9.76281                                                     
[5]	validation-rmse:9.41046                                                     
[6]	validation-rmse:9.09447                                                     
[7]	validation-rmse:8.81268                                                     
[8]	validation-rmse:8.56155                                                     
[9]	validation-rmse:8.33768                                                     
[10]	validation-rmse:8.14017                                                    
[11]	validation-rmse:7.96380                                                    
[12]	validation-rmse:7.80744




[0]	validation-rmse:11.41874                                                    
[1]	validation-rmse:10.73022                                                    
[2]	validation-rmse:10.12687                                                    
[3]	validation-rmse:9.60976                                                     
[4]	validation-rmse:9.16056                                                     
[5]	validation-rmse:8.77564                                                     
[6]	validation-rmse:8.44233                                                     
[7]	validation-rmse:8.16160                                                     
[8]	validation-rmse:7.91736                                                     
[9]	validation-rmse:7.71625                                                     
[10]	validation-rmse:7.54245                                                    
[11]	validation-rmse:7.39598                                                    
[12]	validation-rmse:7.27251




[0]	validation-rmse:8.66346                                                     
[1]	validation-rmse:7.27140                                                     
[2]	validation-rmse:6.77639                                                     
[3]	validation-rmse:6.58581                                                     
[4]	validation-rmse:6.51007                                                     
[5]	validation-rmse:6.47049                                                     
[6]	validation-rmse:6.45122                                                     
[7]	validation-rmse:6.43618                                                     
[8]	validation-rmse:6.43077                                                     
[9]	validation-rmse:6.42580                                                     
[10]	validation-rmse:6.42127                                                    
[11]	validation-rmse:6.41756                                                    
[12]	validation-rmse:6.41320




[3]	validation-rmse:10.54233                                                    
[4]	validation-rmse:10.20427                                                    
[5]	validation-rmse:9.89427                                                     
[6]	validation-rmse:9.61037                                                     
[7]	validation-rmse:9.35059                                                     
[8]	validation-rmse:9.11234                                                     
[9]	validation-rmse:8.89595                                                     
[10]	validation-rmse:8.69762                                                    
[11]	validation-rmse:8.51742                                                    
[12]	validation-rmse:8.35368                                                    
[13]	validation-rmse:8.20500                                                    
[14]	validation-rmse:8.06943                                                    
[15]	validation-rmse:7.94640




[0]	validation-rmse:11.26359                                                    
[1]	validation-rmse:10.45936                                                    
[2]	validation-rmse:9.78058                                                     
[3]	validation-rmse:9.21281                                                     
[4]	validation-rmse:8.73921                                                     
[5]	validation-rmse:8.34693                                                     
[6]	validation-rmse:8.02322                                                     
[7]	validation-rmse:7.75786                                                     
[8]	validation-rmse:7.53910                                                     
[9]	validation-rmse:7.36011                                                     
[10]	validation-rmse:7.21515                                                    
[11]	validation-rmse:7.09569                                                    
[12]	validation-rmse:6.99841




[0]	validation-rmse:11.23974                                                     
[1]	validation-rmse:10.41998                                                     
[2]	validation-rmse:9.73072                                                      
[3]	validation-rmse:9.15778                                                      
[4]	validation-rmse:8.68300                                                      
[5]	validation-rmse:8.29199                                                      
[6]	validation-rmse:7.97028                                                      
[7]	validation-rmse:7.70896                                                      
[8]	validation-rmse:7.49362                                                      
[9]	validation-rmse:7.32038                                                      
[10]	validation-rmse:7.17758                                                     
[11]	validation-rmse:7.06149                                                     
[12]	validation-




[0]	validation-rmse:11.27766                                                    
[1]	validation-rmse:10.48304                                                    
[2]	validation-rmse:9.81146                                                     
[3]	validation-rmse:9.24809                                                     
[4]	validation-rmse:8.77597                                                     
[5]	validation-rmse:8.38366                                                     
[6]	validation-rmse:8.05873                                                     
[7]	validation-rmse:7.79101                                                     
[8]	validation-rmse:7.56972                                                     
[9]	validation-rmse:7.38882                                                     
[10]	validation-rmse:7.24043                                                    
[11]	validation-rmse:7.11486                                                    
[12]	validation-rmse:7.01351




[0]	validation-rmse:11.39233                                                    
[1]	validation-rmse:10.68015                                                    
[2]	validation-rmse:10.06273                                                    
[3]	validation-rmse:9.53248                                                     
[4]	validation-rmse:9.07568                                                     
[5]	validation-rmse:8.68729                                                     
[6]	validation-rmse:8.35615                                                     
[7]	validation-rmse:8.07508                                                     
[8]	validation-rmse:7.83855                                                     
[9]	validation-rmse:7.63779                                                     
[10]	validation-rmse:7.46688                                                    
[11]	validation-rmse:7.32390                                                    
[12]	validation-rmse:7.20512




[0]	validation-rmse:11.63625                                                    
[1]	validation-rmse:11.11268                                                    
[2]	validation-rmse:10.63875                                                    
[3]	validation-rmse:10.21020                                                    
[4]	validation-rmse:9.82334                                                     
[5]	validation-rmse:9.47578                                                     
[6]	validation-rmse:9.16215                                                     
[7]	validation-rmse:8.87978                                                     
[8]	validation-rmse:8.62789                                                     
[9]	validation-rmse:8.40325                                                     
[10]	validation-rmse:8.20240                                                    
[11]	validation-rmse:8.02374                                                    
[12]	validation-rmse:7.86369




[0]	validation-rmse:11.45286                                                    
[1]	validation-rmse:10.78809                                                    
[2]	validation-rmse:10.20606                                                    
[3]	validation-rmse:9.69892                                                     
[4]	validation-rmse:9.26078                                                     
[5]	validation-rmse:8.87964                                                     
[6]	validation-rmse:8.55741                                                     
[7]	validation-rmse:8.27337                                                     
[8]	validation-rmse:8.03233                                                     
[9]	validation-rmse:7.82324                                                     
[10]	validation-rmse:7.65000                                                    
[11]	validation-rmse:7.50049                                                    
[12]	validation-rmse:7.37389




[0]	validation-rmse:11.25796                                                    
[1]	validation-rmse:10.45151                                                    
[2]	validation-rmse:9.77497                                                     
[3]	validation-rmse:9.21069                                                     
[4]	validation-rmse:8.74517                                                     
[5]	validation-rmse:8.35766                                                     
[6]	validation-rmse:8.04330                                                     
[7]	validation-rmse:7.78465                                                     
[8]	validation-rmse:7.57232                                                     
[9]	validation-rmse:7.39835                                                     
[10]	validation-rmse:7.25788                                                    
[11]	validation-rmse:7.14267                                                    
[12]	validation-rmse:7.04916




[0]	validation-rmse:9.65322                                                     
[1]	validation-rmse:8.22797                                                     
[2]	validation-rmse:7.47861                                                     
[3]	validation-rmse:7.08143                                                     
[4]	validation-rmse:6.87714                                                     
[5]	validation-rmse:6.76750                                                     
[6]	validation-rmse:6.69465                                                     
[7]	validation-rmse:6.66013                                                     
[8]	validation-rmse:6.63799                                                     
[9]	validation-rmse:6.62145                                                     
[10]	validation-rmse:6.60912                                                    
[11]	validation-rmse:6.60014                                                    
[12]	validation-rmse:6.59737




[0]	validation-rmse:10.96484                                                    
[1]	validation-rmse:9.97174                                                     
[2]	validation-rmse:9.18571                                                     
[3]	validation-rmse:8.57036                                                     
[4]	validation-rmse:8.09386                                                     
[5]	validation-rmse:7.72729                                                     
[6]	validation-rmse:7.44486                                                     
[7]	validation-rmse:7.23247                                                     
[8]	validation-rmse:7.06998                                                     
[9]	validation-rmse:6.94540                                                     
[10]	validation-rmse:6.84768                                                    
[11]	validation-rmse:6.77532                                                    
[12]	validation-rmse:6.71611




[0]	validation-rmse:11.71256                                                    
[1]	validation-rmse:11.25278                                                    
[2]	validation-rmse:10.83142                                                    
[3]	validation-rmse:10.44416                                                    
[4]	validation-rmse:10.08957                                                    
[5]	validation-rmse:9.76245                                                     
[6]	validation-rmse:9.46670                                                     
[7]	validation-rmse:9.19558                                                     
[8]	validation-rmse:8.95176                                                     
[9]	validation-rmse:8.72357                                                     
[10]	validation-rmse:8.52374                                                    
[11]	validation-rmse:8.33437                                                    
[12]	validation-rmse:8.17013




[0]	validation-rmse:11.81659                                                     
[1]	validation-rmse:11.44563                                                     
[2]	validation-rmse:11.09902                                                     
[3]	validation-rmse:10.77548                                                     
[4]	validation-rmse:10.47384                                                     
[5]	validation-rmse:10.19277                                                     
[6]	validation-rmse:9.93119                                                      
[7]	validation-rmse:9.68787                                                      
[8]	validation-rmse:9.46183                                                      
[9]	validation-rmse:9.25205                                                      
[10]	validation-rmse:9.05752                                                     
[11]	validation-rmse:8.87733                                                     
[12]	validation-




[0]	validation-rmse:10.68920                                                    
[1]	validation-rmse:9.54916                                                     
[2]	validation-rmse:8.70755                                                     
[3]	validation-rmse:8.09873                                                     
[4]	validation-rmse:7.65959                                                     
[5]	validation-rmse:7.35096                                                     
[6]	validation-rmse:7.13077                                                     
[7]	validation-rmse:6.97186                                                     
[8]	validation-rmse:6.85992                                                     
[9]	validation-rmse:6.77872                                                     
[10]	validation-rmse:6.71904                                                    
[11]	validation-rmse:6.67470                                                    
[12]	validation-rmse:6.64045




[0]	validation-rmse:11.50170                                                    
[1]	validation-rmse:10.87357                                                    
[2]	validation-rmse:10.31888                                                    
[3]	validation-rmse:9.83284                                                     
[4]	validation-rmse:9.40598                                                     
[5]	validation-rmse:9.03509                                                     
[6]	validation-rmse:8.71132                                                     
[7]	validation-rmse:8.43018                                                     
[8]	validation-rmse:8.18699                                                     
[9]	validation-rmse:7.97707                                                     
[10]	validation-rmse:7.79629                                                    
[11]	validation-rmse:7.63995                                                    
[12]	validation-rmse:7.50454




[0]	validation-rmse:11.07246                                                     
[1]	validation-rmse:10.14316                                                     
[2]	validation-rmse:9.39296                                                      
[3]	validation-rmse:8.79173                                                      
[4]	validation-rmse:8.31370                                                      
[5]	validation-rmse:7.93513                                                      
[6]	validation-rmse:7.63822                                                      
[7]	validation-rmse:7.40593                                                      
[8]	validation-rmse:7.22430                                                      
[9]	validation-rmse:7.08189                                                      
[10]	validation-rmse:6.97246                                                     
[11]	validation-rmse:6.88613                                                     
[12]	validation-




[0]	validation-rmse:9.60329                                                      
[1]	validation-rmse:8.14858                                                      
[2]	validation-rmse:7.37361                                                      
[3]	validation-rmse:6.97761                                                      
[4]	validation-rmse:6.76760                                                      
[5]	validation-rmse:6.65345                                                      
[6]	validation-rmse:6.58720                                                      
[7]	validation-rmse:6.54898                                                      
[8]	validation-rmse:6.52224                                                      
[9]	validation-rmse:6.50345                                                      
[10]	validation-rmse:6.49111                                                     
[11]	validation-rmse:6.48127                                                     
[12]	validation-




[0]	validation-rmse:10.69007                                                     
[1]	validation-rmse:9.55402                                                      
[2]	validation-rmse:8.71780                                                      
[3]	validation-rmse:8.11239                                                      
[4]	validation-rmse:7.67924                                                      
[5]	validation-rmse:7.37428                                                      
[6]	validation-rmse:7.15185                                                      
[7]	validation-rmse:6.99739                                                      
[8]	validation-rmse:6.88404                                                      
[9]	validation-rmse:6.80223                                                      
[10]	validation-rmse:6.74195                                                     
[11]	validation-rmse:6.69835                                                     
[12]	validation-




[0]	validation-rmse:10.68863                                                    
[1]	validation-rmse:9.55719                                                     
[2]	validation-rmse:8.73186                                                     
[3]	validation-rmse:8.14020                                                     
[4]	validation-rmse:7.71815                                                     
[5]	validation-rmse:7.42406                                                     
[6]	validation-rmse:7.21191                                                     
[7]	validation-rmse:7.06581                                                     
[8]	validation-rmse:6.96041                                                     
[9]	validation-rmse:6.88410                                                     
[10]	validation-rmse:6.82959                                                    
[11]	validation-rmse:6.78621                                                    
[12]	validation-rmse:6.75760




[0]	validation-rmse:8.08443                                                     
[1]	validation-rmse:6.99555                                                     
[2]	validation-rmse:6.71814                                                     
[3]	validation-rmse:6.63808                                                     
[4]	validation-rmse:6.59980                                                     
[5]	validation-rmse:6.57892                                                     
[6]	validation-rmse:6.57139                                                     
[7]	validation-rmse:6.56400                                                     
[8]	validation-rmse:6.55803                                                     
[9]	validation-rmse:6.55175                                                     
[10]	validation-rmse:6.54628                                                    
[11]	validation-rmse:6.54237                                                    
[12]	validation-rmse:6.53206




[0]	validation-rmse:9.74096                                                     
[1]	validation-rmse:8.31295                                                     
[2]	validation-rmse:7.52879                                                     
[3]	validation-rmse:7.10130                                                     
[4]	validation-rmse:6.87158                                                     
[5]	validation-rmse:6.74042                                                     
[6]	validation-rmse:6.66556                                                     
[7]	validation-rmse:6.61940                                                     
[8]	validation-rmse:6.58766                                                     
[9]	validation-rmse:6.56827                                                     
[10]	validation-rmse:6.55370                                                    
[11]	validation-rmse:6.54707                                                    
[12]	validation-rmse:6.54124




[0]	validation-rmse:10.81963                                                    
[1]	validation-rmse:9.74865                                                     
[2]	validation-rmse:8.93926                                                     
[3]	validation-rmse:8.33118                                                     
[4]	validation-rmse:7.88190                                                     
[5]	validation-rmse:7.55100                                                     
[6]	validation-rmse:7.31085                                                     
[7]	validation-rmse:7.13165                                                     
[8]	validation-rmse:7.00186                                                     
[9]	validation-rmse:6.90495                                                     
[10]	validation-rmse:6.83013                                                    
[11]	validation-rmse:6.77666                                                    
[12]	validation-rmse:6.73472




[0]	validation-rmse:11.14327                                                    
[1]	validation-rmse:10.26196                                                    
[2]	validation-rmse:9.53798                                                     
[3]	validation-rmse:8.94808                                                     
[4]	validation-rmse:8.46977                                                     
[5]	validation-rmse:8.09765                                                     
[6]	validation-rmse:7.79097                                                     
[7]	validation-rmse:7.54613                                                     
[8]	validation-rmse:7.34796                                                     
[9]	validation-rmse:7.19524                                                     
[10]	validation-rmse:7.07250                                                    
[11]	validation-rmse:6.97741                                                    
[12]	validation-rmse:6.89905




[0]	validation-rmse:10.53069                                                    
[1]	validation-rmse:9.32437                                                     
[2]	validation-rmse:8.47536                                                     
[3]	validation-rmse:7.89386                                                     
[4]	validation-rmse:7.49437                                                     
[5]	validation-rmse:7.22167                                                     
[6]	validation-rmse:7.03833                                                     
[7]	validation-rmse:6.91420                                                     
[8]	validation-rmse:6.82476                                                     
[9]	validation-rmse:6.76573                                                     
[10]	validation-rmse:6.72255                                                    
[11]	validation-rmse:6.69032                                                    
[12]	validation-rmse:6.66784




[4]	validation-rmse:7.15741                                                     
[5]	validation-rmse:7.01116                                                     
[6]	validation-rmse:6.92772                                                     
[7]	validation-rmse:6.87664                                                     
[8]	validation-rmse:6.84727                                                     
[9]	validation-rmse:6.82977                                                     
[10]	validation-rmse:6.81324                                                    
[11]	validation-rmse:6.80228                                                    
[12]	validation-rmse:6.79676                                                    
[13]	validation-rmse:6.79362                                                    
[14]	validation-rmse:6.78970                                                    
[15]	validation-rmse:6.78742                                                    
[16]	validation-rmse:6.78073




[0]	validation-rmse:9.43359                                                     
[1]	validation-rmse:7.99790                                                     
[2]	validation-rmse:7.30279                                                     
[3]	validation-rmse:6.96929                                                     
[4]	validation-rmse:6.80194                                                     
[5]	validation-rmse:6.71809                                                     
[6]	validation-rmse:6.67040                                                     
[7]	validation-rmse:6.64266                                                     
[8]	validation-rmse:6.62715                                                     
[9]	validation-rmse:6.61675                                                     
[10]	validation-rmse:6.61141                                                    
[11]	validation-rmse:6.60689                                                    
[12]	validation-rmse:6.60372




[1]	validation-rmse:9.29081                                                     
[2]	validation-rmse:8.45061                                                     
[3]	validation-rmse:7.87999                                                     
[4]	validation-rmse:7.49243                                                     
[5]	validation-rmse:7.23421                                                     
[6]	validation-rmse:7.06098                                                     
[7]	validation-rmse:6.94418                                                     
[8]	validation-rmse:6.86465                                                     
[9]	validation-rmse:6.80980                                                     
[10]	validation-rmse:6.77042                                                    
[11]	validation-rmse:6.74108                                                    
[12]	validation-rmse:6.71869                                                    
[13]	validation-rmse:6.70511




[1]	validation-rmse:6.90470                                                     
[2]	validation-rmse:6.76620                                                     
[3]	validation-rmse:6.72533                                                     
[4]	validation-rmse:6.71501                                                     
[5]	validation-rmse:6.70756                                                     
[6]	validation-rmse:6.70279                                                     
[7]	validation-rmse:6.69649                                                     
[8]	validation-rmse:6.69418                                                     
[9]	validation-rmse:6.68905                                                     
[10]	validation-rmse:6.68453                                                    
[11]	validation-rmse:6.68187                                                    
[12]	validation-rmse:6.67994                                                    
[13]	validation-rmse:6.67453




[0]	validation-rmse:10.98429                                                    
[1]	validation-rmse:10.00167                                                    
[2]	validation-rmse:9.23028                                                     
[3]	validation-rmse:8.62000                                                     
[4]	validation-rmse:8.14962                                                     
[5]	validation-rmse:7.78706                                                     
[6]	validation-rmse:7.50252                                                     
[7]	validation-rmse:7.29138                                                     
[8]	validation-rmse:7.13215                                                     
[9]	validation-rmse:7.00355                                                     
[10]	validation-rmse:6.91387                                                    
[11]	validation-rmse:6.83377                                                    
[12]	validation-rmse:6.77174




[0]	validation-rmse:10.47184                                                    
[1]	validation-rmse:9.24432                                                     
[2]	validation-rmse:8.39171                                                     
[3]	validation-rmse:7.81120                                                     
[4]	validation-rmse:7.42169                                                     
[5]	validation-rmse:7.15281                                                     
[6]	validation-rmse:6.97936                                                     
[7]	validation-rmse:6.86273                                                     
[8]	validation-rmse:6.77997                                                     
[9]	validation-rmse:6.71580                                                     
[10]	validation-rmse:6.67380                                                    
[11]	validation-rmse:6.63896                                                    
[12]	validation-rmse:6.61795




[0]	validation-rmse:8.99704                                                     
[1]	validation-rmse:7.56930                                                     
[2]	validation-rmse:6.97964                                                     
[3]	validation-rmse:6.73805                                                     
[4]	validation-rmse:6.62770                                                     
[5]	validation-rmse:6.57419                                                     
[6]	validation-rmse:6.54093                                                     
[7]	validation-rmse:6.52603                                                     
[8]	validation-rmse:6.50884                                                     
[9]	validation-rmse:6.49772                                                     
[10]	validation-rmse:6.49335                                                    
[11]	validation-rmse:6.48879                                                    
[12]	validation-rmse:6.48318

## Autologging

In [13]:
# retrieve the best-performing params from mlflow
best_params = {
  "learning_rate": 0.21668937995954535,
  "max_depth": 18,
  "min_child_weight": 1.0625241915799823,
  "objective": "reg:linear",
  "reg_alpha": 0.014108586386588398,
  "reg_lambda": 0.007062245165893128,
  "seed": 42
}

In [14]:
mlflow.xgboost.autolog()

# train model with best params
booster = xgb.train(
  params=best_params,
  dtrain=train,
  num_boost_round=1000,
  evals=[(valid, "validation")],
  early_stopping_rounds=50
)

2024/07/24 11:41:19 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '376b419a02484c1e9476f7984b11b842', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current xgboost workflow


[0]	validation-rmse:10.53069
[1]	validation-rmse:9.32437
[2]	validation-rmse:8.47536
[3]	validation-rmse:7.89386
[4]	validation-rmse:7.49437
[5]	validation-rmse:7.22167
[6]	validation-rmse:7.03833
[7]	validation-rmse:6.91420
[8]	validation-rmse:6.82476
[9]	validation-rmse:6.76573
[10]	validation-rmse:6.72255
[11]	validation-rmse:6.69032
[12]	validation-rmse:6.66784
[13]	validation-rmse:6.64784
[14]	validation-rmse:6.63216
[15]	validation-rmse:6.62061
[16]	validation-rmse:6.61052
[17]	validation-rmse:6.60383
[18]	validation-rmse:6.59924
[19]	validation-rmse:6.59532
[20]	validation-rmse:6.59425
[21]	validation-rmse:6.59283
[22]	validation-rmse:6.59114
[23]	validation-rmse:6.58899
[24]	validation-rmse:6.58693
[25]	validation-rmse:6.58533
[26]	validation-rmse:6.58384
[27]	validation-rmse:6.58115
[28]	validation-rmse:6.57882
[29]	validation-rmse:6.57593
[30]	validation-rmse:6.57317
[31]	validation-rmse:6.57063
[32]	validation-rmse:6.56899
[33]	validation-rmse:6.56818
[34]	validation-rmse:6.



## Manual Model Management

In [18]:
alpha = .1
lasso = Lasso(alpha)
lasso.fit(X=X_train, y=y_train)
preds = lasso.predict(X_val)
rmse = root_mean_squared_error(preds, y_val)

with open("models/lin_reg.bin", "bw") as f_out:
  pickle.dump((dv, lasso), f_out)

with mlflow.start_run():
  mlflow.set_tag("developer", "Bastien Winant")

  mlflow.log_params({
    "train_data": train_url,
    "val_data": val_url,
    "aplha": alpha
  })

  mlflow.log_metric("rmse", rmse)

  # save model as an artifact
  mlflow.log_artifact("models/lin_reg.bin", artifact_path="models_pickle")

## Automated Model Management

In [21]:
# retrieve the best-performing params from mlflow
best_params = {
  "learning_rate": 0.21668937995954535,
  "max_depth": 18,
  "min_child_weight": 1.0625241915799823,
  "objective": "reg:linear",
  "reg_alpha": 0.014108586386588398,
  "reg_lambda": 0.007062245165893128,
  "seed": 42
}

booster = xgb.train(
  params=best_params,
  dtrain=train,
  num_boost_round=1000,
  evals=[(valid, "validation")],
  early_stopping_rounds=50
)

preds = booster.predict(valid)
rmse = root_mean_squared_error(y_val, preds)

with open("models/preprocessor.b", "wb") as f_out:
  pickle.dump(dv, f_out)

with mlflow.start_run():
  mlflow.log_params(best_params)
  mlflow.log_metric("rmse", rmse)

  # log the model
  mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")
  mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")



[0]	validation-rmse:10.53069
[1]	validation-rmse:9.32437
[2]	validation-rmse:8.47536
[3]	validation-rmse:7.89386
[4]	validation-rmse:7.49437
[5]	validation-rmse:7.22167
[6]	validation-rmse:7.03833
[7]	validation-rmse:6.91420
[8]	validation-rmse:6.82476
[9]	validation-rmse:6.76573
[10]	validation-rmse:6.72255
[11]	validation-rmse:6.69032
[12]	validation-rmse:6.66784
[13]	validation-rmse:6.64784
[14]	validation-rmse:6.63216
[15]	validation-rmse:6.62061
[16]	validation-rmse:6.61052
[17]	validation-rmse:6.60383
[18]	validation-rmse:6.59924
[19]	validation-rmse:6.59532
[20]	validation-rmse:6.59425
[21]	validation-rmse:6.59283
[22]	validation-rmse:6.59114
[23]	validation-rmse:6.58899
[24]	validation-rmse:6.58693
[25]	validation-rmse:6.58533
[26]	validation-rmse:6.58384
[27]	validation-rmse:6.58115
[28]	validation-rmse:6.57882
[29]	validation-rmse:6.57593
[30]	validation-rmse:6.57317
[31]	validation-rmse:6.57063
[32]	validation-rmse:6.56899
[33]	validation-rmse:6.56818
[34]	validation-rmse:6.



## Retrieve model and make predictions

In [23]:
logged_model = 'runs:/4cffa5e1dfc542ceb451d95d681c51a6/models_mlflow'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)
xgboost_model = mlflow.xgboost.load_model(logged_model)

In [27]:
# Predict on a Pandas DataFrame.
loaded_model.predict(X_val)

array([14.457083,  7.050064, 15.660235, ..., 13.59472 ,  6.457606,
        8.311016], dtype=float32)

In [24]:
xgboost_model.predict(valid)

array([14.457083,  7.050064, 15.660235, ..., 13.59472 ,  6.457606,
        8.311016], dtype=float32)