In [2]:
import os
import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
import pickle

with open('./gdbt.pickle', 'rb') as f:
    model_gbdt = pickle.load(f)
with open('./lr.pickle', 'rb') as f:
    model_lr = pickle.load(f)
with open('./rf.pickle', 'rb') as f:
    model_rf = pickle.load(f)

print(np.version.version)

1.17.3


In [2]:
df_test =  pd.read_csv('./data/test_features.csv')

In [3]:
def logdf(cols):
    df_test[cols] = df_test[cols].map(lambda i: np.log(i) if i > 0 else 0)

def dffillmed(cols):
    df_test[cols] = df_test[cols].fillna(df_test[cols].median())

ids = df_test['name']
df_test = df_test.drop(['email_address', 'name'] , axis=1)
df_test['deferred_income'] = df_test['deferred_income'].map(lambda i: np.log(-i) if -i > 0 else 0)

logdf('bonus')
logdf('deferral_payments')
logdf('director_fees')
logdf('exercised_stock_options')
logdf('expenses')
logdf('long_term_incentive')
logdf('other')
logdf('restricted_stock')
logdf('salary')
logdf('total_payments')
logdf('total_stock_value')

logdf('loan_advances')
logdf('restricted_stock_deferred')

dffillmed('from_poi_to_this_person')
dffillmed('from_messages')
dffillmed('to_messages')
dffillmed('shared_receipt_with_poi')
dffillmed('from_this_person_to_poi')

df_test.head()

Unnamed: 0,bonus,deferral_payments,deferred_income,director_fees,exercised_stock_options,expenses,from_messages,from_poi_to_this_person,from_this_person_to_poi,loan_advances,long_term_incentive,other,restricted_stock,restricted_stock_deferred,salary,shared_receipt_with_poi,to_messages,total_payments,total_stock_value
0,15.473738,14.57819,14.66328,0.0,13.767513,9.761636,484.0,228.0,108.0,0.0,0.0,12.258181,11.967619,0.0,12.273727,5521.0,7991.0,15.520555,13.920506
1,14.115615,0.0,6.725034,0.0,0.0,11.096,27.0,140.0,15.0,0.0,13.789467,7.390799,12.437403,0.0,12.537536,1593.0,1858.0,14.797435,12.437403
2,14.220976,0.0,14.952385,0.0,15.527144,10.435262,32.0,32.0,21.0,0.0,14.29609,9.336973,13.65659,0.0,12.402022,1035.0,1045.0,12.573081,15.670411
3,14.914123,0.0,0.0,0.0,14.644548,11.364124,3069.0,66.0,609.0,0.0,14.074007,7.415175,14.095524,0.0,12.808099,2097.0,3093.0,15.37323,15.100398
4,13.815511,0.0,12.367341,0.0,0.0,10.331171,49.0,58.0,12.0,0.0,12.765688,12.637514,14.732626,0.0,12.936489,1585.0,1892.0,14.440785,14.732626


In [4]:
gdbt_predf = model_gbdt.predict_proba(df_test)[:,1]
sub = pd.DataFrame({'name': ids, 'poi': gdbt_predf})
sub.to_csv('sub_gdbt.csv', index=False)
gdbt_predf

array([0.33433268, 0.22698633, 0.29432686, 0.22698633, 0.10844339,
       0.30682965, 0.77306486, 0.07970951, 0.06278301, 0.10844339,
       0.07456377, 0.04611947, 0.04611947, 0.07456377, 0.40980962,
       0.04611947, 0.08912824, 0.12727306, 0.15253194, 0.07456377,
       0.84222868, 0.31993754, 0.06264158, 0.10164758, 0.08342958,
       0.04940703, 0.0635813 , 0.08342958, 0.11333933, 0.26740638,
       0.04611947, 0.10735956, 0.43133653])

In [5]:
rf_predf = model_rf.predict_proba(df_test)[:,1]
sub = pd.DataFrame({'name': ids, 'poi': rf_predf})
sub.to_csv('sub_rf.csv', index=False)
rf_predf

array([0.35784715, 0.35818102, 0.34990397, 0.09901528, 0.37771693,
       0.20985959, 0.62185878, 0.13085743, 0.11014447, 0.22035581,
       0.11535088, 0.00393317, 0.04382032, 0.12085743, 0.40000039,
       0.02357603, 0.12841374, 0.19877603, 0.26897985, 0.12525253,
       0.76811111, 0.33047397, 0.08294966, 0.16188178, 0.25768386,
       0.09485551, 0.08495389, 0.13800029, 0.19362835, 0.34002592,
       0.0147304 , 0.19188885, 0.38724354])

In [6]:
lr_predf = model_lr.predict_proba(df_test)[:,1]
sub = pd.DataFrame({'name': ids, 'poi': lr_predf})
sub.to_csv('sub_lr.csv', index=False)
lr_predf

array([9.99999999e-01, 9.68947182e-01, 9.93296523e-01, 1.00000000e+00,
       9.72378567e-01, 5.05968027e-03, 9.99794615e-01, 9.86511638e-01,
       5.71784820e-01, 7.54244757e-01, 3.03530317e-41, 6.16963325e-03,
       9.26577623e-01, 3.72027363e-01, 1.10998753e-25, 1.62404107e-02,
       4.53405239e-02, 9.97434666e-01, 1.40380354e-01, 3.90449763e-08,
       2.87210878e-01, 9.99980602e-01, 7.64997369e-03, 5.59335744e-01,
       6.51621730e-01, 1.24926467e-03, 1.38590848e-07, 4.17809902e-04,
       5.90087818e-01, 9.55743193e-02, 5.07257340e-12, 1.04399472e-01,
       9.99970166e-01])