## Регрессия поломок

Устанавливаем три библиотеки Python:

catboost: Для построения модели машинного обучения.

featuretools: Для автоматического создания новых признаков из данных.

optuna: Для поиска наилучших настроек для модели.

Эти библиотеки используются вместе для создания модели для предсказания поломок автомобилей на основе предоставленных данных.

In [1]:
!pip install catboost -q
!pip install featuretools -q
!pip install optuna -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.7/98.7 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m587.9/587.9 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.2/215.2 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m364.4/364.4 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.5/233.5 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25h

Отключаем вывод предупреждений о потенциальных проблемах, чтобы не загромождать вывод программы.

In [2]:
import warnings

warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", UserWarning)

In [3]:
import featuretools as ft
import numpy as np
import optuna
import pandas as pd
from catboost import CatBoostRegressor
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold
from woodwork.logical_types import Age, Categorical, Datetime

In [4]:
optuna.logging.set_verbosity(optuna.logging.WARNING)

Этот код подготавливает данные и параметры для создания модели, которая будет предсказывать "target_class" автомобиля на основе других его характеристик.

In [5]:
N_SPLITS = 3
N_TRIALS = 50
RANDOM_SEED = 42
SEARCH_BEST_PARAMS = False

CAT_FEATURES = ["model", "car_type", "fuel_type"]
TARGET_COL = "target_reg"

FEATURES_TO_DROP = ["car_id", "target_reg", "deviation_normal_count"]

## Загрузка данных

Загружаем данные о машинах из 5 файлов. Затем разделяет данные на обучающую и тестовую выборки. Удаляет столбец с ответом (target_class) из обучающей выборки и объединяет обучающую и тестовую выборки для дальнейшей обработки.

В итоге, код подготавливает данные для создания модели, которая будет предсказывать поломки автомобилей.

In [6]:
car_info_train = pd.read_csv("car_train.csv")
car_info_test = pd.read_csv("car_test.csv")
rides_info = pd.read_csv("rides_info.csv")
driver_info = pd.read_csv("driver_info.csv")
fix_info = pd.read_csv("fix_info.csv")

train_cars = car_info_train["car_id"]
test_cars = car_info_test["car_id"]
y_train = car_info_train[TARGET_COL]

car_info_train = car_info_train.drop(columns=[TARGET_COL, "target_class"])
all_data = pd.concat([car_info_train, car_info_test], ignore_index=True)

## Feature Engineering

Создаем новые признаки (фич) из данных, чтобы улучшить работу модели машинного обучения, которая будет предсказывать поломки автомобилей.

Что делает код:

1. Объединяет данные: Создает связи между разными таблицами (информация о машинах, поездках, водителях, ремонтах) на основе общих столбцов (например, car_id).
2. Создает новые признаки: Автоматически генерирует новые признаки, комбинируя информацию из связанных таблиц. Например, может создать признак "среднее количество поездок в месяц для данной модели машины".
3. Очищает признаки: Удаляет ненужные признаки, которые не несут полезной информации (например, признаки, имеющие только одно значение).
4. Готовит данные для модели: Разделяет данные на обучающую и тестовую выборки, удаляет ненужные столбцы и обновляет список категориальных признаков.

В результате: код создает новые, потенциально полезные признаки, которые помогут модели машинного обучения лучше предсказывать поломки автомобилей.

Ключевые моменты:

EntitySet: Контейнер, который хранит связанные таблицы.

ft.dfs: Функция, которая автоматически генерирует новые признаки.

CAT_FEATURES: Список категориальных признаков.

X_train, X_test: Матрицы признаков для обучения и тестирования модели.

In [7]:
# Создаём отношения между источниками данных
es = ft.EntitySet(id="car_data")

es = es.add_dataframe(
    dataframe_name="cars",
    dataframe=all_data,
    index="car_id",
    logical_types={
        "car_type": Categorical,
        "fuel_type": Categorical,
        "model": Categorical,
    },
)

es = es.add_dataframe(
    dataframe_name="rides",
    dataframe=rides_info.drop(["ride_id"], axis=1),
    index="index",
    time_index="ride_date",
)

es = es.add_dataframe(
    dataframe_name="drivers",
    dataframe=driver_info,
    index="user_id",
    logical_types={
        "sex": Categorical,
        "first_ride_date": Datetime,
        "age": Age,
    },
)

es = es.add_dataframe(
    dataframe_name="fixes",
    dataframe=fix_info,
    index="index",
    logical_types={
        "work_type": Categorical,
        "worker_id": Categorical,
    },
)

es = es.add_relationship("cars", "car_id", "rides", "car_id")
es = es.add_relationship("drivers", "user_id", "rides", "user_id")
es = es.add_relationship("cars", "car_id", "fixes", "car_id")

# Генерируем фичи
all_data, _ = ft.dfs(
    entityset=es,
    target_dataframe_name="cars",
    max_depth=2,
)

# Удаляем константные фичи
all_data = ft.selection.remove_single_value_features(all_data)

train_data = all_data.loc[train_cars].reset_index()
test_data = all_data.loc[test_cars].reset_index()

X_train = train_data.drop(columns=FEATURES_TO_DROP, errors="ignore")
X_test = test_data.drop(columns=FEATURES_TO_DROP, errors="ignore")

# Обновляем список категориальных фичей
CAT_FEATURES += [col for col in X_train if col.startswith("MODE")]

In [8]:
X_train.head()

Unnamed: 0,model,car_type,fuel_type,car_rating,year_to_start,riders,year_to_work,MAX(rides.deviation_normal),MAX(rides.distance),MAX(rides.rating),MAX(rides.refueling),MAX(rides.ride_cost),MAX(rides.ride_duration),MAX(rides.speed_avg),MAX(rides.speed_max),MAX(rides.stop_times),MAX(rides.user_ride_quality),MEAN(rides.deviation_normal),MEAN(rides.distance),MEAN(rides.rating),MEAN(rides.refueling),MEAN(rides.ride_cost),MEAN(rides.ride_duration),MEAN(rides.speed_avg),MEAN(rides.speed_max),MEAN(rides.stop_times),MEAN(rides.user_ride_quality),MIN(rides.deviation_normal),MIN(rides.distance),MIN(rides.rating),MIN(rides.ride_cost),MIN(rides.ride_duration),MIN(rides.speed_avg),MIN(rides.speed_max),MIN(rides.user_ride_quality),SKEW(rides.deviation_normal),SKEW(rides.distance),SKEW(rides.rating),SKEW(rides.refueling),SKEW(rides.ride_cost),SKEW(rides.ride_duration),SKEW(rides.speed_avg),SKEW(rides.speed_max),SKEW(rides.stop_times),SKEW(rides.user_ride_quality),STD(rides.deviation_normal),STD(rides.distance),STD(rides.rating),STD(rides.refueling),STD(rides.ride_cost),...,SUM(rides.ride_cost),SUM(rides.ride_duration),SUM(rides.speed_avg),SUM(rides.speed_max),SUM(rides.stop_times),SUM(rides.user_ride_quality),COUNT(fixes),MAX(fixes.destroy_degree),MAX(fixes.work_duration),MEAN(fixes.destroy_degree),MEAN(fixes.work_duration),MIN(fixes.work_duration),MODE(fixes.work_type),MODE(fixes.worker_id),NUM_UNIQUE(fixes.work_type),NUM_UNIQUE(fixes.worker_id),SKEW(fixes.destroy_degree),SKEW(fixes.work_duration),STD(fixes.destroy_degree),STD(fixes.work_duration),SUM(fixes.destroy_degree),SUM(fixes.work_duration),MAX(rides.drivers.age),MAX(rides.drivers.user_rating),MAX(rides.drivers.user_rides),MAX(rides.drivers.user_time_accident),MEAN(rides.drivers.age),MEAN(rides.drivers.user_rating),MEAN(rides.drivers.user_rides),MEAN(rides.drivers.user_time_accident),MIN(rides.drivers.age),MIN(rides.drivers.user_rating),MIN(rides.drivers.user_rides),MIN(rides.drivers.user_time_accident),MODE(rides.DAY(ride_date)),MODE(rides.MONTH(ride_date)),MODE(rides.WEEKDAY(ride_date)),MODE(rides.drivers.sex),SKEW(rides.drivers.age),SKEW(rides.drivers.user_rating),SKEW(rides.drivers.user_rides),SKEW(rides.drivers.user_time_accident),STD(rides.drivers.age),STD(rides.drivers.user_rating),STD(rides.drivers.user_rides),STD(rides.drivers.user_time_accident),SUM(rides.drivers.age),SUM(rides.drivers.user_rating),SUM(rides.drivers.user_rides),SUM(rides.drivers.user_time_accident)
0,Kia Rio X-line,economy,petrol,3.78,2015,76163,2021,0.001,1849349.0,9.44,0.0,523483.0,37392.0,77.0,180.855726,20.0,11.035871,-0.120391,69777.646008,4.737759,0.0,20106.873563,1635.770115,44.66092,87.183965,3.62069,-0.90119,-9.0,71.606506,0.1,44.0,4.0,25.0,38.0,-10.501738,-8.207991,4.69761,0.129744,0.0,4.733686,4.508536,0.400501,0.845063,1.566092,-0.080287,0.86175,276803.751158,2.00167,0.0,77483.240255,...,3498596.0,284624.0,7771.0,14734.090155,630.0,-156.807066,35,9.0,56.0,3.048571,26.657143,7.0,reparking,LR,4,33,0.835907,0.826462,2.732847,10.171884,106.7,933.0,62.0,9.8,2626.0,77.0,33.511494,8.229885,828.034483,17.724138,18.0,6.2,5.0,0.0,1,1,2,1,0.268286,-0.170288,0.75113,1.813293,10.109652,0.611473,546.505545,17.294174,5831.0,1432.0,144078.0,2056.0
1,VW Polo VI,economy,petrol,3.9,2015,78218,2021,47.673,2762119.0,10.0,0.0,609068.0,38450.0,88.0,187.862734,3.0,32.610351,6.050011,103672.947347,4.480517,0.0,26813.614943,2223.178161,49.862069,89.474427,0.833333,14.018105,-25.088,152.315513,0.0,74.0,7.0,25.0,39.0,0.437053,0.415372,4.35128,0.337138,0.0,4.166274,3.739653,0.838893,0.801942,0.768856,0.387577,16.102649,379729.690043,2.829132,0.0,98135.97596,...,4665569.0,386833.0,8676.0,15568.550271,145.0,2439.150253,35,10.0,48.0,2.917143,24.942857,4.0,reparking,YH,5,34,0.997276,-0.296841,2.707233,8.574733,102.1,873.0,57.0,9.8,2821.0,23.0,34.988506,7.988506,924.804598,6.965517,18.0,6.7,7.0,0.0,1,1,2,0,-0.023846,0.23613,0.555751,0.74935,9.592259,0.574089,585.134416,4.91182,6088.0,1390.0,160916.0,1212.0
2,Renault Sandero,standart,petrol,6.3,2012,23340,2017,4.001,1744243.0,9.7,0.0,561786.0,39922.0,73.0,102.382857,3.0,11.766087,-2.223954,91285.399011,4.768391,0.0,23987.793103,2048.856322,44.005747,67.473599,0.804598,0.722771,-12.4,41.231207,0.1,17.0,2.0,25.0,36.0,-12.535368,-0.771747,3.849076,0.279173,0.0,4.155873,3.851109,0.393552,0.21992,0.792589,-0.400598,3.103029,334230.131905,1.78503,0.0,88190.939557,...,4173876.0,356501.0,7657.0,11740.406294,140.0,119.257195,35,10.0,59.0,3.74,26.142857,1.0,repair,AP,5,35,0.472628,0.671481,2.978077,13.040983,130.9,915.0,58.0,9.5,2617.0,25.0,32.83908,7.843103,940.04023,9.775862,18.0,6.2,1.0,0.0,1,1,2,1,0.173034,0.078351,0.374761,0.424371,9.277823,0.665963,531.541486,5.42647,5714.0,1364.7,163567.0,1701.0
3,Mercedes-Benz GLC,business,petrol,4.04,2011,1263,2020,48.956,2167931.0,10.0,0.0,1956795.0,39725.0,88.0,172.793237,3.0,3.93104,14.771948,94935.797502,3.88092,0.0,37839.086207,1943.511494,49.344828,86.661339,0.862069,-4.29037,-12.691,63.043831,0.1,73.0,6.0,25.0,37.0,-10.723544,0.527094,4.371611,0.189553,0.0,7.880247,4.15767,0.800698,0.821062,0.633564,0.688573,14.934988,372353.889033,2.326176,0.0,181004.405078,...,6584001.0,338171.0,8586.0,15079.073004,150.0,-746.524431,35,10.0,64.0,4.085714,28.771429,1.0,repair,LM,4,34,0.492743,0.63949,3.23775,14.764994,143.0,1007.0,57.0,10.0,2626.0,86.0,34.977011,8.524138,951.126437,19.991304,21.0,8.0,155.0,0.0,1,1,2,0,0.195056,0.661985,0.780447,1.680088,8.458641,0.431415,514.264373,20.782287,6086.0,1483.2,165496.0,2299.0
4,Renault Sandero,standart,petrol,4.7,2012,26428,2017,49.269,2167675.0,9.94,0.0,516203.0,36872.0,89.0,203.462289,3.0,3.344463,12.455678,80363.072784,4.181149,0.0,19888.431034,1597.793103,50.603448,86.263698,0.758621,-13.465342,-20.907,58.707324,0.1,25.0,3.0,26.0,31.0,-25.742137,0.021883,4.494934,0.169853,0.0,4.684331,4.284837,0.773263,1.036183,0.732646,0.31186,20.915945,331395.523759,2.359131,0.0,81541.27724,...,3460587.0,278016.0,8805.0,15009.883456,132.0,-2342.969445,35,10.0,65.0,3.88,28.028571,10.0,repair,CD,4,34,0.478043,1.341642,3.216758,12.659537,135.8,981.0,57.0,10.0,2481.0,72.0,34.298851,8.112069,896.229885,15.758333,18.0,6.5,5.0,0.0,1,1,2,0,0.014715,0.166042,0.448681,1.840095,10.327151,0.66864,570.284478,15.192322,5968.0,1411.5,155944.0,1891.0


In [9]:
X_test.head()

Unnamed: 0,model,car_type,fuel_type,car_rating,year_to_start,riders,year_to_work,MAX(rides.deviation_normal),MAX(rides.distance),MAX(rides.rating),MAX(rides.refueling),MAX(rides.ride_cost),MAX(rides.ride_duration),MAX(rides.speed_avg),MAX(rides.speed_max),MAX(rides.stop_times),MAX(rides.user_ride_quality),MEAN(rides.deviation_normal),MEAN(rides.distance),MEAN(rides.rating),MEAN(rides.refueling),MEAN(rides.ride_cost),MEAN(rides.ride_duration),MEAN(rides.speed_avg),MEAN(rides.speed_max),MEAN(rides.stop_times),MEAN(rides.user_ride_quality),MIN(rides.deviation_normal),MIN(rides.distance),MIN(rides.rating),MIN(rides.ride_cost),MIN(rides.ride_duration),MIN(rides.speed_avg),MIN(rides.speed_max),MIN(rides.user_ride_quality),SKEW(rides.deviation_normal),SKEW(rides.distance),SKEW(rides.rating),SKEW(rides.refueling),SKEW(rides.ride_cost),SKEW(rides.ride_duration),SKEW(rides.speed_avg),SKEW(rides.speed_max),SKEW(rides.stop_times),SKEW(rides.user_ride_quality),STD(rides.deviation_normal),STD(rides.distance),STD(rides.rating),STD(rides.refueling),STD(rides.ride_cost),...,SUM(rides.ride_cost),SUM(rides.ride_duration),SUM(rides.speed_avg),SUM(rides.speed_max),SUM(rides.stop_times),SUM(rides.user_ride_quality),COUNT(fixes),MAX(fixes.destroy_degree),MAX(fixes.work_duration),MEAN(fixes.destroy_degree),MEAN(fixes.work_duration),MIN(fixes.work_duration),MODE(fixes.work_type),MODE(fixes.worker_id),NUM_UNIQUE(fixes.work_type),NUM_UNIQUE(fixes.worker_id),SKEW(fixes.destroy_degree),SKEW(fixes.work_duration),STD(fixes.destroy_degree),STD(fixes.work_duration),SUM(fixes.destroy_degree),SUM(fixes.work_duration),MAX(rides.drivers.age),MAX(rides.drivers.user_rating),MAX(rides.drivers.user_rides),MAX(rides.drivers.user_time_accident),MEAN(rides.drivers.age),MEAN(rides.drivers.user_rating),MEAN(rides.drivers.user_rides),MEAN(rides.drivers.user_time_accident),MIN(rides.drivers.age),MIN(rides.drivers.user_rating),MIN(rides.drivers.user_rides),MIN(rides.drivers.user_time_accident),MODE(rides.DAY(ride_date)),MODE(rides.MONTH(ride_date)),MODE(rides.WEEKDAY(ride_date)),MODE(rides.drivers.sex),SKEW(rides.drivers.age),SKEW(rides.drivers.user_rating),SKEW(rides.drivers.user_rides),SKEW(rides.drivers.user_time_accident),STD(rides.drivers.age),STD(rides.drivers.user_rating),STD(rides.drivers.user_rides),STD(rides.drivers.user_time_accident),SUM(rides.drivers.age),SUM(rides.drivers.user_rating),SUM(rides.drivers.user_rides),SUM(rides.drivers.user_time_accident)
0,Skoda Rapid,economy,petrol,4.8,2013,42269,2019,36.661,3021921.0,8.91,0.0,336198.0,37356.0,100.0,195.454152,3.0,23.3624,16.664374,80893.068947,3.746207,0.0,14495.442529,1316.275862,51.482759,100.987907,0.827586,9.229144,-11.495,19.59009,0.1,19.0,3.0,25.0,38.446584,-10.297021,-0.593285,5.930494,0.184308,0.0,4.229911,4.641263,0.667228,0.453493,0.62451,-0.428845,12.162105,336344.827625,2.326693,0.0,51964.626097,...,2522207.0,229032.0,8958.0,17571.895793,144.0,1605.871135,35,8.0,49.0,2.765714,23.485714,4.0,reparking,BB,4,35,0.750094,0.312615,2.396682,8.125383,96.8,822.0,57.0,9.8,3287.0,25.0,33.706897,7.894253,849.212644,9.568966,18.0,5.8,10.0,0.0,1,1,2,1,0.166132,0.124766,0.735757,0.711919,10.280475,0.632886,580.542915,4.72276,5865.0,1373.6,147763.0,1665.0
1,Renault Sandero,standart,petrol,4.32,2015,90014,2016,0.085,2279221.0,10.0,0.0,629499.0,39644.0,87.0,181.538685,4.0,30.378596,-0.082115,113240.814175,4.318966,0.0,27256.063218,2215.45977,50.356322,92.576238,0.752874,12.871315,-7.925,38.730162,0.0,9.0,2.0,25.0,38.0,-1.288132,-8.796293,3.790663,0.398829,0.0,4.062327,3.697281,0.728297,0.453788,0.948945,0.579107,0.740807,400328.058724,2.512479,0.0,98130.403053,...,4742555.0,385490.0,8762.0,15737.960397,131.0,2239.60877,35,10.0,48.0,3.657143,25.428571,4.0,reparking,XX,3,34,0.640624,0.171686,2.939616,10.757448,128.0,890.0,62.0,9.8,3207.0,24.0,34.850575,7.998276,857.528736,6.896552,18.0,6.3,1.0,0.0,1,1,2,0,-0.002022,-0.1481,1.238131,0.932452,10.115551,0.617048,600.837062,4.87422,6064.0,1391.7,149210.0,1200.0
2,Smart ForTwo,economy,petrol,4.46,2015,82684,2017,3.567,1828386.0,10.0,0.0,439306.0,35985.0,79.0,118.440645,3.0,33.807138,-3.915529,53534.172552,5.134655,0.0,14099.586207,1207.988506,43.068966,64.670517,0.844828,17.404267,-25.434,48.442541,0.1,33.0,3.0,25.0,35.0,-2.728841,-0.842484,5.524373,0.058986,0.0,5.140408,5.261929,0.987392,0.515213,0.742176,0.08096,9.326006,256038.037148,1.964062,0.0,63868.149995,...,2453328.0,210190.0,7494.0,11252.669928,147.0,3028.34246,35,10.0,60.0,3.462857,28.342857,3.0,reparking,TF,4,34,0.82347,0.928644,3.168789,11.846923,121.2,992.0,60.0,9.6,3274.0,24.0,34.545977,8.027586,937.206897,7.557471,18.0,6.4,0.0,0.0,1,1,2,1,0.128273,-0.026184,0.670552,0.774293,10.16657,0.594423,596.865263,5.154786,6011.0,1396.8,163074.0,1315.0
3,Smart ForFour,economy,petrol,2.8,2014,68833,2021,0.002,1262004.0,9.68,0.0,359047.0,32641.0,90.0,112.829785,3.0,29.055693,-2.228115,53659.988217,4.617356,0.0,14594.12069,1269.683908,43.804598,67.004679,0.833333,18.699584,-9.999,45.220865,0.83,11.0,2.0,25.0,34.0,0.396298,-0.934398,4.180673,0.281985,0.0,4.323927,4.25492,1.045886,0.386085,0.511452,-0.746454,2.852085,206525.781584,1.9354,0.0,56369.354002,...,2539377.0,220925.0,7622.0,11658.814225,145.0,3141.530149,35,10.0,66.0,2.945714,25.885714,1.0,reparking,HV,4,32,1.224997,0.968681,2.882946,13.150199,103.1,906.0,59.0,9.6,3211.0,25.0,34.431034,7.840805,877.091954,9.064327,18.0,6.4,2.0,0.0,1,1,2,0,-0.125441,0.252493,0.988304,0.546549,9.414867,0.612586,528.869754,4.728072,5991.0,1364.3,152614.0,1550.0
4,Skoda Rapid,economy,petrol,6.56,2013,42442,2021,7.628,1950412.0,10.0,0.0,379361.0,29182.0,89.0,187.846088,4.0,25.107074,-22.83954,68749.984881,4.287471,0.0,15266.925287,1326.706897,52.574713,91.587136,0.87931,6.732205,-73.471,48.17461,0.0,27.0,3.0,25.0,39.0,-4.636368,-1.044706,4.665427,0.043945,0.0,4.395148,4.210903,0.720521,0.744545,0.61202,0.618211,19.169241,274258.068987,2.638532,0.0,59008.921743,...,2656445.0,230847.0,9148.0,15661.400199,153.0,1171.403606,35,8.0,44.0,3.697143,23.8,3.0,repair,TW,4,34,0.104028,-0.262908,2.644805,9.186947,129.4,833.0,63.0,9.5,2742.0,25.0,35.091954,8.001724,853.689655,7.068966,18.0,6.8,19.0,0.0,1,1,2,0,-0.027461,0.274267,0.740067,0.800319,9.535007,0.574252,574.605576,5.318292,6106.0,1392.3,148542.0,1230.0


In [16]:
# Подготовка данных для CatBoost
X_train_catboost = X_train.copy()
X_test_catboost = X_test.copy()
for cat_feature in CAT_FEATURES:
    X_train_catboost[cat_feature] = X_train_catboost[cat_feature].astype(str)
    X_test_catboost[cat_feature] = X_test_catboost[cat_feature].astype(str)


In [22]:
from sklearn.model_selection import train_test_split

# Разделение данных
X_train, X_valid, y_train, y_valid = train_test_split(
    X_train_catboost,  # Входные данные
    y_train,           # Целевой признак
    test_size=0.2,     # Размер тестовой выборки (например, 20%)
    random_state=42    # Для воспроизводимости
)

print(f"Размер тренировочных данных: {X_train.shape}")
print(f"Размер валидационных данных: {X_valid.shape}")


Размер тренировочных данных: (1869, 109)
Размер валидационных данных: (468, 109)


In [23]:
from catboost import CatBoostRegressor
from sklearn.metrics import mean_absolute_error

# Задаем параметры модели
params = {
    "learning_rate": 0.1,       # Шаг обучения
    "depth": 6,                 # Глубина дерева
    "l2_leaf_reg": 3.0,         # Регуляризация
    "random_seed": 42,          # Фиксация случайности
    "eval_metric": "MAE",       # Метрика качества
    "verbose": 100,             # Логи каждые 100 итераций
    "cat_features": CAT_FEATURES,         # Категориальные признаки, если есть
    "od_type": "Iter",          # Тип остановки по раннему прекращению
    "od_wait": 50,              # Количество итераций без улучшения для остановки
}

# Инициализация модели
model = CatBoostRegressor(**params)

model.fit(
    X_train,
    y_train,
    eval_set=(X_valid, y_valid),
    use_best_model=True
)

# Предсказания на тестовой выборке
y_pred = model.predict(X_valid)

# Оценка качества модели
mae = mean_absolute_error(y_valid, y_pred)
print(f"Mean Absolute Error (MAE) на тестовой выборке: {mae}")

# Сохранение модели (опционально)
model.save_model("catboost_regressor_model.cbm", format="cbm")


0:	learn: 12.5768374	test: 13.2367034	best: 13.2367034 (0)	total: 33.4ms	remaining: 33.4s
100:	learn: 4.1120709	test: 5.2284187	best: 5.2284187 (100)	total: 2.66s	remaining: 23.7s
Stopped by overfitting detector  (50 iterations wait)

bestTest = 5.205052981
bestIteration = 109

Shrink model to first 110 iterations.
Mean Absolute Error (MAE) на тестовой выборке: 5.205053961135221


In [25]:
from catboost import CatBoostRegressor

# Загружаем модель из файла
loaded_model = CatBoostRegressor()
loaded_model.load_model("catboost_regressor_model.cbm")

# Предсказания на тестовых данных
y_pred_test = loaded_model.predict(X_test_catboost)

# Добавление предсказаний к тестовым данным
test_data["target_reg"] = y_pred_test

# Сохранение результата в файл
test_data[["car_id", "target_reg"]].to_csv("submission_r.csv", index=False)

print("Предсказания успешно выполнены и сохранены в 'submission_r.csv'.")


Предсказания успешно выполнены и сохранены в 'submission_r.csv'.
