In [28]:
import pandas as pd
pd.set_option('display.max_columns', None)
from sklearn.metrics import mean_squared_error, r2_score, median_absolute_error, mean_absolute_error
import numpy as np
import torch
import torch.nn as nn
from pytorch_tabnet.tab_model import TabNetRegressor
from pytorch_tabnet.metrics import Metric
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.preprocessing import OneHotEncoder, LabelEncoder, StandardScaler
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.compose import TransformedTargetRegressor
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, InputLayer
from tensorflow.keras.callbacks import EarlyStopping
import time

import warnings
warnings.filterwarnings("ignore", message="'pin_memory' argument is set as true but not supported on MPS")

In [2]:
df = pd.read_csv('final_data.csv', low_memory=False)
df.salary_gross.fillna(False, inplace=True)
df.shape

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df.salary_gross.fillna(False, inplace=True)
  df.salary_gross.fillna(False, inplace=True)


(709524, 43)

In [3]:
def culc_metrics(y_test, y_pred):
    test_mse = mean_squared_error(y_test, y_pred)
    rmse = test_mse**0.5
    r2 = r2_score(y_test, y_pred)
    mae = mean_absolute_error(y_test, y_pred)

    def symmetric_mean_absolute_percentage_error(y_true, y_pred):
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        smape = 100 * np.mean(2 * np.abs(y_true - y_pred) / (np.abs(y_true) + np.abs(y_pred)))
        return smape

    smape = symmetric_mean_absolute_percentage_error(y_test, y_pred)

    medae = median_absolute_error(y_test, y_pred)

    print(f'Корень из среднеквадратичной ошибки (RMSE): {rmse}')
    print(f"R² Score: {r2}")
    print(f"Средняя абсолютная ошибка (MAE): {mae}")
    print(f"Средняя абсолютная процентная ошибка (SMAPE): {smape:.2f}%")
    print(f"Медианная абсолютная ошибка (MedAE): {medae}")

In [4]:
cat_columns = ['premium', 'has_test', 'response_letter_required', 'area_name', 'salary_currency', 'salary_gross', 'type_name', 'address_city', 'address_metro_station_name', 'address_metro_line_name', 'address_metro_stations_0_line_name', 'archived', 'employer_name', 'employer_accredited_it_employer', 'employer_trusted', 'schedule_name', 'accept_temporary', 'professional_roles_0_name', 'accept_incomplete_resumes', 'experience_name', 'employment_name', 'address_metro_stations_3_station_name', 'address_metro_stations_3_line_name', 'working_time_intervals_0_name', 'working_time_modes_0_name', 'working_days_0_name', 'branding_type', 'branding_tariff', 'department_name', 'insider_interview_id', 'brand_snippet_logo', 'brand_snippet_picture', 'brand_snippet_background_color', 'brand_snippet_background_gradient_angle', 'brand_snippet_background_gradient_color_list_0_position', 'brand_snippet_background_gradient_color_list_1_position', 'category']
text_columns = ['name', 'snippet_requirement', 'snippet_responsibility']
num_columns = ['name_length', 'length']

In [21]:
scaler = StandardScaler()
num_df = pd.DataFrame(scaler.fit_transform(df[num_columns]), columns=num_columns)

In [22]:
label_columns = []
ohe_columns = []

for column in cat_columns:
    if df[column].nunique() > 10:
        label_columns.append(column)
    else:
        ohe_columns.append(column)

to_bool = list(df[cat_columns].select_dtypes(include=['bool']).columns)
df[['salary_gross', 'employer_accredited_it_employer']] = df[['salary_gross', 'employer_accredited_it_employer']].astype(bool).astype(int)
df[to_bool] = df[to_bool].astype(int)

ohe = OneHotEncoder(sparse_output=False, drop='first')
ohe_encoded = ohe.fit_transform(df[ohe_columns])
ohe_feature_names = ohe.get_feature_names_out(ohe_columns).tolist()
encoded_ohe_data = pd.DataFrame(ohe_encoded, columns=ohe_feature_names)

embedding_dim = 5
embeddings = {}

for col in label_columns:
    unique_values = df[col].unique()
    value_to_idx = {v: i for i, v in enumerate(unique_values)}
    df[col+'_idx'] = df[col].map(value_to_idx)

    num_embeddings = len(unique_values)
    embedding_layer = nn.Embedding(num_embeddings, embedding_dim)

    embeddings[col] = {
        'value_to_idx': value_to_idx,
        'embedding': embedding_layer,
        'num_embeddings': num_embeddings
    }

embedded_data = []
for col in label_columns:
    indices = torch.tensor(df[col+'_idx'].values, dtype=torch.long)
    embedded = embeddings[col]['embedding'](indices).detach().numpy()
    embedded_cols = [f"{col}_embed_{i}" for i in range(embedding_dim)]
    embedded_df = pd.DataFrame(embedded, columns=embedded_cols)
    embedded_data.append(embedded_df)

embedded_data = pd.concat(embedded_data, axis=1)
final_data = pd.concat([encoded_ohe_data, embedded_data], axis=1)

for col in label_columns:
    df.drop(col+'_idx', axis=1, inplace=True)

In [7]:
final_data.shape

(709524, 111)

In [8]:
final_data

Unnamed: 0,premium_1,has_test_1,response_letter_required_1,salary_currency_BYR,salary_currency_EUR,salary_currency_GEL,salary_currency_KGS,salary_currency_KZT,salary_currency_RUR,salary_currency_USD,salary_currency_UZS,salary_gross_1,type_name_Закрытая,type_name_Открытая,type_name_Рекламная,archived_1,employer_trusted_1,schedule_name_Гибкий график,schedule_name_Полный день,schedule_name_Сменный график,schedule_name_Удаленная работа,accept_temporary_1,accept_incomplete_resumes_1,experience_name_Нет опыта,experience_name_От 1 года до 3 лет,experience_name_От 3 до 6 лет,employment_name_Полная занятость,employment_name_Проектная работа,employment_name_Стажировка,employment_name_Частичная занятость,working_time_intervals_0_name_Можно сменами по 4-6 часов в день,working_time_modes_0_name_С началом дня после 16:00,working_days_0_name_По субботам и воскресеньям,branding_type_MAKEUP,branding_type_Unknown,branding_tariff_Unknown,insider_interview_id_1,brand_snippet_logo_Unknown,brand_snippet_picture_Unknown,brand_snippet_background_color_#EF3124,brand_snippet_background_color_#FF5B29,brand_snippet_background_color_Unknown,brand_snippet_background_gradient_angle_134.0,brand_snippet_background_gradient_angle_200.0,brand_snippet_background_gradient_angle_206.43,brand_snippet_background_gradient_angle_67.0,brand_snippet_background_gradient_angle_Unknown,brand_snippet_background_gradient_color_list_0_position_0.0,brand_snippet_background_gradient_color_list_0_position_0.52,brand_snippet_background_gradient_color_list_0_position_6.96,brand_snippet_background_gradient_color_list_0_position_Unknown,brand_snippet_background_gradient_color_list_1_position_40.0,brand_snippet_background_gradient_color_list_1_position_88.86,brand_snippet_background_gradient_color_list_1_position_90.95,brand_snippet_background_gradient_color_list_1_position_94.48,brand_snippet_background_gradient_color_list_1_position_Unknown,area_name_embed_0,area_name_embed_1,area_name_embed_2,area_name_embed_3,area_name_embed_4,address_city_embed_0,address_city_embed_1,address_city_embed_2,address_city_embed_3,address_city_embed_4,address_metro_station_name_embed_0,address_metro_station_name_embed_1,address_metro_station_name_embed_2,address_metro_station_name_embed_3,address_metro_station_name_embed_4,address_metro_line_name_embed_0,address_metro_line_name_embed_1,address_metro_line_name_embed_2,address_metro_line_name_embed_3,address_metro_line_name_embed_4,address_metro_stations_0_line_name_embed_0,address_metro_stations_0_line_name_embed_1,address_metro_stations_0_line_name_embed_2,address_metro_stations_0_line_name_embed_3,address_metro_stations_0_line_name_embed_4,employer_name_embed_0,employer_name_embed_1,employer_name_embed_2,employer_name_embed_3,employer_name_embed_4,professional_roles_0_name_embed_0,professional_roles_0_name_embed_1,professional_roles_0_name_embed_2,professional_roles_0_name_embed_3,professional_roles_0_name_embed_4,address_metro_stations_3_station_name_embed_0,address_metro_stations_3_station_name_embed_1,address_metro_stations_3_station_name_embed_2,address_metro_stations_3_station_name_embed_3,address_metro_stations_3_station_name_embed_4,address_metro_stations_3_line_name_embed_0,address_metro_stations_3_line_name_embed_1,address_metro_stations_3_line_name_embed_2,address_metro_stations_3_line_name_embed_3,address_metro_stations_3_line_name_embed_4,department_name_embed_0,department_name_embed_1,department_name_embed_2,department_name_embed_3,department_name_embed_4,category_embed_0,category_embed_1,category_embed_2,category_embed_3,category_embed_4
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,-0.581004,0.010968,-0.752548,0.446675,-1.263296,1.093199,0.043522,-1.025787,1.405245,-1.198341,0.088235,-0.244736,-1.194631,0.306374,0.661689,-0.680968,1.711364,0.694497,-0.831994,1.015268,1.713466,2.309882,-1.969667,1.802627,-0.092293,-0.762315,-0.947418,0.468338,0.164583,-0.806373,0.672523,0.222038,-0.798875,0.784567,0.921089,0.505244,0.986912,-0.358342,2.533038,0.621478,-0.326178,0.881995,0.380914,-2.32725,-0.347610,-1.270978,1.057618,-0.118360,0.472299,-0.503427,1.316130,-0.303547,-0.744560,0.161952,0.734747
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,-0.581004,0.010968,-0.752548,0.446675,-1.263296,1.093199,0.043522,-1.025787,1.405245,-1.198341,0.061260,0.223090,0.556677,1.389389,-0.284192,0.139569,-0.565316,-1.337491,2.620162,1.853694,-2.850189,-1.745484,-0.106043,-0.713799,-0.569794,-0.040465,0.692507,1.404474,0.099583,0.149277,2.387161,-1.145858,0.120309,1.738032,-1.040383,0.505244,0.986912,-0.358342,2.533038,0.621478,-0.326178,0.881995,0.380914,-2.32725,-0.347610,-1.270978,1.057618,-0.118360,0.472299,-0.503427,-0.737019,-0.424337,1.157046,0.006620,1.282567
2,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,-0.581004,0.010968,-0.752548,0.446675,-1.263296,1.093199,0.043522,-1.025787,1.405245,-1.198341,1.369008,-0.253909,0.001691,-0.054280,1.493907,-0.518892,-0.489204,0.589976,-0.774031,-1.433781,0.010874,0.225056,0.185233,-0.315569,-2.640739,0.227236,0.636689,0.340636,-0.228102,0.167382,2.387161,-1.145858,0.120309,1.738032,-1.040383,0.505244,0.986912,-0.358342,2.533038,0.621478,-0.326178,0.881995,0.380914,-2.32725,-0.347610,-1.270978,1.057618,-0.118360,0.472299,-0.503427,-0.032980,2.116684,-0.652107,0.048510,-1.477048
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.313661,0.522936,0.939813,-0.766602,0.097707,1.093199,0.043522,-1.025787,1.405245,-1.198341,-1.422027,-0.233865,1.569913,1.974113,-0.001567,-0.651363,1.266199,-0.385866,-2.035225,-0.034178,0.580323,1.867318,0.519559,0.391230,-1.793014,-1.449637,-0.377762,0.548368,-0.755356,-1.423038,2.387161,-1.145858,0.120309,1.738032,-1.040383,-0.766979,-0.494289,0.416491,1.155987,1.726749,-0.657395,0.005542,-0.288030,-0.63776,0.082974,-1.270978,1.057618,-0.118360,0.472299,-0.503427,1.316130,-0.303547,-0.744560,0.161952,0.734747
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,-0.254966,0.531895,-0.022933,1.375687,1.038304,1.093199,0.043522,-1.025787,1.405245,-1.198341,0.269800,0.069911,0.857201,-0.033175,0.731596,0.184655,-1.372097,-0.474219,0.659218,0.554109,-0.023980,0.575849,-0.435903,0.684621,-0.359004,0.349810,-0.594648,0.380679,0.939181,-2.074155,-1.661391,0.828505,0.502644,-0.550386,1.133081,0.505244,0.986912,-0.358342,2.533038,0.621478,-0.326178,0.881995,0.380914,-2.32725,-0.347610,-1.270978,1.057618,-0.118360,0.472299,-0.503427,-0.911008,-0.004136,-1.045973,-0.177252,0.914669
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
709519,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,-0.396277,0.436552,0.164031,0.285438,-0.865460,0.901887,0.456337,0.112733,-0.948612,0.652436,0.468437,1.023695,0.942781,0.694319,-1.342757,-0.528317,0.354969,1.256162,-0.416956,1.243720,-1.049524,0.168407,0.524139,0.445269,1.885846,-1.733976,-1.287089,-0.495999,-1.492536,0.487114,-0.281298,1.750975,-1.495718,1.325013,-0.683672,0.505244,0.986912,-0.358342,2.533038,0.621478,-0.326178,0.881995,0.380914,-2.32725,-0.347610,-0.137746,-0.548283,0.733579,0.642236,0.208316,0.934244,0.214669,-1.017204,-0.075515,-0.000091
709520,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,-0.396277,0.436552,0.164031,0.285438,-0.865460,0.901887,0.456337,0.112733,-0.948612,0.652436,-0.805078,-1.100066,-0.677999,-0.480609,0.118676,1.933394,2.410229,0.930200,-0.086898,1.439368,0.010260,0.886964,0.377089,-0.195276,0.296280,-0.653520,1.344557,0.808592,0.368274,-0.549407,-0.641724,-1.758996,-0.248562,0.783574,0.055397,0.505244,0.986912,-0.358342,2.533038,0.621478,-0.326178,0.881995,0.380914,-2.32725,-0.347610,-1.270978,1.057618,-0.118360,0.472299,-0.503427,-1.533051,-1.405677,0.049055,-1.356282,-0.311959
709521,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,-0.581004,0.010968,-0.752548,0.446675,-1.263296,1.093199,0.043522,-1.025787,1.405245,-1.198341,0.468437,1.023695,0.942781,0.694319,-1.342757,-0.528317,0.354969,1.256162,-0.416956,1.243720,-1.049524,0.168407,0.524139,0.445269,1.885846,1.245667,0.068925,-0.989368,-0.510114,0.986052,-0.281298,1.750975,-1.495718,1.325013,-0.683672,0.505244,0.986912,-0.358342,2.533038,0.621478,-0.326178,0.881995,0.380914,-2.32725,-0.347610,-1.270978,1.057618,-0.118360,0.472299,-0.503427,0.934244,0.214669,-1.017204,-0.075515,-0.000091
709522,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,-0.475951,-1.061715,-2.106523,-1.051932,-0.774172,1.254188,0.067204,-0.008181,1.377721,0.700979,-1.167074,-0.408475,0.079876,-0.288539,1.723166,1.436755,1.572393,-0.947230,0.143442,-0.255307,0.035718,-1.984588,-0.015106,-0.100081,2.646133,-0.849984,-0.623273,-0.210358,-1.856199,0.935517,0.672523,0.222038,-0.798875,0.784567,0.921089,0.505244,0.986912,-0.358342,2.533038,0.621478,-0.326178,0.881995,0.380914,-2.32725,-0.347610,-1.270978,1.057618,-0.118360,0.472299,-0.503427,1.316130,-0.303547,-0.744560,0.161952,0.734747


In [23]:
X_train, X_test_val, y_train, y_test_val, = train_test_split(final_data, df['salary'], test_size=0.4, random_state=12345)
X_test, X_val, y_test, y_val = train_test_split(X_test_val, y_test_val, test_size=0.5, random_state=12345)

print(f'Размеры выборок: Обучающая {X_train.shape}, Валидационная {X_test.shape}, Тестовая {X_val.shape}')

Размеры выборок: Обучающая (425714, 111), Валидационная (141905, 111), Тестовая (141905, 111)


### Случайный лес с эмбедингами

In [10]:
model_dtr = DecisionTreeRegressor(random_state=12345)

regressor = TransformedTargetRegressor(
    regressor=model_dtr,
    func=np.log,
    inverse_func=np.exp
)


param_grid = {
    'regressor__max_depth': [10, 11, 12, 13, 14, 15, 16],
    'regressor__min_samples_split': [2, 5],
    'regressor__min_samples_leaf': [1, 2]
}

grid_search = GridSearchCV(
    estimator=regressor,
    param_grid=param_grid,
    cv=3,
    scoring='neg_mean_squared_error',
    n_jobs=-1,
    verbose=2
)

grid_search.fit(X_train, y_train)

best_model = grid_search.best_estimator_
best_params = grid_search.best_params_
print(f'Лучшие параметры: {best_params}')


y_pred = best_model.predict(X_test)
culc_metrics(y_test, y_pred)

Fitting 3 folds for each of 28 candidates, totalling 84 fits
Лучшие параметры: {'regressor__max_depth': 15, 'regressor__min_samples_leaf': 2, 'regressor__min_samples_split': 2}
Корень из среднеквадратичной ошибки (RMSE): 48970.264067754295
R² Score: 0.5416319993914924
Средняя абсолютная ошибка (MAE): 25268.431367611895
Средняя абсолютная процентная ошибка (SMAPE): 28.27%
Медианная абсолютная ошибка (MedAE): 14970.763403016885


Случайному лесу создание эмбедингов не принесло никакой информации и никак не улучшило обобщающую способность. Продолжем использщовать изначальный DF.

#Полносвязная нейронная сеть

Создадим свою нейронную сеть основаную на **Sequentia**
И протестируем на разных вариантах архитектур

### Полносвязная нейронная сеть с эмбедингами

In [6]:
def build_and_train_model(architecture, X_train, y_train, X_test, y_test, epochs=100, batch_size=32):
    """
    Строит и обучает модель с заданной архитектурой

    Параметры:
    architecture - список, определяющий архитектуру сети (количество нейронов в каждом слое)
    X_train, y_train - обучающие данные
    X_test, y_test - тестовые данные
    epochs - количество эпох обучения
    batch_size - размер батча

    Возвращает:
    model - обученная модель
    history - история обучения
    metrics - словарь с метриками на тестовых данных
    train_time - время обучения
    """

    input_shape = X_train.shape[1]

    model = Sequential()

    model.add(InputLayer(shape=(input_shape,)))
    model.add(Dense(architecture[0], activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    for neurons in architecture[1:]:
        model.add(Dense(neurons, activation='relu'))
        model.add(BatchNormalization())
        model.add(Dropout(0.2))

    model.add(Dense(1))

    model.compile(optimizer='adam', loss='mse', metrics=['mae'])

    early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

    start_time = time.time()
    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=epochs,
        batch_size=batch_size,
        callbacks=[early_stopping],
        verbose=0
    )
    train_time = time.time() - start_time

    y_pred = model.predict(X_test).flatten()

    metrics = {
        'MAE': mean_absolute_error(y_test, y_pred),
        'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
        'R2': r2_score(y_test, y_pred)
    }

    return model, history, y_pred, train_time, metrics

In [24]:
architectures = {
    'small': [64, 32],
    'medium': [128, 64, 32],
    'large': [256, 128, 64, 32],
    'wide': [512, 256],
    'deep': [64, 64, 64, 64, 64]
}

results = {}

for name, arch in architectures.items():
    print(f"\nTraining {name} architecture: {arch}")
    model, history, y_pred, train_time, metrics = build_and_train_model(
        arch, X_train, y_train, X_test, y_test
    )

    results[name] = {
        'architecture': arch,
        'train_time': train_time,
        'metrics': metrics,
        'epochs_trained': len(history.history['loss'])
    }

    print(f"Training time: {train_time:.2f}s")
    culc_metrics(y_test, y_pred)


Training small architecture: [64, 32]
[1m4435/4435[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 1ms/step
Training time: 1012.70s
Корень из среднеквадратичной ошибки (RMSE): 50941.297970951884
R² Score: 0.5039911649265851
Средняя абсолютная ошибка (MAE): 29103.112701003913
Средняя абсолютная процентная ошибка (SMAPE): 32.33%
Медианная абсолютная ошибка (MedAE): 20343.0625

Training medium architecture: [128, 64, 32]
[1m4435/4435[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 1ms/step
Training time: 1239.07s
Корень из среднеквадратичной ошибки (RMSE): 63997.72743782658
R² Score: 0.2171500930664486
Средняя абсолютная ошибка (MAE): 29133.000331506486
Средняя абсолютная процентная ошибка (SMAPE): 31.22%
Медианная абсолютная ошибка (MedAE): 18446.55078125

Training large architecture: [256, 128, 64, 32]
[1m4435/4435[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 2ms/step
Training time: 903.88s
Корень из среднеквадратичной ошибки (RMSE): 99553.35289036723
R² Score: 

### TabNet с эмбедингами

In [24]:
X_train = X_train.to_numpy()
X_val = X_val.to_numpy()
X_test = X_test.to_numpy()

y_train = y_train.to_numpy().reshape(-1, 1)
y_val = y_val.to_numpy().reshape(-1, 1)
y_test = y_test.to_numpy().reshape(-1, 1)

class SMAPE(Metric):
    def __init__(self):
        self._name = "smape"
        self._maximize = False

    def __call__(self, y_true, y_pred):
        return 100 * np.mean(2 * np.abs(y_true - y_pred) / (np.abs(y_true) + np.abs(y_pred)))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tabnet_params = {
    "n_d": 8,
    "n_a": 8,
    "n_steps": 3,
    "gamma": 1.3,
    "lambda_sparse": 1e-3,
    "optimizer_fn": torch.optim.Adam,
    "optimizer_params": dict(lr=2e-2),
    "mask_type": "sparsemax",
    "scheduler_params": dict(
        mode="min",
        patience=5,
        min_lr=1e-5,
        factor=0.9,
    ),
    "scheduler_fn": torch.optim.lr_scheduler.ReduceLROnPlateau,
    "seed": 42,
    "verbose": 10
}

model = TabNetRegressor(**tabnet_params, device_name=device)

model.fit(
    X_train=X_train,
    y_train=y_train,
    eval_set=[(X_train, y_train), (X_val, y_val)],
    eval_name=['train', 'val'],
    eval_metric=['rmse', 'mae', SMAPE],
    max_epochs=50,
    patience=20,
    batch_size=1024,
    virtual_batch_size=128,
    num_workers=0,
    drop_last=False,
    loss_fn=torch.nn.functional.mse_loss,
)

y_pred = model.predict(X_test)

culc_metrics(y_test, y_pred)



epoch 0  | loss: 20437558003.24312| train_rmse: 141221.62524| train_mae: 85872.05342| train_smape: 184.71002| val_rmse: 115379.20418| val_mae: 85718.8759| val_smape: 184.79026|  0:00:26s
epoch 10 | loss: 10820886075.10296| train_rmse: 104044.71899| train_mae: 28615.76256| train_smape: 31.64873| val_rmse: 65343.81696| val_mae: 28682.34666| val_smape: 31.78737|  0:04:51s
epoch 20 | loss: 10630551067.8584| train_rmse: 101629.35374| train_mae: 31064.74035| train_smape: 34.09568| val_rmse: 60451.6293| val_mae: 31113.5129| val_smape: 34.13363|  0:09:12s
epoch 30 | loss: 10159129009.73706| train_rmse: 99112.52429| train_mae: 27355.95985| train_smape: 30.8259 | val_rmse: 56314.21829| val_mae: 27496.50355| val_smape: 30.97229|  0:13:24s
epoch 40 | loss: 9945627975.79911| train_rmse: 99790.60418| train_mae: 32431.17795| train_smape: 35.30642| val_rmse: 56739.63909| val_mae: 32498.53681| val_smape: 35.4043 |  0:17:36s
Stop training because you reached max_epochs = 50 with best_epoch = 36 and best



Корень из среднеквадратичной ошибки (RMSE): 49171.296497332296
R² Score: 0.5378608956464286
Средняя абсолютная ошибка (MAE): 26366.728262877266
Средняя абсолютная процентная ошибка (SMAPE): 29.62%
Медианная абсолютная ошибка (MedAE): 17047.4140625


#### Вывод
С эмбедингами лучше
На cpu быстрее чем на mps

### Оптимизация TabNet

In [None]:
param_grid = {
    'n_d': [8, 16, 32],
    'n_a': [8, 16, 32],
    'n_steps': [3, 5, 7],
    'gamma': [1.0, 1.3, 1.5],
    'lambda_sparse': [0, 1e-4, 1e-3],
}

from sklearn.base import BaseEstimator, RegressorMixin

class TabNetWrapper(BaseEstimator, RegressorMixin):
    def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, lambda_sparse=1e-3):
        self.n_d = n_d
        self.n_a = n_a
        self.n_steps = n_steps
        self.gamma = gamma
        self.lambda_sparse = lambda_sparse
        self.model = None
        
    def fit(self, X, y):
        self.model = TabNetRegressor(
            n_d=self.n_d,
            n_a=self.n_a,
            n_steps=self.n_steps,
            gamma=self.gamma,
            lambda_sparse=self.lambda_sparse
        )
        self.model.fit(X, y.reshape(-1, 1))
        return self
    
    def predict(self, X):
        return self.model.predict(X).flatten()

model = TabNetWrapper()

grid_search = GridSearchCV(
    estimator=model,
    param_grid=param_grid,
    cv=3,
    scoring='neg_mean_squared_error',
    verbose=2,
    n_jobs=-1
)

grid_search.fit(X_train, y_train)

best_model = grid_search.best_estimator_
best_params = grid_search.best_params_
print(f'Лучшие параметры: {best_params}')


y_pred = best_model.predict(X_test)
culc_metrics(y_test, y_pred)

print("Лучшие параметры:", grid_search.best_params_)

Fitting 3 folds for each of 243 candidates, totalling 729 fits




epoch 0  | loss: 22581979623.97112|  0:00:11s
epoch 0  | loss: 14762814863.24909|  0:00:11s
epoch 0  | loss: 24332797271.79783|  0:00:11s
epoch 0  | loss: 24249658364.30325|  0:00:12s
epoch 0  | loss: 14640936697.53069|  0:00:11s
epoch 0  | loss: 22490206370.65704|  0:00:12s
epoch 0  | loss: 14670376605.11191|  0:00:17s
epoch 0  | loss: 22508431393.27076|  0:00:17s
epoch 0  | loss: 24246481058.65704|  0:00:18s
epoch 1  | loss: 14061122340.04332|  0:00:23s
epoch 1  | loss: 21887270586.68592|  0:00:23s
epoch 1  | loss: 23618582446.67148|  0:00:23s
epoch 1  | loss: 13482588527.82672|  0:00:24s
epoch 1  | loss: 23115953244.41878|  0:00:24s
epoch 0  | loss: 22384062560.11553|  0:00:24s
epoch 1  | loss: 21404253389.16967|  0:00:24s
epoch 0  | loss: 14667839809.61733|  0:00:24s
epoch 0  | loss: 24160459716.85198|  0:00:24s
epoch 2  | loss: 20896716777.8195|  0:00:35s
epoch 2  | loss: 13079806330.91697|  0:00:35s
epoch 2  | loss: 22499209082.91696|  0:00:36s
epoch 1  | loss: 13584044762.10831|



epoch 46 | loss: 4899370672.98195|  0:19:48s
epoch 45 | loss: 14390037672.66426|  0:19:49s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=8, n_steps=3; total time=19.9min




epoch 47 | loss: 12551798131.98556|  0:19:52s
epoch 99 | loss: 13283057887.19133|  0:19:54s
epoch 65 | loss: 12094232988.18772|  0:19:54s
epoch 93 | loss: 11498048412.64982|  0:19:55s
epoch 96 | loss: 12688175664.05776|  0:19:58s
epoch 96 | loss: 3586361344.0|  0:20:00s
epoch 66 | loss: 13955593094.00721|  0:20:01s
epoch 64 | loss: 4573457279.5379|  0:20:02s
epoch 94 | loss: 11375431436.01444|  0:20:09s
epoch 0  | loss: 22291538748.0722|  0:00:22s
epoch 97 | loss: 12729321028.38989|  0:20:11s
epoch 66 | loss: 11992046149.31408|  0:20:13s
epoch 47 | loss: 4837376660.33213|  0:20:13s
epoch 97 | loss: 3591485041.67509|  0:20:14s
epoch 0  | loss: 14493371255.22022|  0:00:23s
epoch 46 | loss: 14360878840.6065|  0:20:15s
epoch 48 | loss: 12370225186.65704|  0:20:18s
epoch 67 | loss: 13907877127.39352|  0:20:21s
epoch 65 | loss: 4560010393.41516|  0:20:23s
epoch 95 | loss: 11417305137.44404|  0:20:24s
epoch 98 | loss: 12758090080.57762|  0:20:26s
epoch 98 | loss: 3504916915.29242|  0:20:28s
e



epoch 1  | loss: 20107830246.12274|  0:00:46s
epoch 96 | loss: 11343316465.213|  0:20:37s
epoch 99 | loss: 12355421622.06498|  0:20:38s
epoch 1  | loss: 12357648232.43321|  0:00:47s
epoch 48 | loss: 4856174479.2491|  0:20:39s
epoch 68 | loss: 13866798816.5776|  0:20:39s
epoch 99 | loss: 3511123769.29964|  0:20:40s
epoch 47 | loss: 14400066335.42239|  0:20:42s
epoch 66 | loss: 4556764404.90975|  0:20:42s
epoch 49 | loss: 12488125688.1444|  0:20:43s
epoch 97 | loss: 11418245084.41877|  0:20:52s
epoch 68 | loss: 12047979522.77256|  0:20:53s
epoch 2  | loss: 17606379203.9278|  0:01:13s
epoch 69 | loss: 13869250310.00722|  0:21:01s
epoch 0  | loss: 24071456745.8195|  0:00:27s
epoch 67 | loss: 4516913741.16968|  0:21:05s
epoch 2  | loss: 9667580956.64982|  0:01:16s
epoch 98 | loss: 11490644917.60289|  0:21:08s
epoch 49 | loss: 4841351264.11552|  0:21:09s
epoch 50 | loss: 12466414294.87364|  0:21:13s
epoch 48 | loss: 14319252637.11191|  0:21:13s
epoch 69 | loss: 12070515732.79422|  0:21:17s
[



epoch 70 | loss: 13864930127.94225|  0:21:23s
epoch 99 | loss: 11501809395.29242|  0:21:23s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=16, n_steps=3; total time=21.4min




epoch 68 | loss: 4556088915.1769|  0:21:27s
epoch 3  | loss: 15516035265.15523|  0:01:40s
epoch 1  | loss: 22044415231.07582|  0:00:55s
epoch 3  | loss: 7603782919.3935|  0:01:41s
epoch 70 | loss: 12149335022.90253|  0:21:36s
epoch 50 | loss: 4842266442.8592|  0:21:37s
epoch 51 | loss: 12485519238.00722|  0:21:39s
epoch 49 | loss: 14252826853.19856|  0:21:41s
epoch 71 | loss: 13845625660.99639|  0:21:42s
epoch 69 | loss: 4445009207.45126|  0:21:48s
epoch 4  | loss: 14235897260.36101|  0:02:05s
epoch 2  | loss: 19403291935.42238|  0:01:21s
epoch 0  | loss: 22108613521.09747|  0:00:34s
epoch 71 | loss: 12096583751.62455|  0:21:57s
epoch 0  | loss: 14321639989.60289|  0:00:34s
epoch 4  | loss: 6421664087.33574|  0:02:07s
epoch 72 | loss: 13727608373.14078|  0:22:03s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=16, n_steps=3; total time=22.1min
epoch 51 | loss: 4777521763.81227|  0:22:04s




epoch 52 | loss: 12452335344.28882|  0:22:06s
epoch 70 | loss: 4451029578.85921|  0:22:08s
epoch 50 | loss: 14263272211.87003|  0:22:08s
epoch 72 | loss: 12132923007.53791|  0:22:15s
epoch 5  | loss: 13595945926.23828|  0:02:27s
epoch 3  | loss: 17074140861.4585|  0:01:43s
epoch 73 | loss: 13765330141.34296|  0:22:20s
epoch 5  | loss: 5796659086.787|  0:02:29s
epoch 1  | loss: 19079795336.77979|  0:01:02s
epoch 71 | loss: 4354508590.67148|  0:22:26s
epoch 1  | loss: 11242683168.34657|  0:01:01s
epoch 52 | loss: 4735963520.92419|  0:22:29s
epoch 53 | loss: 12398853249.84837|  0:22:29s
epoch 73 | loss: 12091235364.04332|  0:22:32s
epoch 51 | loss: 14282764228.38991|  0:22:33s
epoch 0  | loss: 23847132403.98556|  0:00:29s
epoch 6  | loss: 13290185648.05776|  0:02:48s
epoch 74 | loss: 13752146722.19493|  0:22:38s
epoch 4  | loss: 15751752386.54151|  0:02:05s
epoch 6  | loss: 5495872969.93502|  0:02:49s
epoch 72 | loss: 4344261395.87004|  0:22:45s
epoch 74 | loss: 12065450768.63538|  0:22:5



epoch 19 | loss: 12744700819.87004|  0:10:03s
epoch 30 | loss: 12327820704.34657|  0:11:40s
epoch 19 | loss: 5065180895.65342|  0:10:06s
epoch 71 | loss: 14050433799.39348|  0:31:32s
epoch 99 | loss: 4761621729.5018|  0:31:33s
epoch 73 | loss: 4354429781.71841|  0:31:36s
epoch 18 | loss: 14759134490.80144|  0:09:32s
epoch 74 | loss: 12168643017.47292|  0:31:37s
epoch 30 | loss: 4614735975.97112|  0:11:51s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=8, n_steps=5; total time=31.8min




epoch 0  | loss: 22264018308.15884|  0:00:22s
epoch 28 | loss: 14027069721.87726|  0:11:18s
epoch 31 | loss: 12333350819.11914|  0:12:08s
epoch 72 | loss: 13951801568.57762|  0:32:04s
epoch 20 | loss: 12679423673.29964|  0:10:41s
epoch 75 | loss: 12215888624.7509|  0:32:06s
epoch 74 | loss: 4346025088.46209|  0:32:07s
epoch 0  | loss: 14459883904.46209|  0:00:22s
epoch 20 | loss: 5068879183.01805|  0:10:44s
epoch 1  | loss: 11456781999.59567|  0:00:46s
epoch 31 | loss: 4590384852.56317|  0:12:20s
epoch 19 | loss: 14846059771.37905|  0:10:10s
epoch 29 | loss: 13888542911.30686|  0:11:48s
epoch 32 | loss: 12360293394.48375|  0:12:36s
epoch 1  | loss: 12384929265.21299|  0:00:45s
epoch 2  | loss: 18000937645.74728|  0:01:09s
epoch 73 | loss: 13871042801.67509|  0:32:35s
epoch 76 | loss: 12146702390.98917|  0:32:36s
epoch 75 | loss: 4436456478.0361|  0:32:37s
epoch 32 | loss: 4533443109.42961|  0:12:47s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=8, n_steps=5; total time=32.7min




epoch 21 | loss: 12682641395.06137|  0:11:19s
epoch 21 | loss: 4989113045.48736|  0:11:20s
epoch 30 | loss: 13935186498.07942|  0:12:14s
epoch 33 | loss: 12323778326.1805|  0:13:01s
epoch 2  | loss: 10080956804.15884|  0:01:04s
epoch 20 | loss: 14692478509.74729|  0:10:46s
epoch 3  | loss: 16280001772.59203|  0:01:28s
epoch 0  | loss: 24027290631.39352|  0:00:18s
epoch 77 | loss: 12149674597.19856|  0:33:02s
epoch 74 | loss: 13966157338.33934|  0:33:02s
epoch 76 | loss: 4299366657.38628|  0:33:03s
epoch 33 | loss: 4604056473.64621|  0:13:11s
epoch 3  | loss: 8295899986.25271|  0:01:22s
epoch 4  | loss: 14945768807.50903|  0:01:46s
epoch 31 | loss: 13919974581.1408|  0:12:37s
epoch 34 | loss: 12268954846.72924|  0:13:24s
epoch 22 | loss: 12658085571.4657|  0:11:50s
epoch 22 | loss: 4969263451.95668|  0:11:51s
epoch 1  | loss: 22039966679.33574|  0:00:36s
epoch 21 | loss: 14640259637.14079|  0:11:16s
epoch 34 | loss: 4587138315.55234|  0:13:33s
epoch 4  | loss: 6958934805.25632|  0:01:39



epoch 37 | loss: 13201959188.33214|  0:11:54s
epoch 43 | loss: 12219199854.90253|  0:23:14s
epoch 42 | loss: 14092106883.69675|  0:22:34s
epoch 43 | loss: 4492264701.68953|  0:23:14s
epoch 98 | loss: 13663458976.34657|  0:44:39s
epoch 59 | loss: 13356182535.3935|  0:24:09s
epoch 41 | loss: 3540938756.15885|  0:13:06s
epoch 64 | loss: 11736577616.40433|  0:25:04s
epoch 40 | loss: 11764325881.06859|  0:13:27s
epoch 63 | loss: 4249343945.01083|  0:25:02s
epoch 38 | loss: 13330368130.31047|  0:12:15s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=8, n_steps=7; total time=45.0min




epoch 0  | loss: 21930150272.46208|  0:00:31s
epoch 99 | loss: 13543665139.06138|  0:45:08s
epoch 60 | loss: 13252068761.87726|  0:24:34s
epoch 42 | loss: 3437291164.64982|  0:13:23s
epoch 44 | loss: 12209212463.13358|  0:23:48s
epoch 41 | loss: 11699109802.51264|  0:13:46s
epoch 43 | loss: 13855363940.27435|  0:23:07s
epoch 44 | loss: 4486214899.98556|  0:23:47s
epoch 65 | loss: 11697671672.6065|  0:25:25s
epoch 39 | loss: 13789755730.71481|  0:12:33s
epoch 64 | loss: 4258062498.19495|  0:25:24s
epoch 0  | loss: 14223171983.2491|  0:00:30s
epoch 43 | loss: 3429820008.89531|  0:13:43s
epoch 42 | loss: 12018427748.73646|  0:14:07s
epoch 61 | loss: 13051982603.55234|  0:24:59s
epoch 40 | loss: 13531356995.92779|  0:12:55s
epoch 1  | loss: 18532156551.85559|  0:01:03s
epoch 66 | loss: 11585568300.82311|  0:25:48s
epoch 65 | loss: 4116629851.03249|  0:25:48s
epoch 45 | loss: 12166958981.54512|  0:24:24s
epoch 44 | loss: 13892609271.68232|  0:23:43s
epoch 45 | loss: 4595323828.6787|  0:24:2



epoch 3  | loss: 13775486520.37545|  0:02:11s
epoch 64 | loss: 13174821415.74008|  0:26:12s
epoch 69 | loss: 11542386845.11192|  0:27:00s
epoch 68 | loss: 4107201574.35379|  0:26:59s
epoch 47 | loss: 3311871908.73646|  0:15:05s
epoch 47 | loss: 12097849027.00361|  0:25:30s
epoch 47 | loss: 4553563413.25631|  0:25:30s
epoch 46 | loss: 14013028608.92417|  0:24:53s
epoch 46 | loss: 11731096548.73646|  0:15:33s
epoch 44 | loss: 13361711492.62093|  0:14:17s
epoch 3  | loss: 6199740372.10108|  0:02:06s
epoch 65 | loss: 13030406224.40434|  0:26:34s
epoch 48 | loss: 3273943171.23466|  0:15:23s
epoch 70 | loss: 11606543175.62454|  0:27:21s
epoch 0  | loss: 23685657267.29242|  0:00:28s
epoch 4  | loss: 13254540313.87726|  0:02:39s
epoch 69 | loss: 3940209744.40434|  0:27:21s
epoch 47 | loss: 11643301931.43682|  0:15:51s
epoch 45 | loss: 13147042448.63538|  0:14:36s
epoch 48 | loss: 12173440280.49098|  0:26:00s
epoch 48 | loss: 4594505966.90253|  0:26:01s
epoch 49 | loss: 3250339953.67509|  0:15:



epoch 84 | loss: 12526298452.10108|  0:26:32s
epoch 88 | loss: 2878004324.04332|  0:27:31s
epoch 97 | loss: 12952129123.35018|  0:38:43s
epoch 29 | loss: 12228299796.33213|  0:14:45s
epoch 85 | loss: 11074163707.61011|  0:27:59s
epoch 70 | loss: 13439419943.74007|  0:37:20s
epoch 29 | loss: 4409220931.00361|  0:14:29s
epoch 71 | loss: 11929109902.32491|  0:38:06s
epoch 71 | loss: 3956842323.1769|  0:38:04s
epoch 25 | loss: 14062859359.65343|  0:12:51s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=16, n_steps=5; total time=39.7min




epoch 85 | loss: 12304208778.62816|  0:26:53s
epoch 89 | loss: 2840021913.41516|  0:27:50s
epoch 98 | loss: 12945826506.39711|  0:39:06s
epoch 0  | loss: 21666640794.33935|  0:00:41s
epoch 86 | loss: 11046691522.77256|  0:28:19s
epoch 30 | loss: 12210502618.10831|  0:15:15s
epoch 86 | loss: 12375694209.84838|  0:27:11s
epoch 71 | loss: 13580312346.33935|  0:37:49s
epoch 90 | loss: 2790671902.26715|  0:28:08s
epoch 30 | loss: 4296749952.0|  0:14:56s
epoch 72 | loss: 11921570403.35018|  0:38:35s
epoch 72 | loss: 3999746049.84837|  0:38:33s
epoch 99 | loss: 13031296929.27076|  0:39:28s
epoch 26 | loss: 14027032520.54873|  0:13:20s
epoch 87 | loss: 11097461955.69676|  0:28:37s
epoch 87 | loss: 12314063475.75451|  0:27:29s
epoch 0  | loss: 13792008271.48015|  0:00:37s
epoch 91 | loss: 2849623205.66065|  0:28:26s
epoch 31 | loss: 12205201578.05054|  0:15:44s
epoch 88 | loss: 10986417867.09025|  0:28:58s
epoch 1  | loss: 17278981955.4657|  0:01:21s
epoch 72 | loss: 13611091395.4657|  0:38:20s



epoch 94 | loss: 2823205117.22744|  0:29:27s
epoch 90 | loss: 12125123471.2491|  0:28:33s
epoch 91 | loss: 11019572762.33935|  0:29:55s
epoch 33 | loss: 12303169716.6787|  0:16:48s
epoch 74 | loss: 13644907717.31407|  0:39:21s
epoch 33 | loss: 4314942613.25632|  0:16:30s
epoch 95 | loss: 2743508304.17328|  0:29:45s
epoch 75 | loss: 3716287998.61372|  0:40:07s
epoch 91 | loss: 11939043334.23826|  0:28:51s
epoch 75 | loss: 11906890721.03972|  0:40:10s
epoch 2  | loss: 6430376673.03971|  0:02:01s
epoch 92 | loss: 11332414152.54873|  0:30:12s
epoch 29 | loss: 13954645001.24188|  0:14:55s
epoch 3  | loss: 13619405531.95668|  0:02:41s
epoch 96 | loss: 2865601908.44765|  0:30:02s
epoch 34 | loss: 12240279039.07582|  0:17:15s
epoch 0  | loss: 23242441872.17328|  0:00:38s
epoch 92 | loss: 11875469157.66065|  0:29:09s
epoch 93 | loss: 11032061172.90975|  0:30:29s
epoch 34 | loss: 4264652536.1444|  0:16:56s
epoch 75 | loss: 13682225147.84116|  0:39:51s
epoch 76 | loss: 3945039631.71119|  0:40:35s



epoch 78 | loss: 13725224880.05776|  0:41:24s




epoch 5  | loss: 5648681788.0722|  0:03:57s
epoch 79 | loss: 3850760439.68231|  0:42:07s
epoch 79 | loss: 11998375025.67509|  0:42:13s
epoch 33 | loss: 13944825084.76534|  0:16:56s
epoch 6  | loss: 13058943791.13357|  0:04:37s
epoch 99 | loss: 10980059503.59567|  0:32:16s
epoch 98 | loss: 12252007954.48376|  0:31:01s
epoch 0  | loss: 22591392564.6787|  0:00:13s
epoch 38 | loss: 12129709542.58484|  0:19:13s
epoch 3  | loss: 15330437759.5379|  0:02:41s
epoch 38 | loss: 4023737660.5343|  0:18:57s
epoch 79 | loss: 13612359765.02527|  0:41:54s
epoch 1  | loss: 21939143757.63177|  0:00:30s
epoch 99 | loss: 12366289426.48376|  0:31:20s
epoch 80 | loss: 4149462040.02888|  0:42:39s
epoch 80 | loss: 11943981311.07582|  0:42:46s
epoch 34 | loss: 13844490785.73285|  0:17:31s
epoch 6  | loss: 5423172295.62455|  0:04:42s
epoch 2  | loss: 21011156339.52346|  0:00:48s
epoch 39 | loss: 12087786532.96751|  0:19:47s
epoch 7  | loss: 13029879640.72203|  0:05:24s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8,



epoch 39 | loss: 4006137454.44043|  0:19:34s
epoch 80 | loss: 13496683320.83754|  0:42:30s
epoch 3  | loss: 19987550610.94585|  0:01:04s
epoch 81 | loss: 4318148410.68592|  0:43:14s
epoch 4  | loss: 15167250102.06498|  0:03:30s
epoch 81 | loss: 11975070244.04332|  0:43:20s
epoch 0  | loss: 14771468321.27075|  0:00:19s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=32, n_steps=3; total time=32.1min




epoch 35 | loss: 13793788783.13357|  0:18:06s
epoch 40 | loss: 12031357321.70397|  0:20:17s
epoch 4  | loss: 19029700528.51986|  0:01:19s
epoch 7  | loss: 5449474029.05415|  0:05:25s
epoch 40 | loss: 3854981287.97112|  0:20:03s
epoch 1  | loss: 14112548782.67148|  0:00:35s
epoch 0  | loss: 24328516393.58844|  0:00:16s
epoch 81 | loss: 13264397984.34657|  0:42:59s
epoch 5  | loss: 18018005415.27798|  0:01:33s
epoch 8  | loss: 12934627756.82311|  0:06:03s
epoch 82 | loss: 4292957574.00722|  0:43:44s
epoch 82 | loss: 11906816697.76174|  0:43:50s
epoch 41 | loss: 12014452787.29242|  0:20:43s
epoch 36 | loss: 13767537615.48013|  0:18:35s
epoch 2  | loss: 13187879891.63899|  0:00:51s
epoch 6  | loss: 17126112845.63176|  0:01:47s
epoch 1  | loss: 23694230531.69675|  0:00:32s
epoch 5  | loss: 15000144410.33935|  0:04:08s
epoch 41 | loss: 3854385887.65343|  0:20:30s
epoch 7  | loss: 16252581919.42239|  0:02:01s
epoch 82 | loss: 13424923699.29242|  0:43:29s
epoch 3  | loss: 12206604546.77256|  0



epoch 64 | loss: 11296538409.3574|  0:31:28s
epoch 63 | loss: 3371712513.38628|  0:31:03s
epoch 39 | loss: 4374263280.28881|  0:11:40s
epoch 50 | loss: 12065714918.58484|  0:12:37s
epoch 38 | loss: 13562332993.15522|  0:11:24s
epoch 21 | loss: 14154428377.64622|  0:15:00s
epoch 58 | loss: 12971703552.46209|  0:29:32s
epoch 25 | loss: 12212832533.71841|  0:17:15s
epoch 0  | loss: 22502480385.84837|  0:00:27s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=16, n_steps=7; total time=55.0min
epoch 24 | loss: 4672484340.44766|  0:16:49s




epoch 40 | loss: 4351687542.75812|  0:11:57s
epoch 51 | loss: 11996712724.79422|  0:12:54s
epoch 39 | loss: 13464653535.65342|  0:11:40s
epoch 65 | loss: 11285785610.8592|  0:31:56s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=16, n_steps=7; total time=54.4min
epoch 64 | loss: 3236240142.787|  0:31:32s




epoch 41 | loss: 4354217961.81949|  0:12:11s
epoch 52 | loss: 11930784774.93141|  0:13:07s
epoch 59 | loss: 12888444710.81588|  0:29:58s
epoch 40 | loss: 13536756695.79783|  0:11:55s
epoch 1  | loss: 21311470606.78701|  0:00:49s
epoch 0  | loss: 14694227269.31408|  0:00:22s
epoch 22 | loss: 14051954452.79423|  0:15:36s
epoch 26 | loss: 12160071574.1805|  0:17:49s
epoch 53 | loss: 12010796976.05776|  0:13:20s
epoch 42 | loss: 4303107054.90253|  0:12:25s
epoch 66 | loss: 11873486722.77256|  0:32:21s
epoch 41 | loss: 13530565046.98917|  0:12:09s
epoch 65 | loss: 3025710789.31408|  0:31:58s
epoch 0  | loss: 24256721021.68952|  0:00:21s
epoch 25 | loss: 4672730008.25993|  0:17:23s
epoch 2  | loss: 19556618557.92058|  0:01:11s
epoch 60 | loss: 13158693630.61372|  0:30:24s
epoch 54 | loss: 12167458716.18772|  0:13:35s
epoch 43 | loss: 4299011365.89169|  0:12:40s
epoch 1  | loss: 13465100661.37184|  0:00:44s
epoch 42 | loss: 13620881384.43321|  0:12:23s
epoch 1  | loss: 23016715842.54152|  0:0



epoch 30 | loss: 4818221633.15524|  0:11:47s
epoch 45 | loss: 11620962581.71841|  0:29:11s
epoch 85 | loss: 12303327358.38268|  0:41:33s
epoch 89 | loss: 3555997626.22383|  0:23:49s
epoch 32 | loss: 12276939073.61733|  0:12:23s
epoch 88 | loss: 12820225436.64982|  0:23:36s
epoch 31 | loss: 14024943250.02166|  0:11:49s
epoch 44 | loss: 3813065888.11553|  0:28:52s
epoch 93 | loss: 11190654555.26354|  0:43:54s
epoch 91 | loss: 3118244103.16246|  0:43:30s
epoch 90 | loss: 3549306760.31769|  0:24:03s
epoch 41 | loss: 13395048623.13357|  0:27:22s
epoch 31 | loss: 4825956142.20939|  0:12:10s
epoch 89 | loss: 12883841502.26715|  0:23:51s
epoch 33 | loss: 12236025717.83393|  0:12:46s
epoch 86 | loss: 12501172685.63177|  0:42:00s
epoch 0  | loss: 22489709523.63898|  0:00:33s
epoch 91 | loss: 3523494597.31408|  0:24:18s
epoch 32 | loss: 13940088765.45849|  0:12:12s
epoch 46 | loss: 11419387553.73285|  0:29:48s
epoch 94 | loss: 11112901918.0361|  0:44:20s
epoch 90 | loss: 12889156691.17689|  0:24:



epoch 39 | loss: 13669287392.11552|  0:14:59s
epoch 5  | loss: 13585850915.58123|  0:03:26s
epoch 92 | loss: 12495511672.6065|  0:44:56s
epoch 49 | loss: 3757670533.31408|  0:32:11s
epoch 39 | loss: 4612172802.77256|  0:15:26s
epoch 98 | loss: 3217685345.5018|  0:46:55s
epoch 41 | loss: 12292700532.44766|  0:16:02s
epoch 46 | loss: 13319331524.38989|  0:30:48s
[CV] END gamma=1.0, lambda_sparse=0, n_a=16, n_d=8, n_steps=3; total time=27.2min




epoch 40 | loss: 13799399360.23103|  0:15:30s
epoch 93 | loss: 12461085176.6065|  0:45:30s
epoch 0  | loss: 14674713189.66065|  0:00:43s
epoch 51 | loss: 11503489544.08664|  0:33:13s
epoch 40 | loss: 4572219652.62094|  0:15:53s
epoch 6  | loss: 13296233973.83394|  0:04:06s
epoch 42 | loss: 12261449614.787|  0:16:27s
[CV] END gamma=1.0, lambda_sparse=0, n_a=8, n_d=32, n_steps=5; total time=47.8min
epoch 99 | loss: 2933921368.72202|  0:47:26s




epoch 50 | loss: 3544302567.04693|  0:32:54s




epoch 41 | loss: 13588484112.63538|  0:15:54s
epoch 0  | loss: 24210734693.66064|  0:00:36s
epoch 94 | loss: 12458126324.44765|  0:45:58s
epoch 47 | loss: 13315598654.38267|  0:31:28s
epoch 41 | loss: 4549027797.02528|  0:16:16s
epoch 43 | loss: 12163593158.70036|  0:16:50s
epoch 0  | loss: 22479891762.83033|  0:00:15s
epoch 1  | loss: 13104399901.574|  0:01:17s
epoch 7  | loss: 13152425752.95308|  0:04:39s
epoch 52 | loss: 11341508637.34295|  0:33:50s
epoch 42 | loss: 13584500212.90974|  0:16:17s
epoch 1  | loss: 21149315223.56679|  0:00:32s
epoch 42 | loss: 4584033778.13718|  0:16:41s
epoch 51 | loss: 3535823311.48015|  0:33:32s
epoch 95 | loss: 12466498538.28159|  0:46:26s
epoch 44 | loss: 12153886904.83755|  0:17:14s
epoch 1  | loss: 22384868159.76895|  0:01:14s
epoch 43 | loss: 13606141799.50903|  0:16:40s
epoch 2  | loss: 19502219020.01445|  0:00:48s
epoch 48 | loss: 13205182339.23466|  0:32:08s
epoch 2  | loss: 10623346586.33935|  0:01:53s
epoch 8  | loss: 13091715443.06138|  0:



epoch 44 | loss: 13743753852.30325|  0:17:02s
epoch 52 | loss: 3535636647.74007|  0:34:08s
epoch 2  | loss: 20038635636.44766|  0:01:46s
epoch 4  | loss: 16625765680.05776|  0:01:19s
epoch 44 | loss: 4656703955.63899|  0:17:27s
epoch 0  | loss: 14666366968.6065|  0:00:14s
epoch 46 | loss: 12070766598.00722|  0:17:59s
epoch 49 | loss: 13216652438.6426|  0:32:43s
epoch 9  | loss: 13012889470.15162|  0:05:44s
epoch 3  | loss: 8584878482.94585|  0:02:25s
epoch 45 | loss: 13567291592.54874|  0:17:23s
epoch 97 | loss: 12409783875.23466|  0:47:19s
epoch 54 | loss: 11369424801.73285|  0:35:01s
epoch 5  | loss: 15522068423.62454|  0:01:33s
epoch 1  | loss: 13429217265.213|  0:00:28s
epoch 45 | loss: 4704823606.52708|  0:17:48s
epoch 47 | loss: 12122912483.81227|  0:18:20s
epoch 53 | loss: 3321125282.65704|  0:34:42s
epoch 6  | loss: 14654360139.78338|  0:01:48s
epoch 3  | loss: 17974693912.02888|  0:02:17s
epoch 2  | loss: 11682443977.47293|  0:00:43s
epoch 46 | loss: 13956258426.91696|  0:17:4



epoch 13 | loss: 12832634029.7473|  0:07:54s




epoch 56 | loss: 3506851286.41155|  0:36:32s
epoch 13 | loss: 12662382438.58484|  0:03:36s
epoch 50 | loss: 4470489100.01444|  0:19:43s
epoch 52 | loss: 12013365586.7148|  0:20:13s
epoch 9  | loss: 5225540488.31769|  0:02:31s
epoch 7  | loss: 5573451071.30686|  0:04:39s
epoch 51 | loss: 13625260312.49098|  0:19:37s
epoch 0  | loss: 24241136525.40074|  0:00:14s
epoch 53 | loss: 13190556133.19856|  0:35:08s
epoch 14 | loss: 12587165084.18772|  0:03:51s
epoch 10 | loss: 4993452257.50181|  0:02:46s
epoch 58 | loss: 11500086872.25993|  0:37:25s
epoch 51 | loss: 4531527901.80506|  0:20:04s
epoch 53 | loss: 12202845743.13357|  0:20:35s
epoch 7  | loss: 14882408360.20218|  0:04:30s
epoch 1  | loss: 22906439421.22744|  0:00:29s
epoch 15 | loss: 12521417701.66065|  0:04:05s
epoch 14 | loss: 12795035800.95307|  0:08:24s
epoch 52 | loss: 13532859519.07581|  0:19:58s
epoch 11 | loss: 4884070330.68592|  0:03:00s
epoch 57 | loss: 3207538310.93141|  0:37:06s
epoch 8  | loss: 5434002162.13718|  0:05:11



epoch 86 | loss: 3033161709.51625|  0:54:53s
epoch 88 | loss: 11057190236.41877|  0:21:58s
epoch 99 | loss: 3902059323.14801|  0:38:11s
epoch 41 | loss: 13743319445.71842|  0:22:34s
epoch 73 | loss: 12928874536.66426|  0:18:31s
[CV] END gamma=1.0, lambda_sparse=0, n_a=16, n_d=8, n_steps=5; total time=38.0min




epoch 83 | loss: 3519418985.81949|  0:21:03s
epoch 88 | loss: 10882317535.65343|  0:55:39s
epoch 89 | loss: 11025265933.86282|  0:22:13s
epoch 42 | loss: 4386319368.31769|  0:23:14s
epoch 49 | loss: 14346957506.07942|  0:26:40s
epoch 74 | loss: 12791647633.09747|  0:18:46s
epoch 0  | loss: 22339788082.83033|  0:00:28s
epoch 84 | loss: 3036044933.54513|  0:21:19s
epoch 83 | loss: 12667409787.14802|  0:53:46s
epoch 90 | loss: 11141865485.40072|  0:22:30s
epoch 87 | loss: 2975725856.34657|  0:55:29s
epoch 0  | loss: 14565440109.05415|  0:00:29s
epoch 75 | loss: 12713133129.47293|  0:19:03s
epoch 42 | loss: 13688247085.2852|  0:23:08s
epoch 85 | loss: 3287875151.01805|  0:21:35s
epoch 91 | loss: 11096355412.10108|  0:22:46s
epoch 89 | loss: 11074147708.30325|  0:56:17s
epoch 43 | loss: 4273527711.42238|  0:23:49s
epoch 1  | loss: 20096331153.09747|  0:00:59s
epoch 76 | loss: 12812528109.74729|  0:19:20s
epoch 50 | loss: 14337110699.89891|  0:27:14s
epoch 86 | loss: 2934607789.74729|  0:21:



epoch 2  | loss: 17412679561.70397|  0:01:31s
epoch 44 | loss: 4356974259.29242|  0:24:25s
epoch 90 | loss: 10979190081.61733|  0:56:55s
epoch 51 | loss: 14337021817.0686|  0:27:47s
epoch 78 | loss: 12699316919.91336|  0:19:53s
epoch 88 | loss: 3175091745.03971|  0:22:25s
epoch 94 | loss: 11099976842.62816|  0:23:35s
epoch 2  | loss: 9923457304.95307|  0:01:29s
epoch 79 | loss: 13000166381.51625|  0:20:07s
epoch 0  | loss: 24115707109.19856|  0:00:22s
epoch 85 | loss: 12396636767.19134|  0:55:03s
epoch 89 | loss: 3090494510.44043|  0:22:39s
epoch 95 | loss: 11114026968.02888|  0:23:48s
epoch 89 | loss: 2907049172.10108|  0:56:44s
epoch 44 | loss: 13636051187.06137|  0:24:17s
epoch 3  | loss: 15406801146.45486|  0:01:56s
epoch 45 | loss: 4346810867.98556|  0:24:54s
epoch 80 | loss: 12729382944.34657|  0:20:21s
epoch 52 | loss: 14343105037.86281|  0:28:15s
epoch 91 | loss: 11166071212.8231|  0:57:27s
epoch 90 | loss: 3279209094.46931|  0:22:53s
epoch 3  | loss: 7813656959.5379|  0:01:53s



epoch 48 | loss: 4249476818.2527|  0:26:24s
epoch 96 | loss: 2821707977.70397|  0:24:19s
epoch 88 | loss: 12824650296.37546|  0:56:45s
epoch 92 | loss: 3239383325.80505|  0:58:27s
epoch 7  | loss: 12991853951.53791|  0:03:38s
epoch 87 | loss: 12678137976.83755|  0:21:59s
epoch 5  | loss: 15213836119.33575|  0:02:14s
epoch 97 | loss: 2838930764.24549|  0:24:32s
epoch 94 | loss: 10986057891.11913|  0:59:08s
epoch 7  | loss: 5244434658.42599|  0:03:32s
epoch 88 | loss: 12577083041.03971|  0:22:11s
epoch 48 | loss: 13529552225.9639|  0:26:16s
epoch 56 | loss: 14325319545.06859|  0:30:08s
epoch 49 | loss: 4213716726.75812|  0:26:50s
epoch 98 | loss: 3484953490.25271|  0:24:45s
epoch 0  | loss: 22101835794.48375|  0:00:30s
epoch 6  | loss: 14863351959.56678|  0:02:35s
epoch 8  | loss: 12872048121.06859|  0:04:01s
epoch 89 | loss: 12750210401.03971|  0:57:17s
epoch 89 | loss: 3693803158.87364|  0:22:24s
epoch 93 | loss: 2932556745.70397|  0:58:58s
epoch 8  | loss: 5111277885.45848|  0:03:55s




epoch 58 | loss: 14349486972.76534|  0:37:01s
epoch 96 | loss: 10980791356.76534|  1:06:10s
epoch 50 | loss: 13611514742.29602|  0:33:12s
epoch 10 | loss: 12762109282.42599|  0:26:10s
epoch 51 | loss: 4171339117.05415|  0:49:06s
epoch 10 | loss: 5101119791.13357|  0:26:03s
epoch 93 | loss: 12515161641.12636|  0:44:37s
epoch 2  | loss: 15896347574.98917|  0:22:52s
epoch 9  | loss: 14568743844.04332|  0:24:58s
epoch 91 | loss: 12524782551.33574|  1:19:40s
epoch 94 | loss: 12816037798.35378|  0:44:49s
epoch 95 | loss: 3366949481.3574|  1:21:20s
epoch 59 | loss: 14322214392.60649|  0:52:47s
epoch 11 | loss: 12731676360.08664|  0:26:34s
epoch 51 | loss: 13558083996.64982|  0:49:00s
epoch 0  | loss: 14324896148.79423|  0:15:51s
epoch 11 | loss: 5008227974.93141|  0:26:25s
epoch 97 | loss: 10873974657.38628|  1:22:02s
epoch 52 | loss: 4093154828.47653|  0:49:34s
epoch 95 | loss: 12797655148.12995|  0:45:02s
epoch 10 | loss: 14499174948.96751|  0:25:18s
epoch 3  | loss: 14155593466.91697|  0:2



epoch 14 | loss: 14359466229.83393|  0:27:01s
epoch 15 | loss: 12582948246.1805|  0:28:27s
epoch 3  | loss: 6650755059.06137|  0:17:42s
epoch 15 | loss: 4940703482.45487|  0:43:44s
epoch 63 | loss: 14336157493.60288|  1:10:16s
epoch 55 | loss: 13288918769.67509|  1:06:34s
epoch 6  | loss: 13048731655.3935|  0:40:42s
epoch 56 | loss: 4060951198.0361|  1:07:07s
epoch 15 | loss: 14304639788.8231|  0:42:50s
epoch 16 | loss: 12514233874.02166|  0:44:22s
epoch 95 | loss: 12255904115.98556|  1:37:36s
epoch 99 | loss: 3006206942.49819|  1:39:15s
epoch 16 | loss: 4889363462.00722|  0:44:12s
epoch 0  | loss: 23871816223.4224|  0:16:03s
epoch 64 | loss: 14317626868.90975|  1:10:47s
epoch 4  | loss: 5975949097.12636|  0:33:47s
epoch 56 | loss: 13342647306.62816|  1:22:32s
epoch 16 | loss: 14262911655.27799|  0:58:41s
epoch 57 | loss: 4051370905.87726|  1:23:04s
epoch 17 | loss: 12552709197.16968|  1:00:16s
epoch 7  | loss: 12979878494.26715|  0:56:48s
epoch 17 | loss: 4843844567.33574|  1:00:06s
[



epoch 96 | loss: 12565534516.2166|  1:53:39s
epoch 65 | loss: 14315210114.31047|  1:26:42s
epoch 17 | loss: 14229084205.28521|  0:59:06s
epoch 1  | loss: 21265900495.94224|  0:32:06s
epoch 5  | loss: 5779115576.83755|  0:49:51s
epoch 57 | loss: 13203755142.70036|  1:23:03s
epoch 18 | loss: 12506353146.91697|  1:00:39s
epoch 58 | loss: 4202269092.96751|  1:23:35s
epoch 18 | loss: 4829700684.24549|  1:00:30s
epoch 0  | loss: 22286188695.56679|  0:00:23s
epoch 8  | loss: 12925188264.66426|  0:57:24s
epoch 18 | loss: 14189418828.24549|  1:14:50s
epoch 66 | loss: 14321340598.06498|  1:42:33s
epoch 97 | loss: 12259733628.5343|  2:09:34s
epoch 19 | loss: 12489136908.47654|  1:16:25s
epoch 19 | loss: 4802378539.89892|  1:16:15s
epoch 2  | loss: 18468600040.8953|  0:48:02s
epoch 58 | loss: 13205109469.80505|  1:38:54s
epoch 1  | loss: 20158149506.31046|  0:16:08s
epoch 59 | loss: 4145824406.87364|  1:39:26s
epoch 6  | loss: 5669304495.13357|  1:05:48s
epoch 19 | loss: 14193897167.48014|  1:15:1



epoch 9  | loss: 12866586668.36101|  1:13:20s
epoch 20 | loss: 12457120384.92419|  1:16:48s
epoch 20 | loss: 4754508784.28881|  1:16:38s
epoch 2  | loss: 17771620498.02166|  0:16:29s
epoch 98 | loss: 12458538216.43321|  2:10:07s
epoch 59 | loss: 13241139979.55235|  1:39:22s
epoch 3  | loss: 16333873482.8592|  0:48:32s
epoch 60 | loss: 4170020870.93141|  1:39:53s
epoch 20 | loss: 14183187191.22022|  1:15:34s
epoch 7  | loss: 5593479261.80505|  1:20:55s
epoch 0  | loss: 14516266006.1805|  0:14:56s
epoch 21 | loss: 12509776664.49097|  1:31:47s
epoch 3  | loss: 15728200091.26354|  0:31:24s
epoch 68 | loss: 14340972747.3213|  1:58:03s
epoch 21 | loss: 4759621217.5018|  1:31:36s
epoch 10 | loss: 12810002642.2527|  1:28:27s
epoch 21 | loss: 14207185297.09748|  1:30:31s
epoch 60 | loss: 13223715212.47654|  1:54:25s
epoch 61 | loss: 4084599543.68231|  1:54:57s
epoch 99 | loss: 12289105171.63899|  2:25:17s
epoch 1  | loss: 12398353744.40434|  0:15:17s
epoch 4  | loss: 15455302400.46211|  1:03:38



epoch 72 | loss: 14333571878.81587|  1:59:59s
epoch 26 | loss: 12422436866.77256|  1:33:48s
epoch 7  | loss: 14920082796.12996|  1:05:21s
epoch 26 | loss: 4624109452.47653|  1:33:36s
epoch 26 | loss: 14132710218.85921|  1:32:28s
epoch 11 | loss: 5353816585.70397|  1:23:12s
epoch 9  | loss: 12719033883.26354|  0:33:35s
epoch 0  | loss: 24055791678.84476|  0:00:20s
epoch 64 | loss: 13203658964.56319|  1:56:25s
epoch 6  | loss: 5564823740.99639|  0:17:14s
epoch 65 | loss: 4021477799.74007|  1:56:56s
epoch 27 | loss: 12367706274.65704|  1:34:10s
epoch 27 | loss: 4679128101.4296|  1:33:58s
epoch 14 | loss: 12716134410.16606|  1:30:43s
epoch 73 | loss: 14326835393.15522|  2:00:25s
epoch 27 | loss: 14081210274.19495|  1:32:49s
epoch 10 | loss: 12615941060.3899|  0:33:55s
epoch 8  | loss: 14877650513.32852|  1:05:52s
epoch 1  | loss: 22125626214.58484|  0:00:40s
epoch 7  | loss: 5274239670.52707|  0:17:35s
epoch 12 | loss: 5302864332.24548|  1:23:42s
epoch 65 | loss: 13058330752.92419|  1:56:5



epoch 40 | loss: 12334260250.80144|  1:52:21s
epoch 39 | loss: 12819866565.54512|  0:22:14s
epoch 94 | loss: 3776224916.79422|  2:18:51s
epoch 46 | loss: 3242394733.97834|  0:39:14s
epoch 38 | loss: 4491533102.20939|  1:45:25s
epoch 93 | loss: 13116413430.29603|  2:18:35s
epoch 65 | loss: 13697610026.05054|  1:54:48s
epoch 51 | loss: 11374851799.79784|  0:55:53s
epoch 63 | loss: 11389802341.19856|  1:56:17s
epoch 63 | loss: 4233125622.75812|  1:56:05s
epoch 34 | loss: 14147862622.72924|  1:28:00s
epoch 40 | loss: 12891913730.31047|  0:22:50s
epoch 47 | loss: 3199858627.4657|  0:39:51s
epoch 0  | loss: 21940885818.22383|  0:00:53s
epoch 41 | loss: 12321810767.48015|  1:53:14s
epoch 95 | loss: 3479225400.83754|  2:19:40s
epoch 66 | loss: 14061545203.98555|  1:55:26s
epoch 52 | loss: 11275561090.31047|  0:56:29s
epoch 64 | loss: 11370915155.63899|  1:56:56s
epoch 64 | loss: 3975886633.58844|  1:56:44s
epoch 94 | loss: 13033575661.05415|  2:19:24s
epoch 39 | loss: 4342771631.59567|  1:46:1



epoch 75 | loss: 15191159400.43321|  2:01:23s
epoch 73 | loss: 3799922379.78339|  2:02:53s
epoch 41 | loss: 13923426325.25632|  1:34:40s
epoch 48 | loss: 12158400105.8195|  1:59:40s
epoch 73 | loss: 11666733112.37546|  2:03:12s
epoch 50 | loss: 12420451150.55596|  0:29:33s
epoch 57 | loss: 2866589537.73285|  0:46:32s
epoch 62 | loss: 11081612299.3213|  1:02:56s
epoch 7  | loss: 12892429482.97474|  0:07:34s
epoch 46 | loss: 4178684433.09747|  1:52:38s
epoch 76 | loss: 15173880733.11192|  2:02:04s
epoch 0  | loss: 14208260678.23826|  0:00:58s
epoch 74 | loss: 3699030642.13718|  2:03:34s
[CV] END gamma=1.0, lambda_sparse=0, n_a=16, n_d=8, n_steps=7; total time=146.2min




epoch 74 | loss: 11656833457.90614|  2:03:53s
epoch 51 | loss: 12621447241.47293|  0:30:14s
epoch 63 | loss: 11045840943.36462|  1:03:32s
epoch 58 | loss: 3055187262.84476|  0:47:11s
epoch 42 | loss: 13738057432.72202|  1:35:33s
epoch 49 | loss: 12123940774.81589|  2:00:33s
epoch 77 | loss: 14982878695.97112|  2:02:42s
epoch 8  | loss: 12829909681.90614|  0:08:29s
epoch 47 | loss: 4258687765.71841|  1:53:30s
epoch 75 | loss: 3637715196.30325|  2:04:12s
epoch 64 | loss: 11044027725.16967|  1:04:05s
epoch 52 | loss: 12550354344.66427|  0:30:50s
epoch 75 | loss: 11487511902.72924|  2:04:32s
epoch 59 | loss: 3012508694.41155|  0:47:46s
epoch 1  | loss: 11131303734.52708|  0:01:49s
epoch 0  | loss: 23738307726.32492|  0:00:53s
epoch 78 | loss: 14269900823.56679|  2:03:19s
epoch 43 | loss: 13821731719.8556|  1:36:25s
epoch 50 | loss: 12145705538.07942|  2:01:25s
epoch 65 | loss: 11081361255.27798|  1:04:38s
epoch 76 | loss: 3949971664.40433|  2:04:50s
epoch 9  | loss: 12805776171.43682|  0:0



epoch 98 | loss: 11210779779.23466|  2:20:28s
epoch 26 | loss: 12269427572.90976|  0:24:42s
epoch 60 | loss: 13275946384.17328|  1:52:06s
epoch 85 | loss: 2871108300.47653|  1:03:48s
epoch 67 | loss: 12087683061.83393|  2:17:10s
epoch 99 | loss: 3439366208.92419|  2:20:27s
epoch 92 | loss: 10903603077.31408|  1:20:17s
epoch 65 | loss: 3654630158.32491|  2:09:55s
epoch 20 | loss: 4514381427.75452|  0:17:57s
epoch 18 | loss: 14099208963.23466|  0:16:55s
epoch 78 | loss: 11852360492.8231|  0:47:07s
epoch 99 | loss: 11129270083.00361|  2:20:50s
epoch 86 | loss: 2675219894.06498|  1:04:08s
epoch 93 | loss: 10931701710.787|  1:20:36s
epoch 27 | loss: 12337288866.19495|  0:25:14s
epoch 61 | loss: 13922959163.14801|  1:52:39s
epoch 68 | loss: 12143742073.53068|  2:17:45s
epoch 0  | loss: 21536606076.76534|  0:00:48s
epoch 66 | loss: 3702360999.74007|  2:10:31s
epoch 79 | loss: 11864984609.50181|  0:47:38s
epoch 21 | loss: 4429138720.80866|  0:18:35s
epoch 87 | loss: 2681417680.17328|  1:04:33s



epoch 22 | loss: 4435898653.57401|  0:26:48s
epoch 89 | loss: 2591276124.64982|  1:12:50s
epoch 96 | loss: 11076174789.08303|  1:29:14s
epoch 20 | loss: 13969417349.08303|  0:25:47s
[CV] END gamma=1.0, lambda_sparse=0, n_a=16, n_d=16, n_steps=5; total time=149.7min




epoch 1  | loss: 17657989716.10109|  0:09:13s
epoch 81 | loss: 11802946069.48736|  0:56:03s
epoch 63 | loss: 13605396291.00362|  2:01:23s
epoch 29 | loss: 13089032877.74729|  0:34:04s
epoch 97 | loss: 10772014900.90975|  1:44:56s
epoch 90 | loss: 2712240430.67148|  1:28:33s
epoch 70 | loss: 12044853762.77256|  2:41:54s
epoch 68 | loss: 3703501474.65704|  2:34:39s
epoch 23 | loss: 4363263642.80144|  0:42:41s
epoch 82 | loss: 11908596517.66066|  1:11:47s
epoch 21 | loss: 13857129045.48736|  0:41:41s
epoch 98 | loss: 11056999060.10109|  1:45:12s
epoch 0  | loss: 13872389792.80866|  0:16:05s
epoch 91 | loss: 2862622264.6065|  1:28:50s
epoch 64 | loss: 13735871011.11913|  2:17:18s
epoch 30 | loss: 12829271526.12274|  0:49:59s
epoch 0  | loss: 23556607709.80504|  0:16:06s
epoch 2  | loss: 14495818695.16246|  0:25:19s
epoch 83 | loss: 11736952325.54512|  1:12:07s
epoch 99 | loss: 10849794092.59206|  1:45:29s
epoch 24 | loss: 4305328551.74008|  0:43:09s
epoch 71 | loss: 11902245173.14079|  2:4



epoch 86 | loss: 11702807914.74369|  1:13:07s
epoch 66 | loss: 13679929495.56678|  2:18:22s
epoch 32 | loss: 12820868733.22743|  0:51:03s
epoch 95 | loss: 2634081220.62094|  1:30:04s
epoch 26 | loss: 4180021594.1083|  0:44:11s
epoch 0  | loss: 22579770855.97112|  0:00:14s
epoch 73 | loss: 12084553744.63538|  2:43:30s
epoch 71 | loss: 3644564908.82311|  2:36:12s
epoch 2  | loss: 6836479775.88447|  0:17:28s
epoch 24 | loss: 13715000661.94944|  0:43:14s
epoch 87 | loss: 11731091961.76173|  1:13:24s
epoch 96 | loss: 2939294459.61011|  1:30:20s
epoch 2  | loss: 16441664561.90614|  0:17:28s
epoch 4  | loss: 13246607503.24909|  0:26:43s
epoch 1  | loss: 21935261548.12995|  0:00:29s
epoch 33 | loss: 12873523640.83754|  0:51:30s
epoch 67 | loss: 13621908903.74007|  2:18:51s
epoch 27 | loss: 4146562329.87726|  0:44:37s
epoch 88 | loss: 11781140963.81227|  1:13:41s
epoch 97 | loss: 2608443027.1769|  1:30:37s
epoch 2  | loss: 20821884314.33935|  0:00:43s
epoch 74 | loss: 12128551954.48376|  2:44:0



epoch 36 | loss: 12373359394.65704|  0:53:04s
epoch 7  | loss: 15453356874.85922|  0:02:11s
epoch 30 | loss: 4026449299.87004|  0:46:12s
epoch 70 | loss: 14073391390.49819|  2:20:32s
epoch 93 | loss: 11965362115.46571|  1:15:23s
epoch 5  | loss: 5645258748.30325|  0:19:37s
epoch 75 | loss: 3627445225.3574|  2:38:27s
epoch 8  | loss: 14729934010.68592|  0:02:30s
epoch 77 | loss: 12239896800.11552|  2:45:49s
epoch 0  | loss: 14780942687.19134|  0:00:29s
epoch 28 | loss: 13389409232.86642|  0:45:34s
epoch 37 | loss: 12365755001.99278|  0:53:55s
epoch 5  | loss: 15016881534.15163|  0:20:04s
epoch 94 | loss: 11666710363.95668|  1:16:11s
epoch 31 | loss: 3908056958.15162|  0:47:13s
epoch 7  | loss: 13029513663.30686|  0:29:29s
epoch 9  | loss: 14088320019.87002|  0:03:18s
epoch 71 | loss: 15251864152.72202|  2:21:50s
epoch 1  | loss: 14133923448.1444|  0:01:29s
epoch 95 | loss: 11663560249.53069|  1:17:06s
epoch 76 | loss: 3673935631.2491|  2:40:03s
epoch 10 | loss: 13610020351.5379|  0:04:0



epoch 42 | loss: 12176319689.01083|  0:58:32s
epoch 7  | loss: 7752660061.34296|  0:05:36s
epoch 33 | loss: 13214023516.88087|  0:50:41s
epoch 36 | loss: 3810291224.95307|  0:51:48s
epoch 80 | loss: 3628060579.58123|  2:43:49s
epoch 18 | loss: 12435649299.40794|  0:07:52s
epoch 9  | loss: 5352772698.5704|  0:25:06s
epoch 82 | loss: 11848493208.49098|  2:51:13s
epoch 0  | loss: 24263438945.96391|  0:00:35s
epoch 8  | loss: 6977976617.58845|  0:06:11s
epoch 19 | loss: 12380748788.44765|  0:08:18s
epoch 9  | loss: 14761036061.57401|  0:25:22s
epoch 43 | loss: 12166041742.787|  0:59:18s
epoch 76 | loss: 13444058900.33213|  2:26:38s
epoch 11 | loss: 12904959067.0325|  0:34:44s
epoch 37 | loss: 3852151625.93502|  0:52:32s
epoch 34 | loss: 13185086456.83755|  0:51:30s
epoch 81 | loss: 3832745339.14801|  2:44:39s
epoch 1  | loss: 23673386468.27436|  0:01:09s
epoch 20 | loss: 12329094165.25632|  0:08:43s
epoch 83 | loss: 11815079241.93502|  2:52:06s
epoch 9  | loss: 6374884163.4657|  0:06:47s
e



epoch 96 | loss: 13682980231.8556|  3:01:56s
epoch 25 | loss: 4780649567.65343|  1:00:52s
epoch 31 | loss: 13537187248.98196|  0:36:09s
epoch 55 | loss: 12645155130.22383|  1:26:35s
epoch 38 | loss: 3883681244.41877|  0:41:38s
epoch 60 | loss: 11407465040.86643|  0:43:50s
epoch 60 | loss: 2942895038.38267|  1:27:56s
epoch 66 | loss: 12652287575.33574|  1:34:55s
epoch 27 | loss: 12567995896.1444|  1:10:17s
epoch 32 | loss: 13704930385.79062|  0:36:31s
epoch 25 | loss: 14080936230.35379|  1:01:08s
epoch 61 | loss: 11437877954.77256|  0:44:08s
epoch 39 | loss: 3963797557.14079|  0:42:02s
epoch 0  | loss: 22529570583.1047|  0:00:33s
epoch 56 | loss: 12963585316.50542|  1:42:29s
epoch 97 | loss: 13361319565.40073|  3:17:54s
epoch 26 | loss: 4689952449.61733|  1:17:00s
epoch 62 | loss: 11375937859.46571|  0:59:50s
epoch 61 | loss: 3058847512.49097|  1:43:49s
epoch 33 | loss: 13676799039.76895|  0:52:18s
epoch 67 | loss: 12708752049.90614|  1:50:49s
[CV] END gamma=1.0, lambda_sparse=0, n_a=16



epoch 40 | loss: 3909732871.62455|  0:57:50s
epoch 1  | loss: 21465443777.15524|  0:16:27s
epoch 57 | loss: 12924064244.90974|  1:42:56s
epoch 63 | loss: 11448811716.85198|  1:00:03s
epoch 28 | loss: 12563303849.58844|  1:26:20s
epoch 98 | loss: 13176749053.68953|  3:18:24s
epoch 34 | loss: 13534248067.69675|  0:52:36s
epoch 26 | loss: 14023546603.66788|  1:17:12s
epoch 41 | loss: 3817448118.98917|  0:58:08s
epoch 62 | loss: 2943500224.46209|  1:44:13s
epoch 68 | loss: 12699912400.40433|  1:51:14s
epoch 64 | loss: 11364408552.43321|  1:00:17s
epoch 0  | loss: 14724094898.36823|  0:00:28s
epoch 27 | loss: 4642554542.20938|  1:17:36s
epoch 35 | loss: 13606846567.50902|  0:52:54s
epoch 58 | loss: 12689077217.9639|  1:43:21s
epoch 2  | loss: 19637630380.82311|  0:16:53s
epoch 65 | loss: 11602956989.45848|  1:00:32s
epoch 42 | loss: 3913534068.90975|  0:58:27s
epoch 99 | loss: 13116215749.31408|  3:18:53s
epoch 63 | loss: 3331026500.85199|  1:44:39s
epoch 29 | loss: 12441133593.87726|  1:26



epoch 40 | loss: 13354827677.574|  1:10:11s
epoch 30 | loss: 4553891969.38628|  1:35:00s
epoch 62 | loss: 12451871589.66065|  2:00:45s
epoch 72 | loss: 11554461355.43682|  1:17:55s
epoch 47 | loss: 3637890779.95668|  1:15:50s
epoch 67 | loss: 3342031239.3935|  2:01:58s
epoch 6  | loss: 14636022910.61372|  0:34:26s
epoch 41 | loss: 13410700869.77617|  1:10:29s
epoch 4  | loss: 9153628116.56317|  0:18:08s
epoch 73 | loss: 12721728040.66426|  2:09:02s
epoch 0  | loss: 24258675512.37546|  0:00:22s
epoch 32 | loss: 12418458214.12274|  1:44:23s
epoch 30 | loss: 13689004953.87726|  1:35:11s
epoch 73 | loss: 11502838430.26715|  1:18:09s
epoch 48 | loss: 3538451644.76534|  1:16:08s
epoch 63 | loss: 12837356701.57401|  2:01:10s
epoch 42 | loss: 13175349697.61732|  1:10:47s
epoch 74 | loss: 11536108383.19134|  1:18:23s
epoch 68 | loss: 2956204918.52708|  2:02:22s
epoch 31 | loss: 4468007093.1408|  1:35:35s
epoch 7  | loss: 14018838815.4224|  0:34:51s
epoch 1  | loss: 23022892449.73285|  0:00:46s




epoch 43 | loss: 12464521265.90614|  2:29:53s
epoch 19 | loss: 4965439113.70397|  1:03:48s
epoch 42 | loss: 4230579988.10108|  2:20:57s
epoch 22 | loss: 12383979346.7148|  1:20:11s
epoch 89 | loss: 12697063981.7473|  2:54:44s
epoch 70 | loss: 2973007055.48014|  2:01:44s
epoch 65 | loss: 12547043465.24187|  1:56:21s
epoch 84 | loss: 2904415098.91697|  2:47:54s
epoch 79 | loss: 12844984502.06498|  2:46:48s
epoch 18 | loss: 14275825423.71118|  0:46:13s
epoch 42 | loss: 13184592942.67148|  2:21:07s
epoch 71 | loss: 2912955996.41877|  2:02:03s
epoch 66 | loss: 12388066369.15522|  1:56:38s
epoch 20 | loss: 4923919034.68592|  1:04:16s
epoch 90 | loss: 12705295663.59567|  2:55:09s
epoch 23 | loss: 12372076008.43321|  1:20:37s
epoch 44 | loss: 12485155492.04333|  2:30:29s
epoch 0  | loss: 22462162400.57762|  0:00:35s
epoch 85 | loss: 2847942719.07581|  2:48:19s
epoch 43 | loss: 4255532787.98556|  2:21:32s
epoch 80 | loss: 12601789012.10108|  2:47:14s
epoch 19 | loss: 14247293189.08304|  0:46:37



epoch 44 | loss: 4287017247.88447|  2:22:10s
epoch 92 | loss: 12755017735.3935|  2:56:04s
epoch 69 | loss: 12373517877.83393|  1:57:34s
epoch 74 | loss: 2889761491.63899|  2:03:02s
epoch 25 | loss: 12600352377.99277|  1:21:32s
epoch 22 | loss: 4858221791.65343|  1:05:13s
epoch 87 | loss: 2881272098.88809|  2:49:12s
epoch 21 | loss: 14144396986.68593|  0:47:29s
epoch 44 | loss: 13135378390.41155|  2:22:24s
epoch 70 | loss: 12452864933.4296|  1:57:54s
epoch 2  | loss: 18580550511.8267|  0:01:49s
epoch 75 | loss: 2944294084.62094|  2:03:23s
epoch 46 | loss: 12475617432.02888|  2:31:50s
epoch 0  | loss: 14605915113.81949|  0:00:39s
epoch 93 | loss: 12767319914.74369|  2:56:33s
epoch 26 | loss: 12618140306.94586|  1:22:02s
epoch 45 | loss: 4007495015.04693|  2:22:52s
epoch 23 | loss: 4811727834.1083|  1:05:45s
epoch 22 | loss: 14170308987.37907|  0:47:57s
epoch 88 | loss: 2960079902.0361|  2:49:40s
epoch 71 | loss: 12365892773.66065|  1:58:14s
epoch 76 | loss: 2877586725.8917|  2:03:45s
epo



epoch 31 | loss: 13832126394.22383|  1:22:39s
epoch 51 | loss: 3931956973.97834|  2:57:36s
epoch 83 | loss: 12197581828.15884|  2:32:52s
epoch 31 | loss: 4503511555.69675|  1:40:34s
epoch 97 | loss: 3003955975.16245|  3:24:32s
epoch 88 | loss: 3034253037.05415|  2:38:29s
epoch 84 | loss: 12190242310.70036|  2:33:08s
epoch 7  | loss: 5637472091.49458|  0:35:46s
epoch 35 | loss: 12194773996.59206|  1:57:06s
epoch 51 | loss: 12890557815.22022|  2:57:44s
epoch 9  | loss: 12987975539.52346|  0:37:04s
epoch 32 | loss: 13861566117.4296|  1:23:02s
epoch 89 | loss: 3228668008.8953|  2:38:45s
epoch 53 | loss: 12467389858.19495|  3:07:09s
epoch 32 | loss: 4592580662.06498|  1:41:00s
epoch 98 | loss: 2756341386.8592|  3:24:55s
epoch 85 | loss: 12226439701.94946|  2:33:24s
epoch 52 | loss: 3866638510.44043|  2:58:10s
epoch 0  | loss: 24166281382.3538|  0:00:39s
epoch 36 | loss: 12345834240.0|  1:57:31s
epoch 90 | loss: 3186138158.90253|  2:39:02s
epoch 33 | loss: 13825391381.71841|  1:23:26s
epoch 



epoch 54 | loss: 12894016036.50541|  3:02:49s




epoch 39 | loss: 12231172890.80145|  2:02:13s
epoch 12 | loss: 12857714152.8953|  0:42:12s
epoch 95 | loss: 2640588691.63899|  2:43:48s
epoch 91 | loss: 12208529081.76174|  2:38:22s
epoch 0  | loss: 22495616469.48737|  0:00:14s
epoch 56 | loss: 12443616492.59206|  3:12:17s
epoch 55 | loss: 3679326064.05776|  3:03:16s
epoch 37 | loss: 13603290355.06137|  1:28:21s
epoch 36 | loss: 4442031343.82672|  2:01:40s
epoch 96 | loss: 2658311089.67509|  2:59:30s
epoch 92 | loss: 12109948117.48736|  2:54:03s
epoch 40 | loss: 12234122002.02166|  2:18:04s
epoch 1  | loss: 21308839832.49097|  0:15:55s
epoch 11 | loss: 5416475501.05415|  0:56:48s
epoch 55 | loss: 12670342925.86282|  3:18:48s
epoch 3  | loss: 18189180112.86643|  0:21:25s
epoch 13 | loss: 12801876908.82311|  0:58:11s
epoch 93 | loss: 11979478542.09386|  2:54:18s
epoch 97 | loss: 2606640854.87365|  2:59:47s
epoch 38 | loss: 13583954980.50541|  1:44:10s
epoch 2  | loss: 19806539217.79062|  0:16:10s
epoch 37 | loss: 4309468294.93141|  2:02:



epoch 41 | loss: 13589721096.31768|  1:45:27s
epoch 5  | loss: 16171813301.6029|  0:22:51s
epoch 98 | loss: 11869801360.17329|  2:55:45s
epoch 59 | loss: 12466148922.68592|  3:29:36s
epoch 7  | loss: 14407270245.66065|  0:17:35s
epoch 44 | loss: 12210530482.36823|  2:19:48s
epoch 58 | loss: 3583196158.61372|  3:20:35s
epoch 0  | loss: 14698542471.8556|  0:00:19s
epoch 40 | loss: 4183046604.24549|  2:03:35s
epoch 14 | loss: 5204850863.59567|  0:58:36s
epoch 99 | loss: 11976006273.61733|  2:56:01s
epoch 42 | loss: 13386662307.58122|  1:45:50s
epoch 58 | loss: 12446655213.74729|  3:20:40s
epoch 8  | loss: 13820434924.59206|  0:17:51s
epoch 16 | loss: 12717864249.76173|  1:00:01s
epoch 45 | loss: 12175229887.30686|  2:20:14s
epoch 1  | loss: 13612809759.42238|  0:00:41s
epoch 6  | loss: 15561421656.722|  0:23:33s
epoch 9  | loss: 13474634485.37185|  0:18:11s
epoch 60 | loss: 12467580453.4296|  3:30:17s
epoch 41 | loss: 4216577907.52346|  2:04:07s
epoch 43 | loss: 13565272210.48376|  1:46:1



epoch 59 | loss: 12672841761.27074|  3:36:46s
epoch 46 | loss: 12086889578.74368|  2:36:08s
epoch 17 | loss: 12682848229.19856|  1:16:10s
epoch 44 | loss: 13362935485.45848|  2:02:10s
epoch 42 | loss: 4026353401.06859|  2:20:02s
epoch 11 | loss: 12875923460.15886|  0:34:11s
epoch 3  | loss: 11038228973.51624|  0:16:51s
epoch 0  | loss: 24249889252.27437|  0:00:18s
epoch 7  | loss: 15083977372.18773|  0:39:42s
epoch 61 | loss: 12481385896.66426|  3:46:20s
epoch 60 | loss: 3567864040.66426|  3:37:17s
epoch 47 | loss: 12017149022.26716|  2:36:31s
epoch 16 | loss: 5111282849.73285|  1:15:17s
epoch 12 | loss: 12757844711.04693|  0:34:26s
epoch 45 | loss: 13303514843.03248|  2:02:33s
epoch 60 | loss: 12462542267.14802|  3:37:20s
epoch 4  | loss: 9721944246.98917|  0:17:11s
epoch 1  | loss: 23138496868.73646|  0:00:37s
epoch 18 | loss: 12605069424.7509|  1:16:43s
epoch 43 | loss: 4130493750.52708|  2:20:28s
epoch 13 | loss: 12620361319.04694|  0:34:42s
epoch 48 | loss: 11878245497.76173|  2:3

### Подготовка данных без эмбедингов

In [12]:
scaler = StandardScaler()
num_df = pd.DataFrame(scaler.fit_transform(df[num_columns]), columns=num_columns)

label_columns = []
ohe_columns = []

for column in cat_columns:
    if df[column].nunique() > 10:
        label_columns.append(column)
    else:
        ohe_columns.append(column)

to_bool = list(df[cat_columns].select_dtypes(include=['bool']).columns)
df[['salary_gross', 'employer_accredited_it_employer']] = df[['salary_gross', 'employer_accredited_it_employer']].astype(bool).astype(int)
df[to_bool] = df[to_bool].astype(int)

ohe = OneHotEncoder(sparse_output=False, drop='first')
ohe_encoded = ohe.fit_transform(df[ohe_columns])
ohe_feature_names = ohe.get_feature_names_out(ohe_columns).tolist()
encoded_ohe_data = pd.DataFrame(ohe_encoded, columns=ohe_feature_names)

label_encoder = LabelEncoder()
for col in label_columns:
    df[col] = label_encoder.fit_transform(df[col])

X = pd.concat([df[label_columns], encoded_ohe_data, num_df], axis=1)
y = df['salary']

X_2_train, X_2_test_val, y_2_train, y_2_test_val, = train_test_split(X, y, test_size=0.4, random_state=12345)
X_2_test, X_2_val, y_2_test, y_2_val = train_test_split(X_2_test_val, y_2_test_val, test_size=0.5, random_state=12345)

print(f'Размеры выборок: Обучающая {X_2_train.shape}, Валидационная {X_2_test.shape}, Тестовая {X_2_val.shape}')

Размеры выборок: Обучающая (425714, 69), Валидационная (141905, 69), Тестовая (141905, 69)


### Полносвязная нейронная сеть без эмбедингов

In [None]:
results_2 = {}

architectures = {
    'small': [64, 32],
    'medium': [128, 64, 32],
    'large': [256, 128, 64, 32],
    'wide': [512, 256],
    'deep': [64, 64, 64, 64, 64]
}

for name, arch in architectures.items():
    print(f"\nTraining {name} architecture: {arch}")
    model, history, y_pred, train_time, metrics = build_and_train_model(
        arch, X_2_train, y_2_train, X_2_test, y_2_test
    )

    results_2[name] = {
        'architecture': arch,
        'train_time': train_time,
        'metrics': metrics,
        'epochs_trained': len(history.history['loss'])
    }

    print(f"Training time: {train_time:.2f}s")
    culc_metrics(y_2_test, y_pred)


Training small architecture: [64, 32]
[1m4435/4435[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 1ms/step
Training time: 702.78s
Корень из среднеквадратичной ошибки (RMSE): 71249.22691526295
R² Score: 0.029691776178032314
Средняя абсолютная ошибка (MAE): 41372.06422014223
Средняя абсолютная процентная ошибка (SMAPE): 45.38%
Медианная абсолютная ошибка (MedAE): 33547.609375

Training medium architecture: [128, 64, 32]
[1m4435/4435[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 1ms/step
Training time: 1053.81s
Корень из среднеквадратичной ошибки (RMSE): 70796.97617631152
R² Score: 0.041970643708062916
Средняя абсолютная ошибка (MAE): 38817.90827763207
Средняя абсолютная процентная ошибка (SMAPE): 42.65%
Медианная абсолютная ошибка (MedAE): 28399.53125

Training large architecture: [256, 128, 64, 32]
[1m4435/4435[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 2ms/step
Training time: 1138.94s
Корень из среднеквадратичной ошибки (RMSE): 75041.47628239205
R² Score: 

### TabNet без эмбедингов

In [None]:
X_train = X_2_train.to_numpy()
X_val = X_2_val.to_numpy()
X_test = X_2_test.to_numpy()

y_train = y_2_train.to_numpy().reshape(-1, 1)
y_val = y_2_val.to_numpy().reshape(-1, 1)
y_test = y_2_test.to_numpy().reshape(-1, 1)

class SMAPE(Metric):
    def __init__(self):
        self._name = "smape"
        self._maximize = False

    def __call__(self, y_true, y_pred):
        return 100 * np.mean(2 * np.abs(y_true - y_pred) / (np.abs(y_true) + np.abs(y_pred)))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tabnet_params = {
    "n_d": 8,
    "n_a": 8,
    "n_steps": 3,
    "gamma": 1.3,
    "lambda_sparse": 1e-3,
    "optimizer_fn": torch.optim.Adam,
    "optimizer_params": dict(lr=2e-2),
    "mask_type": "sparsemax",
    "scheduler_params": dict(
        mode="min",
        patience=5,
        min_lr=1e-5,
        factor=0.9,
    ),
    "scheduler_fn": torch.optim.lr_scheduler.ReduceLROnPlateau,
    "seed": 42,
    "verbose": 10,
    "device_name": device
}

model = TabNetRegressor(**tabnet_params)

model.fit(
    X_train=X_train,
    y_train=y_train,
    eval_set=[(X_train, y_train), (X_val, y_val)],
    eval_name=['train', 'val'],
    eval_metric=['rmse', 'mae', SMAPE],
    max_epochs=50,
    patience=20,
    batch_size=512,
    virtual_batch_size=128,
    num_workers=0,
    drop_last=False,
    loss_fn=torch.nn.functional.mse_loss,
    pin_memory=False
)

y_pred = model.predict(X_test)

culc_metrics(y_test, y_pred)

epoch 0  | loss: 19872069664.19348| train_rmse: 136917.37655| train_mae: 79243.44531| train_smape: 154.26673| val_rmse: 110076.88833| val_mae: 79087.0 | val_smape: 154.32972|  0:00:58s
epoch 10 | loss: 11264595244.16502| train_rmse: 106016.77754| train_mae: 29198.63867| train_smape: 32.39938| val_rmse: 68291.12645| val_mae: 29137.13281| val_smape: 32.44966|  0:10:57s
epoch 20 | loss: 11220968340.13476| train_rmse: 112126.81278| train_mae: 33202.75781| train_smape: 36.70336| val_rmse: 78464.1925| val_mae: 33103.00391| val_smape: 36.7279 |  0:20:47s
epoch 30 | loss: 11206899306.1923| train_rmse: 114423.6865| train_mae: 38514.47266| train_smape: 41.57988| val_rmse: 80557.57692| val_mae: 38373.82812| val_smape: 41.61062|  0:30:39s

Early stopping occurred at epoch 30 with best_epoch = 10 and best_val_smape = 32.44966




Корень из среднеквадратичной ошибки (RMSE): 61374.38162621274
R² Score: 0.28001469373703003
Средняя абсолютная ошибка (MAE): 29087.1015625
Средняя абсолютная процентная ошибка (SMAPE): 32.53%
Медианная абсолютная ошибка (MedAE): 19507.56640625


## Выводы