# Hyperparamter optimization using Scikit-learn and optuna

In [1]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm import tqdm

* drugs.csv 파일 안에는 실제 FDA의 승인을 받은 약과 약물이 아닌 non-drug 들이 존재한다. 

In [2]:
import pandas as pd
import numpy as np

* 데이터를 읽어들이자!

In [3]:
drugs = pd.read_csv("drugs.csv")

In [4]:
drugs

Unnamed: 0,smiles,is_drug
0,BrC1=CC2=C(NC(=O)CN=C2C2=CC=CC=N2)C=C1,1
1,C#CCN[C@@H]1CCC2=CC=CC=C12,1
2,C1(CC[C@@]2([C@@H](CC(N)=O)[C@@]3([C@@]4([N+]5...,1
3,C1C2CNCC1C1=C2C=C2N=CC=NC2=C1,1
4,C1CN2C[C@@H](N=C2S1)C1=CC=CC=C1,1
...,...,...
1594,CC1=CC=CC=C1NC(=O)C2=C(C=CC(=C2)NC(=O)C(C)N3CC...,0
1595,CC1=C(C(=CC(=C1)Br)C(=O)NC(=S)NC2=C(C(=CC=C2)[...,0
1596,CC(C)(C)OC(=O)N1CCOC(C(C1)C2=CC(=C(C=C2)Cl)Cl)CO,0
1597,CCCCCCCCCCCC=COC(=O)C1=C(C(=CC=C1)S(=O)(=O)[O-...,0


## 실습의 목적은 Drug과 non-drug을 구분하는 모델을 만들어보자! 

* 그 후에, 이 모델의 성능을 최적화 시키는 연습을 해보자!

* Descriptor로 변환하자. 

In [6]:
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.GraphDescriptors import (BalabanJ, BertzCT, Chi0, Chi0n, Chi0v, Chi1,
                                         Chi1n, Chi1v, Chi2n, Chi2v, Chi3n, Chi3v, Chi4n, Chi4v,
                                         HallKierAlpha, Ipc, Kappa1, Kappa2, Kappa3)

from rdkit.Chem.EState.EState_VSA import (EState_VSA1, EState_VSA10, EState_VSA11, EState_VSA2, EState_VSA3,
                                          EState_VSA4, EState_VSA5, EState_VSA6, EState_VSA7, EState_VSA8, EState_VSA9,
                                          VSA_EState1, VSA_EState10, VSA_EState2, VSA_EState3, VSA_EState4, VSA_EState5,
                                          VSA_EState6, VSA_EState7, VSA_EState8, VSA_EState9,)

from rdkit.Chem.Descriptors import (ExactMolWt, MolWt, HeavyAtomMolWt, MaxAbsPartialCharge, MinPartialCharge,
                                    MaxPartialCharge, MinAbsPartialCharge, NumRadicalElectrons, NumValenceElectrons)

from rdkit.Chem.EState.EState import (MaxAbsEStateIndex, MaxEStateIndex, MinAbsEStateIndex, MinEStateIndex,)

from rdkit.Chem.Lipinski import (FractionCSP3, HeavyAtomCount, NHOHCount, NOCount, NumAliphaticCarbocycles,
                                 NumAliphaticHeterocycles, NumAliphaticRings, NumAromaticCarbocycles, NumAromaticHeterocycles,
                                 NumAromaticRings, NumHAcceptors, NumHDonors, NumHeteroatoms, RingCount,
                                 NumRotatableBonds, NumSaturatedCarbocycles, NumSaturatedHeterocycles, NumSaturatedRings,)

from rdkit.Chem.Crippen import (MolLogP, MolMR, )

from rdkit.Chem.MolSurf import (LabuteASA, PEOE_VSA1, PEOE_VSA10, PEOE_VSA11, PEOE_VSA12, PEOE_VSA13, PEOE_VSA14,
                                PEOE_VSA2, PEOE_VSA3,PEOE_VSA4, PEOE_VSA5, PEOE_VSA6, PEOE_VSA7, PEOE_VSA8, PEOE_VSA9,
                                SMR_VSA1, SMR_VSA10, SMR_VSA2, SMR_VSA3, SMR_VSA4, SMR_VSA5, SMR_VSA6,
                                SMR_VSA7, SMR_VSA8, SMR_VSA9, SlogP_VSA1, SlogP_VSA10, SlogP_VSA11, SlogP_VSA12,
                                SlogP_VSA2, SlogP_VSA3,SlogP_VSA4, SlogP_VSA5, SlogP_VSA6, SlogP_VSA7, SlogP_VSA8,
                                SlogP_VSA9, TPSA, )

from rdkit.Chem.Fragments import (fr_Al_COO, fr_Al_OH, fr_Al_OH_noTert, fr_ArN, fr_Ar_COO, fr_Ar_N, fr_Ar_NH,
 fr_Ar_OH, fr_COO, fr_COO2, fr_C_O, fr_C_O_noCOO, fr_C_S, fr_HOCCN, fr_Imine, fr_NH0, fr_NH1,
 fr_NH2, fr_N_O, fr_Ndealkylation1, fr_Ndealkylation2, fr_Nhpyrrole, fr_SH, fr_aldehyde, fr_alkyl_carbamate,
 fr_alkyl_halide, fr_allylic_oxid, fr_amide, fr_amidine, fr_aniline, fr_aryl_methyl, fr_azide, fr_azo, fr_barbitur,
 fr_benzene, fr_benzodiazepine, fr_bicyclic, fr_diazo, fr_dihydropyridine, fr_epoxide, fr_ester, fr_ether, fr_furan,
 fr_guanido, fr_halogen, fr_hdrzine, fr_hdrzone, fr_imidazole, fr_imide, fr_isocyan, fr_isothiocyan, fr_ketone,
 fr_ketone_Topliss, fr_lactam, fr_lactone, fr_methoxy, fr_morpholine, fr_nitrile, fr_nitro, fr_nitro_arom,
 fr_nitro_arom_nonortho, fr_nitroso, fr_oxazole, fr_oxime, fr_para_hydroxylation, fr_phenol,
 fr_phenol_noOrthoHbond, fr_phos_acid, fr_phos_ester, fr_piperdine, fr_piperzine, fr_priamide, fr_prisulfonamd,
 fr_pyridine, fr_quatN, fr_sulfide, fr_sulfonamd, fr_sulfone, fr_term_acetylene, fr_tetrazole, fr_thiazole, fr_thiocyan,
 fr_thiophene, fr_unbrch_alkane, fr_urea)

# Descriptor 계산 수행 함수. 
def calc_descriptors(mol):
    if mol is None:
        print("Molecule is None!")
        return None
    else:
        AllChem.ComputeGasteigerCharges(mol)
        finger = [
            BalabanJ(mol) , # 0
            BertzCT(mol) , # 1
            Chi0(mol) , # 2
            Chi0n(mol) , # 3
            Chi0v(mol) , # 4
            Chi1(mol) , # 5
            Chi1n(mol) , # 6
            Chi1v(mol) , # 7
            Chi2n(mol) ,
            Chi2v(mol) ,
            Chi3n(mol) ,
            Chi3v(mol) ,
            Chi4n(mol) ,
            Chi4v(mol) ,
            EState_VSA1(mol) ,
            EState_VSA10(mol) ,
            EState_VSA11(mol) ,
            EState_VSA2(mol) ,
            EState_VSA3(mol) ,
            EState_VSA4(mol) ,
            EState_VSA5(mol) ,
            EState_VSA6(mol) ,
            EState_VSA7(mol) ,
            EState_VSA8(mol) ,
                EState_VSA9(mol) ,
                ExactMolWt(mol) ,
                FractionCSP3(mol) ,
                HallKierAlpha(mol) ,
                HeavyAtomCount(mol) ,
                HeavyAtomMolWt(mol) ,
                # Ipc(mol) ,
                Kappa1(mol) ,
                Kappa2(mol) ,
                Kappa3(mol) ,
                LabuteASA(mol) ,
                MaxAbsEStateIndex(mol) ,
                MaxAbsPartialCharge(mol) ,
                MaxEStateIndex(mol) ,
                MaxPartialCharge(mol) ,
                MinAbsEStateIndex(mol) ,
                MinAbsPartialCharge(mol) ,
                MinEStateIndex(mol) ,
                MinPartialCharge(mol) ,
                MolLogP(mol) ,
                MolMR(mol) ,
                MolWt(mol) ,
                NHOHCount(mol) ,
                NOCount(mol) ,
                NumAliphaticCarbocycles(mol) ,
                NumAliphaticHeterocycles(mol) ,
                NumAliphaticRings(mol) ,
                NumAromaticCarbocycles(mol) ,
                NumAromaticHeterocycles(mol) ,
                NumAromaticRings(mol) ,
                NumHAcceptors(mol) ,
                NumHDonors(mol) ,
                NumHeteroatoms(mol) ,
                NumRadicalElectrons(mol) ,
                NumRotatableBonds(mol) ,
                NumSaturatedCarbocycles(mol) ,
                NumSaturatedHeterocycles(mol) ,
                NumSaturatedRings(mol) ,
                NumValenceElectrons(mol) ,
                PEOE_VSA1(mol) ,
                PEOE_VSA10(mol) ,
                PEOE_VSA11(mol) ,
                PEOE_VSA12(mol) ,
                PEOE_VSA13(mol) ,
                PEOE_VSA14(mol) ,
                PEOE_VSA2(mol) ,
                PEOE_VSA3(mol) ,
                PEOE_VSA4(mol) ,
                PEOE_VSA5(mol) ,
                PEOE_VSA6(mol) ,
                PEOE_VSA7(mol) ,
                PEOE_VSA8(mol) ,
                PEOE_VSA9(mol) ,
                RingCount(mol) ,
                SMR_VSA1(mol) ,
                SMR_VSA10(mol) ,
                SMR_VSA2(mol) ,
                SMR_VSA3(mol) ,
                SMR_VSA4(mol) ,
                SMR_VSA5(mol) ,
                SMR_VSA6(mol) ,
                SMR_VSA7(mol) ,
                SMR_VSA8(mol) ,
                SMR_VSA9(mol) ,
                SlogP_VSA1(mol) ,
                SlogP_VSA10(mol) ,
                SlogP_VSA11(mol) ,
                SlogP_VSA12(mol) ,
                SlogP_VSA2(mol) ,
                SlogP_VSA3(mol) ,
                SlogP_VSA4(mol) ,
                SlogP_VSA5(mol) ,
                SlogP_VSA6(mol) ,
                SlogP_VSA7(mol) ,
                SlogP_VSA8(mol) ,
                SlogP_VSA9(mol) ,
                TPSA(mol) ,
                VSA_EState1(mol) ,
                VSA_EState10(mol) ,
                VSA_EState2(mol) ,
                VSA_EState3(mol) ,
                VSA_EState4(mol) ,
                VSA_EState5(mol) ,
                VSA_EState6(mol) ,
                VSA_EState7(mol) ,
                VSA_EState8(mol) ,
                VSA_EState9(mol) ,
                fr_Al_COO(mol) ,
                fr_Al_OH(mol) ,
                fr_Al_OH_noTert(mol) ,
                fr_ArN(mol) ,
                fr_Ar_COO(mol) ,
                fr_Ar_N(mol) ,
                fr_Ar_NH(mol) ,
                fr_Ar_OH(mol) ,
                fr_COO(mol) ,
                fr_COO2(mol) ,
                fr_C_O(mol) ,
                fr_C_O_noCOO(mol) ,
                fr_C_S(mol) ,
                fr_HOCCN(mol) ,
                fr_Imine(mol) ,
                fr_NH0(mol) ,
                fr_NH1(mol) ,
                fr_NH2(mol) ,
                fr_N_O(mol) ,
                fr_Ndealkylation1(mol) ,
                fr_Ndealkylation2(mol) ,
                fr_Nhpyrrole(mol) ,
                fr_SH(mol) ,
                fr_aldehyde(mol) ,
                fr_alkyl_carbamate(mol) ,
                fr_alkyl_halide(mol) ,
                fr_allylic_oxid(mol) ,
                fr_amide(mol) ,
                fr_amidine(mol) ,
                fr_aniline(mol) ,
                fr_aryl_methyl(mol) ,
                fr_azide(mol) ,
                fr_azo(mol) ,
                fr_barbitur(mol) ,
                fr_benzene(mol) ,
                fr_benzodiazepine(mol) ,
                fr_bicyclic(mol) ,
                fr_diazo(mol) ,
                fr_dihydropyridine(mol) ,
                fr_epoxide(mol) ,
                fr_ester(mol) ,
                fr_ether(mol) ,
                fr_furan(mol) ,
                fr_guanido(mol) ,
                fr_halogen(mol) ,
                fr_hdrzine(mol) ,
                fr_hdrzone(mol) ,
                fr_imidazole(mol) ,
                fr_imide(mol) ,
                fr_isocyan(mol) ,
                fr_isothiocyan(mol) ,
                fr_ketone(mol) ,
                fr_ketone_Topliss(mol) ,
                fr_lactam(mol) ,
                fr_lactone(mol) ,
                fr_methoxy(mol) ,
                fr_morpholine(mol) ,
                fr_nitrile(mol) ,
                fr_nitro(mol) ,
                fr_nitro_arom(mol) ,
                fr_nitro_arom_nonortho(mol) ,
                fr_nitroso(mol) ,
                fr_oxazole(mol) ,
                fr_oxime(mol) ,
                fr_para_hydroxylation(mol) ,
                fr_phenol(mol) ,
                fr_phenol_noOrthoHbond(mol) ,
                fr_phos_acid(mol) ,
                fr_phos_ester(mol) ,
                fr_piperdine(mol) ,
                fr_piperzine(mol) ,
                fr_priamide(mol) ,
                fr_prisulfonamd(mol) ,
                fr_pyridine(mol) ,
                fr_quatN(mol) ,
                fr_sulfide(mol) ,
                fr_sulfonamd(mol) ,
                fr_sulfone(mol) ,
                fr_term_acetylene(mol) ,
                fr_tetrazole(mol) ,
                fr_thiazole(mol) ,
                fr_thiocyan(mol) ,
                fr_thiophene(mol),
                fr_unbrch_alkane(mol) ,
                fr_urea(mol) , #rdkit properties # 196
                ]
        return finger

In [20]:
import math

desc_list = [] # descriptor 리스트 
is_drug = [] # True/False 저장할 리스트
for smi, flag in zip(drugs["smiles"], drugs["is_drug"]):
    m = Chem.MolFromSmiles(smi)
    if m is None: # smiles => mol 변환시 에러가 나면 즉, smiles가 문제가 있을 때, skip!
        continue
    desc = calc_descriptors(m)

    # NaN 존재하는지 확인 후, NaN이 있으면 skip!
    if any([math.isnan(v) for v in desc]):
        print(f"{smi} caused NaN")
        continue
        
    desc_list.append(desc)
    if flag == 1:
        is_drug.append(True)
    else:
        is_drug.append(False)


C1(CC[C@@]2([C@@H](CC(N)=O)[C@@]3([C@@]4([N+]5=C([C@H]([C@@]4(CC(N)=O)C)CCC(N)=O)C(C)=C4[N+]6=C(C=C7[N+]8=C([C@H](C7(C)C)CCC(N)=O)C(C)=C2N3[Co-3]568([N+]2=CN([C@H]3O[C@@H]([C@@H](OP(O[C@@H](CN1)C)([O-])=O)[C@H]3O)CO)C1=CC(C)=C(C=C21)C)C)[C@H]([C@@]4(CC(N)=O)C)CCC(N)=O)C)[H])C)=O caused NaN
CCP(CC)(CC)=[Au]S[C@@H]1O[C@H](COC(C)=O)[C@@H](OC(C)=O)[C@H](OC(C)=O)[C@H]1OC(C)=O caused NaN
C1=CC=C(C=C1)P2(C3=CC=CC=C3C(O2)C4=CC=C(C=C4)Br)(C5=CC=CC=C5)C6=CC=CC=C6 caused NaN
CC(C)(C)OC(=O)N1CC(=O)CC1C(C(CO)[Se]C2=CC=CC=C2)OCOCCOC caused NaN




CCCC[Se]C#N caused NaN


In [21]:
len(desc_list)

1594

In [22]:
len(is_drug)

1594

In [11]:
desc_list

[[2.1951789491913214,
  667.6462238772621,
  13.242276208189782,
  9.929198619986938,
  11.515195159101642,
  9.237183443017882,
  5.795009029901175,
  6.588007299458526,
  4.117894526245813,
  5.033570055037502,
  2.7991137707034355,
  3.2923654094204746,
  1.9689695238602891,
  2.3300547843177695,
  0.0,
  4.794537184071822,
  0.0,
  12.45193613526408,
  0.0,
  27.129170279832135,
  0.0,
  6.196843571613076,
  36.39820241076966,
  31.223115755538558,
  0.0,
  315.00072404,
  0.07142857142857142,
  -2.01,
  19,
  306.07800000000003,
  12.045931158579348,
  4.9322599810202385,
  2.4161294526180095,
  118.2713948713214,
  11.657272022234821,
  0.3238397436703982,
  11.657272022234821,
  0.24557142127850587,
  0.10933201058201036,
  0.24557142127850587,
  -0.11620039682539662,
  -0.3238397436703982,
  2.6336000000000004,
  77.51470000000002,
  316.158,
  1,
  4,
  0,
  1,
  1,
  1,
  1,
  2,
  3,
  1,
  5,
  0,
  1,
  0,
  0,
  0,
  94,
  5.316788604006331,
  6.544756405912575,
  0.0,
  

In [23]:
X = np.array(desc_list)

In [24]:
y = np.array(is_drug)

In [25]:
from sklearn.model_selection import train_test_split

In [26]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state=42)

## RandomForestClassifier를 한 번 테스트 해보자!

In [27]:
from sklearn.ensemble import RandomForestClassifier as RFC

In [28]:
my_model = RFC()

In [29]:
my_model.fit(X_train, y_train)

RandomForestClassifier()

In [30]:
y_pred = my_model.predict(X_test)

* Accuracy 계산

In [34]:
from sklearn.metrics import accuracy_score

In [35]:
accuracy_score(y_test, y_pred)

0.7398119122257053

In [37]:
my_model.get_params()

{'bootstrap': True,
 'ccp_alpha': 0.0,
 'class_weight': None,
 'criterion': 'gini',
 'max_depth': None,
 'max_features': 'auto',
 'max_leaf_nodes': None,
 'max_samples': None,
 'min_impurity_decrease': 0.0,
 'min_impurity_split': None,
 'min_samples_leaf': 1,
 'min_samples_split': 2,
 'min_weight_fraction_leaf': 0.0,
 'n_estimators': 100,
 'n_jobs': None,
 'oob_score': False,
 'random_state': None,
 'verbose': 0,
 'warm_start': False}

In [39]:
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV

In [40]:
params = {"n_estimators":[100, 150], 
          "min_samples_split": [2, 3, 4], 
          "max_features":[0.3, 0.5, 0.7]}

* GridSearch를 수행한다는 것은 2 X 3 X 3의 18개의 parameter 조합을 테스트 하겠다. 
* Default metrics 은 classification의 경우, accuracy, regression의 경우, mean_squered_error
* metrics 은 다른 metrics을 사용 가능

In [42]:
clf = RFC(n_jobs = -1)
grid_search = GridSearchCV(clf, param_grid=params)

In [43]:
grid_search.fit(X, y)

GridSearchCV(estimator=RandomForestClassifier(n_jobs=-1),
             param_grid={'max_features': [0.3, 0.5, 0.7],
                         'min_samples_split': [2, 3, 4],
                         'n_estimators': [100, 150]})

* 실제 grid_search의 결과 확인을 위해서 cv_results_ 라고 하는 attribute를 출력. 

In [44]:
grid_search.cv_results_

{'mean_fit_time': array([1.21606536, 1.09780893, 0.82537427, 1.24706044, 0.93272696,
        0.99437327, 0.84112811, 1.18829589, 0.79524379, 1.19081059,
        0.88842535, 1.29439955, 1.10393324, 1.6148355 , 1.14509001,
        1.55198679, 1.0554172 , 1.72842197]),
 'std_fit_time': array([1.05101156, 0.08946257, 0.04745746, 0.14285106, 0.09943324,
        0.09815457, 0.03383912, 0.07974902, 0.00616749, 0.0477885 ,
        0.0560411 , 0.06847485, 0.06188498, 0.0573029 , 0.08377219,
        0.07477711, 0.08128355, 0.19194055]),
 'mean_score_time': array([0.03349261, 0.05481219, 0.03742018, 0.06687193, 0.04259548,
        0.04527297, 0.02185359, 0.04795184, 0.03077407, 0.0450212 ,
        0.03421807, 0.05007782, 0.03138342, 0.04729209, 0.03596978,
        0.04058838, 0.02771602, 0.03725071]),
 'std_score_time': array([0.00449981, 0.00793232, 0.00405409, 0.01209828, 0.00265462,
        0.00956476, 0.00147635, 0.00802208, 0.0048704 , 0.00433452,
        0.00258637, 0.00228931, 0.00326983, 

In [45]:
# Utility function to report best scores
def report(results, n_top=3):
    for i in range(1, n_top + 1):
        candidates = np.flatnonzero(results['rank_test_score'] == i)
        for candidate in candidates:
            print("Model with rank: {0}".format(i))
            print("Mean validation score: {0:.3f} (std: {1:.3f})"
                  .format(results['mean_test_score'][candidate],
                          results['std_test_score'][candidate]))
            print("Parameters: {0}".format(results['params'][candidate]))
            print("")

In [46]:
report(grid_search.cv_results_)

Model with rank: 1
Mean validation score: 0.683 (std: 0.031)
Parameters: {'max_features': 0.3, 'min_samples_split': 2, 'n_estimators': 150}

Model with rank: 2
Mean validation score: 0.674 (std: 0.038)
Parameters: {'max_features': 0.5, 'min_samples_split': 3, 'n_estimators': 150}

Model with rank: 3
Mean validation score: 0.673 (std: 0.018)
Parameters: {'max_features': 0.3, 'min_samples_split': 4, 'n_estimators': 100}



### RandomizedSearch를 테스트 

In [48]:
from time import time
n_iter = 5 # 5번만 테스트. 
start = time()
random_search = RandomizedSearchCV(clf, param_distributions = params, n_iter = n_iter)
random_search.fit(X, y)
print("RandomizedSearchCV took %.2f seconds for %d candidates"
      " parameter settings." % ((time() - start), n_iter))

RandomizedSearchCV took 46.06 seconds for 5 candidates parameter settings.


In [49]:
report(random_search.cv_results_)

Model with rank: 1
Mean validation score: 0.673 (std: 0.029)
Parameters: {'n_estimators': 150, 'min_samples_split': 3, 'max_features': 0.5}

Model with rank: 2
Mean validation score: 0.664 (std: 0.038)
Parameters: {'n_estimators': 150, 'min_samples_split': 3, 'max_features': 0.7}

Model with rank: 3
Mean validation score: 0.662 (std: 0.025)
Parameters: {'n_estimators': 150, 'min_samples_split': 4, 'max_features': 0.7}



## Optuna 라고 하는 hyperparameter package를 테스트

* https://optuna.org/
* 베이지안 최적화 방법을 통해서 단순 무작위가 아니라 기존의 결과에 기반하여 효율적으로 hyperparameter를 찾는다. 

* !pip install optuna 이 명령을 이용해서 설치 가능!

In [52]:
import optuna
import sklearn

def objective(trial): # trial은 optuna의 class이다. 
    
    # RandomForest 모델을 몇번 split할 것인가? 정수 값을 받는 파라미터. 
    # 정수로 정의되는 parameter는 suggest_int 라고 하는 method를 사용. 
    rf_max_depth = trial.suggest_int('rf_max_depth', 2, 32)

    # n_estimator의 범위를 지정. 
    n_tree = trial.suggest_int('n_estimators', 50, 200)
    
    # min_samples_split이라고 하는 parameter는 2~5 사이에서 찾도록 지정. 
    min_samples_split = trial.suggest_int('min_samples_split', 2, 5)
    
    # max_features 는 실수로 주어지는 feature.
    # 그러므로 suggest_float 이라고하는 method를 사용. 
    max_features = trial.suggest_float('max_features', 0.3, 0.9)
    
    ## 범위 지정이 끝나면 model을 정의. 
    regressor_obj = RFC(n_estimators = n_tree, 
                        min_samples_split = min_samples_split, 
                        max_depth=rf_max_depth, 
                        max_features = max_features, 
                        n_jobs=-1)
    # 모델 fitting. 
    regressor_obj.fit(X_train, y_train)

    # 모델을 validation set으로 검증. 
    y_pred = regressor_obj.predict(X_test)
    
    ## optuna에서는 목적함수 값을 최소화 하는 방향으로 최적화를 수행한다. 
    error = 1.0 - sklearn.metrics.accuracy_score(y_test, y_pred)
    
    return error

In [53]:
study = optuna.create_study()
study.optimize(objective, n_trials=10)

[32m[I 2021-08-28 13:08:00,497][0m A new study created in memory with name: no-name-08b1aa06-7969-49b4-8d63-a8e81ff56fb3[0m
[32m[I 2021-08-28 13:08:02,070][0m Trial 0 finished with value: 0.2163009404388715 and parameters: {'rf_max_depth': 12, 'n_estimators': 116, 'min_samples_split': 4, 'max_features': 0.8153336371765685}. Best is trial 0 with value: 0.2163009404388715.[0m
[32m[I 2021-08-28 13:08:02,888][0m Trial 1 finished with value: 0.24764890282131657 and parameters: {'rf_max_depth': 12, 'n_estimators': 69, 'min_samples_split': 5, 'max_features': 0.6921385477303101}. Best is trial 0 with value: 0.2163009404388715.[0m
[32m[I 2021-08-28 13:08:04,646][0m Trial 2 finished with value: 0.2288401253918495 and parameters: {'rf_max_depth': 30, 'n_estimators': 154, 'min_samples_split': 5, 'max_features': 0.6297415320236064}. Best is trial 0 with value: 0.2163009404388715.[0m
[32m[I 2021-08-28 13:08:04,973][0m Trial 3 finished with value: 0.23510971786833856 and parameters: {'r

In [54]:
study.best_params 

{'rf_max_depth': 16,
 'n_estimators': 70,
 'min_samples_split': 5,
 'max_features': 0.6511023977953498}

### optuna 예제 링크 

* https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/002_configurations.html#sphx-glr-tutorial-10-key-features-002-configurations-py