In [1]:
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import LeaveOneOut, KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.metrics import classification_report

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

## New data

In [2]:
label_dict = {4:0, 5:1}

In [3]:
X_train = pd.read_csv("../data/HAPT/Train/X_train.txt", header=None, sep=" ")
y = pd.read_csv("../data/HAPT/Train/y_train.txt", header=None).rename(columns={0:"label"})
X_train["label"] = y.label
X_train = X_train.sample(frac=1.).reset_index(drop=True)

X_test = pd.read_csv("../data/HAPT/Test/X_test.txt", header=None, sep=" ")
y = pd.read_csv("../data/HAPT/Test/y_test.txt", header=None).rename(columns={0:"label"})
X_test["label"] = y.label 

In [4]:
X_train.shape, X_test.shape

((7767, 562), (3162, 562))

In [5]:
np.unique(X_train.label, return_counts=True)

(array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]),
 array([1226, 1073,  987, 1293, 1423, 1413,   47,   23,   75,   60,   90,
          57]))

In [6]:
train_idx = X_train.label.isin([4,5])
np.unique(X_train[train_idx].label, return_counts=True)

(array([4, 5]), array([1293, 1423]))

In [7]:
# Sitting vs standing
train_idx = X_train.label.isin([4,5])
X_train = X_train[train_idx].reset_index(drop=True)
X_train["label"] = X_train["label"].map(label_dict)

test_idx = X_test.label.isin([4,5])
X_test = X_test[test_idx].reset_index(drop=True)
X_test["label"] = X_test["label"].map(label_dict)
y_test = X_test["label"]

In [8]:
X_train.shape, X_test.shape

((2716, 562), (1064, 562))

In [9]:
np.unique(X_train.label, return_counts=True)

(array([0, 1]), array([1293, 1423]))

In [10]:
X_train.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335,336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431,432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447,448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463,464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511,512,513,514,515,516,517,518,519,520,521,522,523,524,525,526,527,528,529,530,531,532,533,534,535,536,537,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,553,554,555,556,557,558,559,560,label
0,0.040243,-0.002684,-0.020638,-0.996505,-0.996423,-0.994309,-0.9969,-0.995687,-0.994611,-0.799694,-0.751905,-0.717516,0.837553,0.715361,0.678068,-0.997548,-0.99998,-0.999981,-0.999947,-0.996691,-0.995407,-0.994355,-0.705828,-0.715289,-0.640649,0.42827,-0.148116,0.278691,0.036127,0.422815,-0.283304,0.34806,0.042574,0.361628,-0.084547,0.170821,-0.038616,0.207362,-0.126003,-0.059313,0.864398,-0.344069,-0.243635,-0.99878,-0.999368,-0.999097,-0.998913,-0.999416,-0.999103,0.794177,-0.36775,-0.250326,0.889534,-0.316061,-0.241643,0.382838,0.648454,-0.787413,-0.877008,-0.999241,-0.999668,-0.999042,-1.0,-1.0,-1.0,-0.155429,0.199949,-0.24527,0.291296,0.178666,-0.130571,0.081504,-0.033415,-0.099502,0.140254,-0.180525,0.219565,-0.802173,-0.901742,0.651628,0.076309,0.014078,-0.032759,-0.992358,-0.991581,-0.991173,-0.992196,-0.989468,-0.990399,-0.994181,-0.995586,-0.994846,0.991855,0.993747,0.986682,-0.992835,-0.999919,-0.999884,-0.999831,-0.990833,-0.989467,-0.989358,-0.74495,-0.84773,-0.717519,0.331114,0.008559,0.43048,0.669208,0.395338,-0.183545,0.487499,0.256549,0.230083,0.092803,0.310735,0.364079,0.259515,0.099777,0.037689,-0.02696342,-0.019136,-0.079452,-0.994016,-0.995145,-0.996139,-0.994197,-0.995614,-0.996197,-0.90924,-0.956484,-0.771937,0.850448,0.892991,0.826309,-0.994638,-0.999973,-0.999974,-0.999976,-0.995429,-0.996488,-0.995577,-0.484168,-0.469555,-0.68932,0.231368,-0.225231,0.114994,0.142977,0.215015,-0.115741,0.164775,0.247057,0.518008,-0.411013,0.344292,0.062815,0.201294,-0.506591,-0.460108,-0.097991,-0.040614,-0.057046,-0.994474,-0.995201,-0.995705,-0.995326,-0.995594,-0.995629,-0.993477,-0.994591,-0.994698,0.993576,0.996055,0.995755,-0.996374,-0.999964,-0.999977,-0.999965,-0.996388,-0.995471,-0.995834,-0.702549,-0.663793,-0.688431,0.180963,0.028781,0.037117,0.093685,0.204951,0.082543,0.184627,0.485833,0.532718,-0.262837,0.374473,0.643151,0.365788,-0.329179,-0.261616,-0.99765,-0.997484,-0.997837,-0.996859,-0.99423,-0.99765,-0.999975,-0.997236,-0.882288,0.569237,-0.42283,0.272721,-0.216685,-0.99765,-0.997484,-0.997837,-0.996859,-0.99423,-0.99765,-0.999975,-0.997236,-0.882288,0.569237,-0.42283,0.272721,-0.216685,-0.993142,-0.994789,-0.994558,-0.99303,-0.989309,-0.993142,-0.999889,-0.994046,-0.868269,0.361886,-0.088881,-0.203014,-0.227797,-0.994885,-0.994867,-0.994592,-0.995866,-0.999198,-0.994885,-0.999963,-0.994614,-0.359337,0.592595,-0.675998,0.486471,-0.390939,-0.996217,-0.995749,-0.996148,-0.99476,-0.994261,-0.996217,-0.999976,-0.99668,-0.723035,0.248114,-0.253383,0.23496,-0.180132,-0.995326,-0.992794,-0.991751,-0.996789,-0.997219,-0.994976,-0.99573,-0.994002,-0.991147,-0.998231,-0.998712,-0.996036,-0.999816,-0.998348,-0.995214,-0.995704,-0.999978,-0.999968,-0.999915,-0.993744,-0.991444,-0.990013,-1.0,-1.0,-0.903572,-0.4,0.0,0.076923,0.273159,0.296917,0.405459,-0.654413,-0.938325,-0.745517,-0.931492,-0.700037,-0.902025,-0.999993,-0.999963,-0.999873,-0.99993,-0.999915,-0.99984,-0.999973,-1.0,-0.999986,-0.999876,-0.9999,-0.999982,-0.99998,-0.99991,-0.999983,-0.999927,-0.999931,-0.999848,-0.999825,-0.999776,-0.999895,-0.999998,-0.999973,-0.999894,-0.999821,-0.999938,-0.999972,-0.999853,-0.999955,-0.999718,-0.999781,-0.999856,-0.999836,-0.999751,-0.999901,-0.99999,-0.999926,-0.999839,-0.999838,-0.999932,-0.999913,-0.999869,-0.992834,-0.991854,-0.990159,-0.992444,-0.991802,-0.989967,-0.989671,-0.99207,-0.990682,-0.99483,-0.992543,-0.988413,-0.997123,-0.992703,-0.974338,-0.99297,-0.999919,-0.999884,-0.999831,-0.987428,-0.993777,-0.987873,-1.0,-1.0,-1.0,-0.12,-0.4,-0.44,0.154737,0.193864,0.21387,-0.516576,-0.881283,-0.516356,-0.89308,-0.356559,-0.764879,-0.999993,-0.999954,-0.999877,-0.999925,-0.999907,-0.999769,-0.99995,-0.999998,-0.999969,-0.99988,-0.999862,-0.999953,-0.99993,-0.999861,-0.999952,-0.999924,-0.999916,-0.99986,-0.999846,-0.999694,-0.999833,-0.99996,-0.999926,-0.999881,-0.999783,-0.999853,-0.999911,-0.999854,-0.999886,-0.999686,-0.999753,-0.999858,-0.99985,-0.99971,-0.999733,-0.999807,-0.999724,-0.999837,-0.999821,-0.999725,-0.999751,-0.999866,-0.992591,-0.994508,-0.994416,-0.994403,-0.995615,-0.996852,-0.993224,-0.995769,-0.99482,-0.995396,-0.996069,-0.998321,-0.99827,-0.998268,-0.998001,-0.994524,-0.99997,-0.999976,-0.999978,-0.992688,-0.993658,-0.994163,-0.802327,-0.894922,-0.862703,-1.0,-0.741935,-0.806452,-0.071182,0.185207,0.19374,-0.437143,-0.767755,-0.527453,-0.810588,-0.634384,-0.900503,-0.999975,-0.999969,-0.999968,-0.99998,-0.999956,-0.999972,-0.999978,-0.999999,-0.999972,-0.999967,-0.999959,-0.999987,-0.999971,-0.999975,-0.999982,-0.999993,-0.999986,-0.999991,-0.999976,-0.999944,-0.999972,-0.999999,-0.999981,-0.999984,-0.99997,-0.999981,-0.999976,-0.999984,-0.999987,-0.999974,-0.999959,-0.99999,-0.99997,-0.999902,-0.999929,-0.999983,-0.999982,-0.999962,-0.999957,-0.999953,-0.999979,-0.999982,-0.996134,-0.998115,-0.997104,-0.998193,-0.99411,-0.996134,-0.999979,-0.996345,-1.0,-0.578947,0.680752,-0.679646,-0.852221,-0.993194,-0.996718,-0.99554,-0.997575,-0.975486,-0.993194,-0.999927,-0.995123,-1.0,0.111111,0.588981,-0.711548,-0.941455,-0.994602,-0.995861,-0.996089,-0.995266,-0.989273,-0.994602,-0.999969,-0.996428,-0.924075,-0.948718,0.463655,-0.294981,-0.57344,-0.995422,-0.996392,-0.995524,-0.997527,-0.996915,-0.995422,-0.999978,-0.995321,-1.0,-0.936508,0.180928,-0.708249,-0.930926,-0.527569,0.176781,-0.160059,-0.273077,-0.58792,0.324633,0.197343,1
1,0.027705,-0.01972,-0.039006,-0.963102,-0.950293,-0.930569,-0.96761,-0.951663,-0.942195,-0.758632,-0.72002,-0.647611,0.813254,0.681866,0.639756,-0.945777,-0.999089,-0.998525,-0.997503,-0.974564,-0.960328,-0.951223,-0.57408,-0.609239,-0.578016,-0.235298,-0.017428,0.281536,-0.286037,0.095854,-0.378997,0.409105,0.093442,-0.065393,0.003585,-0.005194,0.113306,0.649457,0.928481,0.543064,0.808253,-0.226803,-0.449244,-0.948806,-0.95296,-0.932947,-0.953637,-0.952613,-0.937934,0.754187,-0.240019,-0.436706,0.792605,-0.21702,-0.491979,0.504867,0.513596,-0.911337,-0.590182,-0.96078,-0.952976,-0.945533,-0.126955,-1.0,-1.0,-0.97227,0.972866,-0.973481,0.974113,-0.947579,0.95432,-0.961284,0.96795,-0.959833,0.959487,-0.958844,0.956839,0.758814,0.994733,0.696375,0.056167,-0.010686,-0.083722,-0.988205,-0.983795,-0.989149,-0.988453,-0.981712,-0.989037,-0.988179,-0.990492,-0.989789,0.98542,0.990264,0.985928,-0.988571,-0.999849,-0.999712,-0.999778,-0.987362,-0.983462,-0.990531,-0.72265,-0.708124,-0.783283,-0.066122,-0.219349,0.177716,0.399835,0.16872,-0.621082,0.248293,-0.049381,-0.152278,-0.027624,-0.194127,0.239922,0.049789,0.329403,0.220006,2.050635e-07,0.032881,-0.11707,-0.966028,-0.978476,-0.958899,-0.966216,-0.978627,-0.961524,-0.887371,-0.935808,-0.754478,0.838565,0.896479,0.790216,-0.929533,-0.998926,-0.998593,-0.99792,-0.9701,-0.981289,-0.973141,0.291946,0.283903,-0.693353,0.143682,-0.422519,0.537309,-0.067971,-0.14701,-0.045239,0.343727,-0.143666,-0.042055,-0.15489,0.173258,0.163704,-0.198977,-0.691113,-0.279689,-0.093964,-0.052169,-0.054844,-0.985866,-0.99226,-0.98604,-0.985361,-0.993057,-0.987397,-0.989435,-0.990219,-0.989307,0.986686,0.992401,0.987283,-0.990485,-0.999847,-0.999951,-0.999819,-0.985452,-0.993737,-0.99162,-0.422141,-0.575594,-0.465645,0.1427,-0.472072,0.031169,0.26159,0.040893,-0.126873,0.262722,0.365272,0.080164,-0.401022,0.151647,0.166153,0.134166,-0.441451,-0.191008,-0.947314,-0.947147,-0.951382,-0.931442,-0.994171,-0.947314,-0.998044,-0.95767,-0.093554,-0.396161,0.183808,0.067003,0.217983,-0.947314,-0.947147,-0.951382,-0.931442,-0.994171,-0.947314,-0.998044,-0.95767,-0.093554,-0.396161,0.183808,0.067003,0.217983,-0.989181,-0.988955,-0.989324,-0.990424,-0.986552,-0.989181,-0.999788,-0.992025,-0.766312,-0.067902,0.153499,0.035089,-0.234197,-0.936621,-0.958405,-0.953729,-0.966635,-0.95245,-0.936621,-0.997819,-0.954761,0.300788,-0.353648,0.093177,0.113875,0.019273,-0.990569,-0.990631,-0.991,-0.990919,-0.984299,-0.990569,-0.999914,-0.991393,-0.500869,0.364035,-0.012899,-0.247087,-0.261214,-0.970672,-0.95025,-0.937402,-0.95987,-0.952243,-0.932807,-0.966075,-0.941171,-0.927697,-0.956881,-0.955755,-0.931439,-0.989719,-0.982581,-0.943356,-0.951669,-0.999158,-0.998421,-0.996782,-0.986428,-0.98571,-0.978908,-0.696642,-0.67894,-0.574538,-1.0,-1.0,-1.0,-0.364202,-0.263677,-0.074732,0.426034,0.192232,0.158685,-0.143285,0.451238,0.272747,-0.998927,-0.999793,-0.999828,-0.999808,-0.999739,-0.999676,-0.999644,-0.999551,-0.999084,-0.999801,-0.999729,-0.999613,-0.999139,-0.99974,-0.998513,-0.999535,-0.999938,-0.999676,-0.999879,-0.999037,-0.998712,-0.998761,-0.998426,-0.999851,-0.999577,-0.998694,-0.998434,-0.99965,-0.996996,-0.998775,-0.999538,-0.9997,-0.999589,-0.998617,-0.99822,-0.995312,-0.996831,-0.999627,-0.999343,-0.997223,-0.9968,-0.999617,-0.988187,-0.982828,-0.987333,-0.989278,-0.986318,-0.988839,-0.985879,-0.985646,-0.988684,-0.993361,-0.99009,-0.989969,-0.991877,-0.96876,-0.974937,-0.986718,-0.99985,-0.999713,-0.999781,-0.984837,-0.985624,-0.983165,-1.0,-1.0,-1.0,-0.32,-0.4,-0.48,-0.306025,-0.199094,-0.151631,-0.564015,-0.911607,-0.597232,-0.940525,-0.475381,-0.868581,-0.999898,-0.999841,-0.999872,-0.99996,-0.999844,-0.999891,-0.999847,-0.999991,-0.999854,-0.999892,-0.99987,-0.999849,-0.999837,-0.999893,-0.999204,-0.999784,-0.999988,-0.999776,-0.999919,-0.999416,-0.999616,-0.999624,-0.999624,-0.999894,-0.999724,-0.999621,-0.999723,-0.999779,-0.998968,-0.999448,-0.999899,-0.999906,-0.999849,-0.999673,-0.999867,-0.999591,-0.999195,-0.999935,-0.999808,-0.999827,-0.999616,-0.99989,-0.969903,-0.98212,-0.965455,-0.965405,-0.976519,-0.959896,-0.963051,-0.982009,-0.955515,-0.971633,-0.978077,-0.963531,-0.998546,-0.991121,-0.999345,-0.973538,-0.999358,-0.999713,-0.999007,-0.99027,-0.99229,-0.983172,-0.565461,-0.603871,-0.579095,-1.0,-1.0,-1.0,-0.623304,-0.215671,-0.602255,-0.195159,-0.614936,0.007814,-0.407566,0.060441,-0.346678,-0.999324,-0.99987,-0.999979,-0.99998,-0.999959,-0.99985,-0.999962,-0.999997,-0.999345,-0.999976,-0.999912,-0.999978,-0.999356,-0.999961,-0.999645,-0.999959,-0.999982,-0.999979,-0.999954,-0.99986,-0.999863,-0.999855,-0.999672,-0.999977,-0.999935,-0.999841,-0.999687,-0.999963,-0.999058,-0.999776,-0.999972,-0.999955,-0.99993,-0.999892,-0.999801,-0.999998,-0.999011,-0.999957,-0.999926,-0.999887,-0.99901,-0.999947,-0.93556,-0.953322,-0.934539,-0.962362,-0.90569,-0.93556,-0.998233,-0.968708,-0.624207,-1.0,-0.020428,-0.169942,-0.497064,-0.989966,-0.986296,-0.985318,-0.989298,-0.988385,-0.989966,-0.999814,-0.989398,-1.0,-1.0,0.030779,-0.322641,-0.73198,-0.978522,-0.954372,-0.965324,-0.94887,-0.995657,-0.978522,-0.998991,-0.984161,-0.636798,-1.0,-0.555452,0.493119,0.313147,-0.991298,-0.990112,-0.990312,-0.98915,-0.985564,-0.991298,-0.999928,-0.989877,-0.895847,-1.0,0.165352,-0.004286,-0.339527,-0.041803,0.027147,-0.366565,-0.342974,-0.497322,0.24192,0.347374,1
2,0.040181,-0.001511,-0.01354,-0.997022,-0.992382,-0.991907,-0.996929,-0.992441,-0.992986,-0.800077,-0.747659,-0.708102,0.839602,0.709361,0.678118,-0.995456,-0.999983,-0.999945,-0.999876,-0.996492,-0.993616,-0.99508,-0.680556,-0.611606,-0.410001,0.130097,0.111693,-0.226873,0.121198,0.422627,-0.239387,0.200146,0.230696,0.249419,-0.131726,0.315551,-0.420664,-0.218135,0.340959,0.140308,0.931432,-0.031449,-0.252802,-0.998501,-0.996467,-0.994996,-0.998495,-0.996669,-0.995623,0.860902,-0.061994,-0.256331,0.954987,-0.008922,-0.25158,-0.316514,0.817245,-0.999378,-0.867799,-0.998479,-0.997598,-0.996936,-1.0,-1.0,-1.0,-0.310714,0.320535,-0.330718,0.341244,0.00702,0.039045,-0.087846,0.137696,-0.173661,0.183286,-0.19311,0.202469,-0.628288,0.63285,0.131617,0.072758,0.012153,-0.032697,-0.994448,-0.983695,-0.991611,-0.99489,-0.981668,-0.992434,-0.995507,-0.99005,-0.988825,0.98997,0.987737,0.98396,-0.992494,-0.999947,-0.999711,-0.999841,-0.995591,-0.983488,-0.993714,-0.817601,-0.697146,-0.761523,0.146065,0.286187,0.199893,0.187942,0.333849,-0.038053,0.289951,0.378047,0.238718,0.047876,0.355942,0.45259,-0.283798,0.268694,0.076873,-0.02836834,-0.019838,-0.079018,-0.996945,-0.989924,-0.996407,-0.997092,-0.989801,-0.996758,-0.90947,-0.953488,-0.772193,0.852224,0.891141,0.825495,-0.993934,-0.999991,-0.999924,-0.999979,-0.997332,-0.990727,-0.996693,-0.674022,-0.356626,-0.695257,0.45491,-0.338869,0.262243,-0.086834,-0.085907,0.102876,0.063279,0.143095,0.394763,-0.168661,0.013886,0.243106,0.453164,-0.082355,-0.249219,-0.095813,-0.037281,-0.055345,-0.995443,-0.992413,-0.994768,-0.995025,-0.992197,-0.99468,-0.996113,-0.994909,-0.992848,0.994821,0.995085,0.995626,-0.994369,-0.999972,-0.999953,-0.999955,-0.995084,-0.991561,-0.9942,-0.715298,-0.480543,-0.637022,0.435761,0.036433,0.332242,0.18815,-0.048851,0.160781,0.084481,0.338249,0.420174,-0.038948,0.260298,0.253485,0.162421,-0.180275,-0.128953,-0.995003,-0.993199,-0.994075,-0.992301,-0.993955,-0.995003,-0.999935,-0.995396,-0.708863,0.472436,-0.319059,0.20637,-0.220544,-0.995003,-0.993199,-0.994075,-0.992301,-0.993955,-0.995003,-0.999935,-0.995396,-0.708863,0.472436,-0.319059,0.20637,-0.220544,-0.992271,-0.992681,-0.992969,-0.991814,-0.974303,-0.992271,-0.999864,-0.993792,-0.83282,0.283732,-0.13358,-0.509465,0.255178,-0.993657,-0.993581,-0.993297,-0.994858,-0.995727,-0.993657,-0.99995,-0.994627,-0.333311,0.364517,-0.507214,0.306031,-0.005912,-0.993976,-0.995515,-0.995734,-0.995939,-0.984997,-0.993976,-0.99996,-0.996274,-0.663511,0.708613,-0.235607,-0.482087,-0.124136,-0.995698,-0.985031,-0.990617,-0.99743,-0.994672,-0.992102,-0.997102,-0.988247,-0.989504,-0.997497,-0.996779,-0.992682,-0.989714,-0.998505,-0.995176,-0.993131,-0.999982,-0.999917,-0.99987,-0.997045,-0.984887,-0.989576,-1.0,-1.0,-0.837884,-1.0,-0.733333,-1.0,0.315896,0.347053,0.305802,-0.490277,-0.824164,-0.815527,-0.949478,-0.450978,-0.746485,-0.999984,-0.999988,-0.999953,-0.999872,-0.999919,-0.999896,-0.999972,-0.99998,-0.999985,-0.999927,-0.999923,-0.999974,-0.999985,-0.99988,-0.999958,-0.999902,-0.999777,-0.99966,-0.999565,-0.999494,-0.999785,-0.999999,-0.999942,-0.999691,-0.999523,-0.999871,-0.999929,-0.999621,-0.999909,-0.999803,-0.999639,-0.999863,-0.999911,-0.999716,-0.999898,-0.999914,-0.999894,-0.99975,-0.99989,-0.999905,-0.999866,-0.999889,-0.994382,-0.983453,-0.991819,-0.99502,-0.985189,-0.989092,-0.993696,-0.985423,-0.990402,-0.996204,-0.986797,-0.989546,-0.994418,-0.995686,-0.995504,-0.991548,-0.999947,-0.999711,-0.999841,-0.992556,-0.987275,-0.989059,-1.0,-1.0,-1.0,-0.08,0.08,-0.32,0.177891,0.168733,0.122841,-0.484829,-0.830084,-0.611799,-0.899451,-0.356016,-0.787747,-0.999986,-0.999987,-0.999944,-0.999838,-0.999935,-0.999897,-0.999969,-0.999994,-0.999987,-0.999895,-0.999933,-0.99997,-0.999971,-0.999845,-0.999873,-0.999909,-0.999728,-0.999635,-0.999558,-0.999375,-0.999757,-0.999995,-0.999894,-0.999638,-0.999413,-0.999791,-0.999805,-0.99957,-0.999927,-0.999783,-0.9997,-0.999859,-0.999887,-0.999671,-0.999892,-0.999995,-0.999822,-0.999811,-0.999837,-0.999898,-0.999758,-0.999873,-0.995126,-0.989712,-0.994496,-0.997584,-0.990097,-0.997201,-0.996749,-0.990202,-0.995307,-0.99781,-0.993395,-0.997892,-0.998528,-0.996099,-0.996772,-0.993272,-0.999989,-0.999923,-0.99998,-0.993735,-0.990005,-0.993069,-0.920371,-0.745176,-0.918335,-1.0,-1.0,-0.741935,0.316306,-0.002316,0.314012,-0.604961,-0.84151,-0.49165,-0.848914,-0.62386,-0.859784,-0.999992,-0.99999,-0.999971,-0.999968,-0.999955,-0.999936,-0.999994,-0.999989,-0.999992,-0.999965,-0.999944,-0.999992,-0.999991,-0.999962,-0.999941,-0.999963,-0.999978,-0.999967,-0.999977,-0.999927,-0.999935,-0.99998,-0.999926,-0.99997,-0.999967,-0.999946,-0.999921,-0.999965,-0.999987,-0.999986,-0.999963,-0.999968,-0.99992,-0.999945,-0.999981,-0.999994,-0.999985,-0.999956,-0.999933,-0.999987,-0.999983,-0.999959,-0.989661,-0.9949,-0.992128,-0.995633,-0.984771,-0.989661,-0.999928,-0.992429,-0.946182,-1.0,0.438739,-0.550284,-0.770261,-0.991954,-0.99282,-0.991674,-0.994305,-0.987984,-0.991954,-0.999891,-0.993905,-1.0,-0.936508,0.292383,-0.569614,-0.84412,-0.992668,-0.995449,-0.993941,-0.997104,-0.9884,-0.992668,-0.999957,-0.991491,-0.924075,-0.846154,0.381304,-0.848477,-0.9782,-0.994921,-0.996671,-0.995697,-0.997315,-0.994868,-0.994921,-0.999976,-0.994256,-1.0,-0.809524,0.613083,-0.780136,-0.932063,0.055645,0.570028,0.213908,-0.416739,-0.749341,0.108463,0.204921,0
3,0.041542,0.004941,-0.017991,-0.996213,-0.979937,-0.967548,-0.996299,-0.980631,-0.971401,-0.797031,-0.730051,-0.699154,0.839132,0.701355,0.652492,-0.983637,-0.999972,-0.999683,-0.999437,-0.995467,-0.983328,-0.971254,-0.638426,-0.335584,-0.389344,0.255029,-0.064427,0.106454,-0.135059,-0.145501,0.005704,0.299828,-0.265108,-0.033506,0.164625,-0.113274,-0.166229,-0.050865,-0.219672,0.57095,0.931584,-0.284899,0.107579,-0.996125,-0.9803,-0.978497,-0.996284,-0.979589,-0.978464,0.862167,-0.305161,0.102416,0.954349,-0.262503,0.100223,-0.044194,0.817638,-0.85669,-0.979466,-0.996533,-0.979128,-0.980023,-0.831857,-1.0,-0.373283,-0.317094,0.321411,-0.326062,0.331042,-0.320294,0.317466,-0.315499,0.313166,-0.470749,0.47018,-0.469465,0.467805,0.852063,0.542716,0.899092,0.076599,0.005335,-0.071154,-0.992309,-0.979574,-0.987197,-0.992094,-0.97821,-0.987479,-0.993508,-0.980169,-0.984915,0.992727,0.981206,0.967673,-0.988524,-0.999918,-0.999596,-0.999728,-0.98917,-0.977507,-0.985417,-0.769828,-0.636157,-0.711162,0.307175,0.182846,0.373525,0.579947,0.001287,-0.060012,0.250131,0.450803,-0.042387,0.289651,0.037231,0.268651,-0.135932,-0.153448,-0.029436,-0.02461614,-0.03318,-0.080377,-0.981917,-0.967591,-0.977449,-0.984297,-0.977792,-0.977628,-0.896625,-0.952506,-0.741327,0.844835,0.843119,0.81509,-0.975638,-0.999814,-0.999347,-0.999661,-0.988825,-0.986665,-0.978164,-0.21043,-0.43235,-0.342753,0.157169,-0.225881,0.159772,0.149067,-0.427036,0.420634,-0.240921,0.215161,-0.131129,0.155421,-0.049922,0.103048,0.119519,-0.259666,-0.438459,-0.099299,-0.043455,-0.037993,-0.984744,-0.979268,-0.977128,-0.985191,-0.985986,-0.978725,-0.985161,-0.962931,-0.972612,0.988864,0.973563,0.977095,-0.984877,-0.999826,-0.999735,-0.999601,-0.985249,-0.990575,-0.980501,-0.426063,-0.430155,-0.275815,0.095773,-0.043182,0.070745,-0.118485,-0.307764,0.242323,0.090064,-0.074543,-0.067054,-0.00976,0.265463,0.031345,0.045682,-0.444646,-0.042805,-0.982127,-0.980898,-0.981767,-0.972714,-0.991832,-0.982127,-0.999657,-0.980821,-0.431722,-0.151712,0.23982,-0.225611,0.099128,-0.982127,-0.980898,-0.981767,-0.972714,-0.991832,-0.982127,-0.999657,-0.980821,-0.431722,-0.151712,0.23982,-0.225611,0.099128,-0.989287,-0.986487,-0.991482,-0.97979,-0.978796,-0.989287,-0.999773,-0.996659,-0.827354,0.219232,-0.093029,-0.281261,0.053563,-0.975552,-0.959342,-0.967942,-0.945921,-0.992828,-0.975552,-0.999413,-0.976045,0.117668,-0.207116,0.059428,0.11535,-0.098875,-0.984176,-0.975736,-0.981729,-0.968172,-0.983768,-0.984176,-0.999738,-0.987937,-0.262471,0.083462,0.005637,-0.506154,0.382212,-0.994703,-0.972874,-0.963747,-0.996699,-0.9822,-0.970058,-0.995613,-0.97015,-0.961135,-0.997177,-0.984542,-0.974504,-0.996021,-0.985176,-0.974472,-0.979802,-0.999976,-0.999658,-0.999089,-0.993124,-0.97555,-0.98007,-1.0,-0.786886,-0.656337,-0.866667,-1.0,-1.0,0.284659,-0.070695,0.198646,-0.578454,-0.874104,-0.328007,-0.595432,-0.151643,-0.513321,-0.999987,-0.999968,-0.999887,-0.99994,-0.999851,-0.999869,-0.99998,-0.999998,-0.999982,-0.999891,-0.999871,-0.999986,-0.999978,-0.999903,-0.999752,-0.99964,-0.999726,-0.999781,-0.999894,-0.999659,-0.999734,-0.999939,-0.99968,-0.999677,-0.999826,-0.999815,-0.999664,-0.999805,-0.999218,-0.999426,-0.999502,-0.999805,-0.9994,-0.999119,-0.998558,-0.997889,-0.999146,-0.999642,-0.99932,-0.998298,-0.999105,-0.999685,-0.991927,-0.981724,-0.986124,-0.993492,-0.978366,-0.98601,-0.991913,-0.979302,-0.986659,-0.993828,-0.977726,-0.986796,-0.988598,-0.998032,-0.989098,-0.987633,-0.999918,-0.999596,-0.99973,-0.987763,-0.987947,-0.982858,-1.0,-0.93912,-1.0,-0.08,-0.4,-0.16,0.296145,-0.311289,0.010702,-0.46126,-0.784968,-0.336936,-0.746502,-0.416803,-0.815553,-0.999996,-0.999964,-0.999869,-0.99993,-0.999839,-0.999805,-0.999963,-0.999973,-0.999976,-0.999877,-0.999831,-0.99996,-0.999934,-0.999844,-0.999678,-0.999573,-0.999724,-0.999747,-0.999912,-0.999559,-0.999823,-0.999997,-0.999544,-0.999687,-0.999779,-0.999849,-0.999548,-0.999781,-0.999752,-0.999633,-0.999623,-0.999767,-0.999731,-0.99987,-0.999626,-0.99997,-0.999631,-0.999726,-0.999778,-0.999642,-0.99962,-0.999793,-0.979495,-0.968518,-0.968512,-0.982636,-0.967203,-0.981307,-0.979614,-0.96767,-0.969898,-0.98582,-0.975472,-0.985054,-0.998536,-0.994467,-0.98343,-0.971552,-0.999802,-0.999392,-0.999663,-0.983185,-0.963609,-0.968376,-0.615074,-0.379443,-0.559151,-0.733333,-0.935484,-0.935484,-0.277095,-0.271751,0.125233,-0.359854,-0.723008,-0.396106,-0.747064,-0.455876,-0.763543,-0.999804,-0.999939,-0.99988,-0.999928,-0.999814,-0.999899,-0.999942,-0.999999,-0.999807,-0.999876,-0.999831,-0.999968,-0.999804,-0.9999,-0.999474,-0.999714,-0.999854,-0.999882,-0.999973,-0.999877,-0.999938,-0.999992,-0.999366,-0.999828,-0.999954,-0.999953,-0.999342,-0.999896,-0.999805,-0.999564,-0.999719,-0.9998,-0.999763,-0.99986,-0.999502,-0.999516,-0.999704,-0.999641,-0.999794,-0.999509,-0.999679,-0.9998,-0.979433,-0.981591,-0.977312,-0.981681,-0.980557,-0.979433,-0.999682,-0.980845,-0.946182,-1.0,0.059545,0.184163,-0.00867,-0.987059,-0.984503,-0.982081,-0.988682,-0.990941,-0.987059,-0.999757,-0.983683,-1.0,-1.0,0.057649,-0.37217,-0.751008,-0.961774,-0.964417,-0.956735,-0.973006,-0.997668,-0.961774,-0.999033,-0.94502,-0.439676,-1.0,-0.147167,-0.616263,-0.838036,-0.977085,-0.975423,-0.974055,-0.975996,-0.973435,-0.977085,-0.999633,-0.97974,-0.704753,-1.0,-0.139874,-0.159273,-0.531876,-0.01788,0.016172,-0.380236,-0.047587,-0.71628,0.282041,-0.045632,1
4,0.043116,0.008285,-0.017163,-0.980197,-0.96535,-0.983871,-0.983203,-0.966786,-0.986774,-0.776009,-0.707561,-0.706783,0.822611,0.705043,0.665657,-0.980047,-0.999726,-0.999257,-0.999795,-0.987187,-0.970954,-0.988586,-0.387587,-0.215438,-0.43661,0.015868,-0.022078,0.052859,0.123321,0.023309,-0.13132,0.016676,0.516341,0.17586,0.041897,-0.016885,0.052488,-0.418668,-0.296465,0.460401,0.918522,-0.330221,0.095379,-0.992814,-0.979093,-0.993259,-0.992994,-0.97935,-0.992968,0.851083,-0.346034,0.087525,0.940933,-0.305973,0.094968,0.021665,0.784103,-0.804822,-0.984126,-0.993906,-0.981723,-0.992627,-0.734993,-1.0,-0.715091,-0.633587,0.675183,-0.717523,0.760516,-0.6391,0.666015,-0.694165,0.722638,-0.326646,0.362296,-0.398294,0.433799,0.997717,0.899462,0.917886,0.069862,0.028274,-0.035348,-0.978254,-0.971886,-0.982103,-0.978036,-0.970164,-0.980051,-0.983508,-0.979176,-0.989625,0.980763,0.977373,0.979972,-0.977677,-0.999617,-0.999336,-0.999578,-0.972541,-0.971738,-0.976085,-0.523778,-0.571674,-0.603376,0.047966,-0.162422,0.319483,-0.102376,-0.003819,-0.251542,-0.098901,0.139741,0.046804,0.165952,0.107867,0.239206,-0.338161,0.150452,0.223451,-0.006013976,-0.038291,-0.069084,-0.909738,-0.938369,-0.950992,-0.902715,-0.938017,-0.954334,-0.853917,-0.934955,-0.739715,0.827682,0.861972,0.795629,-0.911547,-0.995845,-0.997868,-0.998579,-0.896607,-0.922168,-0.952246,-0.073007,-0.223518,-0.025284,-0.194014,0.015161,0.063389,0.043257,-0.423088,0.244163,-0.118351,0.418713,0.000439,-0.23747,0.339175,0.133806,-0.767429,-0.091855,-0.374881,-0.091716,-0.030178,-0.071026,-0.980065,-0.982985,-0.974102,-0.978913,-0.983531,-0.976534,-0.988693,-0.985652,-0.975822,0.984422,0.988645,0.978004,-0.981479,-0.999726,-0.999814,-0.99951,-0.978047,-0.98559,-0.982621,-0.310032,-0.271727,-0.289432,-0.176399,0.031768,-0.027364,-0.144381,-0.352592,0.026737,-0.12368,-0.025864,0.031869,-0.529238,0.265547,0.123114,-0.287511,0.314641,-0.474171,-0.979125,-0.972891,-0.9744,-0.9696,-0.995258,-0.979125,-0.99953,-0.974128,-0.307858,0.029667,-0.072161,0.209425,-0.086198,-0.979125,-0.972891,-0.9744,-0.9696,-0.995258,-0.979125,-0.99953,-0.974128,-0.307858,0.029667,-0.072161,0.209425,-0.086198,-0.978603,-0.98111,-0.980619,-0.980569,-0.97152,-0.978603,-0.999491,-0.980209,-0.625859,0.363815,0.045279,-0.446618,-0.149252,-0.913771,-0.905063,-0.897906,-0.932885,-0.987159,-0.913771,-0.995314,-0.88969,0.533678,-0.396195,0.126883,0.124629,-0.010084,-0.981487,-0.983277,-0.984402,-0.98375,-0.97577,-0.981487,-0.999757,-0.985463,-0.264792,-0.091012,-0.081435,0.191549,0.024341,-0.977687,-0.956079,-0.97681,-0.980903,-0.968829,-0.986787,-0.977756,-0.954696,-0.975089,-0.985612,-0.975458,-0.993306,-0.990259,-0.981562,-0.999671,-0.970379,-0.999717,-0.999159,-0.999677,-0.97894,-0.983974,-0.974334,-0.748883,-0.688236,-0.718405,-0.733333,-0.733333,-0.615385,-0.030084,-0.062643,0.198715,-0.363477,-0.76931,-0.250417,-0.600847,-0.747441,-0.939358,-0.99975,-0.999745,-0.999658,-0.999534,-0.999596,-0.999623,-0.999771,-0.999936,-0.999727,-0.999573,-0.999619,-0.999826,-0.999725,-0.999492,-0.99927,-0.999583,-0.999846,-0.999393,-0.998762,-0.99912,-0.999563,-0.999688,-0.999188,-0.999681,-0.998792,-0.999605,-0.999184,-0.999196,-0.999786,-0.999552,-0.999473,-0.999752,-0.999621,-0.99962,-0.99976,-1.0,-0.999731,-0.999604,-0.99963,-0.999835,-0.999686,-0.999735,-0.978155,-0.971123,-0.979003,-0.980337,-0.974951,-0.983177,-0.974382,-0.975835,-0.981441,-0.984374,-0.973567,-0.986461,-0.99255,-0.976708,-0.97399,-0.974942,-0.999616,-0.999337,-0.999577,-0.969663,-0.980708,-0.975586,-0.943699,-0.93912,-1.0,-0.16,-0.84,-0.08,-0.103954,-0.160706,0.05447,-0.519706,-0.844347,-0.452245,-0.776904,-0.599683,-0.906144,-0.999751,-0.999818,-0.999634,-0.999551,-0.999647,-0.999562,-0.999625,-0.999993,-0.999773,-0.999523,-0.999602,-0.999632,-0.999681,-0.99941,-0.998053,-0.999717,-0.999845,-0.999525,-0.998997,-0.99927,-0.999584,-0.999879,-0.999291,-0.999671,-0.998911,-0.999626,-0.999418,-0.999307,-0.999346,-0.999523,-0.999442,-0.999749,-0.999611,-0.999587,-0.999546,-0.999802,-0.999393,-0.999625,-0.999601,-0.999547,-0.999391,-0.999713,-0.948247,-0.958872,-0.948793,-0.903749,-0.928753,-0.954404,-0.937608,-0.948775,-0.942131,-0.884662,-0.935385,-0.957625,-0.995734,-0.985004,-0.997925,-0.950913,-0.995758,-0.997945,-0.998623,-0.98159,-0.983216,-0.974295,-0.54318,-0.37144,-0.495041,-1.0,-1.0,-0.741935,-0.722547,-0.559069,-0.301725,0.666916,0.479767,0.088802,-0.344747,-0.028144,-0.394855,-0.995426,-0.99982,-0.99994,-0.999885,-0.999937,-0.999834,-0.999912,-0.999927,-0.995669,-0.999906,-0.99989,-0.999919,-0.99574,-0.999887,-0.997295,-0.999805,-0.999952,-0.999933,-0.999947,-0.999857,-0.999746,-0.999794,-0.997568,-0.999935,-0.999929,-0.999733,-0.99772,-0.999927,-0.998755,-0.999504,-0.999892,-0.999852,-0.999806,-0.999525,-0.99933,-0.999555,-0.998644,-0.999831,-0.999736,-0.999428,-0.998636,-0.999818,-0.966032,-0.976167,-0.962911,-0.983504,-0.977314,-0.966032,-0.99944,-0.969883,-0.716573,-1.0,-0.039168,-0.441373,-0.741449,-0.978705,-0.983826,-0.981024,-0.988203,-0.98006,-0.978705,-0.999608,-0.982739,-1.0,-0.936508,0.455709,-0.530259,-0.8566,-0.953993,-0.895911,-0.925481,-0.90424,-0.995107,-0.953993,-0.995175,-0.9713,-0.451458,-1.0,-0.687,0.295788,-0.099894,-0.98415,-0.983039,-0.979702,-0.988437,-0.979985,-0.98415,-0.99981,-0.979604,-0.747421,-0.555556,-0.119234,-0.483912,-0.844812,-0.016275,0.483361,-0.862908,-0.092707,-0.67675,0.31339,-0.037061,1


# Logreg

In [11]:
def get_predictions(x):
    return [1 if xi >= 0.5 else 0 for xi in x]

In [12]:
X, y = X_train[[col for col in X_train.columns if col != "label"]].values , X_train.label

In [13]:
X.shape, y.shape

((2716, 561), (2716,))

In [14]:
loo = KFold(n_splits=10)
preds = np.zeros(len(y))
for i, (train_index, test_index) in enumerate(loo.split(X)):
    _X_train, _X_test = X[train_index, :], X[test_index,:]
    _y_train = y[train_index].values
    clf = LogisticRegression(random_state=0, C=1., max_iter=500).fit(_X_train, _y_train)
    preds[test_index] = clf.predict_proba(_X_test)[:,1]

print(f"Models AUC score: {roc_auc_score(y, preds)}")
print(classification_report(y, get_predictions(preds)))

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Models AUC score: 0.9926888880555278
              precision    recall  f1-score   support

           0       0.96      0.96      0.96      1293
           1       0.97      0.96      0.96      1423

    accuracy                           0.96      2716
   macro avg       0.96      0.96      0.96      2716
weighted avg       0.96      0.96      0.96      2716



In [25]:
# Logreg MLE
log_reg = LogisticRegression(random_state=0, C=1.0, max_iter = 500).fit(X[:561,:], y[:561])

In [26]:
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
y_pred = log_reg.predict_proba(_X_test)[:,1]

In [27]:
print(f"Models AUC score: {roc_auc_score(y_test, y_pred)}")

Models AUC score: 0.9808955984818445


## Small logreg model

In [18]:
sort_idx = np.argsort(np.abs(log_reg.coef_))
sort_coef_vals = np.abs(log_reg.coef_)[0][sort_idx[0]]

In [19]:
len(sort_coef_vals[sort_coef_vals > 0.1])

197

In [20]:
feature_idx = sort_idx[0][-195:]

In [21]:
loo = KFold(n_splits=10)
preds = np.zeros(len(y))
for i, (train_index, test_index) in enumerate(loo.split(X)):
    _X_train, _X_test = X[:,feature_idx][train_index], X[:,feature_idx][test_index]
    _y_train = y[train_index]
    clf = LogisticRegression(random_state=0, C=1., max_iter=500).fit(_X_train, _y_train)
    preds[test_index] = clf.predict_proba(_X_test)[:,1]

print(f"Models AUC score: {roc_auc_score(y, preds)}")
print(classification_report(y, get_predictions(preds)))

Models AUC score: 0.9925252956755631
              precision    recall  f1-score   support

           0       0.96      0.96      0.96      1293
           1       0.96      0.96      0.96      1423

    accuracy                           0.96      2716
   macro avg       0.96      0.96      0.96      2716
weighted avg       0.96      0.96      0.96      2716



In [22]:
# Logreg MLE
log_reg = LogisticRegression(random_state=0, C=1.0, max_iter = 500).fit(X[:561,feature_idx], y[:561])

In [23]:
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
y_pred = log_reg.predict_proba(_X_test[:,feature_idx])[:,1]

In [24]:
print(f"Models AUC score: {roc_auc_score(y_test, y_pred)}")

Models AUC score: 0.9794652467002776


# Gaussian logreg

## MAP l2 estimate C=0.3

In [28]:
loo = KFold(n_splits=10)
preds = np.zeros(len(y))
for i, (train_index, test_index) in enumerate(loo.split(X)):
    _X_train, _X_test = X[train_index, :], X[test_index,:]
    _y_train = y[train_index].values
    clf = LogisticRegression(random_state=0, C=.3, max_iter=500).fit(_X_train, _y_train)
    preds[test_index] = clf.predict_proba(_X_test)[:,1]

print(f"Models AUC score: {roc_auc_score(y, preds)}")
print(classification_report(y, get_predictions(preds)))

Models AUC score: 0.9909078507493998
              precision    recall  f1-score   support

           0       0.95      0.95      0.95      1293
           1       0.96      0.96      0.96      1423

    accuracy                           0.96      2716
   macro avg       0.96      0.96      0.96      2716
weighted avg       0.96      0.96      0.96      2716



In [29]:
log_reg = LogisticRegression(random_state=0, C=.3, penalty="l2", max_iter = 500).fit(X[:561,:], y[:561])

In [30]:
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
y_pred = log_reg.predict_proba(_X_test)[:,1]
print(f"Models AUC score: {roc_auc_score(y_test, y_pred)}")

Models AUC score: 0.9823011669404633


## Small MAP logreg model

In [31]:
sort_idx = np.argsort(np.abs(log_reg.coef_))
sort_coef_vals = np.abs(log_reg.coef_)[0][sort_idx[0]]

In [20]:
feature_idx = sort_idx[0][-195:]

In [34]:
loo = KFold(n_splits=10)
preds = np.zeros(len(y))
for i, (train_index, test_index) in enumerate(loo.split(X)):
    _X_train, _X_test = X[:,feature_idx][train_index], X[:,feature_idx][test_index]
    _y_train = y[train_index]
    clf = LogisticRegression(random_state=0, C=.3, penalty="l2", max_iter=500).fit(_X_train, _y_train)
    preds[test_index] = clf.predict_proba(_X_test)[:,1]

print(f"Models AUC score: {roc_auc_score(y, preds)}")
print(classification_report(y, get_predictions(preds)))

Models AUC score: 0.990836109240578
              precision    recall  f1-score   support

           0       0.95      0.96      0.95      1293
           1       0.96      0.96      0.96      1423

    accuracy                           0.96      2716
   macro avg       0.96      0.96      0.96      2716
weighted avg       0.96      0.96      0.96      2716



In [35]:
# Logreg MLE
log_reg = LogisticRegression(random_state=0, C=.3, penalty="l2", max_iter=500).fit(X[:561,feature_idx], y[:561])

In [36]:
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
y_pred = log_reg.predict_proba(_X_test[:,feature_idx])[:,1]

In [37]:
print(f"Models AUC score: {roc_auc_score(y_test, y_pred)}")

Models AUC score: 0.9799538322098227


In [28]:
X_test.shape

(271, 195)

# Slap-and-spike prior

In [29]:
X, y = X_train[[col for col in X_train.columns if col != "label"]].values , X_train.label

In [30]:
import pymc3 as pm
import theano as tt
from scipy.special import expit
from scipy.stats import norm, bernoulli

This is our slab and spike model

$$a \sim \mathcal{N}(0, 3)$$
$$\gamma_i \sim Bernoulli(p=0.1)$$
$$\alpha_i|\sigma_\beta \sim \mathcal{N}(0, \sigma_\beta)$$
$$e \sim \mathcal{N}(0, \sigma^2_eI_n)$$
$$y \sim \frac{1}{1+exp(-(a + \sum_{i=1}^N \gamma_i \alpha_i x_i + e))}$$

The model parameters are $\theta = \{\gamma_i, \alpha_i\}_i^N $. The $\gamma_i$ and $\alpha_i$ are modelled IID.

In [31]:
prob = 0.1
a_mu = 0
a_var = 3
gamma_var = 1
with pm.Model() as model:
    # priors inclusion probability
    gamma_i = pm.Bernoulli("gamma_i", prob, shape=X.shape[1])
    # a is the interception
    a = pm.Normal("a", mu=a_mu, sd=a_var)
    # The prior for the features varibles which are included
    alpha = pm.Normal("alpha", mu=0, sd=gamma_var, shape=X.shape[1])
    # Deterministic function
    p = pm.math.dot(X,gamma_i * alpha) 
    # Likelihood
    y_obs = pm.Bernoulli("y_obs", pm.invlogit(p + a),  observed=y)
 

In [32]:
with model:
    trace = pm.sample(2000, random_seed = 4816, cores = 1, progressbar = True, chains = 1)
    

  trace = pm.sample(2000, random_seed = 4816, cores = 1, progressbar = True, chains = 1)
Sequential sampling (1 chains in 1 job)
CompoundStep
>BinaryGibbsMetropolis: [gamma_i]
>NUTS: [alpha, a]


Sampling 1 chain for 1_000 tune and 2_000 draw iterations (1_000 + 2_000 draws total) took 2142 seconds.
There were 24 divergences after tuning. Increase `target_accept` or reparameterize.
Only one chain was sampled, this makes it impossible to run some convergence checks


# Map estimate

The log loss for the spike and slap prior
$$p(\theta) = p(\{\gamma_i, \alpha_i\}_i^N) = \prod_{i=1}^N p(\gamma_i) p(\alpha_i)$$

$$\log p(\theta) = \sum_{i=1}^N \log Bernoulli(\gamma_i | p=0.1) + \log \mathcal{N}(\alpha_i | \mu=0, \sigma_\beta=3)$$

In [33]:
def spike_slab_log_prior(gamma: np.array, alpha: np.array, p, gamma_mu=0, sigma_beta=3):
    return (bernoulli.logpmf(gamma, p=p) + norm.logpdf(alpha, loc=gamma_mu, scale=sigma_beta)).sum()

The negative log likelihood function, i.e. cross entropy

$$E(\mathbf{w}) = -log(p(\pmb{\mathbf{t}}|\pmb{\mathbf{x}},\mathbf{w})) \nonumber$$

$$= - \sum_{n=1}^N \left( t_n \ln y(\mathbf{x}_n) + (1-t_n) \ln (1-y(\mathbf{x}_n)) \right)$$
  

In [34]:
def log_likelihood(a, gamma, alpha, X, T):
    y_x = expit(a + np.dot(X, np.transpose(gamma*alpha)))
    return (T*np.log(y_x) + ((1-T)*np.log(1-y_x))).sum()

In [35]:
prob = 0.1
gamma_mu = 0
sigma_beta = 3

def find_spike_slab_MAP(trace, X, y, prob, gamma_mu, sigma_beta):
    min_loss = np.inf
    cur_min = -1
    for i in range(len(trace)):
        tmp_trace = trace[i]
        tmp_spike_slab_log_prior = spike_slab_log_prior(tmp_trace["gamma_i"], tmp_trace["alpha"], p=prob, gamma_mu=gamma_mu, sigma_beta=sigma_beta)
        tmp_log_likelihood = log_likelihood(tmp_trace["a"], tmp_trace["gamma_i"], tmp_trace["alpha"], X, y)
        neq_loss = -(tmp_log_likelihood + tmp_spike_slab_log_prior)
        if neq_loss <= min_loss:
            min_loss = neq_loss
            cur_min = i
    return trace[cur_min]
    

In [36]:
map_trace = find_spike_slab_MAP(trace, X, y, prob, gamma_mu, sigma_beta)

In [37]:
map_estimate = map_trace["gamma_i"] * map_trace["alpha"]

_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
map_preds = expit(map_trace["a"] + np.dot(_X_test, np.transpose(map_estimate)))

In [47]:
print(f"Models AUC score: {roc_auc_score(y_test, map_preds)}")

Models AUC score: 0.9642659887837761


# Full Bayesian

In [39]:
results = pd.DataFrame({'var': np.arange(561), 
                        'inclusion_probability':np.apply_along_axis(np.mean, 0, trace['gamma_i']),
                       'alpha':np.apply_along_axis(np.mean, 0, trace['alpha']),
                       'alpha_given_inclusion': np.apply_along_axis(np.sum, 0, trace['gamma_i']*trace['alpha'])
                            /np.apply_along_axis(np.sum, 0, trace['gamma_i'])
                       })

In [40]:
results.sort_values('inclusion_probability', ascending = False).head(10)


Unnamed: 0,var,inclusion_probability,alpha,alpha_given_inclusion
41,41,1.0,-3.88396,-3.88396
451,451,1.0,-2.267181,-2.267181
182,182,1.0,5.423177,5.423177
142,142,1.0,1.571083,1.571083
50,50,1.0,-4.035479,-4.035479
445,445,1.0,4.38021,4.38021
186,186,1.0,-3.663025,-3.663025
187,187,0.997,-2.154221,-2.158809
158,158,0.997,-0.996644,-1.002019
53,53,0.9915,-3.449169,-3.479308


## Bayesian inference

In [41]:
estimate = trace['alpha'] * trace['gamma_i'] 
_X_test = X_test[[col for col in X_test.columns if col != "label"]].values
preds = np.apply_along_axis(np.mean, 1, expit(trace['a'] + np.dot(_X_test, np.transpose(estimate) )) )


In [42]:
print(f"Models AUC score: {roc_auc_score(y_test, preds)}")

Models AUC score: 0.9731986064691553


In [43]:
preds[:10]

array([0.99728329, 0.95677686, 0.92627129, 0.96335403, 0.97437116,
       0.94798007, 0.89826115, 0.83813811, 0.84503151, 0.9424179 ])

In [44]:
map_preds[:10]

array([0.99377665, 0.9638406 , 0.86181986, 0.90449025, 0.98654351,
       0.9787043 , 0.94709745, 0.91942937, 0.89949213, 0.92160971])

In [45]:
(map_preds[map_preds > 0.5]).mean(), (map_preds[map_preds <= 0.5]).mean()

(0.9363551762776625, 0.09595969172622476)

In [46]:
(preds[preds > 0.5]).mean(), (preds[preds <= 0.5]).mean()

(0.9245190406644662, 0.09512387902157081)