In [1]:
!python -V

Python 3.9.19


In [2]:
import pandas as pd

In [3]:
import pickle

In [4]:
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

In [6]:
import mlflow


mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment3")

2024/05/30 22:11:27 INFO mlflow.tracking.fluent: Experiment with name 'nyc-taxi-experiment3' does not exist. Creating a new experiment.


<Experiment: artifact_location='/home/mlops/mlops-zoomcamp/02-experiment-tracking/mlruns/1', creation_time=1717107087030, experiment_id='1', last_update_time=1717107087030, lifecycle_stage='active', name='nyc-taxi-experiment3', tags={}>

In [7]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)

    df.tpep_dropoff_datetime = pd.to_datetime(df.tpep_dropoff_datetime)
    df.tpep_pickup_datetime = pd.to_datetime(df.tpep_pickup_datetime)

    df['duration'] = df.tpep_dropoff_datetime - df.tpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [8]:
!ls -l

total 412
-rw-rw-r-- 1 mlops mlops   3734 May 19 22:35 README.md
drwxrwxr-x 2 mlops mlops   4096 May 23 04:59 data
-rw-rw-r-- 1 mlops mlops 139394 May 30 04:48 duration-prediction.ipynb
drwxrwxr-x 2 mlops mlops   4096 May 19 22:35 images
-rw-rw-r-- 1 mlops mlops   1165 May 19 22:35 meta.json
-rw-r--r-- 1 mlops mlops 217088 May 30 22:11 mlflow.db
-rw-rw-r-- 1 mlops mlops   4676 May 19 22:35 mlflow_on_aws.md
drwxrwxr-x 4 mlops mlops   4096 May 28 23:33 mlruns
-rw-rw-r-- 1 mlops mlops  15480 May 19 22:35 model-registry.ipynb
drwxrwxr-x 2 mlops mlops   4096 May 23 05:56 models
-rw-rw-r-- 1 mlops mlops     77 May 19 22:35 requirements.txt
drwxrwxr-x 2 mlops mlops   4096 May 19 22:35 running-mlflow-examples


In [9]:
df_train = read_dataframe('./data/yellow_tripdata_2023-01.parquet')
df_val = read_dataframe('./data/yellow_tripdata_2023-02.parquet')

In [10]:
len(df_train), len(df_val)

(3009173, 2855951)

In [11]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [12]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [13]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [14]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)



5.247057977640596

In [17]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [18]:
with mlflow.start_run():

    mlflow.set_tag("developer", "akash")

    mlflow.log_param("train-data-path", "./data/yellow_tripdata_2023-01.csv")
    mlflow.log_param("valid-data-path", "./data/yellow_tripdata_2023-02.csv")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")



In [19]:
import xgboost as xgb

In [20]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [21]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [22]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=10,
            evals=[(valid, 'validation')],
            early_stopping_rounds=5
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [23]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

  0%|                    | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:5.52307                                     
[1]	validation-rmse:5.14990                                     
[2]	validation-rmse:5.10660                                     
[3]	validation-rmse:5.08924                                     
[4]	validation-rmse:5.08169                                     
[5]	validation-rmse:5.07563                                     
[6]	validation-rmse:5.07041                                     
[7]	validation-rmse:5.06490                                     
[8]	validation-rmse:5.05518                                     
[9]	validation-rmse:5.05067                                     
  2%| | 1/50 [00:14<11:41, 14.31s/trial, best loss: 5.0506676127





[0]	validation-rmse:9.42040                                     
[1]	validation-rmse:8.84809                                     
[2]	validation-rmse:8.34401                                     
[3]	validation-rmse:7.90139                                     
[4]	validation-rmse:7.51429                                     
[5]	validation-rmse:7.17739                                     
[6]	validation-rmse:6.88469                                     
[7]	validation-rmse:6.63151                                     
[8]	validation-rmse:6.41333                                     
[9]	validation-rmse:6.22577                                     
  4%| | 2/50 [00:27<11:04, 13.85s/trial, best loss: 5.0506676127





[0]	validation-rmse:7.09902                                     
[1]	validation-rmse:5.83959                                     
[2]	validation-rmse:5.35616                                     
[3]	validation-rmse:5.16953                                     
[4]	validation-rmse:5.09330                                     
[5]	validation-rmse:5.05765                                     
[6]	validation-rmse:5.03902                                     
[7]	validation-rmse:5.02770                                     
[8]	validation-rmse:5.01856                                     
[9]	validation-rmse:5.01468                                     
  6%| | 3/50 [00:43<11:36, 14.82s/trial, best loss: 5.0146789268





[0]	validation-rmse:5.13937                                     
[1]	validation-rmse:4.96450                                     
[2]	validation-rmse:4.94150                                     
[3]	validation-rmse:4.93253                                     
[4]	validation-rmse:4.92473                                     
[5]	validation-rmse:4.91714                                     
[6]	validation-rmse:4.91039                                     
[7]	validation-rmse:4.90081                                     
[8]	validation-rmse:4.89335                                     
[9]	validation-rmse:4.88767                                     
  8%| | 4/50 [01:03<12:58, 16.92s/trial, best loss: 4.8876712748





[0]	validation-rmse:9.51342                                     
[1]	validation-rmse:9.01281                                     
[2]	validation-rmse:8.56139                                     
[3]	validation-rmse:8.15542                                     
[4]	validation-rmse:7.79110                                     
[5]	validation-rmse:7.46512                                     
[6]	validation-rmse:7.17405                                     
[7]	validation-rmse:6.91486                                     
[8]	validation-rmse:6.68458                                     
[9]	validation-rmse:6.48082                                     
 10%| | 5/50 [01:25<14:02, 18.72s/trial, best loss: 4.8876712748





[0]	validation-rmse:7.46344                                     
[1]	validation-rmse:6.10603                                     
[2]	validation-rmse:5.44629                                     
[3]	validation-rmse:5.13605                                     
[4]	validation-rmse:4.98906                                     
[5]	validation-rmse:4.91227                                     
[6]	validation-rmse:4.86914                                     
[7]	validation-rmse:4.84342                                     
[8]	validation-rmse:4.82804                                     
[9]	validation-rmse:4.81351                                     
 12%| | 6/50 [02:09<20:01, 27.31s/trial, best loss: 4.8135089463





[0]	validation-rmse:8.99236                                     
[1]	validation-rmse:8.12201                                     
[2]	validation-rmse:7.42371                                     
[3]	validation-rmse:6.86970                                     
[4]	validation-rmse:6.43395                                     
[5]	validation-rmse:6.09550                                     
[6]	validation-rmse:5.83395                                     
[7]	validation-rmse:5.63224                                     
[8]	validation-rmse:5.47741                                     
[9]	validation-rmse:5.35952                                     
 14%|▏| 7/50 [02:33<18:40, 26.05s/trial, best loss: 4.8135089463





[0]	validation-rmse:6.54078                                     
[1]	validation-rmse:5.32758                                     
[2]	validation-rmse:4.95269                                     
[3]	validation-rmse:4.82741                                     
[4]	validation-rmse:4.77410                                     
[5]	validation-rmse:4.74942                                     
[6]	validation-rmse:4.73641                                     
[7]	validation-rmse:4.72629                                     
[8]	validation-rmse:4.72207                                     
[9]	validation-rmse:4.71914                                     
 16%|▏| 8/50 [03:42<27:51, 39.81s/trial, best loss: 4.7191407369





[0]	validation-rmse:5.02415                                     
[1]	validation-rmse:5.00937                                     
[2]	validation-rmse:4.99653                                     
[3]	validation-rmse:4.98457                                     
[4]	validation-rmse:4.97287                                     
[5]	validation-rmse:4.96382                                     
[6]	validation-rmse:4.95442                                     
[7]	validation-rmse:4.94546                                     
[8]	validation-rmse:4.93657                                     
[9]	validation-rmse:4.92856                                     
 18%|▏| 9/50 [04:01<22:49, 33.41s/trial, best loss: 4.7191407369





[0]	validation-rmse:7.48520                                     
[1]	validation-rmse:6.11384                                     
[2]	validation-rmse:5.43521                                     
[3]	validation-rmse:5.11041                                     
[4]	validation-rmse:4.95378                                     
[5]	validation-rmse:4.87249                                     
[6]	validation-rmse:4.82716                                     
[7]	validation-rmse:4.79859                                     
[8]	validation-rmse:4.78119                                     
[9]	validation-rmse:4.76740                                     
 20%|▏| 10/50 [04:47<24:45, 37.14s/trial, best loss: 4.719140736





[0]	validation-rmse:9.41197                                     
[1]	validation-rmse:8.82935                                     
[2]	validation-rmse:8.31351                                     
[3]	validation-rmse:7.85794                                     
[4]	validation-rmse:7.45706                                     
[5]	validation-rmse:7.10510                                     
[6]	validation-rmse:6.79712                                     
[7]	validation-rmse:6.52881                                     
[8]	validation-rmse:6.29598                                     
[9]	validation-rmse:6.09425                                     
 22%|▏| 11/50 [05:28<25:00, 38.47s/trial, best loss: 4.719140736





[0]	validation-rmse:9.29876                                     
[1]	validation-rmse:8.63019                                     
[2]	validation-rmse:8.05111                                     
[3]	validation-rmse:7.55200                                     
[4]	validation-rmse:7.12336                                     
[5]	validation-rmse:6.75767                                     
[6]	validation-rmse:6.44588                                     
[7]	validation-rmse:6.18237                                     
[8]	validation-rmse:5.95989                                     
[9]	validation-rmse:5.77233                                     
 24%|▏| 12/50 [06:22<27:22, 43.22s/trial, best loss: 4.719140736





[0]	validation-rmse:9.57974                                     
[1]	validation-rmse:9.13201                                     
[2]	validation-rmse:8.72181                                     
[3]	validation-rmse:8.34663                                     
[4]	validation-rmse:8.00385                                     
[5]	validation-rmse:7.69171                                     
[6]	validation-rmse:7.40745                                     
[7]	validation-rmse:7.14962                                     
[8]	validation-rmse:6.91504                                     
[9]	validation-rmse:6.70300                                     
 26%|▎| 13/50 [07:07<26:51, 43.56s/trial, best loss: 4.719140736





[0]	validation-rmse:9.59376                                     
[1]	validation-rmse:9.15875                                     
[2]	validation-rmse:8.76006                                     
[3]	validation-rmse:8.39525                                     
[4]	validation-rmse:8.06189                                     
[5]	validation-rmse:7.75803                                     
[6]	validation-rmse:7.48142                                     
[7]	validation-rmse:7.23010                                     
[8]	validation-rmse:7.00226                                     
[9]	validation-rmse:6.79571                                     
 28%|▎| 14/50 [07:31<22:41, 37.81s/trial, best loss: 4.719140736





[0]	validation-rmse:9.32130                                     
[1]	validation-rmse:8.66979                                     
[2]	validation-rmse:8.10290                                     
[3]	validation-rmse:7.61227                                     
[4]	validation-rmse:7.18909                                     
[5]	validation-rmse:6.82579                                     
[6]	validation-rmse:6.51485                                     
[7]	validation-rmse:6.25000                                     
[8]	validation-rmse:6.02505                                     
[9]	validation-rmse:5.83525                                     
 30%|▎| 15/50 [08:17<23:28, 40.25s/trial, best loss: 4.719140736





[0]	validation-rmse:9.36089                                     
[1]	validation-rmse:8.73839                                     
[2]	validation-rmse:8.19187                                     
[3]	validation-rmse:7.71398                                     
[4]	validation-rmse:7.29752                                     
[5]	validation-rmse:6.93584                                     
[6]	validation-rmse:6.62313                                     
[7]	validation-rmse:6.35386                                     
[8]	validation-rmse:6.12260                                     
[9]	validation-rmse:5.92392                                     
 32%|▎| 16/50 [09:14<25:37, 45.22s/trial, best loss: 4.719140736





[0]	validation-rmse:7.91570                                     
[1]	validation-rmse:6.59515                                     
[2]	validation-rmse:5.82122                                     
[3]	validation-rmse:5.38474                                     
[4]	validation-rmse:5.14016                                     
[5]	validation-rmse:5.00282                                     
[6]	validation-rmse:4.92329                                     
[7]	validation-rmse:4.87393                                     
[8]	validation-rmse:4.84187                                     
[9]	validation-rmse:4.81895                                     
 34%|▎| 17/50 [09:59<24:46, 45.04s/trial, best loss: 4.719140736





[0]	validation-rmse:9.23369                                     
[1]	validation-rmse:8.52708                                     
[2]	validation-rmse:7.93203                                     
[3]	validation-rmse:7.43404                                     
[4]	validation-rmse:7.01989                                     
[5]	validation-rmse:6.67784                                     
[6]	validation-rmse:6.39747                                     
[7]	validation-rmse:6.16848                                     
[8]	validation-rmse:5.98195                                     
[9]	validation-rmse:5.83046                                     
 36%|▎| 18/50 [10:11<18:50, 35.33s/trial, best loss: 4.719140736





[0]	validation-rmse:9.37992                                     
[1]	validation-rmse:8.77145                                     
[2]	validation-rmse:8.23433                                     
[3]	validation-rmse:7.76182                                     
[4]	validation-rmse:7.34794                                     
[5]	validation-rmse:6.98635                                     
[6]	validation-rmse:6.67221                                     
[7]	validation-rmse:6.39899                                     
[8]	validation-rmse:6.16329                                     
[9]	validation-rmse:5.95971                                     
 38%|▍| 19/50 [11:42<26:50, 51.95s/trial, best loss: 4.719140736





[0]	validation-rmse:6.97608                                     
[1]	validation-rmse:5.63068                                     
[2]	validation-rmse:5.09734                                     
[3]	validation-rmse:4.88924                                     
[4]	validation-rmse:4.79923                                     
[5]	validation-rmse:4.75705                                     
[6]	validation-rmse:4.73469                                     
[7]	validation-rmse:4.71927                                     
[8]	validation-rmse:4.70738                                     
[9]	validation-rmse:4.70043                                     
 40%|▍| 20/50 [12:42<27:07, 54.26s/trial, best loss: 4.700434010





[0]	validation-rmse:6.12953                                     
[1]	validation-rmse:5.09636                                     
[2]	validation-rmse:4.84705                                     
[3]	validation-rmse:4.77261                                     
[4]	validation-rmse:4.74230                                     
[5]	validation-rmse:4.72866                                     
[6]	validation-rmse:4.72024                                     
[7]	validation-rmse:4.71687                                     
[8]	validation-rmse:4.71322                                     
[9]	validation-rmse:4.70974                                     
 42%|▍| 21/50 [13:42<27:09, 56.18s/trial, best loss: 4.700434010





[0]	validation-rmse:5.81395                                     
[1]	validation-rmse:4.98442                                     
[2]	validation-rmse:4.82295                                     
[3]	validation-rmse:4.77653                                     
[4]	validation-rmse:4.75671                                     
[5]	validation-rmse:4.74535                                     
[6]	validation-rmse:4.74100                                     
[7]	validation-rmse:4.73733                                     
[8]	validation-rmse:4.73331                                     
[9]	validation-rmse:4.72976                                     
 44%|▍| 22/50 [14:33<25:30, 54.67s/trial, best loss: 4.700434010





[0]	validation-rmse:8.75868                                     
[1]	validation-rmse:7.74331                                     
[2]	validation-rmse:6.96692                                     
[3]	validation-rmse:6.38140                                     
[4]	validation-rmse:5.94589                                     
[5]	validation-rmse:5.62480                                     
[6]	validation-rmse:5.39156                                     
[7]	validation-rmse:5.22138                                     
[8]	validation-rmse:5.09757                                     
[9]	validation-rmse:5.00833                                     
 46%|▍| 23/50 [15:51<27:44, 61.66s/trial, best loss: 4.700434010





[0]	validation-rmse:8.27801                                     
[1]	validation-rmse:7.04948                                     
[2]	validation-rmse:6.23034                                     
[3]	validation-rmse:5.69770                                     
[4]	validation-rmse:5.35956                                     
[5]	validation-rmse:5.14676                                     
[6]	validation-rmse:5.01280                                     
[7]	validation-rmse:4.92609                                     
[8]	validation-rmse:4.86917                                     
[9]	validation-rmse:4.83055                                     
 48%|▍| 24/50 [17:08<28:41, 66.22s/trial, best loss: 4.700434010





[0]	validation-rmse:6.03910                                     
[1]	validation-rmse:5.06043                                     
[2]	validation-rmse:4.83933                                     
[3]	validation-rmse:4.77248                                     
[4]	validation-rmse:4.74648                                     
[5]	validation-rmse:4.73427                                     
[6]	validation-rmse:4.72888                                     
[7]	validation-rmse:4.72387                                     
[8]	validation-rmse:4.71847                                     
[9]	validation-rmse:4.71554                                     
 50%|▌| 25/50 [18:00<25:43, 61.73s/trial, best loss: 4.700434010





[0]	validation-rmse:8.31259                                     
[1]	validation-rmse:7.09334                                     
[2]	validation-rmse:6.26988                                     
[3]	validation-rmse:5.72764                                     
[4]	validation-rmse:5.37753                                     
[5]	validation-rmse:5.15305                                     
[6]	validation-rmse:5.01075                                     
[7]	validation-rmse:4.91775                                     
[8]	validation-rmse:4.85671                                     
[9]	validation-rmse:4.81491                                     
 52%|▌| 26/50 [19:22<27:08, 67.87s/trial, best loss: 4.700434010





[0]	validation-rmse:5.62402                                     
[1]	validation-rmse:4.94606                                     
[2]	validation-rmse:4.82902                                     
[3]	validation-rmse:4.79653                                     
[4]	validation-rmse:4.78340                                     
[5]	validation-rmse:4.77233                                     
[6]	validation-rmse:4.76773                                     
[7]	validation-rmse:4.76280                                     
[8]	validation-rmse:4.75890                                     
[9]	validation-rmse:4.75516                                     
 54%|▌| 27/50 [20:01<22:45, 59.39s/trial, best loss: 4.700434010





[0]	validation-rmse:6.59353                                     
[1]	validation-rmse:5.34910                                     
[2]	validation-rmse:4.94851                                     
[3]	validation-rmse:4.81275                                     
[4]	validation-rmse:4.75650                                     
[5]	validation-rmse:4.73302                                     
[6]	validation-rmse:4.71825                                     
[7]	validation-rmse:4.70746                                     
[8]	validation-rmse:4.70120                                     
[9]	validation-rmse:4.69798                                     
 56%|▌| 28/50 [21:06<22:20, 60.95s/trial, best loss: 4.697981213





[0]	validation-rmse:7.05058                                     
[1]	validation-rmse:5.69557                                     
[2]	validation-rmse:5.14373                                     
[3]	validation-rmse:4.92053                                     
[4]	validation-rmse:4.82413                                     
[5]	validation-rmse:4.77606                                     
[6]	validation-rmse:4.75074                                     
[7]	validation-rmse:4.73504                                     
[8]	validation-rmse:4.72433                                     
[9]	validation-rmse:4.71419                                     
 58%|▌| 29/50 [22:12<21:54, 62.62s/trial, best loss: 4.697981213





[0]	validation-rmse:5.20290                                     
[1]	validation-rmse:4.85142                                     
[2]	validation-rmse:4.80663                                     
[3]	validation-rmse:4.78858                                     
[4]	validation-rmse:4.78204                                     
[5]	validation-rmse:4.77577                                     
[6]	validation-rmse:4.77001                                     
[7]	validation-rmse:4.76459                                     
[8]	validation-rmse:4.75918                                     
[9]	validation-rmse:4.75411                                     
 60%|▌| 30/50 [22:49<18:14, 54.74s/trial, best loss: 4.697981213





[0]	validation-rmse:7.95147                                     
[1]	validation-rmse:6.64883                                     
[2]	validation-rmse:5.88113                                     
[3]	validation-rmse:5.44504                                     
[4]	validation-rmse:5.19980                                     
[5]	validation-rmse:5.06169                                     
[6]	validation-rmse:4.97904                                     
[7]	validation-rmse:4.92957                                     
[8]	validation-rmse:4.89685                                     
[9]	validation-rmse:4.87497                                     
 62%|▌| 31/50 [23:25<15:36, 49.29s/trial, best loss: 4.697981213





[0]	validation-rmse:8.67333                                     
[1]	validation-rmse:7.61205                                     
[2]	validation-rmse:6.81691                                     
[3]	validation-rmse:6.23089                                     
[4]	validation-rmse:5.80560                                     
[5]	validation-rmse:5.49909                                     
[6]	validation-rmse:5.28234                                     
[7]	validation-rmse:5.12834                                     
[8]	validation-rmse:5.01865                                     
[9]	validation-rmse:4.94035                                     
 64%|▋| 32/50 [24:37<16:45, 55.88s/trial, best loss: 4.697981213





[0]	validation-rmse:6.65049                                     
[1]	validation-rmse:5.39202                                     
[2]	validation-rmse:4.97465                                     
[3]	validation-rmse:4.82968                                     
[4]	validation-rmse:4.76888                                     
[5]	validation-rmse:4.73997                                     
[6]	validation-rmse:4.72351                                     
[7]	validation-rmse:4.71052                                     
[8]	validation-rmse:4.70511                                     
[9]	validation-rmse:4.69972                                     
 66%|▋| 33/50 [25:35<16:04, 56.75s/trial, best loss: 4.697981213





[0]	validation-rmse:6.48718                                     
[1]	validation-rmse:5.31519                                     
[2]	validation-rmse:4.97157                                     
[3]	validation-rmse:4.85871                                     
[4]	validation-rmse:4.81351                                     
[5]	validation-rmse:4.78776                                     
[6]	validation-rmse:4.77339                                     
[7]	validation-rmse:4.76376                                     
[8]	validation-rmse:4.76005                                     
[9]	validation-rmse:4.75597                                     
 68%|▋| 34/50 [26:17<13:57, 52.32s/trial, best loss: 4.697981213





[0]	validation-rmse:5.04644                                     
[1]	validation-rmse:4.93074                                     
[2]	validation-rmse:4.90885                                     
[3]	validation-rmse:4.90003                                     
[4]	validation-rmse:4.89145                                     
[5]	validation-rmse:4.88425                                     
[6]	validation-rmse:4.87747                                     
[7]	validation-rmse:4.87125                                     
[8]	validation-rmse:4.86509                                     
[9]	validation-rmse:4.85936                                     
 70%|▋| 35/50 [26:45<11:12, 44.85s/trial, best loss: 4.697981213





[0]	validation-rmse:7.97931                                     
[1]	validation-rmse:6.66082                                     
[2]	validation-rmse:5.86269                                     
[3]	validation-rmse:5.39498                                     
[4]	validation-rmse:5.12710                                     
[5]	validation-rmse:4.97168                                     
[6]	validation-rmse:4.88131                                     
[7]	validation-rmse:4.82510                                     
[8]	validation-rmse:4.78847                                     
[9]	validation-rmse:4.76390                                     
 72%|▋| 36/50 [27:52<12:00, 51.49s/trial, best loss: 4.697981213





[0]	validation-rmse:7.42806                                     
[1]	validation-rmse:6.04989                                     
[2]	validation-rmse:5.37969                                     
[3]	validation-rmse:5.06474                                     
[4]	validation-rmse:4.91637                                     
[5]	validation-rmse:4.83850                                     
[6]	validation-rmse:4.79536                                     
[7]	validation-rmse:4.76914                                     
[8]	validation-rmse:4.75333                                     
[9]	validation-rmse:4.74121                                     
 74%|▋| 37/50 [28:49<11:30, 53.11s/trial, best loss: 4.697981213





[0]	validation-rmse:5.44134                                     
[1]	validation-rmse:4.90106                                     
[2]	validation-rmse:4.81466                                     
[3]	validation-rmse:4.78820                                     
[4]	validation-rmse:4.77450                                     
[5]	validation-rmse:4.76912                                     
[6]	validation-rmse:4.76260                                     
[7]	validation-rmse:4.75728                                     
[8]	validation-rmse:4.75236                                     
[9]	validation-rmse:4.74707                                     
 76%|▊| 38/50 [29:27<09:43, 48.59s/trial, best loss: 4.697981213





[0]	validation-rmse:6.54451                                     
[1]	validation-rmse:5.32228                                     
[2]	validation-rmse:4.94071                                     
[3]	validation-rmse:4.81319                                     
[4]	validation-rmse:4.75991                                     
[5]	validation-rmse:4.73472                                     
[6]	validation-rmse:4.71822                                     
[7]	validation-rmse:4.70741                                     
[8]	validation-rmse:4.70018                                     
[9]	validation-rmse:4.69657                                     
 78%|▊| 39/50 [30:25<09:26, 51.50s/trial, best loss: 4.696573500





[0]	validation-rmse:7.67340                                     
[1]	validation-rmse:6.34019                                     
[2]	validation-rmse:5.64109                                     
[3]	validation-rmse:5.28871                                     
[4]	validation-rmse:5.10965                                     
[5]	validation-rmse:5.01415                                     
[6]	validation-rmse:4.96255                                     
[7]	validation-rmse:4.93291                                     
[8]	validation-rmse:4.91221                                     
[9]	validation-rmse:4.89863                                     
 80%|▊| 40/50 [30:52<07:21, 44.12s/trial, best loss: 4.696573500





[0]	validation-rmse:9.11379                                     
[1]	validation-rmse:8.32178                                     
[2]	validation-rmse:7.67034                                     
[3]	validation-rmse:7.13768                                     
[4]	validation-rmse:6.70671                                     
[5]	validation-rmse:6.35948                                     
[6]	validation-rmse:6.08219                                     
[7]	validation-rmse:5.86230                                     
[8]	validation-rmse:5.68762                                     
[9]	validation-rmse:5.55018                                     
 82%|▊| 41/50 [31:10<05:26, 36.26s/trial, best loss: 4.696573500





[0]	validation-rmse:7.16215                                     
[1]	validation-rmse:5.81314                                     
[2]	validation-rmse:5.24260                                     
[3]	validation-rmse:5.00448                                     
[4]	validation-rmse:4.89736                                     
[5]	validation-rmse:4.84500                                     
[6]	validation-rmse:4.81700                                     
[7]	validation-rmse:4.79922                                     
[8]	validation-rmse:4.78586                                     
[9]	validation-rmse:4.77794                                     
 84%|▊| 42/50 [31:57<05:17, 39.64s/trial, best loss: 4.696573500





[0]	validation-rmse:5.04495                                     
[1]	validation-rmse:4.82370                                     
[2]	validation-rmse:4.79254                                     
[3]	validation-rmse:4.78442                                     
[4]	validation-rmse:4.77719                                     
[5]	validation-rmse:4.77065                                     
[6]	validation-rmse:4.76429                                     
[7]	validation-rmse:4.75777                                     
[8]	validation-rmse:4.75262                                     
[9]	validation-rmse:4.74736                                     
 86%|▊| 43/50 [32:34<04:31, 38.72s/trial, best loss: 4.696573500





[0]	validation-rmse:8.61361                                     
[1]	validation-rmse:7.53323                                     
[2]	validation-rmse:6.74500                                     
[3]	validation-rmse:6.18118                                     
[4]	validation-rmse:5.78323                                     
[5]	validation-rmse:5.50714                                     
[6]	validation-rmse:5.31402                                     
[7]	validation-rmse:5.18081                                     
[8]	validation-rmse:5.08807                                     
[9]	validation-rmse:5.02204                                     
 88%|▉| 44/50 [33:09<03:45, 37.51s/trial, best loss: 4.696573500





[0]	validation-rmse:6.52451                                     
[1]	validation-rmse:5.31029                                     
[2]	validation-rmse:4.93583                                     
[3]	validation-rmse:4.81155                                     
[4]	validation-rmse:4.76080                                     
[5]	validation-rmse:4.73800                                     
[6]	validation-rmse:4.72128                                     
[7]	validation-rmse:4.71066                                     
[8]	validation-rmse:4.70394                                     
[9]	validation-rmse:4.70018                                     
 90%|▉| 45/50 [33:59<03:26, 41.38s/trial, best loss: 4.696573500





[0]	validation-rmse:8.10516                                     
[1]	validation-rmse:6.82426                                     
[2]	validation-rmse:6.01744                                     
[3]	validation-rmse:5.52544                                     
[4]	validation-rmse:5.23204                                     
[5]	validation-rmse:5.05647                                     
[6]	validation-rmse:4.95086                                     
[7]	validation-rmse:4.88350                                     
[8]	validation-rmse:4.83992                                     
[9]	validation-rmse:4.81008                                     
 92%|▉| 46/50 [34:51<02:57, 44.48s/trial, best loss: 4.696573500





[0]	validation-rmse:4.94494                                     
[1]	validation-rmse:4.89249                                     
[2]	validation-rmse:4.88036                                     
[3]	validation-rmse:4.86967                                     
[4]	validation-rmse:4.86063                                     
[5]	validation-rmse:4.85106                                     
[6]	validation-rmse:4.84353                                     
[7]	validation-rmse:4.83673                                     
[8]	validation-rmse:4.82901                                     
[9]	validation-rmse:4.82178                                     
 94%|▉| 47/50 [35:23<02:02, 40.95s/trial, best loss: 4.696573500





[0]	validation-rmse:7.55889                                     
[1]	validation-rmse:6.17724                                     
[2]	validation-rmse:5.46172                                     
[3]	validation-rmse:5.10677                                     
[4]	validation-rmse:4.93041                                     
[5]	validation-rmse:4.83914                                     
[6]	validation-rmse:4.78783                                     
[7]	validation-rmse:4.75801                                     
[8]	validation-rmse:4.73910                                     
[9]	validation-rmse:4.72470                                     
 96%|▉| 48/50 [36:29<01:36, 48.24s/trial, best loss: 4.696573500





[0]	validation-rmse:6.20736                                     
[1]	validation-rmse:5.22457                                     
[2]	validation-rmse:4.99091                                     
[3]	validation-rmse:4.91812                                     
[4]	validation-rmse:4.89308                                     
[5]	validation-rmse:4.87892                                     
[6]	validation-rmse:4.87067                                     
[7]	validation-rmse:4.86566                                     
[8]	validation-rmse:4.86085                                     
[9]	validation-rmse:4.85615                                     
 98%|▉| 49/50 [36:56<00:42, 42.11s/trial, best loss: 4.696573500





[0]	validation-rmse:8.91152                                     
[1]	validation-rmse:7.98792                                     
[2]	validation-rmse:7.25774                                     
[3]	validation-rmse:6.68769                                     
[4]	validation-rmse:6.24625                                     
[5]	validation-rmse:5.90886                                     
[6]	validation-rmse:5.65166                                     
[7]	validation-rmse:5.45801                                     
[8]	validation-rmse:5.31180                                     
[9]	validation-rmse:5.20192                                     
100%|█| 50/50 [37:37<00:00, 45.15s/trial, best loss: 4.696573500





In [24]:
mlflow.xgboost.autolog(disable=True)

In [25]:
with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        'learning_rate': 0.515363222104369,
        'max_depth': 94,
        'min_child_weight': 6.239676007741263,
        'objective': 'reg:linear',
        'reg_alpha': 0.08024205108205405,
        'reg_lambda': 0.002576667534447248,
        'seed': 42
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=100,
        evals=[(valid, 'validation')],
        early_stopping_rounds=5
    )

    y_pred = booster.predict(valid)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")



[0]	validation-rmse:6.54451
[1]	validation-rmse:5.32228
[2]	validation-rmse:4.94071
[3]	validation-rmse:4.81319
[4]	validation-rmse:4.75991
[5]	validation-rmse:4.73472
[6]	validation-rmse:4.71822
[7]	validation-rmse:4.70741
[8]	validation-rmse:4.70018
[9]	validation-rmse:4.69657
[10]	validation-rmse:4.69350
[11]	validation-rmse:4.69098
[12]	validation-rmse:4.68805
[13]	validation-rmse:4.68526
[14]	validation-rmse:4.68264
[15]	validation-rmse:4.67991
[16]	validation-rmse:4.67751
[17]	validation-rmse:4.67542
[18]	validation-rmse:4.67326
[19]	validation-rmse:4.67115
[20]	validation-rmse:4.66909
[21]	validation-rmse:4.66712
[22]	validation-rmse:4.66527
[23]	validation-rmse:4.66321
[24]	validation-rmse:4.66149
[25]	validation-rmse:4.65922
[26]	validation-rmse:4.65733
[27]	validation-rmse:4.65566
[28]	validation-rmse:4.65382
[29]	validation-rmse:4.65246
[30]	validation-rmse:4.65068
[31]	validation-rmse:4.64809
[32]	validation-rmse:4.64662
[33]	validation-rmse:4.64526
[34]	validation-rmse:4.6



In [30]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)
        



AssertionError: /home/mlops/anaconda3/envs/exp-tracking-env/lib/python3.9/distutils/core.py