In [1]:
!python -V

Python 3.9.19


In [2]:
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt

import sklearn
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error

print(f"sklearn version: {sklearn.__version__}")

import warnings
warnings.filterwarnings("ignore")

sklearn version: 1.4.2


In [3]:
import mlflow

mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")

<Experiment: artifact_location='/Users/gan-m2/MLOps/Module 2/mlruns/1', creation_time=1716865717048, experiment_id='1', last_update_time=1716865717048, lifecycle_stage='active', name='nyc-taxi-experiment', tags={}>

In [4]:
def read_dataframe(filename):
    df = pd.read_csv(filename)

    df.lpep_dropoff_datetime = pd.to_datetime(df.lpep_dropoff_datetime)
    df.lpep_pickup_datetime = pd.to_datetime(df.lpep_pickup_datetime)

    df['duration'] = df.lpep_dropoff_datetime - df.lpep_pickup_datetime
    df.duration = df.duration.apply(lambda td: td.total_seconds() / 60)

    df = df[(df.duration >= 1) & (df.duration <= 60)]

    categorical = ['PULocationID', 'DOLocationID']
    df[categorical] = df[categorical].astype(str)
    
    return df

In [5]:
df_train = read_dataframe('./data/green_tripdata_2021-01.csv')
df_val = read_dataframe('./data/green_tripdata_2021-02.csv')

In [6]:
len(df_train), len(df_val)

(73908, 61921)

In [7]:
df_train['PU_DO'] = df_train['PULocationID'] + '_' + df_train['DOLocationID']
df_val['PU_DO'] = df_val['PULocationID'] + '_' + df_val['DOLocationID']

In [8]:
categorical = ['PU_DO'] #'PULocationID', 'DOLocationID']
numerical = ['trip_distance']

dv = DictVectorizer()

train_dicts = df_train[categorical + numerical].to_dict(orient='records')
X_train = dv.fit_transform(train_dicts)

val_dicts = df_val[categorical + numerical].to_dict(orient='records')
X_val = dv.transform(val_dicts)

In [9]:
target = 'duration'
y_train = df_train[target].values
y_val = df_val[target].values

In [10]:
lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_val)

mean_squared_error(y_val, y_pred, squared=False)

7.758715211934021

In [15]:
with open('models/lin_reg.bin', 'wb') as f_out:
    pickle.dump((dv, lr), f_out)

In [17]:
# 开启一个新的 MLflow 运行，在 with 代码块内的所有 MLflow 相关操作都会记录到这个运行中
with mlflow.start_run():

    # 标签 developer，值为 cristian
    mlflow.set_tag("developer", "cristian")

    # 记录训练数据和验证数据的路径。
    mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
    mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")

    # 记录超参数 alpha 的值。这里 alpha 是用于 Lasso 回归模型的正则化强度。
    alpha = 0.2
    mlflow.log_param("alpha", alpha)
    
    # 建模
    lr = Lasso(alpha)
    lr.fit(X_train, y_train)

    # 使用验证数据 X_val 进行预测，并计算预测结果的均方根误差（RMSE）。然后将 RMSE 作为指标记录下来
    y_pred = lr.predict(X_val)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)
    # 将本地路径为 models/lin_reg.bin 的模型文件记录为一个工件，存储到 models_pickle 目录中。工件可以是模型文件、数据文件等，便于日后下载和使用。

    mlflow.log_artifact(local_path="models/lin_reg.bin", artifact_path="models_pickle")



### 单个xgboost模型训练

In [18]:
import xgboost as xgb

In [19]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [22]:
train = xgb.DMatrix(X_train, label=y_train)
valid = xgb.DMatrix(X_val, label=y_val)

In [23]:
# param: 在超参数优化过程中，Hyperopt 会从这个搜索空间中采样参数组合，并将其传递给 objective 函数中的 params 参数。
# 因此，params 是从搜索空间中采样得到的一组具体参数，用于配置和训练 XGBoost 模型。
# 也即是下面的 search_space 
def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")
        mlflow.log_params(params)
        booster = xgb.train(
            
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = booster.predict(valid)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)

    return {'loss': rmse, 'status': STATUS_OK}

### Hyper tuning and store logs in MLflow (it should run 2 hours + to complete)

In [24]:
# max_depth: 树的最大深度，取值范围为 4 到 100 之间的整数。
# learning_rate: 学习率，取值范围为 0.001 到 1 之间的对数均匀分布。
# reg_alpha: L1 正则化项，取值范围为 0.00001 到 0.1 之间的对数均匀分布。
# reg_lambda: L2 正则化项，取值范围为 0.000001 到 0.1 之间的对数均匀分布。
# min_child_weight: 最小子节点权重，取值范围为 0.1 到 20 之间的对数均匀分布。
# objective: 目标函数，这里设置为回归任务。
# seed: 随机种子，用于保证结果的可重复性。

search_space = {
    'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
    'learning_rate': hp.loguniform('learning_rate', -3, 0),
    'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
    'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
    'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
    'objective': 'reg:linear',
    'seed': 42
}

# fmin: 优化函数
# fn: 目标函数，也就是我们在上面定义的 function
# space: 参数空间，也就是我们要搜索的超参数范围，在上面定义的 search_space
# algo: 优化算法，这里使用 TPE 算法，也就是树形 Parzen Estimator
# max_evals: 最大迭代次数
# trials: 用于记录每次迭代的结果
best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:7.18551                           
[1]	validation-rmse:6.74890                           
  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]




[2]	validation-rmse:6.68660                           
[3]	validation-rmse:6.67631                           
[4]	validation-rmse:6.66562                           
[5]	validation-rmse:6.66060                           
[6]	validation-rmse:6.65793                           
[7]	validation-rmse:6.65051                           
[8]	validation-rmse:6.63889                           
[9]	validation-rmse:6.63304                           
[10]	validation-rmse:6.62931                          
[11]	validation-rmse:6.62501                          
[12]	validation-rmse:6.61594                          
[13]	validation-rmse:6.60854                          
[14]	validation-rmse:6.60432                          
[15]	validation-rmse:6.60278                          
[16]	validation-rmse:6.59901                          
[17]	validation-rmse:6.59760                          
[18]	validation-rmse:6.59351                          
[19]	validation-rmse:6.58840                          
[20]	valid





[0]	validation-rmse:8.19682                                                    
[1]	validation-rmse:7.05841                                                    
[2]	validation-rmse:6.74699                                                    
[3]	validation-rmse:6.64427                                                    
[4]	validation-rmse:6.60503                                                    
[5]	validation-rmse:6.59068                                                    
[6]	validation-rmse:6.58134                                                    
[7]	validation-rmse:6.57321                                                    
[8]	validation-rmse:6.56684                                                    
[9]	validation-rmse:6.56086                                                    
[10]	validation-rmse:6.55716                                                   
[11]	validation-rmse:6.55255                                                   
[12]	validation-rmse:6.54896            





[0]	validation-rmse:11.30733                                                   
[1]	validation-rmse:10.53182                                                   
[2]	validation-rmse:9.87140                                                    
[3]	validation-rmse:9.31232                                                    
[4]	validation-rmse:8.83985                                                    
[5]	validation-rmse:8.44362                                                    
[6]	validation-rmse:8.11235                                                    
[7]	validation-rmse:7.83616                                                    
[8]	validation-rmse:7.60726                                                    
[9]	validation-rmse:7.41704                                                    
[10]	validation-rmse:7.26004                                                   
[11]	validation-rmse:7.13092                                                   
[12]	validation-rmse:7.02404            





[0]	validation-rmse:11.47406                                                   
[1]	validation-rmse:10.82118                                                   
[2]	validation-rmse:10.24628                                                   
[3]	validation-rmse:9.74228                                                    
[4]	validation-rmse:9.30200                                                    
[5]	validation-rmse:8.91784                                                    
[6]	validation-rmse:8.58424                                                    
[7]	validation-rmse:8.29613                                                    
[8]	validation-rmse:8.04614                                                    
[9]	validation-rmse:7.83150                                                    
[10]	validation-rmse:7.64578                                                   
[11]	validation-rmse:7.48640                                                   
[12]	validation-rmse:7.34990            





[0]	validation-rmse:11.38864                                                   
[1]	validation-rmse:10.67438                                                   
[2]	validation-rmse:10.05792                                                   
[3]	validation-rmse:9.52814                                                    
[4]	validation-rmse:9.07569                                                    
[5]	validation-rmse:8.68951                                                    
[6]	validation-rmse:8.36181                                                    
[7]	validation-rmse:8.08443                                                    
[8]	validation-rmse:7.85128                                                    
[9]	validation-rmse:7.65420                                                    
[10]	validation-rmse:7.48824                                                   
[11]	validation-rmse:7.34933                                                   
[12]	validation-rmse:7.23194            





[1]	validation-rmse:9.76581                                                    
[2]	validation-rmse:8.95332                                                    
[3]	validation-rmse:8.34289                                                    
[4]	validation-rmse:7.88484                                                    
[5]	validation-rmse:7.54799                                                    
[6]	validation-rmse:7.29898                                                    
[7]	validation-rmse:7.11590                                                    
[8]	validation-rmse:6.98206                                                    
[9]	validation-rmse:6.88202                                                    
[10]	validation-rmse:6.80910                                                   
[11]	validation-rmse:6.75354                                                   
[12]	validation-rmse:6.70988                                                   
[13]	validation-rmse:6.67697            





[0]	validation-rmse:8.45319                                                    
[1]	validation-rmse:7.16377                                                    
[2]	validation-rmse:6.75214                                                    
[3]	validation-rmse:6.61077                                                    
[4]	validation-rmse:6.54526                                                    
[5]	validation-rmse:6.51882                                                    
[6]	validation-rmse:6.50270                                                    
[7]	validation-rmse:6.49029                                                    
[8]	validation-rmse:6.48611                                                    
[9]	validation-rmse:6.48282                                                    
[10]	validation-rmse:6.47647                                                   
[11]	validation-rmse:6.47187                                                   
[12]	validation-rmse:6.46650            





[0]	validation-rmse:11.34240                                                   
[1]	validation-rmse:10.59431                                                   
[2]	validation-rmse:9.95561                                                    
[3]	validation-rmse:9.41168                                                    
[4]	validation-rmse:8.95009                                                    
[5]	validation-rmse:8.56256                                                    
[6]	validation-rmse:8.23742                                                    
[7]	validation-rmse:7.96314                                                    
[8]	validation-rmse:7.73700                                                    
[9]	validation-rmse:7.54602                                                    
[10]	validation-rmse:7.38776                                                   
[11]	validation-rmse:7.25413                                                   
[12]	validation-rmse:7.14476            





[0]	validation-rmse:11.52085                                                    
[1]	validation-rmse:10.90679                                                    
[2]	validation-rmse:10.36552                                                    
[3]	validation-rmse:9.88120                                                     
[4]	validation-rmse:9.45864                                                     
[5]	validation-rmse:9.08326                                                     
[6]	validation-rmse:8.75535                                                     
[7]	validation-rmse:8.46889                                                     
[8]	validation-rmse:8.21559                                                     
[9]	validation-rmse:7.99214                                                     
[10]	validation-rmse:7.81308                                                    
[11]	validation-rmse:7.64356                                                    
[12]	validation-rmse:7.50283





[0]	validation-rmse:11.63690                                                    
[1]	validation-rmse:11.11219                                                    
[2]	validation-rmse:10.63647                                                    
[3]	validation-rmse:10.20584                                                    
[4]	validation-rmse:9.81635                                                     
[5]	validation-rmse:9.46553                                                     
[6]	validation-rmse:9.14898                                                     
[7]	validation-rmse:8.86341                                                     
[8]	validation-rmse:8.60846                                                     
[9]	validation-rmse:8.38021                                                     
[10]	validation-rmse:8.17546                                                    
[11]	validation-rmse:7.99158                                                    
[12]	validation-rmse:7.82905





[0]	validation-rmse:8.06311                                                     
[1]	validation-rmse:7.00165                                                     
[2]	validation-rmse:6.73023                                                     
[3]	validation-rmse:6.64496                                                     
[4]	validation-rmse:6.60571                                                     
[5]	validation-rmse:6.59150                                                     
[6]	validation-rmse:6.58596                                                     
[7]	validation-rmse:6.57695                                                     
[8]	validation-rmse:6.56868                                                     
[9]	validation-rmse:6.56466                                                     
[10]	validation-rmse:6.55792                                                    
[11]	validation-rmse:6.55286                                                    
[12]	validation-rmse:6.54987





[1]	validation-rmse:6.68900                                                     
[2]	validation-rmse:6.66766                                                     
[3]	validation-rmse:6.64940                                                     
[4]	validation-rmse:6.64014                                                     
[5]	validation-rmse:6.62982                                                     
[6]	validation-rmse:6.61433                                                     
[7]	validation-rmse:6.61094                                                     
[8]	validation-rmse:6.60717                                                     
[9]	validation-rmse:6.60237                                                     
[10]	validation-rmse:6.59970                                                    
[11]	validation-rmse:6.59637                                                    
[12]	validation-rmse:6.59424                                                    
[13]	validation-rmse:6.58976





[0]	validation-rmse:11.15740                                                    
[1]	validation-rmse:10.28370                                                    
[2]	validation-rmse:9.56075                                                     
[3]	validation-rmse:8.97071                                                     
[4]	validation-rmse:8.49117                                                     
[5]	validation-rmse:8.10203                                                     
[6]	validation-rmse:7.79138                                                     
[7]	validation-rmse:7.54267                                                     
[8]	validation-rmse:7.34278                                                     
[9]	validation-rmse:7.18280                                                     
[10]	validation-rmse:7.05367                                                    
[11]	validation-rmse:6.95182                                                    
[12]	validation-rmse:6.86671





[0]	validation-rmse:10.29356                                                    
[1]	validation-rmse:8.97530                                                     
[2]	validation-rmse:8.10031                                                     
[3]	validation-rmse:7.53062                                                     
[4]	validation-rmse:7.16547                                                     
[5]	validation-rmse:6.92919                                                     
[6]	validation-rmse:6.77509                                                     
[7]	validation-rmse:6.67459                                                     
[8]	validation-rmse:6.60824                                                     
[9]	validation-rmse:6.56107                                                     
[10]	validation-rmse:6.52636                                                    
[11]	validation-rmse:6.49903                                                    
[12]	validation-rmse:6.47947





[0]	validation-rmse:11.23163                                                    
[1]	validation-rmse:10.40809                                                    
[2]	validation-rmse:9.72087                                                     
[3]	validation-rmse:9.15096                                                     
[4]	validation-rmse:8.68159                                                     
[5]	validation-rmse:8.29723                                                     
[6]	validation-rmse:7.98360                                                     
[7]	validation-rmse:7.72977                                                     
[8]	validation-rmse:7.52428                                                     
[9]	validation-rmse:7.35781                                                     
[10]	validation-rmse:7.22045                                                    
[11]	validation-rmse:7.11062                                                    
[12]	validation-rmse:7.02011





[0]	validation-rmse:11.67462                                                    
[1]	validation-rmse:11.18274                                                    
[2]	validation-rmse:10.73462                                                    
[3]	validation-rmse:10.32679                                                    
[4]	validation-rmse:9.95618                                                     
[5]	validation-rmse:9.61991                                                     
[6]	validation-rmse:9.31542                                                     
[7]	validation-rmse:9.04028                                                     
[8]	validation-rmse:8.79200                                                     
[9]	validation-rmse:8.56876                                                     
[10]	validation-rmse:8.36783                                                    
[11]	validation-rmse:8.18721                                                    
[12]	validation-rmse:8.02534





[1]	validation-rmse:11.38889                                                    
[2]	validation-rmse:11.01933                                                    
[3]	validation-rmse:10.67592                                                    
[4]	validation-rmse:10.35708                                                    
[5]	validation-rmse:10.06157                                                    
[6]	validation-rmse:9.78773                                                     
[7]	validation-rmse:9.53471                                                     
[8]	validation-rmse:9.30039                                                     
[9]	validation-rmse:9.08420                                                     
[10]	validation-rmse:8.88465                                                    
[11]	validation-rmse:8.70112                                                    
[12]	validation-rmse:8.53208                                                    
[13]	validation-rmse:8.37654





[0]	validation-rmse:6.76512                                                     
[1]	validation-rmse:6.60736                                                     
[2]	validation-rmse:6.59148                                                     
[3]	validation-rmse:6.58399                                                     
[4]	validation-rmse:6.57367                                                     
[5]	validation-rmse:6.56664                                                     
[6]	validation-rmse:6.55632                                                     
[7]	validation-rmse:6.54273                                                     
[8]	validation-rmse:6.53019                                                     
[9]	validation-rmse:6.52730                                                     
[10]	validation-rmse:6.51866                                                    
[11]	validation-rmse:6.51602                                                    
[12]	validation-rmse:6.51385





[1]	validation-rmse:9.06596                                                     
[2]	validation-rmse:8.20726                                                     
[3]	validation-rmse:7.64513                                                     
[4]	validation-rmse:7.28222                                                     
[5]	validation-rmse:7.04746                                                     
[6]	validation-rmse:6.89551                                                     
[7]	validation-rmse:6.79610                                                     
[8]	validation-rmse:6.72994                                                     
[9]	validation-rmse:6.68572                                                     
[10]	validation-rmse:6.65262                                                    
[11]	validation-rmse:6.62825                                                    
[12]	validation-rmse:6.61200                                                    
[13]	validation-rmse:6.59803





[1]	validation-rmse:9.39752                                                     
[2]	validation-rmse:8.54889                                                     
[3]	validation-rmse:7.95166                                                     
[4]	validation-rmse:7.53639                                                     
[5]	validation-rmse:7.25356                                                     
[6]	validation-rmse:7.05447                                                     
[7]	validation-rmse:6.91733                                                     
[8]	validation-rmse:6.82182                                                     
[9]	validation-rmse:6.75061                                                     
[10]	validation-rmse:6.70195                                                    
[11]	validation-rmse:6.66411                                                    
[12]	validation-rmse:6.63525                                                    
[13]	validation-rmse:6.61246





[0]	validation-rmse:9.78925                                                     
[1]	validation-rmse:8.37386                                                     
[2]	validation-rmse:7.57804                                                     
[3]	validation-rmse:7.13940                                                     
[4]	validation-rmse:6.89710                                                     
[5]	validation-rmse:6.76129                                                     
[6]	validation-rmse:6.68226                                                     
[7]	validation-rmse:6.62755                                                     
[8]	validation-rmse:6.59328                                                     
[9]	validation-rmse:6.56990                                                     
[10]	validation-rmse:6.55612                                                    
[11]	validation-rmse:6.54582                                                    
[12]	validation-rmse:6.53733





[5]	validation-rmse:6.93161                                                     
[6]	validation-rmse:6.86077                                                     
[7]	validation-rmse:6.81416                                                     
[8]	validation-rmse:6.78669                                                     
[9]	validation-rmse:6.77000                                                     
[10]	validation-rmse:6.75852                                                    
[11]	validation-rmse:6.75301                                                    
[12]	validation-rmse:6.74049                                                    
[13]	validation-rmse:6.73626                                                    
[14]	validation-rmse:6.73377                                                    
[15]	validation-rmse:6.72997                                                    
[16]	validation-rmse:6.72523                                                    
[17]	validation-rmse:6.72270





[0]	validation-rmse:11.00791                                                    
[1]	validation-rmse:10.04353                                                    
[2]	validation-rmse:9.27808                                                     
[3]	validation-rmse:8.68023                                                     
[4]	validation-rmse:8.21335                                                     
[5]	validation-rmse:7.84848                                                     
[6]	validation-rmse:7.57413                                                     
[7]	validation-rmse:7.35564                                                     
[8]	validation-rmse:7.19375                                                     
[9]	validation-rmse:7.06845                                                     
[10]	validation-rmse:6.96864                                                    
[11]	validation-rmse:6.89135                                                    
[12]	validation-rmse:6.82873





[0]	validation-rmse:9.67903                                                     
[1]	validation-rmse:8.25730                                                     
[2]	validation-rmse:7.47636                                                     
[3]	validation-rmse:7.05521                                                     
[4]	validation-rmse:6.84365                                                     
[5]	validation-rmse:6.71057                                                     
[6]	validation-rmse:6.64508                                                     
[7]	validation-rmse:6.59997                                                     
[8]	validation-rmse:6.56884                                                     
[9]	validation-rmse:6.55130                                                     
[10]	validation-rmse:6.54086                                                    
[11]	validation-rmse:6.53602                                                    
[12]	validation-rmse:6.52733





[1]	validation-rmse:7.79680                                                     
[2]	validation-rmse:7.19192                                                     
[3]	validation-rmse:6.93403                                                     
[4]	validation-rmse:6.81437                                                     
[5]	validation-rmse:6.75132                                                     
[6]	validation-rmse:6.72074                                                     
[7]	validation-rmse:6.70158                                                     
[8]	validation-rmse:6.68915                                                     
[9]	validation-rmse:6.67726                                                     
[10]	validation-rmse:6.67082                                                    
[11]	validation-rmse:6.66622                                                    
[12]	validation-rmse:6.66314                                                    
[13]	validation-rmse:6.66180





[0]	validation-rmse:11.04522                                                    
[1]	validation-rmse:10.09405                                                    
[2]	validation-rmse:9.32971                                                     
[3]	validation-rmse:8.72003                                                     
[4]	validation-rmse:8.23710                                                     
[5]	validation-rmse:7.85624                                                     
[6]	validation-rmse:7.55972                                                     
[7]	validation-rmse:7.33000                                                     
[8]	validation-rmse:7.14847                                                     
[9]	validation-rmse:7.00594                                                     
[10]	validation-rmse:6.89639                                                    
[11]	validation-rmse:6.80974                                                    
[12]	validation-rmse:6.73983





[1]	validation-rmse:9.98540                                                      
[2]	validation-rmse:9.20824                                                      
[3]	validation-rmse:8.60566                                                      
[4]	validation-rmse:8.13956                                                      
[5]	validation-rmse:7.78434                                                      
[6]	validation-rmse:7.51085                                                      
[7]	validation-rmse:7.30372                                                      
[8]	validation-rmse:7.14632                                                      
[9]	validation-rmse:7.02689                                                      
[10]	validation-rmse:6.93481                                                     
[11]	validation-rmse:6.86506                                                     
[12]	validation-rmse:6.81110                                                     
[13]	validation-





[0]	validation-rmse:11.18203                                                     
[1]	validation-rmse:10.32605                                                     
[2]	validation-rmse:9.62013                                                      
[3]	validation-rmse:9.04275                                                      
[4]	validation-rmse:8.57157                                                      
[5]	validation-rmse:8.19212                                                      
[6]	validation-rmse:7.88138                                                      
[7]	validation-rmse:7.63478                                                      
[8]	validation-rmse:7.43894                                                      
[9]	validation-rmse:7.27895                                                      
[10]	validation-rmse:7.15459                                                     
[11]	validation-rmse:7.04777                                                     
[12]	validation-





[0]	validation-rmse:10.71451                                                     
[1]	validation-rmse:9.58315                                                      
[2]	validation-rmse:8.74073                                                      
[3]	validation-rmse:8.12039                                                      
[4]	validation-rmse:7.67186                                                      
[5]	validation-rmse:7.35017                                                      
[6]	validation-rmse:7.11540                                                      
[7]	validation-rmse:6.94871                                                      
[8]	validation-rmse:6.82515                                                      
[9]	validation-rmse:6.73823                                                      
[10]	validation-rmse:6.67465                                                     
[11]	validation-rmse:6.62458                                                     
[12]	validation-





[2]	validation-rmse:10.91238                                                     
[3]	validation-rmse:10.54417                                                     
[4]	validation-rmse:10.20505                                                     
[5]	validation-rmse:9.89290                                                      
[6]	validation-rmse:9.60570                                                      
[7]	validation-rmse:9.34219                                                      
[8]	validation-rmse:9.10103                                                      
[9]	validation-rmse:8.88014                                                      
[10]	validation-rmse:8.67854                                                     
[11]	validation-rmse:8.49462                                                     
[12]	validation-rmse:8.32646                                                     
[13]	validation-rmse:8.17389                                                     
[14]	validation-





[0]	validation-rmse:11.56905                                                     
[1]	validation-rmse:10.99287                                                     
[2]	validation-rmse:10.47742                                                     
[3]	validation-rmse:10.01927                                                     
[4]	validation-rmse:9.61046                                                      
[5]	validation-rmse:9.24963                                                      
[6]	validation-rmse:8.92886                                                      
[7]	validation-rmse:8.64901                                                      
[8]	validation-rmse:8.39778                                                      
[9]	validation-rmse:8.17675                                                      
[10]	validation-rmse:7.98600                                                     
[11]	validation-rmse:7.81585                                                     
[12]	validation-





[0]	validation-rmse:11.06258                                                     
[1]	validation-rmse:10.12671                                                     
[2]	validation-rmse:9.37181                                                      
[3]	validation-rmse:8.76864                                                      
[4]	validation-rmse:8.29013                                                      
[5]	validation-rmse:7.91277                                                      
[6]	validation-rmse:7.61717                                                      
[7]	validation-rmse:7.38512                                                      
[8]	validation-rmse:7.20244                                                      
[9]	validation-rmse:7.05950                                                      
[10]	validation-rmse:6.94894                                                     
[11]	validation-rmse:6.86240                                                     
[12]	validation-





[0]	validation-rmse:10.03321                                                     
[1]	validation-rmse:8.67104                                                      
[2]	validation-rmse:7.82989                                                      
[3]	validation-rmse:7.33571                                                      
[4]	validation-rmse:7.03523                                                      
[5]	validation-rmse:6.85297                                                      
[6]	validation-rmse:6.74398                                                      
[7]	validation-rmse:6.67452                                                      
[8]	validation-rmse:6.62787                                                      
[9]	validation-rmse:6.59369                                                      
[10]	validation-rmse:6.57413                                                     
[11]	validation-rmse:6.55050                                                     
[12]	validation-





[0]	validation-rmse:11.39798                                                     
[1]	validation-rmse:10.68756                                                     
[2]	validation-rmse:10.07197                                                     
[3]	validation-rmse:9.54143                                                      
[4]	validation-rmse:9.08567                                                      
[5]	validation-rmse:8.69234                                                      
[6]	validation-rmse:8.35897                                                      
[7]	validation-rmse:8.07393                                                      
[8]	validation-rmse:7.83201                                                      
[9]	validation-rmse:7.62650                                                      
[10]	validation-rmse:7.45246                                                     
[11]	validation-rmse:7.30653                                                     
[12]	validation-





[0]	validation-rmse:10.59024                                                     
[1]	validation-rmse:9.41426                                                      
[2]	validation-rmse:8.57347                                                      
[3]	validation-rmse:7.97183                                                      
[4]	validation-rmse:7.55666                                                      
[5]	validation-rmse:7.26406                                                      
[6]	validation-rmse:7.05232                                                      
[7]	validation-rmse:6.90961                                                      
[8]	validation-rmse:6.81458                                                      
[9]	validation-rmse:6.74253                                                      
[10]	validation-rmse:6.68921                                                     
[11]	validation-rmse:6.65185                                                     
[12]	validation-





[0]	validation-rmse:9.18306                                                      
[1]	validation-rmse:7.72274                                                      
[2]	validation-rmse:7.06218                                                      
[3]	validation-rmse:6.76750                                                      
[4]	validation-rmse:6.62574                                                      
[5]	validation-rmse:6.55576                                                      
[6]	validation-rmse:6.51441                                                      
[7]	validation-rmse:6.49289                                                      
[8]	validation-rmse:6.47337                                                      
[9]	validation-rmse:6.46243                                                      
[10]	validation-rmse:6.45000                                                     
[11]	validation-rmse:6.44462                                                     
[12]	validation-





[0]	validation-rmse:11.24180                                                     
[1]	validation-rmse:10.42200                                                     
[2]	validation-rmse:9.73473                                                      
[3]	validation-rmse:9.16085                                                      
[4]	validation-rmse:8.68220                                                      
[5]	validation-rmse:8.28994                                                      
[6]	validation-rmse:7.96644                                                      
[7]	validation-rmse:7.70268                                                      
[8]	validation-rmse:7.48508                                                      
[9]	validation-rmse:7.30966                                                      
[10]	validation-rmse:7.16419                                                     
[11]	validation-rmse:7.04776                                                     
[12]	validation-





[0]	validation-rmse:11.80877                                                     
[1]	validation-rmse:11.43036                                                     
[2]	validation-rmse:11.07674                                                     
[3]	validation-rmse:10.74612                                                     
[4]	validation-rmse:10.43751                                                     
[5]	validation-rmse:10.14981                                                     
[6]	validation-rmse:9.88225                                                      
[7]	validation-rmse:9.63337                                                      
[8]	validation-rmse:9.40176                                                      
[9]	validation-rmse:9.18643                                                      
[10]	validation-rmse:8.98690                                                     
[11]	validation-rmse:8.80173                                                     
[12]	validation-





[1]	validation-rmse:9.77003                                                      
[2]	validation-rmse:8.96161                                                      
[3]	validation-rmse:8.35787                                                      
[4]	validation-rmse:7.90458                                                      
[5]	validation-rmse:7.57171                                                      
[6]	validation-rmse:7.32518                                                      
[7]	validation-rmse:7.14317                                                      
[8]	validation-rmse:7.01084                                                      
[9]	validation-rmse:6.91178                                                      
[10]	validation-rmse:6.83753                                                     
[11]	validation-rmse:6.78144                                                     
[12]	validation-rmse:6.73664                                                     
[13]	validation-





[0]	validation-rmse:11.43402                                                     
[1]	validation-rmse:10.75158                                                     
[2]	validation-rmse:10.15530                                                     
[3]	validation-rmse:9.63682                                                      
[4]	validation-rmse:9.18837                                                      
[5]	validation-rmse:8.80071                                                      
[6]	validation-rmse:8.46762                                                      
[7]	validation-rmse:8.18052                                                      
[8]	validation-rmse:7.93526                                                      
[9]	validation-rmse:7.72666                                                      
[10]	validation-rmse:7.54802                                                     
[11]	validation-rmse:7.39607                                                     
[12]	validation-





[0]	validation-rmse:11.56865                                                     
[1]	validation-rmse:10.99433                                                     
[2]	validation-rmse:10.47930                                                     
[3]	validation-rmse:10.02145                                                     
[4]	validation-rmse:9.61180                                                      
[5]	validation-rmse:9.24837                                                      
[6]	validation-rmse:8.93245                                                      
[7]	validation-rmse:8.64524                                                      
[8]	validation-rmse:8.39670                                                      
[9]	validation-rmse:8.17575                                                      
[10]	validation-rmse:7.97993                                                     
[11]	validation-rmse:7.81164                                                     
[12]	validation-





[0]	validation-rmse:11.69739                                                     
[1]	validation-rmse:11.22564                                                     
[2]	validation-rmse:10.79511                                                     
[3]	validation-rmse:10.39796                                                     
[4]	validation-rmse:10.03882                                                     
[5]	validation-rmse:9.71060                                                      
[6]	validation-rmse:9.41162                                                      
[7]	validation-rmse:9.13610                                                      
[8]	validation-rmse:8.88798                                                      
[9]	validation-rmse:8.66555                                                      
[10]	validation-rmse:8.46007                                                     
[11]	validation-rmse:8.27405                                                     
[12]	validation-





[0]	validation-rmse:8.34406                                                      
[1]	validation-rmse:7.15122                                                      
[2]	validation-rmse:6.77732                                                      
[3]	validation-rmse:6.65579                                                      
[4]	validation-rmse:6.58440                                                      
[5]	validation-rmse:6.56110                                                      
[6]	validation-rmse:6.54529                                                      
[7]	validation-rmse:6.53664                                                      
[8]	validation-rmse:6.53054                                                      
[9]	validation-rmse:6.52649                                                      
[10]	validation-rmse:6.52010                                                     
[11]	validation-rmse:6.51534                                                     
[12]	validation-





[6]	validation-rmse:7.86570                                                      
[7]	validation-rmse:7.63517                                                      
[8]	validation-rmse:7.45561                                                      
[9]	validation-rmse:7.31476                                                      
[10]	validation-rmse:7.20221                                                     
[11]	validation-rmse:7.11289                                                     
[12]	validation-rmse:7.04302                                                     
[13]	validation-rmse:6.98696                                                     
[14]	validation-rmse:6.94194                                                     
[15]	validation-rmse:6.90601                                                     
[16]	validation-rmse:6.87725                                                     
[17]	validation-rmse:6.85423                                                     
[18]	validation-





[0]	validation-rmse:11.35745                                                     
[1]	validation-rmse:10.61783                                                     
[2]	validation-rmse:9.98177                                                      
[3]	validation-rmse:9.43774                                                      
[4]	validation-rmse:8.97166                                                      
[5]	validation-rmse:8.57741                                                      
[6]	validation-rmse:8.24308                                                      
[7]	validation-rmse:7.96187                                                      
[8]	validation-rmse:7.72604                                                      
 88%|████████▊ | 44/50 [30:20<04:08, 41.38s/trial, best loss: 6.3050429894846625]


KeyboardInterrupt: 

### fit a model with optimal parameters selected above (with auto_log=True)

[MLflow Tracking](https://mlflow.org/docs/latest/tracking/autolog.html#supported-libraries)

The generic autolog function mlflow.autolog() enables autologging for each supported library you have installed as soon as you import it.

也就是说，只要你import了`mlflow`，然后调用[`mlflow.autolog()`](https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.autolog)，那么你在后面调用的任何支持的库，都会自动记录参数，模型，metrics等等。

The following libraries support autologging:
- Fastai
- Gluon
- Keras
- LightGBM
- PyTorch
- Scikit-learn
- Spark
- Statsmodels
- XGBoost


In [63]:
# 启用XGBoost自动记录
mlflow.xgboost.autolog()

with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        'learning_rate': 0.058550748798202254,
        'max_depth': 16,
        'min_child_weight': 3.616990741399196,
        'objective': 'reg:linear',
        'reg_alpha': 0.014691513505959162,
        'reg_lambda': 0.15590387336160477,
        'seed': 42
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, 'validation')],
        early_stopping_rounds=50
    )

    y_pred = booster.predict(valid)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    # 当auto_log=False时，需要手动调用mlflow.log_model
    # with open("models/preprocessor.b", "wb") as f_out:
    #     pickle.dump(dv, f_out)
    # mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")

[0]	validation-rmse:11.74481




[1]	validation-rmse:11.31211
[2]	validation-rmse:10.91238
[3]	validation-rmse:10.54417
[4]	validation-rmse:10.20505
[5]	validation-rmse:9.89290
[6]	validation-rmse:9.60570
[7]	validation-rmse:9.34219
[8]	validation-rmse:9.10103
[9]	validation-rmse:8.88014
[10]	validation-rmse:8.67854
[11]	validation-rmse:8.49462
[12]	validation-rmse:8.32646
[13]	validation-rmse:8.17389
[14]	validation-rmse:8.03496
[15]	validation-rmse:7.90781
[16]	validation-rmse:7.79301
[17]	validation-rmse:7.68817
[18]	validation-rmse:7.59198
[19]	validation-rmse:7.50505
[20]	validation-rmse:7.42664
[21]	validation-rmse:7.35457
[22]	validation-rmse:7.28971
[23]	validation-rmse:7.23062
[24]	validation-rmse:7.17669
[25]	validation-rmse:7.12811
[26]	validation-rmse:7.08421
[27]	validation-rmse:7.04428
[28]	validation-rmse:7.00817
[29]	validation-rmse:6.97509
[30]	validation-rmse:6.94465
[31]	validation-rmse:6.91717
[32]	validation-rmse:6.89209
[33]	validation-rmse:6.86916
[34]	validation-rmse:6.84837
[35]	validation-rms



### fit a model with optimal parameters selected above (with auto_log=False)


In [65]:
# 不启用XGBoost自动记录
mlflow.xgboost.autolog(disable=True)

with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        'learning_rate': 0.058550748798202254,
        'max_depth': 16,
        'min_child_weight': 3.616990741399196,
        'objective': 'reg:linear',
        'reg_alpha': 0.014691513505959162,
        'reg_lambda': 0.15590387336160477,
        'seed': 42
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, 'validation')],
        early_stopping_rounds=50
    )

    y_pred = booster.predict(valid)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    # 将上面的dv保存到models/preprocessor.b文件中，
    # 因为我们使用了X_train来fit了预处理器 dv = DictVectorizer(), 当我们再次使用模型时，需要使用相同的预处理器
    with open("models/preprocessor.b", "wb") as f_out:
        # 用pickle包保存dv，其中dv是DictVectorizer()的实例，而f_out是一个文件对象
        # 这俩共同保存在了models/preprocessor.b文件中
        pickle.dump(dv, f_out)
    
    # 当auto_log=False时，需要手动添加mlflow.artifact，从而记录模型文件
    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")

[0]	validation-rmse:11.74481
[1]	validation-rmse:11.31211
[2]	validation-rmse:10.91238
[3]	validation-rmse:10.54417
[4]	validation-rmse:10.20505
[5]	validation-rmse:9.89290
[6]	validation-rmse:9.60570
[7]	validation-rmse:9.34219
[8]	validation-rmse:9.10103
[9]	validation-rmse:8.88014
[10]	validation-rmse:8.67854
[11]	validation-rmse:8.49462
[12]	validation-rmse:8.32646
[13]	validation-rmse:8.17389
[14]	validation-rmse:8.03496
[15]	validation-rmse:7.90781
[16]	validation-rmse:7.79301
[17]	validation-rmse:7.68817
[18]	validation-rmse:7.59198
[19]	validation-rmse:7.50505
[20]	validation-rmse:7.42664
[21]	validation-rmse:7.35457
[22]	validation-rmse:7.28971
[23]	validation-rmse:7.23062
[24]	validation-rmse:7.17669
[25]	validation-rmse:7.12811
[26]	validation-rmse:7.08421
[27]	validation-rmse:7.04428
[28]	validation-rmse:7.00817
[29]	validation-rmse:6.97509
[30]	validation-rmse:6.94465
[31]	validation-rmse:6.91717
[32]	validation-rmse:6.89209
[33]	validation-rmse:6.86916
[34]	validation-rms

In [None]:
import pickle

# 加载预处理器
with open("models/preprocessor.b", "rb") as f_in:
    dv = pickle.load(f_in)

# 使用加载的预处理器进行数据转换
# transformed_data = dv.transform(new_data)

### 不用autolog的方法（原来的手动方法）

In [26]:
with mlflow.start_run():
    
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_val, label=y_val)

    best_params = {
        'learning_rate': 0.058550748798202254,
        'max_depth': 16,
        'min_child_weight': 3.616990741399196,
        'objective': 'reg:linear',
        'reg_alpha': 0.014691513505959162,
        'reg_lambda': 0.15590387336160477,
        'seed': 42
    }

    mlflow.log_params(best_params)

    booster = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, 'validation')],
        early_stopping_rounds=50
    )

    y_pred = booster.predict(valid)
    rmse = mean_squared_error(y_val, y_pred, squared=False)
    mlflow.log_metric("rmse", rmse)

    with open("models/preprocessor.b", "wb") as f_out:
        pickle.dump(dv, f_out)
    mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

    # 不需要再次手动调用mlflow.xgboost.log_model
    mlflow.xgboost.log_model(booster, artifact_path="models_mlflow")

[0]	validation-rmse:11.74481




[1]	validation-rmse:11.31211
[2]	validation-rmse:10.91238
[3]	validation-rmse:10.54417
[4]	validation-rmse:10.20505
[5]	validation-rmse:9.89290
[6]	validation-rmse:9.60570
[7]	validation-rmse:9.34219
[8]	validation-rmse:9.10103
[9]	validation-rmse:8.88014
[10]	validation-rmse:8.67854
[11]	validation-rmse:8.49462
[12]	validation-rmse:8.32646
[13]	validation-rmse:8.17389
[14]	validation-rmse:8.03496
[15]	validation-rmse:7.90781
[16]	validation-rmse:7.79301
[17]	validation-rmse:7.68817
[18]	validation-rmse:7.59198
[19]	validation-rmse:7.50505
[20]	validation-rmse:7.42664
[21]	validation-rmse:7.35457
[22]	validation-rmse:7.28971
[23]	validation-rmse:7.23062
[24]	validation-rmse:7.17669
[25]	validation-rmse:7.12811
[26]	validation-rmse:7.08421
[27]	validation-rmse:7.04428
[28]	validation-rmse:7.00817
[29]	validation-rmse:6.97509
[30]	validation-rmse:6.94465
[31]	validation-rmse:6.91717
[32]	validation-rmse:6.89209
[33]	validation-rmse:6.86916
[34]	validation-rmse:6.84837
[35]	validation-rms



### Predict on a # Load model as a PyFuncModel （方法1）

In [56]:
# Load model as a PyFuncModel.
logged_model = 'runs:/c51562cf17cb433fbb1deb05b0709d49/model'

loaded_model = mlflow.pyfunc.load_model(logged_model)
loaded_model



mlflow.pyfunc.loaded_model:
  artifact_path: model
  flavor: mlflow.xgboost
  run_id: c51562cf17cb433fbb1deb05b0709d49

In [58]:
X_val_dict = df_val[categorical + numerical].to_dict(orient='records')
# preprocessed DataFrame with DictionaryVectorizer
X_val = dv.transform(X_val_dict)

In [60]:
# loaded_model.predict(valid) # 这个会报错，因为loaded_model是一个PyFuncModel，不是XGBoost模型
# error内容显示 TypeError: Not supported type for data.<class 'xgboost.core.DMatrix'>
loaded_model.predict(X_val)

array([15.085921,  7.206091, 16.032732, ..., 15.449408, 27.112658,
        8.331218], dtype=float32)

### Load model as XGBoost Model (方法2)

In [61]:
# Load model as XGBmodel
xgb_model = mlflow.xgboost.load_model(logged_model)
xgb_model



<xgboost.core.Booster at 0x32118e820>

In [62]:
# convert X_val into DMatrix type
valid = xgb.DMatrix(X_val, label=y_val)

# 当我们使用XGBoost模型时，我们需要使用刚才转换的DMatrix类型的数据，而不是原始的特征矩阵。
xgb_model.predict(valid)

array([15.085921,  7.206091, 16.032732, ..., 15.449408, 27.112658,
        8.331218], dtype=float32)

### 使用不同的模型

In [80]:
# x爱你在的MLflow仅支持 scikit-learn <= 1.4.2
# downgrade scikit-learn to 1.4.2
# 先卸载当前版本
!pip uninstall -y scikit-learn

Found existing installation: scikit-learn 1.4.2
Uninstalling scikit-learn-1.4.2:
  Successfully uninstalled scikit-learn-1.4.2


In [None]:
# !pip install scikit-learn==1.4.2

In [11]:
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor

from sklearn.svm import LinearSVR

In [17]:
# mlflow.sklearn.autolog()

# 因为tmd太慢了所以我没跑这个
# RandomForestRegressor, ExtraTreesRegressor 超级慢
# LinearSVR 我觉得这个也很慢

# for model_class in (GradientBoostingRegressor):

#     with mlflow.start_run():

#         mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
#         mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
#         mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

#         mlmodel = model_class()
#         mlmodel.fit(X_train, y_train)

#         y_pred = mlmodel.predict(X_val)
#         rmse = mean_squared_error(y_val, y_pred, squared=False)
#         mlflow.log_metric("rmse", rmse)


In [15]:
from sklearn.ensemble import GradientBoostingRegressor
# import lasso and ElasticNet
from sklearn.linear_model import Lasso, ElasticNet

In [12]:
# 这个一定要设置，不然MLflow不会记录模型
# 这里我们重新安装了一遍sklearn==1.4.2，因为MLflow仅支持 scikit-learn <= 1.4.2 （2024-05-28）
mlflow.sklearn.autolog()

In [16]:
# 这三个跑得快，原本的跑的太jbd慢了
for model_class in (GradientBoostingRegressor, ElasticNet, Lasso):

    with mlflow.start_run():

        mlflow.log_param("train-data-path", "./data/green_tripdata_2021-01.csv")
        mlflow.log_param("valid-data-path", "./data/green_tripdata_2021-02.csv")
        mlflow.log_artifact("models/preprocessor.b", artifact_path="preprocessor")

        mlmodel = model_class()
        mlmodel.fit(X_train, y_train)

        y_pred = mlmodel.predict(X_val)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        mlflow.log_metric("rmse", rmse)


