In [9]:
# Utils imports
import pandas as pd
import numpy as np
import os
import joblib
import random

#Optimization imports
import optuna as opt

#Evaluation imports
from sklearn.metrics import accuracy_score, mean_absolute_error, mean_squared_error, r2_score


In [10]:
#Model imports
from xgboost import XGBRegressor
from pytorch_tabnet.tab_model import TabNetRegressor
from sklearn.linear_model import LinearRegression, Ridge, Lasso, LogisticRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.svm import SVR
from sklearn.gaussian_process import GaussianProcessRegressor

In [11]:
# Preparing data

df = pd.read_csv("../dataset/smogn_syn_data.csv")
tar_col = "LC50 [-LOG(mol/L)]"
Syn = "isSyn"
isSyn = df[Syn]
X = df.drop([tar_col,Syn], axis=1)
Y = df[tar_col]



# PARAMS 
NUM_TRIALS = 15

In [12]:
# Creating the dict for model and trial values
models = {
    "XGBRegressor" : {
        "model" : XGBRegressor,
        "param" : {
                "n_estimators":  'trial.suggest_categorical("xgb_est",[4500,5000])',
                "learning_rate": 'trial.suggest_categorical("xgb_lr",[0.01,3e-4,0.1])',
                "booster" : 'trial.suggest_categorical("xgb_booster",["gbtree","gblinear","dart"])',
                "tree_method" : 'trial.suggest_categorical("xgb_treemethod",["gpu_hist"])',
                "predictor" : 'trial.suggest_categorical("xgp_predictor",["gpu_predictor"])'
        }
    }
}   

In [18]:
# main objective function for optuna 
def train_main(X,Y,train_fold,fold):
    train_index = train_fold[fold]["train"]
    test_index = train_fold[fold]["test"]
    X_main = X.iloc[train_index, :].to_numpy(dtype=np.float64)
    Y_main = Y[train_index].to_numpy(dtype=np.float64)
    X_test = X.iloc[test_index, :].to_numpy(dtype=np.float64)
    Y_test = Y[test_index].to_numpy(dtype=np.float64)
    out_data = {}
    def objective(trial):
        # XG-Boost Setup
        reg = XGBRegressor(
                n_estimators = trial.suggest_categorical("xgb_est",[4500,5000]),
                learning_rate =  trial.suggest_categorical("xgb_lr",[0.01,3e-4,0.1]),
                booster  = trial.suggest_categorical("xgb_booster",["gblinear","dart"]),
                tree_method  =  trial.suggest_categorical("xgb_treemethod",["gpu_hist"]),
                predictor = trial.suggest_categorical("xgp_predictor",["gpu_predictor"]))
        reg.fit(X_main, Y_main, 
                eval_set = [(X_test,Y_test)],
                eval_metric = "rmse",
                verbose=True)
        Y_pred = reg.predict(X_test)
        error = mean_squared_error(Y_pred, Y_test, squared=False)
        return error
    study = opt.create_study(direction='minimize')
    study.optimize(objective, n_trials = NUM_TRIALS)
    best_params = study.best_params
    trial_data = trial.get_trials()
    reg_main= XGBRegressor(**best_params)
    reg.fit(X_main, Y_main)
    Y_pred_main = reg.predict(X_test)
    error_metrics_all = {
        "mse_error" : mean_squared_error(Y_pred_main,Y_test),
        "mae_error" : mean_absolute_error(Y_pred_main,Y_test),
        "rmse_error" : mean_squared_error(Y_pred_main,Y_test, squared=False),
        "r2_score" : r2_score(Y_pred_main,Y_test)}
        
    out_data[model_name] = { "best_params"  : best_params,
                             "trial_data" : trial_data,
                            "error_metric_all": error_metrics_all,
                            "model" : reg_main}
        
    return out_data                                 

In [19]:
train_fold = joblib.load("../exports/train_test_fold_data.z")
tot_fold = 15
fold_data = {}
for fold in range(tot_fold):
    model_out = train_main(X,Y, train_fold,fold)
    fold_data[fold] = model_out

[32m[I 2023-05-14 14:01:45,007][0m A new study created in memory with name: no-name-2897601a-beaf-4892-aacf-500931bc4880[0m


[0]	validation_0-rmse:3.74655
[1]	validation_0-rmse:3.71212
[2]	validation_0-rmse:3.67807
[3]	validation_0-rmse:3.64437
[4]	validation_0-rmse:3.61104
[5]	validation_0-rmse:3.57804




[6]	validation_0-rmse:3.54500
[7]	validation_0-rmse:3.51260
[8]	validation_0-rmse:3.48012
[9]	validation_0-rmse:3.44810
[10]	validation_0-rmse:3.41682
[11]	validation_0-rmse:3.38514
[12]	validation_0-rmse:3.35408
[13]	validation_0-rmse:3.32328
[14]	validation_0-rmse:3.29272
[15]	validation_0-rmse:3.26280
[16]	validation_0-rmse:3.23341
[17]	validation_0-rmse:3.20402
[18]	validation_0-rmse:3.17503
[19]	validation_0-rmse:3.14703
[20]	validation_0-rmse:3.11917
[21]	validation_0-rmse:3.09106
[22]	validation_0-rmse:3.06404
[23]	validation_0-rmse:3.03620
[24]	validation_0-rmse:3.00926
[25]	validation_0-rmse:2.98282
[26]	validation_0-rmse:2.95683
[27]	validation_0-rmse:2.93024
[28]	validation_0-rmse:2.90510
[29]	validation_0-rmse:2.87999
[30]	validation_0-rmse:2.85488
[31]	validation_0-rmse:2.82992
[32]	validation_0-rmse:2.80591
[33]	validation_0-rmse:2.78225
[34]	validation_0-rmse:2.75751
[35]	validation_0-rmse:2.73319
[36]	validation_0-rmse:2.70976
[37]	validation_0-rmse:2.68581
[38]	validat

[33m[W 2023-05-14 15:49:20,952][0m Trial 0 failed with parameters: {'xgb_est': 5000, 'xgb_lr': 0.01, 'xgb_booster': 'dart', 'xgb_treemethod': 'gpu_hist', 'xgp_predictor': 'gpu_predictor'} because of the following error: KeyboardInterrupt().[0m
Traceback (most recent call last):
  File "C:\Users\avyar\AppData\Local\Programs\Python\Python310\lib\site-packages\optuna\study\_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "C:\Users\avyar\AppData\Local\Temp\ipykernel_3188\3943374982.py", line 18, in objective
    reg.fit(X_main, Y_main,
  File "C:\Users\avyar\AppData\Local\Programs\Python\Python310\lib\site-packages\xgboost\core.py", line 620, in inner_f
    return func(**kwargs)
  File "C:\Users\avyar\AppData\Local\Programs\Python\Python310\lib\site-packages\xgboost\sklearn.py", line 1051, in fit
    self._Booster = train(
  File "C:\Users\avyar\AppData\Local\Programs\Python\Python310\lib\site-packages\xgboost\core.py", line 620, in inner_f
    return func

KeyboardInterrupt: 

0    (1002, 6) (1002,) (72, 6) (72,)




[0]	validation_0-rmse:3.78029
[1]	validation_0-rmse:3.77925
[2]	validation_0-rmse:3.77821
[3]	validation_0-rmse:3.77716
[4]	validation_0-rmse:3.77612
[5]	validation_0-rmse:3.77508
[6]	validation_0-rmse:3.77403
[7]	validation_0-rmse:3.77299
[8]	validation_0-rmse:3.77195
[9]	validation_0-rmse:3.77091
[10]	validation_0-rmse:3.76987
[11]	validation_0-rmse:3.76883
[12]	validation_0-rmse:3.76779
[13]	validation_0-rmse:3.76675
[14]	validation_0-rmse:3.76571
[15]	validation_0-rmse:3.76467
[16]	validation_0-rmse:3.76363
[17]	validation_0-rmse:3.76259
[18]	validation_0-rmse:3.76156
[19]	validation_0-rmse:3.76052
[20]	validation_0-rmse:3.75948
[21]	validation_0-rmse:3.75844
[22]	validation_0-rmse:3.75741
[23]	validation_0-rmse:3.75637
[24]	validation_0-rmse:3.75533
[25]	validation_0-rmse:3.75430
[26]	validation_0-rmse:3.75326
[27]	validation_0-rmse:3.75223
[28]	validation_0-rmse:3.75119
[29]	validation_0-rmse:3.75016
[30]	validation_0-rmse:3.74912
[31]	validation_0-rmse:3.74809
[32]	validation_0-



[36]	validation_0-rmse:4.01676
[37]	validation_0-rmse:4.01569
[38]	validation_0-rmse:4.01461
[39]	validation_0-rmse:4.01353
[40]	validation_0-rmse:4.01245
[41]	validation_0-rmse:4.01137
[42]	validation_0-rmse:4.01030
[43]	validation_0-rmse:4.00922
[44]	validation_0-rmse:4.00815
[45]	validation_0-rmse:4.00707
[46]	validation_0-rmse:4.00600
[47]	validation_0-rmse:4.00492
[48]	validation_0-rmse:4.00385
[49]	validation_0-rmse:4.00277
[50]	validation_0-rmse:4.00170
[51]	validation_0-rmse:4.00063
[52]	validation_0-rmse:3.99955
[53]	validation_0-rmse:3.99848
[54]	validation_0-rmse:3.99741
[55]	validation_0-rmse:3.99633
[56]	validation_0-rmse:3.99527
[57]	validation_0-rmse:3.99419
[58]	validation_0-rmse:3.99312
[59]	validation_0-rmse:3.99205
[60]	validation_0-rmse:3.99098
[61]	validation_0-rmse:3.98991
[62]	validation_0-rmse:3.98884
[63]	validation_0-rmse:3.98777
[64]	validation_0-rmse:3.98670
[65]	validation_0-rmse:3.98563
[66]	validation_0-rmse:3.98456
[67]	validation_0-rmse:3.98350
[68]	val



[33]	validation_0-rmse:3.98943
[34]	validation_0-rmse:3.98833
[35]	validation_0-rmse:3.98724
[36]	validation_0-rmse:3.98614
[37]	validation_0-rmse:3.98505
[38]	validation_0-rmse:3.98395
[39]	validation_0-rmse:3.98286
[40]	validation_0-rmse:3.98177
[41]	validation_0-rmse:3.98067
[42]	validation_0-rmse:3.97958
[43]	validation_0-rmse:3.97849
[44]	validation_0-rmse:3.97740
[45]	validation_0-rmse:3.97631
[46]	validation_0-rmse:3.97522
[47]	validation_0-rmse:3.97413
[48]	validation_0-rmse:3.97304
[49]	validation_0-rmse:3.97195
[50]	validation_0-rmse:3.97086
[51]	validation_0-rmse:3.96977
[52]	validation_0-rmse:3.96868
[53]	validation_0-rmse:3.96759
[54]	validation_0-rmse:3.96650
[55]	validation_0-rmse:3.96541
[56]	validation_0-rmse:3.96433
[57]	validation_0-rmse:3.96324
[58]	validation_0-rmse:3.96215
[59]	validation_0-rmse:3.96106
[60]	validation_0-rmse:3.95998
[61]	validation_0-rmse:3.95889
[62]	validation_0-rmse:3.95781
[63]	validation_0-rmse:3.95672
[64]	validation_0-rmse:3.95564
[65]	val



[34]	validation_0-rmse:3.70750
[35]	validation_0-rmse:3.70647
[36]	validation_0-rmse:3.70544
[37]	validation_0-rmse:3.70442
[38]	validation_0-rmse:3.70339
[39]	validation_0-rmse:3.70236
[40]	validation_0-rmse:3.70134
[41]	validation_0-rmse:3.70031
[42]	validation_0-rmse:3.69929
[43]	validation_0-rmse:3.69826
[44]	validation_0-rmse:3.69724
[45]	validation_0-rmse:3.69621
[46]	validation_0-rmse:3.69519
[47]	validation_0-rmse:3.69417
[48]	validation_0-rmse:3.69314
[49]	validation_0-rmse:3.69212
[50]	validation_0-rmse:3.69110
[51]	validation_0-rmse:3.69008
[52]	validation_0-rmse:3.68906
[53]	validation_0-rmse:3.68803
[54]	validation_0-rmse:3.68701
[55]	validation_0-rmse:3.68599
[56]	validation_0-rmse:3.68497
[57]	validation_0-rmse:3.68395
[58]	validation_0-rmse:3.68293
[59]	validation_0-rmse:3.68191
[60]	validation_0-rmse:3.68089
[61]	validation_0-rmse:3.67987
[62]	validation_0-rmse:3.67886
[63]	validation_0-rmse:3.67784
[64]	validation_0-rmse:3.67682
[65]	validation_0-rmse:3.67580
[66]	val



[34]	validation_0-rmse:3.72653
[35]	validation_0-rmse:3.72542
[36]	validation_0-rmse:3.72430
[37]	validation_0-rmse:3.72319
[38]	validation_0-rmse:3.72208
[39]	validation_0-rmse:3.72096
[40]	validation_0-rmse:3.71985
[41]	validation_0-rmse:3.71874
[42]	validation_0-rmse:3.71761
[43]	validation_0-rmse:3.71651
[44]	validation_0-rmse:3.71538
[45]	validation_0-rmse:3.71427
[46]	validation_0-rmse:3.71317
[47]	validation_0-rmse:3.71204
[48]	validation_0-rmse:3.71094
[49]	validation_0-rmse:3.70982
[50]	validation_0-rmse:3.70871
[51]	validation_0-rmse:3.70760
[52]	validation_0-rmse:3.70649
[53]	validation_0-rmse:3.70538
[54]	validation_0-rmse:3.70426
[55]	validation_0-rmse:3.70315
[56]	validation_0-rmse:3.70203
[57]	validation_0-rmse:3.70093
[58]	validation_0-rmse:3.69982
[59]	validation_0-rmse:3.69871
[60]	validation_0-rmse:3.69760
[61]	validation_0-rmse:3.69649
[62]	validation_0-rmse:3.69538
[63]	validation_0-rmse:3.69428
[64]	validation_0-rmse:3.69317
[65]	validation_0-rmse:3.69206
[66]	val



[25]	validation_0-rmse:4.06234
[26]	validation_0-rmse:4.06125
[27]	validation_0-rmse:4.06016
[28]	validation_0-rmse:4.05907
[29]	validation_0-rmse:4.05798
[30]	validation_0-rmse:4.05689
[31]	validation_0-rmse:4.05581
[32]	validation_0-rmse:4.05472
[33]	validation_0-rmse:4.05363
[34]	validation_0-rmse:4.05254
[35]	validation_0-rmse:4.05146
[36]	validation_0-rmse:4.05037
[37]	validation_0-rmse:4.04929
[38]	validation_0-rmse:4.04820
[39]	validation_0-rmse:4.04712
[40]	validation_0-rmse:4.04603
[41]	validation_0-rmse:4.04495
[42]	validation_0-rmse:4.04386
[43]	validation_0-rmse:4.04278
[44]	validation_0-rmse:4.04169
[45]	validation_0-rmse:4.04061
[46]	validation_0-rmse:4.03953
[47]	validation_0-rmse:4.03844
[48]	validation_0-rmse:4.03736
[49]	validation_0-rmse:4.03628
[50]	validation_0-rmse:4.03519
[51]	validation_0-rmse:4.03411
[52]	validation_0-rmse:4.03303
[53]	validation_0-rmse:4.03195
[54]	validation_0-rmse:4.03087
[55]	validation_0-rmse:4.02978
[56]	validation_0-rmse:4.02870
[57]	val



[35]	validation_0-rmse:3.84521
[36]	validation_0-rmse:3.84416
[37]	validation_0-rmse:3.84312
[38]	validation_0-rmse:3.84207
[39]	validation_0-rmse:3.84103
[40]	validation_0-rmse:3.83998
[41]	validation_0-rmse:3.83894
[42]	validation_0-rmse:3.83790
[43]	validation_0-rmse:3.83685
[44]	validation_0-rmse:3.83581
[45]	validation_0-rmse:3.83477
[46]	validation_0-rmse:3.83373
[47]	validation_0-rmse:3.83269
[48]	validation_0-rmse:3.83165
[49]	validation_0-rmse:3.83061
[50]	validation_0-rmse:3.82957
[51]	validation_0-rmse:3.82852
[52]	validation_0-rmse:3.82748
[53]	validation_0-rmse:3.82645
[54]	validation_0-rmse:3.82541
[55]	validation_0-rmse:3.82437
[56]	validation_0-rmse:3.82333
[57]	validation_0-rmse:3.82229
[58]	validation_0-rmse:3.82125
[59]	validation_0-rmse:3.82021
[60]	validation_0-rmse:3.81918
[61]	validation_0-rmse:3.81814
[62]	validation_0-rmse:3.81711
[63]	validation_0-rmse:3.81607
[64]	validation_0-rmse:3.81503
[65]	validation_0-rmse:3.81400
[66]	validation_0-rmse:3.81296
[67]	val



[33]	validation_0-rmse:4.00179
[34]	validation_0-rmse:4.00071
[35]	validation_0-rmse:3.99963
[36]	validation_0-rmse:3.99855
[37]	validation_0-rmse:3.99747
[38]	validation_0-rmse:3.99639
[39]	validation_0-rmse:3.99532
[40]	validation_0-rmse:3.99424
[41]	validation_0-rmse:3.99316
[42]	validation_0-rmse:3.99209
[43]	validation_0-rmse:3.99101
[44]	validation_0-rmse:3.98993
[45]	validation_0-rmse:3.98886
[46]	validation_0-rmse:3.98778
[47]	validation_0-rmse:3.98671
[48]	validation_0-rmse:3.98563
[49]	validation_0-rmse:3.98456
[50]	validation_0-rmse:3.98348
[51]	validation_0-rmse:3.98241
[52]	validation_0-rmse:3.98134
[53]	validation_0-rmse:3.98026
[54]	validation_0-rmse:3.97919
[55]	validation_0-rmse:3.97812
[56]	validation_0-rmse:3.97705
[57]	validation_0-rmse:3.97598
[58]	validation_0-rmse:3.97490
[59]	validation_0-rmse:3.97384
[60]	validation_0-rmse:3.97276
[61]	validation_0-rmse:3.97170
[62]	validation_0-rmse:3.97062
[63]	validation_0-rmse:3.96956
[64]	validation_0-rmse:3.96848
[65]	val



[30]	validation_0-rmse:3.97271
[31]	validation_0-rmse:3.97160
[32]	validation_0-rmse:3.97050
[33]	validation_0-rmse:3.96939
[34]	validation_0-rmse:3.96830
[35]	validation_0-rmse:3.96720
[36]	validation_0-rmse:3.96609
[37]	validation_0-rmse:3.96499
[38]	validation_0-rmse:3.96389
[39]	validation_0-rmse:3.96280
[40]	validation_0-rmse:3.96170
[41]	validation_0-rmse:3.96059
[42]	validation_0-rmse:3.95949
[43]	validation_0-rmse:3.95840
[44]	validation_0-rmse:3.95730
[45]	validation_0-rmse:3.95620
[46]	validation_0-rmse:3.95510
[47]	validation_0-rmse:3.95400
[48]	validation_0-rmse:3.95292
[49]	validation_0-rmse:3.95182
[50]	validation_0-rmse:3.95072
[51]	validation_0-rmse:3.94962
[52]	validation_0-rmse:3.94853
[53]	validation_0-rmse:3.94744
[54]	validation_0-rmse:3.94634
[55]	validation_0-rmse:3.94524
[56]	validation_0-rmse:3.94415
[57]	validation_0-rmse:3.94306
[58]	validation_0-rmse:3.94196
[59]	validation_0-rmse:3.94087
[60]	validation_0-rmse:3.93977
[61]	validation_0-rmse:3.93869
[62]	val