In [1]:
!python -V

Python 3.9.12


In [16]:
import os
import sys
# Get the parent directory and add it to the system path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.insert(0, parent_dir)

In [20]:
import pickle

from env import *
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [35]:
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error, root_mean_squared_error

In [25]:
import mlflow
import mlflow.experiments


mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("nyc-taxi-experiment")
print(f"tracking URI: '{mlflow.get_tracking_uri()}'")


tracking URI: 'http://127.0.0.1:5000'


In [27]:
def read_dataframe(filename):
    df = pd.read_parquet(filename)

    df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
    df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)

    df["duration"] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ["PULocationID", "DOLocationID"]
    df[categorical] = df[categorical].astype(str)
    
    return df

In [28]:
df_train = read_dataframe("../data/hw2/train/green_tripdata_2023-01.parquet")
df_val = read_dataframe("../data/hw2/train/green_tripdata_2023-02.parquet")

In [29]:
len(df_train), len(df_val)

(65946, 62574)

In [30]:
df_train["PU_DO"] = df_train["PULocationID"] + "_" + df_train["DOLocationID"]
df_val["PU_DO"] = df_val["PULocationID"] + "_" + df_val["DOLocationID"]

In [31]:
categorical = ["PU_DO"] #"PULocationID", "DOLocationID"]
numerical = ["trip_distance"]

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient="records")
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient="records")
X_val = dv.transform(val_dicts)

In [32]:
target = "duration"
y_train = df_train[target].values
y_val = df_val[target].values

In [36]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

root_mean_squared_error(y_val, y_pred)

6.03727552054262

In [38]:
with open("../models/hw2/lin_reg.bin", "wb") as f_out:
    pickle.dump((dv, lr), f_out)

In [40]:
with mlflow.start_run():

    mlflow.set_tag("developer", "Joses")

    mlflow.log_param("train-data-path", "../data/hw2/train/green_tripdata_2023-01.parquet")
    mlflow.log_param("valid-data-path", "../data/hw2/train/green_tripdata_2023-02.parquet")

    alpha = 0.1
    mlflow.log_param("alpha", alpha)
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    y_pred = lr.predict(X_val)
    rmse = root_mean_squared_error(y_val, y_pred)
    mlflow.log_metric("rmse", rmse)

    mlflow.log_artifact(local_path="../models/hw2/lin_reg.bin", artifact_path="models_pickle")

In [42]:
import xgboost as xgb

In [43]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [44]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [47]:
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, "validation")],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = root_mean_squared_error(y_val, y_pred)
        mlflow.log_metric("rmse", rmse)

    return {"loss": rmse, "status": STATUS_OK}

In [48]:
search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[0]	validation-rmse:8.71629                           
[1]	validation-rmse:8.18760                           
[2]	validation-rmse:7.73905                           
[3]	validation-rmse:7.34636                           
[4]	validation-rmse:7.02017                           
[5]	validation-rmse:6.74074                           
[6]	validation-rmse:6.49738                           
[7]	validation-rmse:6.29834                           
[8]	validation-rmse:6.13069                           
[9]	validation-rmse:5.98214                           
[10]	validation-rmse:5.87214                          
[11]	validation-rmse:5.76980                          
[12]	validation-rmse:5.69545                          
[13]	validation-rmse:5.62729                          
[14]	validation-rmse:5.57344                          
[15]	validation-rmse:5.52257                          
[16]	validation-rmse:5.48541                          
[17]	validation-rmse:5.45495                          
[18]	valid




[0]	validation-rmse:8.82028                                                    
[1]	validation-rmse:8.37392                                                    
[2]	validation-rmse:7.97776                                                    
[3]	validation-rmse:7.62776                                                    
[4]	validation-rmse:7.31912                                                    
[5]	validation-rmse:7.04860                                                    
[6]	validation-rmse:6.81212                                                    
[7]	validation-rmse:6.60563                                                    
[8]	validation-rmse:6.42547                                                    
[9]	validation-rmse:6.26973                                                    
[10]	validation-rmse:6.13487                                                   
[11]	validation-rmse:6.01847                                                   
[12]	validation-rmse:5.91803            




[0]	validation-rmse:8.36339                                                    
[1]	validation-rmse:7.60609                                                    
[2]	validation-rmse:7.01532                                                    
[3]	validation-rmse:6.56128                                                    
[4]	validation-rmse:6.21254                                                    
[5]	validation-rmse:5.95070                                                    
[6]	validation-rmse:5.75039                                                    
[7]	validation-rmse:5.60725                                                    
[8]	validation-rmse:5.49821                                                    
[9]	validation-rmse:5.41704                                                    
[10]	validation-rmse:5.35754                                                   
[11]	validation-rmse:5.31345                                                   
[12]	validation-rmse:5.27889            




[0]	validation-rmse:8.90452                                                    
[1]	validation-rmse:8.52744                                                    
[2]	validation-rmse:8.18418                                                    
[3]	validation-rmse:7.87405                                                    
[4]	validation-rmse:7.59401                                                    
[5]	validation-rmse:7.33774                                                    
[6]	validation-rmse:7.11347                                                    
[7]	validation-rmse:6.90589                                                    
[8]	validation-rmse:6.72571                                                    
[9]	validation-rmse:6.56452                                                    
[10]	validation-rmse:6.41428                                                   
[11]	validation-rmse:6.28666                                                   
[12]	validation-rmse:6.17462            




[7]	validation-rmse:5.47513                                                    
[8]	validation-rmse:5.46376                                                    
[9]	validation-rmse:5.46082                                                    
[10]	validation-rmse:5.45452                                                   
[11]	validation-rmse:5.45200                                                   
[12]	validation-rmse:5.44626                                                   
[13]	validation-rmse:5.44188                                                   
[14]	validation-rmse:5.43832                                                   
[15]	validation-rmse:5.43244                                                   
[16]	validation-rmse:5.42786                                                   
[17]	validation-rmse:5.42489                                                   
[18]	validation-rmse:5.42148                                                   
[19]	validation-rmse:5.41847            




[0]	validation-rmse:8.09827                                                    
[1]	validation-rmse:7.21978                                                    
[2]	validation-rmse:6.60581                                                    
[3]	validation-rmse:6.17273                                                    
[4]	validation-rmse:5.88309                                                    
[5]	validation-rmse:5.70402                                                    
[6]	validation-rmse:5.57249                                                    
[7]	validation-rmse:5.47768                                                    
[8]	validation-rmse:5.41869                                                    
[9]	validation-rmse:5.38040                                                    
[10]	validation-rmse:5.35203                                                   
[11]	validation-rmse:5.33267                                                   
[12]	validation-rmse:5.31219            




[0]	validation-rmse:6.48428                                                    
[1]	validation-rmse:5.62674                                                    
[2]	validation-rmse:5.43800                                                    
[3]	validation-rmse:5.38497                                                    
[4]	validation-rmse:5.34337                                                    
[5]	validation-rmse:5.33888                                                    
[6]	validation-rmse:5.33554                                                    
[7]	validation-rmse:5.33260                                                    
[8]	validation-rmse:5.33039                                                    
[9]	validation-rmse:5.32856                                                    
[10]	validation-rmse:5.32579                                                   
[11]	validation-rmse:5.32335                                                   
[12]	validation-rmse:5.32309            




[0]	validation-rmse:8.57759                                                    
[1]	validation-rmse:7.95380                                                    
[2]	validation-rmse:7.43494                                                    
[3]	validation-rmse:7.00760                                                    
[4]	validation-rmse:6.65535                                                    
[5]	validation-rmse:6.36898                                                    
[6]	validation-rmse:6.13741                                                    
[7]	validation-rmse:5.95071                                                    
[8]	validation-rmse:5.79842                                                    
[9]	validation-rmse:5.67677                                                    
[10]	validation-rmse:5.58008                                                   
[11]	validation-rmse:5.50211                                                   
[12]	validation-rmse:5.43834            




[0]	validation-rmse:8.59331                                                     
[1]	validation-rmse:7.97735                                                     
[2]	validation-rmse:7.46309                                                     
[3]	validation-rmse:7.03479                                                     
[4]	validation-rmse:6.68147                                                     
[5]	validation-rmse:6.39058                                                     
[6]	validation-rmse:6.15640                                                     
[7]	validation-rmse:5.96282                                                     
[8]	validation-rmse:5.81295                                                     
[9]	validation-rmse:5.68477                                                     
[10]	validation-rmse:5.58278                                                    
[11]	validation-rmse:5.50655                                                    
[12]	validation-rmse:5.44084




[0]	validation-rmse:8.74158                                                     
[1]	validation-rmse:8.23627                                                     
[2]	validation-rmse:7.79862                                                     
[3]	validation-rmse:7.42077                                                     
[4]	validation-rmse:7.09679                                                     
[5]	validation-rmse:6.81965                                                     
[6]	validation-rmse:6.58340                                                     
[7]	validation-rmse:6.38236                                                     
[8]	validation-rmse:6.21255                                                     
[9]	validation-rmse:6.06903                                                     
[10]	validation-rmse:5.94931                                                    
[11]	validation-rmse:5.84670                                                    
[12]	validation-rmse:5.76104




[0]	validation-rmse:8.34722                                                      
[1]	validation-rmse:7.58013                                                      
[2]	validation-rmse:6.98526                                                      
[3]	validation-rmse:6.52898                                                      
[4]	validation-rmse:6.18214                                                      
[5]	validation-rmse:5.92595                                                      
[6]	validation-rmse:5.72921                                                      
[7]	validation-rmse:5.59059                                                      
[8]	validation-rmse:5.48379                                                      
[9]	validation-rmse:5.41004                                                      
[10]	validation-rmse:5.35143                                                     
[11]	validation-rmse:5.30510                                                     
[12]	validation-




[0]	validation-rmse:8.54488                                                      
[1]	validation-rmse:7.90959                                                      
[2]	validation-rmse:7.38380                                                      
[3]	validation-rmse:6.94716                                                      
[4]	validation-rmse:6.60508                                                      
[5]	validation-rmse:6.33596                                                      
[6]	validation-rmse:6.11475                                                      
[7]	validation-rmse:5.93707                                                      
[8]	validation-rmse:5.80324                                                      
[9]	validation-rmse:5.70028                                                      
[10]	validation-rmse:5.61012                                                     
[11]	validation-rmse:5.54822                                                     
[12]	validation-




[0]	validation-rmse:6.46954                                                      
[1]	validation-rmse:5.66495                                                      
[2]	validation-rmse:5.45150                                                      
[3]	validation-rmse:5.36878                                                      
[4]	validation-rmse:5.33638                                                      
[5]	validation-rmse:5.31897                                                      
[6]	validation-rmse:5.30949                                                      
[7]	validation-rmse:5.29865                                                      
[8]	validation-rmse:5.28638                                                      
[9]	validation-rmse:5.28470                                                      
[10]	validation-rmse:5.27821                                                     
[11]	validation-rmse:5.27635                                                     
[12]	validation-




[0]	validation-rmse:8.95571                                                      
[1]	validation-rmse:8.61951                                                      
[2]	validation-rmse:8.31163                                                      
[3]	validation-rmse:8.03014                                                      
[4]	validation-rmse:7.77319                                                      
[5]	validation-rmse:7.53912                                                      
[6]	validation-rmse:7.32596                                                      
[7]	validation-rmse:7.13229                                                      
[8]	validation-rmse:6.95642                                                      
[9]	validation-rmse:6.79706                                                      
[10]	validation-rmse:6.65372                                                     
[11]	validation-rmse:6.52415                                                     
[12]	validation-




[0]	validation-rmse:8.97799                                                      
[1]	validation-rmse:8.66031                                                      
[2]	validation-rmse:8.36749                                                      
[3]	validation-rmse:8.09802                                                      
[4]	validation-rmse:7.85032                                                      
[5]	validation-rmse:7.62301                                                      
[6]	validation-rmse:7.41483                                                      
[7]	validation-rmse:7.22448                                                      
[8]	validation-rmse:7.05059                                                      
[9]	validation-rmse:6.89197                                                      
[10]	validation-rmse:6.74728                                                     
[11]	validation-rmse:6.61556                                                     
[12]	validation-




[0]	validation-rmse:8.68422                                                      
[1]	validation-rmse:8.13377                                                      
[2]	validation-rmse:7.66118                                                      
[3]	validation-rmse:7.25939                                                      
[4]	validation-rmse:6.91586                                                      
[5]	validation-rmse:6.62449                                                      
[6]	validation-rmse:6.38210                                                      
[7]	validation-rmse:6.17698                                                      
[8]	validation-rmse:6.00931                                                      
[9]	validation-rmse:5.86522                                                      
[10]	validation-rmse:5.74419                                                     
[11]	validation-rmse:5.64696                                                     
[12]	validation-




[0]	validation-rmse:8.19003                                                      
[1]	validation-rmse:7.34531                                                      
[2]	validation-rmse:6.72772                                                      
[3]	validation-rmse:6.28491                                                      
[4]	validation-rmse:5.97218                                                      
[5]	validation-rmse:5.75482                                                      
[6]	validation-rmse:5.60418                                                      
[7]	validation-rmse:5.49858                                                      
[8]	validation-rmse:5.42197                                                      
[9]	validation-rmse:5.36882                                                      
[10]	validation-rmse:5.33040                                                     
[11]	validation-rmse:5.30422                                                     
[12]	validation-




[0]	validation-rmse:8.85541                                                      
[1]	validation-rmse:8.43495                                                      
[2]	validation-rmse:8.05722                                                      
[3]	validation-rmse:7.71983                                                      
[4]	validation-rmse:7.41784                                                      
[5]	validation-rmse:7.15182                                                      
[6]	validation-rmse:6.91038                                                      
[7]	validation-rmse:6.70049                                                      
[8]	validation-rmse:6.51372                                                      
[9]	validation-rmse:6.35030                                                      
[10]	validation-rmse:6.20460                                                     
[11]	validation-rmse:6.07573                                                     
[12]	validation-




[11]	validation-rmse:5.50488                                                     
[12]	validation-rmse:5.50072                                                     
[13]	validation-rmse:5.49646                                                     
[14]	validation-rmse:5.49269                                                     
[15]	validation-rmse:5.48622                                                     
[16]	validation-rmse:5.48216                                                     
[17]	validation-rmse:5.48065                                                     
[18]	validation-rmse:5.47605                                                     
[19]	validation-rmse:5.47092                                                     
[20]	validation-rmse:5.46623                                                     
[21]	validation-rmse:5.46316                                                     
[22]	validation-rmse:5.46070                                                     
[23]	validation-




[0]	validation-rmse:7.56317                                                   
[1]	validation-rmse:6.54187                                                   
[2]	validation-rmse:5.97211                                                   
[3]	validation-rmse:5.66024                                                   
[4]	validation-rmse:5.49291                                                   
[5]	validation-rmse:5.39644                                                   
[6]	validation-rmse:5.34037                                                   
[7]	validation-rmse:5.30676                                                   
[8]	validation-rmse:5.29139                                                   
[9]	validation-rmse:5.27051                                                   
[10]	validation-rmse:5.26168                                                  
[11]	validation-rmse:5.25721                                                  
[12]	validation-rmse:5.25573                        




[10]	validation-rmse:5.49214                                                  
[11]	validation-rmse:5.49107                                                  
[12]	validation-rmse:5.48279                                                  
[13]	validation-rmse:5.46428                                                  
[14]	validation-rmse:5.45620                                                  
[15]	validation-rmse:5.45249                                                  
[16]	validation-rmse:5.44907                                                  
[17]	validation-rmse:5.44240                                                  
[18]	validation-rmse:5.43613                                                  
[19]	validation-rmse:5.43381                                                  
[20]	validation-rmse:5.42953                                                  
[21]	validation-rmse:5.38629                                                  
[22]	validation-rmse:5.38273                        




[0]	validation-rmse:5.41687                                                   
[1]	validation-rmse:5.25648                                                   
[2]	validation-rmse:5.23781                                                   
[3]	validation-rmse:5.23508                                                   
[4]	validation-rmse:5.23177                                                   
[5]	validation-rmse:5.22828                                                   
[6]	validation-rmse:5.20911                                                   
[7]	validation-rmse:5.20407                                                   
[8]	validation-rmse:5.20202                                                   
[9]	validation-rmse:5.20542                                                   
[10]	validation-rmse:5.20193                                                  
[11]	validation-rmse:5.20295                                                  
[12]	validation-rmse:5.20013                        




[0]	validation-rmse:7.31615                                                   
[1]	validation-rmse:6.27257                                                   
[2]	validation-rmse:5.76545                                                   
[3]	validation-rmse:5.51706                                                   
[4]	validation-rmse:5.39902                                                   
[5]	validation-rmse:5.33321                                                   
[6]	validation-rmse:5.29827                                                   
[7]	validation-rmse:5.27702                                                   
[8]	validation-rmse:5.26018                                                   
[9]	validation-rmse:5.25095                                                   
[10]	validation-rmse:5.24409                                                  
[11]	validation-rmse:5.24163                                                  
[12]	validation-rmse:5.23745                        




[1]	validation-rmse:6.87257                                                   
[2]	validation-rmse:6.27741                                                   
[3]	validation-rmse:5.91116                                                   
[4]	validation-rmse:5.69499                                                   
[5]	validation-rmse:5.56217                                                   
[6]	validation-rmse:5.47811                                                   
[7]	validation-rmse:5.42668                                                   
[8]	validation-rmse:5.39083                                                   
[9]	validation-rmse:5.36656                                                   
[10]	validation-rmse:5.35141                                                  
[11]	validation-rmse:5.34042                                                  
[12]	validation-rmse:5.32455                                                  
[13]	validation-rmse:5.31965                        




[0]	validation-rmse:5.79193                                                   
[1]	validation-rmse:5.34089                                                   
[2]	validation-rmse:5.26944                                                   
[3]	validation-rmse:5.25088                                                   
[4]	validation-rmse:5.23695                                                   
[5]	validation-rmse:5.22931                                                   
[6]	validation-rmse:5.22342                                                   
[7]	validation-rmse:5.21956                                                   
[8]	validation-rmse:5.21132                                                   
[9]	validation-rmse:5.20898                                                   
[10]	validation-rmse:5.19779                                                  
[11]	validation-rmse:5.19641                                                  
[12]	validation-rmse:5.19500                        




[1]	validation-rmse:5.94586                                                   
[2]	validation-rmse:5.62153                                                   
[3]	validation-rmse:5.50466                                                   
[4]	validation-rmse:5.44379                                                   
[5]	validation-rmse:5.41876                                                   
[6]	validation-rmse:5.40249                                                   
[7]	validation-rmse:5.38325                                                   
[8]	validation-rmse:5.37266                                                   
[9]	validation-rmse:5.36229                                                   
[10]	validation-rmse:5.35528                                                  
[11]	validation-rmse:5.35242                                                  
[12]	validation-rmse:5.34695                                                  
[13]	validation-rmse:5.34534                        




[1]	validation-rmse:6.75789                                                   
[2]	validation-rmse:6.17093                                                   
[3]	validation-rmse:5.82290                                                   
[4]	validation-rmse:5.62645                                                   
[5]	validation-rmse:5.50790                                                   
[6]	validation-rmse:5.43661                                                   
[7]	validation-rmse:5.38780                                                   
[8]	validation-rmse:5.36005                                                   
[9]	validation-rmse:5.34098                                                   
[10]	validation-rmse:5.32366                                                  
[11]	validation-rmse:5.31651                                                  
[12]	validation-rmse:5.31034                                                  
[13]	validation-rmse:5.30565                        




[0]	validation-rmse:9.01801                                                   
[1]	validation-rmse:8.73424                                                   
[2]	validation-rmse:8.47001                                                   
[3]	validation-rmse:8.22407                                                   
[4]	validation-rmse:7.99574                                                   
[5]	validation-rmse:7.78384                                                   
[6]	validation-rmse:7.58754                                                   
[7]	validation-rmse:7.40574                                                   
[8]	validation-rmse:7.23787                                                   
[9]	validation-rmse:7.08261                                                   
[10]	validation-rmse:6.93933                                                  
[11]	validation-rmse:6.80705                                                  
[12]	validation-rmse:6.68568                        




[0]	validation-rmse:7.91783                                                   
[1]	validation-rmse:6.97738                                                   
[2]	validation-rmse:6.36130                                                   
[3]	validation-rmse:5.98782                                                   
[4]	validation-rmse:5.72772                                                   
[5]	validation-rmse:5.58204                                                   
[6]	validation-rmse:5.49037                                                   
[7]	validation-rmse:5.43051                                                   
[8]	validation-rmse:5.39220                                                   
[9]	validation-rmse:5.35808                                                   
[10]	validation-rmse:5.34177                                                  
[11]	validation-rmse:5.33102                                                  
[12]	validation-rmse:5.32066                        




[1]	validation-rmse:5.45920                                                   
[2]	validation-rmse:5.35833                                                   
[3]	validation-rmse:5.33260                                                   
[4]	validation-rmse:5.31982                                                   
[5]	validation-rmse:5.30921                                                   
[6]	validation-rmse:5.30463                                                   
[7]	validation-rmse:5.29988                                                   
[8]	validation-rmse:5.29249                                                   
[9]	validation-rmse:5.28549                                                   
[10]	validation-rmse:5.27900                                                  
[11]	validation-rmse:5.27510                                                  
[12]	validation-rmse:5.27047                                                  
[13]	validation-rmse:5.26740                        




[0]	validation-rmse:8.34717                                                   
[1]	validation-rmse:7.58958                                                   
[2]	validation-rmse:7.00990                                                   
[3]	validation-rmse:6.57151                                                   
[4]	validation-rmse:6.24422                                                   
[5]	validation-rmse:6.00145                                                   
[6]	validation-rmse:5.82187                                                   
[7]	validation-rmse:5.68927                                                   
[8]	validation-rmse:5.59477                                                   
[9]	validation-rmse:5.52086                                                   
[10]	validation-rmse:5.46879                                                  
[11]	validation-rmse:5.42926                                                  
[12]	validation-rmse:5.39845                        




[0]	validation-rmse:7.52763                                                   
[1]	validation-rmse:6.52547                                                   
[2]	validation-rmse:5.99428                                                   
[3]	validation-rmse:5.72479                                                   
[4]	validation-rmse:5.58783                                                   
[5]	validation-rmse:5.51745                                                   
[6]	validation-rmse:5.46511                                                   
[7]	validation-rmse:5.43733                                                   
[8]	validation-rmse:5.42136                                                   
[9]	validation-rmse:5.40732                                                   
[10]	validation-rmse:5.39616                                                  
[11]	validation-rmse:5.38349                                                  
[12]	validation-rmse:5.36977                        




[0]	validation-rmse:8.06441                                                   
[1]	validation-rmse:7.16003                                                   
[2]	validation-rmse:6.52761                                                   
[3]	validation-rmse:6.09353                                                   
[4]	validation-rmse:5.79651                                                   
[5]	validation-rmse:5.59935                                                   
[6]	validation-rmse:5.46589                                                   
[7]	validation-rmse:5.37709                                                   
[8]	validation-rmse:5.31443                                                   
[9]	validation-rmse:5.27414                                                   
[10]	validation-rmse:5.24405                                                  
[11]	validation-rmse:5.22540                                                  
[12]	validation-rmse:5.21202                        




[0]	validation-rmse:8.75548                                                   
[1]	validation-rmse:8.25784                                                   
[2]	validation-rmse:7.82056                                                   
[3]	validation-rmse:7.43817                                                   
[4]	validation-rmse:7.10485                                                   
[5]	validation-rmse:6.81638                                                   
[6]	validation-rmse:6.56796                                                   
[7]	validation-rmse:6.35572                                                   
[8]	validation-rmse:6.17309                                                   
[9]	validation-rmse:6.01801                                                   
[10]	validation-rmse:5.88599                                                  
[11]	validation-rmse:5.77352                                                  
[12]	validation-rmse:5.67930                        




[10]	validation-rmse:5.81387                                                  
[11]	validation-rmse:5.75770                                                  
[12]	validation-rmse:5.71431                                                  
[13]	validation-rmse:5.67863                                                  
[14]	validation-rmse:5.64805                                                  
[15]	validation-rmse:5.62548                                                  
[16]	validation-rmse:5.61209                                                  
[17]	validation-rmse:5.59454                                                  
[18]	validation-rmse:5.58148                                                  
[19]	validation-rmse:5.57404                                                  
[20]	validation-rmse:5.56789                                                  
[21]	validation-rmse:5.56261                                                  
[22]	validation-rmse:5.55367                        




[0]	validation-rmse:6.94041                                                   
[1]	validation-rmse:5.91012                                                   
[2]	validation-rmse:5.50544                                                   
[3]	validation-rmse:5.34553                                                   
[4]	validation-rmse:5.27710                                                   
[5]	validation-rmse:5.24610                                                   
[6]	validation-rmse:5.22693                                                   
[7]	validation-rmse:5.21651                                                   
[8]	validation-rmse:5.20520                                                   
[9]	validation-rmse:5.20081                                                   
[10]	validation-rmse:5.19724                                                  
[11]	validation-rmse:5.19282                                                  
[12]	validation-rmse:5.18720                        




[2]	validation-rmse:6.91030                                                   
[3]	validation-rmse:6.48008                                                   
[4]	validation-rmse:6.17316                                                   
[5]	validation-rmse:5.95199                                                   
[6]	validation-rmse:5.79628                                                   
[7]	validation-rmse:5.68410                                                   
[8]	validation-rmse:5.60350                                                   
[9]	validation-rmse:5.54533                                                   
[10]	validation-rmse:5.49940                                                  
[11]	validation-rmse:5.46792                                                  
[12]	validation-rmse:5.44416                                                  
[13]	validation-rmse:5.42495                                                  
[14]	validation-rmse:5.40831                        




[0]	validation-rmse:5.90923                                                   
[1]	validation-rmse:5.32789                                                   
[2]	validation-rmse:5.23142                                                   
[3]	validation-rmse:5.20829                                                   
[4]	validation-rmse:5.19516                                                   
[5]	validation-rmse:5.18953                                                   
[6]	validation-rmse:5.18791                                                   
[7]	validation-rmse:5.18426                                                   
[8]	validation-rmse:5.18047                                                   
[9]	validation-rmse:5.18048                                                   
[10]	validation-rmse:5.18133                                                  
[11]	validation-rmse:5.18039                                                  
[12]	validation-rmse:5.17800                        




[0]	validation-rmse:6.48944                                                   
[1]	validation-rmse:5.56654                                                   
[2]	validation-rmse:5.30555                                                   
[3]	validation-rmse:5.22622                                                   
[4]	validation-rmse:5.20018                                                   
[5]	validation-rmse:5.19055                                                   
[6]	validation-rmse:5.18074                                                   
[7]	validation-rmse:5.17932                                                   
[8]	validation-rmse:5.17603                                                   
[9]	validation-rmse:5.17681                                                   
[10]	validation-rmse:5.17511                                                  
[11]	validation-rmse:5.17508                                                  
[12]	validation-rmse:5.17263                        




[0]	validation-rmse:8.67025                                                   
[1]	validation-rmse:8.11471                                                   
[2]	validation-rmse:7.64243                                                   
[3]	validation-rmse:7.24687                                                   
[4]	validation-rmse:6.91233                                                   
[5]	validation-rmse:6.63720                                                   
[6]	validation-rmse:6.40455                                                   
[7]	validation-rmse:6.21297                                                   
[8]	validation-rmse:6.05036                                                   
[9]	validation-rmse:5.91625                                                   
[10]	validation-rmse:5.80944                                                  
[11]	validation-rmse:5.71869                                                  
[12]	validation-rmse:5.64740                        




[0]	validation-rmse:7.32556                                                   
[1]	validation-rmse:6.31499                                                   
[2]	validation-rmse:5.83566                                                   
[3]	validation-rmse:5.61529                                                   
[4]	validation-rmse:5.50951                                                   
[5]	validation-rmse:5.45225                                                   
[6]	validation-rmse:5.41758                                                   
[7]	validation-rmse:5.39584                                                   
[8]	validation-rmse:5.37728                                                   
[9]	validation-rmse:5.36060                                                   
[10]	validation-rmse:5.35271                                                  
[11]	validation-rmse:5.34439                                                  
[12]	validation-rmse:5.33539                        




[0]	validation-rmse:8.46905                                                   
[1]	validation-rmse:7.78064                                                   
[2]	validation-rmse:7.22853                                                   
[3]	validation-rmse:6.79316                                                   
[4]	validation-rmse:6.44930                                                   
[5]	validation-rmse:6.18178                                                   
[6]	validation-rmse:5.97529                                                   
[7]	validation-rmse:5.81635                                                   
[8]	validation-rmse:5.69375                                                   
[9]	validation-rmse:5.59845                                                   
[10]	validation-rmse:5.52477                                                  
[11]	validation-rmse:5.46615                                                  
[12]	validation-rmse:5.42230                        




[0]	validation-rmse:8.83291                                                   
[1]	validation-rmse:8.39879                                                   
[2]	validation-rmse:8.01186                                                   
[3]	validation-rmse:7.66949                                                   
[4]	validation-rmse:7.36056                                                   
[5]	validation-rmse:7.09105                                                   
[6]	validation-rmse:6.85411                                                   
[7]	validation-rmse:6.64754                                                   
[8]	validation-rmse:6.46914                                                   
[9]	validation-rmse:6.31573                                                   
[10]	validation-rmse:6.17458                                                  
[11]	validation-rmse:6.05506                                                  
[12]	validation-rmse:5.95544                        




[0]	validation-rmse:8.54553                                                   
[1]	validation-rmse:7.89936                                                   
[2]	validation-rmse:7.36494                                                   
[3]	validation-rmse:6.92792                                                   
[4]	validation-rmse:6.57274                                                   
[5]	validation-rmse:6.28788                                                   
[6]	validation-rmse:6.05955                                                   
[7]	validation-rmse:5.87610                                                   
[8]	validation-rmse:5.72909                                                   
[9]	validation-rmse:5.61594                                                   
[10]	validation-rmse:5.52422                                                  
[11]	validation-rmse:5.45304                                                  
[12]	validation-rmse:5.39687                        




[0]	validation-rmse:6.21626                                                   
[1]	validation-rmse:5.45125                                                   
[2]	validation-rmse:5.28290                                                   
[3]	validation-rmse:5.22481                                                   
[4]	validation-rmse:5.20626                                                   
[5]	validation-rmse:5.19216                                                   
[6]	validation-rmse:5.18877                                                   
[7]	validation-rmse:5.18567                                                   
[8]	validation-rmse:5.18109                                                   
[9]	validation-rmse:5.17723                                                   
[10]	validation-rmse:5.17650                                                  
[11]	validation-rmse:5.16591                                                  
[12]	validation-rmse:5.16426                        




[0]	validation-rmse:8.90662                                                   
[1]	validation-rmse:8.52788                                                   
[2]	validation-rmse:8.18271                                                   
[3]	validation-rmse:7.86872                                                   
[4]	validation-rmse:7.58560                                                   
[5]	validation-rmse:7.32891                                                   
[6]	validation-rmse:7.09803                                                   
[7]	validation-rmse:6.88896                                                   
[8]	validation-rmse:6.70333                                                   
[9]	validation-rmse:6.53443                                                   
[10]	validation-rmse:6.38473                                                  
[11]	validation-rmse:6.25030                                                  
[12]	validation-rmse:6.13296                        




[6]	validation-rmse:6.78382                                                   
[7]	validation-rmse:6.58823                                                   
[8]	validation-rmse:6.42122                                                   
[9]	validation-rmse:6.27773                                                   
[10]	validation-rmse:6.15752                                                  
[11]	validation-rmse:6.05429                                                  
[12]	validation-rmse:5.96642                                                  
[13]	validation-rmse:5.89088                                                  
[14]	validation-rmse:5.82589                                                  
[15]	validation-rmse:5.77155                                                  
[16]	validation-rmse:5.72438                                                  
[17]	validation-rmse:5.68510                                                  
[18]	validation-rmse:5.65050                        




[0]	validation-rmse:5.44995                                                   
[1]	validation-rmse:5.25682                                                   
[2]	validation-rmse:5.24228                                                   
[3]	validation-rmse:5.23263                                                   
[4]	validation-rmse:5.22506                                                   
[5]	validation-rmse:5.21758                                                   
[6]	validation-rmse:5.20916                                                   
[7]	validation-rmse:5.20318                                                   
[8]	validation-rmse:5.20246                                                   
[9]	validation-rmse:5.20151                                                   
[10]	validation-rmse:5.20326                                                  
[11]	validation-rmse:5.20204                                                  
[12]	validation-rmse:5.19517                        




[0]	validation-rmse:8.04621                                                   
[1]	validation-rmse:7.14456                                                   
[2]	validation-rmse:6.53462                                                   
[3]	validation-rmse:6.09961                                                   
[4]	validation-rmse:5.82510                                                   
[5]	validation-rmse:5.65235                                                   
[6]	validation-rmse:5.52242                                                   
[7]	validation-rmse:5.44645                                                   
[8]	validation-rmse:5.39664                                                   
[9]	validation-rmse:5.36181                                                   
[10]	validation-rmse:5.33960                                                  
[11]	validation-rmse:5.32296                                                  
[12]	validation-rmse:5.30336                        




[0]	validation-rmse:9.00191                                                   
[1]	validation-rmse:8.70423                                                   
[2]	validation-rmse:8.42800                                                   
[3]	validation-rmse:8.17242                                                   
[4]	validation-rmse:7.93595                                                   
[5]	validation-rmse:7.71875                                                   
[6]	validation-rmse:7.51691                                                   
[7]	validation-rmse:7.33254                                                   
[8]	validation-rmse:7.16106                                                   
[9]	validation-rmse:7.00532                                                   
[10]	validation-rmse:6.86013                                                  
[11]	validation-rmse:6.72599                                                  
[12]	validation-rmse:6.60604                        

In [50]:
best_result

{'learning_rate': 0.3931914926020957,
 'max_depth': 5.0,
 'min_child_weight': 4.061140110039656,
 'reg_alpha': 0.0812130837749679,
 'reg_lambda': 0.010221480465528184}

In [49]:
mlflow.xgboost.autolog(disable=True)

In [53]:
with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        "learning_rate": 0.39319149260209574,
        "max_depth": 5,
        "min_child_weight": 4.061140110039656,
        "objective": "reg:linear",
        "reg_alpha": 0.0812130837749679,
        "reg_lambda": 0.010221480465528184,
        "seed": 42
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, "validation")],
        early_stopping_rounds=50
    )

    y_pred = booster.predict(valid)
    rmse = root_mean_squared_error(y_val, y_pred)
    mlflow.log_metric("rmse", rmse)

    with open("../models/hw2/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
    mlflow.log_artifact("../models/hw2/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")

[0]	validation-rmse:7.25975
[1]	validation-rmse:6.30296
[2]	validation-rmse:5.88128
[3]	validation-rmse:5.69972
[4]	validation-rmse:5.61293
[5]	validation-rmse:5.57003
[6]	validation-rmse:5.55271
[7]	validation-rmse:5.53824
[8]	validation-rmse:5.52051
[9]	validation-rmse:5.51369
[10]	validation-rmse:5.50921
[11]	validation-rmse:5.50488
[12]	validation-rmse:5.50072
[13]	validation-rmse:5.49646




[14]	validation-rmse:5.49269
[15]	validation-rmse:5.48622
[16]	validation-rmse:5.48216
[17]	validation-rmse:5.48065
[18]	validation-rmse:5.47605
[19]	validation-rmse:5.47092
[20]	validation-rmse:5.46623
[21]	validation-rmse:5.46316
[22]	validation-rmse:5.46070
[23]	validation-rmse:5.45559
[24]	validation-rmse:5.45100
[25]	validation-rmse:5.44850
[26]	validation-rmse:5.44589
[27]	validation-rmse:5.44262
[28]	validation-rmse:5.43623
[29]	validation-rmse:5.43418
[30]	validation-rmse:5.43235
[31]	validation-rmse:5.42945
[32]	validation-rmse:5.42746
[33]	validation-rmse:5.42397
[34]	validation-rmse:5.42134
[35]	validation-rmse:5.41909
[36]	validation-rmse:5.41600
[37]	validation-rmse:5.41364
[38]	validation-rmse:5.41261
[39]	validation-rmse:5.41059
[40]	validation-rmse:5.40892
[41]	validation-rmse:5.40653
[42]	validation-rmse:5.40553
[43]	validation-rmse:5.40336
[44]	validation-rmse:5.40022
[45]	validation-rmse:5.39516
[46]	validation-rmse:5.39284
[47]	validation-rmse:5.39178
[48]	validatio



In [56]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.svm import LinearSVR

mlflow.sklearn.autolog()

for model_class in (RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, LinearSVR):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "../data/hw2/train/green_tripdata_2023-01.parquet")
        mlflow.log_param("valid-data-path", "../data/hw2/train/green_tripdata_2023-02.parquet")
        mlflow.log_artifact("../models/hw2/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = root_mean_squared_error(y_val, y_pred)
        mlflow.log_metric("rmse", rmse)
        

