In [1]:
!python -V

Python 3.9.19


In [None]:
# mlflow ui --backend-store-uri sqlite:///mlflow.db

In [2]:
import pandas as pd

In [3]:
import pickle

In [4]:
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

In [6]:
import mlflow


mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/workspaces/mlops-zoomcamp/02-experiment-tracking/mlruns/1', creation_time=1718809930065, experiment_id='1', last_update_time=1718809930065, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [7]:
def read_dataframe(filename):
    if filename.endswith('.csv'):
        df = pd.read_csv(filename)
        df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
        df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)
    elif filename.endswith('.parquet'):
        df = pd.read_parquet(filename)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [9]:
df_train = read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2024-01.parquet')
df_val = read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2024-02.parquet')

In [10]:
len(df_train), len(df_val)

(54373, 51497)

In [11]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [12]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [13]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [14]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)



5.9947992164797

In [19]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [15]:
with mlflow.start_run():

    mlflow.set_tag("developer", "cristian")

    mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
    mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")



In [16]:
import xgboost as xgb

In [17]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [18]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [19]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [25]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

In [26]:
best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:6.83465                           
[1]	validation-rmse:5.93499                           
  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[2]	validation-rmse:5.61749                           
[3]	validation-rmse:5.50422                           
[4]	validation-rmse:5.45300                           
[5]	validation-rmse:5.43174                           
[6]	validation-rmse:5.42036                           
[7]	validation-rmse:5.41480                           
[8]	validation-rmse:5.41183                           
[9]	validation-rmse:5.40905                           
[10]	validation-rmse:5.40652                          
[11]	validation-rmse:5.40190                          
[12]	validation-rmse:5.39667                          
[13]	validation-rmse:5.39464                          
[14]	validation-rmse:5.39119                          
[15]	validation-rmse:5.39131                          
[16]	validation-rmse:5.38688                          
[17]	validation-rmse:5.38453                          
[18]	validation-rmse:5.38056                          
[19]	validation-rmse:5.37865                          
[20]	valid





[0]	validation-rmse:5.73920                                                    
[1]	validation-rmse:5.43849                                                    
[2]	validation-rmse:5.41168                                                    
[3]	validation-rmse:5.40225                                                    
[4]	validation-rmse:5.39862                                                    
[5]	validation-rmse:5.39912                                                    
[6]	validation-rmse:5.40038                                                    
[7]	validation-rmse:5.39660                                                    
[8]	validation-rmse:5.39163                                                    
[9]	validation-rmse:5.39122                                                    
[10]	validation-rmse:5.39059                                                   
[11]	validation-rmse:5.37241                                                   
[12]	validation-rmse:5.37322            





[0]	validation-rmse:7.64064                                                    
[1]	validation-rmse:6.69570                                                    
[2]	validation-rmse:6.11129                                                    
[3]	validation-rmse:5.76347                                                    
[4]	validation-rmse:5.55839                                                    
[5]	validation-rmse:5.44078                                                    
[6]	validation-rmse:5.37259                                                    
[7]	validation-rmse:5.33362                                                    
[8]	validation-rmse:5.30945                                                    
[9]	validation-rmse:5.29302                                                    
[10]	validation-rmse:5.28214                                                   
[11]	validation-rmse:5.27507                                                   
[12]	validation-rmse:5.26862            





[0]	validation-rmse:7.04831                                                    
[1]	validation-rmse:6.08971                                                    
[2]	validation-rmse:5.64982                                                    
[3]	validation-rmse:5.44841                                                    
[4]	validation-rmse:5.38574                                                    
[5]	validation-rmse:5.35264                                                    
[6]	validation-rmse:5.33773                                                    
[7]	validation-rmse:5.32526                                                    
[8]	validation-rmse:5.31839                                                    
[9]	validation-rmse:5.31179                                                    
[10]	validation-rmse:5.30970                                                   
[11]	validation-rmse:5.30097                                                   
[12]	validation-rmse:5.29491            





[0]	validation-rmse:8.44623                                                    
[1]	validation-rmse:7.88203                                                    
[2]	validation-rmse:7.40722                                                    
[3]	validation-rmse:7.00897                                                    
[4]	validation-rmse:6.67877                                                    
[5]	validation-rmse:6.40385                                                    
[6]	validation-rmse:6.17794                                                    
[7]	validation-rmse:5.99212                                                    
[8]	validation-rmse:5.84170                                                    
[9]	validation-rmse:5.71953                                                    
[10]	validation-rmse:5.61970                                                   
[11]	validation-rmse:5.54196                                                   
[12]	validation-rmse:5.47833            





[0]	validation-rmse:8.34411                                                    
[1]	validation-rmse:7.72462                                                    
[2]	validation-rmse:7.22607                                                    
[3]	validation-rmse:6.82549                                                    
[4]	validation-rmse:6.51176                                                    
[5]	validation-rmse:6.26427                                                    
[6]	validation-rmse:6.07623                                                    
[7]	validation-rmse:5.92873                                                    
[8]	validation-rmse:5.80685                                                    
[9]	validation-rmse:5.71788                                                    
[10]	validation-rmse:5.64502                                                   
[11]	validation-rmse:5.59420                                                   
[12]	validation-rmse:5.55198            





[0]	validation-rmse:8.47795                                                    
[1]	validation-rmse:7.94449                                                    
[2]	validation-rmse:7.49737                                                    
[3]	validation-rmse:7.12542                                                    
[4]	validation-rmse:6.81722                                                    
[5]	validation-rmse:6.56386                                                    
[6]	validation-rmse:6.35748                                                    
[7]	validation-rmse:6.18851                                                    
[8]	validation-rmse:6.05159                                                    
[9]	validation-rmse:5.94074                                                    
[10]	validation-rmse:5.85126                                                   
[11]	validation-rmse:5.77867                                                   
[12]	validation-rmse:5.71971            





[0]	validation-rmse:8.79191                                                    
[1]	validation-rmse:8.49645                                                    
[2]	validation-rmse:8.22558                                                    
[3]	validation-rmse:7.97618                                                    
[4]	validation-rmse:7.74668                                                    
[5]	validation-rmse:7.52991                                                    
[6]	validation-rmse:7.33542                                                    
[7]	validation-rmse:7.15895                                                    
[8]	validation-rmse:6.99126                                                    
[9]	validation-rmse:6.84365                                                    
[10]	validation-rmse:6.70727                                                   
[11]	validation-rmse:6.58196                                                   
[12]	validation-rmse:6.46238            





[0]	validation-rmse:8.33268                                                    
[1]	validation-rmse:7.70449                                                    
[2]	validation-rmse:7.20221                                                    
[3]	validation-rmse:6.80442                                                    
[4]	validation-rmse:6.49304                                                    
[5]	validation-rmse:6.25109                                                    
[6]	validation-rmse:6.06416                                                    
[7]	validation-rmse:5.92127                                                    
[8]	validation-rmse:5.81185                                                    
[9]	validation-rmse:5.72840                                                    
[10]	validation-rmse:5.66439                                                   
[11]	validation-rmse:5.61582                                                   
[12]	validation-rmse:5.57822            





[0]	validation-rmse:7.58253                                                    
[1]	validation-rmse:6.62844                                                    
[2]	validation-rmse:6.05661                                                    
[3]	validation-rmse:5.72726                                                    
[4]	validation-rmse:5.54158                                                    
[5]	validation-rmse:5.43721                                                    
[6]	validation-rmse:5.37929                                                    
[7]	validation-rmse:5.34581                                                    
[8]	validation-rmse:5.32602                                                    
[9]	validation-rmse:5.30950                                                    
[10]	validation-rmse:5.29936                                                   
[11]	validation-rmse:5.28959                                                   
[12]	validation-rmse:5.28375            





[0]	validation-rmse:6.53679                                                     
[1]	validation-rmse:5.62967                                                     
[2]	validation-rmse:5.34045                                                     
[3]	validation-rmse:5.25465                                                     
[4]	validation-rmse:5.21888                                                     
[5]	validation-rmse:5.20148                                                     
[6]	validation-rmse:5.18884                                                     
[7]	validation-rmse:5.18643                                                     
[8]	validation-rmse:5.18648                                                     
[9]	validation-rmse:5.18526                                                     
[10]	validation-rmse:5.18607                                                    
[11]	validation-rmse:5.18518                                                    
[12]	validation-rmse:5.18530





[0]	validation-rmse:8.53298                                                     
[1]	validation-rmse:8.03126                                                     
[2]	validation-rmse:7.59819                                                     
[3]	validation-rmse:7.22587                                                     
[4]	validation-rmse:6.90582                                                     
[5]	validation-rmse:6.63399                                                     
[6]	validation-rmse:6.40283                                                     
[7]	validation-rmse:6.20795                                                     
[8]	validation-rmse:6.04423                                                     
[9]	validation-rmse:5.90705                                                     
 22%|██▏       | 11/50 [07:34<27:10, 41.81s/trial, best loss: 5.138968589677634]

In [20]:
mlflow.xgboost.autolog(disable=True)

In [21]:
with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        'learning_rate': 0.09585355369315604,
        'max_depth': 30,
        'min_child_weight': 1.060597050922164,
        'objective': 'reg:linear',
        'reg_alpha': 0.018060244040060163,
        'reg_lambda': 0.011658731377413597,
        'seed': 42
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, 'validation')],
        early_stopping_rounds=50
    )

    y_pred = booster.predict(valid)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")



[0]	validation-rmse:8.55121
[1]	validation-rmse:8.06400
[2]	validation-rmse:7.64347
[3]	validation-rmse:7.28077
[4]	validation-rmse:6.96868
[5]	validation-rmse:6.70403
[6]	validation-rmse:6.47713
[7]	validation-rmse:6.28420
[8]	validation-rmse:6.12122
[9]	validation-rmse:5.98447
[10]	validation-rmse:5.87103
[11]	validation-rmse:5.77302
[12]	validation-rmse:5.68988
[13]	validation-rmse:5.62238
[14]	validation-rmse:5.56459
[15]	validation-rmse:5.51506
[16]	validation-rmse:5.47349
[17]	validation-rmse:5.43747
[18]	validation-rmse:5.40878
[19]	validation-rmse:5.38302
[20]	validation-rmse:5.36147
[21]	validation-rmse:5.34029
[22]	validation-rmse:5.32216
[23]	validation-rmse:5.30792
[24]	validation-rmse:5.29485
[25]	validation-rmse:5.28297
[26]	validation-rmse:5.27297
[27]	validation-rmse:5.26478
[28]	validation-rmse:5.25749
[29]	validation-rmse:5.25086
[30]	validation-rmse:5.24386
[31]	validation-rmse:5.23724
[32]	validation-rmse:5.23377
[33]	validation-rmse:5.23046
[34]	validation-rmse:5.2



In [23]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)
        

