# Train gradient boosting model

# 1. Imports

## 1.1 Packages

In [11]:
import sys

import pandas as pd


## 1.2 Options

In [12]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
sys.path.append('../src')

from velib_prediction.pipelines.train_model.mlflow import (  # noqa: E402
    create_mlflow_experiment,
)
from velib_prediction.pipelines.train_model.nodes import (  # noqa: E402
    get_split_train_val_cv,
    train_model_cv_mlflow,
)


## 1.3 Datasets

In [14]:
df_train = pd.read_parquet("../data/04_feature/df_feat_train.parquet")
df_train.sample(2)

Unnamed: 0,idx,stationcode,is_installed,capacity,numdocksavailable,numbikesavailable,mechanical,ebike,is_renting,is_returning,duedate,coordonnees_geo,code_insee_commune,date,duedate_year,duedate_month,duedate_day,duedate_weekday,duedate_weekend
1,90201729821482,9020,1,21,20,1,0,1,1,1,2024-10-25 01:58:02+00:00,"{'lat': 48.87929591733507, 'lon': 2.3373600840...",75056,2024-10-25,2024,10,25,4,0
4,70021729666465,7002,1,35,27,7,7,0,1,1,2024-10-23 06:54:25+00:00,"{'lat': 48.848563233059, 'lon': 2.3204218259346}",75056,2024-10-23,2024,10,23,2,0


# 2. Prepare datasets

In [15]:
df_train.drop(
    columns=[
        "duedate", "coordonnees_geo", "date"
    ],
    inplace=True
)

df_train.rename(columns={"numbikesavailable": "target"}, inplace=True)

In [16]:
list_df = get_split_train_val_cv(df_train, n_splits=3)

In [17]:
len(list_df)

3

In [18]:
feat_cat = [
    "is_installed",
    "is_renting",
    "is_returning",
    "code_insee_commune",
    "duedate_weekend",
]

# 3. Train model

In [19]:
experiment_id = create_mlflow_experiment(
    experiment_folder_path="../reports/mlflow/",
    experiment_name="velib_prediction"
)
experiment_id

'232251782342021390'

In [20]:
params_catboost = {
    "iterations": 100,
    "depth": 7,
}

In [21]:
train_model_cv_mlflow(
    run_name="Test_catboost",
    experiment_id=experiment_id,
    list_train_valid=list_df,
    feat_cat=feat_cat,
    verbose=10,
    **params_catboost
)

Learning rate set to 0.179928
0:	learn: 7.8937182	test: 8.5246575	best: 8.5246575 (0)	total: 57.2ms	remaining: 5.66s
10:	learn: 2.6720617	test: 3.4482533	best: 3.4482533 (10)	total: 60.8ms	remaining: 492ms
20:	learn: 1.2439592	test: 2.1852701	best: 2.1852701 (20)	total: 64.9ms	remaining: 244ms
30:	learn: 0.7914949	test: 1.8097632	best: 1.8097632 (30)	total: 67.9ms	remaining: 151ms
40:	learn: 0.6124834	test: 1.6922961	best: 1.6922961 (40)	total: 72.5ms	remaining: 104ms
50:	learn: 0.4875047	test: 1.6169685	best: 1.6169685 (50)	total: 76.7ms	remaining: 73.7ms
60:	learn: 0.3924059	test: 1.5613925	best: 1.5613925 (60)	total: 80.6ms	remaining: 51.6ms
70:	learn: 0.3330397	test: 1.5347338	best: 1.5347338 (70)	total: 128ms	remaining: 52.4ms
80:	learn: 0.2771502	test: 1.5094332	best: 1.5094332 (80)	total: 134ms	remaining: 31.3ms
90:	learn: 0.2303746	test: 1.4894493	best: 1.4894493 (90)	total: 139ms	remaining: 13.7ms
99:	learn: 0.2043448	test: 1.4816550	best: 1.4816550 (99)	total: 143ms	remaining



Downloading artifacts:   0%|          | 0/7 [00:00<?, ?it/s]



Learning rate set to 0.200535
0:	learn: 7.8998344	test: 9.1757075	best: 9.1757075 (0)	total: 429us	remaining: 42.5ms
10:	learn: 2.0744299	test: 2.9789295	best: 2.9789295 (10)	total: 6.21ms	remaining: 50.2ms
20:	learn: 0.9610734	test: 1.9268642	best: 1.9268642 (20)	total: 9.98ms	remaining: 37.6ms
30:	learn: 0.6671856	test: 1.6409085	best: 1.6409085 (30)	total: 13.9ms	remaining: 31ms
40:	learn: 0.5185374	test: 1.5291467	best: 1.5291467 (40)	total: 17.4ms	remaining: 25ms
50:	learn: 0.4220873	test: 1.4829977	best: 1.4829977 (50)	total: 23.3ms	remaining: 22.4ms
60:	learn: 0.3545840	test: 1.4474840	best: 1.4462133 (59)	total: 26.6ms	remaining: 17ms
70:	learn: 0.3131330	test: 1.4279304	best: 1.4279304 (70)	total: 30.7ms	remaining: 12.5ms
80:	learn: 0.2784170	test: 1.4122493	best: 1.4122493 (80)	total: 35ms	remaining: 8.2ms
90:	learn: 0.2445950	test: 1.3976580	best: 1.3976580 (90)	total: 38.5ms	remaining: 3.81ms
99:	learn: 0.2214155	test: 1.3884084	best: 1.3878648 (98)	total: 41.5ms	remaining:

Downloading artifacts:   0%|          | 0/7 [00:00<?, ?it/s]



Learning rate set to 0.213688
0:	learn: 8.0950644	test: 8.1451373	best: 8.1451373 (0)	total: 629us	remaining: 62.3ms
10:	learn: 1.7462678	test: 2.3358006	best: 2.3358006 (10)	total: 4.76ms	remaining: 38.5ms
20:	learn: 0.7706031	test: 1.4830840	best: 1.4830840 (20)	total: 9.41ms	remaining: 35.4ms
30:	learn: 0.5472353	test: 1.3399778	best: 1.3399778 (30)	total: 13.3ms	remaining: 29.7ms
40:	learn: 0.4424857	test: 1.2863818	best: 1.2863818 (40)	total: 17.4ms	remaining: 25ms
50:	learn: 0.3497947	test: 1.2408420	best: 1.2408420 (50)	total: 23.1ms	remaining: 22.2ms
60:	learn: 0.3024409	test: 1.2060134	best: 1.2060134 (60)	total: 27.3ms	remaining: 17.4ms
70:	learn: 0.2623036	test: 1.1866811	best: 1.1866811 (70)	total: 31.5ms	remaining: 12.9ms
80:	learn: 0.2348587	test: 1.1675150	best: 1.1675150 (80)	total: 35.5ms	remaining: 8.32ms
90:	learn: 0.2062557	test: 1.1525349	best: 1.1525349 (90)	total: 39.8ms	remaining: 3.93ms
99:	learn: 0.1934106	test: 1.1467537	best: 1.1467537 (99)	total: 43.6ms	rem

Downloading artifacts:   0%|          | 0/7 [00:00<?, ?it/s]