# Train gradient boosting model

# 1. Imports

## 1.1 Packages

In [60]:
import sys

import pandas as pd


## 1.2 Options

In [61]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [62]:
sys.path.append('../src')

from velib_prediction.pipelines.train_model.mlflow import (  # noqa: E402
    create_mlflow_experiment,
)
from velib_prediction.pipelines.train_model.nodes import (  # noqa: E402
    add_lags_sma,
    get_split_train_val_cv,
    train_model_cv_mlflow,
)


In [63]:
lags_to_try = [1,]

In [64]:
feat_date = "duedate"

## 1.3 Datasets

In [65]:
df_train = pd.read_parquet("../data/04_feature/df_feat_train.parquet")
df_train.sample(2)

Unnamed: 0,idx,stationcode,is_installed,capacity,numdocksavailable,numbikesavailable,mechanical,ebike,is_renting,is_returning,duedate,code_insee_commune,duedate_year,duedate_month,duedate_day,duedate_weekday,duedate_weekend
7,131011729778185,13101,1,34,30,3,2,1,1,1,2024-10-24 13:56:25+00:00,75056,2024,10,24,3,0
6,70021730501106,7002,1,35,8,26,22,4,1,1,2024-11-01 22:45:06+00:00,75056,2024,11,1,4,0


In [66]:
df_train.rename(columns={"numbikesavailable": "target"}, inplace=True)

# 2. Prepare datasets

In [67]:
# Add lags defined
df_train = add_lags_sma(df_train, lags_to_try, feat_id='stationcode', feat_date=feat_date, feat_target="target", n_shift=5)

In [68]:
df_train.sort_values(by="duedate", inplace=True)
df_train.drop(columns="duedate", inplace=True)

In [69]:
list_df = get_split_train_val_cv(df_train, n_splits=3)

In [70]:
len(list_df)

3

In [71]:
feat_cat = [
    "is_installed",
    "is_renting",
    "is_returning",
    "code_insee_commune",
    "duedate_weekend",
]

# 3. Train model

In [72]:
experiment_id = create_mlflow_experiment(
    experiment_folder_path="../data/06_models/mlruns",
    experiment_name="velib_prediction"
)
experiment_id

'708543812054389333'

In [73]:
params_catboost = {
    "iterations": 100,
    "depth": 7,
}

In [74]:
train_model_cv_mlflow(
    run_name="Test_catboost",
    experiment_id=experiment_id,
    list_train_valid=list_df,
    feat_cat=feat_cat,
    verbose=10,
    **params_catboost
)



Learning rate set to 0.204216
0:	learn: 8.2879177	test: 7.5947855	best: 7.5947855 (0)	total: 58.9ms	remaining: 5.83s
10:	learn: 1.9229949	test: 1.9464476	best: 1.9464476 (10)	total: 65.6ms	remaining: 531ms
20:	learn: 0.7498882	test: 1.2583431	best: 1.2583431 (20)	total: 70.3ms	remaining: 264ms
30:	learn: 0.5504552	test: 1.1492577	best: 1.1492577 (30)	total: 74.4ms	remaining: 166ms
40:	learn: 0.4386838	test: 1.0906112	best: 1.0906112 (40)	total: 79.4ms	remaining: 114ms
50:	learn: 0.3628740	test: 1.0448733	best: 1.0448733 (50)	total: 85.2ms	remaining: 81.9ms
60:	learn: 0.3082720	test: 1.0204172	best: 1.0204172 (60)	total: 91.5ms	remaining: 58.5ms
70:	learn: 0.2599096	test: 1.0031083	best: 1.0030558 (69)	total: 96ms	remaining: 39.2ms
80:	learn: 0.2211429	test: 0.9847377	best: 0.9847377 (80)	total: 101ms	remaining: 23.6ms
90:	learn: 0.2015777	test: 0.9738408	best: 0.9738408 (90)	total: 107ms	remaining: 10.6ms
99:	learn: 0.1803074	test: 0.9675743	best: 0.9675743 (99)	total: 112ms	remaining:

Downloading artifacts:   0%|          | 0/7 [00:00<?, ?it/s]



Learning rate set to 0.227694
0:	learn: 7.8645281	test: 8.1732645	best: 8.1732645 (0)	total: 611us	remaining: 60.5ms
10:	learn: 1.4827616	test: 2.9689809	best: 2.9689809 (10)	total: 7.68ms	remaining: 62.1ms
20:	learn: 0.7437128	test: 2.2049915	best: 2.2049915 (20)	total: 13.6ms	remaining: 51.2ms
30:	learn: 0.5685405	test: 2.0965980	best: 2.0965980 (30)	total: 18.4ms	remaining: 41ms
40:	learn: 0.4510357	test: 2.0338918	best: 2.0338918 (40)	total: 23.2ms	remaining: 33.4ms
50:	learn: 0.3764778	test: 1.9528864	best: 1.9519673 (49)	total: 29ms	remaining: 27.9ms
60:	learn: 0.3139373	test: 1.9338885	best: 1.9243468 (58)	total: 34.1ms	remaining: 21.8ms
70:	learn: 0.2666599	test: 1.9357011	best: 1.9243468 (58)	total: 39.1ms	remaining: 16ms
80:	learn: 0.2340987	test: 1.9181725	best: 1.9139059 (78)	total: 44.3ms	remaining: 10.4ms
90:	learn: 0.2071453	test: 1.8987260	best: 1.8987260 (90)	total: 49.6ms	remaining: 4.9ms
99:	learn: 0.1880245	test: 1.8957842	best: 1.8942637 (97)	total: 54ms	remaining:

Downloading artifacts:   0%|          | 0/7 [00:00<?, ?it/s]



Learning rate set to 0.24266
0:	learn: 7.6173193	test: 7.5281582	best: 7.5281582 (0)	total: 1.32ms	remaining: 131ms
10:	learn: 1.3328502	test: 2.3253323	best: 2.3253323 (10)	total: 7.32ms	remaining: 59.2ms
20:	learn: 0.7378684	test: 1.7040665	best: 1.7040665 (20)	total: 13.6ms	remaining: 51.1ms
30:	learn: 0.6228388	test: 1.5180303	best: 1.5180303 (30)	total: 19ms	remaining: 42.2ms
40:	learn: 0.5400639	test: 1.3907635	best: 1.3907635 (40)	total: 24.6ms	remaining: 35.3ms
50:	learn: 0.4704229	test: 1.3065026	best: 1.3065026 (50)	total: 72.6ms	remaining: 69.8ms
60:	learn: 0.3872695	test: 1.2378014	best: 1.2378014 (60)	total: 79ms	remaining: 50.5ms
70:	learn: 0.3272282	test: 1.1814235	best: 1.1814235 (70)	total: 85.1ms	remaining: 34.8ms
80:	learn: 0.2875514	test: 1.1551786	best: 1.1551786 (80)	total: 91.2ms	remaining: 21.4ms
90:	learn: 0.2532266	test: 1.1375581	best: 1.1375581 (90)	total: 97.3ms	remaining: 9.62ms
99:	learn: 0.2278953	test: 1.1260603	best: 1.1260603 (99)	total: 103ms	remaini

Downloading artifacts:   0%|          | 0/7 [00:00<?, ?it/s]