In [10]:
!python -V

Python 3.9.18


In [11]:
import pandas as pd

In [12]:
import pickle

In [13]:
import seaborn as sns
import matplotlib.pyplot as plt

In [14]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

In [15]:
import mlflow


mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/Users/yunbo/Documents/GitHub/mlops-zoomcamp/02-experiment-tracking/mlruns/1', creation_time=1710747964467, experiment_id='1', last_update_time=1710747964467, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [18]:
def read_dataframe(filename):
    if filename.endswith('.csv'):
        df = pd.read_csv(filename)

        df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
        df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)
    elif filename.endswith('.parquet'):
        df = pd.read_parquet(filename)
        
    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [19]:
df_train = read_dataframe('./data/green_tripdata_2021-01.parquet')
df_val = read_dataframe('./data/green_tripdata_2021-02.parquet')

In [20]:
len(df_train), len(df_val)

(73908, 61921)

In [21]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [22]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [23]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [24]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)

7.758715204520257

In [26]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [28]:
with mlflow.start_run():

    mlflow.set_tag("developer", "cristian")

    mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
    mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")

    alpha = 0.01
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")

In [29]:
import xgboost as xgb

In [30]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [31]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [35]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [36]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:10.63051                          
[1]	validation-rmse:9.45642                           
[2]	validation-rmse:8.59687                           
[3]	validation-rmse:7.98297                           
[4]	validation-rmse:7.54633                           
[5]	validation-rmse:7.24210                           
[6]	validation-rmse:7.02564                           
[7]	validation-rmse:6.87547                           
[8]	validation-rmse:6.76593                           
[9]	validation-rmse:6.68959                           
[10]	validation-rmse:6.63341                          
[11]	validation-rmse:6.59121                          
[12]	validation-rmse:6.55977                          
[13]	validation-rmse:6.53660                          
[14]	validation-rmse:6.51821                          
[15]	validation-rmse:6.50425                          
[16]	validation-rmse:6.49322                          
[17]	validation-rmse:6.48295                          
[18]	valid




[3]	validation-rmse:10.60851                                                   
[4]	validation-rmse:10.28052                                                   
[5]	validation-rmse:9.97817                                                    
[6]	validation-rmse:9.69973                                                    
[7]	validation-rmse:9.44371                                                    
[8]	validation-rmse:9.20863                                                    
[9]	validation-rmse:8.99302                                                    
[10]	validation-rmse:8.79549                                                   
[11]	validation-rmse:8.61455                                                   
[12]	validation-rmse:8.44907                                                   
[13]	validation-rmse:8.29791                                                   
[14]	validation-rmse:8.16007                                                   
[15]	validation-rmse:8.03444            




[8]	validation-rmse:6.72536                                                    
[9]	validation-rmse:6.72128                                                    
[10]	validation-rmse:6.71548                                                   
[11]	validation-rmse:6.71046                                                   
[12]	validation-rmse:6.70314                                                   
[13]	validation-rmse:6.70091                                                   
[14]	validation-rmse:6.69785                                                   
[15]	validation-rmse:6.69150                                                   
[16]	validation-rmse:6.68911                                                   
[17]	validation-rmse:6.68773                                                   
[18]	validation-rmse:6.68540                                                   
[19]	validation-rmse:6.68282                                                   
[20]	validation-rmse:6.67721            




[1]	validation-rmse:7.67971                                                    
[2]	validation-rmse:7.07241                                                    
[3]	validation-rmse:6.81220                                                    
[4]	validation-rmse:6.69027                                                    
[5]	validation-rmse:6.62628                                                    
[6]	validation-rmse:6.59395                                                    
[7]	validation-rmse:6.57110                                                    
[8]	validation-rmse:6.56093                                                    
[9]	validation-rmse:6.55019                                                    
[10]	validation-rmse:6.54508                                                   
[11]	validation-rmse:6.54162                                                   
[12]	validation-rmse:6.53772                                                   
[13]	validation-rmse:6.53166            




[7]	validation-rmse:6.72823                                                    
[8]	validation-rmse:6.71759                                                    
[9]	validation-rmse:6.71395                                                    
[10]	validation-rmse:6.70633                                                   
[11]	validation-rmse:6.70541                                                   
[12]	validation-rmse:6.70396                                                   
[13]	validation-rmse:6.69792                                                   
[14]	validation-rmse:6.69423                                                   
[15]	validation-rmse:6.69140                                                   
[16]	validation-rmse:6.68446                                                   
[17]	validation-rmse:6.67990                                                   
[18]	validation-rmse:6.67539                                                   
[19]	validation-rmse:6.67207            




[0]	validation-rmse:10.69473                                                   
[1]	validation-rmse:9.56330                                                    
[2]	validation-rmse:8.73273                                                    
[3]	validation-rmse:8.13072                                                    
[4]	validation-rmse:7.70129                                                    
[5]	validation-rmse:7.39566                                                    
[6]	validation-rmse:7.18145                                                    
[7]	validation-rmse:7.02181                                                    
[8]	validation-rmse:6.91188                                                    
[9]	validation-rmse:6.83077                                                    
[10]	validation-rmse:6.77280                                                   
[11]	validation-rmse:6.72608                                                   
[12]	validation-rmse:6.69328            




[0]	validation-rmse:10.43339                                                   
[1]	validation-rmse:9.17463                                                    
[2]	validation-rmse:8.30585                                                    
[3]	validation-rmse:7.71549                                                    
[4]	validation-rmse:7.32177                                                    
[5]	validation-rmse:7.06237                                                    
[6]	validation-rmse:6.88724                                                    
[7]	validation-rmse:6.76912                                                    
[8]	validation-rmse:6.68805                                                    
[9]	validation-rmse:6.63157                                                    
[10]	validation-rmse:6.59091                                                   
[11]	validation-rmse:6.56175                                                   
[12]	validation-rmse:6.53880            




[0]	validation-rmse:8.38692                                                    
[1]	validation-rmse:7.17488                                                    
[2]	validation-rmse:6.81594                                                    
[3]	validation-rmse:6.69035                                                    
[4]	validation-rmse:6.63324                                                    
[5]	validation-rmse:6.60646                                                    
[6]	validation-rmse:6.59425                                                    
[7]	validation-rmse:6.58854                                                    
[8]	validation-rmse:6.58100                                                    
[9]	validation-rmse:6.57542                                                    
[10]	validation-rmse:6.57056                                                   
[11]	validation-rmse:6.56684                                                   
[12]	validation-rmse:6.56177            




[7]	validation-rmse:6.79150                                                    
[8]	validation-rmse:6.75814                                                    
[9]	validation-rmse:6.73645                                                    
[10]	validation-rmse:6.72232                                                   
[11]	validation-rmse:6.71547                                                   
[12]	validation-rmse:6.70824                                                   
[13]	validation-rmse:6.70173                                                   
[14]	validation-rmse:6.69818                                                   
[15]	validation-rmse:6.69594                                                   
[16]	validation-rmse:6.69410                                                   
[17]	validation-rmse:6.69109                                                   
[18]	validation-rmse:6.68900                                                   
[19]	validation-rmse:6.68623            




[1]	validation-rmse:7.12255                                                    
[2]	validation-rmse:6.76473                                                    
[3]	validation-rmse:6.64796                                                    
[4]	validation-rmse:6.60297                                                    
[5]	validation-rmse:6.58087                                                    
[6]	validation-rmse:6.56468                                                    
[7]	validation-rmse:6.55520                                                    
[8]	validation-rmse:6.54768                                                    
[9]	validation-rmse:6.54236                                                    
[10]	validation-rmse:6.53822                                                   
[11]	validation-rmse:6.53212                                                   
[12]	validation-rmse:6.52874                                                   
[13]	validation-rmse:6.52475            




[0]	validation-rmse:6.88402                                                     
[1]	validation-rmse:6.78912                                                     
[2]	validation-rmse:6.76778                                                     
[3]	validation-rmse:6.76235                                                     
[4]	validation-rmse:6.75719                                                     
[5]	validation-rmse:6.74959                                                     
[6]	validation-rmse:6.74791                                                     
[7]	validation-rmse:6.75040                                                     
[8]	validation-rmse:6.74224                                                     
[9]	validation-rmse:6.74044                                                     
[10]	validation-rmse:6.73653                                                    
[11]	validation-rmse:6.74734                                                    
[12]	validation-rmse:6.74249




[2]	validation-rmse:6.73215                                                     
[3]	validation-rmse:6.71274                                                     
[4]	validation-rmse:6.70626                                                     
[5]	validation-rmse:6.70380                                                     
[6]	validation-rmse:6.70268                                                     
[7]	validation-rmse:6.69631                                                     
[8]	validation-rmse:6.69142                                                     
[9]	validation-rmse:6.68596                                                     
[10]	validation-rmse:6.68397                                                    
[11]	validation-rmse:6.68176                                                    
[12]	validation-rmse:6.67626                                                    
[13]	validation-rmse:6.67153                                                    
[14]	validation-rmse:6.66978




[0]	validation-rmse:7.92838                                                     
[1]	validation-rmse:6.96090                                                     
[2]	validation-rmse:6.74902                                                     
[3]	validation-rmse:6.69025                                                     
[4]	validation-rmse:6.66236                                                     
[5]	validation-rmse:6.64665                                                     
[6]	validation-rmse:6.64149                                                     
[7]	validation-rmse:6.63552                                                     
[8]	validation-rmse:6.62972                                                     
[9]	validation-rmse:6.62293                                                     
[10]	validation-rmse:6.61755                                                    
[11]	validation-rmse:6.61299                                                    
[12]	validation-rmse:6.60623




[2]	validation-rmse:6.67687                                                     
[3]	validation-rmse:6.65380                                                     
[4]	validation-rmse:6.64390                                                     
[5]	validation-rmse:6.63859                                                     
[6]	validation-rmse:6.63355                                                     
[7]	validation-rmse:6.62972                                                     
[8]	validation-rmse:6.62660                                                     
[9]	validation-rmse:6.62060                                                     
[10]	validation-rmse:6.61704                                                    
[11]	validation-rmse:6.60816                                                    
[12]	validation-rmse:6.60627                                                    
[13]	validation-rmse:6.60482                                                    
[14]	validation-rmse:6.59614




[1]	validation-rmse:8.70980                                                     
[2]	validation-rmse:7.88565                                                     
[3]	validation-rmse:7.39815                                                     
[4]	validation-rmse:7.11229                                                     
[5]	validation-rmse:6.94305                                                     
[6]	validation-rmse:6.83649                                                     
[7]	validation-rmse:6.76854                                                     
[8]	validation-rmse:6.72570                                                     
[9]	validation-rmse:6.69562                                                     
[10]	validation-rmse:6.67241                                                    
[11]	validation-rmse:6.65569                                                    
[12]	validation-rmse:6.64460                                                    
[13]	validation-rmse:6.63443




[1]	validation-rmse:6.65068                                                     
[2]	validation-rmse:6.61345                                                     
[3]	validation-rmse:6.60630                                                     
[4]	validation-rmse:6.59666                                                     
[5]	validation-rmse:6.58360                                                     
[6]	validation-rmse:6.57682                                                     
[7]	validation-rmse:6.56759                                                     
[8]	validation-rmse:6.56170                                                     
[9]	validation-rmse:6.55671                                                     
[10]	validation-rmse:6.53750                                                    
[11]	validation-rmse:6.53329                                                    
[12]	validation-rmse:6.53044                                                    
[13]	validation-rmse:6.52743




[0]	validation-rmse:9.75680                                                     
[1]	validation-rmse:8.33099                                                     
[2]	validation-rmse:7.53757                                                     
[3]	validation-rmse:7.10787                                                     
[4]	validation-rmse:6.86588                                                     
[5]	validation-rmse:6.73556                                                     
[6]	validation-rmse:6.65631                                                     
[7]	validation-rmse:6.60741                                                     
[8]	validation-rmse:6.58001                                                     
[9]	validation-rmse:6.55298                                                     
[10]	validation-rmse:6.53697                                                    
[11]	validation-rmse:6.52697                                                    
[12]	validation-rmse:6.52184




[2]	validation-rmse:8.17172                                                     
[3]	validation-rmse:7.62530                                                     
[4]	validation-rmse:7.27478                                                     
[5]	validation-rmse:7.05142                                                     
[6]	validation-rmse:6.90855                                                     
[7]	validation-rmse:6.81512                                                     
[8]	validation-rmse:6.75185                                                     
[9]	validation-rmse:6.71037                                                     
[10]	validation-rmse:6.67980                                                    
[11]	validation-rmse:6.65824                                                    
[12]	validation-rmse:6.64146                                                    
[13]	validation-rmse:6.62823                                                    
[14]	validation-rmse:6.61868




[0]	validation-rmse:7.94813                                                     
[1]	validation-rmse:6.97161                                                     
[2]	validation-rmse:6.74763                                                     
[3]	validation-rmse:6.67292                                                     
[4]	validation-rmse:6.64752                                                     
[5]	validation-rmse:6.63408                                                     
[6]	validation-rmse:6.62271                                                     
[7]	validation-rmse:6.61294                                                     
[8]	validation-rmse:6.60653                                                     
[9]	validation-rmse:6.59742                                                     
[10]	validation-rmse:6.59212                                                    
[11]	validation-rmse:6.58981                                                    
[12]	validation-rmse:6.58744




[0]	validation-rmse:10.87795                                                    
[1]	validation-rmse:9.84094                                                     
[2]	validation-rmse:9.04337                                                     
[3]	validation-rmse:8.43105                                                     
[4]	validation-rmse:7.96564                                                     
[5]	validation-rmse:7.62182                                                     
[6]	validation-rmse:7.36196                                                     
[7]	validation-rmse:7.16759                                                     
[8]	validation-rmse:7.02512                                                     
[9]	validation-rmse:6.91871                                                     
[10]	validation-rmse:6.83713                                                    
[11]	validation-rmse:6.77356                                                    
[12]	validation-rmse:6.72201




[0]	validation-rmse:11.28055                                                    
[1]	validation-rmse:10.48706                                                    
[2]	validation-rmse:9.81640                                                     
[3]	validation-rmse:9.25292                                                     
[4]	validation-rmse:8.77872                                                     
[5]	validation-rmse:8.38541                                                     
[6]	validation-rmse:8.05802                                                     
[7]	validation-rmse:7.78932                                                     
[8]	validation-rmse:7.56436                                                     
[9]	validation-rmse:7.38109                                                     
[10]	validation-rmse:7.23066                                                    
[11]	validation-rmse:7.10681                                                    
[12]	validation-rmse:7.00336




[0]	validation-rmse:11.28771                                                    
[1]	validation-rmse:10.49895                                                    
[2]	validation-rmse:9.83150                                                     
[3]	validation-rmse:9.26950                                                     
[4]	validation-rmse:8.79425                                                     
[5]	validation-rmse:8.40093                                                     
[6]	validation-rmse:8.07360                                                     
[7]	validation-rmse:7.80268                                                     
[8]	validation-rmse:7.57790                                                     
[9]	validation-rmse:7.39457                                                     
[10]	validation-rmse:7.24138                                                    
[11]	validation-rmse:7.11499                                                    
[12]	validation-rmse:7.01161




[0]	validation-rmse:11.30643                                                    
[1]	validation-rmse:10.53024                                                    
[2]	validation-rmse:9.86963                                                     
[3]	validation-rmse:9.31018                                                     
[4]	validation-rmse:8.83667                                                     
[5]	validation-rmse:8.43778                                                     
[6]	validation-rmse:8.10503                                                     
[7]	validation-rmse:7.82836                                                     
[8]	validation-rmse:7.59954                                                     
[9]	validation-rmse:7.40706                                                     
[10]	validation-rmse:7.24897                                                    
[11]	validation-rmse:7.11982                                                    
[12]	validation-rmse:7.01172




[0]	validation-rmse:11.59975                                                    
[1]	validation-rmse:11.04669                                                    
[2]	validation-rmse:10.54767                                                    
[3]	validation-rmse:10.09975                                                    
[4]	validation-rmse:9.69849                                                     
[5]	validation-rmse:9.33841                                                     
[6]	validation-rmse:9.01793                                                     
[7]	validation-rmse:8.73180                                                     
[8]	validation-rmse:8.47822                                                     
[9]	validation-rmse:8.25246                                                     
[10]	validation-rmse:8.05213                                                    
[11]	validation-rmse:7.87540                                                    
[12]	validation-rmse:7.71935




[0]	validation-rmse:11.30308                                                    
[1]	validation-rmse:10.52611                                                    
[2]	validation-rmse:9.86527                                                     
[3]	validation-rmse:9.30736                                                     
[4]	validation-rmse:8.83393                                                     
[5]	validation-rmse:8.44111                                                     
[6]	validation-rmse:8.11063                                                     
[7]	validation-rmse:7.83675                                                     
[8]	validation-rmse:7.61192                                                     
[9]	validation-rmse:7.42377                                                     
[10]	validation-rmse:7.26914                                                    
[11]	validation-rmse:7.13924                                                    
[12]	validation-rmse:7.03265




[0]	validation-rmse:11.20038                                                    
[1]	validation-rmse:10.35802                                                    
[2]	validation-rmse:9.65753                                                     
[3]	validation-rmse:9.08075                                                     
[4]	validation-rmse:8.60682                                                     
[5]	validation-rmse:8.21852                                                     
[6]	validation-rmse:7.90357                                                     
[7]	validation-rmse:7.65714                                                     
[8]	validation-rmse:7.45212                                                     
[9]	validation-rmse:7.28777                                                     
[10]	validation-rmse:7.15448                                                    
[11]	validation-rmse:7.04655                                                    
[12]	validation-rmse:6.95973




[0]	validation-rmse:11.58034                                                    
[1]	validation-rmse:11.01168                                                    
[2]	validation-rmse:10.50170                                                    
[3]	validation-rmse:10.04527                                                    
[4]	validation-rmse:9.63826                                                     
[5]	validation-rmse:9.27575                                                     
[6]	validation-rmse:8.95353                                                     
[7]	validation-rmse:8.66809                                                     
[8]	validation-rmse:8.41665                                                     
[9]	validation-rmse:8.19379                                                     
[10]	validation-rmse:7.99691                                                    
[11]	validation-rmse:7.82382                                                    
[12]	validation-rmse:7.67126




[0]	validation-rmse:11.05137                                                    
[1]	validation-rmse:10.10913                                                    
[2]	validation-rmse:9.34970                                                     
[3]	validation-rmse:8.74364                                                     
[4]	validation-rmse:8.26462                                                     
[5]	validation-rmse:7.88657                                                     
[6]	validation-rmse:7.59111                                                     
[7]	validation-rmse:7.36110                                                     
[8]	validation-rmse:7.18079                                                     
[9]	validation-rmse:7.04007                                                     
[10]	validation-rmse:6.92888                                                    
[11]	validation-rmse:6.84260                                                    
[12]	validation-rmse:6.77346




[0]	validation-rmse:11.80368                                                    
[1]	validation-rmse:11.42058                                                    
[2]	validation-rmse:11.06217                                                    
[3]	validation-rmse:10.72714                                                    
[4]	validation-rmse:10.41543                                                    
[5]	validation-rmse:10.12446                                                    
[6]	validation-rmse:9.85322                                                     
[7]	validation-rmse:9.60118                                                     
[8]	validation-rmse:9.36580                                                     
[9]	validation-rmse:9.14774                                                     
[10]	validation-rmse:8.94531                                                    
[11]	validation-rmse:8.75684                                                    
[12]	validation-rmse:8.58214




[0]	validation-rmse:11.50963                                                    
[1]	validation-rmse:10.88463                                                    
[2]	validation-rmse:10.33097                                                    
[3]	validation-rmse:9.84231                                                     
[4]	validation-rmse:9.41146                                                     
[5]	validation-rmse:9.03435                                                     
[6]	validation-rmse:8.70366                                                     
[7]	validation-rmse:8.41458                                                     
[8]	validation-rmse:8.16273                                                     
[9]	validation-rmse:7.94394                                                     
[10]	validation-rmse:7.75442                                                    
[11]	validation-rmse:7.59003                                                    
[12]	validation-rmse:7.44757




[1]	validation-rmse:11.23186                                                    
[2]	validation-rmse:10.80173                                                    
[3]	validation-rmse:10.40864                                                    
[4]	validation-rmse:10.04928                                                    
[5]	validation-rmse:9.72131                                                     
[6]	validation-rmse:9.42259                                                     
[7]	validation-rmse:9.15127                                                     
[8]	validation-rmse:8.90494                                                     
[9]	validation-rmse:8.68152                                                     
[10]	validation-rmse:8.47980                                                    
[11]	validation-rmse:8.29735                                                    
[12]	validation-rmse:8.13275                                                    
[13]	validation-rmse:7.98416




[0]	validation-rmse:11.43192                                                    
[1]	validation-rmse:10.74952                                                    
[2]	validation-rmse:10.15350                                                    
[3]	validation-rmse:9.63587                                                     
[4]	validation-rmse:9.18886                                                     
[5]	validation-rmse:8.80226                                                     
[6]	validation-rmse:8.47028                                                     
[7]	validation-rmse:8.18465                                                     
[8]	validation-rmse:7.94526                                                     
[9]	validation-rmse:7.73810                                                     
[10]	validation-rmse:7.56017                                                    
[11]	validation-rmse:7.41098                                                    
[12]	validation-rmse:7.28298




[0]	validation-rmse:10.96746                                                    
[1]	validation-rmse:9.97694                                                     
[2]	validation-rmse:9.20302                                                     
[3]	validation-rmse:8.59901                                                     
[4]	validation-rmse:8.12957                                                     
[5]	validation-rmse:7.76279                                                     
[6]	validation-rmse:7.48463                                                     
[7]	validation-rmse:7.27738                                                     
[8]	validation-rmse:7.12350                                                     
[9]	validation-rmse:6.99208                                                     
[10]	validation-rmse:6.89891                                                    
[11]	validation-rmse:6.82982                                                    
[12]	validation-rmse:6.77417




[0]	validation-rmse:11.70601                                                    
[1]	validation-rmse:11.23932                                                    
[2]	validation-rmse:10.81033                                                    
[3]	validation-rmse:10.41704                                                    
[4]	validation-rmse:10.05686                                                    
[5]	validation-rmse:9.72707                                                     
[6]	validation-rmse:9.42636                                                     
[7]	validation-rmse:9.15197                                                     
[8]	validation-rmse:8.90247                                                     
[9]	validation-rmse:8.67505                                                     
[10]	validation-rmse:8.46841                                                    
[11]	validation-rmse:8.28151                                                    
[12]	validation-rmse:8.11164




[0]	validation-rmse:10.72703                                                    
[1]	validation-rmse:9.61197                                                     
[2]	validation-rmse:8.77950                                                     
[3]	validation-rmse:8.17889                                                     
[4]	validation-rmse:7.74255                                                     
[5]	validation-rmse:7.42442                                                     
[6]	validation-rmse:7.19527                                                     
[7]	validation-rmse:7.03441                                                     
[8]	validation-rmse:6.91617                                                     
[9]	validation-rmse:6.83024                                                     
[10]	validation-rmse:6.76446                                                    
[11]	validation-rmse:6.71518                                                    
[12]	validation-rmse:6.68045




[2]	validation-rmse:11.04534                                                    
[3]	validation-rmse:10.70903                                                    
[4]	validation-rmse:10.39674                                                    
[5]	validation-rmse:10.10707                                                    
[6]	validation-rmse:9.83867                                                     
[7]	validation-rmse:9.59032                                                     
[8]	validation-rmse:9.36040                                                     
[9]	validation-rmse:9.14829                                                     
[10]	validation-rmse:8.95237                                                    
[11]	validation-rmse:8.77181                                                    
[12]	validation-rmse:8.60555                                                    
[13]	validation-rmse:8.45262                                                    
[14]	validation-rmse:8.31206




[2]	validation-rmse:9.42931                                                     
[3]	validation-rmse:8.83603                                                     
[4]	validation-rmse:8.36268                                                     
[5]	validation-rmse:7.98834                                                     
[6]	validation-rmse:7.69402                                                     
[7]	validation-rmse:7.46276                                                     
[8]	validation-rmse:7.28170                                                     
[9]	validation-rmse:7.14038                                                     
[10]	validation-rmse:7.03000                                                    
[11]	validation-rmse:6.94350                                                    
[12]	validation-rmse:6.87520                                                    
[13]	validation-rmse:6.82162                                                    
[14]	validation-rmse:6.77766




[0]	validation-rmse:9.29625                                                     
[1]	validation-rmse:7.84095                                                     
[2]	validation-rmse:7.16379                                                     
[3]	validation-rmse:6.84425                                                     
[4]	validation-rmse:6.69210                                                     
[5]	validation-rmse:6.61225                                                     
[6]	validation-rmse:6.56589                                                     
[7]	validation-rmse:6.53812                                                     
[8]	validation-rmse:6.51976                                                     
[9]	validation-rmse:6.51003                                                     
[10]	validation-rmse:6.50275                                                    
[11]	validation-rmse:6.50237                                                    
[12]	validation-rmse:6.49590




[0]	validation-rmse:11.46751                                                    
[1]	validation-rmse:10.80979                                                    
[2]	validation-rmse:10.23236                                                    
[3]	validation-rmse:9.72599                                                     
[4]	validation-rmse:9.28394                                                     
[5]	validation-rmse:8.89910                                                     
[6]	validation-rmse:8.56551                                                     
[7]	validation-rmse:8.27690                                                     
[8]	validation-rmse:8.02692                                                     
[9]	validation-rmse:7.81258                                                     
[10]	validation-rmse:7.62781                                                    
[11]	validation-rmse:7.46871                                                    
[12]	validation-rmse:7.33306




[0]	validation-rmse:11.64211                                                    
[1]	validation-rmse:11.12234                                                    
[2]	validation-rmse:10.65056                                                    
[3]	validation-rmse:10.22311                                                    
[4]	validation-rmse:9.83708                                                     
[5]	validation-rmse:9.48816                                                     
[6]	validation-rmse:9.17336                                                     
[7]	validation-rmse:8.88994                                                     
[8]	validation-rmse:8.63663                                                     
[9]	validation-rmse:8.40887                                                     
[10]	validation-rmse:8.20458                                                    
[11]	validation-rmse:8.02350                                                    
[12]	validation-rmse:7.86099




[0]	validation-rmse:10.61518                                                    
[1]	validation-rmse:9.45016                                                     
[2]	validation-rmse:8.60538                                                     
[3]	validation-rmse:8.00082                                                     
[4]	validation-rmse:7.57715                                                     
[5]	validation-rmse:7.28045                                                     
[6]	validation-rmse:7.07706                                                     
[7]	validation-rmse:6.93865                                                     
[8]	validation-rmse:6.83677                                                     
[9]	validation-rmse:6.76017                                                     
[10]	validation-rmse:6.70555                                                    
[11]	validation-rmse:6.66689                                                    
[12]	validation-rmse:6.63451




[0]	validation-rmse:11.35579                                                    
[1]	validation-rmse:10.62040                                                    
[2]	validation-rmse:9.99010                                                     
[3]	validation-rmse:9.44897                                                     
[4]	validation-rmse:8.99416                                                     
[5]	validation-rmse:8.60758                                                     
[6]	validation-rmse:8.28127                                                     
[7]	validation-rmse:8.00979                                                     
[8]	validation-rmse:7.77767                                                     
[9]	validation-rmse:7.58293                                                     
[10]	validation-rmse:7.42671                                                    
[11]	validation-rmse:7.28877                                                    
[12]	validation-rmse:7.17863




[1]	validation-rmse:10.30850                                                    
[2]	validation-rmse:9.59755                                                     
[3]	validation-rmse:9.01625                                                     
[4]	validation-rmse:8.54525                                                     
[5]	validation-rmse:8.16453                                                     
[6]	validation-rmse:7.85865                                                     
[7]	validation-rmse:7.61398                                                     
[8]	validation-rmse:7.41954                                                     
[9]	validation-rmse:7.26345                                                     
[10]	validation-rmse:7.13853                                                    
[11]	validation-rmse:7.03695                                                    
[12]	validation-rmse:6.95556                                                    
[13]	validation-rmse:6.88703




[0]	validation-rmse:10.42181                                                    
[1]	validation-rmse:9.15442                                                     
[2]	validation-rmse:8.27881                                                     
[3]	validation-rmse:7.68692                                                     
[4]	validation-rmse:7.29048                                                     
[5]	validation-rmse:7.02757                                                     
[6]	validation-rmse:6.85147                                                     
[7]	validation-rmse:6.73439                                                     
[8]	validation-rmse:6.65297                                                     
[9]	validation-rmse:6.59611                                                     
[10]	validation-rmse:6.55495                                                    
[11]	validation-rmse:6.52587                                                    
[12]	validation-rmse:6.50360




[0]	validation-rmse:10.86879                                                    
[1]	validation-rmse:9.81500                                                     
[2]	validation-rmse:8.99998                                                     
[3]	validation-rmse:8.37601                                                     
[4]	validation-rmse:7.90152                                                     
[5]	validation-rmse:7.54613                                                     
[6]	validation-rmse:7.28010                                                     
[7]	validation-rmse:7.08288                                                     
[8]	validation-rmse:6.93444                                                     
[9]	validation-rmse:6.82412                                                     
[10]	validation-rmse:6.73810                                                    
[11]	validation-rmse:6.67174                                                    
[12]	validation-rmse:6.62220




[0]	validation-rmse:9.43320                                                     
[1]	validation-rmse:7.97009                                                     
[2]	validation-rmse:7.25048                                                     
[3]	validation-rmse:6.89829                                                     
[4]	validation-rmse:6.72200                                                     
[5]	validation-rmse:6.62587                                                     
[6]	validation-rmse:6.57163                                                     
[7]	validation-rmse:6.53651                                                     
[8]	validation-rmse:6.51460                                                     
[9]	validation-rmse:6.49945                                                     
[10]	validation-rmse:6.49142                                                    
[11]	validation-rmse:6.48558                                                    
[12]	validation-rmse:6.48175




[1]	validation-rmse:8.89250                                                     
[2]	validation-rmse:8.05350                                                     
[3]	validation-rmse:7.53176                                                     
[4]	validation-rmse:7.20975                                                     
[5]	validation-rmse:7.00988                                                     
[6]	validation-rmse:6.87959                                                     
[7]	validation-rmse:6.79947                                                     
[8]	validation-rmse:6.74201                                                     
[9]	validation-rmse:6.70706                                                     
[10]	validation-rmse:6.68432                                                    
[11]	validation-rmse:6.66764                                                    
[12]	validation-rmse:6.65613                                                    
[13]	validation-rmse:6.64758




[6]	validation-rmse:9.73418                                                     
[7]	validation-rmse:9.48019                                                     
[8]	validation-rmse:9.24687                                                     
[9]	validation-rmse:9.03301                                                     
[10]	validation-rmse:8.83618                                                    
[11]	validation-rmse:8.65583                                                    
[12]	validation-rmse:8.49108                                                    
[13]	validation-rmse:8.34009                                                    
[14]	validation-rmse:8.20200                                                    
[15]	validation-rmse:8.07619                                                    
[16]	validation-rmse:7.96154                                                    
[17]	validation-rmse:7.85692                                                    
[18]	validation-rmse:7.76209




[3]	validation-rmse:6.85794                                                     
[4]	validation-rmse:6.78228                                                     
[5]	validation-rmse:6.74262                                                     
[6]	validation-rmse:6.72489                                                     
[7]	validation-rmse:6.71350                                                     
[8]	validation-rmse:6.70751                                                     
[9]	validation-rmse:6.70151                                                     
[10]	validation-rmse:6.69909                                                    
[11]	validation-rmse:6.69666                                                    
[12]	validation-rmse:6.69341                                                    
[13]	validation-rmse:6.68919                                                    
[14]	validation-rmse:6.68598                                                    
[15]	validation-rmse:6.68087




[0]	validation-rmse:10.77945                                                    
[1]	validation-rmse:9.68425                                                     
[2]	validation-rmse:8.85775                                                     
[3]	validation-rmse:8.24207                                                     
[4]	validation-rmse:7.78946                                                     
[5]	validation-rmse:7.45855                                                     
[6]	validation-rmse:7.21766                                                     
[7]	validation-rmse:7.04008                                                     
[8]	validation-rmse:6.91188                                                     
[9]	validation-rmse:6.81653                                                     
[10]	validation-rmse:6.74697                                                    
[11]	validation-rmse:6.69572                                                    
[12]	validation-rmse:6.65369

In [39]:
best_params = {
    'learning_rate': 0.11558775676492819,
    'max_depth': 49,
    'min_child_weight': 1.3913919084738615,
    'objective': 'reg:linear',
    'reg_alpha': 0.35940520308672946,
    'reg_lambda': 0.09603619965585856,
    'seed': 42
}

mlflow.xgboost.autolog()



booster = xgb.train(
    params=best_params,
    dtrain=train,
    num_boost_round=1000,
    evals=[(valid, 'validation')],
    early_stopping_rounds=50
)


2024/03/18 10:49:45 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '41f3039d3ed5403f9d30a0186672b6fb', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current xgboost workflow


[0]	validation-rmse:11.28055
[1]	validation-rmse:10.48706
[2]	validation-rmse:9.81640
[3]	validation-rmse:9.25292
[4]	validation-rmse:8.77872
[5]	validation-rmse:8.38541
[6]	validation-rmse:8.05802
[7]	validation-rmse:7.78932
[8]	validation-rmse:7.56436
[9]	validation-rmse:7.38109
[10]	validation-rmse:7.23066
[11]	validation-rmse:7.10681
[12]	validation-rmse:7.00336
[13]	validation-rmse:6.91772
[14]	validation-rmse:6.84847
[15]	validation-rmse:6.78996
[16]	validation-rmse:6.74116
[17]	validation-rmse:6.70106
[18]	validation-rmse:6.66667
[19]	validation-rmse:6.63853
[20]	validation-rmse:6.61374
[21]	validation-rmse:6.59305
[22]	validation-rmse:6.57537
[23]	validation-rmse:6.55982
[24]	validation-rmse:6.54666
[25]	validation-rmse:6.53387
[26]	validation-rmse:6.52463
[27]	validation-rmse:6.51517
[28]	validation-rmse:6.50701
[29]	validation-rmse:6.50019
[30]	validation-rmse:6.49398
[31]	validation-rmse:6.48864
[32]	validation-rmse:6.48330
[33]	validation-rmse:6.47881
[34]	validation-rmse:6



In [38]:
with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        'learning_rate': 0.11558775676492819,
        'max_depth': 49,
        'min_child_weight': 1.3913919084738615,
        'objective': 'reg:linear',
        'reg_alpha': 0.35940520308672946,
        'reg_lambda': 0.09603619965585856,
        'seed': 42
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, 'validation')],
        early_stopping_rounds=50
    )

    y_pred = booster.predict(valid)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")



[0]	validation-rmse:11.28055
[1]	validation-rmse:10.48706
[2]	validation-rmse:9.81640
[3]	validation-rmse:9.25292
[4]	validation-rmse:8.77872
[5]	validation-rmse:8.38541
[6]	validation-rmse:8.05802
[7]	validation-rmse:7.78932
[8]	validation-rmse:7.56436
[9]	validation-rmse:7.38109
[10]	validation-rmse:7.23066
[11]	validation-rmse:7.10681
[12]	validation-rmse:7.00336
[13]	validation-rmse:6.91772
[14]	validation-rmse:6.84847
[15]	validation-rmse:6.78996
[16]	validation-rmse:6.74116
[17]	validation-rmse:6.70106
[18]	validation-rmse:6.66667
[19]	validation-rmse:6.63853
[20]	validation-rmse:6.61374
[21]	validation-rmse:6.59305
[22]	validation-rmse:6.57537
[23]	validation-rmse:6.55982
[24]	validation-rmse:6.54666
[25]	validation-rmse:6.53387
[26]	validation-rmse:6.52463
[27]	validation-rmse:6.51517
[28]	validation-rmse:6.50701
[29]	validation-rmse:6.50019
[30]	validation-rmse:6.49398
[31]	validation-rmse:6.48864
[32]	validation-rmse:6.48330
[33]	validation-rmse:6.47881
[34]	validation-rmse:6



In [None]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)
        



In [40]:
import mlflow
logged_model = 'runs:/41f3039d3ed5403f9d30a0186672b6fb/model'

# Load model as a PyFuncModel.
loaded_model = mlflow.xgboost.load_model(logged_model)

# Predict on a Pandas DataFrame.
import pandas as pd
y = loaded_model.predict(valid)
y[:20]



array([14.293808 ,  6.8635406, 14.469265 , 24.46155  ,  9.345045 ,
       17.146893 , 10.557908 ,  8.365253 ,  9.519408 , 18.608387 ,
       11.337758 , 16.273603 ,  8.86146  ,  5.844951 , 17.037682 ,
       10.301903 , 10.891236 ,  9.453916 ,  5.4378595,  8.083388 ],
      dtype=float32)