In [1]:
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import LeaveOneOut, KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.metrics import classification_report

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

## New data

In [2]:
label_dict = {4:0, 5:1}

In [3]:
X_train = pd.read_csv("../data/HAPT/Train/X_train.txt", header=None, sep=" ")
y = pd.read_csv("../data/HAPT/Train/y_train.txt", header=None).rename(columns={0:"label"})
X_train["label"] = y.label
X_train = X_train.sample(frac=1.).reset_index(drop=True)

X_test = pd.read_csv("../data/HAPT/Test/X_test.txt", header=None, sep=" ")
y = pd.read_csv("../data/HAPT/Test/y_test.txt", header=None).rename(columns={0:"label"})
X_test["label"] = y.label 

In [4]:
X_train.shape, X_test.shape

((7767, 562), (3162, 562))

In [5]:
np.unique(X_train.label, return_counts=True)

(array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]),
 array([1226, 1073,  987, 1293, 1423, 1413,   47,   23,   75,   60,   90,
          57]))

In [6]:
train_idx = X_train.label.isin([4,5])
np.unique(X_train[train_idx].label, return_counts=True)

(array([4, 5]), array([1293, 1423]))

In [7]:
# Sitting vs standing
train_idx = X_train.label.isin([4,5])
X_train = X_train[train_idx].reset_index(drop=True)
X_train["label"] = X_train["label"].map(label_dict)

test_idx = X_test.label.isin([4,5])
X_test = X_test[test_idx].reset_index(drop=True)
X_test["label"] = X_test["label"].map(label_dict)
y_test = X_test["label"]

In [8]:
X_train.shape, X_test.shape

((2716, 562), (1064, 562))

In [9]:
np.unique(X_train.label, return_counts=True)

(array([0, 1]), array([1293, 1423]))

In [10]:
X_train.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335,336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431,432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447,448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463,464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511,512,513,514,515,516,517,518,519,520,521,522,523,524,525,526,527,528,529,530,531,532,533,534,535,536,537,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,553,554,555,556,557,558,559,560,label
0,0.03542,0.001244,-0.010663,-0.996965,-0.994525,-0.990779,-0.997072,-0.994399,-0.991379,-0.802993,-0.744936,-0.702598,0.839002,0.715419,0.681221,-0.992871,-0.999969,-0.999941,-0.999806,-0.996563,-0.994291,-0.990245,-0.904177,-0.543483,-0.349238,0.142745,-0.097043,0.100102,-0.118569,0.019684,-0.006496,0.131612,0.045949,0.471738,-0.288712,0.253717,-0.411336,-0.206082,-0.451335,0.076261,0.886118,0.215624,0.255853,-0.991231,-0.989879,-0.980523,-0.991377,-0.990075,-0.980368,0.818452,0.181203,0.251181,0.907625,0.232888,0.25161,0.150014,0.70225,-0.897752,-0.876327,-0.991703,-0.991181,-0.979701,-0.612051,-0.721574,-0.433234,-0.487969,0.49189,-0.496193,0.500874,0.052255,-0.033966,0.013231,0.008271,-0.241865,0.237972,-0.234041,0.229412,-0.986457,-0.998157,0.975032,0.07424,0.011114,-0.046336,-0.996211,-0.98959,-0.993607,-0.99604,-0.988517,-0.992395,-0.994844,-0.989503,-0.990397,0.997159,0.991751,0.993419,-0.99492,-0.999967,-0.999846,-0.999885,-0.99365,-0.990667,-0.987837,-0.862673,-0.814214,-0.766512,0.224222,-0.029104,0.334422,0.268444,0.06299,0.039067,0.173748,0.408439,0.497426,0.005516,0.532869,0.200108,-0.043368,0.192923,-0.16512,-0.02419,-0.021585,-0.063298,-0.996396,-0.989951,-0.994157,-0.996309,-0.989984,-0.994115,-0.909399,-0.950523,-0.763433,0.851441,0.892651,0.834727,-0.988358,-0.99998,-0.999925,-0.999815,-0.996048,-0.989814,-0.99331,-0.421608,-0.415358,-0.229107,0.295711,-0.319628,0.400343,-0.319997,-0.114132,0.127632,-0.073447,0.175588,0.075372,-0.079878,0.200359,-0.343324,0.330308,0.018889,0.235675,-0.100995,-0.042903,-0.054515,-0.995611,-0.994407,-0.996024,-0.994928,-0.994936,-0.996254,-0.997637,-0.994672,-0.99461,0.996294,0.993457,0.995877,-0.996076,-0.999974,-0.999971,-0.999968,-0.994042,-0.995711,-0.997245,-0.721776,-0.613829,-0.715049,0.316107,-0.060547,0.320592,0.325275,-0.024817,0.161394,0.113298,0.048386,0.28174,-0.065518,0.426458,0.297385,0.274669,-0.125117,0.081837,-0.992635,-0.991426,-0.991911,-0.98841,-0.994511,-0.992635,-0.999898,-0.992358,-0.620187,0.083632,-0.00062,0.013071,-0.067123,-0.992635,-0.991426,-0.991911,-0.98841,-0.994511,-0.992635,-0.999898,-0.992358,-0.620187,0.083632,-0.00062,0.013071,-0.067123,-0.995459,-0.99549,-0.995488,-0.992248,-0.986118,-0.995459,-0.999927,-0.99533,-0.889299,0.316468,-0.322731,-0.06574,-0.015007,-0.988727,-0.992831,-0.993004,-0.990934,-0.997683,-0.988727,-0.999894,-0.993834,-0.273966,-0.036579,-0.140001,0.269979,-0.245276,-0.995875,-0.995975,-0.996434,-0.994605,-0.987473,-0.995875,-0.999974,-0.996665,-0.73433,0.458161,-0.327315,-0.482512,0.363986,-0.996375,-0.990867,-0.990257,-0.996904,-0.995312,-0.990731,-0.996187,-0.991123,-0.989305,-0.997803,-0.997425,-0.990091,-0.999811,-0.997947,-0.987777,-0.995143,-0.999982,-0.999946,-0.999847,-0.997171,-0.992674,-0.989944,-1.0,-1.0,-0.868799,-0.866667,-0.466667,-1.0,0.008921,0.220599,0.273867,-0.573938,-0.897885,-0.675853,-0.907955,-0.244074,-0.550254,-0.999981,-0.999986,-0.999962,-0.999915,-0.999941,-0.999932,-0.99998,-1.0,-0.999982,-0.999948,-0.999951,-0.999987,-0.999983,-0.999923,-0.999981,-0.999853,-0.999878,-0.999882,-0.999764,-0.999935,-0.999886,-0.999991,-0.999954,-0.999853,-0.999837,-0.99993,-0.999949,-0.999883,-0.999847,-0.999925,-0.999884,-0.99976,-0.99986,-0.999695,-0.999614,-0.999906,-0.999852,-0.999869,-0.999843,-0.999701,-0.999851,-0.999802,-0.995908,-0.990851,-0.992758,-0.996955,-0.988714,-0.992251,-0.996021,-0.988253,-0.991522,-0.997066,-0.989793,-0.993088,-0.995579,-0.994608,-0.999414,-0.995075,-0.999967,-0.999846,-0.999885,-0.995286,-0.990702,-0.987349,-1.0,-1.0,-1.0,0.04,-0.16,-0.04,0.120674,-0.042636,0.327397,-0.544108,-0.850271,-0.49102,-0.878653,-0.615221,-0.920396,-0.999984,-0.999987,-0.999967,-0.999921,-0.99993,-0.99992,-0.999964,-0.999988,-0.999986,-0.99995,-0.99994,-0.999964,-0.999979,-0.99991,-0.999942,-0.999876,-0.999838,-0.99988,-0.999781,-0.99992,-0.999864,-0.999947,-0.999881,-0.999834,-0.999827,-0.999879,-0.999844,-0.999884,-0.999871,-0.999917,-0.999901,-0.999781,-0.999838,-0.999719,-0.999213,-0.999997,-0.999916,-0.999875,-0.999814,-0.999254,-0.999937,-0.999816,-0.995161,-0.991173,-0.993756,-0.996754,-0.989239,-0.994466,-0.99626,-0.990843,-0.992792,-0.997218,-0.990797,-0.99368,-0.99862,-0.99602,-0.999545,-0.993777,-0.999986,-0.999923,-0.999961,-0.996023,-0.992294,-0.991104,-0.920371,-0.764603,-0.918335,-0.666667,-1.0,-1.0,0.22992,-0.201507,-0.089774,-0.41123,-0.741018,-0.264591,-0.670335,-0.103213,-0.380194,-0.999989,-0.999989,-0.999956,-0.999989,-0.999974,-0.999969,-0.999977,-0.999991,-0.999989,-0.999961,-0.999971,-0.999983,-0.999987,-0.999985,-0.999915,-0.999987,-0.999985,-0.999978,-0.999979,-0.999973,-0.999982,-0.999997,-0.99992,-0.99998,-0.999978,-0.999986,-0.999919,-0.999977,-0.99997,-0.999977,-0.999954,-0.99997,-0.999988,-0.999976,-0.999967,-1.0,-0.999966,-0.999948,-0.999991,-0.999981,-0.999963,-0.999978,-0.989866,-0.991878,-0.987278,-0.993841,-0.992041,-0.989866,-0.999902,-0.986233,-0.907014,-1.0,0.118484,-0.471269,-0.736823,-0.995354,-0.994664,-0.994101,-0.995384,-0.988229,-0.995354,-0.999939,-0.994631,-1.0,-0.873016,0.418541,-0.541377,-0.857222,-0.993801,-0.993134,-0.991746,-0.994335,-0.99695,-0.993801,-0.999949,-0.989755,-0.872354,-1.0,-0.014357,-0.608602,-0.831991,-0.99598,-0.996044,-0.995122,-0.997514,-0.999088,-0.99598,-0.99998,-0.994012,-1.0,-0.904762,0.312585,-0.682602,-0.923821,0.033059,0.353255,-0.367092,0.66187,-0.66291,-0.062711,-0.15162,0
1,0.038503,-0.000789,-0.019411,-0.994402,-0.982502,-0.984675,-0.995733,-0.984795,-0.986861,-0.79994,-0.732582,-0.706772,0.832837,0.708327,0.667937,-0.991651,-0.999963,-0.999813,-0.999816,-0.996781,-0.9902,-0.988831,-0.72901,-0.516138,-0.5033,0.253187,-0.214929,0.397841,-0.025094,-0.087912,0.028365,0.171567,-0.122982,0.437685,-0.215445,0.425466,-0.213178,-0.263189,-0.527357,-0.047721,0.957914,0.071953,0.101605,-0.997536,-0.995399,-0.995815,-0.997509,-0.995939,-0.995661,0.887296,0.040492,0.092775,0.980489,0.093542,0.101875,-0.516426,0.886192,-0.985763,-0.981847,-0.997406,-0.997326,-0.995546,-1.0,-0.954192,-1.0,-0.463925,0.484668,-0.505578,0.526607,-0.521538,0.532148,-0.543447,0.554408,-0.061648,0.108646,-0.156138,0.203389,-0.302784,-0.898784,-0.006169,0.074009,0.02209,-0.042823,-0.992699,-0.983795,-0.978988,-0.993017,-0.983282,-0.976885,-0.992955,-0.987711,-0.978051,0.990258,0.985566,0.981832,-0.986558,-0.999924,-0.999713,-0.999472,-0.99265,-0.98523,-0.974122,-0.786949,-0.750298,-0.528008,0.171196,-0.316833,0.458083,0.354952,-0.001896,-0.017015,0.311631,0.06186,0.291209,-0.101619,0.544488,0.219357,-0.366857,-0.351201,0.38684,-0.027194,-0.036765,-0.062281,-0.992177,-0.982348,-0.978672,-0.992225,-0.981836,-0.976657,-0.906512,-0.955029,-0.751385,0.848167,0.883707,0.824727,-0.97857,-0.999959,-0.99971,-0.999529,-0.99388,-0.981471,-0.971514,-0.438845,-0.573072,-0.099185,0.21391,-0.283044,0.209551,0.168093,-0.31715,0.188866,0.059083,0.064369,0.111712,-0.196749,0.220005,-0.066963,0.279944,-0.07566,-0.262184,-0.098593,-0.03291,-0.072132,-0.994262,-0.992741,-0.991164,-0.994484,-0.993057,-0.991377,-0.996206,-0.994561,-0.988409,0.993849,0.994494,0.993968,-0.993898,-0.999962,-0.999956,-0.999907,-0.994355,-0.994118,-0.991641,-0.662544,-0.505138,-0.58854,0.159513,-0.135817,0.057017,0.035122,-0.207639,0.100108,0.032251,0.188145,0.238463,-0.275244,0.27619,0.207153,0.084627,0.009013,-0.097697,-0.99089,-0.986675,-0.988334,-0.984544,-0.998063,-0.99089,-0.999853,-0.98971,-0.553766,0.155735,-0.082536,0.149601,-0.178857,-0.99089,-0.986675,-0.988334,-0.984544,-0.998063,-0.99089,-0.999853,-0.98971,-0.553766,0.155735,-0.082536,0.149601,-0.178857,-0.986695,-0.983208,-0.985234,-0.978383,-0.997181,-0.986695,-0.999698,-0.989527,-0.672109,0.536245,-0.228893,-0.373288,-0.09413,-0.978615,-0.982489,-0.979787,-0.988254,-0.991196,-0.978615,-0.999676,-0.976499,-0.008496,-0.125536,-0.081584,0.279558,-0.265848,-0.993698,-0.993899,-0.993919,-0.995555,-0.992203,-0.993698,-0.999954,-0.994066,-0.63244,0.449203,-0.272854,-0.171778,-0.044645,-0.993512,-0.975837,-0.977176,-0.994448,-0.984559,-0.987776,-0.992683,-0.975716,-0.977336,-0.997276,-0.988983,-0.992713,-0.998011,-0.99403,-0.984248,-0.984646,-0.999959,-0.999724,-0.999699,-0.99219,-0.981661,-0.973691,-1.0,-0.812205,-0.73874,-0.066667,-1.0,-0.307692,0.028919,-0.031091,0.425218,-0.571304,-0.920549,-0.425989,-0.751104,-0.704424,-0.91797,-0.999979,-0.999906,-0.999953,-0.999943,-0.999949,-0.999912,-0.999925,-0.999995,-0.999958,-0.999948,-0.999949,-0.999949,-0.99996,-0.999942,-0.999779,-0.999808,-0.999784,-0.999691,-0.999808,-0.999731,-0.99973,-0.999821,-0.999744,-0.999707,-0.999791,-0.999763,-0.999732,-0.999727,-0.999902,-0.999112,-0.999385,-0.999546,-0.99976,-0.999508,-0.998875,-0.999788,-0.999776,-0.999473,-0.999713,-0.999139,-0.999721,-0.999612,-0.992874,-0.983793,-0.97652,-0.993121,-0.984952,-0.979256,-0.990764,-0.985561,-0.976882,-0.993934,-0.98758,-0.984019,-0.999335,-0.994662,-0.954616,-0.985061,-0.999924,-0.999714,-0.999472,-0.989395,-0.989953,-0.973556,-1.0,-1.0,-0.900162,-0.44,-0.4,0.0,-0.087992,-0.263403,0.213456,-0.516648,-0.842516,-0.587984,-0.904838,-0.524854,-0.909319,-0.999977,-0.999894,-0.999953,-0.999942,-0.999949,-0.999875,-0.999873,-0.999999,-0.999923,-0.999948,-0.999934,-0.999877,-0.999923,-0.999922,-0.999729,-0.999807,-0.999715,-0.999741,-0.999834,-0.99979,-0.999774,-0.999992,-0.999768,-0.999678,-0.999814,-0.999805,-0.999708,-0.999791,-0.999863,-0.999106,-0.999462,-0.999547,-0.999734,-0.999462,-0.998258,-0.999399,-0.99922,-0.999536,-0.999657,-0.998279,-0.999328,-0.999613,-0.991304,-0.98458,-0.975855,-0.992382,-0.981115,-0.980532,-0.99121,-0.983679,-0.977167,-0.993441,-0.983431,-0.979791,-0.999398,-0.993093,-0.982503,-0.985022,-0.999954,-0.999799,-0.999695,-0.993508,-0.987269,-0.988416,-0.757285,-0.583101,-0.671907,-1.0,-1.0,-0.935484,-0.221375,-0.250808,0.024835,-0.308831,-0.669244,-0.189812,-0.592793,0.074656,-0.257604,-0.999957,-0.999967,-0.999983,-0.999983,-0.999965,-0.999958,-0.999986,-1.0,-0.999955,-0.99998,-0.99996,-0.999992,-0.999955,-0.999977,-0.999784,-0.999937,-0.999975,-0.999982,-0.999981,-0.99995,-0.999893,-0.999895,-0.999775,-0.999971,-0.999975,-0.99988,-0.99978,-0.999978,-0.999733,-0.999902,-0.999897,-0.999919,-0.999862,-0.999735,-0.999463,-0.999376,-0.999712,-0.999868,-0.999834,-0.999425,-0.999704,-0.999894,-0.98236,-0.988557,-0.982077,-0.990743,-0.995215,-0.98236,-0.999816,-0.982431,-0.907014,-1.0,0.26854,-0.401452,-0.667169,-0.981783,-0.984235,-0.985723,-0.983222,-0.976423,-0.981783,-0.999669,-0.991703,-0.940965,-1.0,0.504597,-0.091535,-0.471509,-0.987541,-0.981959,-0.983286,-0.984844,-0.995211,-0.987541,-0.999785,-0.986306,-0.720683,-0.948718,-0.213456,-0.10444,-0.471535,-0.994605,-0.99312,-0.994573,-0.991065,-0.995912,-0.994605,-0.999963,-0.996378,-0.955696,-1.0,0.308836,0.402921,0.198318,0.138377,0.118956,-0.024701,0.031174,-0.876519,0.037477,-0.042464,0
2,0.040654,-0.00607,-0.020801,-0.99038,-0.984304,-0.980658,-0.991889,-0.984965,-0.983982,-0.793232,-0.743581,-0.703319,0.833689,0.706325,0.663834,-0.98922,-0.999917,-0.999832,-0.999741,-0.994768,-0.988544,-0.989937,-0.568276,-0.667365,-0.51208,0.181708,-0.116273,0.032316,-0.051826,0.087924,-0.095985,0.250302,-0.273406,0.272754,-0.112638,0.108686,-0.159575,-0.724445,-0.685584,0.599023,0.812746,0.121869,0.447933,-0.994663,-0.996169,-0.986977,-0.99527,-0.996238,-0.987872,0.745883,0.088074,0.437965,0.838241,0.142895,0.441891,0.266807,0.52401,-0.964296,-0.613477,-0.997288,-0.996474,-0.990671,-0.981083,-0.739143,-0.674898,-0.689338,0.698632,-0.708198,0.718017,-0.405246,0.417883,-0.431431,0.444725,-0.33406,0.336066,-0.338079,0.33937,-0.823958,-0.958774,0.658754,0.079842,-0.007984,-0.053403,-0.992129,-0.985321,-0.985697,-0.991761,-0.984598,-0.983257,-0.992209,-0.989075,-0.988704,0.992711,0.989831,0.986292,-0.988686,-0.999915,-0.999751,-0.999687,-0.990754,-0.987515,-0.978028,-0.7344,-0.725725,-0.664244,0.257212,0.016135,0.280037,0.303088,0.220945,-0.053435,0.496839,0.052726,0.272245,0.160121,0.195341,0.527108,-0.378371,-0.137973,0.143286,-0.023937,-0.030955,-0.069836,-0.993151,-0.981446,-0.984192,-0.994513,-0.984059,-0.986983,-0.903103,-0.955776,-0.743353,0.851131,0.877835,0.828112,-0.986114,-0.999958,-0.999748,-0.999775,-0.996537,-0.987331,-0.989992,-0.359483,-0.530308,-0.234732,0.356684,-0.367705,0.133649,0.338421,-0.153094,0.045096,0.003381,0.309517,0.013254,-0.068606,0.019757,0.181507,-0.604742,0.356322,-0.735743,-0.101374,-0.041089,-0.051476,-0.993749,-0.992837,-0.992745,-0.993851,-0.992809,-0.993234,-0.992937,-0.995801,-0.988167,0.99391,0.992728,0.994718,-0.994025,-0.999957,-0.999957,-0.99993,-0.993792,-0.99276,-0.995489,-0.662237,-0.527104,-0.600184,0.297288,-0.126428,-0.040564,0.023716,-0.090152,0.010868,-0.032526,0.130716,0.130271,-0.218188,0.251788,-0.128299,-0.182174,-0.08359,-0.299698,-0.989627,-0.981974,-0.984365,-0.978848,-0.995296,-0.989627,-0.999809,-0.987671,-0.512689,0.38373,-0.33061,0.340104,-0.272705,-0.989627,-0.981974,-0.984365,-0.978848,-0.995296,-0.989627,-0.999809,-0.987671,-0.512689,0.38373,-0.33061,0.340104,-0.272705,-0.989072,-0.990148,-0.989708,-0.990521,-0.981605,-0.989072,-0.999793,-0.989155,-0.800439,0.251652,0.028258,-0.591902,0.194353,-0.986868,-0.973845,-0.980245,-0.97811,-0.995523,-0.986868,-0.99977,-0.991412,-0.163391,-0.032848,-0.341886,0.391815,0.113748,-0.993858,-0.994422,-0.994474,-0.992964,-0.983091,-0.993858,-0.999956,-0.993349,-0.635585,0.499567,-0.20261,-0.257807,-0.120131,-0.989888,-0.974668,-0.976833,-0.990175,-0.987374,-0.982384,-0.989564,-0.979791,-0.975241,-0.992757,-0.991235,-0.986204,-0.996029,-0.980522,-0.97808,-0.982319,-0.999911,-0.999767,-0.999582,-0.992273,-0.986281,-0.981787,-0.844178,-0.83986,-0.760627,-0.866667,-0.933333,-1.0,-0.111512,0.236628,0.158195,-0.311807,-0.743992,-0.460578,-0.771823,-0.372953,-0.704217,-0.999895,-0.999974,-0.999901,-0.999915,-0.999829,-0.99985,-0.999942,-0.999973,-0.99991,-0.999895,-0.999851,-0.999952,-0.999912,-0.999875,-0.99982,-0.999884,-0.999691,-0.999535,-0.999625,-0.999715,-0.999198,-0.999323,-0.999802,-0.999572,-0.999651,-0.999228,-0.999782,-0.999568,-0.999678,-0.999539,-0.999551,-0.999869,-0.999575,-0.999473,-0.999768,-0.999784,-0.999622,-0.999697,-0.999555,-0.999771,-0.999585,-0.999797,-0.991747,-0.984584,-0.983261,-0.993329,-0.987409,-0.986053,-0.991046,-0.987822,-0.98598,-0.995023,-0.990333,-0.984469,-0.996055,-0.989817,-0.987015,-0.987395,-0.999915,-0.999752,-0.999688,-0.989269,-0.99149,-0.984098,-1.0,-1.0,-0.942216,-0.16,-0.32,-0.28,0.134694,-0.011571,0.292372,-0.648096,-0.892613,-0.611992,-0.915987,-0.453439,-0.806323,-0.999967,-0.999971,-0.9999,-0.999923,-0.999806,-0.999794,-0.99992,-0.999998,-0.999968,-0.999898,-0.999805,-0.999923,-0.999939,-0.999822,-0.999894,-0.999923,-0.999632,-0.999689,-0.999884,-0.999661,-0.999809,-0.999972,-0.999912,-0.999593,-0.9998,-0.999833,-0.999776,-0.999754,-0.999904,-0.999598,-0.999618,-0.999832,-0.999513,-0.99936,-0.99957,-0.999812,-0.999656,-0.999755,-0.999457,-0.999571,-0.999628,-0.999707,-0.991925,-0.986127,-0.983016,-0.993465,-0.979011,-0.985281,-0.993178,-0.984174,-0.980729,-0.994144,-0.984139,-0.990597,-0.995972,-0.999711,-0.991516,-0.987686,-0.999963,-0.99978,-0.999817,-0.995057,-0.992483,-0.991409,-0.776641,-0.626021,-0.727855,-0.733333,-1.0,-0.935484,0.091097,-0.414553,-0.287124,-0.242231,-0.620581,-0.131237,-0.595565,-0.314889,-0.721829,-0.999966,-0.99997,-0.999993,-0.999964,-0.999947,-0.999912,-0.99998,-0.999985,-0.999964,-0.999981,-0.999928,-0.999982,-0.999965,-0.999955,-0.999724,-0.999973,-0.999983,-0.999984,-0.999971,-0.99996,-0.999961,-1.0,-0.999748,-0.999979,-0.999969,-0.999973,-0.999759,-0.999978,-0.999833,-0.999953,-0.999953,-0.999952,-0.999931,-0.999963,-0.999876,-0.999962,-0.999824,-0.999938,-0.999946,-0.999914,-0.99982,-0.999952,-0.980706,-0.982544,-0.978051,-0.985658,-0.998103,-0.980706,-0.999709,-0.98213,-0.907014,-1.0,0.078245,-0.0725,-0.413015,-0.989624,-0.989822,-0.988635,-0.992775,-0.99642,-0.989624,-0.99984,-0.990304,-1.0,-0.84127,0.510651,-0.55211,-0.846664,-0.98293,-0.972527,-0.973344,-0.98093,-0.996281,-0.98293,-0.999567,-0.989476,-0.696828,-0.846154,-0.493275,-0.361804,-0.746416,-0.994039,-0.995226,-0.99388,-0.99643,-0.993336,-0.994039,-0.999968,-0.991946,-1.0,-1.0,0.310746,-0.658474,-0.900227,-0.040788,0.093908,-0.351106,0.380865,-0.53408,0.002688,-0.294332,0
3,0.040935,0.001744,-0.007872,-0.994284,-0.98115,-0.984947,-0.995181,-0.981656,-0.987317,-0.793059,-0.731085,-0.688507,0.835551,0.708344,0.681482,-0.987378,-0.999959,-0.999765,-0.999644,-0.996155,-0.984743,-0.990012,-0.601784,-0.412346,-0.243825,0.293067,-0.060975,0.19621,0.012734,0.055635,-0.123258,0.175932,0.156006,0.466763,-0.361396,0.331065,-0.116174,0.088209,-0.370099,0.34578,0.94885,-0.230767,-0.069948,-0.99622,-0.985302,-0.981059,-0.996207,-0.984952,-0.980859,0.879105,-0.253899,-0.071324,0.971366,-0.208182,-0.073448,-0.228373,0.862451,-0.908172,-0.988624,-0.996362,-0.98464,-0.980241,-1.0,-1.0,-1.0,-0.204667,0.228406,-0.252781,0.277743,-0.55897,0.568518,-0.578861,0.589027,-0.457191,0.463261,-0.469306,0.474515,0.941429,0.980364,0.984643,0.07586,-0.004762,-0.049723,-0.988752,-0.980325,-0.98115,-0.990335,-0.977605,-0.976691,-0.982952,-0.987343,-0.990058,0.987575,0.984587,0.980817,-0.983785,-0.99986,-0.999618,-0.999546,-0.990405,-0.979812,-0.969542,-0.713384,-0.691235,-0.541666,0.232343,0.049382,0.383307,0.52365,0.048552,-0.185731,0.154979,0.104439,0.385153,-0.173431,0.302589,0.377066,-0.317137,-0.43156,0.3373,-0.028051,-0.016027,-0.08308,-0.991918,-0.983715,-0.974937,-0.993223,-0.984885,-0.979555,-0.907524,-0.947943,-0.755984,0.841695,0.888063,0.794238,-0.983856,-0.999956,-0.999818,-0.999577,-0.994989,-0.987254,-0.985063,-0.500954,-0.195713,-0.443467,0.28603,-0.281165,0.145747,0.331861,-0.360774,0.317188,-0.176454,0.336883,-0.107569,-0.04568,0.258342,-0.058074,-0.353835,0.456072,-0.231534,-0.095557,-0.037328,-0.05102,-0.99124,-0.990806,-0.982499,-0.991697,-0.991867,-0.983912,-0.990794,-0.990532,-0.977302,0.987031,0.988198,0.9817,-0.990845,-0.999929,-0.999935,-0.999742,-0.991857,-0.993148,-0.987061,-0.601313,-0.46426,-0.396492,0.150774,-0.050611,-0.058841,0.230015,-0.351167,0.226139,-0.128388,0.152032,-0.026375,-0.329916,0.442169,-0.085798,-0.163437,0.527209,-0.227598,-0.986373,-0.984347,-0.986374,-0.977503,-0.99457,-0.986373,-0.999765,-0.988785,-0.506101,0.224146,-0.214736,0.133732,0.10768,-0.986373,-0.984347,-0.986374,-0.977503,-0.99457,-0.986373,-0.999765,-0.988785,-0.506101,0.224146,-0.214736,0.133732,0.10768,-0.984584,-0.985771,-0.98692,-0.982524,-0.987464,-0.984584,-0.999672,-0.988352,-0.696405,0.352491,-0.352626,-0.051052,-0.027066,-0.984515,-0.975747,-0.978137,-0.97921,-0.993992,-0.984515,-0.999747,-0.987129,-0.020976,0.019662,-0.279668,0.28685,0.023005,-0.991029,-0.988799,-0.989477,-0.98907,-0.983863,-0.991029,-0.999911,-0.989995,-0.447275,0.196204,-0.196459,0.227267,-0.190552,-0.992326,-0.97113,-0.977707,-0.994965,-0.98438,-0.987924,-0.993389,-0.972502,-0.979688,-0.996348,-0.990907,-0.990808,-0.993953,-0.981269,-0.990459,-0.982777,-0.999958,-0.99969,-0.999707,-0.992228,-0.979765,-0.978509,-1.0,-0.763544,-0.809888,-0.2,-1.0,-0.461538,0.288129,0.072821,0.465497,-0.594804,-0.901859,-0.564082,-0.85385,-0.600214,-0.838299,-0.999986,-0.999929,-0.999827,-0.999863,-0.999811,-0.999872,-0.999919,-0.999994,-0.99997,-0.999817,-0.999848,-0.999944,-0.999962,-0.999837,-0.999754,-0.999817,-0.999657,-0.999763,-0.999758,-0.999418,-0.999677,-0.999556,-0.999721,-0.999604,-0.999634,-0.999619,-0.999699,-0.999732,-0.999802,-0.999585,-0.999626,-0.999743,-0.999617,-0.998794,-0.999084,-0.999825,-0.999752,-0.999699,-0.999412,-0.999299,-0.999723,-0.999667,-0.988916,-0.979906,-0.977299,-0.989545,-0.982308,-0.983133,-0.986941,-0.982585,-0.98069,-0.99138,-0.987418,-0.987306,-0.996652,-0.99456,-0.992606,-0.982224,-0.99986,-0.999618,-0.999546,-0.985929,-0.985899,-0.970556,-1.0,-1.0,-1.0,-0.12,0.24,-0.72,0.196066,-0.214139,0.249145,-0.363521,-0.783573,-0.638292,-0.93046,-0.761859,-0.955743,-0.999988,-0.999907,-0.999825,-0.999858,-0.999791,-0.999815,-0.999858,-0.999968,-0.999936,-0.99981,-0.999803,-0.999856,-0.999884,-0.999772,-0.999205,-0.999871,-0.999615,-0.999725,-0.99982,-0.999315,-0.999834,-0.999988,-0.999703,-0.999597,-0.9996,-0.999857,-0.999616,-0.999699,-0.999084,-0.999495,-0.99966,-0.999759,-0.999625,-0.998543,-0.998455,-0.999984,-0.999276,-0.999741,-0.999296,-0.998532,-0.999489,-0.999599,-0.989544,-0.985182,-0.969656,-0.992615,-0.9829,-0.977717,-0.991341,-0.984047,-0.964681,-0.993464,-0.989203,-0.987286,-0.996882,-0.999661,-0.993581,-0.983159,-0.999952,-0.999825,-0.999595,-0.988947,-0.986476,-0.970083,-0.739371,-0.589162,-0.572635,-0.666667,-0.806452,-0.548387,0.052944,-0.291309,-0.339716,-0.326756,-0.671264,-0.413457,-0.811805,-0.520215,-0.851779,-0.999961,-0.999934,-0.999959,-0.999958,-0.999935,-0.999939,-0.999976,-0.999993,-0.999954,-0.999951,-0.999932,-0.999983,-0.999953,-0.999951,-0.999809,-0.999956,-0.99997,-0.99998,-0.999969,-0.999976,-0.999982,-0.999999,-0.99981,-0.999966,-0.999971,-0.999987,-0.99981,-0.999976,-0.999647,-0.999863,-0.999776,-0.999956,-0.999943,-0.999894,-0.999774,-0.999972,-0.999618,-0.999771,-0.999935,-0.999861,-0.999598,-0.999951,-0.978922,-0.986802,-0.977083,-0.992296,-0.99551,-0.978922,-0.999766,-0.98078,-0.843671,-1.0,0.138142,-0.64403,-0.88366,-0.983118,-0.989362,-0.988406,-0.990667,-0.98407,-0.983118,-0.999739,-0.992204,-1.0,-0.936508,0.44214,-0.532882,-0.826608,-0.982944,-0.975091,-0.976447,-0.980482,-0.992101,-0.982944,-0.999621,-0.989218,-0.668778,-1.0,-0.323567,-0.244994,-0.636489,-0.989163,-0.988823,-0.988385,-0.988464,-0.991157,-0.989163,-0.999903,-0.990197,-0.895847,-1.0,0.096396,-0.147704,-0.498977,0.010854,-0.133821,0.108071,-0.25942,-0.775786,0.244795,0.076339,0
4,0.039243,-0.002173,-0.019856,-0.998294,-0.993611,-0.994571,-0.998097,-0.992928,-0.995191,-0.802836,-0.749987,-0.718481,0.840266,0.715508,0.678425,-0.997709,-0.999993,-0.999958,-0.999951,-0.997444,-0.992602,-0.995776,-0.787241,-0.654151,-0.629414,0.156553,-0.006346,-0.042483,0.192961,0.194353,-0.178006,0.255051,-0.244694,0.226299,0.025947,-0.100921,0.16582,0.377159,0.071268,-0.01761,0.923137,-0.308424,-0.060548,-0.999155,-0.996801,-0.998611,-0.999233,-0.99672,-0.998643,0.852053,-0.33225,-0.068623,0.946861,-0.281651,-0.058973,-0.091163,0.795914,-0.830797,-0.991302,-0.999436,-0.996901,-0.998717,-1.0,-1.0,-1.0,-0.493111,0.511816,-0.530411,0.548846,-0.548545,0.556318,-0.564493,0.572076,-0.440564,0.456204,-0.471417,0.485358,0.878195,0.043122,-0.390797,0.074917,0.00917,-0.042443,-0.997817,-0.996562,-0.994396,-0.997815,-0.996354,-0.99287,-0.996954,-0.998259,-0.994479,0.995762,0.995167,0.991982,-0.998054,-0.999983,-0.999961,-0.999902,-0.996541,-0.995687,-0.987733,-0.901594,-0.921923,-0.786031,0.142663,0.010147,0.263648,0.050397,0.300654,-0.065104,0.449595,0.194131,0.060506,0.235088,-0.029035,0.371442,0.249309,0.412934,0.031734,-0.027316,-0.024034,-0.077564,-0.995506,-0.994292,-0.994869,-0.995558,-0.994325,-0.994401,-0.910077,-0.958733,-0.771712,0.849056,0.891669,0.827558,-0.994562,-0.999983,-0.999967,-0.999968,-0.995974,-0.99454,-0.99321,-0.546972,-0.641066,-0.591373,0.285979,-0.292388,0.251412,-0.144236,-0.08184,-0.020665,0.112863,0.140359,0.337301,-0.317363,0.355693,-0.236417,-0.595912,-0.034635,-0.374959,-0.094309,-0.045274,-0.05753,-0.997531,-0.99901,-0.99744,-0.997302,-0.999103,-0.997423,-0.997861,-0.997759,-0.995983,0.997363,0.998027,0.998018,-0.999087,-0.999988,-0.999997,-0.999981,-0.996793,-0.998789,-0.998173,-0.794637,-0.944673,-0.73515,0.290104,0.008753,0.184121,0.261203,0.010274,-0.016436,0.104685,0.128783,0.471061,-0.242864,0.430355,0.548681,-0.203208,-0.080287,0.206637,-0.99751,-0.997822,-0.997974,-0.99704,-0.999494,-0.99751,-0.999974,-0.996893,-0.890724,0.410168,-0.38771,0.31677,-0.074249,-0.99751,-0.997822,-0.997974,-0.99704,-0.999494,-0.99751,-0.999974,-0.996893,-0.890724,0.410168,-0.38771,0.31677,-0.074249,-0.998021,-0.997852,-0.997981,-0.99734,-0.98981,-0.998021,-0.99997,-0.997389,-0.894566,0.066369,0.139557,-0.598701,0.313852,-0.99477,-0.995684,-0.995489,-0.996014,-0.994713,-0.99477,-0.999964,-0.995905,-0.418765,0.086549,-0.387595,0.583157,-0.408718,-0.99912,-0.999184,-0.99906,-0.998915,-0.996485,-0.99912,-0.999996,-0.998901,-0.896688,0.601596,-0.566242,0.139388,-0.102166,-0.998163,-0.994538,-0.993245,-0.998003,-0.993014,-0.994681,-0.997598,-0.994006,-0.992144,-0.998416,-0.99039,-0.99458,-0.999851,-0.997921,-0.99125,-0.998194,-0.999991,-0.999934,-0.99992,-0.997221,-0.99501,-0.99001,-1.0,-0.94487,-0.94419,-1.0,-1.0,-0.846154,0.024762,0.01238,0.286769,-0.535402,-0.874656,0.463657,0.324596,-0.572777,-0.809017,-0.999991,-0.999982,-0.99999,-0.99993,-0.999984,-0.999963,-0.999987,-1.0,-0.999989,-0.999976,-0.999989,-0.999991,-0.999992,-0.999952,-0.99993,-0.999983,-0.999953,-0.999941,-0.999921,-0.999863,-0.999982,-0.999997,-0.999933,-0.999942,-0.999924,-0.999991,-0.999935,-0.999953,-0.999925,-0.999853,-0.999897,-0.999847,-0.999856,-0.999806,-0.999966,-0.999966,-0.999918,-0.999909,-0.999869,-0.99997,-0.999917,-0.999871,-0.997863,-0.99683,-0.994021,-0.997941,-0.996427,-0.992515,-0.997302,-0.997058,-0.991773,-0.998011,-0.995984,-0.995472,-0.993058,-0.996917,-0.994616,-0.998457,-0.999983,-0.999961,-0.999902,-0.995374,-0.998577,-0.987893,-1.0,-1.0,-1.0,0.16,-0.2,-0.48,0.136375,0.055474,0.073735,-0.468116,-0.842824,-0.522713,-0.870266,-0.638468,-0.941755,-0.999994,-0.999983,-0.999993,-0.999919,-0.999983,-0.999944,-0.999978,-0.999993,-0.999988,-0.99997,-0.999985,-0.999979,-0.99999,-0.99994,-0.99997,-0.999985,-0.999925,-0.999952,-0.999941,-0.999866,-0.999971,-0.999997,-0.999986,-0.999931,-0.999933,-0.999978,-0.999959,-0.99997,-0.999845,-0.999807,-0.999907,-0.999832,-0.999842,-0.999819,-0.999957,-0.999989,-0.999813,-0.999903,-0.999847,-0.999958,-0.999896,-0.99986,-0.994518,-0.995395,-0.995166,-0.995754,-0.993622,-0.994932,-0.995929,-0.995294,-0.994508,-0.996122,-0.993465,-0.99452,-0.998879,-0.994205,-0.994951,-0.995818,-0.999981,-0.999969,-0.999968,-0.996624,-0.997674,-0.995676,-0.866126,-0.84774,-0.918335,-1.0,-1.0,-1.0,0.115057,-0.11857,0.045532,-0.247043,-0.640112,-0.084173,-0.490173,-0.021969,-0.379231,-0.999982,-0.999994,-0.99998,-0.999989,-0.999968,-0.99996,-0.999977,-0.999983,-0.999982,-0.999981,-0.999963,-0.99998,-0.999982,-0.999983,-0.999961,-0.999996,-0.999996,-0.999998,-0.99999,-0.999984,-0.99997,-0.999971,-0.999965,-0.999996,-0.999989,-0.999967,-0.999966,-0.999996,-0.999972,-0.999987,-0.999962,-0.999994,-0.999984,-0.999943,-0.999954,-0.999988,-0.999971,-0.999967,-0.999979,-0.999969,-0.999968,-0.999991,-0.997708,-0.997622,-0.996884,-0.997845,-0.997672,-0.997708,-0.999982,-0.996512,-1.0,-0.578947,0.574853,-0.549423,-0.780728,-0.997638,-0.997222,-0.996046,-0.997695,-0.988506,-0.997638,-0.999973,-0.994692,-1.0,-0.777778,0.372216,-0.732348,-0.935462,-0.997381,-0.994996,-0.995003,-0.994562,-0.997679,-0.997381,-0.999975,-0.995413,-0.956057,-1.0,-0.112523,-0.219762,-0.489628,-0.999177,-0.999092,-0.998995,-0.998968,-0.999893,-0.999177,-0.999997,-0.998142,-1.0,-0.936508,0.469461,-0.657031,-0.84385,0.058502,0.094588,-0.305328,-0.611742,-0.702566,0.299004,0.069815,1


# Logreg

In [11]:
def get_predictions(x):
    return [1 if xi >= 0.5 else 0 for xi in x]

In [12]:
X, y = X_train[[col for col in X_train.columns if col != "label"]].values , X_train.label

In [13]:
X.shape, y.shape

((2716, 561), (2716,))

In [14]:
loo = KFold(n_splits=10)
preds = np.zeros(len(y))
for i, (train_index, test_index) in enumerate(loo.split(X)):
    _X_train, _X_test = X[test_index, :], X[train_index,:]
    _y_train = y[test_index].values
    clf = LogisticRegression(random_state=0, C=1., max_iter=500).fit(_X_train, _y_train)
    preds[train_index] = clf.predict_proba(_X_test)[:,1]

print(f"Models AUC score: {roc_auc_score(y, preds)}")
print(classification_report(y, get_predictions(preds)))

Models AUC score: 0.9746627469715028
              precision    recall  f1-score   support

           0       0.93      0.90      0.92      1293
           1       0.91      0.94      0.93      1423

    accuracy                           0.92      2716
   macro avg       0.92      0.92      0.92      2716
weighted avg       0.92      0.92      0.92      2716



In [15]:
# Logreg MLE
log_reg = LogisticRegression(random_state=0, C=1.0, max_iter = 500).fit(X[:561,:], y[:561])

In [16]:
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
y_pred = log_reg.predict_proba(_X_test)[:,1]

In [17]:
print(f"Models AUC score: {roc_auc_score(y_test, y_pred)}")

Models AUC score: 0.9780455163428313


# Gaussian logreg

## MAP l2 estimate C=0.3

In [18]:
loo = KFold(n_splits=10)
preds = np.zeros(len(y))
for i, (train_index, test_index) in enumerate(loo.split(X)):
    _X_train, _X_test = X[test_index, :], X[train_index,:]
    _y_train = y[test_index].values
    clf = LogisticRegression(random_state=0, C=.3, max_iter=500).fit(_X_train, _y_train)
    preds[train_index] = clf.predict_proba(_X_test)[:,1]

print(f"Models AUC score: {roc_auc_score(y, preds)}")
print(classification_report(y, get_predictions(preds)))

Models AUC score: 0.9673130467912251
              precision    recall  f1-score   support

           0       0.93      0.87      0.89      1293
           1       0.89      0.94      0.91      1423

    accuracy                           0.90      2716
   macro avg       0.91      0.90      0.90      2716
weighted avg       0.90      0.90      0.90      2716



In [19]:
log_reg = LogisticRegression(random_state=0, C=.3, penalty="l2", max_iter = 500).fit(X[:561,:], y[:561])

In [20]:
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
y_pred = log_reg.predict_proba(_X_test)[:,1]
print(f"Models AUC score: {roc_auc_score(y_test, y_pred)}")

Models AUC score: 0.9780101115957627


## Bayesian gaussian

In [22]:
import pymc3 as pm
import theano as tt
from scipy.special import expit

In [23]:
with pm.Model() as model:
    # Alpha is the interception
    alpha = pm.Normal("alpha", mu=0, sd=3)
    # The prior for the features varibles which are included
    beta = pm.Normal("beta", mu=0, sd=3, shape=X.shape[1])
    # Deterministic function
    p = pm.math.dot(X,beta)
    # Likelihood
    y_obs = pm.Bernoulli("y_obs", pm.invlogit(p + alpha),  observed=y)
    

In [24]:
with model:
    trace = pm.sample(2000, random_seed = 4816, cores = 1, progressbar = True, chains = 1)

  trace = pm.sample(2000, random_seed = 4816, cores = 1, progressbar = True, chains = 1)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [beta, alpha]


Sampling 1 chain for 1_000 tune and 2_000 draw iterations (1_000 + 2_000 draws total) took 4038 seconds.
There were 21 divergences after tuning. Increase `target_accept` or reparameterize.
Only one chain was sampled, this makes it impossible to run some convergence checks


In [26]:
results = pd.DataFrame({'var': np.arange(561), 
                       'beta':np.apply_along_axis(np.mean, 0, trace['beta']),
                        'alpha':np.apply_along_axis(np.mean, 0, trace['alpha'])
                       })

In [27]:
results.head(10).T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
var,0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0
beta,0.433055,0.661716,0.325327,0.494487,-1.051093,0.694642,0.34118,-0.997011,0.538365,0.584682
alpha,-0.077208,-0.077208,-0.077208,-0.077208,-0.077208,-0.077208,-0.077208,-0.077208,-0.077208,-0.077208


In [28]:
estimate = trace['beta']
preds = np.apply_along_axis(np.mean, 1, expit(trace['alpha'] + np.dot(X, np.transpose(estimate) )) )

In [29]:
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
y_pred = np.apply_along_axis(np.mean, 1, expit(trace['alpha'] + np.dot(_X_test, np.transpose(estimate) )) )
print(f"Models AUC score: {roc_auc_score(y_test, y_pred)}")

Models AUC score: 0.9865284937404407


# Slap-and-spike prior

In [38]:
with pm.Model() as model:
    # priors inclusion probability
    xi = pm.Bernoulli("xi", .75, shape=X.shape[1])
    # Alpha is the interception
    alpha = pm.Normal("alpha", mu=0, sd=3)
    # The prior for the features varibles which are included
    beta = pm.Normal("beta", mu=0, sd=1, shape=X.shape[1])
    # Deterministic function
    p = pm.math.dot(X,xi * beta) 
    # Likelihood
    y_obs = pm.Bernoulli("y_obs", pm.invlogit(p + alpha),  observed=y)
 

In [39]:
with model:
    trace = pm.sample(2000, random_seed = 4816, cores = 1, progressbar = True, chains = 1)
    

  trace = pm.sample(2000, random_seed = 4816, cores = 1, progressbar = True, chains = 1)
Sequential sampling (1 chains in 1 job)
CompoundStep
>BinaryGibbsMetropolis: [xi]
>NUTS: [beta, alpha]


Sampling 1 chain for 1_000 tune and 2_000 draw iterations (1_000 + 2_000 draws total) took 3493 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks


In [40]:
results = pd.DataFrame({'var': np.arange(561), 
                        'inclusion_probability':np.apply_along_axis(np.mean, 0, trace['xi']),
                       'beta':np.apply_along_axis(np.mean, 0, trace['beta']),
                       'beta_given_inclusion': np.apply_along_axis(np.sum, 0, trace['xi']*trace['beta'])
                            /np.apply_along_axis(np.sum, 0, trace['xi'])
                       })

In [41]:
results.sort_values('inclusion_probability', ascending = False).head(10)


Unnamed: 0,var,inclusion_probability,beta,beta_given_inclusion
445,445,1.0,3.950892,3.950892
158,158,1.0,-1.005632,-1.005632
197,197,1.0,-1.529335,-1.529335
451,451,1.0,-2.235666,-2.235666
186,186,1.0,-3.120934,-3.120934
50,50,1.0,-3.861003,-3.861003
182,182,1.0,5.43541,5.43541
142,142,1.0,1.764619,1.764619
41,41,0.9995,-3.832141,-3.834108
53,53,0.9995,-3.541421,-3.543288


## MAP estimate

In [42]:
from scipy.special import expit

In [43]:
map_estimate = results["inclusion_probability"] * results["beta"]

In [44]:
map_preds = expit(trace["alpha"].mean() + np.dot(_X_test, np.transpose(map_estimate)))

In [45]:
print(f"Models AUC score: {roc_auc_score(y_test, map_preds)}")

Models AUC score: 0.9786261541947545


## Bayesian

In [46]:
estimate = trace['beta'] * trace['xi'] 
preds = np.apply_along_axis(np.mean, 1, expit(trace['alpha'] + np.dot(_X_test, np.transpose(estimate) )) )


In [47]:
print(f"Models AUC score: {roc_auc_score(y_test, preds)}")

Models AUC score: 0.9816426386449895


In [48]:
preds[:10]

array([0.9993482 , 0.94805618, 0.86133167, 0.90214986, 0.99097552,
       0.97183736, 0.89049947, 0.8605957 , 0.83446152, 0.94632371])

In [49]:
map_preds[:10]

array([0.99943686, 0.96676848, 0.90967069, 0.94811593, 0.99064731,
       0.97518212, 0.90810685, 0.84526856, 0.79791429, 0.92390105])

In [50]:
(map_preds[map_preds > 0.5]).mean(), (map_preds[map_preds <= 0.5]).mean()

(0.941169250542443, 0.07475592802376103)

In [51]:
(preds[preds > 0.5]).mean(), (preds[preds <= 0.5]).mean()

(0.9338252093809313, 0.08906595699281039)