Kfold logo

In [2]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './data_2/'

for event in range(5):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
avg_auc = np.mean(auc_scores)
avg_f1 = np.mean(f1_scores)

print(f"Avg AUC: {avg_auc}")
print(f"Avg F1 Score: {avg_f1}")

[0]	eval-auc:0.40978
[1]	eval-auc:0.45616
[2]	eval-auc:0.48337
[3]	eval-auc:0.53840
[4]	eval-auc:0.54359
[5]	eval-auc:0.54164
[6]	eval-auc:0.54060
[7]	eval-auc:0.52138
[8]	eval-auc:0.52207
[9]	eval-auc:0.52701
[10]	eval-auc:0.53184
[11]	eval-auc:0.51718
[12]	eval-auc:0.51592
[13]	eval-auc:0.50642
[14]	eval-auc:0.51054
[15]	eval-auc:0.49984
[16]	eval-auc:0.51576
[17]	eval-auc:0.51180
[18]	eval-auc:0.51641
[19]	eval-auc:0.52756
[20]	eval-auc:0.53733
[21]	eval-auc:0.53524
[22]	eval-auc:0.53568
[23]	eval-auc:0.51735
[24]	eval-auc:0.53530
[25]	eval-auc:0.53942
[26]	eval-auc:0.53777
[27]	eval-auc:0.53266
[28]	eval-auc:0.55704
[29]	eval-auc:0.55418
[30]	eval-auc:0.55336
[31]	eval-auc:0.55138
[32]	eval-auc:0.55424
[33]	eval-auc:0.55215
[34]	eval-auc:0.55281
[35]	eval-auc:0.54189
[36]	eval-auc:0.53129
[37]	eval-auc:0.52525
[38]	eval-auc:0.51570
[39]	eval-auc:0.51537
[40]	eval-auc:0.51718
[41]	eval-auc:0.51449
[42]	eval-auc:0.50950
[43]	eval-auc:0.51230
[44]	eval-auc:0.49715
[45]	eval-auc:0.5024

[137]	eval-auc:0.48036
[138]	eval-auc:0.48206
[139]	eval-auc:0.48285
[140]	eval-auc:0.48298
[141]	eval-auc:0.48276
[142]	eval-auc:0.48276
[143]	eval-auc:0.48171
[144]	eval-auc:0.48368
[145]	eval-auc:0.48324
[146]	eval-auc:0.48281
[147]	eval-auc:0.48211
[148]	eval-auc:0.48254
[149]	eval-auc:0.48241
[150]	eval-auc:0.48302
[151]	eval-auc:0.48211
[152]	eval-auc:0.48211
[153]	eval-auc:0.48311
[154]	eval-auc:0.48425
[155]	eval-auc:0.48285
[156]	eval-auc:0.48224
[157]	eval-auc:0.48193
[158]	eval-auc:0.48163
[159]	eval-auc:0.48019
[160]	eval-auc:0.48014
[161]	eval-auc:0.47923
[162]	eval-auc:0.47958
[163]	eval-auc:0.47866
[164]	eval-auc:0.48071
[165]	eval-auc:0.48049
[166]	eval-auc:0.48054
[167]	eval-auc:0.47979
[168]	eval-auc:0.47966
[169]	eval-auc:0.48027
[170]	eval-auc:0.48123
[171]	eval-auc:0.48267
[172]	eval-auc:0.48228
[173]	eval-auc:0.48123
[174]	eval-auc:0.48102
[175]	eval-auc:0.48219
[176]	eval-auc:0.48106
[177]	eval-auc:0.48176
[178]	eval-auc:0.48272
[179]	eval-auc:0.48062
[180]	eval-

[16]	eval-auc:0.52557
[17]	eval-auc:0.51938
[18]	eval-auc:0.51281
[19]	eval-auc:0.50533
[20]	eval-auc:0.51227
[21]	eval-auc:0.51206
[22]	eval-auc:0.51550
[23]	eval-auc:0.51297
[24]	eval-auc:0.51701
[25]	eval-auc:0.51276
[26]	eval-auc:0.52283
[27]	eval-auc:0.51852
[28]	eval-auc:0.52132
[29]	eval-auc:0.52245
[30]	eval-auc:0.52056
[31]	eval-auc:0.51852
[32]	eval-auc:0.51884
[33]	eval-auc:0.51599
[34]	eval-auc:0.52099
[35]	eval-auc:0.52213
[36]	eval-auc:0.52853
[37]	eval-auc:0.52697
[38]	eval-auc:0.52579
[39]	eval-auc:0.52369
[40]	eval-auc:0.52708
[41]	eval-auc:0.52385
[42]	eval-auc:0.52089
[43]	eval-auc:0.52471
[44]	eval-auc:0.52536
[45]	eval-auc:0.52702
[46]	eval-auc:0.52239
[47]	eval-auc:0.52202
[48]	eval-auc:0.52051
[49]	eval-auc:0.52223
[50]	eval-auc:0.51728
[51]	eval-auc:0.52116
[52]	eval-auc:0.52417
[53]	eval-auc:0.52358
[54]	eval-auc:0.52541
[55]	eval-auc:0.52347
[56]	eval-auc:0.52261
[57]	eval-auc:0.52218
[58]	eval-auc:0.52261
[59]	eval-auc:0.52972
[60]	eval-auc:0.53122
[61]	eval-

[169]	eval-auc:0.50766
[170]	eval-auc:0.50596
[171]	eval-auc:0.50651
[172]	eval-auc:0.50614
[173]	eval-auc:0.50681
[174]	eval-auc:0.50748
[175]	eval-auc:0.50718
[176]	eval-auc:0.50693
[177]	eval-auc:0.50766
[178]	eval-auc:0.50712
[179]	eval-auc:0.50730
[180]	eval-auc:0.50870
[181]	eval-auc:0.51113
[182]	eval-auc:0.51095
[183]	eval-auc:0.50949
[184]	eval-auc:0.51113
[185]	eval-auc:0.51168
[186]	eval-auc:0.51064
[187]	eval-auc:0.51064
[188]	eval-auc:0.51174
[189]	eval-auc:0.51277
[190]	eval-auc:0.51350
[191]	eval-auc:0.51442
[192]	eval-auc:0.51563
[193]	eval-auc:0.51582
[194]	eval-auc:0.51849
[195]	eval-auc:0.51618
[196]	eval-auc:0.51545
[197]	eval-auc:0.51685
[198]	eval-auc:0.51721
[199]	eval-auc:0.51618
[200]	eval-auc:0.51569
[201]	eval-auc:0.51484
Avg AUC: 0.5264193855959898
Avg F1 Score: 0.5013367112795095


Baseline

In [1]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './data_baseline_2/'

for event in range(5):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
baseline_avg_auc = np.mean(auc_scores)
baseline_avg_f1 = np.mean(f1_scores)

print(f"Avg AUC: {baseline_avg_auc}")
print(f"Avg F1 Score: {baseline_avg_f1}")

[0]	eval-auc:0.45434
[1]	eval-auc:0.46640
[2]	eval-auc:0.47823
[3]	eval-auc:0.48381
[4]	eval-auc:0.48001
[5]	eval-auc:0.54672
[6]	eval-auc:0.53952
[7]	eval-auc:0.53823
[8]	eval-auc:0.53849
[9]	eval-auc:0.51434
[10]	eval-auc:0.52943
[11]	eval-auc:0.51278
[12]	eval-auc:0.49996
[13]	eval-auc:0.51468
[14]	eval-auc:0.52962
[15]	eval-auc:0.53493
[16]	eval-auc:0.51855
[17]	eval-auc:0.53493
[18]	eval-auc:0.51858
[19]	eval-auc:0.51843
[20]	eval-auc:0.51312
[21]	eval-auc:0.52154
[22]	eval-auc:0.51980
[23]	eval-auc:0.52033
[24]	eval-auc:0.52761
[25]	eval-auc:0.54263
[26]	eval-auc:0.53800
[27]	eval-auc:0.53269
[28]	eval-auc:0.53421
[29]	eval-auc:0.53406
[30]	eval-auc:0.53830
[31]	eval-auc:0.54331
[32]	eval-auc:0.53140
[33]	eval-auc:0.53102
[34]	eval-auc:0.53133
[35]	eval-auc:0.51335
[36]	eval-auc:0.50546
[37]	eval-auc:0.49621
[38]	eval-auc:0.49803
[39]	eval-auc:0.50599
[40]	eval-auc:0.50850
[41]	eval-auc:0.51358
[42]	eval-auc:0.52306
[43]	eval-auc:0.52184
[44]	eval-auc:0.51798
[45]	eval-auc:0.5141

[160]	eval-auc:0.59186
[161]	eval-auc:0.59303
[162]	eval-auc:0.59231
[163]	eval-auc:0.59175
[164]	eval-auc:0.59206
[165]	eval-auc:0.59242
[166]	eval-auc:0.59288
[167]	eval-auc:0.59272
[168]	eval-auc:0.59359
[169]	eval-auc:0.59318
[170]	eval-auc:0.59364
[171]	eval-auc:0.59283
[172]	eval-auc:0.59405
[173]	eval-auc:0.59379
[174]	eval-auc:0.59303
[175]	eval-auc:0.59277
[176]	eval-auc:0.59303
[177]	eval-auc:0.59221
[178]	eval-auc:0.59068
[179]	eval-auc:0.59104
[180]	eval-auc:0.59201
[181]	eval-auc:0.59043
[182]	eval-auc:0.59032
[183]	eval-auc:0.59038
[184]	eval-auc:0.58946
[185]	eval-auc:0.58839
[186]	eval-auc:0.58685
[187]	eval-auc:0.58767
[188]	eval-auc:0.58752
[189]	eval-auc:0.58711
[190]	eval-auc:0.58736
[191]	eval-auc:0.58844
[192]	eval-auc:0.58925
[193]	eval-auc:0.58946
[194]	eval-auc:0.58946
[195]	eval-auc:0.58813
[196]	eval-auc:0.58798
[197]	eval-auc:0.58736
[198]	eval-auc:0.58711
[199]	eval-auc:0.58706
[200]	eval-auc:0.58731
[201]	eval-auc:0.58696
[202]	eval-auc:0.58655
[203]	eval-

[40]	eval-auc:0.52763
[41]	eval-auc:0.53673
[42]	eval-auc:0.53036
[43]	eval-auc:0.52532
[44]	eval-auc:0.52635
[45]	eval-auc:0.53248
[46]	eval-auc:0.54032
[47]	eval-auc:0.54742
[48]	eval-auc:0.54930
[49]	eval-auc:0.54851
[50]	eval-auc:0.54973
[51]	eval-auc:0.54906
[52]	eval-auc:0.54706
[53]	eval-auc:0.54438
[54]	eval-auc:0.54451
[55]	eval-auc:0.54147
[56]	eval-auc:0.54463
[57]	eval-auc:0.54226
[58]	eval-auc:0.53752
[59]	eval-auc:0.53345
[60]	eval-auc:0.53218
[61]	eval-auc:0.53534
[62]	eval-auc:0.53661
[63]	eval-auc:0.53588
[64]	eval-auc:0.53455
[65]	eval-auc:0.53461
[66]	eval-auc:0.53522
[67]	eval-auc:0.53388
[68]	eval-auc:0.53072
[69]	eval-auc:0.52708
[70]	eval-auc:0.52550
[71]	eval-auc:0.52471
[72]	eval-auc:0.52356
[73]	eval-auc:0.52271
[74]	eval-auc:0.51840
[75]	eval-auc:0.52307
[76]	eval-auc:0.52920
[77]	eval-auc:0.53212
[78]	eval-auc:0.53097
[79]	eval-auc:0.53139
[80]	eval-auc:0.53115
[81]	eval-auc:0.52860
[82]	eval-auc:0.53121
[83]	eval-auc:0.52920
[84]	eval-auc:0.53048
[85]	eval-

[202]	eval-auc:0.56153
[203]	eval-auc:0.56115
[204]	eval-auc:0.56033
[205]	eval-auc:0.55977
[206]	eval-auc:0.56159
[207]	eval-auc:0.56008
[208]	eval-auc:0.55995
[209]	eval-auc:0.56077
[210]	eval-auc:0.55958
[211]	eval-auc:0.55977
[212]	eval-auc:0.56052
[213]	eval-auc:0.55869
[214]	eval-auc:0.55850
Avg AUC: 0.5401897122910624
Avg F1 Score: 0.4773780088805301


In [2]:
print(f"Avg AUC: {baseline_avg_auc}")
print(f"Avg F1 Score: {baseline_avg_f1}")

Avg AUC: 0.5401897122910624
Avg F1 Score: 0.4773780088805301
