In [1]:
import pandas as pd
import numpy as np
np.random.seed(0)

from IPython.display import display
import ipywidgets as widgets

from sklearn.model_selection import cross_val_score, cross_val_predict, KFold
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split, GridSearchCV

from sklearn.dummy import DummyRegressor
from sklearn.svm import SVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_selection import RFE

import xgboost as xgb
import catboost as ctb
import lightgbm as lgb

import mlflow
import mlflow.sklearn
import mlflow.xgboost

import re
import sys
import os

import scikitplot as skplt

import seaborn as sns
import matplotlib.pyplot as plt

pd.pandas.set_option('display.max_columns', None)
%matplotlib inline

In [2]:
remote_server_uri = 'https://dagshub.com/adam.zabek/car_price_prediction.mlflow' ### insert url to remote server
mlflow.set_tracking_uri(remote_server_uri)

In [3]:
os.environ['MLFLOW_TRACKING_USERNAME'] = 'adam.zabek' ### insert name
os.environ['MLFLOW_TRACKING_PASSWORD'] = 'e3c5d562249c76008f52ba8e9f3e0c7b416512db' ### insert password

In [4]:
mlflow.set_experiment("car_price_prediction_full_data")

<Experiment: artifact_location='mlflow-artifacts:/e8dc41b9c8654cdaae28e552be07f5b6', creation_time=1695909119704, experiment_id='2', last_update_time=1695909119704, lifecycle_stage='active', name='car_price_prediction_full_data', tags={}>

In [5]:
train_set = pd.read_csv('./train_set_selected.csv')

In [6]:
train_set.head()

Unnamed: 0.1,Unnamed: 0,Norm Engine capacity,Norm Mileage,Norm Engine power,Norm Age,norm_Typ,norm_Color,norm_Region,norm_Company,Transmission_Na przednie koła,Damaged_Nie,Price
0,0,-0.950847,-1.56425,-0.438799,-1.662566,3,4,13,14,1,1,11.065075
1,1,-0.279087,-0.262598,-0.285932,-0.314186,6,4,6,57,1,1,10.643041
2,2,-0.649339,-0.403158,-0.936032,0.37577,1,4,14,59,1,1,9.795345
3,3,-0.00333,0.329471,0.489637,0.37577,7,4,7,4,1,1,10.545341
4,4,-0.853436,-0.500523,-1.166697,0.057435,0,2,7,21,1,1,9.994242


In [7]:
black_list = ['Price', 'Unnamed: 0']

In [8]:
def get_feats(df, black_list):
    feats = df.columns
    return [x for x in feats if x not in black_list]

In [9]:
def get_X(df, feats):
    
    X = df[feats]
    return X

def get_y(df, target_var):
    return df[target_var].values

In [10]:
feats = get_feats(train_set, black_list)

In [11]:
X = get_X(train_set, feats)

In [12]:
y = get_y(train_set, target_var = 'Price')

In [13]:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, train_size = .8, shuffle = True)

In [19]:
model_param = [
    {
        'model_type': 'DecisionTree',
        'params': {
            'criterion': 'absolute_error',
            'max_depth': 10,
            'random_state': 0
        }
    },
    {
        'model_type': 'RandomForest',
        'params': {
            'n_estimators': 1000,
            'max_depth': 20,
            'random_state': 0
        }
    },
    {
        'model_type': 'CatBoost',
        'params': {
            #'iterations': 2000,
            'depth': 10,
            'learning_rate': 0.01,
            'n_estimators': 2000,
            'loss_function': 'MAE',
            'early_stopping_rounds': 50,
            'random_state': 0
        }
    },
    {
        'model_type': 'LightGBM',
        'params': {
            'boosting_type': 'gbdt',
            'num_leaves': 30,
            'max_depth': 20,
            'n_estimators': 5000,
            'learning_rate': 0.05,
            'subsample': 0.9,
            'colsample_bytree': 0.5,
            'random_state': 0
        }
    },
    {
        'model_type': 'XGBoost',
        'params': {
            'max_depth': 20,
            'learning_rate': 0.05,
            'n_estimators': 5000,
            'subsample': 0.9,
            'colsample_bytree': 0.5,
            'random_state': 0
        }
    }
]

In [20]:
def train_models(models, model_param, X_train, X_test, y_train, y_test):
    mlflow.sklearn.autolog(disable=True)

    for model_name in models:
        model_params  = next((params for params in model_param if params['model_type'] == model_name), None)
        
        if model_params is None:
            print(f"Model '{model_name}' not found in the parameters list.")
            continue
        
        with mlflow.start_run(run_name=f'{model_name}_Model'):
            if model_name == 'DecisionTree':
                
                mlflow.set_tag("model_name", "DT")
                mlflow.log_params(model_params)
            
                model = DecisionTreeRegressor(**model_params['params'])
                model.fit(X_train, y_train)
            
                y_pred = model.predict(X_test)
                y_pred = np.exp(y_pred)
                y_test_exp = np.exp(y_test)
            
                mae = mean_absolute_error(y_test_exp, y_pred)
            
                mlflow.log_metric("test_mae", mae)
                mlflow.sklearn.log_model(model, f"{model_name}_Model")
            
            elif model_name == 'RandomForest':

                mlflow.set_tag("model_name", "RF")
                mlflow.log_params(model_params)
                        
                model = RandomForestRegressor(**model_params['params'])
                model.fit(X_train, y_train)
            
                y_pred = model.predict(X_test)
                y_pred = np.exp(y_pred)
                y_test_exp = np.exp(y_test)
            
                mae = mean_absolute_error(y_test_exp, y_pred)
            
                mlflow.log_metric("test_mae", mae)
                mlflow.sklearn.log_model(model, f"{model_name}_Model")
            
            elif model_name == 'CatBoost':

                mlflow.set_tag("model_name", "CB")
                mlflow.log_params(model_params)
                                    
                model = ctb.CatBoostRegressor(**model_params['params'])
                model.fit(X_train, y_train)
            
                y_pred = model.predict(X_test)
                y_pred = np.exp(y_pred)
                y_test_exp = np.exp(y_test)
                
                mae = mean_absolute_error(y_test_exp, y_pred)
            
                mlflow.log_metric("test_mae", mae)
                mlflow.sklearn.log_model(model, f"{model_name}_Model")
            
            elif model_name == 'LightGBM':

                mlflow.set_tag("model_name", "LGB")
                mlflow.log_params(model_params)
                                    
                model = lgb.LGBMRegressor(**model_params['params'])
                model.fit(X_train, y_train)
            
                y_pred = model.predict(X_test)
                y_pred = np.exp(y_pred)
                y_test_exp = np.exp(y_test)
            
                mae = mean_absolute_error(y_test_exp, y_pred)
            
                mlflow.log_metric("test_mae", mae)
                mlflow.sklearn.log_model(model, f"{model_name}_Model")
            
            elif model_name == 'XGBoost':

                mlflow.set_tag("model_name", "XGB")
                mlflow.log_params(model_params)
                                    
                model = xgb.XGBRegressor(**model_params['params'])
                model.fit(X_train, y_train)
            
                y_pred = model.predict(X_test)
                y_pred = np.exp(y_pred)
                y_test_exp = np.exp(y_test)
            
                mae = mean_absolute_error(y_test_exp, y_pred)
            
                mlflow.log_metric("test_mae", mae)
                mlflow.sklearn.log_model(model, f"{model_name}_Model")

In [21]:
models = ['DecisionTree','RandomForest','CatBoost','LightGBM','XGBoost']
train_models(models, model_param, X_train, X_test, y_train, y_test)

0:	learn: 0.8102215	total: 186ms	remaining: 6m 12s
1:	learn: 0.8033852	total: 211ms	remaining: 3m 30s
2:	learn: 0.7966934	total: 235ms	remaining: 2m 36s
3:	learn: 0.7898669	total: 271ms	remaining: 2m 15s
4:	learn: 0.7831105	total: 302ms	remaining: 2m
5:	learn: 0.7765006	total: 325ms	remaining: 1m 47s
6:	learn: 0.7700668	total: 347ms	remaining: 1m 38s
7:	learn: 0.7637574	total: 371ms	remaining: 1m 32s
8:	learn: 0.7574940	total: 393ms	remaining: 1m 26s
9:	learn: 0.7512218	total: 416ms	remaining: 1m 22s
10:	learn: 0.7449968	total: 440ms	remaining: 1m 19s
11:	learn: 0.7387936	total: 462ms	remaining: 1m 16s
12:	learn: 0.7330104	total: 484ms	remaining: 1m 13s
13:	learn: 0.7270028	total: 504ms	remaining: 1m 11s
14:	learn: 0.7211239	total: 525ms	remaining: 1m 9s
15:	learn: 0.7152805	total: 547ms	remaining: 1m 7s
16:	learn: 0.7095026	total: 567ms	remaining: 1m 6s
17:	learn: 0.7038234	total: 589ms	remaining: 1m 4s
18:	learn: 0.6982456	total: 611ms	remaining: 1m 3s
19:	learn: 0.6927223	total: 634

169:	learn: 0.2966991	total: 3.98s	remaining: 42.8s
170:	learn: 0.2957680	total: 4s	remaining: 42.8s
171:	learn: 0.2948568	total: 4.02s	remaining: 42.7s
172:	learn: 0.2939699	total: 4.04s	remaining: 42.7s
173:	learn: 0.2930863	total: 4.06s	remaining: 42.6s
174:	learn: 0.2922407	total: 4.08s	remaining: 42.6s
175:	learn: 0.2913276	total: 4.1s	remaining: 42.5s
176:	learn: 0.2905062	total: 4.12s	remaining: 42.5s
177:	learn: 0.2896269	total: 4.14s	remaining: 42.4s
178:	learn: 0.2887689	total: 4.17s	remaining: 42.4s
179:	learn: 0.2879409	total: 4.19s	remaining: 42.4s
180:	learn: 0.2871643	total: 4.21s	remaining: 42.3s
181:	learn: 0.2863283	total: 4.23s	remaining: 42.3s
182:	learn: 0.2854911	total: 4.25s	remaining: 42.2s
183:	learn: 0.2846900	total: 4.27s	remaining: 42.2s
184:	learn: 0.2839996	total: 4.29s	remaining: 42.1s
185:	learn: 0.2832514	total: 4.32s	remaining: 42.1s
186:	learn: 0.2825245	total: 4.33s	remaining: 42s
187:	learn: 0.2817391	total: 4.36s	remaining: 42s
188:	learn: 0.280944

331:	learn: 0.2292898	total: 7.53s	remaining: 37.8s
332:	learn: 0.2291504	total: 7.55s	remaining: 37.8s
333:	learn: 0.2289612	total: 7.57s	remaining: 37.8s
334:	learn: 0.2287939	total: 7.59s	remaining: 37.7s
335:	learn: 0.2286406	total: 7.61s	remaining: 37.7s
336:	learn: 0.2284896	total: 7.63s	remaining: 37.7s
337:	learn: 0.2283429	total: 7.65s	remaining: 37.6s
338:	learn: 0.2281554	total: 7.67s	remaining: 37.6s
339:	learn: 0.2280183	total: 7.7s	remaining: 37.6s
340:	learn: 0.2278036	total: 7.72s	remaining: 37.5s
341:	learn: 0.2276638	total: 7.74s	remaining: 37.5s
342:	learn: 0.2275262	total: 7.76s	remaining: 37.5s
343:	learn: 0.2273738	total: 7.79s	remaining: 37.5s
344:	learn: 0.2272363	total: 7.81s	remaining: 37.5s
345:	learn: 0.2271080	total: 7.83s	remaining: 37.4s
346:	learn: 0.2269338	total: 7.85s	remaining: 37.4s
347:	learn: 0.2267448	total: 7.87s	remaining: 37.4s
348:	learn: 0.2266195	total: 7.89s	remaining: 37.3s
349:	learn: 0.2264703	total: 7.92s	remaining: 37.3s
350:	learn: 0

495:	learn: 0.2119279	total: 11.1s	remaining: 33.8s
496:	learn: 0.2118682	total: 11.2s	remaining: 33.8s
497:	learn: 0.2118094	total: 11.2s	remaining: 33.7s
498:	learn: 0.2117560	total: 11.2s	remaining: 33.7s
499:	learn: 0.2117088	total: 11.2s	remaining: 33.7s
500:	learn: 0.2116403	total: 11.3s	remaining: 33.7s
501:	learn: 0.2115448	total: 11.3s	remaining: 33.6s
502:	learn: 0.2114897	total: 11.3s	remaining: 33.6s
503:	learn: 0.2114466	total: 11.3s	remaining: 33.6s
504:	learn: 0.2113983	total: 11.4s	remaining: 33.6s
505:	learn: 0.2113429	total: 11.4s	remaining: 33.6s
506:	learn: 0.2112742	total: 11.4s	remaining: 33.6s
507:	learn: 0.2112085	total: 11.4s	remaining: 33.6s
508:	learn: 0.2111524	total: 11.4s	remaining: 33.5s
509:	learn: 0.2111148	total: 11.5s	remaining: 33.5s
510:	learn: 0.2110496	total: 11.5s	remaining: 33.5s
511:	learn: 0.2109987	total: 11.5s	remaining: 33.5s
512:	learn: 0.2109479	total: 11.5s	remaining: 33.4s
513:	learn: 0.2108781	total: 11.6s	remaining: 33.4s
514:	learn: 

658:	learn: 0.2035010	total: 14.9s	remaining: 30.2s
659:	learn: 0.2034400	total: 14.9s	remaining: 30.2s
660:	learn: 0.2034060	total: 14.9s	remaining: 30.2s
661:	learn: 0.2033696	total: 14.9s	remaining: 30.2s
662:	learn: 0.2033385	total: 14.9s	remaining: 30.1s
663:	learn: 0.2033064	total: 15s	remaining: 30.1s
664:	learn: 0.2032578	total: 15s	remaining: 30.1s
665:	learn: 0.2032046	total: 15s	remaining: 30.1s
666:	learn: 0.2031722	total: 15s	remaining: 30s
667:	learn: 0.2031450	total: 15.1s	remaining: 30s
668:	learn: 0.2030913	total: 15.1s	remaining: 30s
669:	learn: 0.2030512	total: 15.1s	remaining: 30s
670:	learn: 0.2030059	total: 15.1s	remaining: 29.9s
671:	learn: 0.2029422	total: 15.1s	remaining: 29.9s
672:	learn: 0.2028869	total: 15.2s	remaining: 29.9s
673:	learn: 0.2028584	total: 15.2s	remaining: 29.9s
674:	learn: 0.2028209	total: 15.2s	remaining: 29.9s
675:	learn: 0.2027758	total: 15.2s	remaining: 29.8s
676:	learn: 0.2027312	total: 15.3s	remaining: 29.8s
677:	learn: 0.2026829	total:

824:	learn: 0.1977356	total: 18.7s	remaining: 26.6s
825:	learn: 0.1977049	total: 18.7s	remaining: 26.6s
826:	learn: 0.1976770	total: 18.7s	remaining: 26.6s
827:	learn: 0.1976489	total: 18.8s	remaining: 26.6s
828:	learn: 0.1976186	total: 18.8s	remaining: 26.5s
829:	learn: 0.1975938	total: 18.8s	remaining: 26.5s
830:	learn: 0.1975620	total: 18.8s	remaining: 26.5s
831:	learn: 0.1975290	total: 18.9s	remaining: 26.5s
832:	learn: 0.1975056	total: 18.9s	remaining: 26.5s
833:	learn: 0.1974731	total: 18.9s	remaining: 26.4s
834:	learn: 0.1974515	total: 18.9s	remaining: 26.4s
835:	learn: 0.1974336	total: 19s	remaining: 26.4s
836:	learn: 0.1974017	total: 19s	remaining: 26.4s
837:	learn: 0.1973631	total: 19s	remaining: 26.3s
838:	learn: 0.1973304	total: 19s	remaining: 26.3s
839:	learn: 0.1973078	total: 19s	remaining: 26.3s
840:	learn: 0.1972815	total: 19.1s	remaining: 26.3s
841:	learn: 0.1972610	total: 19.1s	remaining: 26.2s
842:	learn: 0.1972289	total: 19.1s	remaining: 26.2s
843:	learn: 0.1972083	

984:	learn: 0.1933171	total: 22.5s	remaining: 23.1s
985:	learn: 0.1933059	total: 22.5s	remaining: 23.1s
986:	learn: 0.1932751	total: 22.5s	remaining: 23.1s
987:	learn: 0.1932385	total: 22.5s	remaining: 23.1s
988:	learn: 0.1932207	total: 22.5s	remaining: 23s
989:	learn: 0.1931948	total: 22.6s	remaining: 23s
990:	learn: 0.1931773	total: 22.6s	remaining: 23s
991:	learn: 0.1931539	total: 22.6s	remaining: 23s
992:	learn: 0.1931303	total: 22.6s	remaining: 23s
993:	learn: 0.1931058	total: 22.7s	remaining: 22.9s
994:	learn: 0.1930899	total: 22.7s	remaining: 22.9s
995:	learn: 0.1930701	total: 22.7s	remaining: 22.9s
996:	learn: 0.1930462	total: 22.7s	remaining: 22.9s
997:	learn: 0.1930305	total: 22.7s	remaining: 22.8s
998:	learn: 0.1929946	total: 22.8s	remaining: 22.8s
999:	learn: 0.1929640	total: 22.8s	remaining: 22.8s
1000:	learn: 0.1929475	total: 22.8s	remaining: 22.8s
1001:	learn: 0.1929260	total: 22.8s	remaining: 22.8s
1002:	learn: 0.1929043	total: 22.9s	remaining: 22.7s
1003:	learn: 0.1928

1141:	learn: 0.1896665	total: 26.2s	remaining: 19.6s
1142:	learn: 0.1896460	total: 26.2s	remaining: 19.6s
1143:	learn: 0.1896245	total: 26.2s	remaining: 19.6s
1144:	learn: 0.1895990	total: 26.2s	remaining: 19.6s
1145:	learn: 0.1895754	total: 26.3s	remaining: 19.6s
1146:	learn: 0.1895595	total: 26.3s	remaining: 19.5s
1147:	learn: 0.1895375	total: 26.3s	remaining: 19.5s
1148:	learn: 0.1895238	total: 26.3s	remaining: 19.5s
1149:	learn: 0.1895004	total: 26.4s	remaining: 19.5s
1150:	learn: 0.1894790	total: 26.4s	remaining: 19.5s
1151:	learn: 0.1894435	total: 26.4s	remaining: 19.4s
1152:	learn: 0.1894287	total: 26.4s	remaining: 19.4s
1153:	learn: 0.1894073	total: 26.4s	remaining: 19.4s
1154:	learn: 0.1893954	total: 26.5s	remaining: 19.4s
1155:	learn: 0.1893740	total: 26.5s	remaining: 19.3s
1156:	learn: 0.1893559	total: 26.5s	remaining: 19.3s
1157:	learn: 0.1893319	total: 26.5s	remaining: 19.3s
1158:	learn: 0.1893138	total: 26.6s	remaining: 19.3s
1159:	learn: 0.1892983	total: 26.6s	remaining:

1299:	learn: 0.1864909	total: 29.9s	remaining: 16.1s
1300:	learn: 0.1864773	total: 29.9s	remaining: 16.1s
1301:	learn: 0.1864630	total: 29.9s	remaining: 16.1s
1302:	learn: 0.1864447	total: 30s	remaining: 16s
1303:	learn: 0.1864294	total: 30s	remaining: 16s
1304:	learn: 0.1864044	total: 30s	remaining: 16s
1305:	learn: 0.1863767	total: 30s	remaining: 16s
1306:	learn: 0.1863658	total: 30.1s	remaining: 15.9s
1307:	learn: 0.1863543	total: 30.1s	remaining: 15.9s
1308:	learn: 0.1863426	total: 30.1s	remaining: 15.9s
1309:	learn: 0.1863220	total: 30.2s	remaining: 15.9s
1310:	learn: 0.1862903	total: 30.2s	remaining: 15.9s
1311:	learn: 0.1862689	total: 30.2s	remaining: 15.8s
1312:	learn: 0.1862449	total: 30.2s	remaining: 15.8s
1313:	learn: 0.1862313	total: 30.2s	remaining: 15.8s
1314:	learn: 0.1862213	total: 30.3s	remaining: 15.8s
1315:	learn: 0.1861989	total: 30.3s	remaining: 15.7s
1316:	learn: 0.1861822	total: 30.3s	remaining: 15.7s
1317:	learn: 0.1861675	total: 30.3s	remaining: 15.7s
1318:	lea

1463:	learn: 0.1834946	total: 33.8s	remaining: 12.4s
1464:	learn: 0.1834787	total: 33.9s	remaining: 12.4s
1465:	learn: 0.1834640	total: 33.9s	remaining: 12.3s
1466:	learn: 0.1834451	total: 33.9s	remaining: 12.3s
1467:	learn: 0.1834232	total: 33.9s	remaining: 12.3s
1468:	learn: 0.1834108	total: 34s	remaining: 12.3s
1469:	learn: 0.1833916	total: 34s	remaining: 12.3s
1470:	learn: 0.1833811	total: 34s	remaining: 12.2s
1471:	learn: 0.1833596	total: 34s	remaining: 12.2s
1472:	learn: 0.1833406	total: 34.1s	remaining: 12.2s
1473:	learn: 0.1833244	total: 34.1s	remaining: 12.2s
1474:	learn: 0.1833081	total: 34.1s	remaining: 12.1s
1475:	learn: 0.1832935	total: 34.1s	remaining: 12.1s
1476:	learn: 0.1832742	total: 34.2s	remaining: 12.1s
1477:	learn: 0.1832530	total: 34.2s	remaining: 12.1s
1478:	learn: 0.1832356	total: 34.2s	remaining: 12.1s
1479:	learn: 0.1832122	total: 34.2s	remaining: 12s
1480:	learn: 0.1831969	total: 34.3s	remaining: 12s
1481:	learn: 0.1831794	total: 34.3s	remaining: 12s
1482:	l

1621:	learn: 0.1810624	total: 37.7s	remaining: 8.78s
1622:	learn: 0.1810524	total: 37.7s	remaining: 8.76s
1623:	learn: 0.1810342	total: 37.7s	remaining: 8.73s
1624:	learn: 0.1810228	total: 37.7s	remaining: 8.71s
1625:	learn: 0.1810127	total: 37.8s	remaining: 8.69s
1626:	learn: 0.1810053	total: 37.8s	remaining: 8.66s
1627:	learn: 0.1809915	total: 37.8s	remaining: 8.64s
1628:	learn: 0.1809814	total: 37.8s	remaining: 8.62s
1629:	learn: 0.1809669	total: 37.9s	remaining: 8.59s
1630:	learn: 0.1809496	total: 37.9s	remaining: 8.57s
1631:	learn: 0.1809352	total: 37.9s	remaining: 8.55s
1632:	learn: 0.1809226	total: 37.9s	remaining: 8.52s
1633:	learn: 0.1809077	total: 38s	remaining: 8.5s
1634:	learn: 0.1808954	total: 38s	remaining: 8.48s
1635:	learn: 0.1808833	total: 38s	remaining: 8.46s
1636:	learn: 0.1808720	total: 38s	remaining: 8.44s
1637:	learn: 0.1808592	total: 38.1s	remaining: 8.41s
1638:	learn: 0.1808364	total: 38.1s	remaining: 8.39s
1639:	learn: 0.1808261	total: 38.1s	remaining: 8.37s
16

1779:	learn: 0.1789561	total: 41.5s	remaining: 5.13s
1780:	learn: 0.1789333	total: 41.5s	remaining: 5.11s
1781:	learn: 0.1789215	total: 41.6s	remaining: 5.08s
1782:	learn: 0.1789092	total: 41.6s	remaining: 5.06s
1783:	learn: 0.1788982	total: 41.6s	remaining: 5.04s
1784:	learn: 0.1788851	total: 41.6s	remaining: 5.01s
1785:	learn: 0.1788689	total: 41.7s	remaining: 4.99s
1786:	learn: 0.1788530	total: 41.7s	remaining: 4.97s
1787:	learn: 0.1788434	total: 41.7s	remaining: 4.94s
1788:	learn: 0.1788310	total: 41.7s	remaining: 4.92s
1789:	learn: 0.1788169	total: 41.7s	remaining: 4.9s
1790:	learn: 0.1788054	total: 41.8s	remaining: 4.87s
1791:	learn: 0.1787939	total: 41.8s	remaining: 4.85s
1792:	learn: 0.1787822	total: 41.8s	remaining: 4.83s
1793:	learn: 0.1787683	total: 41.8s	remaining: 4.8s
1794:	learn: 0.1787571	total: 41.9s	remaining: 4.78s
1795:	learn: 0.1787404	total: 41.9s	remaining: 4.76s
1796:	learn: 0.1787267	total: 41.9s	remaining: 4.74s
1797:	learn: 0.1787128	total: 41.9s	remaining: 4

1942:	learn: 0.1769381	total: 45.5s	remaining: 1.33s
1943:	learn: 0.1769228	total: 45.5s	remaining: 1.31s
1944:	learn: 0.1769095	total: 45.5s	remaining: 1.29s
1945:	learn: 0.1768992	total: 45.6s	remaining: 1.26s
1946:	learn: 0.1768914	total: 45.6s	remaining: 1.24s
1947:	learn: 0.1768797	total: 45.6s	remaining: 1.22s
1948:	learn: 0.1768686	total: 45.6s	remaining: 1.19s
1949:	learn: 0.1768617	total: 45.6s	remaining: 1.17s
1950:	learn: 0.1768490	total: 45.7s	remaining: 1.15s
1951:	learn: 0.1768319	total: 45.7s	remaining: 1.12s
1952:	learn: 0.1768187	total: 45.7s	remaining: 1.1s
1953:	learn: 0.1768111	total: 45.7s	remaining: 1.08s
1954:	learn: 0.1768051	total: 45.8s	remaining: 1.05s
1955:	learn: 0.1767887	total: 45.8s	remaining: 1.03s
1956:	learn: 0.1767732	total: 45.8s	remaining: 1.01s
1957:	learn: 0.1767663	total: 45.8s	remaining: 983ms
1958:	learn: 0.1767563	total: 45.9s	remaining: 960ms
1959:	learn: 0.1767468	total: 45.9s	remaining: 937ms
1960:	learn: 0.1767366	total: 45.9s	remaining: 

In [22]:
def train_xgb_with_best_params(X_train, X_test, y_train, y_test):
    mlflow.sklearn.autolog(disable=True)

    param_grid = {
        'learning_rate': [0.01, 0.05],
        'max_depth': [10, 20],
        'n_estimators': [5000],
        'learning_rate': [0.01, 0.05, 0.1],
        'subsample': [0.5, 0.8, 1.0],
        'colsample_bytree': [0.5, 0.8, 1.0]
    }

    xgb_model = xgb.XGBRegressor(random_state=0)

    grid_search = GridSearchCV(estimator=xgb_model, param_grid=param_grid, scoring='neg_mean_absolute_error', cv=3)
    grid_search.fit(X_train, y_train)

    print("Best params:", grid_search.best_params_)

    best_params = grid_search.best_params_
    
    with mlflow.start_run(run_name="Final_XGB"):
    
        mlflow.set_tag("Final_XGB", "XGB")
        mlflow.log_params(best_params)
            
        final_xgb_model = xgb.XGBRegressor(**best_params)
        final_xgb_model.fit(X_train, y_train)
            
        y_pred = final_xgb_model.predict(X_test)
        y_pred = np.exp(y_pred)
        y_test_exp = np.exp(y_test)
    
        mae = mean_absolute_error(y_test_exp, y_pred)
            
        mlflow.log_metric("test_mae", mae)
        mlflow.sklearn.log_model(final_xgb_model, f"Final_XGB_Model")

In [23]:
train_xgb_with_best_params(X_train, X_test, y_train, y_test)

Best params: {'colsample_bytree': 0.5, 'learning_rate': 0.01, 'max_depth': 10, 'n_estimators': 5000, 'subsample': 1.0}
