In [1]:
import pandas as pd
import numpy as np

In [2]:
## !pip install pyarrow

In [2]:
import pickle

In [3]:
import mlflow

In [4]:
mlflow.set_tracking_uri('sqlite:///mlflow.db')
mlflow.set_experiment('nyc-experiment')

2024/06/28 08:06:31 INFO mlflow.tracking.fluent: Experiment with name 'nyc-experiment' does not exist. Creating a new experiment.


<Experiment: artifact_location='/workspaces/MLOps-zoomcamp/02-experiment-tracking/mlruns/1', creation_time=1719561991637, experiment_id='1', last_update_time=1719561991637, lifecycle_stage='active', name='nyc-experiment', tags={}>

In [5]:
df = pd.read_parquet('https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2023-01.parquet')

In [6]:
import seaborn as sns

from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression

from sklearn.metrics import mean_squared_error

In [7]:
df.shape

(3066766, 19)

In [8]:
def read_dataframe(filename):
    if filename.endswith('.csv'):
        df = pd.read_csv(filename)

        df.tpep_dropoff_datetime = pd.to_datetime(df.tpep_dropoff_datetime)
        df.tpep_pickup_datetime = pd.to_datetime(df.tpep_pickup_datetime)
    elif filename.endswith('.parquet'):
        df = pd.read_parquet(filename)

    df['duration'] = df.tpep_dropoff_datetime - df.tpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [9]:
df_train = read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2023-01.parquet')
df_val = read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2023-02.parquet')

In [10]:
len(df_train), len(df_val)

(3009173, 2855951)

# Reduced the size of train and test data due to buffer size issue and hence the rmse value may vary

In [11]:
df_train = df_train.iloc[:20000] 
df_val = df_val.iloc[:20000] 

In [12]:
len(df_train), len(df_val)

(20000, 20000)

In [13]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [14]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

In [15]:
dv = DictVectorizer()

In [16]:
train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

In [17]:
X_train.shape

(20000, 4506)

In [18]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [19]:
val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

Training the model

In [20]:
lr = LinearRegression()
lr.fit(X_train, y_train)

Evaluating the model

In [21]:
y_pred = lr.predict(X_val)

rmse_val = mean_squared_error(y_val, y_pred, squared=False)
rmse_val



6.1952743368852845

In [22]:
from sklearn.linear_model import Lasso

In [23]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [25]:
lr = Lasso(0.1)
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)
rmse = mean_squared_error(y_val, y_pred, squared=False)



In [26]:
rmse

5.737033794695259

In [28]:
with mlflow.start_run():

    mlflow.set_tag("developer", "Bhavana")

    #mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
    #mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")



In [29]:
import xgboost as xgb

In [30]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [31]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [32]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

In [33]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

  0%|                                                              | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:7.56356                                                                               
[1]	validation-rmse:6.46786                                                                               
[2]	validation-rmse:5.94130                                                                               
[3]	validation-rmse:5.69510                                                                               
[4]	validation-rmse:5.59209                                                                               
[5]	validation-rmse:5.54693                                                                               
[6]	validation-rmse:5.53263                                                                               
[7]	validation-rmse:5.52593                                                                               
[8]	validation-rmse:5.53347                                                                               
[9]	validation-rmse:5.54653          





[0]	validation-rmse:8.00736                                                                               
[1]	validation-rmse:6.97799                                                                               
[2]	validation-rmse:6.36243                                                                               
[3]	validation-rmse:6.01288                                                                               
[4]	validation-rmse:5.81695                                                                               
[5]	validation-rmse:5.71555                                                                               
[6]	validation-rmse:5.66286                                                                               
[7]	validation-rmse:5.64510                                                                               
[8]	validation-rmse:5.64315                                                                               
[9]	validation-rmse:5.65418          





[0]	validation-rmse:7.60818                                                                               
[1]	validation-rmse:6.53038                                                                               
[2]	validation-rmse:6.01778                                                                               
[3]	validation-rmse:5.78218                                                                               
[4]	validation-rmse:5.68949                                                                               
[5]	validation-rmse:5.67236                                                                               
[6]	validation-rmse:5.68337                                                                               
[7]	validation-rmse:5.71425                                                                               
[8]	validation-rmse:5.75064                                                                               
[9]	validation-rmse:5.76208          





[0]	validation-rmse:8.14936                                                                               
[1]	validation-rmse:7.23246                                                                               
[2]	validation-rmse:6.71743                                                                               
[3]	validation-rmse:6.41685                                                                               
[4]	validation-rmse:6.26248                                                                               
[5]	validation-rmse:6.17331                                                                               
[6]	validation-rmse:6.14091                                                                               
[7]	validation-rmse:6.12442                                                                               
[8]	validation-rmse:6.11411                                                                               
[9]	validation-rmse:6.11724          





[0]	validation-rmse:9.00647                                                                               
[1]	validation-rmse:8.46489                                                                               
[2]	validation-rmse:8.00980                                                                               
[3]	validation-rmse:7.64210                                                                               
[4]	validation-rmse:7.32937                                                                               
[5]	validation-rmse:7.07999                                                                               
[6]	validation-rmse:6.87326                                                                               
[7]	validation-rmse:6.71096                                                                               
[8]	validation-rmse:6.57354                                                                               
[9]	validation-rmse:6.46321          





[1]	validation-rmse:8.86959                                                                               
[2]	validation-rmse:8.53012                                                                               
[3]	validation-rmse:8.21877                                                                               
[4]	validation-rmse:7.94108                                                                               
[5]	validation-rmse:7.68642                                                                               
[6]	validation-rmse:7.45480                                                                               
[7]	validation-rmse:7.24803                                                                               
[8]	validation-rmse:7.05930                                                                               
[9]	validation-rmse:6.89861                                                                               
[10]	validation-rmse:6.74686         





[0]	validation-rmse:9.32613                                                                               
[1]	validation-rmse:9.02587                                                                               
[2]	validation-rmse:8.74842                                                                               
[3]	validation-rmse:8.49601                                                                               
[4]	validation-rmse:8.26221                                                                               
[5]	validation-rmse:8.04975                                                                               
[6]	validation-rmse:7.85181                                                                               
[7]	validation-rmse:7.67733                                                                               
[8]	validation-rmse:7.51446                                                                               
[9]	validation-rmse:7.36235          





[0]	validation-rmse:7.43988                                                                               
[1]	validation-rmse:6.37616                                                                               
[2]	validation-rmse:5.91896                                                                               
[3]	validation-rmse:5.73428                                                                               
[4]	validation-rmse:5.66651                                                                               
[5]	validation-rmse:5.66959                                                                               
[6]	validation-rmse:5.70382                                                                               
[7]	validation-rmse:5.72926                                                                               
[8]	validation-rmse:5.72955                                                                               
[9]	validation-rmse:5.73745          





[0]	validation-rmse:8.95656                                                                               
[1]	validation-rmse:8.36342                                                                               
[2]	validation-rmse:7.86164                                                                               
[3]	validation-rmse:7.43959                                                                               
[4]	validation-rmse:7.08845                                                                               
[5]	validation-rmse:6.79575                                                                               
[6]	validation-rmse:6.55378                                                                               
[7]	validation-rmse:6.35486                                                                               
[8]	validation-rmse:6.19143                                                                               
[9]	validation-rmse:6.05938          





[0]	validation-rmse:6.16803                                                                               
[1]	validation-rmse:6.06927                                                                               
[2]	validation-rmse:6.08759                                                                               
[3]	validation-rmse:6.08715                                                                               
[4]	validation-rmse:6.09183                                                                               
[5]	validation-rmse:6.12691                                                                               
[6]	validation-rmse:6.12641                                                                               
[7]	validation-rmse:6.13084                                                                               
[8]	validation-rmse:6.13386                                                                               
[9]	validation-rmse:6.13677          





[0]	validation-rmse:9.23646                                                                               
[1]	validation-rmse:8.85647                                                                               
[2]	validation-rmse:8.51050                                                                               
[3]	validation-rmse:8.19554                                                                               
[4]	validation-rmse:7.90965                                                                               
[5]	validation-rmse:7.65070                                                                               
[6]	validation-rmse:7.41678                                                                               
[7]	validation-rmse:7.20598                                                                               
[8]	validation-rmse:7.01610                                                                               
[9]	validation-rmse:6.84535          





[0]	validation-rmse:8.79238                                                                               
[1]	validation-rmse:8.08893                                                                               
[2]	validation-rmse:7.52259                                                                               
[3]	validation-rmse:7.06986                                                                               
[4]	validation-rmse:6.71312                                                                               
[5]	validation-rmse:6.43185                                                                               
[6]	validation-rmse:6.21404                                                                               
[7]	validation-rmse:6.04642                                                                               
[8]	validation-rmse:5.91493                                                                               
[9]	validation-rmse:5.81415          





[0]	validation-rmse:8.40232                                                                               
[1]	validation-rmse:7.54823                                                                               
[2]	validation-rmse:6.97206                                                                               
[3]	validation-rmse:6.59671                                                                               
[4]	validation-rmse:6.33967                                                                               
[5]	validation-rmse:6.18290                                                                               
[6]	validation-rmse:6.08191                                                                               
[7]	validation-rmse:6.00224                                                                               
[8]	validation-rmse:5.96617                                                                               
[9]	validation-rmse:5.92020          





[0]	validation-rmse:7.19223                                                                               
[1]	validation-rmse:6.15878                                                                               
[2]	validation-rmse:5.79527                                                                               
[3]	validation-rmse:5.69420                                                                               
[4]	validation-rmse:5.68912                                                                               
[5]	validation-rmse:5.71539                                                                               
[6]	validation-rmse:5.72560                                                                               
[7]	validation-rmse:5.74658                                                                               
[8]	validation-rmse:5.78014                                                                               
[9]	validation-rmse:5.78428          





[0]	validation-rmse:8.67205                                                                               
[1]	validation-rmse:7.90389                                                                               
[2]	validation-rmse:7.30503                                                                               
[3]	validation-rmse:6.85288                                                                               
[4]	validation-rmse:6.50909                                                                               
[5]	validation-rmse:6.24665                                                                               
[6]	validation-rmse:6.05127                                                                               
[7]	validation-rmse:5.91081                                                                               
[8]	validation-rmse:5.80825                                                                               
[9]	validation-rmse:5.73533          





[0]	validation-rmse:6.99905                                                                               
[1]	validation-rmse:5.97711                                                                               
[2]	validation-rmse:5.61921                                                                               
[3]	validation-rmse:5.49772                                                                               
[4]	validation-rmse:5.45190                                                                               
[5]	validation-rmse:5.43563                                                                               
[6]	validation-rmse:5.42482                                                                               
[7]	validation-rmse:5.41956                                                                               
[8]	validation-rmse:5.42153                                                                               
[9]	validation-rmse:5.42060          





[0]	validation-rmse:6.26803                                                                               
[1]	validation-rmse:5.70692                                                                               
[2]	validation-rmse:5.66491                                                                               
[3]	validation-rmse:5.72308                                                                               
[4]	validation-rmse:5.72563                                                                               
[5]	validation-rmse:5.76355                                                                               
[6]	validation-rmse:5.76800                                                                               
[7]	validation-rmse:5.77248                                                                               
[8]	validation-rmse:5.78001                                                                               
[9]	validation-rmse:5.77998          





[0]	validation-rmse:5.77647                                                                               
[1]	validation-rmse:5.56857                                                                               
[2]	validation-rmse:5.57727                                                                               
[3]	validation-rmse:5.57716                                                                               
[4]	validation-rmse:5.57756                                                                               
[5]	validation-rmse:5.56205                                                                               
[6]	validation-rmse:5.55681                                                                               
[7]	validation-rmse:5.55527                                                                               
[8]	validation-rmse:5.56045                                                                               
[9]	validation-rmse:5.56759          




[0]	validation-rmse:8.78011                                                                               
[1]	validation-rmse:8.07053                                                                               
                                                                                                          




[2]	validation-rmse:7.49906
[3]	validation-rmse:7.04285                                                                               
[4]	validation-rmse:6.68194                                                                               
[5]	validation-rmse:6.39878                                                                               
[6]	validation-rmse:6.17819                                                                               
[7]	validation-rmse:6.00692                                                                               
[8]	validation-rmse:5.87445                                                                               
[9]	validation-rmse:5.77295                                                                               
[10]	validation-rmse:5.69508                                                                              
[11]	validation-rmse:5.63505                                                                              
[12]	vali





[0]	validation-rmse:6.94796                                                                               
[1]	validation-rmse:6.00275                                                                               
[2]	validation-rmse:5.72395                                                                               
[3]	validation-rmse:5.65509                                                                               
[4]	validation-rmse:5.64244                                                                               
[5]	validation-rmse:5.63648                                                                               
[6]	validation-rmse:5.62087                                                                               
[7]	validation-rmse:5.61775                                                                               
[8]	validation-rmse:5.61789                                                                               
[9]	validation-rmse:5.61764          





[0]	validation-rmse:8.46896                                                                               
[1]	validation-rmse:7.58985                                                                               
[2]	validation-rmse:6.94747                                                                               
[3]	validation-rmse:6.48655                                                                               
[4]	validation-rmse:6.16268                                                                               
[5]	validation-rmse:5.93754                                                                               
[6]	validation-rmse:5.78167                                                                               
[7]	validation-rmse:5.67408                                                                               
[8]	validation-rmse:5.59978                                                                               
[9]	validation-rmse:5.54812          





[0]	validation-rmse:8.49417                                                                               
[1]	validation-rmse:7.62709                                                                               
[2]	validation-rmse:6.98932                                                                               
[3]	validation-rmse:6.52853                                                                               
[4]	validation-rmse:6.19984                                                                               
[5]	validation-rmse:5.96800                                                                               
[6]	validation-rmse:5.80642                                                                               
[7]	validation-rmse:5.69383                                                                               
[8]	validation-rmse:5.61512                                                                               
[9]	validation-rmse:5.56044          





[0]	validation-rmse:8.48517                                                                               
[1]	validation-rmse:7.61287                                                                               
[2]	validation-rmse:6.97208                                                                               
[3]	validation-rmse:6.51259                                                                               
[4]	validation-rmse:6.18749                                                                               
[5]	validation-rmse:5.95919                                                                               
[6]	validation-rmse:5.79986                                                                               
[7]	validation-rmse:5.69039                                                                               
[8]	validation-rmse:5.61361                                                                               
[9]	validation-rmse:5.56086          





[0]	validation-rmse:8.99601                                                                               
[1]	validation-rmse:8.43120                                                                               
[2]	validation-rmse:7.94758                                                                               
[3]	validation-rmse:7.53445                                                                               
[4]	validation-rmse:7.18498                                                                               
[5]	validation-rmse:6.88977                                                                               
[6]	validation-rmse:6.64231                                                                               
[7]	validation-rmse:6.43556                                                                               
[8]	validation-rmse:6.26366                                                                               
[9]	validation-rmse:6.12080          





[0]	validation-rmse:9.15640                                                                               
[1]	validation-rmse:8.71158                                                                               
[2]	validation-rmse:8.31439                                                                               
[3]	validation-rmse:7.96069                                                                               
[4]	validation-rmse:7.64652                                                                               
[5]	validation-rmse:7.36764                                                                               
[6]	validation-rmse:7.12128                                                                               
[7]	validation-rmse:6.90426                                                                               
[8]	validation-rmse:6.71453                                                                               
[9]	validation-rmse:6.54850          





[0]	validation-rmse:7.98326                                                                               
[1]	validation-rmse:6.93359                                                                               
[2]	validation-rmse:6.30197                                                                               
[3]	validation-rmse:5.93522                                                                               
[4]	validation-rmse:5.72671                                                                               
[5]	validation-rmse:5.61408                                                                               
[6]	validation-rmse:5.55189                                                                               
[7]	validation-rmse:5.51738                                                                               
[8]	validation-rmse:5.49609                                                                               
[9]	validation-rmse:5.47957          





[0]	validation-rmse:8.46687                                                                               
[1]	validation-rmse:7.58621                                                                               
[2]	validation-rmse:6.94554                                                                               
[3]	validation-rmse:6.48690                                                                               
[4]	validation-rmse:6.16311                                                                               
[5]	validation-rmse:5.93731                                                                               
[6]	validation-rmse:5.78138                                                                               
[7]	validation-rmse:5.67435                                                                               
[8]	validation-rmse:5.60036                                                                               
[9]	validation-rmse:5.55058          





[0]	validation-rmse:8.65679                                                                               
[1]	validation-rmse:7.87505                                                                               
[2]	validation-rmse:7.26745                                                                               
[3]	validation-rmse:6.80221                                                                               
[4]	validation-rmse:6.45178                                                                               
[5]	validation-rmse:6.18703                                                                               
[6]	validation-rmse:5.99307                                                                               
[7]	validation-rmse:5.84963                                                                               
[8]	validation-rmse:5.74371                                                                               
[9]	validation-rmse:5.66588          





[0]	validation-rmse:7.73972                                                                               
[1]	validation-rmse:6.64431                                                                               
[2]	validation-rmse:6.06233                                                                               
[3]	validation-rmse:5.76479                                                                               
[4]	validation-rmse:5.61273                                                                               
[5]	validation-rmse:5.53080                                                                               
[6]	validation-rmse:5.49016                                                                               
[7]	validation-rmse:5.47022                                                                               
[8]	validation-rmse:5.46399                                                                               
[9]	validation-rmse:5.47398          





[0]	validation-rmse:9.14651                                                                               
[1]	validation-rmse:8.69416                                                                               
[2]	validation-rmse:8.29106                                                                               
[3]	validation-rmse:7.93292                                                                               
[4]	validation-rmse:7.61574                                                                               
[5]	validation-rmse:7.33567                                                                               
[6]	validation-rmse:7.08903                                                                               
[7]	validation-rmse:6.87223                                                                               
[8]	validation-rmse:6.68247                                                                               
[9]	validation-rmse:6.51684          





[0]	validation-rmse:8.22865                                                                               
[1]	validation-rmse:7.25027                                                                               
[2]	validation-rmse:6.60071                                                                               
[3]	validation-rmse:6.18004                                                                               
[4]	validation-rmse:5.91160                                                                               
[5]	validation-rmse:5.74033                                                                               
[6]	validation-rmse:5.63011                                                                               
[7]	validation-rmse:5.56913                                                                               
[8]	validation-rmse:5.53425                                                                               
[9]	validation-rmse:5.50795          





[0]	validation-rmse:6.52182                                                                               
[1]	validation-rmse:5.69748                                                                               
[2]	validation-rmse:5.51116                                                                               
[3]	validation-rmse:5.47256                                                                               
[4]	validation-rmse:5.45645                                                                               
[5]	validation-rmse:5.45283                                                                               
[6]	validation-rmse:5.45368                                                                               
[7]	validation-rmse:5.45463                                                                               
[8]	validation-rmse:5.45303                                                                               
[9]	validation-rmse:5.45481          





[0]	validation-rmse:7.82248                                                                               
[1]	validation-rmse:6.74288                                                                               
[2]	validation-rmse:6.13749                                                                               
[3]	validation-rmse:5.80589                                                                               
[4]	validation-rmse:5.63356                                                                               
[5]	validation-rmse:5.54278                                                                               
[6]	validation-rmse:5.49179                                                                               
[7]	validation-rmse:5.46696                                                                               
[8]	validation-rmse:5.45077                                                                               
[9]	validation-rmse:5.44467          





[0]	validation-rmse:9.06867                                                                               
[1]	validation-rmse:8.55671                                                                               
[2]	validation-rmse:8.10922                                                                               
[3]	validation-rmse:7.72019                                                                               
[4]	validation-rmse:7.38357                                                                               
[5]	validation-rmse:7.09217                                                                               
[6]	validation-rmse:6.84194                                                                               
[7]	validation-rmse:6.62765                                                                               
[8]	validation-rmse:6.44471                                                                               
[9]	validation-rmse:6.28902          





[0]	validation-rmse:9.33177                                                                               
[1]	validation-rmse:9.03247                                                                               
[2]	validation-rmse:8.75366                                                                               
[3]	validation-rmse:8.49403                                                                               
[4]	validation-rmse:8.25220                                                                               
[5]	validation-rmse:8.02781                                                                               
[6]	validation-rmse:7.81940                                                                               
[7]	validation-rmse:7.62651                                                                               
[8]	validation-rmse:7.44746                                                                               
[9]	validation-rmse:7.28171          





[0]	validation-rmse:9.07468                                                                               
[1]	validation-rmse:8.56735                                                                               
[2]	validation-rmse:8.12304                                                                               
[3]	validation-rmse:7.73705                                                                               
[4]	validation-rmse:7.40219                                                                               
[5]	validation-rmse:7.11183                                                                               
[6]	validation-rmse:6.86230                                                                               
[7]	validation-rmse:6.64869                                                                               
[8]	validation-rmse:6.46580                                                                               
[9]	validation-rmse:6.30973          




[0]	validation-rmse:9.26188                                                                               
[1]	validation-rmse:8.90289                                                                               
[2]	validation-rmse:8.57365                                                                               
 72%|█████████████████████████▉          | 36/50 [04:25<01:43,  7.42s/trial, best loss: 5.423936449051383]




[3]	validation-rmse:8.27227                                                                               
[4]	validation-rmse:7.99687                                                                               
[5]	validation-rmse:7.74534                                                                               
[6]	validation-rmse:7.51634                                                                               
[7]	validation-rmse:7.30809                                                                               
[8]	validation-rmse:7.11911                                                                               
[9]	validation-rmse:6.94880                                                                               
[10]	validation-rmse:6.79463                                                                              
[11]	validation-rmse:6.65554                                                                              
[12]	validation-rmse:6.53044         





[0]	validation-rmse:8.87638                                                                               
[1]	validation-rmse:8.22840                                                                               
[2]	validation-rmse:7.69075                                                                               
[3]	validation-rmse:7.24970                                                                               
[4]	validation-rmse:6.88943                                                                               
[5]	validation-rmse:6.59752                                                                               
[6]	validation-rmse:6.36268                                                                               
[7]	validation-rmse:6.17548                                                                               
[8]	validation-rmse:6.02644                                                                               
[9]	validation-rmse:5.90633          




[0]	validation-rmse:9.12507                                                                               
[1]	validation-rmse:8.65596                                                                               
[2]	validation-rmse:8.24021                                                                               
 76%|███████████████████████████▎        | 38/50 [04:45<01:47,  8.97s/trial, best loss: 5.423936449051383]




[3]	validation-rmse:7.87277                                                                               
[4]	validation-rmse:7.54882                                                                               
[5]	validation-rmse:7.26497                                                                               
[6]	validation-rmse:7.01671                                                                               
[7]	validation-rmse:6.80022                                                                               
[8]	validation-rmse:6.61127                                                                               
[9]	validation-rmse:6.44778                                                                               
[10]	validation-rmse:6.30588                                                                              
[11]	validation-rmse:6.18275                                                                              
[12]	validation-rmse:6.07703         





[0]	validation-rmse:9.02582                                                                               
[1]	validation-rmse:8.48431                                                                               
[2]	validation-rmse:8.01829                                                                               
[3]	validation-rmse:7.61953                                                                               
[4]	validation-rmse:7.28188                                                                               
[5]	validation-rmse:6.99458                                                                               
[6]	validation-rmse:6.75021                                                                               
[7]	validation-rmse:6.54562                                                                               
[8]	validation-rmse:6.37531                                                                               
[9]	validation-rmse:6.23341          





[0]	validation-rmse:7.43477                                                                               
[1]	validation-rmse:6.33869                                                                               
[2]	validation-rmse:5.84311                                                                               
[3]	validation-rmse:5.63042                                                                               
[4]	validation-rmse:5.54118                                                                               
[5]	validation-rmse:5.50567                                                                               
[6]	validation-rmse:5.48212                                                                               
[7]	validation-rmse:5.48403                                                                               
[8]	validation-rmse:5.48975                                                                               
[9]	validation-rmse:5.49019          





[0]	validation-rmse:8.90999                                                                               
[1]	validation-rmse:8.29182                                                                               
[2]	validation-rmse:7.79433                                                                               
[3]	validation-rmse:7.39720                                                                               
[4]	validation-rmse:7.08295                                                                               
[5]	validation-rmse:6.82813                                                                               
[6]	validation-rmse:6.64114                                                                               
[7]	validation-rmse:6.49390                                                                               
[8]	validation-rmse:6.36922                                                                               
[9]	validation-rmse:6.28924          





[0]	validation-rmse:9.28458                                                                               
[1]	validation-rmse:8.94453                                                                               
[2]	validation-rmse:8.63136                                                                               
[3]	validation-rmse:8.34265                                                                               
[4]	validation-rmse:8.07767                                                                               
[5]	validation-rmse:7.83424                                                                               
[6]	validation-rmse:7.61143                                                                               
[7]	validation-rmse:7.40726                                                                               
[8]	validation-rmse:7.22099                                                                               
[9]	validation-rmse:7.05071          





[0]	validation-rmse:9.26606                                                                               
[1]	validation-rmse:8.91087                                                                               
[2]	validation-rmse:8.58463                                                                               
[3]	validation-rmse:8.28534                                                                               
[4]	validation-rmse:8.01135                                                                               
[5]	validation-rmse:7.76159                                                                               
[6]	validation-rmse:7.53370                                                                               
[7]	validation-rmse:7.32601                                                                               
[8]	validation-rmse:7.13743                                                                               
[9]	validation-rmse:6.96667          




[0]	validation-rmse:5.51211                                                                               
[1]	validation-rmse:5.57988                                                                               
[2]	validation-rmse:5.58710                                                                               
[3]	validation-rmse:5.59431                                                                               
[4]	validation-rmse:5.59749                                                                               
[5]	validation-rmse:5.60432                                                                               
[6]	validation-rmse:5.60270                                                                               
 88%|███████████████████████████████▋    | 44/50 [05:40<00:59, 10.00s/trial, best loss: 5.412425829765289]




[7]	validation-rmse:5.60211                                                                               
[8]	validation-rmse:5.60505                                                                               
[9]	validation-rmse:5.60501                                                                               
[10]	validation-rmse:5.60648                                                                              
[11]	validation-rmse:5.60819                                                                              
[12]	validation-rmse:5.62973                                                                              
[13]	validation-rmse:5.63395                                                                              
[14]	validation-rmse:5.63203                                                                              
[15]	validation-rmse:5.63115                                                                              
[16]	validation-rmse:5.63840         





[0]	validation-rmse:9.20867                                                                               
[1]	validation-rmse:8.80561                                                                               
[2]	validation-rmse:8.44121                                                                               
[3]	validation-rmse:8.11166                                                                               
[4]	validation-rmse:7.81432                                                                               
[5]	validation-rmse:7.54743                                                                               
[6]	validation-rmse:7.30785                                                                               
[7]	validation-rmse:7.09371                                                                               
[8]	validation-rmse:6.90243                                                                               
[9]	validation-rmse:6.73251          





[0]	validation-rmse:9.29891                                                                               
[1]	validation-rmse:8.97120                                                                               
[2]	validation-rmse:8.66806                                                                               
[3]	validation-rmse:8.38813                                                                               
[4]	validation-rmse:8.12992                                                                               
[5]	validation-rmse:7.89211                                                                               
[6]	validation-rmse:7.67336                                                                               
[7]	validation-rmse:7.47242                                                                               
[8]	validation-rmse:7.28810                                                                               
[9]	validation-rmse:7.11920          





[0]	validation-rmse:6.14708                                                                               
[1]	validation-rmse:5.54843                                                                               
[2]	validation-rmse:5.45784                                                                               
[3]	validation-rmse:5.44489                                                                               
[4]	validation-rmse:5.44456                                                                               
[5]	validation-rmse:5.45201                                                                               
[6]	validation-rmse:5.45166                                                                               
[7]	validation-rmse:5.45478                                                                               
[8]	validation-rmse:5.45800                                                                               
[9]	validation-rmse:5.45915          





[0]	validation-rmse:5.45753                                                                               
[1]	validation-rmse:5.46709                                                                               
[2]	validation-rmse:5.46030                                                                               
[3]	validation-rmse:5.47503                                                                               
[4]	validation-rmse:5.47725                                                                               
[5]	validation-rmse:5.48470                                                                               
[6]	validation-rmse:5.48061                                                                               
[7]	validation-rmse:5.48274                                                                               
[8]	validation-rmse:5.48177                                                                               
[9]	validation-rmse:5.48300          




[0]	validation-rmse:8.64962                                                                               
[1]	validation-rmse:7.86343                                                                               
[2]	validation-rmse:7.25496                                                                               
 98%|███████████████████████████████████▎| 49/50 [06:02<00:05,  5.12s/trial, best loss: 5.412425829765289]




[3]	validation-rmse:6.78991                                                                               
[4]	validation-rmse:6.43866                                                                               
[5]	validation-rmse:6.17614                                                                               
[6]	validation-rmse:5.98044                                                                               
[7]	validation-rmse:5.83713                                                                               
[8]	validation-rmse:5.73086                                                                               
[9]	validation-rmse:5.65265                                                                               
[10]	validation-rmse:5.59507                                                                              
[11]	validation-rmse:5.55263                                                                              
[12]	validation-rmse:5.52097         




Get the best rmse params from mlflow ui and replace the params values in search space and run the mlflow run

In [39]:
with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        best_params = {
            'learning_rate' : 0.7993188205339578,
            'max_depth' : 73,
            'min_child_weight' : 0.5694797340885676,
            'objective' : 'reg:linear',
            'reg_alpha' : 0.06481429405630112,
            'reg_lambda' : 0.20149788479279268,
            'seed' : 42
        }
        
        mlflow.log_params(best_params)
        booster = xgb.train(
            params=best_params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

        with open('models/preprocessor.b', 'wb') as f_out:
            pickle.dump(dv, f_out)
        mlflow.log_artifact('models/preprocessor.b', artifact_path='preprocessor')
    
        mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")



[0]	validation-rmse:6.16803
[1]	validation-rmse:6.06927
[2]	validation-rmse:6.08759
[3]	validation-rmse:6.08715
[4]	validation-rmse:6.09183
[5]	validation-rmse:6.12691
[6]	validation-rmse:6.12641
[7]	validation-rmse:6.13084
[8]	validation-rmse:6.13386
[9]	validation-rmse:6.13677
[10]	validation-rmse:6.13767
[11]	validation-rmse:6.13780
[12]	validation-rmse:6.13121
[13]	validation-rmse:6.13153
[14]	validation-rmse:6.13282
[15]	validation-rmse:6.13256
[16]	validation-rmse:6.13093
[17]	validation-rmse:6.13379
[18]	validation-rmse:6.13583
[19]	validation-rmse:6.13697
[20]	validation-rmse:6.14021
[21]	validation-rmse:6.14780
[22]	validation-rmse:6.15560
[23]	validation-rmse:6.15871
[24]	validation-rmse:6.16125
[25]	validation-rmse:6.16244
[26]	validation-rmse:6.16508
[27]	validation-rmse:6.16788
[28]	validation-rmse:6.17605
[29]	validation-rmse:6.17858
[30]	validation-rmse:6.18407
[31]	validation-rmse:6.18994
[32]	validation-rmse:6.19217
[33]	validation-rmse:6.19261
[34]	validation-rmse:6.1



In [42]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import LinearSVR

#mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, LinearSVR):

    with mlflow.start_run():

        with open('models/preprocessor.b', 'wb') as f_out:
            pickle.dump(dv, f_out)
        mlflow.log_artifact('models/preprocessor.b', artifact_path='preprocessor')

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

