Baseline

In [1]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './data_baseline_2/'

for event in range(14):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
avg_auc_baseline = np.mean(auc_scores)
avg_f1_baseline = np.mean(f1_scores)

print(f"Avg AUC: {avg_auc_baseline}")
print(f"Avg F1 Score: {avg_f1_baseline}")

[0]	eval-auc:0.43047
[1]	eval-auc:0.41243
[2]	eval-auc:0.32544
[3]	eval-auc:0.33254
[4]	eval-auc:0.36775
[5]	eval-auc:0.36036
[6]	eval-auc:0.39527
[7]	eval-auc:0.41686
[8]	eval-auc:0.46183
[9]	eval-auc:0.47959
[10]	eval-auc:0.51036
[11]	eval-auc:0.48432
[12]	eval-auc:0.47426
[13]	eval-auc:0.47959
[14]	eval-auc:0.52604
[15]	eval-auc:0.49704
[16]	eval-auc:0.50769
[17]	eval-auc:0.48757
[18]	eval-auc:0.50828
[19]	eval-auc:0.52012
[20]	eval-auc:0.49704
[21]	eval-auc:0.51243
[22]	eval-auc:0.51243
[23]	eval-auc:0.52959
[24]	eval-auc:0.52071
[25]	eval-auc:0.51420
[26]	eval-auc:0.51893
[27]	eval-auc:0.51775
[28]	eval-auc:0.51243
[29]	eval-auc:0.52485
[30]	eval-auc:0.49822
[31]	eval-auc:0.48580
[32]	eval-auc:0.49704
[33]	eval-auc:0.48107
[34]	eval-auc:0.48876
[35]	eval-auc:0.48462
[36]	eval-auc:0.48521
[37]	eval-auc:0.48047
[38]	eval-auc:0.49231
[39]	eval-auc:0.49467
[40]	eval-auc:0.49586
[41]	eval-auc:0.49231
[42]	eval-auc:0.50000
[43]	eval-auc:0.49586
[44]	eval-auc:0.51420
[45]	eval-auc:0.5011

[361]	eval-auc:0.56036
[362]	eval-auc:0.55976
[363]	eval-auc:0.55858
[364]	eval-auc:0.56036
[365]	eval-auc:0.55917
[366]	eval-auc:0.55680
[367]	eval-auc:0.55740
[368]	eval-auc:0.55858
[369]	eval-auc:0.55680
[370]	eval-auc:0.55740
[371]	eval-auc:0.55976
[372]	eval-auc:0.55740
[373]	eval-auc:0.55621
[374]	eval-auc:0.55621
[375]	eval-auc:0.55799
[376]	eval-auc:0.55740
[377]	eval-auc:0.55503
[378]	eval-auc:0.55621
[379]	eval-auc:0.55799
[380]	eval-auc:0.56036
[381]	eval-auc:0.55917
[382]	eval-auc:0.55976
[383]	eval-auc:0.55976
[384]	eval-auc:0.55976
[385]	eval-auc:0.55917
[386]	eval-auc:0.55976
[387]	eval-auc:0.56036
[388]	eval-auc:0.55976
[389]	eval-auc:0.55976
[390]	eval-auc:0.56036
[391]	eval-auc:0.56036
[392]	eval-auc:0.56095
[393]	eval-auc:0.56213
[394]	eval-auc:0.56036
[395]	eval-auc:0.56213
[396]	eval-auc:0.56213
[397]	eval-auc:0.56213
[398]	eval-auc:0.56213
[399]	eval-auc:0.56391
[400]	eval-auc:0.56627
[401]	eval-auc:0.56568
[402]	eval-auc:0.56391
[403]	eval-auc:0.56391
[404]	eval-

[200]	eval-auc:0.58065
[201]	eval-auc:0.58193
[202]	eval-auc:0.58244
[203]	eval-auc:0.57988
[204]	eval-auc:0.57680
[205]	eval-auc:0.57552
[0]	eval-auc:0.56601
[1]	eval-auc:0.54394
[2]	eval-auc:0.54528
[3]	eval-auc:0.53589
[4]	eval-auc:0.49887
[5]	eval-auc:0.50598
[6]	eval-auc:0.53465
[7]	eval-auc:0.52743
[8]	eval-auc:0.52764
[9]	eval-auc:0.52517
[10]	eval-auc:0.55116
[11]	eval-auc:0.55363
[12]	eval-auc:0.55322
[13]	eval-auc:0.56095
[14]	eval-auc:0.56724
[15]	eval-auc:0.55476
[16]	eval-auc:0.54528
[17]	eval-auc:0.53208
[18]	eval-auc:0.54414
[19]	eval-auc:0.55281
[20]	eval-auc:0.55507
[21]	eval-auc:0.53929
[22]	eval-auc:0.54744
[23]	eval-auc:0.54476
[24]	eval-auc:0.55116
[25]	eval-auc:0.54641
[26]	eval-auc:0.55074
[27]	eval-auc:0.54600
[28]	eval-auc:0.53795
[29]	eval-auc:0.54167
[30]	eval-auc:0.55930
[31]	eval-auc:0.55930
[32]	eval-auc:0.55992
[33]	eval-auc:0.55538
[34]	eval-auc:0.55281
[35]	eval-auc:0.56580
[36]	eval-auc:0.56436
[37]	eval-auc:0.56002
[38]	eval-auc:0.55384
[39]	eval-auc:

[355]	eval-auc:0.60149
[356]	eval-auc:0.60190
[357]	eval-auc:0.60252
[358]	eval-auc:0.60355
[359]	eval-auc:0.60293
[360]	eval-auc:0.60314
[361]	eval-auc:0.60128
[362]	eval-auc:0.60252
[363]	eval-auc:0.60293
[364]	eval-auc:0.60334
[365]	eval-auc:0.60231
[366]	eval-auc:0.60272
[367]	eval-auc:0.60149
[368]	eval-auc:0.60107
[369]	eval-auc:0.60107
[370]	eval-auc:0.60066
[371]	eval-auc:0.60107
[372]	eval-auc:0.60190
[373]	eval-auc:0.60231
[374]	eval-auc:0.60107
[375]	eval-auc:0.60025
[376]	eval-auc:0.60128
[377]	eval-auc:0.59922
[378]	eval-auc:0.59901
[379]	eval-auc:0.60045
[380]	eval-auc:0.60087
[0]	eval-auc:0.43596
[1]	eval-auc:0.45139
[2]	eval-auc:0.42747
[3]	eval-auc:0.41146
[4]	eval-auc:0.48322
[5]	eval-auc:0.51331
[6]	eval-auc:0.50965
[7]	eval-auc:0.53260
[8]	eval-auc:0.56578
[9]	eval-auc:0.57832
[10]	eval-auc:0.55594
[11]	eval-auc:0.55941
[12]	eval-auc:0.54128
[13]	eval-auc:0.55478
[14]	eval-auc:0.56790
[15]	eval-auc:0.55131
[16]	eval-auc:0.55826
[17]	eval-auc:0.56674
[18]	eval-auc:0.

[104]	eval-auc:0.48589
[105]	eval-auc:0.48328
[106]	eval-auc:0.48380
[107]	eval-auc:0.48119
[108]	eval-auc:0.48015
[109]	eval-auc:0.47753
[110]	eval-auc:0.47544
[111]	eval-auc:0.47649
[112]	eval-auc:0.47701
[113]	eval-auc:0.47597
[114]	eval-auc:0.47806
[115]	eval-auc:0.48067
[116]	eval-auc:0.48276
[117]	eval-auc:0.48485
[118]	eval-auc:0.48224
[119]	eval-auc:0.48224
[120]	eval-auc:0.48276
[121]	eval-auc:0.48380
[122]	eval-auc:0.48485
[123]	eval-auc:0.48224
[124]	eval-auc:0.47858
[125]	eval-auc:0.48224
[126]	eval-auc:0.48224
[127]	eval-auc:0.48224
[128]	eval-auc:0.48171
[129]	eval-auc:0.47806
[130]	eval-auc:0.47858
[131]	eval-auc:0.47701
[132]	eval-auc:0.47910
[133]	eval-auc:0.48276
[134]	eval-auc:0.47962
[135]	eval-auc:0.47910
[136]	eval-auc:0.47701
[137]	eval-auc:0.47701
[138]	eval-auc:0.47231
[139]	eval-auc:0.47179
[140]	eval-auc:0.47074
[141]	eval-auc:0.47335
[142]	eval-auc:0.47544
[143]	eval-auc:0.47492
[144]	eval-auc:0.47597
[145]	eval-auc:0.47231
[146]	eval-auc:0.47126
[147]	eval-

[63]	eval-auc:0.43321
[64]	eval-auc:0.43516
[65]	eval-auc:0.44054
[66]	eval-auc:0.44127
[67]	eval-auc:0.43980
[68]	eval-auc:0.43101
[69]	eval-auc:0.43101
[70]	eval-auc:0.43272
[71]	eval-auc:0.43346
[72]	eval-auc:0.43761
[73]	eval-auc:0.43883
[74]	eval-auc:0.44298
[75]	eval-auc:0.44274
[76]	eval-auc:0.43907
[77]	eval-auc:0.44298
[78]	eval-auc:0.43468
[79]	eval-auc:0.43541
[80]	eval-auc:0.43614
[81]	eval-auc:0.43199
[82]	eval-auc:0.43028
[83]	eval-auc:0.42833
[84]	eval-auc:0.43516
[85]	eval-auc:0.42906
[86]	eval-auc:0.43175
[87]	eval-auc:0.43248
[88]	eval-auc:0.43394
[89]	eval-auc:0.43663
[90]	eval-auc:0.43516
[91]	eval-auc:0.44029
[92]	eval-auc:0.43687
[93]	eval-auc:0.43907
[94]	eval-auc:0.43492
[95]	eval-auc:0.43834
[96]	eval-auc:0.43956
[97]	eval-auc:0.43907
[98]	eval-auc:0.44591
[99]	eval-auc:0.44982
[100]	eval-auc:0.44860
[101]	eval-auc:0.44835
[102]	eval-auc:0.45275
[103]	eval-auc:0.45201
[104]	eval-auc:0.45495
[105]	eval-auc:0.45275
[106]	eval-auc:0.45201
[107]	eval-auc:0.45104
[1

[9]	eval-auc:0.46738
[10]	eval-auc:0.46397
[11]	eval-auc:0.44733
[12]	eval-auc:0.44904
[13]	eval-auc:0.46482
[14]	eval-auc:0.45373
[15]	eval-auc:0.47420
[16]	eval-auc:0.46226
[17]	eval-auc:0.45970
[18]	eval-auc:0.46951
[19]	eval-auc:0.47122
[20]	eval-auc:0.46141
[21]	eval-auc:0.48742
[22]	eval-auc:0.48529
[23]	eval-auc:0.49168
[24]	eval-auc:0.50064
[25]	eval-auc:0.51301
[26]	eval-auc:0.50917
[27]	eval-auc:0.50618
[28]	eval-auc:0.50576
[29]	eval-auc:0.50149
[30]	eval-auc:0.49510
[31]	eval-auc:0.50192
[32]	eval-auc:0.50107
[33]	eval-auc:0.49808
[34]	eval-auc:0.49765
[35]	eval-auc:0.49595
[36]	eval-auc:0.50789
[37]	eval-auc:0.51386
[38]	eval-auc:0.50661
[39]	eval-auc:0.49552
[40]	eval-auc:0.49552
[41]	eval-auc:0.49680
[42]	eval-auc:0.50064
[43]	eval-auc:0.50874
[44]	eval-auc:0.50917
[45]	eval-auc:0.52111
[46]	eval-auc:0.52921
[47]	eval-auc:0.52665
[48]	eval-auc:0.52154
[49]	eval-auc:0.52068
[50]	eval-auc:0.52580
[51]	eval-auc:0.52921
[52]	eval-auc:0.52921
[53]	eval-auc:0.53177
[54]	eval-a

[370]	eval-auc:0.56205
[371]	eval-auc:0.56247
[372]	eval-auc:0.56333
[373]	eval-auc:0.56333
[374]	eval-auc:0.56461
[375]	eval-auc:0.56503
[376]	eval-auc:0.56461
[377]	eval-auc:0.56461
[378]	eval-auc:0.56375
[379]	eval-auc:0.56461
[380]	eval-auc:0.56375
[381]	eval-auc:0.56333
[382]	eval-auc:0.56333
[383]	eval-auc:0.56375
[384]	eval-auc:0.56503
[385]	eval-auc:0.56546
[386]	eval-auc:0.56503
[387]	eval-auc:0.56375
[388]	eval-auc:0.56333
[389]	eval-auc:0.56290
[390]	eval-auc:0.56375
[391]	eval-auc:0.56333
[392]	eval-auc:0.56461
[393]	eval-auc:0.56247
[394]	eval-auc:0.56418
[395]	eval-auc:0.56375
[396]	eval-auc:0.56375
[397]	eval-auc:0.56418
[398]	eval-auc:0.56418
[399]	eval-auc:0.56461
[400]	eval-auc:0.56503
[401]	eval-auc:0.56375
[402]	eval-auc:0.56333
[403]	eval-auc:0.56418
[404]	eval-auc:0.56461
[405]	eval-auc:0.56333
[406]	eval-auc:0.56418
[407]	eval-auc:0.56290
[408]	eval-auc:0.56375
[409]	eval-auc:0.56290
[410]	eval-auc:0.56333
[411]	eval-auc:0.56375
[412]	eval-auc:0.56418
[413]	eval-

[59]	eval-auc:0.57955
[60]	eval-auc:0.58289
[61]	eval-auc:0.58556
[62]	eval-auc:0.58389
[63]	eval-auc:0.58189
[64]	eval-auc:0.58890
[65]	eval-auc:0.58957
[66]	eval-auc:0.58723
[67]	eval-auc:0.58389
[68]	eval-auc:0.58155
[69]	eval-auc:0.58356
[70]	eval-auc:0.58255
[71]	eval-auc:0.57988
[72]	eval-auc:0.57921
[73]	eval-auc:0.58289
[74]	eval-auc:0.58489
[75]	eval-auc:0.58656
[76]	eval-auc:0.57320
[77]	eval-auc:0.57654
[78]	eval-auc:0.56417
[79]	eval-auc:0.56217
[80]	eval-auc:0.56217
[81]	eval-auc:0.56283
[82]	eval-auc:0.55615
[83]	eval-auc:0.55348
[84]	eval-auc:0.56049
[85]	eval-auc:0.56283
[86]	eval-auc:0.56718
[87]	eval-auc:0.56952
[88]	eval-auc:0.57152
[89]	eval-auc:0.57754
[90]	eval-auc:0.58055
[91]	eval-auc:0.56718
[92]	eval-auc:0.57286
[93]	eval-auc:0.56918
[94]	eval-auc:0.56918
[95]	eval-auc:0.56484
[96]	eval-auc:0.56217
[97]	eval-auc:0.56317
[98]	eval-auc:0.56350
[99]	eval-auc:0.56384
[100]	eval-auc:0.56350
[101]	eval-auc:0.55916
[102]	eval-auc:0.55114
[103]	eval-auc:0.55047
[104]	

[157]	eval-auc:0.65087
[158]	eval-auc:0.65029
[159]	eval-auc:0.64913
[160]	eval-auc:0.64651
[161]	eval-auc:0.64767
[162]	eval-auc:0.64680
[163]	eval-auc:0.64593
[164]	eval-auc:0.64535
[165]	eval-auc:0.64738
[166]	eval-auc:0.64651
[167]	eval-auc:0.64738
[168]	eval-auc:0.64651
[169]	eval-auc:0.64593
[170]	eval-auc:0.64738
[171]	eval-auc:0.64942
[172]	eval-auc:0.64884
[173]	eval-auc:0.65000
[174]	eval-auc:0.65087
[175]	eval-auc:0.65320
[176]	eval-auc:0.65203
[177]	eval-auc:0.65320
[178]	eval-auc:0.65291
[179]	eval-auc:0.65233
[180]	eval-auc:0.65320
[181]	eval-auc:0.65174
[182]	eval-auc:0.65291
[183]	eval-auc:0.65116
[184]	eval-auc:0.64971
[185]	eval-auc:0.65203
[186]	eval-auc:0.65116
[187]	eval-auc:0.65233
[188]	eval-auc:0.65378
[189]	eval-auc:0.65640
[190]	eval-auc:0.65640
[191]	eval-auc:0.65291
[192]	eval-auc:0.65552
[193]	eval-auc:0.65552
[194]	eval-auc:0.65436
[195]	eval-auc:0.65523
[196]	eval-auc:0.65233
[197]	eval-auc:0.65203
[198]	eval-auc:0.65116
[199]	eval-auc:0.65203
[200]	eval-

[107]	eval-auc:0.59508
[108]	eval-auc:0.59992
[109]	eval-auc:0.59629
[110]	eval-auc:0.60355
[111]	eval-auc:0.60032
[112]	eval-auc:0.60113
[113]	eval-auc:0.59952
[114]	eval-auc:0.59992
[115]	eval-auc:0.59871
[116]	eval-auc:0.60113
[117]	eval-auc:0.59871
[118]	eval-auc:0.59710
[119]	eval-auc:0.59710
[120]	eval-auc:0.59831
[121]	eval-auc:0.59750
[122]	eval-auc:0.59670
[123]	eval-auc:0.60153
[124]	eval-auc:0.60596
[125]	eval-auc:0.60637
[126]	eval-auc:0.59790
[127]	eval-auc:0.59750
[128]	eval-auc:0.59508
[129]	eval-auc:0.59549
[130]	eval-auc:0.59549
[131]	eval-auc:0.59347
[132]	eval-auc:0.58904
[133]	eval-auc:0.59388
[134]	eval-auc:0.58824
[135]	eval-auc:0.58340
[136]	eval-auc:0.59025
[137]	eval-auc:0.58018
[138]	eval-auc:0.58501
[139]	eval-auc:0.58340
[140]	eval-auc:0.58662
[141]	eval-auc:0.57937
[142]	eval-auc:0.58259
[143]	eval-auc:0.58421
[144]	eval-auc:0.58824
[145]	eval-auc:0.58380
[146]	eval-auc:0.58219
[147]	eval-auc:0.57897
[148]	eval-auc:0.58219
[149]	eval-auc:0.58018
[150]	eval-

In [2]:
print(f"Avg AUC: {avg_auc_baseline}")
print(f"Avg F1 Score: {avg_f1_baseline}")

Avg AUC: 0.5228700731104651
Avg F1 Score: 0.46700555854104353
