## ü™ê Exoplanet Identification

Given *data collected about objects* in space, let's try to predict whether a given object is an **exoplanet** or not.

We will use a variety of classification models to make our predictions.

Data source: https://www.kaggle.com/datasets/nasa/kepler-exoplanet-search-results

### Getting Started

In [1]:
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier

import warnings
warnings.filterwarnings(action='ignore')

In [2]:
data = pd.read_csv('archive/cumulative.csv')
data

Unnamed: 0,rowid,kepid,kepoi_name,kepler_name,koi_disposition,koi_pdisposition,koi_score,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_teq_err1,koi_teq_err2,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_tce_delivname,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag
0,1,10797460,K00752.01,Kepler-227 b,CONFIRMED,CANDIDATE,1.000,0,0,0,0,9.488036,2.775000e-05,-2.775000e-05,170.538750,0.002160,-0.002160,0.146,0.318,-0.146,2.95750,0.08190,-0.08190,615.8,19.5,-19.5,2.26,0.26,-0.15,793.0,,,93.59,29.45,-16.65,35.8,1.0,q1_q17_dr25_tce,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347
1,2,10797460,K00752.02,Kepler-227 c,CONFIRMED,CANDIDATE,0.969,0,0,0,0,54.418383,2.479000e-04,-2.479000e-04,162.513840,0.003520,-0.003520,0.586,0.059,-0.443,4.50700,0.11600,-0.11600,874.8,35.5,-35.5,2.83,0.32,-0.19,443.0,,,9.11,2.87,-1.62,25.8,2.0,q1_q17_dr25_tce,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347
2,3,10811496,K00753.01,,FALSE POSITIVE,FALSE POSITIVE,0.000,0,1,0,0,19.899140,1.494000e-05,-1.494000e-05,175.850252,0.000581,-0.000581,0.969,5.126,-0.077,1.78220,0.03410,-0.03410,10829.0,171.0,-171.0,14.60,3.92,-1.31,638.0,,,39.30,31.04,-10.49,76.3,1.0,q1_q17_dr25_tce,5853.0,158.0,-176.0,4.544,0.044,-0.176,0.868,0.233,-0.078,297.00482,48.134129,15.436
3,4,10848459,K00754.01,,FALSE POSITIVE,FALSE POSITIVE,0.000,0,1,0,0,1.736952,2.630000e-07,-2.630000e-07,170.307565,0.000115,-0.000115,1.276,0.115,-0.092,2.40641,0.00537,-0.00537,8079.2,12.8,-12.8,33.46,8.50,-2.83,1395.0,,,891.96,668.95,-230.35,505.6,1.0,q1_q17_dr25_tce,5805.0,157.0,-174.0,4.564,0.053,-0.168,0.791,0.201,-0.067,285.53461,48.285210,15.597
4,5,10854555,K00755.01,Kepler-664 b,CONFIRMED,CANDIDATE,1.000,0,0,0,0,2.525592,3.761000e-06,-3.761000e-06,171.595550,0.001130,-0.001130,0.701,0.235,-0.478,1.65450,0.04200,-0.04200,603.3,16.9,-16.9,2.75,0.88,-0.35,1406.0,,,926.16,874.33,-314.24,40.9,1.0,q1_q17_dr25_tce,6031.0,169.0,-211.0,4.438,0.070,-0.210,1.046,0.334,-0.133,288.75488,48.226200,15.509
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9559,9560,10031643,K07984.01,,FALSE POSITIVE,FALSE POSITIVE,0.000,0,0,0,1,8.589871,1.846000e-04,-1.846000e-04,132.016100,0.015700,-0.015700,0.765,0.023,-0.541,4.80600,0.63400,-0.63400,87.7,13.0,-13.0,1.11,0.32,-0.23,929.0,,,176.40,152.77,-77.60,8.4,1.0,q1_q17_dr25_tce,5638.0,169.0,-152.0,4.296,0.231,-0.189,1.088,0.313,-0.228,298.74921,46.973351,14.478
9560,9561,10090151,K07985.01,,FALSE POSITIVE,FALSE POSITIVE,0.000,0,1,1,0,0.527699,1.160000e-07,-1.160000e-07,131.705093,0.000170,-0.000170,1.252,0.051,-0.049,3.22210,0.01740,-0.01740,1579.2,4.6,-4.6,29.35,7.70,-2.57,2088.0,,,4500.53,3406.38,-1175.26,453.3,1.0,q1_q17_dr25_tce,5638.0,139.0,-166.0,4.529,0.035,-0.196,0.903,0.237,-0.079,297.18875,47.093819,14.082
9561,9562,10128825,K07986.01,,CANDIDATE,CANDIDATE,0.497,0,0,0,0,1.739849,1.780000e-05,-1.780000e-05,133.001270,0.007690,-0.007690,0.043,0.423,-0.043,3.11400,0.22900,-0.22900,48.5,5.4,-5.4,0.72,0.24,-0.08,1608.0,,,1585.81,1537.86,-502.22,10.6,1.0,q1_q17_dr25_tce,6119.0,165.0,-220.0,4.444,0.056,-0.224,1.031,0.341,-0.114,286.50937,47.163219,14.757
9562,9563,10147276,K07987.01,,FALSE POSITIVE,FALSE POSITIVE,0.021,0,0,1,0,0.681402,2.434000e-06,-2.434000e-06,132.181750,0.002850,-0.002850,0.147,0.309,-0.147,0.86500,0.16200,-0.16200,103.6,14.7,-14.7,1.07,0.36,-0.11,2218.0,,,5713.41,5675.74,-1836.94,12.3,1.0,q1_q17_dr25_tce,6173.0,193.0,-236.0,4.447,0.056,-0.224,1.041,0.341,-0.114,294.16489,47.176281,15.385


In [3]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9564 entries, 0 to 9563
Data columns (total 50 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   rowid              9564 non-null   int64  
 1   kepid              9564 non-null   int64  
 2   kepoi_name         9564 non-null   object 
 3   kepler_name        2294 non-null   object 
 4   koi_disposition    9564 non-null   object 
 5   koi_pdisposition   9564 non-null   object 
 6   koi_score          8054 non-null   float64
 7   koi_fpflag_nt      9564 non-null   int64  
 8   koi_fpflag_ss      9564 non-null   int64  
 9   koi_fpflag_co      9564 non-null   int64  
 10  koi_fpflag_ec      9564 non-null   int64  
 11  koi_period         9564 non-null   float64
 12  koi_period_err1    9110 non-null   float64
 13  koi_period_err2    9110 non-null   float64
 14  koi_time0bk        9564 non-null   float64
 15  koi_time0bk_err1   9110 non-null   float64
 16  koi_time0bk_err2   9110 

### Preprocessing

In [4]:
df = data.copy()

In [5]:
# Drop unused columns
df = df.drop(['rowid', 'kepid', 'kepoi_name', 'kepler_name', 'koi_score', 'koi_pdisposition'], axis=1)

In [6]:
df

Unnamed: 0,koi_disposition,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_teq_err1,koi_teq_err2,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_tce_delivname,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag
0,CONFIRMED,0,0,0,0,9.488036,2.775000e-05,-2.775000e-05,170.538750,0.002160,-0.002160,0.146,0.318,-0.146,2.95750,0.08190,-0.08190,615.8,19.5,-19.5,2.26,0.26,-0.15,793.0,,,93.59,29.45,-16.65,35.8,1.0,q1_q17_dr25_tce,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347
1,CONFIRMED,0,0,0,0,54.418383,2.479000e-04,-2.479000e-04,162.513840,0.003520,-0.003520,0.586,0.059,-0.443,4.50700,0.11600,-0.11600,874.8,35.5,-35.5,2.83,0.32,-0.19,443.0,,,9.11,2.87,-1.62,25.8,2.0,q1_q17_dr25_tce,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347
2,FALSE POSITIVE,0,1,0,0,19.899140,1.494000e-05,-1.494000e-05,175.850252,0.000581,-0.000581,0.969,5.126,-0.077,1.78220,0.03410,-0.03410,10829.0,171.0,-171.0,14.60,3.92,-1.31,638.0,,,39.30,31.04,-10.49,76.3,1.0,q1_q17_dr25_tce,5853.0,158.0,-176.0,4.544,0.044,-0.176,0.868,0.233,-0.078,297.00482,48.134129,15.436
3,FALSE POSITIVE,0,1,0,0,1.736952,2.630000e-07,-2.630000e-07,170.307565,0.000115,-0.000115,1.276,0.115,-0.092,2.40641,0.00537,-0.00537,8079.2,12.8,-12.8,33.46,8.50,-2.83,1395.0,,,891.96,668.95,-230.35,505.6,1.0,q1_q17_dr25_tce,5805.0,157.0,-174.0,4.564,0.053,-0.168,0.791,0.201,-0.067,285.53461,48.285210,15.597
4,CONFIRMED,0,0,0,0,2.525592,3.761000e-06,-3.761000e-06,171.595550,0.001130,-0.001130,0.701,0.235,-0.478,1.65450,0.04200,-0.04200,603.3,16.9,-16.9,2.75,0.88,-0.35,1406.0,,,926.16,874.33,-314.24,40.9,1.0,q1_q17_dr25_tce,6031.0,169.0,-211.0,4.438,0.070,-0.210,1.046,0.334,-0.133,288.75488,48.226200,15.509
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9559,FALSE POSITIVE,0,0,0,1,8.589871,1.846000e-04,-1.846000e-04,132.016100,0.015700,-0.015700,0.765,0.023,-0.541,4.80600,0.63400,-0.63400,87.7,13.0,-13.0,1.11,0.32,-0.23,929.0,,,176.40,152.77,-77.60,8.4,1.0,q1_q17_dr25_tce,5638.0,169.0,-152.0,4.296,0.231,-0.189,1.088,0.313,-0.228,298.74921,46.973351,14.478
9560,FALSE POSITIVE,0,1,1,0,0.527699,1.160000e-07,-1.160000e-07,131.705093,0.000170,-0.000170,1.252,0.051,-0.049,3.22210,0.01740,-0.01740,1579.2,4.6,-4.6,29.35,7.70,-2.57,2088.0,,,4500.53,3406.38,-1175.26,453.3,1.0,q1_q17_dr25_tce,5638.0,139.0,-166.0,4.529,0.035,-0.196,0.903,0.237,-0.079,297.18875,47.093819,14.082
9561,CANDIDATE,0,0,0,0,1.739849,1.780000e-05,-1.780000e-05,133.001270,0.007690,-0.007690,0.043,0.423,-0.043,3.11400,0.22900,-0.22900,48.5,5.4,-5.4,0.72,0.24,-0.08,1608.0,,,1585.81,1537.86,-502.22,10.6,1.0,q1_q17_dr25_tce,6119.0,165.0,-220.0,4.444,0.056,-0.224,1.031,0.341,-0.114,286.50937,47.163219,14.757
9562,FALSE POSITIVE,0,0,1,0,0.681402,2.434000e-06,-2.434000e-06,132.181750,0.002850,-0.002850,0.147,0.309,-0.147,0.86500,0.16200,-0.16200,103.6,14.7,-14.7,1.07,0.36,-0.11,2218.0,,,5713.41,5675.74,-1836.94,12.3,1.0,q1_q17_dr25_tce,6173.0,193.0,-236.0,4.447,0.056,-0.224,1.041,0.341,-0.114,294.16489,47.176281,15.385


In [7]:
df['koi_disposition'].value_counts()

koi_disposition
FALSE POSITIVE    5023
CONFIRMED         2293
CANDIDATE         2248
Name: count, dtype: int64

In [8]:
# Limit target values to CANDIDATE and CONFIRMED
false_positive_rows = df.query("koi_disposition == 'FALSE POSITIVE'").index

df = df.drop(false_positive_rows, axis=0).reset_index(drop=True)

In [9]:
df['koi_disposition'].value_counts()

koi_disposition
CONFIRMED    2293
CANDIDATE    2248
Name: count, dtype: int64

In [10]:
df.isna().mean() * 100

koi_disposition        0.000000
koi_fpflag_nt          0.000000
koi_fpflag_ss          0.000000
koi_fpflag_co          0.000000
koi_fpflag_ec          0.000000
koi_period             0.000000
koi_period_err1        1.717683
koi_period_err2        1.717683
koi_time0bk            0.000000
koi_time0bk_err1       1.717683
koi_time0bk_err2       1.717683
koi_impact             1.409381
koi_impact_err1        1.717683
koi_impact_err2        1.717683
koi_duration           0.000000
koi_duration_err1      1.717683
koi_duration_err2      1.717683
koi_depth              1.409381
koi_depth_err1         1.717683
koi_depth_err2         1.717683
koi_prad               1.409381
koi_prad_err1          1.409381
koi_prad_err2          1.409381
koi_teq                1.409381
koi_teq_err1         100.000000
koi_teq_err2         100.000000
koi_insol              1.387360
koi_insol_err1         1.387360
koi_insol_err2         1.387360
koi_model_snr          1.409381
koi_tce_plnt_num       1.651619
koi_tce_

In [11]:
# Drop columns with all missing values
df = df.drop(['koi_teq_err1', 'koi_teq_err2'], axis=1)

In [12]:
df

Unnamed: 0,koi_disposition,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_tce_delivname,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag
0,CONFIRMED,0,0,0,0,9.488036,0.000028,-0.000028,170.538750,0.002160,-0.002160,0.146,0.318,-0.146,2.9575,0.0819,-0.0819,615.8,19.5,-19.5,2.26,0.26,-0.15,793.0,93.59,29.45,-16.65,35.8,1.0,q1_q17_dr25_tce,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347
1,CONFIRMED,0,0,0,0,54.418383,0.000248,-0.000248,162.513840,0.003520,-0.003520,0.586,0.059,-0.443,4.5070,0.1160,-0.1160,874.8,35.5,-35.5,2.83,0.32,-0.19,443.0,9.11,2.87,-1.62,25.8,2.0,q1_q17_dr25_tce,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347
2,CONFIRMED,0,0,0,0,2.525592,0.000004,-0.000004,171.595550,0.001130,-0.001130,0.701,0.235,-0.478,1.6545,0.0420,-0.0420,603.3,16.9,-16.9,2.75,0.88,-0.35,1406.0,926.16,874.33,-314.24,40.9,1.0,q1_q17_dr25_tce,6031.0,169.0,-211.0,4.438,0.070,-0.210,1.046,0.334,-0.133,288.75488,48.226200,15.509
3,CONFIRMED,0,0,0,0,11.094321,0.000020,-0.000020,171.201160,0.001410,-0.001410,0.538,0.030,-0.428,4.5945,0.0610,-0.0610,1517.5,24.2,-24.2,3.90,1.27,-0.42,835.0,114.81,112.85,-36.70,66.5,1.0,q1_q17_dr25_tce,6046.0,189.0,-232.0,4.486,0.054,-0.229,0.972,0.315,-0.105,296.28613,48.224670,15.714
4,CONFIRMED,0,0,0,0,4.134435,0.000010,-0.000010,172.979370,0.001900,-0.001900,0.762,0.139,-0.532,3.1402,0.0673,-0.0673,686.0,18.7,-18.7,2.77,0.90,-0.30,1160.0,427.65,420.33,-136.70,40.2,2.0,q1_q17_dr25_tce,6046.0,189.0,-232.0,4.486,0.054,-0.229,0.972,0.315,-0.105,296.28613,48.224670,15.714
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4536,CANDIDATE,0,0,0,0,4.736816,0.000147,-0.000147,131.787600,0.025600,-0.025600,0.218,0.285,-0.218,2.8400,1.0000,-1.0000,35.3,12.6,-12.6,0.60,0.20,-0.06,1137.0,395.05,377.30,-120.72,6.9,1.0,q1_q17_dr25_tce,6088.0,165.0,-201.0,4.456,0.056,-0.224,1.011,0.329,-0.110,289.20331,44.505138,13.922
4537,CANDIDATE,0,0,0,0,130.235324,0.003030,-0.003030,218.271900,0.020100,-0.020100,0.075,0.387,-0.075,5.6780,0.5340,-0.5340,750.1,91.4,-91.4,2.44,0.68,-0.23,332.0,2.86,2.38,-0.80,9.7,1.0,q1_q17_dr25_tce,5616.0,166.0,-183.0,4.529,0.036,-0.192,0.903,0.251,-0.084,289.57452,44.519939,15.991
4538,CANDIDATE,0,0,0,0,8.870416,0.000009,-0.000009,137.481093,0.000869,-0.000869,1.206,70.610,-0.033,1.2864,0.0514,-0.0514,873.1,25.8,-25.8,39.46,11.10,-16.68,1151.0,414.26,360.89,-292.07,43.8,1.0,q1_q17_dr25_tce,6022.0,200.0,-181.0,4.027,0.434,-0.186,1.514,0.426,-0.640,290.14914,50.239178,13.579
4539,CANDIDATE,0,0,0,0,47.109631,0.000194,-0.000194,144.131720,0.003430,-0.003430,1.230,6.923,-0.605,5.7410,0.1720,-0.1720,752.2,22.2,-22.2,78.98,30.94,-57.45,751.0,75.40,89.11,-70.44,35.1,1.0,q1_q17_dr25_tce,5258.0,159.0,-159.0,3.597,0.968,-0.242,2.780,1.089,-2.022,296.15601,44.920090,13.731


In [13]:
# Fill remaining missing values
df['koi_tce_delivname'] = df['koi_tce_delivname'].fillna(df['koi_tce_delivname'].mode()[0])

In [14]:
for column in df.columns[df.isna().sum() > 0]:
    df[column] = df[column].fillna(df[column].mean())

In [15]:
df.isna().sum().sum()

np.int64(0)

In [16]:
df['koi_tce_delivname'].unique()

array(['q1_q17_dr25_tce', 'q1_q17_dr24_tce', 'q1_q16_tce'], dtype=object)

In [17]:
# One-hot encode koi_tce_delivname column
delivname_dummies = pd.get_dummies(df['koi_tce_delivname'], dtype=int, prefix='delivname')
df = pd.concat([df, delivname_dummies], axis=1)
df = df.drop('koi_tce_delivname', axis=1)

In [18]:
df

Unnamed: 0,koi_disposition,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag,delivname_q1_q16_tce,delivname_q1_q17_dr24_tce,delivname_q1_q17_dr25_tce
0,CONFIRMED,0,0,0,0,9.488036,0.000028,-0.000028,170.538750,0.002160,-0.002160,0.146,0.318,-0.146,2.9575,0.0819,-0.0819,615.8,19.5,-19.5,2.26,0.26,-0.15,793.0,93.59,29.45,-16.65,35.8,1.0,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347,0,0,1
1,CONFIRMED,0,0,0,0,54.418383,0.000248,-0.000248,162.513840,0.003520,-0.003520,0.586,0.059,-0.443,4.5070,0.1160,-0.1160,874.8,35.5,-35.5,2.83,0.32,-0.19,443.0,9.11,2.87,-1.62,25.8,2.0,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347,0,0,1
2,CONFIRMED,0,0,0,0,2.525592,0.000004,-0.000004,171.595550,0.001130,-0.001130,0.701,0.235,-0.478,1.6545,0.0420,-0.0420,603.3,16.9,-16.9,2.75,0.88,-0.35,1406.0,926.16,874.33,-314.24,40.9,1.0,6031.0,169.0,-211.0,4.438,0.070,-0.210,1.046,0.334,-0.133,288.75488,48.226200,15.509,0,0,1
3,CONFIRMED,0,0,0,0,11.094321,0.000020,-0.000020,171.201160,0.001410,-0.001410,0.538,0.030,-0.428,4.5945,0.0610,-0.0610,1517.5,24.2,-24.2,3.90,1.27,-0.42,835.0,114.81,112.85,-36.70,66.5,1.0,6046.0,189.0,-232.0,4.486,0.054,-0.229,0.972,0.315,-0.105,296.28613,48.224670,15.714,0,0,1
4,CONFIRMED,0,0,0,0,4.134435,0.000010,-0.000010,172.979370,0.001900,-0.001900,0.762,0.139,-0.532,3.1402,0.0673,-0.0673,686.0,18.7,-18.7,2.77,0.90,-0.30,1160.0,427.65,420.33,-136.70,40.2,2.0,6046.0,189.0,-232.0,4.486,0.054,-0.229,0.972,0.315,-0.105,296.28613,48.224670,15.714,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4536,CANDIDATE,0,0,0,0,4.736816,0.000147,-0.000147,131.787600,0.025600,-0.025600,0.218,0.285,-0.218,2.8400,1.0000,-1.0000,35.3,12.6,-12.6,0.60,0.20,-0.06,1137.0,395.05,377.30,-120.72,6.9,1.0,6088.0,165.0,-201.0,4.456,0.056,-0.224,1.011,0.329,-0.110,289.20331,44.505138,13.922,0,0,1
4537,CANDIDATE,0,0,0,0,130.235324,0.003030,-0.003030,218.271900,0.020100,-0.020100,0.075,0.387,-0.075,5.6780,0.5340,-0.5340,750.1,91.4,-91.4,2.44,0.68,-0.23,332.0,2.86,2.38,-0.80,9.7,1.0,5616.0,166.0,-183.0,4.529,0.036,-0.192,0.903,0.251,-0.084,289.57452,44.519939,15.991,0,0,1
4538,CANDIDATE,0,0,0,0,8.870416,0.000009,-0.000009,137.481093,0.000869,-0.000869,1.206,70.610,-0.033,1.2864,0.0514,-0.0514,873.1,25.8,-25.8,39.46,11.10,-16.68,1151.0,414.26,360.89,-292.07,43.8,1.0,6022.0,200.0,-181.0,4.027,0.434,-0.186,1.514,0.426,-0.640,290.14914,50.239178,13.579,0,0,1
4539,CANDIDATE,0,0,0,0,47.109631,0.000194,-0.000194,144.131720,0.003430,-0.003430,1.230,6.923,-0.605,5.7410,0.1720,-0.1720,752.2,22.2,-22.2,78.98,30.94,-57.45,751.0,75.40,89.11,-70.44,35.1,1.0,5258.0,159.0,-159.0,3.597,0.968,-0.242,2.780,1.089,-2.022,296.15601,44.920090,13.731,0,0,1


In [19]:
# Encoding target variable (CONFIRMED is 1 and CANDIDATE is 0)
df['koi_disposition'] = df['koi_disposition'].replace({'CONFIRMED': 1, 'CANDIDATE': 0})

In [20]:
df

Unnamed: 0,koi_disposition,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag,delivname_q1_q16_tce,delivname_q1_q17_dr24_tce,delivname_q1_q17_dr25_tce
0,1,0,0,0,0,9.488036,0.000028,-0.000028,170.538750,0.002160,-0.002160,0.146,0.318,-0.146,2.9575,0.0819,-0.0819,615.8,19.5,-19.5,2.26,0.26,-0.15,793.0,93.59,29.45,-16.65,35.8,1.0,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347,0,0,1
1,1,0,0,0,0,54.418383,0.000248,-0.000248,162.513840,0.003520,-0.003520,0.586,0.059,-0.443,4.5070,0.1160,-0.1160,874.8,35.5,-35.5,2.83,0.32,-0.19,443.0,9.11,2.87,-1.62,25.8,2.0,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347,0,0,1
2,1,0,0,0,0,2.525592,0.000004,-0.000004,171.595550,0.001130,-0.001130,0.701,0.235,-0.478,1.6545,0.0420,-0.0420,603.3,16.9,-16.9,2.75,0.88,-0.35,1406.0,926.16,874.33,-314.24,40.9,1.0,6031.0,169.0,-211.0,4.438,0.070,-0.210,1.046,0.334,-0.133,288.75488,48.226200,15.509,0,0,1
3,1,0,0,0,0,11.094321,0.000020,-0.000020,171.201160,0.001410,-0.001410,0.538,0.030,-0.428,4.5945,0.0610,-0.0610,1517.5,24.2,-24.2,3.90,1.27,-0.42,835.0,114.81,112.85,-36.70,66.5,1.0,6046.0,189.0,-232.0,4.486,0.054,-0.229,0.972,0.315,-0.105,296.28613,48.224670,15.714,0,0,1
4,1,0,0,0,0,4.134435,0.000010,-0.000010,172.979370,0.001900,-0.001900,0.762,0.139,-0.532,3.1402,0.0673,-0.0673,686.0,18.7,-18.7,2.77,0.90,-0.30,1160.0,427.65,420.33,-136.70,40.2,2.0,6046.0,189.0,-232.0,4.486,0.054,-0.229,0.972,0.315,-0.105,296.28613,48.224670,15.714,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4536,0,0,0,0,0,4.736816,0.000147,-0.000147,131.787600,0.025600,-0.025600,0.218,0.285,-0.218,2.8400,1.0000,-1.0000,35.3,12.6,-12.6,0.60,0.20,-0.06,1137.0,395.05,377.30,-120.72,6.9,1.0,6088.0,165.0,-201.0,4.456,0.056,-0.224,1.011,0.329,-0.110,289.20331,44.505138,13.922,0,0,1
4537,0,0,0,0,0,130.235324,0.003030,-0.003030,218.271900,0.020100,-0.020100,0.075,0.387,-0.075,5.6780,0.5340,-0.5340,750.1,91.4,-91.4,2.44,0.68,-0.23,332.0,2.86,2.38,-0.80,9.7,1.0,5616.0,166.0,-183.0,4.529,0.036,-0.192,0.903,0.251,-0.084,289.57452,44.519939,15.991,0,0,1
4538,0,0,0,0,0,8.870416,0.000009,-0.000009,137.481093,0.000869,-0.000869,1.206,70.610,-0.033,1.2864,0.0514,-0.0514,873.1,25.8,-25.8,39.46,11.10,-16.68,1151.0,414.26,360.89,-292.07,43.8,1.0,6022.0,200.0,-181.0,4.027,0.434,-0.186,1.514,0.426,-0.640,290.14914,50.239178,13.579,0,0,1
4539,0,0,0,0,0,47.109631,0.000194,-0.000194,144.131720,0.003430,-0.003430,1.230,6.923,-0.605,5.7410,0.1720,-0.1720,752.2,22.2,-22.2,78.98,30.94,-57.45,751.0,75.40,89.11,-70.44,35.1,1.0,5258.0,159.0,-159.0,3.597,0.968,-0.242,2.780,1.089,-2.022,296.15601,44.920090,13.731,0,0,1


In [21]:
# Split df into X and y
y = df['koi_disposition']
X = df.drop('koi_disposition', axis=1)

In [22]:
# Train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True, random_state=1)

In [23]:
X_train.shape, X_test.shape

((3178, 43), (1363, 43))

In [24]:
# Scale X
scaler = StandardScaler()
scaler.fit(X_train)

X_train = pd.DataFrame(scaler.transform(X_train), columns=X_train.columns, index=X_train.index)
X_test = pd.DataFrame(scaler.transform(X_test), columns=X_test.columns, index=X_test.index)

In [25]:
X_train

Unnamed: 0,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag,delivname_q1_q16_tce,delivname_q1_q17_dr24_tce,delivname_q1_q17_dr25_tce
1673,-0.071134,-0.117116,-0.0355,-0.017742,-0.037165,-0.239706,0.239706,-0.497521,-0.000669,0.000669,0.272742,-0.140167,-0.266272,-0.441041,-0.079904,0.079904,-0.155163,-0.279154,0.279154,-0.035572,-0.031034,0.027635,0.646748,-0.023607,-0.034229,0.020831,-0.215034,0.780456,0.833483,-0.138215,0.084673,-0.331827,0.332863,0.162643,-0.012212,-0.045685,0.033953,-1.053449,-1.878703,-0.325469,-0.242131,-0.185756,0.312051
1239,-0.071134,-0.117116,-0.0355,-0.017742,-0.004275,-0.054711,0.054711,0.202478,0.092227,-0.092227,-0.293536,-0.077331,0.298219,0.352551,-0.069836,0.069836,-0.083836,0.091315,-0.091315,-0.031127,-0.030643,0.026235,-0.844428,-0.027647,-0.041769,0.022613,-0.175524,-0.490423,-0.330867,-1.083094,1.023246,0.060167,0.171510,0.548779,-0.083330,-0.292382,0.077194,0.659043,0.429700,0.597714,-0.242131,-0.185756,0.312051
3589,-0.071134,-0.117116,-0.0355,-0.017742,0.174458,1.954787,-1.954787,-0.239901,0.177133,-0.177133,-0.156273,-0.110109,0.096010,0.318639,0.640932,-0.640932,-0.107305,0.157386,-0.157386,-0.013500,0.009134,0.012008,-0.907385,-0.027660,-0.041620,0.022611,-0.208317,-0.490423,-0.783183,-0.270058,0.408907,-2.906019,0.959291,-2.636846,0.476722,3.648441,-0.313997,-1.261863,-1.319194,-1.420679,4.129990,-0.185756,-3.204607
415,-0.071134,-0.117116,-0.0355,-0.017742,-0.037962,-0.253911,0.253911,0.110849,-0.411112,0.411112,-0.273988,-0.097572,0.269409,-0.502304,-0.487437,0.487437,0.128285,-0.162350,0.162350,-0.023041,-0.020479,0.024953,0.483060,-0.024695,-0.032533,0.021526,0.399749,-0.490423,0.273194,0.938508,-0.836835,0.455247,-0.511866,-0.947499,-0.096417,0.021787,0.102868,-1.024723,0.715114,1.253905,-0.242131,-0.185756,0.312051
2925,-0.071134,-0.117116,-0.0355,-0.017742,0.010035,0.559955,-0.559955,0.873931,0.966253,-0.966253,0.021674,-0.142886,-0.099716,2.651326,1.110078,-1.110078,-0.137152,-0.122235,0.122235,-0.032889,-0.030154,0.026002,-0.779673,-0.027628,-0.041722,0.022601,-0.205947,-0.490423,0.365117,-1.105068,1.040311,-0.470722,0.570146,0.210910,0.000876,-0.066771,0.012333,-0.261806,-1.248770,0.522987,-0.242131,-0.185756,0.312051
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,-0.071134,-0.117116,-0.0355,-0.017742,-0.035650,-0.206766,0.206766,-0.396396,0.736509,-0.736509,-0.263603,-0.088207,0.254104,-0.019645,0.832215,-0.832215,-0.142028,0.029964,-0.029964,-0.035572,-0.029177,0.028335,0.123307,-0.026313,-0.037054,0.022072,-0.228863,-0.490423,0.960423,1.026404,-1.365849,0.297832,-0.454918,-1.092300,-0.080613,0.139864,0.090031,1.143456,1.093772,0.255217,-0.242131,-0.185756,0.312051
2763,-0.071134,-0.117116,-0.0355,-0.017742,-0.031798,-0.076652,0.076652,-0.276621,-0.121035,0.121035,-0.011924,-0.145605,-0.115921,-0.804897,-0.226889,0.226889,0.005323,1.535436,-1.535436,-0.025877,-0.021847,0.024136,-0.216659,-0.027136,-0.039964,0.022373,-0.210688,-0.490423,0.937078,1.509831,-1.741278,0.245360,-0.303056,-1.019900,-0.081848,0.095585,0.075842,0.628206,-1.259209,0.958113,-0.242131,-0.185756,0.312051
905,-0.071134,-0.117116,-0.0355,-0.017742,-0.026434,-0.229689,0.229689,-0.280512,-0.276861,0.276861,-0.294758,-0.086545,0.300019,0.372565,-0.321524,0.321524,-0.054763,-0.034928,0.034928,-0.026912,-0.028004,0.022853,-0.410926,-0.027394,-0.041346,0.022490,-0.107566,0.780456,0.264440,-0.292032,0.596622,-0.276268,0.389811,0.259177,-0.031967,-0.146894,0.025846,0.606496,-0.141000,0.904403,-0.242131,-0.185756,0.312051
3980,-0.071134,-0.117116,-0.0355,-0.017742,-0.034204,-0.235459,0.235459,-0.328916,-0.176973,0.176973,-0.032083,-0.065851,0.090248,-0.601983,-0.130241,0.130241,-0.147475,-0.207184,0.207184,-0.037066,-0.032304,0.029734,-0.331780,-0.027303,-0.041362,0.022569,-0.208712,0.780456,-0.202467,-1.214938,0.989116,0.643527,-0.863045,0.247110,-0.121605,-0.324010,0.144081,-1.268001,1.283264,-0.464025,-0.242131,-0.185756,0.312051


In [26]:
y_train.value_counts()

koi_disposition
0    1589
1    1589
Name: count, dtype: int64

### Training

In [27]:
models = {
    "Logistic Regression": LogisticRegression(),
    "      Decision Tree": DecisionTreeClassifier(),
    "     Neural Network": MLPClassifier(),
    "      Random Forest": RandomForestClassifier(),
    "  Gradient Boosting": GradientBoostingClassifier(),
    "            XGBoost": XGBClassifier(),
    "           LightGBM": LGBMClassifier(),
    "           CatBoost": CatBoostClassifier(verbose=0)
}

In [28]:
for name, model in models.items():
    model.fit(X_train, y_train)
    print(name + " trained.")

Logistic Regression trained.
      Decision Tree trained.
     Neural Network trained.
      Random Forest trained.
  Gradient Boosting trained.
            XGBoost trained.
[LightGBM] [Info] Number of positive: 1589, number of negative: 1589
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000846 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 8583
[LightGBM] [Info] Number of data points in the train set: 3178, number of used features: 40
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
           LightGBM trained.
           CatBoost trained.


### Results

In [38]:
y_pred = model.predict(X_test)
y_pred

array([0, 1, 1, ..., 1, 1, 1])

In [40]:
y_test

4311    0
1808    1
993     1
1573    0
3566    0
       ..
642     1
2497    1
2223    1
2273    1
1999    0
Name: koi_disposition, Length: 1363, dtype: int64

In [41]:
def get_classifications(y_test, y_pred, positive_label=1):  # 1 means CONFIRMED
    tp = 0
    fn = 0
    fp = 0
    tn = 0

    for y_t, y_p in zip(y_test, y_pred):
        if y_t == positive_label:
            if y_p == positive_label:
                tp += 1
            else:
                fn += 1
        else:
            if y_p == positive_label:
                fp += 1
            else:
                tn += 1
    
    return tp, fn, fp, tn

In [42]:
get_classifications(y_test, y_pred, positive_label=1)

(615, 89, 187, 472)

In [43]:
def get_accuracy(tp, fn, fp, tn):
    acc = (tp + tn) / (tp + fn + fp + tn)
    return acc

In [44]:
get_accuracy(*get_classifications(y_test, y_pred, positive_label=1))

0.797505502567865

In [45]:
def get_precision(tp, fn, fp, tn):
    precision = tp / (tp + fp)
    return precision

def get_recall(tp, fn, fp, tn):
    recall = tp / (tp + fn)
    return recall

def get_f1_score(tp, fn, fp, tn):
    precision = get_precision(tp, fn, fp, tn)
    recall = get_recall(tp, fn, fp, tn)
    f1_score = (2 * precision * recall) / (precision + recall)
    return f1_score

In [50]:
for name, model in models.items():
    y_pred = model.predict(X_test)
    print(name + " Accuracy: {:.3f}%".format(get_accuracy(*get_classifications(y_test, y_pred, positive_label=1)) * 100))

Logistic Regression Accuracy: 79.751%
      Decision Tree Accuracy: 75.569%
     Neural Network Accuracy: 80.631%
      Random Forest Accuracy: 81.438%
  Gradient Boosting Accuracy: 82.025%
            XGBoost Accuracy: 81.511%
           LightGBM Accuracy: 81.731%
           CatBoost Accuracy: 82.392%


In [54]:
for name, model in models.items():
    y_pred = model.predict(X_test)
    print(name + " F1_Score: {:.5f}".format(get_f1_score(*get_classifications(y_test, y_pred, positive_label=1))))

Logistic Regression F1_Score: 0.81673
      Decision Tree F1_Score: 0.76400
     Neural Network F1_Score: 0.81356
      Random Forest F1_Score: 0.82394
  Gradient Boosting F1_Score: 0.83138
            XGBoost F1_Score: 0.82716
           LightGBM F1_Score: 0.82910
           CatBoost F1_Score: 0.83516
