In [1]:
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import LeaveOneOut, KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.metrics import classification_report

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

## New data

In [2]:
label_dict = {4:0, 5:1}

In [3]:
X_train = pd.read_csv("../data/HAPT/Train/X_train.txt", header=None, sep=" ")
y = pd.read_csv("../data/HAPT/Train/y_train.txt", header=None).rename(columns={0:"label"})
X_train["label"] = y.label
X_train = X_train.sample(frac=1.).reset_index(drop=True)

X_test = pd.read_csv("../data/HAPT/Test/X_test.txt", header=None, sep=" ")
y = pd.read_csv("../data/HAPT/Test/y_test.txt", header=None).rename(columns={0:"label"})
X_test["label"] = y.label 

In [4]:
X_train.shape, X_test.shape

((7767, 562), (3162, 562))

In [5]:
np.unique(X_train.label, return_counts=True)

(array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]),
 array([1226, 1073,  987, 1293, 1423, 1413,   47,   23,   75,   60,   90,
          57]))

In [6]:
train_idx = X_train.label.isin([4,5])
np.unique(X_train[train_idx].label, return_counts=True)

(array([4, 5]), array([1293, 1423]))

In [7]:
# Sitting vs standing
train_idx = X_train.label.isin([4,5])
X_train = X_train[train_idx].reset_index(drop=True)
X_train["label"] = X_train["label"].map(label_dict)

test_idx = X_test.label.isin([4,5])
X_test = X_test[test_idx].reset_index(drop=True)
X_test["label"] = X_test["label"].map(label_dict)
y_test = X_test["label"]

In [8]:
X_train.shape, X_test.shape

((2716, 562), (1064, 562))

In [9]:
np.unique(X_train.label, return_counts=True)

(array([0, 1]), array([1293, 1423]))

In [10]:
X_train.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335,336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431,432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447,448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463,464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511,512,513,514,515,516,517,518,519,520,521,522,523,524,525,526,527,528,529,530,531,532,533,534,535,536,537,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,553,554,555,556,557,558,559,560,label
0,0.038216,-0.005712,-0.021914,-0.998348,-0.995431,-0.99139,-0.998184,-0.995155,-0.992187,-0.802836,-0.750945,-0.718481,0.840957,0.714987,0.676839,-0.996569,-0.999992,-0.999961,-0.999906,-0.997285,-0.995487,-0.991556,-0.867607,-0.79664,-0.642749,0.217886,-0.085428,-0.0186,0.231355,0.230877,-0.180679,0.130887,-0.000333,0.296732,-0.093821,0.079359,-0.121715,0.440605,0.375685,0.309766,0.923927,-0.307083,-0.059027,-0.9975,-0.995701,-0.991339,-0.997468,-0.995527,-0.991098,0.853679,-0.330423,-0.064355,0.947448,-0.280199,-0.058973,-0.097368,0.797939,-0.832331,-0.991688,-0.997583,-0.995815,-0.991202,-1.0,-1.0,-1.0,-0.570067,0.578585,-0.58714,0.595712,-0.607722,0.61387,-0.62042,0.626455,-0.247365,0.249265,-0.251013,0.251926,0.953591,0.998713,0.949214,0.07364,0.008935,-0.032472,-0.998118,-0.996145,-0.99244,-0.99862,-0.995296,-0.991194,-0.996954,-0.998805,-0.991227,0.995762,0.995167,0.992181,-0.9975,-0.999985,-0.999955,-0.99986,-0.998042,-0.993187,-0.988137,-0.875096,-0.982792,-0.758016,0.192313,-0.060831,0.239542,0.054918,0.294881,0.041307,0.245196,0.428557,0.217915,0.149598,0.174465,0.408847,0.251277,0.470362,-0.165974,-0.028714,-0.020631,-0.077878,-0.995705,-0.993694,-0.995973,-0.995554,-0.993493,-0.995635,-0.910959,-0.957273,-0.771712,0.849056,0.893793,0.827817,-0.994533,-0.999984,-0.999964,-0.999977,-0.99505,-0.993712,-0.994759,-0.633226,-0.464965,-0.625352,0.318084,-0.354852,0.321817,-0.121066,-0.01919,-0.021307,0.032425,0.154276,0.374829,-0.238999,0.140813,0.093786,-0.537437,-0.036864,-0.371105,-0.098837,-0.038925,-0.057068,-0.997669,-0.998675,-0.996074,-0.99756,-0.998607,-0.996214,-0.997861,-0.997759,-0.99397,0.99834,0.999236,0.996323,-0.998641,-0.999988,-0.999996,-0.999969,-0.997422,-0.997978,-0.996169,-0.826652,-0.866808,-0.718448,0.342422,-0.09485,0.211233,0.11551,0.083996,0.031662,0.152799,-0.005266,0.395262,-0.07953,0.254611,0.578913,-0.270831,0.005868,0.011237,-0.996281,-0.995804,-0.996206,-0.994078,-0.993777,-0.996281,-0.999957,-0.99582,-0.824778,0.353952,-0.387992,0.361865,-0.055249,-0.996281,-0.995804,-0.996206,-0.994078,-0.993777,-0.996281,-0.999957,-0.99582,-0.824778,0.353952,-0.387992,0.361865,-0.055249,-0.997247,-0.995969,-0.995828,-0.995013,-0.99028,-0.997247,-0.999953,-0.995606,-0.843317,0.381988,-0.107956,-0.146863,-0.246328,-0.99476,-0.996747,-0.996112,-0.996641,-0.994713,-0.99476,-0.999966,-0.995421,-0.457856,0.257608,-0.493635,0.558145,-0.403671,-0.998728,-0.998649,-0.99852,-0.998469,-0.995725,-0.998728,-0.999994,-0.998467,-0.859921,0.622319,-0.459664,-0.051894,-0.109823,-0.997955,-0.994837,-0.988908,-0.998217,-0.995165,-0.992096,-0.998238,-0.995508,-0.989508,-0.998256,-0.993523,-0.990913,-0.992731,-0.998193,-0.994305,-0.996711,-0.999991,-0.999957,-0.99986,-0.997634,-0.996303,-0.989072,-1.0,-0.94487,-0.868799,-0.466667,-1.0,-1.0,0.123152,0.087196,0.286698,-0.473702,-0.825412,0.082181,-0.183024,-0.331158,-0.585411,-0.999991,-0.999979,-0.999999,-0.999947,-0.999966,-0.999958,-0.999975,-0.999993,-0.999989,-0.999989,-0.999977,-0.999981,-0.999992,-0.999958,-0.999955,-0.999972,-0.999962,-0.999932,-0.999935,-0.999857,-0.999974,-0.999999,-0.999956,-0.999948,-0.999933,-0.999987,-0.999958,-0.99995,-0.999883,-0.999782,-0.999835,-0.999845,-0.999831,-0.999708,-0.999916,-0.999874,-0.999865,-0.999869,-0.999823,-0.999905,-0.999858,-0.999857,-0.99811,-0.996075,-0.990957,-0.99829,-0.996512,-0.991808,-0.997282,-0.996072,-0.991414,-0.998539,-0.995803,-0.991076,-0.999299,-0.995418,-0.98754,-0.997141,-0.999985,-0.999955,-0.99986,-0.996563,-0.995061,-0.989124,-1.0,-1.0,-1.0,-0.68,-0.08,-0.24,0.096974,0.135552,0.212905,-0.67277,-0.911375,-0.693728,-0.938506,-0.550245,-0.881841,-0.999994,-0.999981,-1.0,-0.99994,-0.999969,-0.999936,-0.999975,-1.0,-0.999986,-0.999985,-0.999972,-0.999977,-0.999992,-0.999946,-0.999984,-0.999974,-0.999943,-0.999928,-0.999944,-0.999801,-0.999973,-0.999984,-0.99998,-0.999933,-0.999909,-0.999978,-0.999962,-0.999945,-0.99991,-0.999751,-0.999841,-0.999833,-0.999792,-0.999712,-0.999919,-0.999954,-0.999788,-0.99987,-0.999776,-0.999919,-0.999839,-0.999833,-0.99613,-0.996442,-0.994383,-0.995511,-0.992243,-0.996634,-0.996076,-0.995496,-0.995456,-0.995025,-0.992198,-0.996398,-0.999309,-0.996392,-0.996892,-0.9967,-0.999982,-0.999964,-0.999977,-0.997182,-0.997789,-0.995202,-0.891655,-0.878352,-0.918335,-1.0,-0.935484,-1.0,-0.125002,-0.324116,0.219925,-0.014235,-0.382541,0.159229,-0.274553,-0.377166,-0.669624,-0.999983,-0.999988,-0.999991,-0.999994,-0.999992,-0.999984,-0.999984,-1.0,-0.999982,-0.999991,-0.999989,-0.999991,-0.999982,-0.999994,-0.999951,-0.999998,-0.999998,-0.999998,-0.999992,-0.999993,-0.999992,-0.999996,-0.999958,-0.999997,-0.999993,-0.999993,-0.99996,-0.999996,-0.999984,-0.99998,-0.999972,-0.999975,-0.999945,-0.999943,-0.999982,-0.999981,-0.999981,-0.999967,-0.999951,-0.999982,-0.999979,-0.999969,-0.995209,-0.995788,-0.992539,-0.997851,-0.997434,-0.995209,-0.999961,-0.992649,-1.0,-1.0,0.296943,-0.676905,-0.904737,-0.995217,-0.99634,-0.995621,-0.997002,-0.993801,-0.995217,-0.999946,-0.995151,-1.0,-0.777778,0.717954,-0.720295,-0.902178,-0.997715,-0.996363,-0.996274,-0.996357,-0.999331,-0.997715,-0.999983,-0.996651,-0.956057,-1.0,0.058193,-0.42486,-0.698161,-0.998788,-0.998319,-0.998275,-0.998515,-0.99581,-0.998788,-0.999995,-0.997004,-1.0,-1.0,0.361387,-0.447499,-0.775477,-0.012459,0.265771,0.887486,0.249724,-0.704252,0.298044,0.068766,1
1,0.040137,0.001698,-0.017681,-0.997485,-0.991863,-0.993166,-0.997438,-0.991411,-0.993787,-0.802697,-0.744286,-0.711083,0.840612,0.716274,0.680904,-0.995353,-0.999987,-0.999911,-0.999932,-0.99676,-0.992074,-0.993783,-0.731574,-0.522422,-0.537528,0.200689,0.016174,-0.027804,-0.095348,0.076919,-0.126469,0.235319,0.017128,0.53386,-0.280178,0.366596,-0.154321,0.422118,0.177119,0.35369,0.919337,-0.326731,0.103024,-0.997432,-0.991783,-0.996377,-0.997545,-0.991995,-0.996228,0.849275,-0.347784,0.094138,0.942725,-0.300708,0.103821,0.034169,0.786183,-0.809129,-0.981306,-0.997854,-0.992786,-0.995942,-1.0,-1.0,-0.721574,-0.366969,0.371484,-0.376273,0.381328,-0.317469,0.315191,-0.313969,0.312577,-0.142127,0.152235,-0.162051,0.170913,0.998976,0.862894,0.861366,0.077218,0.016225,-0.031922,-0.995332,-0.988954,-0.98956,-0.995328,-0.987422,-0.987765,-0.994493,-0.994878,-0.989649,0.993304,0.990181,0.987468,-0.992598,-0.999957,-0.999833,-0.999792,-0.993904,-0.986012,-0.982227,-0.817601,-0.788157,-0.676359,0.261717,0.166257,0.429172,0.144562,0.104857,-0.125465,0.210572,0.346777,0.380495,-0.063297,0.450543,0.339566,0.214111,0.168387,0.166846,-0.022651,-0.024049,-0.08219,-0.980022,-0.989233,-0.991901,-0.980649,-0.989672,-0.992183,-0.896685,-0.95599,-0.770249,0.848083,0.889652,0.823527,-0.984206,-0.999768,-0.999913,-0.999924,-0.984307,-0.991437,-0.993685,-0.198919,-0.458572,-0.687148,0.29338,-0.375288,0.276593,0.027343,-0.093129,0.052088,0.125931,-0.050349,0.306903,-0.353229,0.363482,-0.161536,-0.858705,0.381526,-0.300108,-0.095637,-0.042625,-0.053258,-0.989762,-0.995126,-0.996367,-0.989854,-0.995093,-0.996435,-0.990603,-0.99583,-0.996416,0.989645,0.995754,0.994534,-0.994794,-0.999909,-0.999976,-0.999972,-0.989894,-0.995475,-0.997144,-0.541582,-0.624969,-0.734245,0.320182,-0.214788,0.255255,-0.363656,0.033585,0.071718,0.270522,0.017365,0.458598,-0.404425,0.5031,0.000779,-0.472745,0.075194,-0.098915,-0.995758,-0.994151,-0.99473,-0.993948,-0.995991,-0.995758,-0.999946,-0.994822,-0.750764,0.226699,-0.340608,0.793773,-0.717892,-0.995758,-0.994151,-0.99473,-0.993948,-0.995991,-0.995758,-0.999946,-0.994822,-0.750764,0.226699,-0.340608,0.793773,-0.717892,-0.992862,-0.993759,-0.993789,-0.993457,-0.986835,-0.992862,-0.999879,-0.995177,-0.835282,0.350576,-0.398719,0.026882,-0.047042,-0.984238,-0.978386,-0.98057,-0.983678,-0.995277,-0.984238,-0.99976,-0.990128,-0.064101,0.125765,-0.398745,0.317886,0.071312,-0.994724,-0.995221,-0.99525,-0.995828,-0.986665,-0.994724,-0.999965,-0.99567,-0.665183,0.552579,-0.191515,-0.271303,-0.172374,-0.996909,-0.987113,-0.988836,-0.997434,-0.993072,-0.994696,-0.996494,-0.987228,-0.990821,-0.998176,-0.996198,-0.996838,-0.999206,-0.996152,-0.992883,-0.993785,-0.999986,-0.99991,-0.999895,-0.993527,-0.989816,-0.990221,-1.0,-0.904748,-0.903572,-1.0,-1.0,-1.0,0.125189,0.169107,0.539787,-0.606215,-0.909707,-0.623456,-0.883852,-0.720049,-0.933266,-0.999987,-0.999991,-0.999944,-0.999895,-0.999938,-0.999933,-0.999983,-1.0,-0.999989,-0.999926,-0.99995,-0.999989,-0.999988,-0.999909,-0.999944,-0.999864,-0.999854,-0.999842,-0.999765,-0.999808,-0.999922,-0.999971,-0.999921,-0.999818,-0.999789,-0.999943,-0.999914,-0.999838,-0.999946,-0.99969,-0.999743,-0.99983,-0.999828,-0.999604,-0.999519,-0.999699,-0.999913,-0.999805,-0.999794,-0.999567,-0.999896,-0.999838,-0.995907,-0.989145,-0.987214,-0.995027,-0.989478,-0.989911,-0.993898,-0.990116,-0.989992,-0.996525,-0.989998,-0.990755,-0.993511,-0.992637,-0.986166,-0.992372,-0.999957,-0.999833,-0.999792,-0.994785,-0.992093,-0.98849,-1.0,-1.0,-1.0,-0.08,-0.32,0.24,0.194946,-0.063237,0.341743,-0.363644,-0.817973,-0.521019,-0.869864,-0.501655,-0.884211,-0.999994,-0.999993,-0.999933,-0.999886,-0.999936,-0.999918,-0.999969,-0.999994,-0.999994,-0.999908,-0.999943,-0.999971,-0.999972,-0.999887,-0.999851,-0.999894,-0.999819,-0.999881,-0.999796,-0.999783,-0.999918,-0.999989,-0.999876,-0.99982,-0.99978,-0.999931,-0.999832,-0.999866,-0.999903,-0.999652,-0.999757,-0.99981,-0.999832,-0.999504,-0.999466,-0.999563,-0.999701,-0.999816,-0.999745,-0.999444,-0.999744,-0.999807,-0.981267,-0.991482,-0.991925,-0.97981,-0.987978,-0.992209,-0.981103,-0.990281,-0.990597,-0.984256,-0.988176,-0.991872,-0.998513,-0.997438,-0.99851,-0.988382,-0.999762,-0.999914,-0.999937,-0.986949,-0.993923,-0.993641,-0.714514,-0.7311,-0.79248,-1.0,-1.0,-0.935484,-0.346199,-0.28237,-0.216608,-0.101344,-0.54282,-0.072989,-0.457258,-0.101374,-0.433043,-0.999753,-0.99997,-0.999922,-0.999959,-0.999873,-0.999936,-0.9999,-0.999988,-0.999763,-0.999922,-0.999887,-0.999939,-0.999763,-0.999939,-0.999897,-0.99999,-0.999983,-0.999991,-0.999994,-0.999965,-0.999981,-0.999996,-0.999907,-0.999982,-0.999988,-0.999985,-0.999906,-0.99999,-0.999941,-0.999993,-0.999958,-0.999986,-0.999982,-0.999954,-0.999919,-0.999999,-0.999941,-0.999959,-0.999981,-0.999954,-0.999938,-0.999986,-0.99298,-0.994421,-0.990717,-0.996789,-0.987568,-0.99298,-0.999941,-0.990704,-1.0,-0.631579,0.344277,-0.600724,-0.861596,-0.992524,-0.994681,-0.993684,-0.995084,-0.991275,-0.992524,-0.999908,-0.994011,-1.0,-0.968254,0.629196,-0.643364,-0.865474,-0.986539,-0.9769,-0.98165,-0.979455,-0.994945,-0.986539,-0.999691,-0.990357,-0.751553,-1.0,-0.339433,-0.068383,-0.47479,-0.994861,-0.995945,-0.99496,-0.996437,-0.995678,-0.994861,-0.999974,-0.992127,-1.0,-1.0,0.490871,-0.658089,-0.872056,0.046541,-0.064867,-0.61908,-0.762678,-0.678091,0.310887,-0.042296,1
2,0.037051,-0.009297,-0.02251,-0.990393,-0.942027,-0.983337,-0.990804,-0.941414,-0.985294,-0.794692,-0.708248,-0.71144,0.832721,0.684123,0.660024,-0.975509,-0.999915,-0.998568,-0.999779,-0.990658,-0.947044,-0.987508,-0.661743,-0.372098,-0.571696,-0.155275,0.220172,-0.350097,0.316497,-0.077798,-0.090272,0.235847,0.023607,0.244782,-0.140568,0.074908,0.264005,0.63129,-0.303732,0.015101,0.926658,-0.326912,0.00995,-0.995397,-0.981858,-0.987664,-0.995549,-0.981922,-0.987426,0.857031,-0.346202,0.005146,0.948406,-0.304859,0.008758,-0.200278,0.804953,-0.808888,-0.999951,-0.995953,-0.982735,-0.986883,-0.742841,-1.0,-0.547853,-0.61582,0.621777,-0.627947,0.63432,-0.475026,0.478906,-0.48368,0.488291,-0.397544,0.413091,-0.428734,0.443675,0.996211,-0.794849,-0.820257,0.074625,-0.001004,-0.038955,-0.991834,-0.959933,-0.987329,-0.991372,-0.957488,-0.986012,-0.991894,-0.975872,-0.986491,0.992318,0.972135,0.97999,-0.98189,-0.999911,-0.998818,-0.999734,-0.989828,-0.958142,-0.983408,-0.714697,-0.458345,-0.655853,-0.116762,0.258462,-0.0635,0.163166,-0.014476,-0.218053,0.153183,0.209585,0.008017,-0.007349,-0.187093,0.614168,0.143976,0.142461,0.628584,-0.029931,-0.040893,-0.050363,-0.957664,-0.976355,-0.94222,-0.960962,-0.978897,-0.945563,-0.884004,-0.942973,-0.699562,0.828685,0.874456,0.791918,-0.949154,-0.999109,-0.999517,-0.997607,-0.971574,-0.982401,-0.954089,-0.230288,-0.591603,0.180165,-0.003819,-0.072642,-0.103036,0.352673,-0.011571,-0.038291,0.073471,0.421797,-0.355755,0.131203,0.235007,-0.213159,-0.125676,0.630501,-0.377391,-0.099378,-0.044196,-0.062085,-0.983577,-0.981642,-0.963452,-0.982747,-0.981762,-0.969521,-0.989739,-0.983261,-0.943246,0.983134,0.982328,0.975699,-0.980035,-0.999803,-0.999787,-0.999116,-0.981607,-0.980229,-0.979828,-0.391429,-0.20405,-0.334945,-0.042215,0.109139,-0.285176,0.031786,-0.054964,0.010865,-0.089355,0.481132,-0.225105,-0.111132,-0.005295,0.562204,0.181746,-0.14338,-0.559247,-0.972055,-0.976133,-0.978971,-0.96875,-0.994119,-0.972055,-0.999358,-0.982949,-0.314766,-0.201981,0.211208,-0.073775,0.020528,-0.972055,-0.976133,-0.978971,-0.96875,-0.994119,-0.972055,-0.999358,-0.982949,-0.314766,-0.201981,0.211208,-0.073775,0.020528,-0.982679,-0.979412,-0.981208,-0.982051,-0.990108,-0.982679,-0.999577,-0.985704,-0.661151,0.345402,-0.270661,0.071125,-0.250523,-0.950862,-0.944372,-0.940064,-0.960422,-0.98558,-0.950862,-0.998359,-0.944751,0.369612,0.006423,-0.242312,0.328938,-0.167584,-0.979974,-0.97918,-0.9824,-0.976452,-0.980166,-0.979974,-0.999702,-0.986117,-0.222754,0.088356,0.045112,0.249888,-0.485553,-0.99109,-0.938888,-0.980245,-0.989647,-0.945117,-0.984646,-0.990206,-0.93413,-0.975366,-0.989115,-0.946159,-0.989137,-0.997612,-0.975374,-0.996937,-0.973122,-0.999911,-0.997921,-0.999662,-0.989628,-0.955818,-0.98397,-0.907316,-0.655977,-0.73874,-1.0,-1.0,-1.0,-0.176002,-0.173289,0.035026,0.021647,-0.326857,0.286612,0.057165,-0.541663,-0.829363,-0.999904,-0.999951,-0.999909,-0.999831,-0.999892,-0.999965,-0.999979,-0.999999,-0.99991,-0.999877,-0.999933,-0.999986,-0.999913,-0.999855,-0.998185,-0.998978,-0.999405,-0.999596,-0.998822,-0.998953,-0.998933,-0.99973,-0.997975,-0.999313,-0.998772,-0.999238,-0.997941,-0.99934,-0.999767,-0.999272,-0.99982,-0.999951,-0.999587,-0.999625,-0.999858,-0.999978,-0.999667,-0.999898,-0.999604,-0.999898,-0.999659,-0.99987,-0.991993,-0.95977,-0.985477,-0.992354,-0.963038,-0.987027,-0.990208,-0.962111,-0.986173,-0.992837,-0.971196,-0.989474,-0.997673,-0.993454,-0.989296,-0.980137,-0.999911,-0.998818,-0.999733,-0.988834,-0.966594,-0.982431,-1.0,-0.85688,-1.0,0.24,-0.76,-0.6,-0.053127,-0.30097,-0.007807,-0.447746,-0.803,-0.607526,-0.913548,-0.57039,-0.918759,-0.999965,-0.999959,-0.99991,-0.9998,-0.999876,-0.999953,-0.999964,-0.999999,-0.999959,-0.999852,-0.999917,-0.999966,-0.999935,-0.999806,-0.997793,-0.999219,-0.999133,-0.99961,-0.999006,-0.998931,-0.99893,-0.999986,-0.998779,-0.999192,-0.998774,-0.999067,-0.998732,-0.999308,-0.999513,-0.999281,-0.999841,-0.999943,-0.99956,-0.999593,-0.999755,-0.999843,-0.999246,-0.999923,-0.999563,-0.99975,-0.999599,-0.999816,-0.966837,-0.974993,-0.935695,-0.956084,-0.977384,-0.947262,-0.962689,-0.97601,-0.919987,-0.958363,-0.985369,-0.955458,-0.998515,-0.996134,-0.99054,-0.96197,-0.999023,-0.99966,-0.998124,-0.983505,-0.977583,-0.954717,-0.584144,-0.442095,-0.40959,-1.0,-0.483871,-1.0,-0.574011,-0.078301,-0.476109,0.208448,-0.196832,-0.610098,-0.906571,-0.283945,-0.622035,-0.998962,-0.999909,-0.999907,-0.999893,-0.999884,-0.999871,-0.999983,-0.999998,-0.99901,-0.999883,-0.999869,-0.99999,-0.999022,-0.999886,-0.999818,-0.99968,-0.999963,-0.999896,-0.999913,-0.999602,-0.999763,-0.999995,-0.999637,-0.999934,-0.999849,-0.999832,-0.999646,-0.999873,-0.9985,-0.998459,-0.999864,-0.999882,-0.999867,-0.999626,-0.999819,-0.999942,-0.998142,-0.999819,-0.999808,-0.999873,-0.998132,-0.999861,-0.970587,-0.978739,-0.96428,-0.985204,-0.986222,-0.970587,-0.999546,-0.971488,-0.728456,-0.947368,-0.007464,-0.505138,-0.775138,-0.977897,-0.980573,-0.978353,-0.982389,-0.987879,-0.977897,-0.999556,-0.981695,-0.940965,-1.0,0.264905,-0.323226,-0.684302,-0.963379,-0.942351,-0.950159,-0.944893,-0.985265,-0.963379,-0.998262,-0.969825,-0.485005,-1.0,-0.362198,0.018397,-0.328158,-0.980388,-0.978807,-0.976854,-0.980622,-0.985205,-0.980388,-0.999721,-0.973755,-0.691855,-1.0,-0.035044,-0.219877,-0.589726,0.005211,-0.095086,-0.050689,-0.004716,-0.692621,0.310886,0.021392,1
3,0.038887,-0.003339,-0.023784,-0.988175,-0.970379,-0.978091,-0.989349,-0.972932,-0.979569,-0.792247,-0.727862,-0.701384,0.827421,0.701815,0.667247,-0.98264,-0.99989,-0.999576,-0.999666,-0.990886,-0.979459,-0.977819,-0.588125,-0.451449,-0.52756,0.292545,-0.103881,0.126498,0.552032,-0.051083,0.098124,0.003004,0.481057,0.187473,-0.017717,0.01211,0.158893,0.422295,-0.022545,0.539373,0.955256,-0.208394,-0.066481,-0.999976,-0.992557,-0.984442,-0.999968,-0.992729,-0.985988,0.883774,-0.233742,-0.064617,0.978749,-0.184813,-0.066887,-0.286354,0.879214,-0.926125,-0.989651,-0.999963,-0.993792,-0.991926,-1.0,-1.0,-1.0,-0.077937,0.26184,-0.445088,0.627207,0.00055,0.074172,-0.152105,0.231507,-0.419036,0.425342,-0.431714,0.437361,0.537604,-0.384725,-0.903263,0.072418,0.008203,-0.021283,-0.983256,-0.959787,-0.977512,-0.982301,-0.959329,-0.975717,-0.988602,-0.974672,-0.986943,0.982465,0.968829,0.964568,-0.974788,-0.999747,-0.998811,-0.999418,-0.979036,-0.967388,-0.975653,-0.581342,-0.483507,-0.527571,-0.034485,-0.118734,0.13112,0.582812,-0.257393,0.074128,-0.076197,0.42922,0.026165,0.039878,0.095443,0.110919,0.354439,0.044896,0.547213,-0.026107,-0.013044,-0.074003,-0.996549,-0.980389,-0.97928,-0.996724,-0.982562,-0.978075,-0.910031,-0.942852,-0.755408,0.850326,0.884196,0.816176,-0.983715,-0.999987,-0.999735,-0.999702,-0.996548,-0.987015,-0.973191,-0.499257,-0.116951,-0.233057,0.475166,-0.383713,0.282825,0.249918,-0.171195,0.029884,0.108724,0.297526,0.222239,-0.300361,0.433355,0.004731,-0.156436,0.29973,-0.592153,-0.095949,-0.026996,-0.066638,-0.99454,-0.99087,-0.979938,-0.994263,-0.990727,-0.979748,-0.994157,-0.993011,-0.977362,0.995606,0.995685,0.987916,-0.989987,-0.999964,-0.999936,-0.999679,-0.993134,-0.99029,-0.980232,-0.63183,-0.397198,-0.358538,0.313649,-0.082167,0.106734,0.406373,-0.138469,-0.105701,0.072697,-0.040785,0.220982,-0.470787,0.399264,0.322541,-0.317452,0.583312,-0.457713,-0.982101,-0.983492,-0.984533,-0.978351,-0.989567,-0.982101,-0.999672,-0.985241,-0.477887,0.33199,-0.254385,0.323602,-0.361522,-0.982101,-0.983492,-0.984533,-0.978351,-0.989567,-0.982101,-0.999672,-0.985241,-0.477887,0.33199,-0.254385,0.323602,-0.361522,-0.976135,-0.97434,-0.976282,-0.969679,-0.962986,-0.976135,-0.999358,-0.979551,-0.591898,0.274992,0.030394,-0.492374,0.04408,-0.983312,-0.981566,-0.98221,-0.982515,-0.994969,-0.983312,-0.999763,-0.983541,-0.016791,0.040565,-0.170783,0.075254,0.102933,-0.989941,-0.991025,-0.990508,-0.993049,-0.989375,-0.989941,-0.99991,-0.98861,-0.496432,0.417654,-0.254931,0.473163,-0.675234,-0.988146,-0.961442,-0.970994,-0.987745,-0.973557,-0.981212,-0.987533,-0.961159,-0.967348,-0.989246,-0.982035,-0.989167,-0.995003,-0.991031,-0.980493,-0.975399,-0.999878,-0.999355,-0.999499,-0.987942,-0.977354,-0.972226,-0.873892,-0.718723,-0.699473,-0.266667,-0.266667,-0.153846,0.15665,0.093041,0.268741,-0.185153,-0.602337,-0.361109,-0.711434,-0.588101,-0.86766,-0.999992,-0.999612,-0.999831,-0.999821,-0.999899,-0.999755,-0.999887,-0.999985,-0.999881,-0.999808,-0.999858,-0.99992,-0.99988,-0.999812,-0.999883,-0.997785,-0.999683,-0.999402,-0.999042,-0.999584,-0.999349,-0.999912,-0.999395,-0.999526,-0.999175,-0.999568,-0.999377,-0.999322,-0.999763,-0.998653,-0.99948,-0.999661,-0.99937,-0.999584,-0.999248,-0.999586,-0.999564,-0.999576,-0.999417,-0.999338,-0.999519,-0.99961,-0.985591,-0.965891,-0.976173,-0.982154,-0.955621,-0.976534,-0.979269,-0.962663,-0.976732,-0.981357,-0.953415,-0.974447,-0.999291,-0.965193,-0.982645,-0.97559,-0.999746,-0.998811,-0.999418,-0.979152,-0.975778,-0.979832,-0.943699,-0.85688,-0.900162,-0.56,-0.56,-0.56,-0.092337,-0.230309,0.019217,-0.052696,-0.477603,-0.012879,-0.50648,-0.334952,-0.745269,-0.999991,-0.999618,-0.999851,-0.999816,-0.999862,-0.999665,-0.999828,-0.999999,-0.999752,-0.999811,-0.999789,-0.999833,-0.999748,-0.999731,-0.999847,-0.998343,-0.999582,-0.999375,-0.999189,-0.999502,-0.999316,-0.999778,-0.998469,-0.999409,-0.999167,-0.999378,-0.998704,-0.999312,-0.999502,-0.998831,-0.999577,-0.999634,-0.99934,-0.999588,-0.998861,-0.999894,-0.998856,-0.999637,-0.999394,-0.998908,-0.999245,-0.999562,-0.994951,-0.981651,-0.971455,-0.997063,-0.979709,-0.982684,-0.996115,-0.982501,-0.973224,-0.997554,-0.982659,-0.988911,-0.99793,-0.988269,-0.98227,-0.983955,-0.999987,-0.999757,-0.99971,-0.992575,-0.991987,-0.9792,-0.920371,-0.601838,-0.624752,-0.466667,-0.806452,-0.419355,0.279175,-0.054225,0.049215,-0.564036,-0.827539,-0.258191,-0.668823,-0.482556,-0.815692,-0.999997,-0.99994,-0.999978,-0.999992,-0.999925,-0.999986,-0.999953,-0.999997,-0.999989,-0.999981,-0.999944,-0.999973,-0.999988,-0.999979,-0.999751,-0.999905,-0.999987,-0.999965,-0.999945,-0.999844,-0.999744,-0.999708,-0.999727,-0.999977,-0.999924,-0.999694,-0.999741,-0.99995,-0.999869,-0.999402,-0.999877,-0.99992,-0.999852,-0.999575,-0.999433,-0.999735,-0.99973,-0.99985,-0.999783,-0.999565,-0.999719,-0.999879,-0.976853,-0.986596,-0.97895,-0.988761,-0.996192,-0.976853,-0.999746,-0.979889,-0.873482,-1.0,0.372363,-0.393132,-0.634572,-0.973359,-0.974644,-0.969976,-0.978094,-0.980775,-0.973359,-0.999383,-0.972283,-0.940965,-1.0,0.217847,-0.386842,-0.723848,-0.982361,-0.98398,-0.982259,-0.98911,-0.991969,-0.982361,-0.999766,-0.982008,-0.664203,-0.897436,0.090274,-0.561094,-0.84388,-0.991115,-0.991253,-0.991586,-0.989867,-0.99599,-0.991115,-0.999932,-0.992309,-0.923452,-1.0,0.421055,-0.001363,-0.276499,-0.015807,0.205988,-0.036548,-0.013271,-0.798049,0.229334,0.073964,0
4,0.038962,-0.005479,-0.019209,-0.995611,-0.99272,-0.988614,-0.995954,-0.992431,-0.990926,-0.798141,-0.750187,-0.697104,0.835059,0.711894,0.671851,-0.995027,-0.999974,-0.99994,-0.999877,-0.995885,-0.992593,-0.993535,-0.728058,-0.764401,-0.557118,0.126991,0.100894,-0.076227,0.1887,0.257897,-0.171075,0.237579,0.029195,0.369416,-0.123947,0.159765,-0.338084,-0.06757,-0.251453,-0.155579,0.95881,-0.071728,0.146778,-0.997152,-0.994818,-0.991758,-0.997166,-0.994738,-0.992288,0.888702,-0.100614,0.138637,0.981423,-0.048625,0.144319,-0.468517,0.888548,-0.993417,-0.960692,-0.997375,-0.994852,-0.993347,-1.0,-1.0,-0.613047,-0.17812,0.197272,-0.217034,0.23737,-0.059643,0.065543,-0.07293,0.080263,-0.333831,0.340696,-0.347633,0.353896,0.758838,-0.951832,-0.523641,0.07959,0.011081,-0.035821,-0.992176,-0.986022,-0.988206,-0.992584,-0.983773,-0.98947,-0.983847,-0.990228,-0.990242,0.991941,0.99133,0.974616,-0.99107,-0.999916,-0.999768,-0.999757,-0.990676,-0.982561,-0.992424,-0.75742,-0.721902,-0.752632,0.065407,0.253877,0.148656,0.465593,0.256238,-0.075766,0.36011,0.260713,0.357855,0.059246,0.551794,-0.042243,-0.130655,-0.125607,0.044861,-0.018895,-0.038716,-0.078729,-0.996934,-0.98691,-0.993709,-0.997182,-0.986189,-0.994191,-0.907671,-0.958519,-0.767542,0.852941,0.885376,0.827414,-0.985539,-0.999941,-0.999766,-0.999957,-0.997377,-0.984933,-0.995901,-0.339614,-0.810206,-0.621879,0.140747,-0.184168,0.186706,-0.162193,-0.167537,0.065578,0.122105,-0.013768,0.241817,-0.103331,0.039301,-0.024827,0.300743,-0.096284,0.188226,-0.098287,-0.036256,-0.053344,-0.995688,-0.99553,-0.994084,-0.995851,-0.995986,-0.99397,-0.996675,-0.996262,-0.993356,0.994287,0.994221,0.993418,-0.996349,-0.999974,-0.999979,-0.999947,-0.996522,-0.996062,-0.994529,-0.764986,-0.643546,-0.634232,0.136314,0.155453,-0.055506,0.563367,-0.0232,-0.006255,0.267196,-0.204766,0.396984,-0.088554,0.44267,-0.036227,0.239703,0.011393,-0.069689,-0.994806,-0.992262,-0.993459,-0.983474,-0.995547,-0.994806,-0.999929,-0.994064,-0.710246,0.322153,-0.226775,0.283449,-0.327384,-0.994806,-0.992262,-0.993459,-0.983474,-0.995547,-0.994806,-0.999929,-0.994064,-0.710246,0.322153,-0.226775,0.283449,-0.327384,-0.991515,-0.988639,-0.989951,-0.984811,-0.989768,-0.991515,-0.999827,-0.991166,-0.76931,0.349348,-0.362337,-0.177077,0.126515,-0.985023,-0.987536,-0.986573,-0.989835,-0.99324,-0.985023,-0.999821,-0.986472,-0.111356,0.044619,-0.203029,0.220512,-0.149758,-0.996206,-0.996312,-0.996622,-0.995245,-0.996236,-0.996206,-0.999977,-0.996725,-0.725538,0.542703,-0.274186,-0.21473,-0.100802,-0.993859,-0.985586,-0.986903,-0.996224,-0.994927,-0.989129,-0.994399,-0.988022,-0.987564,-0.99735,-0.997319,-0.988204,-0.999032,-0.998278,-0.987741,-0.991104,-0.999971,-0.999922,-0.999799,-0.990275,-0.98224,-0.987149,-1.0,-1.0,-0.903572,-1.0,-0.933333,-1.0,0.262045,0.222492,0.396869,-0.648259,-0.929518,-0.865906,-0.968984,-0.163063,-0.458454,-0.999987,-0.999948,-0.999918,-0.999873,-0.999831,-0.99991,-0.999964,-0.999975,-0.999976,-0.999897,-0.999874,-0.999968,-0.999974,-0.999856,-0.999962,-0.999872,-0.999818,-0.999696,-0.999859,-0.999626,-0.999881,-0.999998,-0.99994,-0.999741,-0.999788,-0.99993,-0.99993,-0.99973,-0.999832,-0.999851,-0.99976,-0.999642,-0.999849,-0.999663,-0.999549,-0.999929,-0.999825,-0.999747,-0.999827,-0.999662,-0.99981,-0.999713,-0.992434,-0.98561,-0.986306,-0.992542,-0.987589,-0.987979,-0.989393,-0.987638,-0.988925,-0.995725,-0.989678,-0.986963,-0.996984,-0.995588,-0.978744,-0.989229,-0.999916,-0.999768,-0.999757,-0.98549,-0.988121,-0.984809,-1.0,-1.0,-1.0,-0.04,0.2,0.04,0.096236,-0.005093,0.302471,-0.632709,-0.941766,-0.672327,-0.933566,-0.445589,-0.805809,-0.999996,-0.999938,-0.999915,-0.999866,-0.999832,-0.999896,-0.999966,-0.999997,-0.99996,-0.999885,-0.999864,-0.999968,-0.999938,-0.999819,-0.999911,-0.999884,-0.999775,-0.999667,-0.999864,-0.999526,-0.999856,-0.999988,-0.99988,-0.999687,-0.999726,-0.999877,-0.999816,-0.999711,-0.999756,-0.999828,-0.999736,-0.999633,-0.99983,-0.999649,-0.999334,-0.999869,-0.999799,-0.999718,-0.999787,-0.999353,-0.999772,-0.999716,-0.995781,-0.990422,-0.991929,-0.997276,-0.985046,-0.994495,-0.996043,-0.988802,-0.992634,-0.998197,-0.985608,-0.995383,-0.99927,-0.993877,-0.991972,-0.99322,-0.999989,-0.99988,-0.999957,-0.995121,-0.995078,-0.994169,-0.953913,-0.721074,-0.838993,-1.0,-1.0,-0.935484,0.08069,-0.325308,0.157125,-0.687051,-0.916252,-0.009593,-0.421749,-0.362886,-0.716845,-0.999994,-0.999966,-0.999987,-0.999985,-0.999978,-0.999971,-0.999997,-1.0,-0.99999,-0.999984,-0.999974,-0.999998,-0.99999,-0.999983,-0.999848,-0.999991,-0.999985,-0.999992,-0.999995,-0.999973,-0.999966,-0.999979,-0.999865,-0.999983,-0.999991,-0.999967,-0.999868,-0.999991,-0.999965,-0.999988,-0.999958,-0.999944,-0.999956,-0.999959,-0.999873,-0.999977,-0.999963,-0.999939,-0.999962,-0.999918,-0.99996,-0.999951,-0.989569,-0.993359,-0.989394,-0.995096,-0.99404,-0.989569,-0.999915,-0.98899,-0.946182,-0.947368,0.398141,-0.547444,-0.792655,-0.988337,-0.987825,-0.986423,-0.990571,-0.983401,-0.988337,-0.999806,-0.989891,-1.0,-1.0,0.270172,-0.431315,-0.776214,-0.992125,-0.986508,-0.989031,-0.984831,-0.995903,-0.992125,-0.999879,-0.990401,-0.828966,-1.0,-0.281105,0.220479,-0.022977,-0.996046,-0.996817,-0.99663,-0.996654,-0.992694,-0.996046,-0.999982,-0.995332,-1.0,-1.0,0.414314,-0.428002,-0.706576,0.006345,-0.76231,-0.443577,0.02178,-0.854093,0.136051,-0.073712,0


# Logreg

In [11]:
def get_predictions(x):
    return [1 if xi >= 0.5 else 0 for xi in x]

In [12]:
X, y = X_train[[col for col in X_train.columns if col != "label"]].values , X_train.label

In [13]:
X.shape, y.shape

((2716, 561), (2716,))

In [15]:
loo = KFold(n_splits=10)
preds = np.zeros(len(y))
for i, (train_index, test_index) in enumerate(loo.split(X)):
    _X_train, _X_test = X[train_index, :], X[test_index,:]
    _y_train = y[train_index].values
    clf = LogisticRegression(random_state=0, C=1., max_iter=500).fit(_X_train, _y_train)
    preds[test_index] = clf.predict_proba(_X_test)[:,1]

print(f"Models AUC score: {roc_auc_score(y, preds)}")
print(classification_report(y, get_predictions(preds)))

Models AUC score: 0.9930524870661473
              precision    recall  f1-score   support

           0       0.96      0.96      0.96      1293
           1       0.96      0.96      0.96      1423

    accuracy                           0.96      2716
   macro avg       0.96      0.96      0.96      2716
weighted avg       0.96      0.96      0.96      2716



In [16]:
# Logreg MLE
log_reg = LogisticRegression(random_state=0, C=1.0, max_iter = 500).fit(X[:561,:], y[:561])

In [17]:
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
y_pred = log_reg.predict_proba(_X_test)[:,1]

In [18]:
print(f"Models AUC score: {roc_auc_score(y_test, y_pred)}")

Models AUC score: 0.974823684359599


# Gaussian logreg

## MAP l2 estimate C=0.3

In [19]:
loo = KFold(n_splits=10)
preds = np.zeros(len(y))
for i, (train_index, test_index) in enumerate(loo.split(X)):
    _X_train, _X_test = X[train_index, :], X[test_index,:]
    _y_train = y[train_index].values
    clf = LogisticRegression(random_state=0, C=.3, max_iter=500).fit(_X_train, _y_train)
    preds[test_index] = clf.predict_proba(_X_test)[:,1]

print(f"Models AUC score: {roc_auc_score(y, preds)}")
print(classification_report(y, get_predictions(preds)))

Models AUC score: 0.9914823263162529
              precision    recall  f1-score   support

           0       0.95      0.96      0.95      1293
           1       0.96      0.96      0.96      1423

    accuracy                           0.96      2716
   macro avg       0.96      0.96      0.96      2716
weighted avg       0.96      0.96      0.96      2716



In [20]:
log_reg = LogisticRegression(random_state=0, C=.3, penalty="l2", max_iter = 500).fit(X[:561,:], y[:561])

In [21]:
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
y_pred = log_reg.predict_proba(_X_test)[:,1]
print(f"Models AUC score: {roc_auc_score(y_test, y_pred)}")

Models AUC score: 0.9766293264600917


# Slap-and-spike prior

In [29]:
X, y = X_train[[col for col in X_train.columns if col != "label"]].values , X_train.label

In [30]:
import pymc3 as pm
import theano as tt
from scipy.special import expit
from scipy.stats import norm, bernoulli

This is our slab and spike model

$$a \sim \mathcal{N}(0, 3)$$
$$\gamma_i \sim Bernoulli(p=0.1)$$
$$\alpha_i|\sigma_\beta \sim \mathcal{N}(0, \sigma_\beta)$$
$$e \sim \mathcal{N}(0, \sigma^2_eI_n)$$
$$y \sim \frac{1}{1+exp(-(a + \sum_{i=1}^N \gamma_i \alpha_i x_i + e))}$$

The model parameters are $\theta = \{\gamma_i, \alpha_i\}_i^N $. The $\gamma_i$ and $\alpha_i$ are modelled IID.

In [31]:
prob = 0.1
a_mu = 0
a_var = 3
gamma_var = 1
with pm.Model() as model:
    # priors inclusion probability
    gamma_i = pm.Bernoulli("gamma_i", prob, shape=X.shape[1])
    # a is the interception
    a = pm.Normal("a", mu=a_mu, sd=a_var)
    # The prior for the features varibles which are included
    alpha = pm.Normal("alpha", mu=0, sd=gamma_var, shape=X.shape[1])
    # Deterministic function
    p = pm.math.dot(X,gamma_i * alpha) 
    # Likelihood
    y_obs = pm.Bernoulli("y_obs", pm.invlogit(p + a),  observed=y)
 

In [32]:
with model:
    trace = pm.sample(2000, random_seed = 4816, cores = 1, progressbar = True, chains = 1)
    

  trace = pm.sample(2000, random_seed = 4816, cores = 1, progressbar = True, chains = 1)
Sequential sampling (1 chains in 1 job)
CompoundStep
>BinaryGibbsMetropolis: [gamma_i]
>NUTS: [alpha, a]


Sampling 1 chain for 1_000 tune and 2_000 draw iterations (1_000 + 2_000 draws total) took 2142 seconds.
There were 24 divergences after tuning. Increase `target_accept` or reparameterize.
Only one chain was sampled, this makes it impossible to run some convergence checks


# Map estimate

The log loss for the spike and slap prior
$$p(\theta) = p(\{\gamma_i, \alpha_i\}_i^N) = \prod_{i=1}^N p(\gamma_i) p(\alpha_i)$$

$$\log p(\theta) = \sum_{i=1}^N \log Bernoulli(\gamma_i | p=0.1) + \log \mathcal{N}(\alpha_i | \mu=0, \sigma_\beta=3)$$

In [33]:
def spike_slab_log_prior(gamma: np.array, alpha: np.array, p, gamma_mu=0, sigma_beta=3):
    return (bernoulli.logpmf(gamma, p=p) + norm.logpdf(alpha, loc=gamma_mu, scale=sigma_beta)).sum()

The negative log likelihood function, i.e. cross entropy

$$E(\mathbf{w}) = -log(p(\pmb{\mathbf{t}}|\pmb{\mathbf{x}},\mathbf{w})) \nonumber$$

$$= - \sum_{n=1}^N \left( t_n \ln y(\mathbf{x}_n) + (1-t_n) \ln (1-y(\mathbf{x}_n)) \right)$$
  

In [34]:
def log_likelihood(a, gamma, alpha, X, T):
    y_x = expit(a + np.dot(X, np.transpose(gamma*alpha)))
    return (T*np.log(y_x) + ((1-T)*np.log(1-y_x))).sum()

In [35]:
prob = 0.1
gamma_mu = 0
sigma_beta = 3

def find_spike_slab_MAP(trace, X, y, prob, gamma_mu, sigma_beta):
    min_loss = np.inf
    cur_min = -1
    for i in range(len(trace)):
        tmp_trace = trace[i]
        tmp_spike_slab_log_prior = spike_slab_log_prior(tmp_trace["gamma_i"], tmp_trace["alpha"], p=prob, gamma_mu=gamma_mu, sigma_beta=sigma_beta)
        tmp_log_likelihood = log_likelihood(tmp_trace["a"], tmp_trace["gamma_i"], tmp_trace["alpha"], X, y)
        neq_loss = -(tmp_log_likelihood + tmp_spike_slab_log_prior)
        if neq_loss <= min_loss:
            min_loss = neq_loss
            cur_min = i
    return trace[cur_min]
    

In [36]:
map_trace = find_spike_slab_MAP(trace, X, y, prob, gamma_mu, sigma_beta)

In [37]:
map_estimate = map_trace["gamma_i"] * map_trace["alpha"]

_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
map_preds = expit(map_trace["a"] + np.dot(_X_test, np.transpose(map_estimate)))

In [47]:
print(f"Models AUC score: {roc_auc_score(y_test, map_preds)}")

Models AUC score: 0.9642659887837761


# Full Bayesian

In [39]:
results = pd.DataFrame({'var': np.arange(561), 
                        'inclusion_probability':np.apply_along_axis(np.mean, 0, trace['gamma_i']),
                       'alpha':np.apply_along_axis(np.mean, 0, trace['alpha']),
                       'alpha_given_inclusion': np.apply_along_axis(np.sum, 0, trace['gamma_i']*trace['alpha'])
                            /np.apply_along_axis(np.sum, 0, trace['gamma_i'])
                       })

In [40]:
results.sort_values('inclusion_probability', ascending = False).head(10)


Unnamed: 0,var,inclusion_probability,alpha,alpha_given_inclusion
41,41,1.0,-3.88396,-3.88396
451,451,1.0,-2.267181,-2.267181
182,182,1.0,5.423177,5.423177
142,142,1.0,1.571083,1.571083
50,50,1.0,-4.035479,-4.035479
445,445,1.0,4.38021,4.38021
186,186,1.0,-3.663025,-3.663025
187,187,0.997,-2.154221,-2.158809
158,158,0.997,-0.996644,-1.002019
53,53,0.9915,-3.449169,-3.479308


## Bayesian inference

In [41]:
estimate = trace['alpha'] * trace['gamma_i'] 
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
preds = np.apply_along_axis(np.mean, 1, expit(trace['a'] + np.dot(_X_test, np.transpose(estimate) )) )


In [42]:
print(f"Models AUC score: {roc_auc_score(y_test, preds)}")

Models AUC score: 0.9731986064691553


In [43]:
preds[:10]

array([0.99728329, 0.95677686, 0.92627129, 0.96335403, 0.97437116,
       0.94798007, 0.89826115, 0.83813811, 0.84503151, 0.9424179 ])

In [44]:
map_preds[:10]

array([0.99377665, 0.9638406 , 0.86181986, 0.90449025, 0.98654351,
       0.9787043 , 0.94709745, 0.91942937, 0.89949213, 0.92160971])

In [45]:
(map_preds[map_preds > 0.5]).mean(), (map_preds[map_preds <= 0.5]).mean()

(0.9363551762776625, 0.09595969172622476)

In [46]:
(preds[preds > 0.5]).mean(), (preds[preds <= 0.5]).mean()

(0.9245190406644662, 0.09512387902157081)