In [1]:
from generatefeatures import generate_features
import pandas as pd
import numpy as np
import os
import rebound
from collections import OrderedDict

First we generate all the features

In [2]:
system = "KOI-0168"    #KOI-0156, KOI-0168, KOI-2086
epoch = 780            #units: BJD - 2,454,900

Ms = {}; Ms["KOI-0156"] = 0.56; Ms["KOI-0168"] = 1.11; Ms["KOI-2086"] = 1.04
columns = ["m1","T1","P1","h1","k1","m2","T2","P2","h2","k2","m3","T3","P3","h3","k3"] #h,k,T=esinw,ecosw,epoch
data = pd.read_csv("systems/%s.dat"%system,names=columns,sep="\s+")

Main execution of code below. **This will take a moment.**

In [5]:
fullfeatures = []
for index, d in data.iterrows():
    fullfeatures.append(generate_features(d, Ms[system], epoch))

Load the trained model

In [7]:
import pickle 
model = pickle.load(open('OptimalXGBmodel.pkl', 'rb'))



The trained model only uses a subset of the generated features, so make a list of the ones it needs for evaluation

In [8]:
features = []
features += ['avg_iH1', 'avg_iH2']
features += ['norm_std_a1', 'norm_max_a1', 'norm_std_window10_a1', 'norm_max_window10_a1']
features += ['norm_std_a2', 'norm_max_a2', 'norm_std_window10_a2', 'norm_max_window10_a2']
features += ['norm_std_a3', 'norm_max_a3', 'norm_std_window10_a3', 'norm_max_window10_a3']
features += ['avg_ecross1', 'std_ecross1', 'max_ecross1', 'min_ecross1']
features += ['avg_ecross2', 'std_ecross2', 'max_ecross2', 'min_ecross2']
features += ['avg_ecross3', 'std_ecross3', 'max_ecross3', 'min_ecross3']
features += ['norm_a1_slope', 'norm_a2_slope', 'norm_a3_slope']
features += ['Lyapunov_time']

There's probably a better way to do this, but it seems the XGBoost API wants a dataframe passed to its predict method. So we need to turn our features into a dataframe with a single row. 

In [9]:
results = pd.DataFrame(columns=features)
for i in range(len(fullfeatures)):
    results.loc[i] = fullfeatures[i][features] # take only the features the model uses
results

Unnamed: 0,avg_iH1,avg_iH2,norm_std_a1,norm_max_a1,norm_std_window10_a1,norm_max_window10_a1,norm_std_a2,norm_max_a2,norm_std_window10_a2,norm_max_window10_a2,...,max_ecross2,min_ecross2,avg_ecross3,std_ecross3,max_ecross3,min_ecross3,norm_a1_slope,norm_a2_slope,norm_a3_slope,Lyapunov_time
0,0.0,0.0,2.343204e-08,5.916606e-08,1.94841e-08,3.320943e-08,1.090694e-08,2.457912e-08,9.054518e-09,1.641766e-08,...,0.119455,0.119455,0.0,0.0,0.205724,0.205724,1.571732e-15,1.767443e-16,-1.225163e-15,5856.398773
1,0.0,0.0,2.763839e-08,5.206889e-08,7.292978e-09,1.178618e-08,1.613849e-08,3.820029e-08,1.027954e-08,2.009626e-08,...,0.18063,0.180628,0.0,0.0,0.084341,0.084341,-4.366766e-15,3.573341e-15,2.868838e-15,-13367.644401
2,0.0,0.0,3.418902e-08,5.812196e-08,8.888849e-09,1.407974e-08,2.572818e-08,6.009075e-08,8.348391e-09,1.499519e-08,...,0.366329,0.366327,0.0,0.0,0.515061,0.51506,4.608189e-14,9.630678e-15,-1.294161e-13,-41120.225426
3,0.0,0.0,3.461788e-08,5.753931e-08,7.827917e-09,1.149875e-08,1.506889e-08,2.936603e-08,4.79766e-09,8.425422e-09,...,0.4865,0.486499,0.0,0.0,0.538874,0.538874,5.162266e-14,-1.350236e-14,-1.357255e-14,-169411.092574
4,0.0,0.0,3.309268e-08,7.062267e-08,9.793781e-09,1.275066e-08,2.369903e-08,5.930276e-08,1.534772e-08,2.437284e-08,...,0.275516,0.275515,0.0,0.0,0.314603,0.314602,1.505904e-14,-3.910017e-15,-1.496248e-15,-1272.755974
5,0.0,0.0,1.205663e-08,2.575868e-08,5.001357e-09,8.875852e-09,3.409926e-08,6.157301e-08,7.190338e-09,1.065366e-08,...,0.101314,0.101313,0.0,0.0,0.271504,0.271504,5.475466e-15,-1.062983e-14,7.366789e-15,183087.986262
6,0.0,0.0,1.183257e-08,3.218361e-08,4.754103e-09,9.889786e-09,1.102131e-08,2.445532e-08,1.039503e-08,1.599908e-08,...,0.140769,0.140769,0.0,0.0,0.167924,0.167923,-2.546009e-15,-6.942044e-16,2.033426e-15,232419.749619
7,0.0,0.0,7.810411e-08,1.11496e-07,1.173502e-08,1.66774e-08,5.104302e-08,9.33291e-08,2.079961e-08,4.392801e-08,...,0.366209,0.366208,0.0,0.0,0.374459,0.374458,2.507318e-13,-1.267431e-13,4.030899e-15,-44474.715807
8,0.0,0.0,2.716591e-08,5.051577e-08,4.197802e-09,7.819604e-09,3.798065e-08,6.959318e-08,1.828711e-08,3.069417e-08,...,0.245596,0.245596,0.0,0.0,0.294901,0.294901,5.145984e-14,-9.350986e-15,-1.158626e-14,4760.516251
9,0.0,0.0,2.164513e-08,4.782465e-08,1.578649e-08,2.552531e-08,1.090981e-08,2.355576e-08,7.213458e-09,1.265585e-08,...,0.340547,0.340547,0.0,0.0,0.326121,0.32612,1.249192e-14,4.089446e-15,-5.029096e-14,-62906.410805


Predict probability

In [10]:
results["predict_proba"] = model.predict_proba(results)[:,1]
results

Unnamed: 0,avg_iH1,avg_iH2,norm_std_a1,norm_max_a1,norm_std_window10_a1,norm_max_window10_a1,norm_std_a2,norm_max_a2,norm_std_window10_a2,norm_max_window10_a2,...,min_ecross2,avg_ecross3,std_ecross3,max_ecross3,min_ecross3,norm_a1_slope,norm_a2_slope,norm_a3_slope,Lyapunov_time,predict_proba
0,0.0,0.0,2.343204e-08,5.916606e-08,1.94841e-08,3.320943e-08,1.090694e-08,2.457912e-08,9.054518e-09,1.641766e-08,...,0.119455,0.0,0.0,0.205724,0.205724,1.571732e-15,1.767443e-16,-1.225163e-15,5856.398773,0.998868
1,0.0,0.0,2.763839e-08,5.206889e-08,7.292978e-09,1.178618e-08,1.613849e-08,3.820029e-08,1.027954e-08,2.009626e-08,...,0.180628,0.0,0.0,0.084341,0.084341,-4.366766e-15,3.573341e-15,2.868838e-15,-13367.644401,0.997549
2,0.0,0.0,3.418902e-08,5.812196e-08,8.888849e-09,1.407974e-08,2.572818e-08,6.009075e-08,8.348391e-09,1.499519e-08,...,0.366327,0.0,0.0,0.515061,0.51506,4.608189e-14,9.630678e-15,-1.294161e-13,-41120.225426,0.933303
3,0.0,0.0,3.461788e-08,5.753931e-08,7.827917e-09,1.149875e-08,1.506889e-08,2.936603e-08,4.79766e-09,8.425422e-09,...,0.486499,0.0,0.0,0.538874,0.538874,5.162266e-14,-1.350236e-14,-1.357255e-14,-169411.092574,0.539691
4,0.0,0.0,3.309268e-08,7.062267e-08,9.793781e-09,1.275066e-08,2.369903e-08,5.930276e-08,1.534772e-08,2.437284e-08,...,0.275515,0.0,0.0,0.314603,0.314602,1.505904e-14,-3.910017e-15,-1.496248e-15,-1272.755974,0.65321
5,0.0,0.0,1.205663e-08,2.575868e-08,5.001357e-09,8.875852e-09,3.409926e-08,6.157301e-08,7.190338e-09,1.065366e-08,...,0.101313,0.0,0.0,0.271504,0.271504,5.475466e-15,-1.062983e-14,7.366789e-15,183087.986262,0.997516
6,0.0,0.0,1.183257e-08,3.218361e-08,4.754103e-09,9.889786e-09,1.102131e-08,2.445532e-08,1.039503e-08,1.599908e-08,...,0.140769,0.0,0.0,0.167924,0.167923,-2.546009e-15,-6.942044e-16,2.033426e-15,232419.749619,0.996133
7,0.0,0.0,7.810411e-08,1.11496e-07,1.173502e-08,1.66774e-08,5.104302e-08,9.33291e-08,2.079961e-08,4.392801e-08,...,0.366208,0.0,0.0,0.374459,0.374458,2.507318e-13,-1.267431e-13,4.030899e-15,-44474.715807,0.939029
8,0.0,0.0,2.716591e-08,5.051577e-08,4.197802e-09,7.819604e-09,3.798065e-08,6.959318e-08,1.828711e-08,3.069417e-08,...,0.245596,0.0,0.0,0.294901,0.294901,5.145984e-14,-9.350986e-15,-1.158626e-14,4760.516251,0.975277
9,0.0,0.0,2.164513e-08,4.782465e-08,1.578649e-08,2.552531e-08,1.090981e-08,2.355576e-08,7.213458e-09,1.265585e-08,...,0.340547,0.0,0.0,0.326121,0.32612,1.249192e-14,4.089446e-15,-5.029096e-14,-62906.410805,0.931119


Save results

In [26]:
data.to_csv("systems/%s_data.csv"%system)
results.to_csv("systems/%s_results.csv"%system)