In [1]:
# Numpy and Pandas
import numpy as np
import pandas as pd

# Pre-processing and setup functions
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from tensorflow.keras.utils import to_categorical

# Algorithms
from sklearn.ensemble import RandomForestClassifier

# Report and model validation
from sklearn.metrics import classification_report

# Model persistence
from joblib import dump, load

pd.set_option('display.max_columns', 500)

In [2]:
exo2 = pd.read_csv('transformed_features.csv')

In [3]:
exo2.sample(10)

Unnamed: 0.1,Unnamed: 0,koi_disposition,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag,log_koi_period,log_koi_period_err1,log_koi_time0bk,log_koi_time0bk_err1,log_koi_impact,log_koi_impact_err1,log_koi_duration,log_koi_duration_err1,log_koi_depth,log_koi_depth_err1,log_koi_prad,log_koi_prad_err1,log_koi_teq,log_koi_insol,log_koi_insol_err1,log_koi_model_snr,log_koi_tce_plnt_num,log_koi_steff,log_koi_steff_err1,log_koi_slogg,log_koi_slogg_err1,log_koi_srad,log_koi_srad_err1,log_koi_srad_err2,log_ra,log_dec,log_koi_kepmag
3172,3172,CONFIRMED,0,0,0,0,3.917947,2.7e-05,-2.7e-05,132.86384,0.00497,-0.00497,0.337,0.126,-0.337,2.549,0.163,-0.163,116.5,10.1,-10.1,0.94,0.11,-0.04,1030,266.81,86.69,-35.37,13.5,1,5500,74,-82,4.547,0.019,-0.11,0.87,0.105,-0.035,291.25027,45.655071,14.75,1.365568,-10.530847,4.889325,-5.304335,-1.084709,-2.063568,0.935701,-1.807889,4.757891,2.312634,-0.061875,-2.198225,6.937314,5.586537,4.46235,2.60269,0.0,8.612503,4.304079,1.514468,-3.912023,-0.139262,-2.244316,-3.381395,5.674183,3.821115,2.691243
1282,1282,CANDIDATE,0,0,0,0,8.847843,3.4e-05,-3.4e-05,140.15085,0.00286,-0.00286,0.047,0.411,-0.047,2.0572,0.0995,-0.0995,460.2,25.8,-25.8,1.15,0.07,-0.07,479,12.48,2.69,-2.38,21.1,1,3867,77,-84,4.718,0.027,-0.033,0.552,0.033,-0.033,285.08655,40.898998,14.911,2.180174,-10.298013,4.942719,-5.856934,-3.036554,-0.886732,0.721346,-2.297598,6.131661,3.250413,0.139762,-2.645075,6.171701,2.524127,0.989913,3.049273,0.0,8.260234,4.343818,1.551385,-3.575551,-0.594207,-3.381395,-3.442019,5.652793,3.711106,2.702099
4375,4375,FALSE POSITIVE,1,0,0,0,438.81524,0.03031,-0.03031,307.9493,0.0267,-0.0267,0.108,0.3367,-0.1079,6.75,1.0,-1.0,332.8,55.4,-55.4,1.36,0.34,-0.13,202,0.39,0.31,-0.1,9.2,1,5324,173,-137,4.551,0.057,-0.153,0.755,0.189,-0.071,292.41116,42.039478,13.589,6.084078,-3.496278,5.729935,-3.623092,-2.216407,-1.085597,1.909543,0.001,5.807542,4.014598,0.307485,-1.075873,5.308268,-0.941609,-1.167962,2.219203,0.0,8.57998,5.153297,1.515347,-2.847312,-0.281038,-1.660731,-2.65926,5.678161,3.738609,2.609261
5079,5079,CANDIDATE,0,0,0,0,485.92297,0.01167,-0.01167,299.5015,0.015,-0.015,0.4516,0.0439,-0.4516,3.756,0.819,-0.819,122.9,27.1,-27.1,1.07,0.47,-0.16,225,0.61,0.79,-0.21,5.5,3,5551,159,-131,4.386,0.143,-0.273,0.957,0.412,-0.147,285.8627,39.20528,12.12,6.18605,-4.450734,5.702119,-4.199705,-0.792747,-3.103317,1.323355,-0.198451,4.811371,3.299571,0.067659,-0.752897,5.4161,-0.494296,-0.234457,1.704748,1.098612,8.621733,5.06891,1.478418,-1.937942,-0.043952,-0.884308,-1.924149,5.655512,3.668811,2.494857
1355,1355,CONFIRMED,0,0,0,0,8.508333,1e-05,-1e-05,134.785645,0.000965,-0.000965,0.004,0.413,-0.004,3.4559,0.0353,-0.0353,646.5,8.2,-8.2,2.6,0.37,-0.28,909,161.66,65.86,-42.15,93.4,1,5767,104,-115,4.396,0.09,-0.11,1.03,0.149,-0.108,295.95135,43.852051,14.176,2.141046,-11.542354,4.903686,-6.943382,-5.298317,-0.881889,1.240083,-3.315938,6.471573,2.104256,0.955511,-0.991553,6.812345,5.085495,4.187546,4.536891,0.0,8.659907,4.644401,1.480695,-2.396896,0.029559,-1.89712,-2.234926,5.690195,3.780821,2.65155
2017,2017,CONFIRMED,0,0,0,0,13.781095,1e-05,-1e-05,177.852785,0.00057,-0.00057,0.545,0.08,-0.406,2.3188,0.0267,-0.0267,756.0,7.8,-7.8,1.42,0.09,-0.13,402,6.17,1.39,-1.55,112.8,1,3846,77,-84,4.738,0.052,-0.024,0.503,0.032,-0.048,290.3815,43.292988,12.925,2.623298,-11.483367,5.180956,-7.469874,-0.605136,-2.513306,0.84105,-3.586323,6.628041,2.054252,0.350657,-2.396896,5.996452,1.819699,0.330023,4.725616,0.0,8.254789,4.343818,1.555615,-2.937463,-0.687165,-3.411248,-3.057608,5.671196,3.767991,2.559163
2615,2615,CANDIDATE,0,0,0,0,4.501582,3.4e-05,-3.4e-05,134.1153,0.00569,-0.00569,0.016,0.442,-0.016,1.855,0.158,-0.158,175.7,21.5,-21.5,0.73,0.04,-0.03,676,49.21,9.73,-7.37,9.7,2,4339,78,-87,4.711,0.02,-0.032,0.553,0.031,-0.022,284.56689,49.01255,15.142,1.504429,-10.286213,4.8987,-5.169045,-4.074542,-0.814186,0.617885,-1.838851,5.168778,3.068099,-0.314711,-3.194183,6.516193,3.896097,2.275317,2.272126,0.693147,8.375399,4.356722,1.5499,-3.863233,-0.592397,-3.442019,-3.863233,5.650968,3.892076,2.717472
5367,5367,FALSE POSITIVE,0,1,1,1,0.645851,2e-06,-2e-06,131.88911,0.0032,-0.0032,0.476,0.453,-0.273,1.2082,0.0931,-0.0931,54.9,4.0,-4.0,0.69,0.15,-0.11,1941,3366.81,2330.27,-1190.35,16.8,1,5381,159,-143,4.426,0.144,-0.192,0.911,0.207,-0.138,285.00354,45.887951,14.107,-0.437187,-13.031609,4.881961,-5.744604,-0.740239,-0.789658,0.189132,-2.363397,4.005513,1.386544,-0.371064,-1.890475,7.570959,8.121721,7.75374,2.821379,0.0,8.59063,5.06891,1.487496,-1.931022,-0.093212,-1.570217,-1.987774,5.652502,3.826203,2.646671
590,590,CONFIRMED,0,0,0,0,11.419261,2.3e-05,-2.3e-05,137.66894,0.00166,-0.00166,0.242,0.182,-0.242,3.9908,0.0545,-0.0545,915.6,18.1,-18.1,1.48,0.16,-0.23,419,7.32,3.38,-2.98,58.0,1,3823,167,-184,4.765,0.077,-0.056,0.492,0.054,-0.074,286.07523,44.664639,15.314,2.435302,-10.671358,4.924852,-6.400938,-1.414694,-1.698269,1.383992,-2.891372,6.81958,2.895967,0.392042,-1.826351,6.037871,1.99061,1.218172,4.060443,0.0,8.248791,5.118,1.561298,-2.551046,-0.709277,-2.900422,-2.617296,5.656255,3.799182,2.728767
378,378,FALSE POSITIVE,0,0,1,1,1.092068,2e-06,-2e-06,133.19403,0.00152,-0.00152,0.469,0.037,-0.469,1.8596,0.0585,-0.0585,219.2,6.1,-6.1,1.29,0.36,-0.11,1611,1588.89,1267.12,-425.97,45.5,1,5639,152,-169,4.551,0.033,-0.187,0.859,0.233,-0.078,295.51108,44.109261,14.495,0.088073,-13.163185,4.891807,-6.489045,-0.755023,-3.270169,0.620361,-2.821779,5.389985,1.808453,0.254642,-1.018877,7.38461,7.370791,7.144503,3.817712,0.0,8.637462,5.023887,1.515347,-3.381395,-0.151986,-1.452434,-2.56395,5.688706,3.78667,2.673804


In [4]:
exo2.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6991 entries, 0 to 6990
Data columns (total 69 columns):
Unnamed: 0               6991 non-null int64
koi_disposition          6991 non-null object
koi_fpflag_nt            6991 non-null int64
koi_fpflag_ss            6991 non-null int64
koi_fpflag_co            6991 non-null int64
koi_fpflag_ec            6991 non-null int64
koi_period               6991 non-null float64
koi_period_err1          6991 non-null float64
koi_period_err2          6991 non-null float64
koi_time0bk              6991 non-null float64
koi_time0bk_err1         6991 non-null float64
koi_time0bk_err2         6991 non-null float64
koi_impact               6991 non-null float64
koi_impact_err1          6991 non-null float64
koi_impact_err2          6991 non-null float64
koi_duration             6991 non-null float64
koi_duration_err1        6991 non-null float64
koi_duration_err2        6991 non-null float64
koi_depth                6991 non-null float64
koi_depth_e

In [5]:
y = exo2['koi_disposition']
exo2_features = exo2.drop(['koi_disposition', 'Unnamed: 0'], axis=1)

In [6]:
exo2_features.sample(5)

Unnamed: 0,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag,log_koi_period,log_koi_period_err1,log_koi_time0bk,log_koi_time0bk_err1,log_koi_impact,log_koi_impact_err1,log_koi_duration,log_koi_duration_err1,log_koi_depth,log_koi_depth_err1,log_koi_prad,log_koi_prad_err1,log_koi_teq,log_koi_insol,log_koi_insol_err1,log_koi_model_snr,log_koi_tce_plnt_num,log_koi_steff,log_koi_steff_err1,log_koi_slogg,log_koi_slogg_err1,log_koi_srad,log_koi_srad_err1,log_koi_srad_err2,log_ra,log_dec,log_koi_kepmag
6781,0,1,1,1,1.332584,1.6e-05,-1.6e-05,132.0362,0.0114,-0.0114,0.746,0.206,-0.505,4.081,0.294,-0.294,60.9,4.5,-4.5,0.84,0.27,-0.09,1727,2096.92,1939.87,-616.41,16.4,1,6108,171,-192,4.471,0.05,-0.2,0.997,0.312,-0.104,295.28735,46.75449,14.835,0.28712,-11.042922,4.883076,-4.474142,-0.29169,-1.575036,1.406342,-1.22078,4.109233,1.5043,-0.174353,-1.305636,7.454141,7.648225,7.570377,2.797281,0.0,8.717355,5.141669,1.497612,-2.97593,-0.003005,-1.161552,-2.273026,5.687949,3.84491,2.696989
3712,0,1,0,0,15.65129,1.5e-05,-1.5e-05,132.81543,0.001,-0.001,0.507,0.027,-0.03,6.7631,0.0258,-0.0258,398640.0,1438.0,-1438.0,78.3,26.62,-8.87,792,93.1,95.75,-31.25,359.7,1,6137,190,-253,4.418,0.056,-0.224,1.103,0.375,-0.125,293.23383,46.598251,16.666,2.750553,-11.100816,4.88896,-6.907755,-0.677274,-3.575551,1.911481,-3.619353,12.895814,7.271009,4.360548,3.2817,6.674561,4.533674,4.561751,5.88527,0.0,8.722091,5.247029,1.485687,-2.864704,0.098034,-0.978166,-2.087474,5.68097,3.841563,2.813371
3279,0,1,1,0,1.08523,6e-06,-6e-06,132.10036,0.00435,-0.00435,0.439,0.175,-0.439,0.652,0.224,-0.224,238.1,36.5,-36.5,1.37,0.37,-0.13,1632,1675.45,1379.29,-464.48,7.6,1,5713,169,-186,4.554,0.035,-0.184,0.851,0.233,-0.078,298.15717,47.331268,15.959,0.081792,-11.960776,4.883562,-5.437579,-0.820981,-1.737271,-0.427711,-1.491655,5.472691,3.59734,0.314811,-0.991553,7.397562,7.423837,7.229325,2.028148,0.0,8.6505,5.129905,1.516006,-3.324236,-0.161343,-1.452434,-2.56395,5.697621,3.857171,2.770023
6119,0,0,1,1,1.891144,3.3e-05,-3.3e-05,132.7545,0.0164,-0.0164,1.088,0.264,-0.094,8.724,0.852,-0.852,49.9,3.7,-3.7,6.57,0.47,-0.58,932,177.97,52.12,-46.84,16.7,1,4267,128,-128,4.637,0.049,-0.021,0.627,0.045,-0.055,290.29175,37.981079,13.814,0.637182,-10.309953,4.888502,-4.110474,0.08526,-1.328025,2.166078,-0.158996,3.910021,1.308603,1.882514,-0.752897,6.837333,5.181615,3.953568,2.815409,0.0,8.358666,4.852038,1.534068,-2.995732,-0.466809,-3.079114,-2.918771,5.670886,3.637088,2.625683
1267,0,0,0,0,15.942485,6e-05,-6e-05,146.33157,0.00298,-0.00298,0.022,0.423,-0.022,3.4894,0.0806,-0.0806,1093.7,34.7,-34.7,2.62,0.54,-0.22,616,34.06,22.94,-9.19,34.3,1,5376,186,-186,4.576,0.034,-0.136,0.806,0.168,-0.067,293.82431,42.779701,15.886,2.768988,-9.722834,4.985875,-5.815832,-3.772261,-0.858022,1.24973,-2.505926,6.997322,3.546769,0.963174,-0.614336,6.423247,3.528124,3.132926,3.535145,0.0,8.5897,5.225752,1.520825,-3.352407,-0.215672,-1.777857,-2.718101,5.682982,3.756064,2.765438


In [7]:
scaler = MinMaxScaler()
scaled_features = scaler.fit_transform(exo2_features)

  return self.partial_fit(X, y)


In [8]:
X_train, X_test, y_train, y_test = train_test_split(scaled_features, y, random_state=42, stratify=y)

In [9]:
forest = RandomForestClassifier()
forest_grid = {
    'n_estimators': [25, 50, 100, 200, 400, 800],
    'criterion': ['gini', 'entropy'],
    'bootstrap': [True, False]
}

In [10]:
grid = GridSearchCV(
    forest, 
    forest_grid, 
    scoring='accuracy', 
    cv=10, 
    n_jobs=-1, 
    verbose=3
)

grid.fit(X_train, y_train)

Fitting 10 folds for each of 24 candidates, totalling 240 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  16 tasks      | elapsed:    6.1s
[Parallel(n_jobs=-1)]: Done 112 tasks      | elapsed:  2.6min
[Parallel(n_jobs=-1)]: Done 240 out of 240 | elapsed:  7.9min finished


GridSearchCV(cv=10, error_score='raise-deprecating',
       estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators='warn', n_jobs=None,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False),
       fit_params=None, iid='warn', n_jobs=-1,
       param_grid={'n_estimators': [25, 50, 100, 200, 400, 800], 'criterion': ['gini', 'entropy'], 'bootstrap': [True, False]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring='accuracy', verbose=3)

In [11]:
print(grid.best_params_)

{'bootstrap': False, 'criterion': 'entropy', 'n_estimators': 400}


In [12]:
print(grid.best_score_)

0.8901392332633988


In [13]:
predictions = grid.predict(X_test)
print(classification_report(y_test, predictions))

                precision    recall  f1-score   support

     CANDIDATE       0.84      0.73      0.79       422
     CONFIRMED       0.80      0.86      0.83       450
FALSE POSITIVE       0.97      1.00      0.98       876

     micro avg       0.90      0.90      0.90      1748
     macro avg       0.87      0.86      0.87      1748
  weighted avg       0.90      0.90      0.90      1748



In [14]:
dump(grid, 'RF.joblib')

['RF.joblib']