In [39]:
import deepchem as dc
from sklearn.ensemble import RandomForestRegressor
import optuna
optuna.logging.set_verbosity(optuna.logging.CRITICAL)

import pandas as pd
import numpy as np

In [58]:
def load_split_dfs(input_dir):
    test = []
    for i in range(5):
        test.append(pd.read_pickle(f'{input_dir}/test{i}'))
    return test

def ds_from_df_split(split_dfs, featurizer):
    split_dss = []
    for i in range(5):
        df = split_dfs[i]
        X = featurizer.featurize(df.smiles)
        ds = dc.data.DiskDataset.from_numpy(X=X, y=np.vstack(df.label.to_numpy()), ids=df.smiles)
        split_dss.append(ds)
    all_dss = dc.data.DiskDataset.merge(split_dss)
    
    transformer = dc.trans.NormalizationTransformer(transform_y=True, dataset=all_dss)
    for i in range(5):
        split_dss[i] = transformer.transform(split_dss[i])
    
    return all_dss, split_dss, transformer

def get_kfold_from_ds_split(split_dss):
    kfold = []
    for i in range(5):
        temp_dss = split_dss.copy()
        temp_test = temp_dss.pop(i)
        kfold.append((dc.data.DiskDataset.merge(temp_dss), temp_test))
    return kfold

In [66]:
def random_forest_model_from_trial(trial):
    param = {
        'n_estimators': trial.suggest_int('n_estimators', 1, 500, log=True),
        'max_depth': trial.suggest_int('max_depth', 1, 50),
        'min_samples_split': trial.suggest_int('min_samples_split', 2, 1000, log=True),
        'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 1000, log=True),
        'max_leaf_nodes': trial.suggest_int('max_leaf_nodes', 2, 100, log=True),
    }
    sklearn_random_forest = RandomForestRegressor(**param, n_jobs=-1)
    rf_model = dc.models.SklearnModel(sklearn_random_forest)
    return rf_model

def random_forest_model_from_param(param):
    sklearn_random_forest = RandomForestRegressor(**param, n_jobs=-1)
    rf_model = dc.models.SklearnModel(sklearn_random_forest)
    return rf_model

def random_forest_optuna(trial, kfold):
    mse = []
    for k in kfold:
        rf_model = random_forest_model_from_trial(trial)
        rf_model.fit(k[0].complete_shuffle())
        y_pred = rf_model.predict(k[1])
        y_meas = k[1].y
        mse.append(dc.metrics.mean_squared_error(y_meas, y_pred))
        
    return sum(mse)/len(mse)    

In [57]:
def get_feedforward_NN_model(trial):
    pass

def feedforward_NN_optuna(trial):
    pass

In [3]:
split_dfs = load_split_dfs('sma1_random_split')

In [4]:
featurizer = dc.feat.CircularFingerprint(radius=2, size=2048, chiral=True)

In [5]:
all_dss, split_dss, transformer = ds_from_df_split(split_dfs, featurizer)

In [21]:
train_tests = get_kfold_from_ds_split(split_dss)

In [63]:
output_info = []

for i,tt in enumerate(train_tests):
    splitter = dc.splits.RandomSplitter()
    kfold = splitter.k_fold_split(dataset=tt[0], k=5)
    study = optuna.create_study(direction='minimize')
    study.optimize(lambda trial: random_forest_optuna(trial, kfold), n_trials=300)
    # study.optimize(lambda trial: random_forest_optuna(trial, kfold), timeout=30)
    
    test_mse = []
    for i in range(5):
        tuned_rf_model = random_forest_model_from_param(study.best_params)
        tuned_rf_model.fit(tt[0].complete_shuffle()) 
        y_pred = tuned_rf_model.predict(tt[1])
        y_meas = tt[1].y
        test_mse.append(dc.metrics.mean_squared_error(y_meas, y_pred))
    
    output_info.append((i, study.best_value, str(study.best_params), sum(test_mse)/len(test_mse), test_mse))

In [65]:
out_df = pd.DataFrame(output_info, columns=['split_index', 'avg_valid_mse', 'best_params', 'avg_test_mse', 'test_mses'])
out_df

Unnamed: 0,split_index,avg_valid_mse,best_params,avg_test_mse,test_mses
0,4,0.992522,"{'n_estimators': 68, 'max_depth': 27, 'min_sam...",0.797814,"[0.8055473928327958, 0.7891457813126309, 0.790..."
1,4,0.988161,"{'n_estimators': 1, 'max_depth': 40, 'min_samp...",1.048322,"[1.05021917445322, 1.0481754097235654, 1.04977..."
2,4,0.967374,"{'n_estimators': 37, 'max_depth': 17, 'min_sam...",0.860698,"[0.8661830998618297, 0.8588125426985327, 0.846..."
3,4,0.994325,"{'n_estimators': 1, 'max_depth': 13, 'min_samp...",1.132869,"[1.1345582014018643, 1.124664530203219, 1.1758..."
4,4,1.035705,"{'n_estimators': 1, 'max_depth': 18, 'min_samp...",1.094103,"[1.1290717457640633, 1.0864907610822976, 1.110..."
