In [1]:
%%capture
!pip install sklearn
!pip install tqdm

In [2]:
import os, sys
sys.path.append('utils/')
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import cPickle as pkl
from sklearn.decomposition import PCA
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import KFold
import pandas as pd

from flip_gradient import flip_gradient
from utils import *
from evaluation import compute_ks, compute_cvm, roc_auc_truncated
from tqdm import tqdm, trange

SEED = 5

pd.set_option('display.max_columns', 100)

%matplotlib inline



# Preparing data

In [3]:
label_prediction = pd.read_csv('../datasets/training.csv.zip').drop('id', 1)
label_prediction = pd.concat([label_prediction, label_prediction]).sample(frac=1., random_state=SEED)
features = label_prediction.drop(['signal', 'mass', 'min_ANNmuon', 'production'], 1).columns
print label_prediction.shape
label_prediction.head()

(135106, 50)


Unnamed: 0,LifeTime,dira,FlightDistance,FlightDistanceError,IP,IPSig,VertexChi2,pt,DOCAone,DOCAtwo,DOCAthree,IP_p0p2,IP_p1p2,isolationa,isolationb,isolationc,isolationd,isolatione,isolationf,iso,CDF1,CDF2,CDF3,ISO_SumBDT,p0_IsoBDT,p1_IsoBDT,p2_IsoBDT,p0_track_Chi2Dof,p1_track_Chi2Dof,p2_track_Chi2Dof,p0_IP,p1_IP,p2_IP,p0_IPSig,p1_IPSig,p2_IPSig,p0_pt,p1_pt,p2_pt,p0_p,p1_p,p2_p,p0_eta,p1_eta,p2_eta,SPDhits,production,signal,mass,min_ANNmuon
21840,0.001173,0.999793,7.309732,0.225475,0.14476,8.930744,4.965856,4945.793457,0.019389,0.015251,0.037411,0.204333,0.249398,1,5,3,1.0,2.0,2.0,5.0,1.0,0.643805,0.270755,-0.423511,-0.131208,-0.115879,-0.176424,1.225634,1.368773,1.647525,0.339441,0.228714,0.454121,10.899674,11.228765,24.274082,885.767273,1630.131714,2696.752686,8884.09668,12979.206055,17545.503906,2.996217,2.763868,2.559938,572,-99,0,1885.391968,0.464575
56053,0.001524,0.999998,12.253985,0.171148,0.0263,2.110847,7.534899,6609.697754,0.04167,0.008835,0.026718,0.257637,0.093278,0,2,1,0.0,0.0,0.0,0.0,0.867473,0.82389,0.82389,-0.697811,-0.311761,-0.177802,-0.208248,1.368423,1.121142,0.73451,0.154609,0.756436,0.364036,9.368566,44.252838,21.37154,1926.730835,2336.613525,2403.396729,15412.594727,12097.913086,20351.759766,2.768578,2.328018,2.825927,205,1,1,1783.964111,0.572177
45478,0.001228,1.0,52.646694,1.009517,0.008819,0.769089,8.150103,8629.933594,0.122356,0.008543,0.03806,0.286731,0.024846,0,2,0,0.0,0.0,0.0,0.0,0.934495,0.934495,0.719623,-1.020152,-0.269676,-0.300122,-0.450355,0.69103,1.401399,0.9645,0.575982,0.359241,0.33646,8.886332,30.627249,21.443039,402.388214,4380.829102,3906.171631,10589.078125,108756.5,133513.78125,3.962947,3.904615,4.22458,300,1,1,1769.629395,0.329092
2583,0.000641,0.999931,5.975876,0.298114,0.06988,3.800114,7.624316,3507.376465,0.158135,0.006316,0.158239,0.096123,0.072929,1,2,1,0.0,1.0,0.0,1.0,0.905173,0.562683,0.46512,-0.250856,-0.186523,-0.036454,-0.027879,1.146612,2.169724,1.156981,0.220572,0.4758,0.10654,3.659899,14.31028,4.953084,386.018738,945.19574,2178.675293,4852.797363,6461.723633,40627.386719,3.222986,2.610014,3.618153,186,-99,0,1669.694946,0.20614
29111,0.000361,0.999051,2.052216,0.13691,0.093032,6.564189,1.652357,6301.269531,0.015981,0.014536,0.004173,0.157119,0.071218,1,2,0,1.0,2.0,0.0,3.0,1.0,0.617369,0.617369,-0.389445,-0.019306,-0.089797,-0.280342,1.098584,0.679145,1.202928,0.164807,0.121844,0.14339,9.387158,7.297283,8.222139,2004.452393,2760.5271,1727.466309,11180.704102,12460.376953,10097.848633,2.403832,2.187777,2.451416,37,4,1,1773.92041,0.760365


In [4]:
test = pd.read_csv('../datasets/test.csv.zip')
print test.shape

(855819, 47)


In [5]:
private_dataset = pd.merge(test, pd.read_csv('../datasets/private_eval.csv'), on='id').drop_duplicates()
print private_dataset.shape
private_dataset.head()

(28497, 48)


Unnamed: 0,id,LifeTime,dira,FlightDistance,FlightDistanceError,IP,IPSig,VertexChi2,pt,DOCAone,DOCAtwo,DOCAthree,IP_p0p2,IP_p1p2,isolationa,isolationb,isolationc,isolationd,isolatione,isolationf,iso,CDF1,CDF2,CDF3,ISO_SumBDT,p0_IsoBDT,p1_IsoBDT,p2_IsoBDT,p0_track_Chi2Dof,p1_track_Chi2Dof,p2_track_Chi2Dof,p0_IP,p1_IP,p2_IP,p0_IPSig,p1_IPSig,p2_IPSig,p0_pt,p1_pt,p2_pt,p0_p,p1_p,p2_p,p0_eta,p1_eta,p2_eta,SPDhits,signal
0,4369817,0.002109,0.999998,14.665833,0.134259,0.029963,2.091192,0.446515,6850.274414,0.008145,0.008414,0.005803,0.478501,0.182662,1,5,2,0.0,0.0,0.0,0.0,0.963404,0.941582,0.832183,-1.271832,-0.431318,-0.429426,-0.411087,1.040657,1.374021,0.888281,0.541408,0.812435,0.449803,24.295046,40.26123,23.897257,1691.214111,3163.742188,2060.64917,11679.78125,14723.319336,14744.435547,2.620276,2.219076,2.656073,353,1
1,10213012,0.000886,0.999992,9.180307,0.472137,0.037329,2.02565,7.784008,4024.888916,0.070422,0.020018,0.078934,0.13011,0.185814,1,2,1,0.0,0.0,0.0,0.0,0.621012,0.621012,0.621012,-0.929209,-0.171253,-0.391201,-0.366756,1.097838,0.505729,0.969616,0.278849,0.215638,0.306633,9.881294,9.835187,9.932581,1413.249878,1814.38147,1084.320923,22417.171875,20466.382812,18283.640625,3.456087,3.114216,3.51732,52,1
2,17806160,0.001313,0.999966,17.477598,1.382421,0.157643,6.020585,1.043844,1584.708984,0.010628,0.009894,0.068503,0.375108,0.232583,5,3,3,1.0,1.0,1.0,3.0,0.499311,0.396624,0.396624,-0.385559,-0.084681,-0.124454,-0.176424,1.382239,1.274481,0.885738,0.280228,0.17973,1.388035,5.037262,5.983574,10.372314,493.309265,1422.916504,317.407227,12352.144531,59287.277344,6400.515137,3.913197,4.422689,3.69648,202,1
3,641074,0.000827,0.999993,5.873784,0.162675,0.02219,2.182303,4.583155,7745.390137,0.011685,0.012809,0.022599,0.157218,0.06667,2,0,6,0.0,0.0,0.0,0.0,0.915044,0.854287,0.72593,-1.448177,-0.524151,-0.539526,-0.384501,0.87012,1.070111,1.27676,0.25948,0.335885,0.128558,10.903188,24.863413,10.396156,1322.295776,2413.23584,4147.405762,6593.108887,15263.637695,20328.482422,2.289592,2.531343,2.272115,211,1
4,11383293,0.001149,0.999999,9.548421,0.38493,0.014353,0.766637,0.499251,2880.133545,0.017162,0.001223,0.03924,0.185238,0.066593,2,6,1,0.0,1.0,0.0,1.0,0.9215,0.844762,0.443685,-0.869419,-0.432826,-0.188276,-0.248317,1.071941,1.061785,0.953241,1.055305,0.234065,0.158303,18.587053,11.259723,4.993015,392.172211,1645.024414,1098.795898,3348.355713,19929.757812,25797.041016,2.834224,3.185898,3.848738,251,1


In [6]:
public_dataset = pd.merge(test, pd.read_csv('../datasets/public_eval.csv'), on='id').drop_duplicates()
print public_dataset.shape
public_dataset.head()

(12302, 48)


Unnamed: 0,id,LifeTime,dira,FlightDistance,FlightDistanceError,IP,IPSig,VertexChi2,pt,DOCAone,DOCAtwo,DOCAthree,IP_p0p2,IP_p1p2,isolationa,isolationb,isolationc,isolationd,isolatione,isolationf,iso,CDF1,CDF2,CDF3,ISO_SumBDT,p0_IsoBDT,p1_IsoBDT,p2_IsoBDT,p0_track_Chi2Dof,p1_track_Chi2Dof,p2_track_Chi2Dof,p0_IP,p1_IP,p2_IP,p0_IPSig,p1_IPSig,p2_IPSig,p0_pt,p1_pt,p2_pt,p0_p,p1_p,p2_p,p0_eta,p1_eta,p2_eta,SPDhits,signal
0,12200521,0.001825,0.999989,8.118148,0.198013,0.038373,2.368293,3.995679,3088.533447,0.01694,0.050317,0.014406,0.177796,0.290612,4,5,1,1.0,0.0,0.0,1.0,0.601583,0.601583,0.252849,-1.112244,-0.232027,-0.425281,-0.454935,0.991786,0.519075,1.587669,0.504891,0.323959,0.729432,21.858988,19.89922,23.709482,1346.735229,1385.601074,612.74408,8325.249023,8831.709961,9389.864258,2.50815,2.539151,3.42152,433,1
1,2491177,0.000535,0.999934,8.240436,1.095218,0.083408,4.210687,8.162792,3235.502197,0.09586,0.070505,0.067769,0.054672,0.140425,1,5,6,0.0,1.0,1.0,2.0,0.386806,0.27846,0.257121,-0.775315,-0.284896,-0.229474,-0.260944,1.142997,0.980602,1.558769,0.186751,0.147826,0.30329,5.095261,6.19945,3.79709,1509.661865,1778.51355,352.61795,45959.210938,41672.800781,6638.411621,4.108746,3.846762,3.627684,369,0
2,14699668,0.00085,0.999996,28.220333,0.677022,0.07794,8.915274,0.035283,13247.209961,0.000482,0.002346,0.002073,0.174334,0.094599,2,2,2,2.0,3.0,3.0,8.0,0.633866,0.323418,0.316222,-0.364786,-0.124454,-0.115879,-0.124454,0.771244,0.681871,0.886693,0.079222,0.29262,0.273358,3.897521,18.936714,30.21509,2613.442871,4030.838379,6680.432617,38441.621094,60690.246094,88079.054688,3.380462,3.403851,3.270758,414,0
3,225705,0.001145,1.0,16.196651,0.395021,0.009561,0.717381,2.758028,5904.848633,0.000262,0.030797,0.008239,0.188568,0.084972,1,4,1,0.0,0.0,0.0,0.0,0.706804,0.619967,0.579491,-1.177279,-0.482754,-0.396888,-0.297637,1.010778,1.088703,0.871755,0.443124,0.364227,0.270769,16.68066,22.49106,15.464107,1185.56897,2531.938477,2261.802734,14468.316406,27498.492188,41622.230469,3.193202,3.076167,3.60488,149,1
4,4323792,0.001371,0.999999,18.812433,0.529717,0.029653,1.518456,4.937074,4474.023438,0.048851,0.03097,0.095518,0.278755,0.083573,0,1,0,0.0,0.0,0.0,0.0,0.865771,0.825148,0.750697,-1.041557,-0.42137,-0.198817,-0.42137,1.265123,0.93261,1.042837,0.45367,0.494659,0.305528,13.372378,24.646584,7.905024,555.046448,2260.256836,1716.283325,8806.927734,28214.091797,44662.679688,3.456395,3.215882,3.951755,163,1


In [7]:
domain_adaptation_random = pd.read_csv('../datasets/domain_adaptation_random.csv').sample(frac=1., random_state=SEED)

domain_adaptation_high_weight = pd.read_csv('../datasets/domain_adaptation_high_weight.csv').sample(frac=1., random_state=SEED)

print domain_adaptation_random.shape, domain_adaptation_high_weight.shape
domain_adaptation_random.head()

(135106, 47) (135106, 47)


Unnamed: 0,LifeTime,dira,FlightDistance,FlightDistanceError,IP,IPSig,VertexChi2,pt,DOCAone,DOCAtwo,DOCAthree,IP_p0p2,IP_p1p2,isolationa,isolationb,isolationc,isolationd,isolatione,isolationf,iso,CDF1,CDF2,CDF3,ISO_SumBDT,p0_IsoBDT,p1_IsoBDT,p2_IsoBDT,p0_track_Chi2Dof,p1_track_Chi2Dof,p2_track_Chi2Dof,p0_IP,p1_IP,p2_IP,p0_IPSig,p1_IPSig,p2_IPSig,p0_pt,p1_pt,p2_pt,p0_p,p1_p,p2_p,p0_eta,p1_eta,p2_eta,SPDhits,domain
21840,0.001173,0.999793,7.309732,0.225475,0.14476,8.930744,4.965856,4945.793457,0.019389,0.015251,0.037411,0.204333,0.249398,1,5,3,1.0,2.0,2.0,5.0,1.0,0.643805,0.270755,-0.423511,-0.131208,-0.115879,-0.176424,1.225634,1.368773,1.647525,0.339441,0.228714,0.454121,10.899674,11.228765,24.274082,885.767273,1630.131714,2696.752686,8884.09668,12979.206055,17545.503906,2.996217,2.763868,2.559938,572,0
123606,0.002917,0.999991,13.179792,0.124931,0.056593,4.327537,7.19514,4935.144531,0.039205,0.037555,0.006955,0.48423,0.368756,0,0,3,0.0,0.0,0.0,0.0,0.816136,0.465712,0.422229,-1.106902,-0.238089,-0.367395,-0.501418,1.560585,1.122543,1.224513,1.53828,1.037786,0.150551,78.596626,32.999645,10.717737,1042.821045,1768.838745,2451.449219,4810.159668,10328.353516,13843.15918,2.209985,2.450302,2.416325,224,1
45478,0.001228,1.0,52.646694,1.009517,0.008819,0.769089,8.150103,8629.933594,0.122356,0.008543,0.03806,0.286731,0.024846,0,2,0,0.0,0.0,0.0,0.0,0.934495,0.934495,0.719623,-1.020152,-0.269676,-0.300122,-0.450355,0.69103,1.401399,0.9645,0.575982,0.359241,0.33646,8.886332,30.627249,21.443039,402.388214,4380.829102,3906.171631,10589.078125,108756.5,133513.78125,3.962947,3.904615,4.22458,300,0
70136,0.000888,0.999911,3.75188,0.205613,0.049413,3.308183,1.282461,3560.137695,0.024591,0.024395,0.014363,0.055511,0.039771,5,6,11,1.0,1.0,0.0,2.0,0.418602,0.334923,0.327201,-0.453168,-0.152291,-0.124454,-0.176424,0.773729,1.650806,1.349887,0.622369,0.198982,0.136246,13.01162,10.056454,6.481815,387.275085,1425.018677,2228.083008,3266.975586,9333.081055,15420.416016,2.8221,2.566648,2.622437,287,1
96664,0.000629,0.999974,17.316277,1.145542,0.151398,9.436165,3.849268,5634.192871,0.092609,0.020451,0.080179,0.139207,0.153073,5,9,8,3.0,2.0,0.0,5.0,0.5709,0.39012,0.358518,-0.184568,-0.050768,-0.0669,-0.0669,1.153638,0.860187,1.109907,0.262164,0.324961,0.237144,5.828904,3.597851,15.289275,2203.238281,490.422485,3105.210693,85277.132812,20857.15625,76515.773438,4.348958,4.443194,3.897151,402,1


In [8]:
check_agreement = pd.read_csv('../datasets/check_agreement.csv.zip')

print check_agreement.shape
check_agreement.head()

(331147, 49)


Unnamed: 0,id,LifeTime,dira,FlightDistance,FlightDistanceError,IP,IPSig,VertexChi2,pt,DOCAone,DOCAtwo,DOCAthree,IP_p0p2,IP_p1p2,isolationa,isolationb,isolationc,isolationd,isolatione,isolationf,iso,CDF1,CDF2,CDF3,ISO_SumBDT,p0_IsoBDT,p1_IsoBDT,p2_IsoBDT,p0_track_Chi2Dof,p1_track_Chi2Dof,p2_track_Chi2Dof,p0_IP,p1_IP,p2_IP,p0_IPSig,p1_IPSig,p2_IPSig,p0_pt,p1_pt,p2_pt,p0_p,p1_p,p2_p,p0_eta,p1_eta,p2_eta,SPDhits,signal,weight
0,15347063,0.001451,0.999964,6.94503,0.229196,0.058117,2.961298,7.953543,2251.611816,0.082219,0.084005,0.066887,0.185107,0.214719,8,6,1,2.0,1.0,1.0,4.0,0.732076,0.492269,0.179091,-0.207475,-0.019306,-0.089797,-0.098372,0.606178,0.862549,1.487057,0.483199,0.474925,0.426797,24.701061,10.732132,8.853514,1438.064697,468.645721,834.562378,10392.814453,6380.673828,15195.594727,2.666142,3.302978,3.594246,512,0,-0.307813
1,14383299,0.000679,0.999818,9.468235,0.517488,0.189683,14.41306,7.141451,10594.470703,0.007983,0.044154,0.001321,0.039357,0.217507,5,6,17,1.0,1.0,1.0,3.0,0.802508,0.605835,0.584701,-0.659644,-0.27833,-0.18637,-0.194944,1.900118,1.073474,1.336784,0.712242,0.260311,0.123877,11.312134,16.435398,7.737038,316.791351,7547.703613,2861.309814,3174.356934,64480.023438,23134.953125,2.995265,2.834816,2.779366,552,0,-0.331421
2,7382797,0.003027,0.999847,13.280714,0.219291,0.231709,11.973175,4.77888,2502.196289,0.045085,0.106614,0.00585,0.335788,0.88508,2,2,1,0.0,0.0,1.0,1.0,0.682607,0.682607,0.295038,-0.399239,-0.115879,-0.131069,-0.152291,0.660675,1.683084,0.798658,0.381544,1.163556,1.290409,16.435801,20.686119,44.521961,1887.477905,317.579529,932.128235,15219.761719,3921.181641,10180.791016,2.776633,3.204923,3.081832,318,0,-0.382215
3,6751065,0.00081,0.999998,5.166821,0.167886,0.011298,0.891142,5.528002,5097.813965,0.055115,0.038642,0.003864,0.076522,0.068347,4,4,3,0.0,0.0,0.0,0.0,0.533615,0.533615,0.533615,-0.821041,-0.208248,-0.177802,-0.434991,0.770563,1.093031,0.938619,0.56465,0.164411,0.166646,24.878387,7.873435,9.630725,975.041687,1650.837524,2617.248291,4365.08252,13221.149414,24291.875,2.179345,2.769762,2.918251,290,0,1.465194
4,9439580,0.000706,0.999896,10.897236,0.284975,0.160511,16.36755,8.670339,20388.097656,0.015587,0.020872,0.014612,0.249906,0.139937,0,1,0,0.0,0.0,0.0,0.0,0.92641,0.92641,0.92641,-1.116815,-0.328938,-0.443564,-0.344313,1.080559,1.471946,1.123868,0.373736,0.230584,0.11243,28.557213,18.738485,7.389726,6035.000977,9657.492188,4763.682617,27463.011719,46903.394531,24241.628906,2.196114,2.262732,2.310401,45,0,-0.477084


In [9]:
check_correlation = pd.read_csv('../datasets/check_correlation.csv.zip')

check_correlation['min_ANNmuon'] = label_prediction['min_ANNmuon'].mean()
check_correlation['production'] = label_prediction['production'].mean()

print check_correlation.shape
check_correlation.head()

(5514, 50)


Unnamed: 0,id,LifeTime,dira,FlightDistance,FlightDistanceError,IP,IPSig,VertexChi2,pt,DOCAone,DOCAtwo,DOCAthree,IP_p0p2,IP_p1p2,isolationa,isolationb,isolationc,isolationd,isolatione,isolationf,iso,CDF1,CDF2,CDF3,ISO_SumBDT,p0_IsoBDT,p1_IsoBDT,p2_IsoBDT,p0_track_Chi2Dof,p1_track_Chi2Dof,p2_track_Chi2Dof,p0_IP,p1_IP,p2_IP,p0_IPSig,p1_IPSig,p2_IPSig,p0_pt,p1_pt,p2_pt,p0_p,p1_p,p2_p,p0_eta,p1_eta,p2_eta,SPDhits,mass,min_ANNmuon,production
0,11120335,0.000703,0.999715,2.927074,0.214014,0.081302,4.259793,1.066585,3108.189941,0.010767,0.024147,0.003066,0.100619,0.087596,2,10,5,2.0,1.0,1.0,4.0,0.661286,0.59228,0.59228,-0.229346,-0.089797,-0.049752,-0.089797,1.015762,0.918588,2.555963,0.336053,0.13852,0.239775,6.975648,7.199831,9.638602,1142.124512,1294.450928,1073.97644,9274.671875,7963.914062,6712.897949,2.783731,2.50331,2.519349,280,1723.887939,0.488508,-36.657765
1,11495369,0.000601,0.99995,15.849142,0.842973,0.182213,13.882857,5.780046,6858.264648,0.007574,0.02001,0.064194,0.058354,0.247012,3,3,4,0.0,1.0,1.0,2.0,0.74436,0.441476,0.209664,-0.364786,-0.115879,-0.124454,-0.124454,1.413934,0.696266,1.531379,0.282444,0.298322,0.168395,10.620588,17.345636,7.344349,1174.860229,3981.284912,1859.680542,22844.791016,109955.101562,37051.800781,3.660059,4.011287,3.684429,386,1926.284058,0.488508,-36.657765
2,7098902,0.002009,0.999984,43.358494,1.323199,0.200158,12.870687,3.460782,3604.347412,0.052849,0.127547,0.073476,0.365799,0.292673,0,1,0,0.0,0.0,0.0,0.0,0.699672,0.678214,0.437321,-1.175178,-0.351104,-0.414748,-0.409326,1.307686,1.081815,0.674668,0.173122,0.480514,1.041596,4.409735,30.87108,7.409998,1046.86731,2564.43042,354.095032,45214.070312,77265.429688,12087.007812,4.458619,4.098382,4.223254,433,1830.873047,0.488508,-36.657765
3,8103692,0.001268,0.99955,6.910733,0.198652,0.16773,8.559438,7.676139,3240.960449,0.101368,0.002148,0.037449,0.060147,0.162985,4,9,4,1.0,1.0,1.0,3.0,0.257655,0.184313,0.182857,-0.495822,-0.194944,-0.176424,-0.124454,1.019122,2.24854,1.243463,0.390963,0.627191,0.136959,11.687304,31.343948,4.312659,1051.328979,1178.332031,1285.70166,11245.551758,6770.969238,17003.119141,3.060873,2.434033,3.273807,245,1909.119019,0.488508,-36.657765
4,10160864,0.001937,0.999996,36.135208,0.508036,0.09717,9.295684,0.42634,6448.445312,0.007005,0.00851,0.009841,0.481759,0.455559,0,4,0,0.0,0.0,0.0,0.0,0.58802,0.537217,0.222616,-0.766796,-0.263475,-0.263475,-0.239847,0.766173,0.715839,0.847636,0.612008,0.498065,0.641609,38.172073,38.144756,14.099512,3081.089111,3140.512207,385.343475,48171.457031,43973.835938,7368.522949,3.441606,3.331079,3.6433,489,1600.925049,0.488508,-36.657765


In [10]:
Xd_1, yd_1 = domain_adaptation_high_weight[features].values, domain_adaptation_high_weight.domain.values
Xd_2, yd_2 = domain_adaptation_random[features].values, domain_adaptation_random.domain.values
Xt, yt = label_prediction[features].values, label_prediction.signal.values
X_public, y_public = public_dataset[features].values, public_dataset.signal.values
X_private, y_private = private_dataset[features].values, private_dataset.signal.values

print Xd_1.shape, yd_1.shape
print Xd_2.shape, yd_2.shape
print Xt.shape, yt.shape
print X_public.shape, y_public.shape
print X_private.shape, y_private.shape

(135106, 46) (135106,)
(135106, 46) (135106,)
(135106, 46) (135106,)
(12302, 46) (12302,)
(28497, 46) (28497,)


In [14]:
batch_size = 50
features_size = len(features)
features_layer_1 = 100
domain_layer_1 = 20

def build_model(shallow_domain_classifier):
    """
    :param shallow_domain_classifier - if True then shallow domain adaptation will be used
    """
    X = tf.placeholder(tf.float32, [None, features_size], name='X') # Input data
    train = tf.placeholder(tf.bool, [], name='train')       # Switch for routing data to class predictor
    l = tf.placeholder(tf.float32, [], name='l')        # Gradient reversal scaler
    dropout_prob = tf.placeholder(tf.float32, [], name='dropout') # Dropout
    
    # Feature extractor - single layer
    W0 = weight_variable([features_size, features_layer_1])
    b0 = bias_variable([features_layer_1])
    F = tf.nn.relu(tf.matmul(X, W0) + b0, name='feature')
    # Apply dropout
    F = tf.nn.dropout(F, dropout_prob)

    # Label predictor - single layer
    Y_ind = tf.placeholder(tf.int32, [None], name='Y_ind')  # Class index
    Y = tf.one_hot(Y_ind, 2)
        
    f = tf.cond(train, lambda: tf.slice(F, [0, 0], [batch_size, -1]), lambda: F)
    y = tf.cond(train, lambda: tf.slice(Y, [0, 0], [batch_size, -1]), lambda: Y)
            
    W1 = weight_variable([features_layer_1, 2])
    b1 = bias_variable([2])
    p_logit = tf.matmul(f, W1) + b1
    p_logit = tf.nn.dropout(p_logit, dropout_prob)
    p = tf.nn.softmax(p_logit, name='p')
    p_loss = tf.nn.softmax_cross_entropy_with_logits(logits=p_logit, labels=y)
    
    # Optimization
    pred_loss = tf.reduce_sum(p_loss, name='pred_loss')
    pred_train_op = tf.train.AdamOptimizer(0.01).minimize(pred_loss, name='pred_train_op')
    
    # Evaluation
    p_acc = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y, 1), tf.arg_max(p, 1)), tf.float32), name='p_acc')
    
    # Domain predictor - shallow
    D_ind = tf.placeholder(tf.int32, [None], name='D_ind')  # Domain index
    D = tf.one_hot(D_ind, 2)

    f_ = flip_gradient(F, l)

    if shallow_domain_classifier:
        W2 = weight_variable([features_layer_1, 2])
        b2 = bias_variable([2])
        d_logit = tf.matmul(f_, W2) + b2
        d_logit = tf.nn.dropout(d_logit, dropout_prob)
        d = tf.nn.softmax(d_logit)
        d_loss = tf.nn.softmax_cross_entropy_with_logits(logits=d_logit, labels=D)

    else:
        W2 = weight_variable([features_layer_1, domain_layer_1])
        b2 = bias_variable([domain_layer_1])
        h2 = tf.nn.relu(tf.matmul(f_, W2) + b2)
        h2 = tf.nn.dropout(h2, dropout_prob)
        W3 = weight_variable([domain_layer_1, 2])
        b3 = bias_variable([2])
        d_logit = tf.matmul(h2, W3) + b3
        d = tf.nn.softmax(d_logit)
        d_loss = tf.nn.softmax_cross_entropy_with_logits(logits=d_logit, labels=D)
            
    # Optimization
    domain_loss = tf.reduce_sum(d_loss, name='domain_loss')
    domain_train_op = tf.train.AdamOptimizer(0.01).minimize(domain_loss, name='domain_train_op')
        
    # Evaluation
    d_acc = tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(D, 1), tf.arg_max(d, 1)), tf.float32), name='d_acc')

In [17]:
sess = tf.InteractiveSession()

def train_and_evaluate(sess, Xt, yt, X_holdout, y_holdout, 
                       Xd, yd, domain_adaptation=None, grad_scale=None, num_batches=10000, verbose=True, dropout=1.):
    '''
    :param sess - tensorflow session
    :param domain - high weights if domain == 1, random else
    :param Xt - train features for label predictor
    :param yt - train target for label predictor
    :param domain_adaptation - takes None, False or True value. 
    No domain adaptation if None, shallow domain adaptation if False and deep domain adaptation if True
    '''
    sess.close()
    tf.reset_default_graph()
    shallow = False
    if domain_adaptation != None:
        shallow = domain_adaptation
    build_model(shallow)
    sess = tf.InteractiveSession()

    # Create batch builders
    D_batches = batch_generator([Xd, yd], batch_size)
    T_batches = batch_generator([Xt, yt], batch_size)
    
    # Get output tensors and train op
    d_acc = sess.graph.get_tensor_by_name('d_acc:0')
    p_probs = sess.graph.get_tensor_by_name('p:0')
    p_acc = sess.graph.get_tensor_by_name('p_acc:0')
    domain_loss = sess.graph.get_tensor_by_name('domain_loss:0')
    domain_op = sess.graph.get_operation_by_name('domain_train_op')
    target_loss = sess.graph.get_tensor_by_name('pred_loss:0')
    target_op = sess.graph.get_operation_by_name('pred_train_op')
    
    sess.run(tf.global_variables_initializer())
    for i in range(num_batches):
        # If no grad_scale, use a schedule
        if grad_scale is None:
            p = float(i) / num_batches
            lp = 2. / (1. + np.exp(-10. * p)) - 1
        else:
            lp = grad_scale

        X0, y0 = D_batches.next()
        X1, y1 = T_batches.next()
        _, target_loss_value, pa, da = sess.run([target_op, target_loss, p_acc, d_acc],
                                   feed_dict={'X:0': X1, 'Y_ind:0': y1, 'D_ind:0': y0,
                                              'train:0': True, 'l:0': lp, 'dropout:0': dropout})
        if domain_adaptation:
            _, domain_loss_value, pa, da = sess.run([domain_op, domain_loss, p_acc, d_acc],
                                       feed_dict={'X:0': X0, 'Y_ind:0': y1, 'D_ind:0': y0,
                                                  'train:0': True, 'l:0': lp, 'dropout:0': dropout})
        

        if verbose and i % 200 == 0:
            print 'loss: %f, domain accuracy: %f, class accuracy: %f' % (target_loss_value, da, pa)
            if domain_adaptation:
                print 'loss: %f, domain accuracy: %f, class accuracy: %f' % (domain_loss_value, da, pa)
            

            
    # Get final accuracies on whole dataset
#     dad, pad = sess.run([d_acc, p_acc], feed_dict={'X:0': Xd, 'Y_ind:0': yd, 
#                             'D_ind:0': np.zeros(Xd.shape[0], dtype=np.int32), 'train:0': False, 'l:0': 1.0, 'dropout:0': 1.0})
    _, _, pt_holdout = sess.run([d_acc, p_acc, p_probs], feed_dict={'X:0': X_holdout, 'Y_ind:0': np.zeros_like(y_holdout, dtype=np.int32),
                            'D_ind:0': np.ones(X_holdout.shape[0], dtype=np.int32), 'train:0': False, 'l:0': 1.0, 'dropout:0': 1.0})
    _, _, pt_public = sess.run([d_acc, p_acc, p_probs], feed_dict={'X:0': X_public, 'Y_ind:0': np.zeros_like(y_public, dtype=np.int32),
                            'D_ind:0': np.ones(X_public.shape[0], dtype=np.int32), 'train:0': False, 'l:0': 1.0, 'dropout:0': 1.0})
    
    _, _, pt_private = sess.run([d_acc, p_acc, p_probs], feed_dict={'X:0': X_private, 'Y_ind:0': np.zeros_like(y_private, dtype=np.int32),
                            'D_ind:0': np.ones(X_private.shape[0], dtype=np.int32), 'train:0': False, 'l:0': 1.0, 'dropout:0': 1.0})
    
    agreement_probs = sess.run([p_probs], feed_dict={'X:0': check_agreement[features].values,
                                                     'Y_ind:0': np.zeros_like(check_agreement.signal, dtype=np.int32),
                                                     'D_ind:0': np.ones(check_agreement.shape[0], dtype=np.int32), 
                                                     'train:0': False, 'l:0': 1.0, 'dropout:0': 1.0})[0][:, 1]
    
    correlation_probs = sess.run([p_probs], feed_dict={'X:0': check_correlation[features].values,
                                                     'Y_ind:0': np.zeros(check_correlation.shape[0], dtype=np.int32),
                                                     'D_ind:0': np.ones(check_correlation.shape[0], dtype=np.int32), 
                                                     'train:0': False, 'l:0': 1.0, 'dropout:0': 1.0})[0][:, 1]

    ks = compute_ks(
        agreement_probs[check_agreement['signal'].values == 0],
        agreement_probs[check_agreement['signal'].values == 1],
        check_agreement[check_agreement['signal'] == 0]['weight'].values,
        check_agreement[check_agreement['signal'] == 1]['weight'].values)
    
    cvm = compute_cvm(correlation_probs, check_correlation.mass)
    
    auc_trunc_holdout = roc_auc_truncated(y_holdout, pt_holdout[:, 1])
    auc_trunc_public = roc_auc_truncated(y_public, pt_public[:, 1])
    auc_trunc_private = roc_auc_truncated(y_private, pt_private[:, 1])
    if verbose:
        print 'KS: ', ks
        print 'CvM: ', cvm
        print 'AUC_truncated (holdout): ', auc_trunc_holdout
        print 'AUC_truncated (public): ', auc_trunc_public
        print 'AUC_truncated (private): ', auc_trunc_private
        
    return auc_trunc_holdout, auc_trunc_public, auc_trunc_private, ks, cvm

def extract_and_plot_pca_feats(sess, feat_tensor_name='feature'):
    F = sess.graph.get_tensor_by_name(feat_tensor_name + ':0')
    emb_s = sess.run(F, feed_dict={'X:0': Xd})
    emb_t = sess.run(F, feed_dict={'X:0': Xt})
    emb_all = np.vstack([emb_s, emb_t])

    pca = PCA(n_components=2)
    pca_emb = pca.fit_transform(emb_all)

    num = pca_emb.shape[0] / 2
    plt.scatter(pca_emb[:num,0], pca_emb[:num,1], c=ys, alpha=0.4)
    plt.scatter(pca_emb[num:,0], pca_emb[num:,1], c=yt, cmap='cool', alpha=0.4)
    
def cross_validation(sess, domain_sample=1, cv=5, train_size=0.7, verbose=True, params={}):
    aucs_holdout, aucs_public, aucs_private, kss, cvms = [], [], [], [], []
    if domain_sample == 1:
        Xd, yd = Xd_1, yd_1
    elif domain_sample == 2:
        Xd, yd = Xd_2, yd_2
    for i in xrange(cv):
        Xt_train, Xt_test, yt_train, yt_test, Xd_train, _, yd_train, _ = train_test_split(Xt, yt, Xd, yd, train_size=train_size)
        auc_trunc_holdout, auc_trunc_public, auc_trunc_private, ks, cvm = train_and_evaluate(sess, Xd_train, yd_train, Xt_train, yt_train, Xt_test, yt_test, **params)
        
        if verbose:
            print 'KS: ', ks
            print 'CvM: ', cvm
            print 'AUC_truncated (holdout): ', auc_trunc_holdout
            print 'AUC_truncated (public): ', auc_trunc_public
            print 'AUC_truncated (private): ', auc_trunc_private
            print '*****************'
        aucs_holdout.append(auc_trunc_holdout); aucs_public.append(auc_trunc_public); 
        aucs_private.append(auc_trunc_private)
        kss.append(ks); cvms.append(cvm)
    print 'Final AUC (holdout). Mean: %.4f. Std: %.4f' % (np.mean(aucs_holdout), np.std(aucs_holdout))
    print 'Final AUC (public). Mean: %.4f. Std: %.4f' % (np.mean(aucs_public), np.std(aucs_public))
    print 'Final AUC (private). Mean: %.4f. Std: %.4f' % (np.mean(aucs_private), np.std(aucs_private))
    print 'Final KS. Mean: %.4f. Std: %.4f' % (np.mean(kss), np.std(kss))
    print 'Final CvM. Mean: %.4f. Std: %.4f' % (np.mean(cvms), np.std(cvms))
    return aucs_holdout, aucs_public, aucs_private, kss, cvms

## Only events with highest weigths were taken for domain adaptation

#### Classification with no domain adaptation

In [18]:
cross_validation(sess, cv=10, params={'domain_adaptation': None, 'verbose': False, 'num_batches': 3000, 'dropout': 1})
pass

KS:  0.0520976774647
CvM:  0.0105130965184
AUC_truncated (holdout):  0.527081529672
AUC_truncated (public):  0.54659363967
AUC_truncated (private):  0.544917140296
*****************
KS:  0.0528705896061
CvM:  0.0466860083203
AUC_truncated (holdout):  0.491017435361
AUC_truncated (public):  0.504722916148
AUC_truncated (private):  0.501964200776
*****************
KS:  0.0545851853528
CvM:  0.0357793245001
AUC_truncated (holdout):  0.539745798187
AUC_truncated (public):  0.564347792875
AUC_truncated (private):  0.556223757708
*****************
KS:  0.0541839215038
CvM:  0.0130105402296
AUC_truncated (holdout):  0.498898439451
AUC_truncated (public):  0.52019219244
AUC_truncated (private):  0.505170289165
*****************
KS:  0.0349773951386
CvM:  0.0459785359532
AUC_truncated (holdout):  0.484956529882
AUC_truncated (public):  0.486613535967
AUC_truncated (private):  0.487874306205
*****************
KS:  0.0600686963958
CvM:  0.03524147937
AUC_truncated (holdout):  0.538622326355
AUC_t

#### Classification with shallow domain adaptation

In [19]:
cross_validation(sess, cv=10, params={'domain_adaptation': False, 'verbose': False, 'num_batches': 3000})
pass

KS:  0.0501749425205
CvM:  0.013520568384
AUC_truncated (holdout):  0.521839421191
AUC_truncated (public):  0.543438644245
AUC_truncated (private):  0.539012187478
*****************
KS:  0.0299758795337
CvM:  0.0536448919239
AUC_truncated (holdout):  0.475619740067
AUC_truncated (public):  0.485449979133
AUC_truncated (private):  0.483551828417
*****************
KS:  0.0436636315832
CvM:  0.0748957354954
AUC_truncated (holdout):  0.483600954572
AUC_truncated (public):  0.496878716481
AUC_truncated (private):  0.494855825442
*****************
KS:  0.0539494869027
CvM:  0.0244360910859
AUC_truncated (holdout):  0.481028743957
AUC_truncated (public):  0.496779085559
AUC_truncated (private):  0.495111725588
*****************
KS:  0.0694519741679
CvM:  0.00191332355006
AUC_truncated (holdout):  0.539662124207
AUC_truncated (public):  0.572430288676
AUC_truncated (private):  0.572480528292
*****************
KS:  0.0608746634127
CvM:  0.0307576695538
AUC_truncated (holdout):  0.478272362428
A

#### Classification with deep domain adaptation

In [20]:
cross_validation(sess, cv=10, params={'domain_adaptation': True, 'verbose': False, 'num_batches': 5000})
pass

KS:  0.00314902405361
CvM:  0.150304778661
AUC_truncated (holdout):  0.506267044985
AUC_truncated (public):  0.507800063674
AUC_truncated (private):  0.508252625835
*****************
KS:  0.0339948522155
CvM:  0.153626871091
AUC_truncated (holdout):  0.499509225101
AUC_truncated (public):  0.499440602354
AUC_truncated (private):  0.499596830392
*****************
KS:  0.0425781184422
CvM:  0.153112707494
AUC_truncated (holdout):  0.49898711158
AUC_truncated (public):  0.498317376087
AUC_truncated (private):  0.499119773101
*****************
KS:  0.0144479029399
CvM:  0.121444502502
AUC_truncated (holdout):  0.503851778161
AUC_truncated (public):  0.504809902282
AUC_truncated (private):  0.504227806173
*****************
KS:  0.015726913511
CvM:  0.125014941195
AUC_truncated (holdout):  0.539271422658
AUC_truncated (public):  0.543935052531
AUC_truncated (private):  0.54781066703
*****************
KS:  0.02987572305
CvM:  0.0571942706172
AUC_truncated (holdout):  0.515543039842
AUC_trunca

#### Deep domain adaptation with more epoches

In [None]:
cross_validation(sess, cv=10, params={'domain_adaptation': True, 'verbose': False, 'num_batches': 15000, 'dropout': 0.98})
pass

KS:  0.012050254142
CvM:  0.123497030618
AUC_truncated (holdout):  0.538390187175
AUC_truncated (public):  0.545049347342
AUC_truncated (private):  0.544332287546
*****************
KS:  0.0194872508689
CvM:  0.103929183505
AUC_truncated (holdout):  0.536674597535
AUC_truncated (public):  0.537709253695
AUC_truncated (private):  0.540764314815
*****************
KS:  0.0224477574215
CvM:  0.059428998828
AUC_truncated (holdout):  0.55210173411
AUC_truncated (public):  0.560269974656
AUC_truncated (private):  0.551392914617
*****************
KS:  0.0266528563101
CvM:  0.128218395614
AUC_truncated (holdout):  0.514245206355
AUC_truncated (public):  0.512258650403
AUC_truncated (private):  0.513701257604
*****************
KS:  0.0392191424107
CvM:  0.147179025101
AUC_truncated (holdout):  0.501169781497
AUC_truncated (public):  0.5016818246
AUC_truncated (private):  0.500749600694
*****************
KS:  0.0471161556806
CvM:  0.152399497963
AUC_truncated (holdout):  0.502212707999
AUC_truncat

#### More epoches without domain adaptation

In [None]:
cross_validation(sess, cv=10, params={'domain_adaptation': None, 'verbose': False, 'num_batches': 15000, 'dropout': 0.98})
pass

KS:  0.000183434230583
CvM:  0.155842626513
AUC_truncated (holdout):  0.5
AUC_truncated (public):  0.5
AUC_truncated (private):  0.499931796481
*****************
KS:  7.22390255723e-05
CvM:  0.155842626513
AUC_truncated (holdout):  0.5
AUC_truncated (public):  0.5
AUC_truncated (private):  0.5
*****************
KS:  0.000137943203479
CvM:  0.155842626513
AUC_truncated (holdout):  0.5
AUC_truncated (public):  0.5
AUC_truncated (private):  0.5
*****************


## Random subsampling was taken for domain adaptation

#### Classification with no domain adaptation

In [59]:
cross_validation(sess, cv=10, domain_sample=2, params={'domain_adaptation': None, 'verbose': False, 'num_batches': 3000, 'dropout': 1})
pass

KS:  0.0525671994172
CvM:  0.0287879355449
AUC_truncated (holdout):  0.502442730603
AUC_truncated (public):  0.515701066711
AUC_truncated (private):  0.513469612468
*****************
KS:  0.0442551754737
CvM:  0.0299983781259
AUC_truncated (holdout):  0.4981890783
AUC_truncated (public):  0.514417918713
AUC_truncated (private):  0.509212307934
*****************
KS:  0.0469887434336
CvM:  0.074109288548
AUC_truncated (holdout):  0.483364623174
AUC_truncated (public):  0.493194820686
AUC_truncated (private):  0.486536310792
*****************
KS:  0.0535786257341
CvM:  0.0027729678667
AUC_truncated (holdout):  0.477784882795
AUC_truncated (public):  0.497579156426
AUC_truncated (private):  0.49836908262
*****************
KS:  0.0431042417538
CvM:  0.0786181228742
AUC_truncated (holdout):  0.477010899545
AUC_truncated (public):  0.486075493886
AUC_truncated (private):  0.485507599617
*****************
KS:  0.0512607443724
CvM:  0.069065554199
AUC_truncated (holdout):  0.477257969916
AUC_tr

#### Classification with shallow domain adaptation

In [60]:
cross_validation(sess, cv=10, domain_sample=2, params={'domain_adaptation': False, 'verbose': False, 'num_batches': 3000})
pass

KS:  0.0397040668144
CvM:  0.0351149133105
AUC_truncated (holdout):  0.478593775751
AUC_truncated (public):  0.495359430494
AUC_truncated (private):  0.486986100854
*****************
KS:  0.0424425205199
CvM:  0.06075129596
AUC_truncated (holdout):  0.47046811799
AUC_truncated (public):  0.482914957135
AUC_truncated (private):  0.481025230559
*****************
KS:  0.0393202786083
CvM:  0.0884793951588
AUC_truncated (holdout):  0.474549361795
AUC_truncated (public):  0.485448536894
AUC_truncated (private):  0.48121307219
*****************
KS:  0.034478067995
CvM:  0.0752129389579
AUC_truncated (holdout):  0.475955715237
AUC_truncated (public):  0.484691552486
AUC_truncated (private):  0.483858077235
*****************
KS:  0.0428402847664
CvM:  0.0156672055789
AUC_truncated (holdout):  0.505053306754
AUC_truncated (public):  0.515860823367
AUC_truncated (private):  0.513348780811
*****************
KS:  0.0439201826097
CvM:  0.0338681805559
AUC_truncated (holdout):  0.481023697327
AUC_tr

#### Classification with deep domain adaptation

In [61]:
cross_validation(sess, cv=10, domain_sample=2, params={'domain_adaptation': True, 'verbose': False, 'num_batches': 5000})
pass

KS:  0.0396732666466
CvM:  0.153104861732
AUC_truncated (holdout):  0.499401142638
AUC_truncated (public):  0.499012691739
AUC_truncated (private):  0.499533595682
*****************
KS:  0.000212636607928
CvM:  0.155604753566
AUC_truncated (holdout):  0.500220634877
AUC_truncated (public):  0.500159184973
AUC_truncated (private):  0.5
*****************
KS:  0.0234580308355
CvM:  0.155842626513
AUC_truncated (holdout):  0.5
AUC_truncated (public):  0.5
AUC_truncated (private):  0.5
*****************
KS:  0.00371014105753
CvM:  0.137609037249
AUC_truncated (holdout):  0.495525825059
AUC_truncated (public):  0.499101554522
AUC_truncated (private):  0.496060881028
*****************
KS:  0.000919621251974
CvM:  0.154895815042
AUC_truncated (holdout):  0.501489819566
AUC_truncated (public):  0.501591849729
AUC_truncated (private):  0.502318919656
*****************
KS:  0.0298288542462
CvM:  0.115255382037
AUC_truncated (holdout):  0.551711237356
AUC_truncated (public):  0.565720604708
AUC_tr

#### Deep domain adaptation with more epoches

In [62]:
cross_validation(sess, cv=10, domain_sample=2, params={'domain_adaptation': True, 'verbose': False, 'num_batches': 15000, 'dropout': 0.98})
pass

KS:  0.00106521237616
CvM:  0.0640017568228
AUC_truncated (holdout):  0.554734324971
AUC_truncated (public):  0.557425879073
AUC_truncated (private):  0.556609352811
*****************
KS:  0.0196739103874
CvM:  0.117992688257
AUC_truncated (holdout):  0.520678609562
AUC_truncated (public):  0.52138804153
AUC_truncated (private):  0.520305727318
*****************
KS:  0.0426694580211
CvM:  0.0405491284844
AUC_truncated (holdout):  0.603056415129
AUC_truncated (public):  0.621093962019
AUC_truncated (private):  0.615757602733
*****************
KS:  0.0416170107821
CvM:  0.0469154555113
AUC_truncated (holdout):  0.606413670548
AUC_truncated (public):  0.62402373877
AUC_truncated (private):  0.621800028854
*****************
KS:  0.00319203509915
CvM:  0.0846490101902
AUC_truncated (holdout):  0.530063505381
AUC_truncated (public):  0.532870871534
AUC_truncated (private):  0.531138138543
*****************
KS:  0.0270031154518
CvM:  0.136648318993
AUC_truncated (holdout):  0.502309753852
AUC