Baseline

In [1]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './data_baseline_2/'

for event in range(5):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
baseline_avg_auc = np.mean(auc_scores)
baseline_avg_f1 = np.mean(f1_scores)

print(f"Avg AUC: {baseline_avg_auc}")
print(f"Avg F1 Score: {baseline_avg_f1}")

[0]	eval-auc:0.41879
[1]	eval-auc:0.52629
[2]	eval-auc:0.54798
[3]	eval-auc:0.56224
[4]	eval-auc:0.57145
[5]	eval-auc:0.57134
[6]	eval-auc:0.59392
[7]	eval-auc:0.59188
[8]	eval-auc:0.56890
[9]	eval-auc:0.55180
[10]	eval-auc:0.55438
[11]	eval-auc:0.55533
[12]	eval-auc:0.56216
[13]	eval-auc:0.55791
[14]	eval-auc:0.55343
[15]	eval-auc:0.55513
[16]	eval-auc:0.56557
[17]	eval-auc:0.56414
[18]	eval-auc:0.56804
[19]	eval-auc:0.57366
[20]	eval-auc:0.58276
[21]	eval-auc:0.58391
[22]	eval-auc:0.57668
[23]	eval-auc:0.58506
[24]	eval-auc:0.58417
[25]	eval-auc:0.58414
[26]	eval-auc:0.57880
[27]	eval-auc:0.57886
[28]	eval-auc:0.58419
[29]	eval-auc:0.57972
[30]	eval-auc:0.58023
[31]	eval-auc:0.58264
[32]	eval-auc:0.58328
[33]	eval-auc:0.58236
[34]	eval-auc:0.57622
[35]	eval-auc:0.57886
[36]	eval-auc:0.58121
[37]	eval-auc:0.58603
[38]	eval-auc:0.58701
[39]	eval-auc:0.59217
[40]	eval-auc:0.60176
[41]	eval-auc:0.59751
[42]	eval-auc:0.59659
[43]	eval-auc:0.59590
[44]	eval-auc:0.59688
[45]	eval-auc:0.5941

[114]	eval-auc:0.53416
[115]	eval-auc:0.53241
[116]	eval-auc:0.53254
[117]	eval-auc:0.53389
[118]	eval-auc:0.53375
[119]	eval-auc:0.53470
[120]	eval-auc:0.53618
[121]	eval-auc:0.53631
[122]	eval-auc:0.53490
[123]	eval-auc:0.53302
[124]	eval-auc:0.53597
[125]	eval-auc:0.53618
[126]	eval-auc:0.53456
[127]	eval-auc:0.53342
[128]	eval-auc:0.53423
[129]	eval-auc:0.53322
[130]	eval-auc:0.53261
[131]	eval-auc:0.53530
[132]	eval-auc:0.53765
[133]	eval-auc:0.53712
[134]	eval-auc:0.53873
[135]	eval-auc:0.53826
[136]	eval-auc:0.53880
[137]	eval-auc:0.54122
[138]	eval-auc:0.53967
[139]	eval-auc:0.53920
[140]	eval-auc:0.53987
[141]	eval-auc:0.54075
[142]	eval-auc:0.54223
[143]	eval-auc:0.53974
[144]	eval-auc:0.54115
[145]	eval-auc:0.53967
[146]	eval-auc:0.53947
[147]	eval-auc:0.53967
[148]	eval-auc:0.53940
[149]	eval-auc:0.54041
[150]	eval-auc:0.54068
[151]	eval-auc:0.54115
[152]	eval-auc:0.54075
[153]	eval-auc:0.54028
[154]	eval-auc:0.53994
[155]	eval-auc:0.54034
[156]	eval-auc:0.53960
[157]	eval-

[272]	eval-auc:0.58582
[273]	eval-auc:0.58690
[274]	eval-auc:0.58831
[275]	eval-auc:0.58447
[276]	eval-auc:0.58485
[277]	eval-auc:0.58333
[278]	eval-auc:0.58474
[279]	eval-auc:0.58382
[280]	eval-auc:0.58539
[281]	eval-auc:0.58501
[282]	eval-auc:0.58577
[283]	eval-auc:0.58755
[284]	eval-auc:0.58853
[285]	eval-auc:0.58826
[286]	eval-auc:0.58820
[287]	eval-auc:0.58945
[288]	eval-auc:0.58977
[289]	eval-auc:0.58966
[290]	eval-auc:0.58988
[291]	eval-auc:0.59237
[292]	eval-auc:0.59291
[293]	eval-auc:0.59394
[294]	eval-auc:0.59378
[295]	eval-auc:0.59351
[296]	eval-auc:0.59215
[297]	eval-auc:0.59199
[298]	eval-auc:0.59248
[299]	eval-auc:0.59269
[300]	eval-auc:0.59324
[301]	eval-auc:0.59269
[302]	eval-auc:0.59345
[303]	eval-auc:0.59291
[304]	eval-auc:0.59053
[305]	eval-auc:0.59291
[306]	eval-auc:0.59242
[307]	eval-auc:0.59210
[308]	eval-auc:0.59161
[309]	eval-auc:0.59259
[310]	eval-auc:0.59167
[311]	eval-auc:0.59161
[312]	eval-auc:0.59253
[313]	eval-auc:0.59150
[314]	eval-auc:0.59150
[315]	eval-

[82]	eval-auc:0.54203
[83]	eval-auc:0.53981
[84]	eval-auc:0.54002
[85]	eval-auc:0.54252
[86]	eval-auc:0.54487
[87]	eval-auc:0.54030
[88]	eval-auc:0.54404
[89]	eval-auc:0.54591
[90]	eval-auc:0.55070
[91]	eval-auc:0.54813
[92]	eval-auc:0.54758
[93]	eval-auc:0.55167
[94]	eval-auc:0.55424
[95]	eval-auc:0.55382
[96]	eval-auc:0.55181
[97]	eval-auc:0.55008
[98]	eval-auc:0.55202
[99]	eval-auc:0.55410
[100]	eval-auc:0.55438
[101]	eval-auc:0.55556
[102]	eval-auc:0.55119
[103]	eval-auc:0.55521
[104]	eval-auc:0.55465
[105]	eval-auc:0.55722
[106]	eval-auc:0.55729
[107]	eval-auc:0.55757
[108]	eval-auc:0.55514
[109]	eval-auc:0.55479
[110]	eval-auc:0.55465
[111]	eval-auc:0.55119
[112]	eval-auc:0.55431
[113]	eval-auc:0.55396
[114]	eval-auc:0.55271
[115]	eval-auc:0.55403
[116]	eval-auc:0.55243
[117]	eval-auc:0.55250
[118]	eval-auc:0.55167
[119]	eval-auc:0.55299
[120]	eval-auc:0.55479
[121]	eval-auc:0.55410
[122]	eval-auc:0.55431
[123]	eval-auc:0.55181
[124]	eval-auc:0.55119
[125]	eval-auc:0.55167
[126]	

[56]	eval-auc:0.53030
[57]	eval-auc:0.52790
[58]	eval-auc:0.52224
[59]	eval-auc:0.52295
[60]	eval-auc:0.52380
[61]	eval-auc:0.52400
[62]	eval-auc:0.52315
[63]	eval-auc:0.52185
[64]	eval-auc:0.52243
[65]	eval-auc:0.51970
[66]	eval-auc:0.52224
[67]	eval-auc:0.51931
[68]	eval-auc:0.52022
[69]	eval-auc:0.52003
[70]	eval-auc:0.52530
[71]	eval-auc:0.52361
[72]	eval-auc:0.52608
[73]	eval-auc:0.52510
[74]	eval-auc:0.52367
[75]	eval-auc:0.52530
[76]	eval-auc:0.52419
[77]	eval-auc:0.52484
[78]	eval-auc:0.52387
[79]	eval-auc:0.52217
[80]	eval-auc:0.52367
[81]	eval-auc:0.52204
[82]	eval-auc:0.52666
[83]	eval-auc:0.52718
[84]	eval-auc:0.52634
[85]	eval-auc:0.52341
[86]	eval-auc:0.52387
[87]	eval-auc:0.52829
[88]	eval-auc:0.52816
[89]	eval-auc:0.52881
[90]	eval-auc:0.52907
[91]	eval-auc:0.52991
[92]	eval-auc:0.53154
[93]	eval-auc:0.53011
[94]	eval-auc:0.53336
[95]	eval-auc:0.53180
[96]	eval-auc:0.53284
[97]	eval-auc:0.52965
[98]	eval-auc:0.53069
[99]	eval-auc:0.53258
[100]	eval-auc:0.53141
[101]	eva

In [2]:
print(f"Avg AUC: {baseline_avg_auc}")
print(f"Avg F1 Score: {baseline_avg_f1}")

Avg AUC: 0.555764504034606
Avg F1 Score: 0.49103436904002606
