In [1]:
import pandas as pd

In [2]:
import pickle

In [3]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.metrics import mean_squared_error

In [4]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)

    df['PU_DO'] = df['PULocationID'] + '_' + df['DOLocationID']

    return df

In [5]:
df_train = read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-01.parquet')
df_val = read_dataframe('https://d37ci6vzurychx.cloudfront.net/trip-data/green_tripdata_2021-02.parquet')

In [6]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [7]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [8]:
import xgboost as xgb

In [9]:
from pathlib import Path

In [10]:
models_folder = Path('models')
models_folder.mkdir(exist_ok=True)

In [11]:
import mlflow

# mlflow.set_tracking_uri("sqlite:////workspaces/MLOps/training/mlflow.db")
# mlflow.set_experiment("test03")

2025/07/17 17:33:19 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/07/17 17:33:19 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
2025/07/17 17:33:19 INFO mlflow.tracking.fluent: Experiment with name 'test03' does not exist. Creating a new experiment.


<Experiment: artifact_location='/workspaces/MLOps/training/mlruns/4', creation_time=1752773599685, experiment_id='4', last_update_time=1752773599685, lifecycle_stage='active', name='test03', tags={}>

In [13]:
import mlflow
import xgboost as xgb
import pickle
from sklearn.metrics import mean_squared_error

mlflow.set_tracking_uri("sqlite:////workspaces/MLOps/training/mlflow.db")
mlflow.set_experiment("test03")

for lr in [0.01, 0.1, 0.2]:  # Try 3 different learning rates
    with mlflow.start_run():
        params = {
            'learning_rate': lr,
            'max_depth': 30,
            'min_child_weight': 1.06,
            'objective': 'reg:linear',
            'reg_alpha': 0.018,
            'reg_lambda': 0.011,
            'seed': 42
        }

        mlflow.log_params(params)

        train = xgb.DMatrix(X_train, label=y_train)
        valid = xgb.DMatrix(X_val, label=y_val)

        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=30,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )

        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

        # Log preprocessor and model
        with open("models/preprocessor.b", "wb") as f_out:
            pickle.dump(dv, f_out)
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")




[0]	validation-rmse:12.13190
[1]	validation-rmse:12.05174
[2]	validation-rmse:11.97262
[3]	validation-rmse:11.89453
[4]	validation-rmse:11.81743
[5]	validation-rmse:11.74133
[6]	validation-rmse:11.66627
[7]	validation-rmse:11.59210
[8]	validation-rmse:11.51902
[9]	validation-rmse:11.44677
[10]	validation-rmse:11.37554
[11]	validation-rmse:11.30515
[12]	validation-rmse:11.23571
[13]	validation-rmse:11.16721
[14]	validation-rmse:11.09963
[15]	validation-rmse:11.03296
[16]	validation-rmse:10.96721
[17]	validation-rmse:10.90224
[18]	validation-rmse:10.83825
[19]	validation-rmse:10.77504
[20]	validation-rmse:10.71262
[21]	validation-rmse:10.65113
[22]	validation-rmse:10.59047
[23]	validation-rmse:10.53064
[24]	validation-rmse:10.47168
[25]	validation-rmse:10.41351
[26]	validation-rmse:10.35608
[27]	validation-rmse:10.29946
[28]	validation-rmse:10.24370
[29]	validation-rmse:10.18871




[0]	validation-rmse:11.41214
[1]	validation-rmse:10.71494
[2]	validation-rmse:10.10951
[3]	validation-rmse:9.58848
[4]	validation-rmse:9.13814
[5]	validation-rmse:8.75229
[6]	validation-rmse:8.42253
[7]	validation-rmse:8.14331
[8]	validation-rmse:7.90431
[9]	validation-rmse:7.70214
[10]	validation-rmse:7.53185
[11]	validation-rmse:7.38625
[12]	validation-rmse:7.26296
[13]	validation-rmse:7.15897
[14]	validation-rmse:7.07113
[15]	validation-rmse:6.99730
[16]	validation-rmse:6.93339
[17]	validation-rmse:6.88000
[18]	validation-rmse:6.83341
[19]	validation-rmse:6.79391
[20]	validation-rmse:6.75888
[21]	validation-rmse:6.72947
[22]	validation-rmse:6.70372
[23]	validation-rmse:6.68119
[24]	validation-rmse:6.66175
[25]	validation-rmse:6.64508
[26]	validation-rmse:6.63008
[27]	validation-rmse:6.61629
[28]	validation-rmse:6.60422
[29]	validation-rmse:6.59343




[0]	validation-rmse:10.63945
[1]	validation-rmse:9.47628
[2]	validation-rmse:8.63028
[3]	validation-rmse:8.02696
[4]	validation-rmse:7.60037
[5]	validation-rmse:7.29939
[6]	validation-rmse:7.08604
[7]	validation-rmse:6.93732
[8]	validation-rmse:6.83199
[9]	validation-rmse:6.75769
[10]	validation-rmse:6.70059
[11]	validation-rmse:6.65816
[12]	validation-rmse:6.62545
[13]	validation-rmse:6.60130
[14]	validation-rmse:6.58387
[15]	validation-rmse:6.56888
[16]	validation-rmse:6.55571
[17]	validation-rmse:6.54750
[18]	validation-rmse:6.53968
[19]	validation-rmse:6.53460
[20]	validation-rmse:6.52952
[21]	validation-rmse:6.52637
[22]	validation-rmse:6.52392
[23]	validation-rmse:6.52301
[24]	validation-rmse:6.52088
[25]	validation-rmse:6.51809
[26]	validation-rmse:6.51682
[27]	validation-rmse:6.51466
[28]	validation-rmse:6.51249
[29]	validation-rmse:6.51077


