In [12]:
# Import Libraries
import mlflow
import pandas as pd
import xgboost
from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from hyperopt.pyll import scope
from sklearn.feature_extraction import DictVectorizer
from sklearn.metrics import mean_squared_error


In [13]:
# defining a function to quickly read and prepare data
def read_dataframe(filename):
    if filename.endswith('.csv'):
        df = pd.read_csv(filename)

        df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
        df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)
    elif filename.endswith('.parquet'):
        df = pd.read_parquet(filename)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)
    df = df[(df.duration >= 1) & (df.duration <= 60)]
    
    df['hour'] = df.lpep_pickup_datetime.dt.hour
    df['dayofweek'] = df.lpep_pickup_datetime.dt.day_of_week

    categorical = ['PULocationID', 'DOLocationID', 'hour', 'dayofweek', 'VendorID']
    df[categorical] = df[categorical].astype(str)
    df['PU_DO'] = df['PULocationID'] + '_' + df['DOLocationID']
    
    return df

In [14]:
# Reading and preparing the dataset
df_train = read_dataframe('/home/ubuntu/data/green_tripdata_2021-01.parquet')
df_val = read_dataframe('/home/ubuntu/data/green_tripdata_2021-02.parquet')

In [15]:
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment-xgboost")

<Experiment: artifact_location='./mlruns/2', experiment_id='2', lifecycle_stage='active', name='nyc-taxi-experiment-xgboost', tags={}>

In [16]:
# defing feature types
categorical = ['PU_DO', 'hour', 'dayofweek', 'VendorID'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

# applying one hot encoding
dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [17]:
# defining the target
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [18]:
# defining xboost level data

training_data = xgboost.DMatrix(X_train, label=y_train)
valid = xgboost.DMatrix(X_val, label=y_val)

In [19]:
def objective(params):
    
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        model = xgboost.train(
            params = params,
            dtrain = training_data,
            num_boost_round = 100,
            evals = [(valid, "validation")],
            early_stopping_rounds = 20
        )

        y_pred = model.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -6, 1),
    'reg_lambda': hp.loguniform('reg_lambda', -5, 1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective' : 'reg:squarederror',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=10,
    trials=Trials()
)



[0]	validation-rmse:17.31678                          
[1]	validation-rmse:14.37341                          
[2]	validation-rmse:12.16280                          
[3]	validation-rmse:10.51843                          
[4]	validation-rmse:9.31764                           
[5]	validation-rmse:8.45665                           
[6]	validation-rmse:7.84277                           
[7]	validation-rmse:7.40453                           
[8]	validation-rmse:7.09856                           
[9]	validation-rmse:6.87677                           
[10]	validation-rmse:6.71735                          
[11]	validation-rmse:6.60471                          
[12]	validation-rmse:6.52214                          
[13]	validation-rmse:6.46069                          
[14]	validation-rmse:6.41460                          
[15]	validation-rmse:6.37906                          
[16]	validation-rmse:6.35008                          
[17]	validation-rmse:6.32678                          
[18]	valid




[0]	validation-rmse:7.34827                                                    
[1]	validation-rmse:6.54618                                                    
[2]	validation-rmse:6.42620                                                    
[3]	validation-rmse:6.39757                                                    
[4]	validation-rmse:6.38266                                                    
[5]	validation-rmse:6.38618                                                    
[6]	validation-rmse:6.38596                                                    
[7]	validation-rmse:6.38475                                                    
[8]	validation-rmse:6.38062                                                    
[9]	validation-rmse:6.38657                                                    
[10]	validation-rmse:6.38254                                                   
[11]	validation-rmse:6.37863                                                   
[12]	validation-rmse:6.38932            




[0]	validation-rmse:19.20773                                                   
[1]	validation-rmse:17.46417                                                   
[2]	validation-rmse:15.93766                                                   
[3]	validation-rmse:14.60359                                                   
[4]	validation-rmse:13.44112                                                   
[5]	validation-rmse:12.43203                                                   
[6]	validation-rmse:11.55767                                                   
[7]	validation-rmse:10.80172                                                   
[8]	validation-rmse:10.15214                                                   
[9]	validation-rmse:9.59562                                                    
[10]	validation-rmse:9.11916                                                   
[11]	validation-rmse:8.71340                                                   
[12]	validation-rmse:8.36946            




[0]	validation-rmse:8.62725                                                    
[1]	validation-rmse:6.78572                                                    
[2]	validation-rmse:6.48929                                                    
[3]	validation-rmse:6.42641                                                    
[4]	validation-rmse:6.40215                                                    
[5]	validation-rmse:6.39256                                                    
[6]	validation-rmse:6.37903                                                    
[7]	validation-rmse:6.37444                                                    
[8]	validation-rmse:6.36829                                                    
[9]	validation-rmse:6.36300                                                    
[10]	validation-rmse:6.35492                                                   
[11]	validation-rmse:6.35154                                                   
[12]	validation-rmse:6.34622            




[0]	validation-rmse:18.36626                                                   
[1]	validation-rmse:16.03515                                                   
[2]	validation-rmse:14.11483                                                   
[3]	validation-rmse:12.54686                                                   
[4]	validation-rmse:11.27825                                                   
[5]	validation-rmse:10.26131                                                   
[6]	validation-rmse:9.44664                                                    
[7]	validation-rmse:8.80405                                                    
[8]	validation-rmse:8.29962                                                    
[9]	validation-rmse:7.90501                                                    
[10]	validation-rmse:7.59787                                                   
[11]	validation-rmse:7.36263                                                   
[12]	validation-rmse:7.17842            




[0]	validation-rmse:19.62117                                                   
[1]	validation-rmse:18.20047                                                   
[2]	validation-rmse:16.91475                                                   
[3]	validation-rmse:15.76079                                                   
[4]	validation-rmse:14.72091                                                   
[5]	validation-rmse:13.79117                                                   
[6]	validation-rmse:12.95442                                                   
[7]	validation-rmse:12.20784                                                   
[8]	validation-rmse:11.54252                                                   
[9]	validation-rmse:10.95046                                                   
[10]	validation-rmse:10.42518                                                  
[11]	validation-rmse:9.95960                                                   
[12]	validation-rmse:9.54681            




[0]	validation-rmse:19.25269                                                   
[1]	validation-rmse:17.54113                                                   
[2]	validation-rmse:16.03833                                                   
[3]	validation-rmse:14.72004                                                   
[4]	validation-rmse:13.56582                                                   
[5]	validation-rmse:12.56202                                                   
[6]	validation-rmse:11.69103                                                   
[7]	validation-rmse:10.93755                                                   
[8]	validation-rmse:10.28931                                                   
[9]	validation-rmse:9.73445                                                    
[10]	validation-rmse:9.25875                                                   
[11]	validation-rmse:8.85342                                                   
[12]	validation-rmse:8.50843            




[0]	validation-rmse:19.60984                                                   
[1]	validation-rmse:18.17878                                                   
[2]	validation-rmse:16.88674                                                   
[3]	validation-rmse:15.72397                                                   
[4]	validation-rmse:14.67767                                                   
[5]	validation-rmse:13.73781                                                   
[6]	validation-rmse:12.89434                                                   
[7]	validation-rmse:12.13907                                                   
[8]	validation-rmse:11.46739                                                   
[9]	validation-rmse:10.86853                                                   
[10]	validation-rmse:10.33688                                                  
[11]	validation-rmse:9.86540                                                   
[12]	validation-rmse:9.44923            




[0]	validation-rmse:7.25945                                                    
[1]	validation-rmse:6.60580                                                    
[2]	validation-rmse:6.57500                                                    
[3]	validation-rmse:6.55986                                                    
[4]	validation-rmse:6.56211                                                    
[5]	validation-rmse:6.55089                                                    
[6]	validation-rmse:6.54384                                                    
[7]	validation-rmse:6.53896                                                    
[8]	validation-rmse:6.53375                                                    
[9]	validation-rmse:6.53017                                                    
[10]	validation-rmse:6.52945                                                   
[11]	validation-rmse:6.52279                                                   
[12]	validation-rmse:6.51841            




[0]	validation-rmse:14.80888                                                   
[1]	validation-rmse:11.00809                                                   
[2]	validation-rmse:8.85570                                                    
[3]	validation-rmse:7.70539                                                    
[4]	validation-rmse:7.10528                                                    
[5]	validation-rmse:6.79279                                                    
[6]	validation-rmse:6.61087                                                    
[7]	validation-rmse:6.51709                                                    
[8]	validation-rmse:6.46186                                                    
[9]	validation-rmse:6.43121                                                    
[10]	validation-rmse:6.41486                                                   
[11]	validation-rmse:6.40712                                                   
[12]	validation-rmse:6.40336            




100%|██████████| 10/10 [07:38<00:00, 45.87s/trial, best loss: 6.151784238279151]


In [20]:
# Train the best model
best_result['max_depth'] = int(best_result['max_depth'])
best_result['objective'] = 'reg:squarederror'
best_result['seed'] = 42

#mlflow.xgboost.autolog()
with mlflow.start_run():
    mlflow.set_tag("model", "xgboost")
    mlflow.log_params(best_result)
    model = xgboost.train(
            params = best_result,
            dtrain = training_data,
            num_boost_round = 200,
            evals = [(valid, "validation")],
            early_stopping_rounds = 50
        )
    y_pred = model.predict(valid)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)


[0]	validation-rmse:17.31678
[1]	validation-rmse:14.37341
[2]	validation-rmse:12.16280
[3]	validation-rmse:10.51843
[4]	validation-rmse:9.31764
[5]	validation-rmse:8.45665
[6]	validation-rmse:7.84277
[7]	validation-rmse:7.40453
[8]	validation-rmse:7.09856
[9]	validation-rmse:6.87677
[10]	validation-rmse:6.71735
[11]	validation-rmse:6.60471
[12]	validation-rmse:6.52214
[13]	validation-rmse:6.46069
[14]	validation-rmse:6.41460
[15]	validation-rmse:6.37906
[16]	validation-rmse:6.35008
[17]	validation-rmse:6.32678
[18]	validation-rmse:6.30806
[19]	validation-rmse:6.29658
[20]	validation-rmse:6.28151
[21]	validation-rmse:6.27430
[22]	validation-rmse:6.26650
[23]	validation-rmse:6.26118
[24]	validation-rmse:6.25430
[25]	validation-rmse:6.25110
[26]	validation-rmse:6.24359
[27]	validation-rmse:6.24000
[28]	validation-rmse:6.23765
[29]	validation-rmse:6.23510
[30]	validation-rmse:6.23283
[31]	validation-rmse:6.23070
[32]	validation-rmse:6.22903
[33]	validation-rmse:6.22688
[34]	validation-rmse

