In [1]:
!python -V

Python 3.9.23


In [2]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
from sklearn.metrics import root_mean_squared_error
import pickle

In [3]:
# incorporate mlflow
import mlflow

mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

2025/06/23 20:53:50 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/06/23 20:53:50 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.


<Experiment: artifact_location='/workspaces/mlops-zoomcamp/02-experiment_tracking/mlruns/1', creation_time=1750216104664, experiment_id='1', last_update_time=1750216104664, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [4]:
# compile to a function -- read the data

def read_dataframe(filename):
    df = pd.read_parquet(filename)
    df['diff'] = df.lpep_dropoff_datetime -df.lpep_pickup_datetime
    df['duration'] = df['diff'].dt.total_seconds()/60
    df = df[(df.duration >=1)&(df.duration <=60)]
    categorical = ['PULocationID','DOLocationID'] # pick the categorical features that might be useful to predict duration
    df[categorical] = df[categorical].astype('str')
    return df

In [5]:
df_train = read_dataframe("data/green_tripdata_2021-01.parquet")
df_val = read_dataframe("data/green_tripdata_2021-02.parquet")

In [6]:
len(df_train), len(df_val)

(73908, 61921)

In [6]:
# exploring new features to better predict the duration
# e.g. pickup-dropoff pair
df_train['PU_DO'] = df_train['PULocationID'] + '_'+ df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_'+ df_val['DOLocationID']

In [7]:
categorical = ['PU_DO'] # ['PULocationID','DOLocationID'] 
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical+numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts) 

val_dicts = df_val[categorical+numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts) # for validation we don't run fit_transform

In [8]:
target ='duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [10]:
lr = LinearRegression() # create an instance/object
lr.fit(X_train, y_train)

y_val_pred = lr.predict(X_val) # apply the model trained on train-set on the validation set

root_mean_squared_error(y_val, y_val_pred)

7.758715209663881

In [11]:
# try Ridge
rd = Ridge(alpha=10)

rd.fit(X_train, y_train)

y_pred = rd.predict(X_train)
y_val_pred = rd.predict(X_val)

print('Training RMSE is', root_mean_squared_error(y_train, y_pred))
print('Validation RMSE is', root_mean_squared_error(y_val, y_val_pred))

Training RMSE is 7.6660279773989375
Validation RMSE is 8.846837413677452


#### suppose we want to keep the linear regression (baseline) model

In [12]:
with open('models/lin_reg.bin', 'wb') as f_out: # mode = 'wb' (write binary)
    pickle.dump((dv, lr), f_out) 

In [10]:
# baseline -- linear regression
# then try Lasso

# now modify to save experiments using mlflow

with mlflow.start_run():

    mlflow.set_tag("developer","CS")
    mlflow.log_param("train-data-path","data/green_tripdata_2021-01.parquet") # including hyperparameters
    mlflow.log_param("valid-data-path","data/green_tripdata_2021-02.parquet")

    #alpha: the penalty term that denotes the amount of shrinkage that will be implemented in the equation
    # when alpha = 0 the model is equivalent to a linear regression model
    # larger alpha value penalizes the optimization function 

    #alpha = 0.01
    #alpha = 0.1

    #mlflow.log_param("alpha", alpha)

    #ls = Lasso(alpha)
    # ls = Lasso(alpha = 0.01) # adjust alpha to see fitting performance
    #ls.fit(X_train, y_train)

    #y_pred = ls.predict(X_train)
    #y_val_pred = ls.predict(X_val)

    lr = LinearRegression() 
    lr.fit(X_train, y_train)
    
    y_pred = lr.predict(X_train)
    y_val_pred = lr.predict(X_val) 

   # print('Training RMSE is', root_mean_squared_error(y_train, y_pred))
   # print('Validation RMSE is', root_mean_squared_error(y_val, y_val_pred))
    rmse_train = root_mean_squared_error(y_train, y_pred)
    rmse_val = root_mean_squared_error(y_val, y_val_pred)

    mlflow.log_metrics({
        "rmse_train":rmse_train, 
        "rmse_val": rmse_val
        })
    
    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle") 
    # local_path: where i stored the model locally using picle dump
    # artifact_path: where the mlflow will save the model



In [11]:
# add this before importing xgboost to avoid crashes
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [10]:
# try xgboost
import xgboost as xgb

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [11]:
mlflow.xgboost.autolog() # call mlflow.autolog() before creating dataset or training model 

In [12]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [None]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model","xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid,"validation")],
            early_stopping_rounds=10 # instead of 50, which is used by the instructor
        )
        y_pred=booster.predict(train)
        y_val_pred=booster.predict(valid)

        rmse_train = root_mean_squared_error(y_train, y_pred)
        rmse_val = root_mean_squared_error(y_val, y_val_pred)

        mlflow.log_metrics({
        "rmse_train":rmse_train, 
        "rmse_val": rmse_val
        })
    return {'loss':rmse_val,'status':STATUS_OK} # loss must be a string or a number

search_space = {
'max_depth': scope.int(hp.quniform('max_depth',4,100,1)),
'learning_rate': hp.loguniform('learning_rate',-3,0), # [exp(-3), exp(0)], which is [0.05, 1]
'reg_alpha':hp.loguniform('reg_alpha',-5,-1),
'reg_lambda':hp.loguniform('reg_lambda',-6,-1),
'min_child_weight':hp.loguniform('min_child_weight',-1,3),
'objective':'reg:linear',
'seed':42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()

)

  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:7.16223                           
[1]	validation-rmse:6.67195                           
[2]	validation-rmse:6.59432                           
[3]	validation-rmse:6.56282                           
[4]	validation-rmse:6.55057                           
[5]	validation-rmse:6.54149                           
[6]	validation-rmse:6.52071                           
[7]	validation-rmse:6.51534                           
[8]	validation-rmse:6.50995                           
[9]	validation-rmse:6.50652                           
[10]	validation-rmse:6.50377                          
[11]	validation-rmse:6.49813                          
[12]	validation-rmse:6.49549                          
[13]	validation-rmse:6.49213                          
[14]	validation-rmse:6.49045                          
[15]	validation-rmse:6.48916                          
[16]	validation-rmse:6.48721                          
[17]	validation-rmse:6.48440                          
[18]	valid




[0]	validation-rmse:9.83771                                                     
[1]	validation-rmse:8.43321                                                     
[2]	validation-rmse:7.63906                                                     
[3]	validation-rmse:7.20078                                                     
[4]	validation-rmse:6.95922                                                     
[5]	validation-rmse:6.82344                                                     
[6]	validation-rmse:6.73585                                                     
[7]	validation-rmse:6.68452                                                     
[8]	validation-rmse:6.65080                                                     
[9]	validation-rmse:6.62794                                                     
[10]	validation-rmse:6.61318                                                    
[11]	validation-rmse:6.60104                                                    
[12]	validation-rmse:6.58849




[0]	validation-rmse:9.83313                                                       
[1]	validation-rmse:8.41097                                                       
[2]	validation-rmse:7.59236                                                       
[3]	validation-rmse:7.13499                                                       
[4]	validation-rmse:6.88155                                                       
[5]	validation-rmse:6.73748                                                       
[6]	validation-rmse:6.65117                                                       
[7]	validation-rmse:6.59787                                                       
[8]	validation-rmse:6.56470                                                       
[9]	validation-rmse:6.54224                                                       
[10]	validation-rmse:6.52511                                                      
[11]	validation-rmse:6.51372                                                      
[12]




[1]	validation-rmse:10.69931                                                     
[2]	validation-rmse:10.09482                                                     
[3]	validation-rmse:9.57626                                                      
[4]	validation-rmse:9.13359                                                      
[5]	validation-rmse:8.75628                                                      
[6]	validation-rmse:8.43595                                                      
[7]	validation-rmse:8.16567                                                      
[8]	validation-rmse:7.93773                                                      
[9]	validation-rmse:7.74615                                                      
[10]	validation-rmse:7.58528                                                     
[11]	validation-rmse:7.45011                                                     
[12]	validation-rmse:7.33666                                                     
[13]	validation-




[0]	validation-rmse:10.80497                                                     
[1]	validation-rmse:9.71761                                                      
[2]	validation-rmse:8.88752                                                      
[3]	validation-rmse:8.26066                                                      
[4]	validation-rmse:7.79624                                                      
[5]	validation-rmse:7.45109                                                      
[6]	validation-rmse:7.19744                                                      
[7]	validation-rmse:7.01303                                                      
[8]	validation-rmse:6.87644                                                      
[9]	validation-rmse:6.77639                                                      
[10]	validation-rmse:6.69786                                                     
[11]	validation-rmse:6.63969                                                     
[12]	validation-




[0]	validation-rmse:8.76214                                                       
[1]	validation-rmse:7.35923                                                       
[2]	validation-rmse:6.82793                                                       
[3]	validation-rmse:6.62056                                                       
[4]	validation-rmse:6.53444                                                       
[5]	validation-rmse:6.49185                                                       
[6]	validation-rmse:6.46809                                                       
[7]	validation-rmse:6.45228                                                       
[8]	validation-rmse:6.44501                                                       
[9]	validation-rmse:6.44099                                                       
[10]	validation-rmse:6.43629                                                      
[11]	validation-rmse:6.43209                                                      
[12]




[0]	validation-rmse:6.67143                                                      
[1]	validation-rmse:6.59780                                                      
[2]	validation-rmse:6.59056                                                      
[3]	validation-rmse:6.57901                                                      
[4]	validation-rmse:6.57283                                                      
[5]	validation-rmse:6.56585                                                      
[6]	validation-rmse:6.55226                                                      
[7]	validation-rmse:6.54965                                                      
[8]	validation-rmse:6.53192                                                      
[9]	validation-rmse:6.52575                                                      
[10]	validation-rmse:6.51947                                                     
[11]	validation-rmse:6.51048                                                     
[12]	validation-




[0]	validation-rmse:11.21168                                                   
[1]	validation-rmse:10.37733                                                   
[2]	validation-rmse:9.68724                                                    
[3]	validation-rmse:9.12009                                                    
[4]	validation-rmse:8.65737                                                    
[5]	validation-rmse:8.28193                                                    
[6]	validation-rmse:7.97846                                                    
[7]	validation-rmse:7.73330                                                    
[8]	validation-rmse:7.53682                                                    
[9]	validation-rmse:7.37978                                                    
[10]	validation-rmse:7.25337                                                   
[11]	validation-rmse:7.15256                                                   
[12]	validation-rmse:7.07209            




[0]	validation-rmse:9.21928                                                       
[1]	validation-rmse:7.78316                                                       
[2]	validation-rmse:7.13138                                                       
[3]	validation-rmse:6.84294                                                       
[4]	validation-rmse:6.70839                                                       
[5]	validation-rmse:6.63858                                                       
[6]	validation-rmse:6.59719                                                       
[7]	validation-rmse:6.57452                                                       
[8]	validation-rmse:6.56144                                                       
[9]	validation-rmse:6.55128                                                       
[10]	validation-rmse:6.54750                                                      
[11]	validation-rmse:6.54126                                                      
[12]




[0]	validation-rmse:9.80485                                                       
[1]	validation-rmse:8.41460                                                       
[2]	validation-rmse:7.64948                                                       
[3]	validation-rmse:7.23671                                                       
[4]	validation-rmse:7.01668                                                       
[5]	validation-rmse:6.89746                                                       
[6]	validation-rmse:6.82396                                                       
[7]	validation-rmse:6.78087                                                       
[8]	validation-rmse:6.74889                                                       
[9]	validation-rmse:6.73112                                                       
[10]	validation-rmse:6.71486                                                      
[11]	validation-rmse:6.70127                                                      
[12]

KeyboardInterrupt: 

In [15]:
# best model selection -- based on metric (rmse in this case), modeling time (duration), training size (max_depth) etc.

# mlflow autolog --- take one run as an example
# certain libraries support autologging
# xgboost is one of them

params = {
        'learning_rate':
        0.20472169880371677,
        'max_depth':
        17,
        'min_child_weight':
        1.2402611720043835,
        'objective':'reg:squarederror', # reg:linear is deprecated
        #'reg:linear',
        'reg_alpha':
        0.28567896734700793,
        'reg_lambda':
        0.004264404814393109,
        'seed':
        42,
        'eval_metric':'rmse', 
        'verbosity':1
}

with mlflow.start_run(): # use this 'wrapper' to ensure run artifacts and metadata are recorded
    booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=100, # triggers crashes
            evals=[(valid,"validation")],
            early_stopping_rounds=10 # instead of 50, which is used by the instructor
        )
    


[0]	validation-rmse:10.62105
[1]	validation-rmse:9.45454
[2]	validation-rmse:8.61244
[3]	validation-rmse:8.02161
[4]	validation-rmse:7.60613
[5]	validation-rmse:7.31430
[6]	validation-rmse:7.11506
[7]	validation-rmse:6.97578
[8]	validation-rmse:6.87711
[9]	validation-rmse:6.80659
[10]	validation-rmse:6.75482
[11]	validation-rmse:6.71870
[12]	validation-rmse:6.68975
[13]	validation-rmse:6.67024
[14]	validation-rmse:6.65272
[15]	validation-rmse:6.63910
[16]	validation-rmse:6.62823
[17]	validation-rmse:6.62004
[18]	validation-rmse:6.61336
[19]	validation-rmse:6.60866
[20]	validation-rmse:6.60602
[21]	validation-rmse:6.60323
[22]	validation-rmse:6.60192
[23]	validation-rmse:6.60053
[24]	validation-rmse:6.59835
[25]	validation-rmse:6.59609
[26]	validation-rmse:6.59387
[27]	validation-rmse:6.59214
[28]	validation-rmse:6.59088
[29]	validation-rmse:6.58932
[30]	validation-rmse:6.58731
[31]	validation-rmse:6.58478
[32]	validation-rmse:6.58333
[33]	validation-rmse:6.58150
[34]	validation-rmse:6.

