In [1]:
!python -V

Python 3.9.15


In [2]:
import pandas as pd

In [3]:
import pickle

In [4]:
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

In [9]:
import mlflow


mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

2024/05/21 23:04:36 INFO mlflow.tracking.fluent: Experiment with name 'nyc-taxi-experiment' does not exist. Creating a new experiment.


<Experiment: artifact_location='/Users/weishanhe/my_github/mlops-zoomcamp/02-experiment-tracking/mlruns/1', creation_time=1716347076037, experiment_id='1', last_update_time=1716347076037, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [19]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)

    df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
    df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [20]:
df_train = read_dataframe('./data/green_tripdata_2021-01.parquet')
df_val = read_dataframe('./data/green_tripdata_2021-02.parquet')

In [22]:
len(df_train), len(df_val)

(73908, 61921)

In [23]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [24]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [25]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [26]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)



7.758715208946364

In [28]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [31]:
with mlflow.start_run():

    mlflow.set_tag("developer", "cristian")

    mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.parquet")
    mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.parquet")

    alpha = 0.01
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")



In [32]:
import xgboost as xgb

In [33]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [34]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [35]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [36]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0), # exp(-3), exp(0) - [0.05, 1]
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:8.77362                           
  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[1]	validation-rmse:7.41847                           
[2]	validation-rmse:6.92566                           
[3]	validation-rmse:6.73735                           
[4]	validation-rmse:6.65530                           
[5]	validation-rmse:6.61400                           
[6]	validation-rmse:6.59376                           
[7]	validation-rmse:6.58409                           
[8]	validation-rmse:6.57972                           
[9]	validation-rmse:6.57148                           
[10]	validation-rmse:6.56788                          
[11]	validation-rmse:6.56165                          
[12]	validation-rmse:6.55783                          
[13]	validation-rmse:6.55430                          
[14]	validation-rmse:6.55148                          
[15]	validation-rmse:6.54394                          
[16]	validation-rmse:6.54034                          
[17]	validation-rmse:6.53686                          
[18]	validation-rmse:6.53114                          
[19]	valid





[0]	validation-rmse:11.62635                                                    
[1]	validation-rmse:11.09542                                                    
[2]	validation-rmse:10.61662                                                    
[3]	validation-rmse:10.18503                                                    
[4]	validation-rmse:9.79628                                                     
[5]	validation-rmse:9.44875                                                     
[6]	validation-rmse:9.13577                                                     
[7]	validation-rmse:8.85573                                                     
[8]	validation-rmse:8.60896                                                     
[9]	validation-rmse:8.38698                                                     
[10]	validation-rmse:8.19095                                                    
[11]	validation-rmse:8.01326                                                    
[12]	validation-rmse:7.85999





[0]	validation-rmse:6.62924                                                     
[1]	validation-rmse:6.57259                                                     
[2]	validation-rmse:6.56982                                                     
[3]	validation-rmse:6.56166                                                     
[4]	validation-rmse:6.55130                                                     
[5]	validation-rmse:6.54182                                                     
[6]	validation-rmse:6.52252                                                     
[7]	validation-rmse:6.51587                                                     
[8]	validation-rmse:6.51013                                                     
[9]	validation-rmse:6.50956                                                     
[10]	validation-rmse:6.50444                                                    
[11]	validation-rmse:6.50012                                                    
[12]	validation-rmse:6.49427





[3]	validation-rmse:7.69975                                                     
[4]	validation-rmse:7.35679                                                     
[5]	validation-rmse:7.13909                                                     
[6]	validation-rmse:6.99870                                                     
[7]	validation-rmse:6.90664                                                     
[8]	validation-rmse:6.84589                                                     
[9]	validation-rmse:6.80734                                                     
[10]	validation-rmse:6.77985                                                    
[11]	validation-rmse:6.76234                                                    
[12]	validation-rmse:6.75008                                                    
[13]	validation-rmse:6.73930                                                    
[14]	validation-rmse:6.73242                                                    
[15]	validation-rmse:6.72667





[0]	validation-rmse:7.24425                                                     
[1]	validation-rmse:6.66354                                                     
[2]	validation-rmse:6.56423                                                     
[3]	validation-rmse:6.53752                                                     
[4]	validation-rmse:6.52349                                                     
[5]	validation-rmse:6.51839                                                     
[6]	validation-rmse:6.50681                                                     
[7]	validation-rmse:6.50164                                                     
[8]	validation-rmse:6.49320                                                     
[9]	validation-rmse:6.48793                                                     
[10]	validation-rmse:6.48191                                                    
[11]	validation-rmse:6.47748                                                    
[12]	validation-rmse:6.47107





[0]	validation-rmse:10.10366                                                    
[1]	validation-rmse:8.74568                                                     
[2]	validation-rmse:7.89354                                                     
[3]	validation-rmse:7.37796                                                     
[4]	validation-rmse:7.06458                                                     
[5]	validation-rmse:6.87371                                                     
[6]	validation-rmse:6.75584                                                     
[7]	validation-rmse:6.68034                                                     
[8]	validation-rmse:6.63189                                                     
[9]	validation-rmse:6.59855                                                     
[10]	validation-rmse:6.57467                                                    
[11]	validation-rmse:6.55780                                                    
[12]	validation-rmse:6.54446





[0]	validation-rmse:11.19680                                                    
[1]	validation-rmse:10.35101                                                    
[2]	validation-rmse:9.64826                                                     
[3]	validation-rmse:9.06861                                                     
[4]	validation-rmse:8.59472                                                     
[5]	validation-rmse:8.21338                                                     
[6]	validation-rmse:7.89565                                                     
[7]	validation-rmse:7.64659                                                     
[8]	validation-rmse:7.44445                                                     
[9]	validation-rmse:7.28359                                                     
[10]	validation-rmse:7.15391                                                    
[11]	validation-rmse:7.05068                                                    
[12]	validation-rmse:6.96806





[0]	validation-rmse:7.96044                                                     
[1]	validation-rmse:6.95347                                                     
[2]	validation-rmse:6.70610                                                     
[3]	validation-rmse:6.62341                                                     
[4]	validation-rmse:6.58802                                                     
[5]	validation-rmse:6.57446                                                     
[6]	validation-rmse:6.56648                                                     
[7]	validation-rmse:6.56122                                                     
[8]	validation-rmse:6.55377                                                     
[9]	validation-rmse:6.54620                                                     
[10]	validation-rmse:6.54086                                                    
[11]	validation-rmse:6.53526                                                    
[12]	validation-rmse:6.53342





[0]	validation-rmse:11.08889                                                    
[1]	validation-rmse:10.16740                                                    
[2]	validation-rmse:9.41901                                                     
[3]	validation-rmse:8.81600                                                     
[4]	validation-rmse:8.33371                                                     
[5]	validation-rmse:7.95010                                                     
[6]	validation-rmse:7.64756                                                     
[7]	validation-rmse:7.40796                                                     
[8]	validation-rmse:7.22084                                                     
[9]	validation-rmse:7.07327                                                     
[10]	validation-rmse:6.95741                                                    
[11]	validation-rmse:6.86629                                                    
[12]	validation-rmse:6.79390





[0]	validation-rmse:7.73093                                                     
[1]	validation-rmse:6.77778                                                     
[2]	validation-rmse:6.56756                                                     
[3]	validation-rmse:6.50485                                                     
[4]	validation-rmse:6.47293                                                     
[5]	validation-rmse:6.45762                                                     
[6]	validation-rmse:6.44891                                                     
[7]	validation-rmse:6.44199                                                     
[8]	validation-rmse:6.43523                                                     
[9]	validation-rmse:6.42971                                                     
[10]	validation-rmse:6.42239                                                    
[11]	validation-rmse:6.41643                                                    
[12]	validation-rmse:6.41305





[1]	validation-rmse:11.21987                                                     
[2]	validation-rmse:10.78592                                                     
[3]	validation-rmse:10.38983                                                     
[4]	validation-rmse:10.02882                                                     
[5]	validation-rmse:9.70031                                                      
[6]	validation-rmse:9.40188                                                      
[7]	validation-rmse:9.13162                                                      
[8]	validation-rmse:8.88660                                                      
[9]	validation-rmse:8.66552                                                      
[10]	validation-rmse:8.46556                                                     
[11]	validation-rmse:8.28557                                                     
[12]	validation-rmse:8.12330                                                     
[13]	validation-





[1]	validation-rmse:10.97700                                                     
[2]	validation-rmse:10.45806                                                     
[3]	validation-rmse:9.99778                                                      
[4]	validation-rmse:9.58988                                                      
[5]	validation-rmse:9.22982                                                      
[6]	validation-rmse:8.91278                                                      
[7]	validation-rmse:8.63440                                                      
[8]	validation-rmse:8.39024                                                      
[9]	validation-rmse:8.17662                                                      
[10]	validation-rmse:7.99027                                                     
[11]	validation-rmse:7.82788                                                     
[12]	validation-rmse:7.68568                                                     
[13]	validation-





[1]	validation-rmse:9.53996                                                      
[2]	validation-rmse:8.70600                                                      
[3]	validation-rmse:8.10570                                                      
[4]	validation-rmse:7.67774                                                      
[5]	validation-rmse:7.37737                                                      
[6]	validation-rmse:7.16742                                                      
[7]	validation-rmse:7.01735                                                      
[8]	validation-rmse:6.91203                                                      
[9]	validation-rmse:6.83677                                                      
[10]	validation-rmse:6.78102                                                     
[11]	validation-rmse:6.73982                                                     
[12]	validation-rmse:6.71014                                                     
[13]	validation-





[1]	validation-rmse:6.67583                                                      
[2]	validation-rmse:6.61947                                                      
[3]	validation-rmse:6.60049                                                      
[4]	validation-rmse:6.59073                                                      
[5]	validation-rmse:6.58334                                                      
[6]	validation-rmse:6.57485                                                      
[7]	validation-rmse:6.56760                                                      
[8]	validation-rmse:6.56410                                                      
[9]	validation-rmse:6.55867                                                      
[10]	validation-rmse:6.55274                                                     
[11]	validation-rmse:6.54349                                                     
[12]	validation-rmse:6.53463                                                     
[13]	validation-





[0]	validation-rmse:11.73292                                                     
[1]	validation-rmse:11.28975                                                     
[2]	validation-rmse:10.88146                                                     
[3]	validation-rmse:10.50592                                                     
[4]	validation-rmse:10.16067                                                     
[5]	validation-rmse:9.84368                                                      
[6]	validation-rmse:9.55293                                                      
[7]	validation-rmse:9.28700                                                      
[8]	validation-rmse:9.04392                                                      
[9]	validation-rmse:8.82222                                                      
[10]	validation-rmse:8.61982                                                     
[11]	validation-rmse:8.43638                                                     
[12]	validation-





[0]	validation-rmse:11.64783                                                     
[1]	validation-rmse:11.13350                                                     
[2]	validation-rmse:10.66645                                                     
[3]	validation-rmse:10.24338                                                     
[4]	validation-rmse:9.86045                                                      
[5]	validation-rmse:9.51356                                                      
[6]	validation-rmse:9.20123                                                      
[7]	validation-rmse:8.92062                                                      
[8]	validation-rmse:8.66795                                                      
[9]	validation-rmse:8.44145                                                      
[10]	validation-rmse:8.23856                                                     
[11]	validation-rmse:8.05713                                                     
[12]	validation-





[0]	validation-rmse:9.44922                                                      
[1]	validation-rmse:8.02242                                                      
[2]	validation-rmse:7.32574                                                      
[3]	validation-rmse:6.99370                                                      
[4]	validation-rmse:6.82752                                                      
[5]	validation-rmse:6.73944                                                      
[6]	validation-rmse:6.69214                                                      
[7]	validation-rmse:6.66308                                                      
[8]	validation-rmse:6.64226                                                      
[9]	validation-rmse:6.63203                                                      
[10]	validation-rmse:6.62071                                                     
[11]	validation-rmse:6.60941                                                     
[12]	validation-





[1]	validation-rmse:10.70559                                                     
[2]	validation-rmse:10.10097                                                     
[3]	validation-rmse:9.58035                                                      
[4]	validation-rmse:9.13449                                                      
[5]	validation-rmse:8.75461                                                      
[6]	validation-rmse:8.43122                                                      
[7]	validation-rmse:8.15761                                                      
[8]	validation-rmse:7.92628                                                      
[9]	validation-rmse:7.73138                                                      
[10]	validation-rmse:7.56747                                                     
[11]	validation-rmse:7.42995                                                     
[12]	validation-rmse:7.31460                                                     
[13]	validation-





[0]	validation-rmse:6.97999
[1]	validation-rmse:6.68010                                                      
[2]	validation-rmse:6.63805                                                      
[3]	validation-rmse:6.62614                                                      
[4]	validation-rmse:6.61668                                                      
[5]	validation-rmse:6.60924                                                      
[6]	validation-rmse:6.60128                                                      
[7]	validation-rmse:6.59387                                                      
[8]	validation-rmse:6.58859                                                      
[9]	validation-rmse:6.58315                                                      
[10]	validation-rmse:6.57813                                                     
[11]	validation-rmse:6.57476                                                     
[12]	validation-rmse:6.57102                                          





[1]	validation-rmse:7.19103                                                      
[2]	validation-rmse:6.88592                                                      
[3]	validation-rmse:6.78804                                                      
[4]	validation-rmse:6.73603                                                      
[5]	validation-rmse:6.70822                                                      
[6]	validation-rmse:6.69517                                                      
[7]	validation-rmse:6.68665                                                      
[8]	validation-rmse:6.68327                                                      
[9]	validation-rmse:6.68235                                                      
[10]	validation-rmse:6.68129                                                     
[11]	validation-rmse:6.67948                                                     
[12]	validation-rmse:6.67904                                                     
[13]	validation-





[6]	validation-rmse:6.78765                                                      
[7]	validation-rmse:6.76144                                                      
[8]	validation-rmse:6.74429                                                      
[9]	validation-rmse:6.73623                                                      
[10]	validation-rmse:6.73194                                                     
[11]	validation-rmse:6.72620                                                     
[12]	validation-rmse:6.72199                                                     
[13]	validation-rmse:6.71919                                                     
[14]	validation-rmse:6.71223                                                     
[15]	validation-rmse:6.71032                                                     
[16]	validation-rmse:6.70811                                                     
[17]	validation-rmse:6.70610                                                     
[18]	validation-





[3]	validation-rmse:6.96672                                                      
[4]	validation-rmse:6.82995                                                      
[5]	validation-rmse:6.76199                                                      
[6]	validation-rmse:6.72639                                                      
[7]	validation-rmse:6.70789                                                      
[8]	validation-rmse:6.69342                                                      
[9]	validation-rmse:6.68485                                                      
[10]	validation-rmse:6.68414                                                     
[11]	validation-rmse:6.67997                                                     
[12]	validation-rmse:6.67751                                                     
[13]	validation-rmse:6.67206                                                     
[14]	validation-rmse:6.66877                                                     
[15]	validation-





[1]	validation-rmse:7.74330                                                     
[2]	validation-rmse:7.14838                                                     
[3]	validation-rmse:6.89670                                                     
[4]	validation-rmse:6.78123                                                     
[5]	validation-rmse:6.73071                                                     
[6]	validation-rmse:6.70083                                                     
[7]	validation-rmse:6.68518                                                     
[8]	validation-rmse:6.67526                                                     
[9]	validation-rmse:6.66602                                                     
[10]	validation-rmse:6.66304                                                    
[11]	validation-rmse:6.65965                                                    
[12]	validation-rmse:6.65671                                                    
[13]	validation-rmse:6.65512





[6]	validation-rmse:6.81506                                                     
[7]	validation-rmse:6.80251                                                     
[8]	validation-rmse:6.78432                                                     
[9]	validation-rmse:6.77752                                                     
[10]	validation-rmse:6.77439                                                    
[11]	validation-rmse:6.77233                                                    
[12]	validation-rmse:6.76761                                                    
[13]	validation-rmse:6.76633                                                    
[14]	validation-rmse:6.76314                                                    
[15]	validation-rmse:6.76156                                                    
[16]	validation-rmse:6.75934                                                    
[17]	validation-rmse:6.75665                                                    
[18]	validation-rmse:6.75531





[1]	validation-rmse:8.54802                                                     
[2]	validation-rmse:7.73351                                                     
[3]	validation-rmse:7.26821                                                     
[4]	validation-rmse:7.00839                                                     
[5]	validation-rmse:6.85612                                                     
[6]	validation-rmse:6.76671                                                     
[7]	validation-rmse:6.71199                                                     
[8]	validation-rmse:6.67568                                                     
[9]	validation-rmse:6.65295                                                     
[10]	validation-rmse:6.63980                                                    
[11]	validation-rmse:6.62716                                                    
[12]	validation-rmse:6.61775                                                    
[13]	validation-rmse:6.61180





[0]	validation-rmse:10.74263                                                    
[1]	validation-rmse:9.63338                                                     
[2]	validation-rmse:8.80667                                                     
[3]	validation-rmse:8.20536                                                     
[4]	validation-rmse:7.76579                                                     
[5]	validation-rmse:7.44673                                                     
[6]	validation-rmse:7.21836                                                     
[7]	validation-rmse:7.05195                                                     
[8]	validation-rmse:6.93555                                                     
[9]	validation-rmse:6.85090                                                     
[10]	validation-rmse:6.77871                                                    
[11]	validation-rmse:6.72797                                                    
[12]	validation-rmse:6.69072





[0]	validation-rmse:8.52891                                                     
[1]	validation-rmse:7.21094                                                     
[2]	validation-rmse:6.77377                                                     
[3]	validation-rmse:6.61053                                                     
[4]	validation-rmse:6.54139                                                     
[5]	validation-rmse:6.50964                                                     
[6]	validation-rmse:6.49307                                                     
[7]	validation-rmse:6.48364                                                     
[8]	validation-rmse:6.47631                                                     
[9]	validation-rmse:6.47122                                                     
[10]	validation-rmse:6.46447                                                    
[11]	validation-rmse:6.45954                                                    
[12]	validation-rmse:6.45477





[2]	validation-rmse:7.51385                                                     
[3]	validation-rmse:7.12312                                                     
[4]	validation-rmse:6.92179                                                     
[5]	validation-rmse:6.81121                                                     
[6]	validation-rmse:6.74985                                                     
[7]	validation-rmse:6.71528                                                     
[8]	validation-rmse:6.68982                                                     
[9]	validation-rmse:6.67262                                                     
[10]	validation-rmse:6.66266                                                    
[11]	validation-rmse:6.65617                                                    
[12]	validation-rmse:6.65446                                                    
[13]	validation-rmse:6.65033                                                    
[14]	validation-rmse:6.64680





[1]	validation-rmse:8.37372                                                    
[2]	validation-rmse:7.60565                                                    
[3]	validation-rmse:7.19085                                                    
[4]	validation-rmse:6.96902                                                    
[5]	validation-rmse:6.84476                                                    
[6]	validation-rmse:6.77866                                                    
[7]	validation-rmse:6.73969                                                    
[8]	validation-rmse:6.71148                                                    
[9]	validation-rmse:6.69143                                                    
[10]	validation-rmse:6.67960                                                   
[11]	validation-rmse:6.67215                                                   
[12]	validation-rmse:6.67159                                                   
[13]	validation-rmse:6.66610            





[7]	validation-rmse:7.36453                                                    
[8]	validation-rmse:7.23277                                                    
[9]	validation-rmse:7.13342                                                    
[10]	validation-rmse:7.05852                                                   
[11]	validation-rmse:7.00309                                                   
[12]	validation-rmse:6.96200                                                   
[13]	validation-rmse:6.93221                                                   
[14]	validation-rmse:6.90623                                                   
[15]	validation-rmse:6.88659                                                   
[16]	validation-rmse:6.86875                                                   
[17]	validation-rmse:6.85654                                                   
[18]	validation-rmse:6.84746                                                   
[19]	validation-rmse:6.83918            





[0]	validation-rmse:10.33242                                                   
[1]	validation-rmse:9.04927                                                    
[2]	validation-rmse:8.19696                                                    
[3]	validation-rmse:7.64245                                                    
[4]	validation-rmse:7.28466                                                    
[5]	validation-rmse:7.05437                                                    
[6]	validation-rmse:6.90662                                                    
[7]	validation-rmse:6.80815                                                    
[8]	validation-rmse:6.74007                                                    
[9]	validation-rmse:6.69272                                                    
[10]	validation-rmse:6.66011                                                   
[11]	validation-rmse:6.63777                                                   
[12]	validation-rmse:6.61992            





[1]	validation-rmse:9.23019                                                     
[2]	validation-rmse:8.37526                                                     
[3]	validation-rmse:7.79761                                                     
[4]	validation-rmse:7.41027                                                     
[5]	validation-rmse:7.15264                                                     
[6]	validation-rmse:6.98034                                                     
[7]	validation-rmse:6.86663                                                     
[8]	validation-rmse:6.78806                                                     
[9]	validation-rmse:6.73141                                                     
[10]	validation-rmse:6.69252                                                    
[11]	validation-rmse:6.66358                                                    
[12]	validation-rmse:6.64209                                                    
[13]	validation-rmse:6.62592





[0]	validation-rmse:11.18990                                                    
[1]	validation-rmse:10.33579                                                    
[2]	validation-rmse:9.62749                                                     
[3]	validation-rmse:9.04243                                                     
[4]	validation-rmse:8.56436                                                     
[5]	validation-rmse:8.17528                                                     
[6]	validation-rmse:7.86122                                                     
[7]	validation-rmse:7.60696                                                     
[8]	validation-rmse:7.40292                                                     
[9]	validation-rmse:7.23871                                                     
[10]	validation-rmse:7.10661                                                    
[11]	validation-rmse:6.99959                                                    
[12]	validation-rmse:6.91329





[0]	validation-rmse:10.95195                                                    
[1]	validation-rmse:9.95420                                                     
[2]	validation-rmse:9.17110                                                     
[3]	validation-rmse:8.56402                                                     
[4]	validation-rmse:8.09976                                                     
[5]	validation-rmse:7.74186                                                     
[6]	validation-rmse:7.47321                                                     
[7]	validation-rmse:7.26668                                                     
[8]	validation-rmse:7.11222                                                     
[9]	validation-rmse:6.99362                                                     
[10]	validation-rmse:6.90080                                                    
[11]	validation-rmse:6.83121                                                    
[12]	validation-rmse:6.77630





[0]	validation-rmse:11.36826                                                    
[1]	validation-rmse:10.64000                                                    
[2]	validation-rmse:10.01471                                                    
[3]	validation-rmse:9.47913                                                     
[4]	validation-rmse:9.02332                                                     
[5]	validation-rmse:8.63776                                                     
[6]	validation-rmse:8.31271                                                     
[7]	validation-rmse:8.03582                                                     
[8]	validation-rmse:7.80409                                                     
[9]	validation-rmse:7.61327                                                     
[10]	validation-rmse:7.44779                                                    
[11]	validation-rmse:7.31310                                                    
[12]	validation-rmse:7.19905





[0]	validation-rmse:11.78932                                                    
[1]	validation-rmse:11.39472                                                    
[2]	validation-rmse:11.02673                                                    
[3]	validation-rmse:10.68497                                                    
[4]	validation-rmse:10.36782                                                    
[5]	validation-rmse:10.07306                                                    
[6]	validation-rmse:9.79911                                                     
[7]	validation-rmse:9.54680                                                     
[8]	validation-rmse:9.31229                                                     
[9]	validation-rmse:9.09607                                                     
[10]	validation-rmse:8.89589                                                    
[11]	validation-rmse:8.71243                                                    
[12]	validation-rmse:8.54176





[0]	validation-rmse:10.20965                                                    
[1]	validation-rmse:8.87866                                                     
[2]	validation-rmse:8.01805                                                     
[3]	validation-rmse:7.47246                                                     
[4]	validation-rmse:7.13123                                                     
[5]	validation-rmse:6.92075                                                     
[6]	validation-rmse:6.78403                                                     
[7]	validation-rmse:6.69735                                                     
[8]	validation-rmse:6.64144                                                     
[9]	validation-rmse:6.60221                                                     
[10]	validation-rmse:6.57533                                                    
[11]	validation-rmse:6.55339                                                    
[12]	validation-rmse:6.53696





[0]	validation-rmse:11.02357                                                    
[1]	validation-rmse:10.06528                                                    
[2]	validation-rmse:9.30202                                                     
[3]	validation-rmse:8.70450                                                     
[4]	validation-rmse:8.23382                                                     
[5]	validation-rmse:7.86679                                                     
[6]	validation-rmse:7.57558                                                     
[7]	validation-rmse:7.36092                                                     
[8]	validation-rmse:7.19255                                                     
[9]	validation-rmse:7.05943                                                     
[10]	validation-rmse:6.95541                                                    
[11]	validation-rmse:6.87642                                                    
[12]	validation-rmse:6.81329





[0]	validation-rmse:11.32915                                                    
[1]	validation-rmse:10.56977                                                    
[2]	validation-rmse:9.92016                                                     
[3]	validation-rmse:9.36831                                                     
[4]	validation-rmse:8.90108                                                     
[5]	validation-rmse:8.50788                                                     
[6]	validation-rmse:8.17709                                                     
[7]	validation-rmse:7.90046                                                     
[8]	validation-rmse:7.67049                                                     
[9]	validation-rmse:7.47904                                                     
[10]	validation-rmse:7.31919                                                    
[11]	validation-rmse:7.18714                                                    
[12]	validation-rmse:7.07748





[0]	validation-rmse:11.52726                                                    
[1]	validation-rmse:10.91919                                                    
[2]	validation-rmse:10.37852                                                    
[3]	validation-rmse:9.90340                                                     
[4]	validation-rmse:9.48224                                                     
[5]	validation-rmse:9.11476                                                     
[6]	validation-rmse:8.79006                                                     
[7]	validation-rmse:8.50946                                                     
[8]	validation-rmse:8.26286                                                     
[9]	validation-rmse:8.05061                                                     
[10]	validation-rmse:7.86292                                                    
[11]	validation-rmse:7.69822                                                    
[12]	validation-rmse:7.55790





[0]	validation-rmse:10.45640                                                    
[1]	validation-rmse:9.19751                                                     
[2]	validation-rmse:8.32128                                                     
[3]	validation-rmse:7.71506                                                     
[4]	validation-rmse:7.30825                                                     
[5]	validation-rmse:7.03240                                                     
[6]	validation-rmse:6.84962                                                     
[7]	validation-rmse:6.72378                                                     
[8]	validation-rmse:6.63852                                                     
[9]	validation-rmse:6.57891                                                     
[10]	validation-rmse:6.53723                                                    
[11]	validation-rmse:6.50590                                                    
[12]	validation-rmse:6.48295





[0]	validation-rmse:10.94741                                                    
[1]	validation-rmse:9.94176                                                     
[2]	validation-rmse:9.15117                                                     
[3]	validation-rmse:8.53556                                                     
[4]	validation-rmse:8.06103                                                     
[5]	validation-rmse:7.69927                                                     
[6]	validation-rmse:7.42439                                                     
[7]	validation-rmse:7.21441                                                     
[8]	validation-rmse:7.05169                                                     
[9]	validation-rmse:6.92851                                                     
[10]	validation-rmse:6.83492                                                    
[11]	validation-rmse:6.76335                                                    
[12]	validation-rmse:6.70753





[0]	validation-rmse:11.21315                                                    
[1]	validation-rmse:10.37747                                                    
[2]	validation-rmse:9.68155                                                     
[3]	validation-rmse:9.10660                                                     
[4]	validation-rmse:8.63275                                                     
[5]	validation-rmse:8.24775                                                     
[6]	validation-rmse:7.93387                                                     
[7]	validation-rmse:7.67893                                                     
[8]	validation-rmse:7.47405                                                     
[9]	validation-rmse:7.30726                                                     
[10]	validation-rmse:7.17366                                                    
[11]	validation-rmse:7.06565                                                    
[12]	validation-rmse:6.97710





[0]	validation-rmse:10.07800                                                    
[1]	validation-rmse:8.72338                                                     
[2]	validation-rmse:7.89630                                                     
[3]	validation-rmse:7.38809                                                     
[4]	validation-rmse:7.09157                                                     
[5]	validation-rmse:6.91179                                                     
[6]	validation-rmse:6.79381                                                     
[7]	validation-rmse:6.72494                                                     
[8]	validation-rmse:6.67584                                                     
[9]	validation-rmse:6.64665                                                     
[10]	validation-rmse:6.62408                                                    
[11]	validation-rmse:6.60713                                                    
[12]	validation-rmse:6.59387





[0]	validation-rmse:11.51991                                                    
[1]	validation-rmse:10.90372                                                    
[2]	validation-rmse:10.35781                                                    
[3]	validation-rmse:9.87614                                                     
[4]	validation-rmse:9.45132                                                     
[5]	validation-rmse:9.07858                                                     
[6]	validation-rmse:8.75213                                                     
[7]	validation-rmse:8.46699                                                     
[8]	validation-rmse:8.21843                                                     
[9]	validation-rmse:8.00208                                                     
[10]	validation-rmse:7.81344                                                    
[11]	validation-rmse:7.64993                                                    
[12]	validation-rmse:7.50826





[0]	validation-rmse:10.66595                                                    
[1]	validation-rmse:9.51904                                                     
[2]	validation-rmse:8.68523                                                     
[3]	validation-rmse:8.08520                                                     
[4]	validation-rmse:7.65359                                                     
[5]	validation-rmse:7.35165                                                     
[6]	validation-rmse:7.13844                                                     
[7]	validation-rmse:6.98578                                                     
[8]	validation-rmse:6.87707                                                     
[9]	validation-rmse:6.79901                                                     
[10]	validation-rmse:6.73694                                                    
[11]	validation-rmse:6.69636                                                    
[12]	validation-rmse:6.66124





[1]	validation-rmse:9.17506                                                     
[2]	validation-rmse:8.32793                                                     
[3]	validation-rmse:7.75963                                                     
[4]	validation-rmse:7.38658                                                     
[5]	validation-rmse:7.14197                                                     
[6]	validation-rmse:6.98145                                                     
[7]	validation-rmse:6.87299                                                     
[8]	validation-rmse:6.80048                                                     
[9]	validation-rmse:6.74906                                                     
[10]	validation-rmse:6.71411                                                    
[11]	validation-rmse:6.68912                                                    
[12]	validation-rmse:6.67102                                                    
[13]	validation-rmse:6.65635





[0]	validation-rmse:6.66857                                                     
[1]	validation-rmse:6.57446                                                     
[2]	validation-rmse:6.57107                                                     
[3]	validation-rmse:6.55956                                                     
[4]	validation-rmse:6.54364                                                     
[5]	validation-rmse:6.53919                                                     
[6]	validation-rmse:6.53240                                                     
[7]	validation-rmse:6.52165                                                     
[8]	validation-rmse:6.51091                                                     
[9]	validation-rmse:6.50512                                                     
[10]	validation-rmse:6.50709                                                    
[11]	validation-rmse:6.50252                                                    
[12]	validation-rmse:6.49595





[0]	validation-rmse:7.35441                                                     
[1]	validation-rmse:6.66522                                                     
[2]	validation-rmse:6.54051                                                     
[3]	validation-rmse:6.49485                                                     
[4]	validation-rmse:6.47543                                                     
[5]	validation-rmse:6.46528                                                     
[6]	validation-rmse:6.45826                                                     
[7]	validation-rmse:6.45239                                                     
[8]	validation-rmse:6.44554                                                     
[9]	validation-rmse:6.44074                                                     
[10]	validation-rmse:6.43348                                                    
[11]	validation-rmse:6.43261                                                    
[12]	validation-rmse:6.43083





[0]	validation-rmse:11.10595                                                    
[1]	validation-rmse:10.20216                                                    
[2]	validation-rmse:9.46894                                                     
[3]	validation-rmse:8.87834                                                     
[4]	validation-rmse:8.40753                                                     
[5]	validation-rmse:8.03202                                                     
[6]	validation-rmse:7.73733                                                     
[7]	validation-rmse:7.50137                                                     
[8]	validation-rmse:7.31847                                                     
[9]	validation-rmse:7.17722                                                     
[10]	validation-rmse:7.06060                                                    
[11]	validation-rmse:6.97109                                                    
[12]	validation-rmse:6.89774




In [38]:
mlflow.xgboost.autolog(disable=True)

In [37]:
with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        'learning_rate': 0.09585355369315604,
        'max_depth': 30,
        'min_child_weight': 1.060597050922164,
        'objective': 'reg:linear',
        'reg_alpha': 0.018060244040060163,
        'reg_lambda': 0.011658731377413597,
        'seed': 42
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, 'validation')],
        early_stopping_rounds=50
    )

    y_pred = booster.predict(valid)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")



[0]	validation-rmse:11.44482
[1]	validation-rmse:10.77202
[2]	validation-rmse:10.18363
[3]	validation-rmse:9.67396
[4]	validation-rmse:9.23166
[5]	validation-rmse:8.84808
[6]	validation-rmse:8.51883
[7]	validation-rmse:8.23597
[8]	validation-rmse:7.99320
[9]	validation-rmse:7.78709
[10]	validation-rmse:7.61022
[11]	validation-rmse:7.45952
[12]	validation-rmse:7.33049
[13]	validation-rmse:7.22098
[14]	validation-rmse:7.12713
[15]	validation-rmse:7.04752
[16]	validation-rmse:6.98005
[17]	validation-rmse:6.92232
[18]	validation-rmse:6.87112
[19]	validation-rmse:6.82740
[20]	validation-rmse:6.78995
[21]	validation-rmse:6.75792
[22]	validation-rmse:6.72994
[23]	validation-rmse:6.70547
[24]	validation-rmse:6.68390
[25]	validation-rmse:6.66421
[26]	validation-rmse:6.64806
[27]	validation-rmse:6.63280
[28]	validation-rmse:6.61924
[29]	validation-rmse:6.60773
[30]	validation-rmse:6.59777
[31]	validation-rmse:6.58875
[32]	validation-rmse:6.58107
[33]	validation-rmse:6.57217
[34]	validation-rmse:



In [23]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)
        



In [40]:
logged_model = 'runs:/95a9e0645b4548f4a5bac4a30539a699/models_mlflow'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)



In [41]:
loaded_model

mlflow.pyfunc.loaded_model:
  artifact_path: models_mlflow
  flavor: mlflow.xgboost
  run_id: 95a9e0645b4548f4a5bac4a30539a699

In [42]:
xgboost_model = mlflow.xgboost.load_model(logged_model)



In [43]:
xgboost_model

<xgboost.core.Booster at 0x33adda970>

In [44]:
y_pred = xgboost_model.predict(valid)

In [45]:
y_pred[:10]

array([14.782765 ,  7.184751 , 15.971323 , 24.328938 ,  9.559302 ,
       17.115105 , 11.6522455,  8.688133 ,  8.962229 , 18.982166 ],
      dtype=float32)

In [46]:
from mlflow.tracking import MlflowClient

MLFLOW_TRACKING_URI = "sqlite:///mlflow.db"

client = MlflowClient(tracking_uri=MLFLOW_TRACKING_URI)

In [54]:
from mlflow.entities import ViewType

runs = client.search_runs(
    experiment_ids = '1',
    filter_string = "metrics.rmse < 6.8",
    run_view_type = ViewType.ACTIVE_ONLY,
    max_results = 5,
    order_by = ["metrics.rmse ASC"]
)

In [55]:
for run in runs:
    print(f"run id: {run.info.run_id}, rmse: {run.data.metrics['rmse']:.4f}")

run id: d786efd9650646dd8ff936eaf5b8c9fe, rmse: 6.3095
run id: 1b235c38732f408d8864e750bc02fb11, rmse: 6.3104
run id: aeec77ccda164cec98b0b60639db3aa5, rmse: 6.3118
run id: b1f3eab4b34a4cb0abb1f61400eb67a1, rmse: 6.3138
run id: 0c99ae2f785e48ee8faa877664ed11c0, rmse: 6.3144


In [56]:
import mlflow

mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)

In [58]:
run_id = "d786efd9650646dd8ff936eaf5b8c9fe"
model_uri = f"runs:/{run_id}/model"
mlflow.register_model(model_uri = model_uri, name = "nyc-taxi-regressor")

Registered model 'nyc-taxi-regressor' already exists. Creating a new version of this model...
Created version '3' of model 'nyc-taxi-regressor'.


<ModelVersion: aliases=[], creation_timestamp=1716606852607, current_stage='None', description=None, last_updated_timestamp=1716606852607, name='nyc-taxi-regressor', run_id='d786efd9650646dd8ff936eaf5b8c9fe', run_link=None, source='/Users/weishanhe/my_github/mlops-zoomcamp/02-experiment-tracking/mlruns/1/d786efd9650646dd8ff936eaf5b8c9fe/artifacts/model', status='READY', status_message=None, tags={}, user_id=None, version=3>

In [61]:
# transition model to production
model_name = "nyc-taxi-regressor"
latest_versions = client.get_latest_versions(name = model_name)

for version in latest_versions:
    print(f"version: {version.version}, stage: {version.current_stage}")

version: 3, stage: None


  latest_versions = client.get_latest_versions(name = model_name)


In [63]:
model_version = 3
new_stage = "Staging"
client.transition_model_version_stage(
    name = model_name,
    version = model_version,
    stage = new_stage,
    archive_existing_versions = False
)

  client.transition_model_version_stage(


<ModelVersion: aliases=[], creation_timestamp=1716606852607, current_stage='Staging', description=None, last_updated_timestamp=1716607290565, name='nyc-taxi-regressor', run_id='d786efd9650646dd8ff936eaf5b8c9fe', run_link=None, source='/Users/weishanhe/my_github/mlops-zoomcamp/02-experiment-tracking/mlruns/1/d786efd9650646dd8ff936eaf5b8c9fe/artifacts/model', status='READY', status_message=None, tags={}, user_id=None, version=3>

In [64]:
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
client.update_model_version(
    name = model_name,
    version = 3,
    description = f"The model version {model_version} was transitioned to {new_stage} on {date}"
)

<ModelVersion: aliases=[], creation_timestamp=1716606852607, current_stage='Staging', description='The model version 3 was transitioned to Staging on 2024-05-24', last_updated_timestamp=1716607359386, name='nyc-taxi-regressor', run_id='d786efd9650646dd8ff936eaf5b8c9fe', run_link=None, source='/Users/weishanhe/my_github/mlops-zoomcamp/02-experiment-tracking/mlruns/1/d786efd9650646dd8ff936eaf5b8c9fe/artifacts/model', status='READY', status_message=None, tags={}, user_id=None, version=3>