In [1]:
import os
import sys
import numpy as np
import pandas as pd
from copy import deepcopy
from tqdm import tqdm

import matplotlib.pyplot as plt
import plotly
import plotly.graph_objs as go
import plotly.express as px
from plotly.subplots import make_subplots
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPRegressor
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV

import torch
from kan import KAN

from tqdm import tqdm

In [2]:
from raw_data_processing import get_x, get_y, get_wavelength
from tools import JSON_Read, plotly_multi_scatter, get_all_sqz_input, KAN_es

In [3]:
SCRIPT_DIR = os.path.abspath('')

## Loading data

In [4]:
d_config = JSON_Read("", "json_config.txt")

EXCITE_WAVE_LENGTH = d_config['EXCITE_WAVE_LENGTH']
PREDICT_IONS = d_config['PREDICT_IONS']
SPEC_FOLDER = d_config['SPEC_FOLDER']

TRAIN_TEST_RATIO = d_config['TRAIN_TEST_RATIO']
VALIDATION_TRAIN_RATIO = d_config['VALIDATION_TRAIN_RATIO']
N_ITER_NO_CHANGE = d_config['N_ITER_NO_CHANGE']

HIDDEN_LAYER_SIZES = d_config['HIDDEN_LAYER_SIZES']
ACTIVATION = d_config['ACTIVATION']
SOLVER = d_config['SOLVER']

In [5]:
x = get_x(wave_length=EXCITE_WAVE_LENGTH, spec_file=""+SPEC_FOLDER)
y = get_y(l_ions=PREDICT_IONS, spec_file=""+SPEC_FOLDER)

# Squeeze input data

In [6]:
l_wavelenth = get_wavelength(spec_file=""+SPEC_FOLDER)

In [7]:
x_matrix, y_matrix = np.broadcast_to(l_wavelenth, (len(x), len(l_wavelenth))), x.to_numpy()

x_sqz = get_all_sqz_input(x_matrix, y_matrix)

In [10]:
def alg_KAN_es(x, y, seed = None,
               K=3, GRID = 3,
               lamb=0., lamb_l1=1., lamb_entropy=2.,
               steps=200, tol=0.001, n_iter_no_change=10):
    
    x_train, x_test, y_train, y_test = train_test_split(x, y, 
                                                        train_size=TRAIN_TEST_RATIO,
                                                        random_state=seed)

    x_val, x_train, y_val, y_train = train_test_split(x_train, y_train, 
                                                      train_size=VALIDATION_TRAIN_RATIO, 
                                                      random_state=seed)
    scaler = StandardScaler()
    x_train = scaler.fit_transform(x_train)
    x_val = scaler.transform(x_val)
    x_test = scaler.transform(x_test)

    tc_x_train = torch.from_numpy(x_train)
    tc_y_train = torch.from_numpy(y_train.reshape([-1,1]))
    tc_x_val = torch.from_numpy(x_val)
    tc_y_val = torch.from_numpy(y_val).reshape([-1,1])
    tc_x_test = torch.from_numpy(x_test)
    tc_y_test = torch.from_numpy(y_test).reshape([-1,1])

    dataset_3 = {'train_input': tc_x_train,
                 'train_label': tc_y_train,
                 'val_input': tc_x_val,
                 'val_label': tc_y_val,
                 'test_input': tc_x_test,
                 'test_label': tc_y_test}
    
    INPUT_SHAPE = tc_x_test.shape[1]

    model_es = KAN_es(width=[INPUT_SHAPE, 1, 1], grid=GRID, k=K, seed=seed)
    result_es = model_es.train_es(dataset_3, 
                                  tol=tol, 
                                  n_iter_no_change=n_iter_no_change,
                                  opt="LBFGS", steps=steps, 
                                  lamb=lamb,
                                  lamb_l1=lamb_l1,
                                  lamb_entropy=lamb_entropy
                                  )
    
    pred_test = model_es(dataset_3['test_input']).cpu().detach().numpy().ravel()
    rmse = mean_squared_error(y_test, pred_test)
    r2 = r2_score(y_test, pred_test)
    mae = mean_absolute_error(y_test, pred_test)

    return [rmse, r2, mae]

In [8]:
def alg_skl_model(x, y, class_model, model_kwargs, seed = None):
    
    x_train, x_test, y_train, y_test = train_test_split(x, y, 
                                                        train_size=TRAIN_TEST_RATIO,
                                                        random_state=seed)
    
    scaler = StandardScaler()
    x_train = scaler.fit_transform(x_train)
    x_test = scaler.transform(x_test)

    #print(model_kwargs)
    model = class_model(random_state=seed, **model_kwargs)
    model.fit(x_train, y_train)

    pred_test = model.predict(x_test)
    rmse = mean_squared_error(y_test, pred_test)
    r2 = r2_score(y_test, pred_test)
    mae = mean_absolute_error(y_test, pred_test)

    return [rmse, r2, mae]

In [9]:
MLP_model_kwargs = {'hidden_layer_sizes': 16,
                  'activation': ACTIVATION,
                  'solver': SOLVER,
                  'early_stopping': True,
                  'validation_fraction': VALIDATION_TRAIN_RATIO,
                  'n_iter_no_change': N_ITER_NO_CHANGE,
                  'learning_rate_init': 0.001,
                  'learning_rate': 'adaptive'}

GB_model_kwargs = {'validation_fraction': VALIDATION_TRAIN_RATIO,
                   'n_iter_no_change': N_ITER_NO_CHANGE}

RF_model_kwargs = {}

In [11]:
def multi_exp(l_algos_names,
              l_algos,
              mult_X_Y,
              l_kwargs,
              l_metrics_names,
              num_iter):
    ''' Function, that process algos(X, Y) and returns df of their metrics. 
    '''
    res_list = []

    for alg, (x, y), kwargs, alg_name in zip(l_algos, mult_X_Y, l_kwargs, l_algos_names):
        print(f'--- Processing {alg_name}')

        for i in range(1, num_iter+1):
            print(f'iter: {i}')
            #print(kwargs)
            l_metrics = alg(x, y, seed=i, **kwargs)
            res_list.append([alg_name, i]+l_metrics)
        print('-------')

    return pd.DataFrame(res_list, columns=['alg_name', 'iter']+l_metrics_names)

In [15]:
l_algos_names=['500_KAN', '500_MLP', '500_RF', '500_GB',
               '5_KAN', '5_MLP', '5_RF', '5_GB']

l_algos=[alg_KAN_es, alg_skl_model, alg_skl_model, alg_skl_model,
         alg_KAN_es, alg_skl_model, alg_skl_model, alg_skl_model]

mult_X_Y=[(x, y), (x, y), (x, y), (x, y), 
          (x_sqz, y), (x_sqz, y), (x_sqz, y), (x_sqz, y)]

l_kwargs=[{},
          {'class_model': MLPRegressor,'model_kwargs': MLP_model_kwargs},
          {'class_model': RandomForestRegressor,'model_kwargs': RF_model_kwargs},
          {'class_model': GradientBoostingRegressor,'model_kwargs': GB_model_kwargs},
          {},
          {'class_model': MLPRegressor,'model_kwargs': MLP_model_kwargs},
          {'class_model': RandomForestRegressor,'model_kwargs': RF_model_kwargs},
          {'class_model': GradientBoostingRegressor,'model_kwargs': GB_model_kwargs},]

l_metrics_names=['rmse', 'r2', 'mae']

num_iter=25

In [16]:
'''
l_algos_names=['500_MLP', '500_RF', '500_GB',
               '5_MLP', '5_RF', '5_GB']

l_algos=[alg_skl_model, alg_skl_model, alg_skl_model,
         alg_skl_model, alg_skl_model, alg_skl_model]

mult_X_Y=[(x, y), (x, y), (x, y), 
          (x_sqz, y), (x_sqz, y), (x_sqz, y)]

l_kwargs=[{'class_model': MLPRegressor,'model_kwargs': MLP_model_kwargs},
          {'class_model': RandomForestRegressor,'model_kwargs': RF_model_kwargs},
          {'class_model': GradientBoostingRegressor,'model_kwargs': GB_model_kwargs},
          {'class_model': MLPRegressor,'model_kwargs': MLP_model_kwargs},
          {'class_model': RandomForestRegressor,'model_kwargs': RF_model_kwargs},
          {'class_model': GradientBoostingRegressor,'model_kwargs': GB_model_kwargs},]

l_metrics_names=['rmse', 'r2', 'mae']

num_iter=3
'''

"\nl_algos_names=['500_MLP', '500_RF', '500_GB',\n               '5_MLP', '5_RF', '5_GB']\n\nl_algos=[alg_skl_model, alg_skl_model, alg_skl_model,\n         alg_skl_model, alg_skl_model, alg_skl_model]\n\nmult_X_Y=[(x, y), (x, y), (x, y), \n          (x_sqz, y), (x_sqz, y), (x_sqz, y)]\n\nl_kwargs=[{'class_model': MLPRegressor,'model_kwargs': MLP_model_kwargs},\n          {'class_model': RandomForestRegressor,'model_kwargs': RF_model_kwargs},\n          {'class_model': GradientBoostingRegressor,'model_kwargs': GB_model_kwargs},\n          {'class_model': MLPRegressor,'model_kwargs': MLP_model_kwargs},\n          {'class_model': RandomForestRegressor,'model_kwargs': RF_model_kwargs},\n          {'class_model': GradientBoostingRegressor,'model_kwargs': GB_model_kwargs},]\n\nl_metrics_names=['rmse', 'r2', 'mae']\n\nnum_iter=3\n"

In [17]:
full_df = multi_exp(l_algos_names=l_algos_names,
                    l_algos=l_algos,
                    mult_X_Y=mult_X_Y,
                    l_kwargs=l_kwargs,
                    l_metrics_names=l_metrics_names,
                    num_iter=num_iter)

--- Processing 500_KAN
iter: 1


trn_ls: 1.60e-01 | vl_ls: 4.37e-01 | e_stop: 10/10 | tst_ls: 4.73e-01 | reg: 3.72e+01 :   7%|▎   | 14/200 [02:10<28:47,  9.29s/it]


Early stopping criteria raised
iter: 2


trn_ls: 1.38e-01 | vl_ls: 4.45e-01 | e_stop: 10/10 | tst_ls: 4.79e-01 | reg: 3.50e+01 :   7%|▎   | 14/200 [02:06<28:00,  9.04s/it]


Early stopping criteria raised
iter: 3


trn_ls: 3.03e-01 | vl_ls: 3.87e-01 | e_stop: 10/10 | tst_ls: 4.39e-01 | reg: 4.42e+01 :  13%|▌   | 26/200 [03:52<25:53,  8.93s/it]


Early stopping criteria raised
iter: 4


trn_ls: 1.26e-01 | vl_ls: 4.89e-01 | e_stop: 10/10 | tst_ls: 4.86e-01 | reg: 4.71e+01 :   8%|▎   | 16/200 [02:24<27:40,  9.03s/it]


Early stopping criteria raised
iter: 5


trn_ls: 1.15e-01 | vl_ls: 4.77e-01 | e_stop: 10/10 | tst_ls: 4.20e-01 | reg: 3.84e+01 :   7%|▎   | 14/200 [02:07<28:12,  9.10s/it]


Early stopping criteria raised
iter: 6


trn_ls: 1.48e-01 | vl_ls: 4.83e-01 | e_stop: 10/10 | tst_ls: 4.57e-01 | reg: 3.74e+01 :   8%|▎   | 15/200 [02:25<29:55,  9.70s/it]


Early stopping criteria raised
iter: 7


trn_ls: 1.01e-01 | vl_ls: 5.39e-01 | e_stop: 10/10 | tst_ls: 4.90e-01 | reg: 3.26e+01 :   8%|▎   | 16/200 [02:34<29:34,  9.64s/it]


Early stopping criteria raised
iter: 8


trn_ls: 1.99e-01 | vl_ls: 4.23e-01 | e_stop: 10/10 | tst_ls: 5.61e-01 | reg: 3.89e+01 :  10%|▍   | 19/200 [03:02<28:55,  9.59s/it]


Early stopping criteria raised
iter: 9


trn_ls: 2.62e-01 | vl_ls: 4.06e-01 | e_stop: 10/10 | tst_ls: 4.86e-01 | reg: 3.65e+01 :   9%|▎   | 18/200 [02:54<29:22,  9.69s/it]


Early stopping criteria raised
iter: 10


trn_ls: 1.21e-01 | vl_ls: 4.57e-01 | e_stop: 10/10 | tst_ls: 4.66e-01 | reg: 3.51e+01 :   8%|▎   | 17/200 [02:44<29:30,  9.68s/it]


Early stopping criteria raised
iter: 11


trn_ls: 1.29e-01 | vl_ls: 4.67e-01 | e_stop: 10/10 | tst_ls: 3.54e-01 | reg: 4.07e+01 :   9%|▎   | 18/200 [02:49<28:29,  9.39s/it]


Early stopping criteria raised
iter: 12


trn_ls: 1.56e-01 | vl_ls: 4.70e-01 | e_stop: 10/10 | tst_ls: 4.81e-01 | reg: 3.78e+01 :   7%|▎   | 14/200 [02:13<29:38,  9.56s/it]


Early stopping criteria raised
iter: 13


trn_ls: 1.42e-01 | vl_ls: 4.69e-01 | e_stop: 10/10 | tst_ls: 4.03e-01 | reg: 4.34e+01 :   7%|▎   | 14/200 [02:16<30:19,  9.78s/it]


Early stopping criteria raised
iter: 14


trn_ls: 1.45e-01 | vl_ls: 4.68e-01 | e_stop: 10/10 | tst_ls: 4.45e-01 | reg: 3.76e+01 :   6%|▎   | 13/200 [02:04<29:47,  9.56s/it]


Early stopping criteria raised
iter: 15


trn_ls: 1.09e-01 | vl_ls: 4.74e-01 | e_stop: 10/10 | tst_ls: 4.94e-01 | reg: 3.48e+01 :   8%|▎   | 16/200 [02:33<29:29,  9.62s/it]


Early stopping criteria raised
iter: 16


trn_ls: 2.58e-01 | vl_ls: 4.53e-01 | e_stop: 10/10 | tst_ls: 4.35e-01 | reg: 4.01e+01 :  13%|▌   | 26/200 [04:06<27:29,  9.48s/it]


Early stopping criteria raised
iter: 17


trn_ls: 1.02e-01 | vl_ls: 4.47e-01 | e_stop: 10/10 | tst_ls: 4.52e-01 | reg: 3.35e+01 :   6%|▎   | 13/200 [02:03<29:35,  9.50s/it]


Early stopping criteria raised
iter: 18


trn_ls: 9.17e-02 | vl_ls: 5.09e-01 | e_stop: 10/10 | tst_ls: 4.36e-01 | reg: 3.20e+01 :   6%|▎   | 13/200 [02:04<29:48,  9.56s/it]


Early stopping criteria raised
iter: 19


trn_ls: 1.51e-01 | vl_ls: 4.69e-01 | e_stop: 10/10 | tst_ls: 3.99e-01 | reg: 3.12e+01 :  12%|▍   | 24/200 [03:46<27:43,  9.45s/it]


Early stopping criteria raised
iter: 20


trn_ls: 3.22e-01 | vl_ls: 4.37e-01 | e_stop: 10/10 | tst_ls: 4.22e-01 | reg: 1.23e+02 :  17%|▋   | 34/200 [05:14<25:34,  9.25s/it]


Early stopping criteria raised
iter: 21


trn_ls: 2.40e-01 | vl_ls: 4.61e-01 | e_stop: 10/10 | tst_ls: 4.53e-01 | reg: 5.58e+01 :  15%|▌   | 30/200 [04:31<25:36,  9.04s/it]


Early stopping criteria raised
iter: 22


trn_ls: 9.34e-02 | vl_ls: 4.99e-01 | e_stop: 10/10 | tst_ls: 4.92e-01 | reg: 3.83e+01 :   8%|▎   | 17/200 [02:27<26:27,  8.67s/it]


Early stopping criteria raised
iter: 23


trn_ls: 1.06e-01 | vl_ls: 5.08e-01 | e_stop: 10/10 | tst_ls: 4.76e-01 | reg: 3.24e+01 :   8%|▎   | 16/200 [02:26<28:06,  9.16s/it]


Early stopping criteria raised
iter: 24


trn_ls: 1.77e-01 | vl_ls: 5.20e-01 | e_stop: 10/10 | tst_ls: 4.71e-01 | reg: 3.53e+01 :   8%|▎   | 16/200 [02:35<29:47,  9.72s/it]


Early stopping criteria raised
iter: 25


trn_ls: 3.16e-01 | vl_ls: 4.79e-01 | e_stop: 10/10 | tst_ls: 7.66e-01 | reg: 3.65e+01 :   9%|▎   | 18/200 [02:54<29:27,  9.71s/it]


Early stopping criteria raised
-------
--- Processing 500_MLP
iter: 1




iter: 2




iter: 3




iter: 4




iter: 5




iter: 6




iter: 7




iter: 8




iter: 9




iter: 10




iter: 11




iter: 12




iter: 13




iter: 14




iter: 15




iter: 16




iter: 17




iter: 18




iter: 19




iter: 20




iter: 21




iter: 22




iter: 23




iter: 24




iter: 25




-------
--- Processing 500_RF
iter: 1
iter: 2
iter: 3
iter: 4
iter: 5
iter: 6
iter: 7
iter: 8
iter: 9
iter: 10
iter: 11
iter: 12
iter: 13
iter: 14
iter: 15
iter: 16
iter: 17
iter: 18
iter: 19
iter: 20
iter: 21
iter: 22
iter: 23
iter: 24
iter: 25
-------
--- Processing 500_GB
iter: 1
iter: 2
iter: 3
iter: 4
iter: 5
iter: 6
iter: 7
iter: 8
iter: 9
iter: 10
iter: 11
iter: 12
iter: 13
iter: 14
iter: 15
iter: 16
iter: 17
iter: 18
iter: 19
iter: 20
iter: 21
iter: 22
iter: 23
iter: 24
iter: 25
-------
--- Processing 5_KAN
iter: 1


trn_ls: 3.17e-01 | vl_ls: 3.54e-01 | e_stop: 10/10 | tst_ls: 3.66e-01 | reg: 5.84e+00 :   6%|▎   | 13/200 [00:05<01:14,  2.50it/s]


Early stopping criteria raised
iter: 2


trn_ls: 3.35e-01 | vl_ls: 3.03e-01 | e_stop: 10/10 | tst_ls: 2.96e-01 | reg: 6.57e+00 :  12%|▌   | 25/200 [00:09<01:09,  2.52it/s]


Early stopping criteria raised
iter: 3


trn_ls: 3.15e-01 | vl_ls: 3.50e-01 | e_stop: 10/10 | tst_ls: 3.61e-01 | reg: 6.16e+00 :   9%|▎   | 18/200 [00:07<01:13,  2.47it/s]


Early stopping criteria raised
iter: 4


trn_ls: 3.10e-01 | vl_ls: 3.65e-01 | e_stop: 10/10 | tst_ls: 3.78e-01 | reg: 5.74e+00 :   8%|▎   | 17/200 [00:06<01:13,  2.50it/s]


Early stopping criteria raised
iter: 5


trn_ls: 3.32e-01 | vl_ls: 3.13e-01 | e_stop: 10/10 | tst_ls: 3.07e-01 | reg: 6.07e+00 :  18%|▋   | 36/200 [00:14<01:08,  2.41it/s]


Early stopping criteria raised
iter: 6


trn_ls: 3.26e-01 | vl_ls: 3.32e-01 | e_stop: 10/10 | tst_ls: 3.31e-01 | reg: 6.27e+00 :   8%|▎   | 17/200 [00:06<01:14,  2.45it/s]


Early stopping criteria raised
iter: 7


trn_ls: 3.24e-01 | vl_ls: 3.07e-01 | e_stop: 10/10 | tst_ls: 3.85e-01 | reg: 5.87e+00 :   8%|▎   | 17/200 [00:06<01:14,  2.47it/s]


Early stopping criteria raised
iter: 8


trn_ls: 3.11e-01 | vl_ls: 3.47e-01 | e_stop: 10/10 | tst_ls: 4.10e-01 | reg: 7.00e+00 :  10%|▍   | 19/200 [00:07<01:14,  2.44it/s]


Early stopping criteria raised
iter: 9


trn_ls: 3.19e-01 | vl_ls: 3.05e-01 | e_stop: 10/10 | tst_ls: 4.05e-01 | reg: 6.19e+00 :   8%|▎   | 17/200 [00:06<01:15,  2.43it/s]


Early stopping criteria raised
iter: 10


trn_ls: 3.22e-01 | vl_ls: 3.26e-01 | e_stop: 10/10 | tst_ls: 3.57e-01 | reg: 6.23e+00 :  10%|▍   | 20/200 [00:07<01:11,  2.53it/s]


Early stopping criteria raised
iter: 11


trn_ls: 3.10e-01 | vl_ls: 3.71e-01 | e_stop: 10/10 | tst_ls: 3.51e-01 | reg: 6.03e+00 :   7%|▎   | 14/200 [00:05<01:15,  2.45it/s]


Early stopping criteria raised
iter: 12


trn_ls: 3.18e-01 | vl_ls: 3.64e-01 | e_stop: 10/10 | tst_ls: 3.12e-01 | reg: 5.93e+00 :   7%|▎   | 14/200 [00:05<01:16,  2.44it/s]


Early stopping criteria raised
iter: 13


trn_ls: 3.12e-01 | vl_ls: 3.54e-01 | e_stop: 10/10 | tst_ls: 3.64e-01 | reg: 6.38e+00 :   8%|▎   | 17/200 [00:06<01:14,  2.47it/s]


Early stopping criteria raised
iter: 14


trn_ls: 3.29e-01 | vl_ls: 3.36e-01 | e_stop: 10/10 | tst_ls: 2.94e-01 | reg: 6.45e+00 :  12%|▍   | 23/200 [00:09<01:12,  2.44it/s]


Early stopping criteria raised
iter: 15


trn_ls: 3.24e-01 | vl_ls: 3.43e-01 | e_stop: 10/10 | tst_ls: 3.15e-01 | reg: 6.08e+00 :   7%|▎   | 14/200 [00:05<01:15,  2.47it/s]


Early stopping criteria raised
iter: 16


trn_ls: 3.09e-01 | vl_ls: 3.66e-01 | e_stop: 10/10 | tst_ls: 3.59e-01 | reg: 6.44e+00 :  10%|▍   | 20/200 [00:07<01:10,  2.56it/s]


Early stopping criteria raised
iter: 17


trn_ls: 3.16e-01 | vl_ls: 3.36e-01 | e_stop: 10/10 | tst_ls: 3.80e-01 | reg: 5.94e+00 :   8%|▎   | 17/200 [00:06<01:13,  2.48it/s]


Early stopping criteria raised
iter: 18


trn_ls: 3.23e-01 | vl_ls: 3.47e-01 | e_stop: 10/10 | tst_ls: 3.24e-01 | reg: 5.99e+00 :   7%|▎   | 14/200 [00:05<01:15,  2.45it/s]


Early stopping criteria raised
iter: 19


trn_ls: 3.19e-01 | vl_ls: 3.67e-01 | e_stop: 10/10 | tst_ls: 3.10e-01 | reg: 7.40e+00 :  11%|▍   | 22/200 [00:09<01:13,  2.42it/s]


Early stopping criteria raised
iter: 20


trn_ls: 3.20e-01 | vl_ls: 3.44e-01 | e_stop: 10/10 | tst_ls: 3.50e-01 | reg: 5.99e+00 :   7%|▎   | 14/200 [00:05<01:15,  2.47it/s]


Early stopping criteria raised
iter: 21


trn_ls: 3.17e-01 | vl_ls: 3.79e-01 | e_stop: 10/10 | tst_ls: 2.72e-01 | reg: 6.19e+00 :   8%|▎   | 16/200 [00:06<01:14,  2.46it/s]


Early stopping criteria raised
iter: 22


trn_ls: 3.19e-01 | vl_ls: 3.50e-01 | e_stop: 10/10 | tst_ls: 3.27e-01 | reg: 5.95e+00 :  10%|▍   | 21/200 [00:08<01:10,  2.53it/s]


Early stopping criteria raised
iter: 23


trn_ls: 3.26e-01 | vl_ls: 3.30e-01 | e_stop: 10/10 | tst_ls: 3.21e-01 | reg: 6.37e+00 :   9%|▎   | 18/200 [00:07<01:13,  2.47it/s]


Early stopping criteria raised
iter: 24


trn_ls: 3.15e-01 | vl_ls: 3.78e-01 | e_stop: 10/10 | tst_ls: 3.38e-01 | reg: 6.32e+00 :   8%|▎   | 16/200 [00:06<01:20,  2.30it/s]


Early stopping criteria raised
iter: 25


trn_ls: 3.19e-01 | vl_ls: 3.57e-01 | e_stop: 10/10 | tst_ls: 3.08e-01 | reg: 6.26e+00 :  11%|▍   | 22/200 [00:08<01:10,  2.52it/s]


Early stopping criteria raised
-------
--- Processing 5_MLP
iter: 1




iter: 2




iter: 3




iter: 4




iter: 5




iter: 6




iter: 7




iter: 8




iter: 9




iter: 10




iter: 11




iter: 12




iter: 13




iter: 14




iter: 15




iter: 16




iter: 17




iter: 18




iter: 19




iter: 20




iter: 21




iter: 22




iter: 23




iter: 24




iter: 25




-------
--- Processing 5_RF
iter: 1
iter: 2
iter: 3
iter: 4
iter: 5
iter: 6
iter: 7
iter: 8
iter: 9
iter: 10
iter: 11
iter: 12
iter: 13
iter: 14
iter: 15
iter: 16
iter: 17
iter: 18
iter: 19
iter: 20
iter: 21
iter: 22
iter: 23
iter: 24
iter: 25
-------
--- Processing 5_GB
iter: 1
iter: 2
iter: 3
iter: 4
iter: 5
iter: 6
iter: 7
iter: 8
iter: 9
iter: 10
iter: 11
iter: 12
iter: 13
iter: 14
iter: 15
iter: 16
iter: 17
iter: 18
iter: 19
iter: 20
iter: 21
iter: 22
iter: 23
iter: 24
iter: 25
-------


In [18]:
full_df

Unnamed: 0,alg_name,iter,rmse,r2,mae
0,500_KAN,1,0.170453,0.923544,0.308077
1,500_KAN,2,0.188008,0.924614,0.315091
2,500_KAN,3,0.195467,0.921781,0.332320
3,500_KAN,4,0.132584,0.949010,0.281474
4,500_KAN,5,0.144570,0.939638,0.275084
...,...,...,...,...,...
195,5_GB,21,0.122070,0.951131,0.272301
196,5_GB,22,0.127524,0.951078,0.266599
197,5_GB,23,0.163647,0.932081,0.295245
198,5_GB,24,0.181779,0.918202,0.316071


In [19]:
full_df.to_excel('full_metrics.xlsx')
#pd.read_excel('full_metrics.xlsx').drop('Unnamed: 0', axis=1)

In [20]:
aggr_df = full_df.groupby(['alg_name']).agg(["mean", "std"]).drop(['iter'], axis=1)
aggr_df.to_excel('aggr_metrics.xlsx')
aggr_df

Unnamed: 0_level_0,rmse,rmse,r2,r2,mae,mae
Unnamed: 0_level_1,mean,std,mean,std,mean,std
alg_name,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
500_GB,0.229024,0.03587,0.906359,0.015556,0.366623,0.028747
500_KAN,0.181186,0.02991,0.925875,0.013506,0.319783,0.02828
500_MLP,1.515338,0.394456,0.380907,0.166678,0.981519,0.123861
500_RF,0.223663,0.038748,0.908481,0.017274,0.355509,0.027921
5_GB,0.169189,0.032629,0.930912,0.013644,0.300406,0.019115
5_KAN,0.118371,0.024763,0.951499,0.011249,0.251473,0.025528
5_MLP,0.256118,0.03825,0.895525,0.015454,0.403326,0.035079
5_RF,0.18803,0.041634,0.923181,0.017504,0.292365,0.024603
